mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Add _normalize_checkpoint_entry method to handle legacy checkpoint data formats (strings, tuples) by converting them to dictionaries. This prevents errors during enrichment when checkpoint data is not in the expected dictionary format. Invalid checkpoint entries are now removed instead of causing processing failures. - Update get_paginated_data and get_recipe_by_id methods to use normalization - Add test cases for legacy string and tuple checkpoint formats - Ensure backward compatibility with existing checkpoint handling
352 lines
11 KiB
Python
352 lines
11 KiB
Python
import asyncio
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
from types import SimpleNamespace
|
|
|
|
import pytest
|
|
|
|
from py.config import config
|
|
from py.services.recipe_scanner import RecipeScanner
|
|
from py.utils.utils import calculate_recipe_fingerprint
|
|
|
|
|
|
class StubHashIndex:
|
|
def __init__(self) -> None:
|
|
self._hash_to_path: dict[str, str] = {}
|
|
|
|
def get_path(self, hash_value: str) -> str | None:
|
|
return self._hash_to_path.get(hash_value)
|
|
|
|
|
|
class StubLoraScanner:
|
|
def __init__(self) -> None:
|
|
self._hash_index = StubHashIndex()
|
|
self._hash_meta: dict[str, dict[str, str]] = {}
|
|
self._models_by_name: dict[str, dict] = {}
|
|
self._cache = SimpleNamespace(raw_data=[], version_index={})
|
|
|
|
async def get_cached_data(self):
|
|
return self._cache
|
|
|
|
def has_hash(self, hash_value: str) -> bool:
|
|
return hash_value.lower() in self._hash_meta
|
|
|
|
def get_preview_url_by_hash(self, hash_value: str) -> str:
|
|
meta = self._hash_meta.get(hash_value.lower())
|
|
return meta.get("preview_url", "") if meta else ""
|
|
|
|
def get_path_by_hash(self, hash_value: str) -> str | None:
|
|
meta = self._hash_meta.get(hash_value.lower())
|
|
return meta.get("path") if meta else None
|
|
|
|
async def get_model_info_by_name(self, name: str):
|
|
return self._models_by_name.get(name)
|
|
|
|
def register_model(self, name: str, info: dict) -> None:
|
|
self._models_by_name[name] = info
|
|
hash_value = (info.get("sha256") or "").lower()
|
|
version_id = info.get("civitai", {}).get("id")
|
|
if hash_value:
|
|
self._hash_meta[hash_value] = {
|
|
"path": info.get("file_path", ""),
|
|
"preview_url": info.get("preview_url", ""),
|
|
}
|
|
self._hash_index._hash_to_path[hash_value] = info.get("file_path", "")
|
|
if version_id is not None:
|
|
self._cache.version_index[int(version_id)] = {
|
|
"file_path": info.get("file_path", ""),
|
|
"sha256": hash_value,
|
|
"preview_url": info.get("preview_url", ""),
|
|
"civitai": info.get("civitai", {}),
|
|
}
|
|
self._cache.raw_data.append({
|
|
"sha256": info.get("sha256", ""),
|
|
"path": info.get("file_path", ""),
|
|
"civitai": info.get("civitai", {}),
|
|
})
|
|
|
|
|
|
@pytest.fixture
|
|
def recipe_scanner(tmp_path: Path, monkeypatch):
|
|
RecipeScanner._instance = None
|
|
monkeypatch.setattr(config, "loras_roots", [str(tmp_path)])
|
|
stub = StubLoraScanner()
|
|
scanner = RecipeScanner(lora_scanner=stub)
|
|
asyncio.run(scanner.refresh_cache(force=True))
|
|
yield scanner, stub
|
|
RecipeScanner._instance = None
|
|
|
|
|
|
async def test_add_recipe_during_concurrent_reads(recipe_scanner):
|
|
scanner, _ = recipe_scanner
|
|
|
|
initial_recipe = {
|
|
"id": "one",
|
|
"file_path": "path/a.png",
|
|
"title": "First",
|
|
"modified": 1.0,
|
|
"created_date": 1.0,
|
|
"loras": [],
|
|
}
|
|
await scanner.add_recipe(initial_recipe)
|
|
|
|
new_recipe = {
|
|
"id": "two",
|
|
"file_path": "path/b.png",
|
|
"title": "Second",
|
|
"modified": 2.0,
|
|
"created_date": 2.0,
|
|
"loras": [],
|
|
}
|
|
|
|
async def reader_task():
|
|
for _ in range(5):
|
|
cache = await scanner.get_cached_data()
|
|
_ = [item["id"] for item in cache.raw_data]
|
|
await asyncio.sleep(0)
|
|
|
|
await asyncio.gather(reader_task(), reader_task(), scanner.add_recipe(new_recipe))
|
|
await asyncio.sleep(0)
|
|
cache = await scanner.get_cached_data()
|
|
|
|
assert {item["id"] for item in cache.raw_data} == {"one", "two"}
|
|
assert len(cache.sorted_by_name) == len(cache.raw_data)
|
|
|
|
|
|
async def test_remove_recipe_during_reads(recipe_scanner):
|
|
scanner, _ = recipe_scanner
|
|
|
|
recipe_ids = ["alpha", "beta", "gamma"]
|
|
for index, recipe_id in enumerate(recipe_ids):
|
|
await scanner.add_recipe({
|
|
"id": recipe_id,
|
|
"file_path": f"path/{recipe_id}.png",
|
|
"title": recipe_id,
|
|
"modified": float(index),
|
|
"created_date": float(index),
|
|
"loras": [],
|
|
})
|
|
|
|
async def reader_task():
|
|
for _ in range(5):
|
|
cache = await scanner.get_cached_data()
|
|
_ = list(cache.sorted_by_date)
|
|
await asyncio.sleep(0)
|
|
|
|
await asyncio.gather(reader_task(), scanner.remove_recipe("beta"))
|
|
await asyncio.sleep(0)
|
|
cache = await scanner.get_cached_data()
|
|
|
|
assert {item["id"] for item in cache.raw_data} == {"alpha", "gamma"}
|
|
|
|
|
|
async def test_update_lora_entry_updates_cache_and_file(tmp_path: Path, recipe_scanner):
|
|
scanner, stub = recipe_scanner
|
|
recipes_dir = Path(config.loras_roots[0]) / "recipes"
|
|
recipes_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
recipe_id = "recipe-1"
|
|
recipe_path = recipes_dir / f"{recipe_id}.recipe.json"
|
|
recipe_data = {
|
|
"id": recipe_id,
|
|
"file_path": str(tmp_path / "image.png"),
|
|
"title": "Original",
|
|
"modified": 0.0,
|
|
"created_date": 0.0,
|
|
"loras": [
|
|
{"file_name": "old", "strength": 1.0, "hash": "", "isDeleted": True, "exclude": True},
|
|
],
|
|
}
|
|
recipe_path.write_text(json.dumps(recipe_data))
|
|
|
|
await scanner.add_recipe(dict(recipe_data))
|
|
|
|
target_hash = "abc123"
|
|
target_info = {
|
|
"sha256": target_hash,
|
|
"file_path": str(tmp_path / "loras" / "target.safetensors"),
|
|
"preview_url": "preview.png",
|
|
"civitai": {"id": 42, "name": "v1", "model": {"name": "Target"}},
|
|
}
|
|
stub.register_model("target", target_info)
|
|
|
|
updated_recipe, updated_lora = await scanner.update_lora_entry(
|
|
recipe_id,
|
|
0,
|
|
target_name="target",
|
|
target_lora=target_info,
|
|
)
|
|
|
|
assert updated_lora["inLibrary"] is True
|
|
assert updated_lora["localPath"] == target_info["file_path"]
|
|
assert updated_lora["hash"] == target_hash
|
|
|
|
with recipe_path.open("r", encoding="utf-8") as file_obj:
|
|
persisted = json.load(file_obj)
|
|
|
|
expected_fingerprint = calculate_recipe_fingerprint(persisted["loras"])
|
|
assert persisted["fingerprint"] == expected_fingerprint
|
|
|
|
cache = await scanner.get_cached_data()
|
|
cached_recipe = next(item for item in cache.raw_data if item["id"] == recipe_id)
|
|
assert cached_recipe["loras"][0]["hash"] == target_hash
|
|
assert cached_recipe["fingerprint"] == expected_fingerprint
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_load_recipe_rewrites_missing_image_path(tmp_path: Path, recipe_scanner):
|
|
scanner, _ = recipe_scanner
|
|
recipes_dir = Path(config.loras_roots[0]) / "recipes"
|
|
recipes_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
recipe_id = "moved"
|
|
old_root = tmp_path / "old_root"
|
|
old_path = old_root / "recipes" / f"{recipe_id}.webp"
|
|
recipe_path = recipes_dir / f"{recipe_id}.recipe.json"
|
|
current_image = recipes_dir / f"{recipe_id}.webp"
|
|
current_image.write_bytes(b"image-bytes")
|
|
|
|
recipe_data = {
|
|
"id": recipe_id,
|
|
"file_path": str(old_path),
|
|
"title": "Relocated",
|
|
"modified": 0.0,
|
|
"created_date": 0.0,
|
|
"loras": [],
|
|
}
|
|
recipe_path.write_text(json.dumps(recipe_data))
|
|
|
|
loaded = await scanner._load_recipe_file(str(recipe_path))
|
|
|
|
expected_path = os.path.normpath(str(current_image))
|
|
assert loaded["file_path"] == expected_path
|
|
|
|
persisted = json.loads(recipe_path.read_text())
|
|
assert persisted["file_path"] == expected_path
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_load_recipe_upgrades_string_checkpoint(tmp_path: Path, recipe_scanner):
|
|
scanner, _ = recipe_scanner
|
|
recipes_dir = Path(config.loras_roots[0]) / "recipes"
|
|
recipes_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
recipe_id = "legacy-checkpoint"
|
|
image_path = recipes_dir / f"{recipe_id}.webp"
|
|
recipe_path = recipes_dir / f"{recipe_id}.recipe.json"
|
|
recipe_path.write_text(
|
|
json.dumps(
|
|
{
|
|
"id": recipe_id,
|
|
"file_path": str(image_path),
|
|
"title": "Legacy",
|
|
"modified": 0.0,
|
|
"created_date": 0.0,
|
|
"loras": [],
|
|
"checkpoint": "sd15.safetensors",
|
|
}
|
|
)
|
|
)
|
|
|
|
loaded = await scanner._load_recipe_file(str(recipe_path))
|
|
|
|
assert isinstance(loaded["checkpoint"], dict)
|
|
assert loaded["checkpoint"]["name"] == "sd15.safetensors"
|
|
assert loaded["checkpoint"]["file_name"] == "sd15"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_paginated_data_normalizes_legacy_checkpoint(recipe_scanner):
|
|
scanner, _ = recipe_scanner
|
|
image_path = Path(config.loras_roots[0]) / "legacy.webp"
|
|
await scanner.add_recipe(
|
|
{
|
|
"id": "legacy-checkpoint",
|
|
"file_path": str(image_path),
|
|
"title": "Legacy",
|
|
"modified": 0.0,
|
|
"created_date": 0.0,
|
|
"loras": [],
|
|
"checkpoint": ["legacy.safetensors"],
|
|
}
|
|
)
|
|
await asyncio.sleep(0)
|
|
|
|
result = await scanner.get_paginated_data(page=1, page_size=5)
|
|
|
|
checkpoint = result["items"][0]["checkpoint"]
|
|
assert checkpoint["name"] == "legacy.safetensors"
|
|
assert checkpoint["file_name"] == "legacy"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_recipe_by_id_handles_non_dict_checkpoint(recipe_scanner):
|
|
scanner, _ = recipe_scanner
|
|
image_path = Path(config.loras_roots[0]) / "by-id.webp"
|
|
await scanner.add_recipe(
|
|
{
|
|
"id": "by-id-checkpoint",
|
|
"file_path": str(image_path),
|
|
"title": "ById",
|
|
"modified": 0.0,
|
|
"created_date": 0.0,
|
|
"loras": [],
|
|
"checkpoint": ("by-id.safetensors",),
|
|
}
|
|
)
|
|
|
|
recipe = await scanner.get_recipe_by_id("by-id-checkpoint")
|
|
|
|
assert recipe["checkpoint"]["name"] == "by-id.safetensors"
|
|
assert recipe["checkpoint"]["file_name"] == "by-id"
|
|
|
|
|
|
def test_enrich_uses_version_index_when_hash_missing(recipe_scanner):
|
|
scanner, stub = recipe_scanner
|
|
version_id = 77
|
|
file_path = str(Path(config.loras_roots[0]) / "loras" / "version-entry.safetensors")
|
|
registered = {
|
|
"sha256": "deadbeef",
|
|
"file_path": file_path,
|
|
"preview_url": "preview-from-cache.png",
|
|
"civitai": {"id": version_id},
|
|
}
|
|
stub.register_model("version-entry", registered)
|
|
|
|
lora = {"hash": "", "file_name": "", "modelVersionId": version_id, "strength": 0.5}
|
|
|
|
enriched = scanner._enrich_lora_entry(dict(lora))
|
|
|
|
assert enriched["inLibrary"] is True
|
|
assert enriched["hash"] == registered["sha256"]
|
|
assert enriched["localPath"] == file_path
|
|
assert enriched["file_name"] == Path(file_path).stem
|
|
assert enriched["preview_url"] == registered["preview_url"]
|
|
|
|
|
|
def test_enrich_formats_absolute_preview_paths(recipe_scanner, tmp_path):
|
|
scanner, stub = recipe_scanner
|
|
version_id = 88
|
|
preview_path = tmp_path / "loras" / "version-entry.preview.jpeg"
|
|
preview_path.parent.mkdir(parents=True, exist_ok=True)
|
|
preview_path.write_text("preview")
|
|
model_path = tmp_path / "loras" / "version-entry.safetensors"
|
|
model_path.write_text("weights")
|
|
|
|
stub.register_model(
|
|
"absolute-preview",
|
|
{
|
|
"sha256": "feedface",
|
|
"file_path": str(model_path),
|
|
"preview_url": str(preview_path),
|
|
"civitai": {"id": version_id},
|
|
},
|
|
)
|
|
|
|
lora = {"hash": "", "file_name": "", "modelVersionId": version_id, "strength": 0.5}
|
|
|
|
enriched = scanner._enrich_lora_entry(dict(lora))
|
|
|
|
assert enriched["preview_url"] == config.get_preview_static_url(str(preview_path))
|