mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
283 lines
9.8 KiB
Python
283 lines
9.8 KiB
Python
import pytest
|
|
import asyncio
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
from py.services.recipe_scanner import RecipeScanner
|
|
from types import SimpleNamespace
|
|
|
|
# We define these here to help with spec= if needed
|
|
class MockCivitaiClient:
|
|
async def get_image_info(self, image_id):
|
|
pass
|
|
|
|
class MockPersistenceService:
|
|
async def save_recipe(self, recipe):
|
|
pass
|
|
|
|
@pytest.fixture
|
|
def mock_civitai_client():
|
|
client = MagicMock(spec=MockCivitaiClient)
|
|
client.get_image_info = AsyncMock()
|
|
return client
|
|
|
|
@pytest.fixture
|
|
def mock_metadata_provider():
|
|
provider = MagicMock()
|
|
provider.get_model_version_info = AsyncMock(return_value=(None, None))
|
|
provider.get_model_by_hash = AsyncMock(return_value=(None, None))
|
|
return provider
|
|
|
|
@pytest.fixture
|
|
def recipe_scanner():
|
|
lora_scanner = MagicMock()
|
|
lora_scanner.get_cached_data = AsyncMock(return_value=SimpleNamespace(raw_data=[]))
|
|
|
|
scanner = RecipeScanner(lora_scanner=lora_scanner)
|
|
return scanner
|
|
|
|
@pytest.fixture
|
|
def setup_scanner(recipe_scanner, mock_civitai_client, mock_metadata_provider, monkeypatch):
|
|
monkeypatch.setattr(recipe_scanner, "_get_civitai_client", AsyncMock(return_value=mock_civitai_client))
|
|
|
|
# Wrap the real method with a mock so we can check calls but still execute it
|
|
real_save = recipe_scanner._save_recipe_persistently
|
|
mock_save = AsyncMock(side_effect=real_save)
|
|
monkeypatch.setattr(recipe_scanner, "_save_recipe_persistently", mock_save)
|
|
|
|
monkeypatch.setattr("py.recipes.enrichment.get_default_metadata_provider", AsyncMock(return_value=mock_metadata_provider))
|
|
|
|
# Mock get_recipe_json_path to avoid file system issues in tests
|
|
recipe_scanner.get_recipe_json_path = AsyncMock(return_value="/tmp/test_recipe.json")
|
|
# Mock open to avoid actual file writing
|
|
monkeypatch.setattr("builtins.open", MagicMock())
|
|
monkeypatch.setattr("json.dump", MagicMock())
|
|
monkeypatch.setattr("os.path.exists", MagicMock(return_value=False)) # avoid EXIF logic
|
|
|
|
return recipe_scanner, mock_civitai_client, mock_metadata_provider
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_repair_all_recipes_skip_up_to_date(setup_scanner):
|
|
recipe_scanner, _, _ = setup_scanner
|
|
|
|
recipe_scanner._cache = SimpleNamespace(raw_data=[
|
|
{"id": "r1", "repair_version": RecipeScanner.REPAIR_VERSION, "title": "Up to date"}
|
|
])
|
|
|
|
# Run
|
|
results = await recipe_scanner.repair_all_recipes()
|
|
|
|
# Verify
|
|
assert results["repaired"] == 0
|
|
assert results["skipped"] == 1
|
|
recipe_scanner._save_recipe_persistently.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_repair_all_recipes_with_enriched_checkpoint_id(setup_scanner):
|
|
recipe_scanner, mock_civitai_client, mock_metadata_provider = setup_scanner
|
|
|
|
recipe = {
|
|
"id": "r1",
|
|
"title": "Old Recipe",
|
|
"source_url": "https://civitai.com/images/12345",
|
|
"checkpoint": None,
|
|
"gen_params": {"prompt": ""}
|
|
}
|
|
recipe_scanner._cache = SimpleNamespace(raw_data=[recipe])
|
|
|
|
# Mock image info returning modelVersionId
|
|
mock_civitai_client.get_image_info.return_value = {
|
|
"modelVersionId": 5678,
|
|
"meta": {"prompt": "a beautiful forest", "Checkpoint": "basic_name.safetensors"}
|
|
}
|
|
|
|
# Mock metadata provider returning full info
|
|
mock_metadata_provider.get_model_version_info.return_value = ({
|
|
"id": 5678,
|
|
"modelId": 1234,
|
|
"name": "v1.0",
|
|
"model": {"name": "Full Model Name"},
|
|
"baseModel": "SDXL 1.0",
|
|
"images": [{"url": "https://image.url/thumb.jpg"}],
|
|
"files": [{"type": "Model", "hashes": {"SHA256": "ABCDEF"}, "name": "full_filename.safetensors"}]
|
|
}, None)
|
|
|
|
# Run
|
|
results = await recipe_scanner.repair_all_recipes()
|
|
|
|
# Verify
|
|
assert results["repaired"] == 1
|
|
mock_metadata_provider.get_model_version_info.assert_called_with("5678")
|
|
|
|
saved_recipe = recipe_scanner._save_recipe_persistently.call_args[0][0]
|
|
checkpoint = saved_recipe["checkpoint"]
|
|
assert checkpoint["modelName"] == "Full Model Name"
|
|
assert checkpoint["modelVersionName"] == "v1.0"
|
|
assert checkpoint["modelId"] == 1234
|
|
assert checkpoint["modelVersionId"] == 5678
|
|
assert checkpoint["type"] == "checkpoint"
|
|
assert "name" not in checkpoint
|
|
assert "version" not in checkpoint
|
|
assert "hash" not in checkpoint
|
|
assert "file_name" not in checkpoint
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_repair_all_recipes_with_enriched_checkpoint_hash(setup_scanner):
|
|
recipe_scanner, mock_civitai_client, mock_metadata_provider = setup_scanner
|
|
|
|
recipe = {
|
|
"id": "r1",
|
|
"title": "Embedded Only",
|
|
"checkpoint": None,
|
|
"gen_params": {
|
|
"prompt": "",
|
|
"Model hash": "hash123"
|
|
}
|
|
}
|
|
recipe_scanner._cache = SimpleNamespace(raw_data=[recipe])
|
|
|
|
# Mock metadata provider lookup by hash
|
|
mock_metadata_provider.get_model_by_hash.return_value = ({
|
|
"id": 999,
|
|
"modelId": 888,
|
|
"name": "v2.0",
|
|
"model": {"name": "Hashed Model"},
|
|
"baseModel": "SD 1.5",
|
|
"files": [{"type": "Model", "hashes": {"SHA256": "hash123"}, "name": "hashed.safetensors"}]
|
|
}, None)
|
|
|
|
# Run
|
|
results = await recipe_scanner.repair_all_recipes()
|
|
|
|
# Verify
|
|
assert results["repaired"] == 1
|
|
mock_metadata_provider.get_model_by_hash.assert_called_with("hash123")
|
|
|
|
saved_recipe = recipe_scanner._save_recipe_persistently.call_args[0][0]
|
|
checkpoint = saved_recipe["checkpoint"]
|
|
assert checkpoint["modelName"] == "Hashed Model"
|
|
assert checkpoint["modelVersionName"] == "v2.0"
|
|
assert checkpoint["modelId"] == 888
|
|
assert checkpoint["type"] == "checkpoint"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_repair_all_recipes_fallback_to_basic(setup_scanner):
|
|
recipe_scanner, mock_civitai_client, mock_metadata_provider = setup_scanner
|
|
|
|
recipe = {
|
|
"id": "r1",
|
|
"title": "No Meta Lookup",
|
|
"checkpoint": None,
|
|
"gen_params": {
|
|
"prompt": "",
|
|
"Checkpoint": "just_a_name.safetensors"
|
|
}
|
|
}
|
|
recipe_scanner._cache = SimpleNamespace(raw_data=[recipe])
|
|
|
|
# Mock metadata provider returning nothing
|
|
mock_metadata_provider.get_model_by_hash.return_value = (None, "Model not found")
|
|
|
|
# Run
|
|
results = await recipe_scanner.repair_all_recipes()
|
|
|
|
# Verify
|
|
assert results["repaired"] == 1
|
|
saved_recipe = recipe_scanner._save_recipe_persistently.call_args[0][0]
|
|
assert saved_recipe["checkpoint"]["modelName"] == "just_a_name.safetensors"
|
|
assert saved_recipe["checkpoint"]["type"] == "checkpoint"
|
|
assert "modelId" not in saved_recipe["checkpoint"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_repair_all_recipes_progress_callback(setup_scanner):
|
|
recipe_scanner, _, _ = setup_scanner
|
|
|
|
recipe_scanner._cache = SimpleNamespace(raw_data=[
|
|
{"id": "r1", "title": "R1", "checkpoint": None},
|
|
{"id": "r2", "title": "R2", "checkpoint": None}
|
|
])
|
|
|
|
progress_calls = []
|
|
async def progress_callback(data):
|
|
progress_calls.append(data)
|
|
|
|
# Run
|
|
await recipe_scanner.repair_all_recipes(
|
|
progress_callback=progress_callback
|
|
)
|
|
|
|
# Verify
|
|
assert len(progress_calls) >= 2
|
|
assert progress_calls[-1]["status"] == "completed"
|
|
assert progress_calls[-1]["total"] == 2
|
|
assert progress_calls[-1]["repaired"] == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_repair_all_recipes_strips_runtime_fields(setup_scanner):
|
|
recipe_scanner, mock_civitai_client, mock_metadata_provider = setup_scanner
|
|
|
|
# Recipe with runtime fields
|
|
recipe = {
|
|
"id": "r1",
|
|
"title": "Cleanup Test",
|
|
"checkpoint": {
|
|
"name": "CP",
|
|
"inLibrary": True,
|
|
"localPath": "/path/to/cp",
|
|
"thumbnailUrl": "thumb.jpg"
|
|
},
|
|
"loras": [
|
|
{
|
|
"name": "L1",
|
|
"weight": 0.8,
|
|
"inLibrary": True,
|
|
"localPath": "/path/to/l1",
|
|
"preview_url": "p.jpg"
|
|
}
|
|
],
|
|
"gen_params": {"prompt": ""}
|
|
}
|
|
recipe_scanner._cache = SimpleNamespace(raw_data=[recipe])
|
|
# Set high version to trigger repair if needed (or just ensure it processes)
|
|
recipe["repair_version"] = 0
|
|
|
|
# Run
|
|
await recipe_scanner.repair_all_recipes()
|
|
|
|
# Verify sanitation
|
|
assert recipe_scanner._save_recipe_persistently.called
|
|
saved_recipe = recipe_scanner._save_recipe_persistently.call_args[0][0]
|
|
|
|
# 1. Check LORA
|
|
lora = saved_recipe["loras"][0]
|
|
assert "inLibrary" not in lora
|
|
assert "localPath" not in lora
|
|
assert "preview_url" not in lora
|
|
assert "strength" in lora # weight renamed to strength
|
|
assert lora["strength"] == 0.8
|
|
|
|
# 2. Check Checkpoint
|
|
cp = saved_recipe["checkpoint"]
|
|
assert "inLibrary" not in cp
|
|
assert "localPath" not in cp
|
|
assert "thumbnailUrl" not in cp
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sanitize_recipe_for_storage(recipe_scanner):
|
|
|
|
recipe = {
|
|
"loras": [{"name": "L1", "inLibrary": True, "weight": 0.5}],
|
|
"checkpoint": {"name": "CP", "localPath": "/tmp/cp"}
|
|
}
|
|
|
|
clean = recipe_scanner._sanitize_recipe_for_storage(recipe)
|
|
|
|
assert "inLibrary" not in clean["loras"][0]
|
|
assert "strength" in clean["loras"][0]
|
|
assert clean["loras"][0]["strength"] == 0.5
|
|
assert "localPath" not in clean["checkpoint"]
|
|
# Testing based on what enricher would produce if it ran,
|
|
# but here we are just testing the sanitizer which handles what is ALREADY there.
|
|
# However, the sanitizer doesn't rename fields, it just removes runtime ones.
|
|
# Since we changed the enricher to NOT put 'name' anymore, this test case
|
|
# should probably reflect the new fields if it's simulating a real recipe.
|
|
assert clean["checkpoint"]["name"] == "CP"
|