Files
ComfyUI-Lora-Manager/tests/services/test_recipe_scanner.py
Will Miao 9198a23ba9 feat: normalize and validate checkpoint entries before enrichment
Add _normalize_checkpoint_entry method to handle legacy checkpoint data formats (strings, tuples) by converting them to dictionaries. This prevents errors during enrichment when checkpoint data is not in the expected dictionary format. Invalid checkpoint entries are now removed instead of causing processing failures.

- Update get_paginated_data and get_recipe_by_id methods to use normalization
- Add test cases for legacy string and tuple checkpoint formats
- Ensure backward compatibility with existing checkpoint handling
2025-11-21 23:36:32 +08:00

352 lines
11 KiB
Python

import asyncio
import json
import os
from pathlib import Path
from types import SimpleNamespace
import pytest
from py.config import config
from py.services.recipe_scanner import RecipeScanner
from py.utils.utils import calculate_recipe_fingerprint
class StubHashIndex:
def __init__(self) -> None:
self._hash_to_path: dict[str, str] = {}
def get_path(self, hash_value: str) -> str | None:
return self._hash_to_path.get(hash_value)
class StubLoraScanner:
def __init__(self) -> None:
self._hash_index = StubHashIndex()
self._hash_meta: dict[str, dict[str, str]] = {}
self._models_by_name: dict[str, dict] = {}
self._cache = SimpleNamespace(raw_data=[], version_index={})
async def get_cached_data(self):
return self._cache
def has_hash(self, hash_value: str) -> bool:
return hash_value.lower() in self._hash_meta
def get_preview_url_by_hash(self, hash_value: str) -> str:
meta = self._hash_meta.get(hash_value.lower())
return meta.get("preview_url", "") if meta else ""
def get_path_by_hash(self, hash_value: str) -> str | None:
meta = self._hash_meta.get(hash_value.lower())
return meta.get("path") if meta else None
async def get_model_info_by_name(self, name: str):
return self._models_by_name.get(name)
def register_model(self, name: str, info: dict) -> None:
self._models_by_name[name] = info
hash_value = (info.get("sha256") or "").lower()
version_id = info.get("civitai", {}).get("id")
if hash_value:
self._hash_meta[hash_value] = {
"path": info.get("file_path", ""),
"preview_url": info.get("preview_url", ""),
}
self._hash_index._hash_to_path[hash_value] = info.get("file_path", "")
if version_id is not None:
self._cache.version_index[int(version_id)] = {
"file_path": info.get("file_path", ""),
"sha256": hash_value,
"preview_url": info.get("preview_url", ""),
"civitai": info.get("civitai", {}),
}
self._cache.raw_data.append({
"sha256": info.get("sha256", ""),
"path": info.get("file_path", ""),
"civitai": info.get("civitai", {}),
})
@pytest.fixture
def recipe_scanner(tmp_path: Path, monkeypatch):
RecipeScanner._instance = None
monkeypatch.setattr(config, "loras_roots", [str(tmp_path)])
stub = StubLoraScanner()
scanner = RecipeScanner(lora_scanner=stub)
asyncio.run(scanner.refresh_cache(force=True))
yield scanner, stub
RecipeScanner._instance = None
async def test_add_recipe_during_concurrent_reads(recipe_scanner):
scanner, _ = recipe_scanner
initial_recipe = {
"id": "one",
"file_path": "path/a.png",
"title": "First",
"modified": 1.0,
"created_date": 1.0,
"loras": [],
}
await scanner.add_recipe(initial_recipe)
new_recipe = {
"id": "two",
"file_path": "path/b.png",
"title": "Second",
"modified": 2.0,
"created_date": 2.0,
"loras": [],
}
async def reader_task():
for _ in range(5):
cache = await scanner.get_cached_data()
_ = [item["id"] for item in cache.raw_data]
await asyncio.sleep(0)
await asyncio.gather(reader_task(), reader_task(), scanner.add_recipe(new_recipe))
await asyncio.sleep(0)
cache = await scanner.get_cached_data()
assert {item["id"] for item in cache.raw_data} == {"one", "two"}
assert len(cache.sorted_by_name) == len(cache.raw_data)
async def test_remove_recipe_during_reads(recipe_scanner):
scanner, _ = recipe_scanner
recipe_ids = ["alpha", "beta", "gamma"]
for index, recipe_id in enumerate(recipe_ids):
await scanner.add_recipe({
"id": recipe_id,
"file_path": f"path/{recipe_id}.png",
"title": recipe_id,
"modified": float(index),
"created_date": float(index),
"loras": [],
})
async def reader_task():
for _ in range(5):
cache = await scanner.get_cached_data()
_ = list(cache.sorted_by_date)
await asyncio.sleep(0)
await asyncio.gather(reader_task(), scanner.remove_recipe("beta"))
await asyncio.sleep(0)
cache = await scanner.get_cached_data()
assert {item["id"] for item in cache.raw_data} == {"alpha", "gamma"}
async def test_update_lora_entry_updates_cache_and_file(tmp_path: Path, recipe_scanner):
scanner, stub = recipe_scanner
recipes_dir = Path(config.loras_roots[0]) / "recipes"
recipes_dir.mkdir(parents=True, exist_ok=True)
recipe_id = "recipe-1"
recipe_path = recipes_dir / f"{recipe_id}.recipe.json"
recipe_data = {
"id": recipe_id,
"file_path": str(tmp_path / "image.png"),
"title": "Original",
"modified": 0.0,
"created_date": 0.0,
"loras": [
{"file_name": "old", "strength": 1.0, "hash": "", "isDeleted": True, "exclude": True},
],
}
recipe_path.write_text(json.dumps(recipe_data))
await scanner.add_recipe(dict(recipe_data))
target_hash = "abc123"
target_info = {
"sha256": target_hash,
"file_path": str(tmp_path / "loras" / "target.safetensors"),
"preview_url": "preview.png",
"civitai": {"id": 42, "name": "v1", "model": {"name": "Target"}},
}
stub.register_model("target", target_info)
updated_recipe, updated_lora = await scanner.update_lora_entry(
recipe_id,
0,
target_name="target",
target_lora=target_info,
)
assert updated_lora["inLibrary"] is True
assert updated_lora["localPath"] == target_info["file_path"]
assert updated_lora["hash"] == target_hash
with recipe_path.open("r", encoding="utf-8") as file_obj:
persisted = json.load(file_obj)
expected_fingerprint = calculate_recipe_fingerprint(persisted["loras"])
assert persisted["fingerprint"] == expected_fingerprint
cache = await scanner.get_cached_data()
cached_recipe = next(item for item in cache.raw_data if item["id"] == recipe_id)
assert cached_recipe["loras"][0]["hash"] == target_hash
assert cached_recipe["fingerprint"] == expected_fingerprint
@pytest.mark.asyncio
async def test_load_recipe_rewrites_missing_image_path(tmp_path: Path, recipe_scanner):
scanner, _ = recipe_scanner
recipes_dir = Path(config.loras_roots[0]) / "recipes"
recipes_dir.mkdir(parents=True, exist_ok=True)
recipe_id = "moved"
old_root = tmp_path / "old_root"
old_path = old_root / "recipes" / f"{recipe_id}.webp"
recipe_path = recipes_dir / f"{recipe_id}.recipe.json"
current_image = recipes_dir / f"{recipe_id}.webp"
current_image.write_bytes(b"image-bytes")
recipe_data = {
"id": recipe_id,
"file_path": str(old_path),
"title": "Relocated",
"modified": 0.0,
"created_date": 0.0,
"loras": [],
}
recipe_path.write_text(json.dumps(recipe_data))
loaded = await scanner._load_recipe_file(str(recipe_path))
expected_path = os.path.normpath(str(current_image))
assert loaded["file_path"] == expected_path
persisted = json.loads(recipe_path.read_text())
assert persisted["file_path"] == expected_path
@pytest.mark.asyncio
async def test_load_recipe_upgrades_string_checkpoint(tmp_path: Path, recipe_scanner):
scanner, _ = recipe_scanner
recipes_dir = Path(config.loras_roots[0]) / "recipes"
recipes_dir.mkdir(parents=True, exist_ok=True)
recipe_id = "legacy-checkpoint"
image_path = recipes_dir / f"{recipe_id}.webp"
recipe_path = recipes_dir / f"{recipe_id}.recipe.json"
recipe_path.write_text(
json.dumps(
{
"id": recipe_id,
"file_path": str(image_path),
"title": "Legacy",
"modified": 0.0,
"created_date": 0.0,
"loras": [],
"checkpoint": "sd15.safetensors",
}
)
)
loaded = await scanner._load_recipe_file(str(recipe_path))
assert isinstance(loaded["checkpoint"], dict)
assert loaded["checkpoint"]["name"] == "sd15.safetensors"
assert loaded["checkpoint"]["file_name"] == "sd15"
@pytest.mark.asyncio
async def test_get_paginated_data_normalizes_legacy_checkpoint(recipe_scanner):
scanner, _ = recipe_scanner
image_path = Path(config.loras_roots[0]) / "legacy.webp"
await scanner.add_recipe(
{
"id": "legacy-checkpoint",
"file_path": str(image_path),
"title": "Legacy",
"modified": 0.0,
"created_date": 0.0,
"loras": [],
"checkpoint": ["legacy.safetensors"],
}
)
await asyncio.sleep(0)
result = await scanner.get_paginated_data(page=1, page_size=5)
checkpoint = result["items"][0]["checkpoint"]
assert checkpoint["name"] == "legacy.safetensors"
assert checkpoint["file_name"] == "legacy"
@pytest.mark.asyncio
async def test_get_recipe_by_id_handles_non_dict_checkpoint(recipe_scanner):
scanner, _ = recipe_scanner
image_path = Path(config.loras_roots[0]) / "by-id.webp"
await scanner.add_recipe(
{
"id": "by-id-checkpoint",
"file_path": str(image_path),
"title": "ById",
"modified": 0.0,
"created_date": 0.0,
"loras": [],
"checkpoint": ("by-id.safetensors",),
}
)
recipe = await scanner.get_recipe_by_id("by-id-checkpoint")
assert recipe["checkpoint"]["name"] == "by-id.safetensors"
assert recipe["checkpoint"]["file_name"] == "by-id"
def test_enrich_uses_version_index_when_hash_missing(recipe_scanner):
scanner, stub = recipe_scanner
version_id = 77
file_path = str(Path(config.loras_roots[0]) / "loras" / "version-entry.safetensors")
registered = {
"sha256": "deadbeef",
"file_path": file_path,
"preview_url": "preview-from-cache.png",
"civitai": {"id": version_id},
}
stub.register_model("version-entry", registered)
lora = {"hash": "", "file_name": "", "modelVersionId": version_id, "strength": 0.5}
enriched = scanner._enrich_lora_entry(dict(lora))
assert enriched["inLibrary"] is True
assert enriched["hash"] == registered["sha256"]
assert enriched["localPath"] == file_path
assert enriched["file_name"] == Path(file_path).stem
assert enriched["preview_url"] == registered["preview_url"]
def test_enrich_formats_absolute_preview_paths(recipe_scanner, tmp_path):
scanner, stub = recipe_scanner
version_id = 88
preview_path = tmp_path / "loras" / "version-entry.preview.jpeg"
preview_path.parent.mkdir(parents=True, exist_ok=True)
preview_path.write_text("preview")
model_path = tmp_path / "loras" / "version-entry.safetensors"
model_path.write_text("weights")
stub.register_model(
"absolute-preview",
{
"sha256": "feedface",
"file_path": str(model_path),
"preview_url": str(preview_path),
"civitai": {"id": version_id},
},
)
lora = {"hash": "", "file_name": "", "modelVersionId": version_id, "strength": 0.5}
enriched = scanner._enrich_lora_entry(dict(lora))
assert enriched["preview_url"] == config.get_preview_static_url(str(preview_path))