mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat: Add recipe metadata repair functionality with UI, API, and progress tracking.
This commit is contained in:
278
tests/services/test_recipe_repair.py
Normal file
278
tests/services/test_recipe_repair.py
Normal file
@@ -0,0 +1,278 @@
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from py.services.recipe_scanner import RecipeScanner
|
||||
from types import SimpleNamespace
|
||||
|
||||
# We define these here to help with spec= if needed
|
||||
class MockCivitaiClient:
|
||||
async def get_image_info(self, image_id):
|
||||
pass
|
||||
|
||||
class MockPersistenceService:
|
||||
async def save_recipe(self, recipe):
|
||||
pass
|
||||
|
||||
@pytest.fixture
|
||||
def mock_civitai_client():
|
||||
client = MagicMock(spec=MockCivitaiClient)
|
||||
client.get_image_info = AsyncMock()
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_metadata_provider():
|
||||
provider = MagicMock()
|
||||
provider.get_model_version_info = AsyncMock(return_value=(None, None))
|
||||
provider.get_model_by_hash = AsyncMock(return_value=(None, None))
|
||||
return provider
|
||||
|
||||
@pytest.fixture
|
||||
def recipe_scanner():
|
||||
lora_scanner = MagicMock()
|
||||
lora_scanner.get_cached_data = AsyncMock(return_value=SimpleNamespace(raw_data=[]))
|
||||
|
||||
scanner = RecipeScanner(lora_scanner=lora_scanner)
|
||||
return scanner
|
||||
|
||||
@pytest.fixture
|
||||
def setup_scanner(recipe_scanner, mock_civitai_client, mock_metadata_provider, monkeypatch):
|
||||
monkeypatch.setattr(recipe_scanner, "_get_civitai_client", AsyncMock(return_value=mock_civitai_client))
|
||||
|
||||
# Wrap the real method with a mock so we can check calls but still execute it
|
||||
real_save = recipe_scanner._save_recipe_persistently
|
||||
mock_save = AsyncMock(side_effect=real_save)
|
||||
monkeypatch.setattr(recipe_scanner, "_save_recipe_persistently", mock_save)
|
||||
|
||||
monkeypatch.setattr("py.services.recipe_scanner.get_default_metadata_provider", AsyncMock(return_value=mock_metadata_provider))
|
||||
|
||||
# Mock get_recipe_json_path to avoid file system issues in tests
|
||||
recipe_scanner.get_recipe_json_path = AsyncMock(return_value="/tmp/test_recipe.json")
|
||||
# Mock open to avoid actual file writing
|
||||
monkeypatch.setattr("builtins.open", MagicMock())
|
||||
monkeypatch.setattr("json.dump", MagicMock())
|
||||
monkeypatch.setattr("os.path.exists", MagicMock(return_value=False)) # avoid EXIF logic
|
||||
|
||||
return recipe_scanner, mock_civitai_client, mock_metadata_provider
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_repair_all_recipes_skip_up_to_date(setup_scanner):
|
||||
recipe_scanner, _, _ = setup_scanner
|
||||
|
||||
recipe_scanner._cache = SimpleNamespace(raw_data=[
|
||||
{"id": "r1", "repair_version": RecipeScanner.REPAIR_VERSION, "title": "Up to date"}
|
||||
])
|
||||
|
||||
# Run
|
||||
results = await recipe_scanner.repair_all_recipes()
|
||||
|
||||
# Verify
|
||||
assert results["repaired"] == 0
|
||||
assert results["skipped"] == 1
|
||||
recipe_scanner._save_recipe_persistently.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_repair_all_recipes_with_enriched_checkpoint_id(setup_scanner):
|
||||
recipe_scanner, mock_civitai_client, mock_metadata_provider = setup_scanner
|
||||
|
||||
recipe = {
|
||||
"id": "r1",
|
||||
"title": "Old Recipe",
|
||||
"source_url": "https://civitai.com/images/12345",
|
||||
"checkpoint": None,
|
||||
"gen_params": {"prompt": ""}
|
||||
}
|
||||
recipe_scanner._cache = SimpleNamespace(raw_data=[recipe])
|
||||
|
||||
# Mock image info returning modelVersionId
|
||||
mock_civitai_client.get_image_info.return_value = {
|
||||
"modelVersionId": 5678,
|
||||
"meta": {"prompt": "a beautiful forest", "Checkpoint": "basic_name.safetensors"}
|
||||
}
|
||||
|
||||
# Mock metadata provider returning full info
|
||||
mock_metadata_provider.get_model_version_info.return_value = ({
|
||||
"id": 5678,
|
||||
"modelId": 1234,
|
||||
"name": "v1.0",
|
||||
"model": {"name": "Full Model Name"},
|
||||
"baseModel": "SDXL 1.0",
|
||||
"images": [{"url": "https://image.url/thumb.jpg"}],
|
||||
"files": [{"type": "Model", "hashes": {"SHA256": "ABCDEF"}, "name": "full_filename.safetensors"}]
|
||||
}, None)
|
||||
|
||||
# Run
|
||||
results = await recipe_scanner.repair_all_recipes()
|
||||
|
||||
# Verify
|
||||
assert results["repaired"] == 1
|
||||
mock_metadata_provider.get_model_version_info.assert_called_with("5678")
|
||||
|
||||
saved_recipe = recipe_scanner._save_recipe_persistently.call_args[0][0]
|
||||
checkpoint = saved_recipe["checkpoint"]
|
||||
assert checkpoint["name"] == "Full Model Name"
|
||||
assert checkpoint["version"] == "v1.0"
|
||||
assert checkpoint["modelId"] == 1234
|
||||
assert checkpoint["id"] == 5678
|
||||
assert checkpoint["hash"] == "abcdef"
|
||||
assert checkpoint["file_name"] == "full_filename"
|
||||
assert "thumbnailUrl" not in checkpoint # Stripped during sanitation
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_repair_all_recipes_with_enriched_checkpoint_hash(setup_scanner):
|
||||
recipe_scanner, mock_civitai_client, mock_metadata_provider = setup_scanner
|
||||
|
||||
recipe = {
|
||||
"id": "r1",
|
||||
"title": "Embedded Only",
|
||||
"checkpoint": None,
|
||||
"gen_params": {
|
||||
"prompt": "",
|
||||
"Model hash": "hash123"
|
||||
}
|
||||
}
|
||||
recipe_scanner._cache = SimpleNamespace(raw_data=[recipe])
|
||||
|
||||
# Mock metadata provider lookup by hash
|
||||
mock_metadata_provider.get_model_by_hash.return_value = ({
|
||||
"id": 999,
|
||||
"modelId": 888,
|
||||
"name": "v2.0",
|
||||
"model": {"name": "Hashed Model"},
|
||||
"baseModel": "SD 1.5",
|
||||
"files": [{"type": "Model", "hashes": {"SHA256": "hash123"}, "name": "hashed.safetensors"}]
|
||||
}, None)
|
||||
|
||||
# Run
|
||||
results = await recipe_scanner.repair_all_recipes()
|
||||
|
||||
# Verify
|
||||
assert results["repaired"] == 1
|
||||
mock_metadata_provider.get_model_by_hash.assert_called_with("hash123")
|
||||
|
||||
saved_recipe = recipe_scanner._save_recipe_persistently.call_args[0][0]
|
||||
checkpoint = saved_recipe["checkpoint"]
|
||||
assert checkpoint["name"] == "Hashed Model"
|
||||
assert checkpoint["version"] == "v2.0"
|
||||
assert checkpoint["modelId"] == 888
|
||||
assert checkpoint["hash"] == "hash123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_repair_all_recipes_fallback_to_basic(setup_scanner):
|
||||
recipe_scanner, mock_civitai_client, mock_metadata_provider = setup_scanner
|
||||
|
||||
recipe = {
|
||||
"id": "r1",
|
||||
"title": "No Meta Lookup",
|
||||
"checkpoint": None,
|
||||
"gen_params": {
|
||||
"prompt": "",
|
||||
"Checkpoint": "just_a_name.safetensors"
|
||||
}
|
||||
}
|
||||
recipe_scanner._cache = SimpleNamespace(raw_data=[recipe])
|
||||
|
||||
# Mock metadata provider returning nothing
|
||||
mock_metadata_provider.get_model_by_hash.return_value = (None, "Model not found")
|
||||
|
||||
# Run
|
||||
results = await recipe_scanner.repair_all_recipes()
|
||||
|
||||
# Verify
|
||||
assert results["repaired"] == 1
|
||||
saved_recipe = recipe_scanner._save_recipe_persistently.call_args[0][0]
|
||||
assert saved_recipe["checkpoint"]["name"] == "just_a_name.safetensors"
|
||||
assert "modelId" not in saved_recipe["checkpoint"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_repair_all_recipes_progress_callback(setup_scanner):
|
||||
recipe_scanner, _, _ = setup_scanner
|
||||
|
||||
recipe_scanner._cache = SimpleNamespace(raw_data=[
|
||||
{"id": "r1", "title": "R1", "checkpoint": None},
|
||||
{"id": "r2", "title": "R2", "checkpoint": None}
|
||||
])
|
||||
|
||||
progress_calls = []
|
||||
async def progress_callback(data):
|
||||
progress_calls.append(data)
|
||||
|
||||
# Run
|
||||
await recipe_scanner.repair_all_recipes(
|
||||
progress_callback=progress_callback
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert len(progress_calls) >= 2
|
||||
assert progress_calls[-1]["status"] == "completed"
|
||||
assert progress_calls[-1]["total"] == 2
|
||||
assert progress_calls[-1]["repaired"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_repair_all_recipes_strips_runtime_fields(setup_scanner):
|
||||
recipe_scanner, mock_civitai_client, mock_metadata_provider = setup_scanner
|
||||
|
||||
# Recipe with runtime fields
|
||||
recipe = {
|
||||
"id": "r1",
|
||||
"title": "Cleanup Test",
|
||||
"checkpoint": {
|
||||
"name": "CP",
|
||||
"inLibrary": True,
|
||||
"localPath": "/path/to/cp",
|
||||
"thumbnailUrl": "thumb.jpg"
|
||||
},
|
||||
"loras": [
|
||||
{
|
||||
"name": "L1",
|
||||
"weight": 0.8,
|
||||
"inLibrary": True,
|
||||
"localPath": "/path/to/l1",
|
||||
"preview_url": "p.jpg"
|
||||
}
|
||||
],
|
||||
"gen_params": {"prompt": ""}
|
||||
}
|
||||
recipe_scanner._cache = SimpleNamespace(raw_data=[recipe])
|
||||
# Set high version to trigger repair if needed (or just ensure it processes)
|
||||
recipe["repair_version"] = 0
|
||||
|
||||
# Run
|
||||
await recipe_scanner.repair_all_recipes()
|
||||
|
||||
# Verify sanitation
|
||||
assert recipe_scanner._save_recipe_persistently.called
|
||||
saved_recipe = recipe_scanner._save_recipe_persistently.call_args[0][0]
|
||||
|
||||
# 1. Check LORA
|
||||
lora = saved_recipe["loras"][0]
|
||||
assert "inLibrary" not in lora
|
||||
assert "localPath" not in lora
|
||||
assert "preview_url" not in lora
|
||||
assert "strength" in lora # weight renamed to strength
|
||||
assert lora["strength"] == 0.8
|
||||
|
||||
# 2. Check Checkpoint
|
||||
cp = saved_recipe["checkpoint"]
|
||||
assert "inLibrary" not in cp
|
||||
assert "localPath" not in cp
|
||||
assert "thumbnailUrl" not in cp
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sanitize_recipe_for_storage(recipe_scanner):
|
||||
import sys
|
||||
import py.services.recipe_scanner
|
||||
print(f"\nDEBUG_ENV: sys.path: {sys.path}")
|
||||
print(f"DEBUG_ENV: recipe_scanner file: {py.services.recipe_scanner.__file__}")
|
||||
|
||||
recipe = {
|
||||
"loras": [{"name": "L1", "inLibrary": True, "weight": 0.5}],
|
||||
"checkpoint": {"name": "CP", "localPath": "/tmp/cp"}
|
||||
}
|
||||
|
||||
clean = recipe_scanner._sanitize_recipe_for_storage(recipe)
|
||||
|
||||
assert "inLibrary" not in clean["loras"][0]
|
||||
assert "strength" in clean["loras"][0]
|
||||
assert clean["loras"][0]["strength"] == 0.5
|
||||
assert "localPath" not in clean["checkpoint"]
|
||||
assert clean["checkpoint"]["name"] == "CP"
|
||||
Reference in New Issue
Block a user