Files
ComfyUI-Lora-Manager/tests/services/test_metadata_sync_service.py
Will Miao f842ea990e feat(metadata): prevent overwriting high-quality Civitai API metadata with archive data. See #764
- Update `is_civitai_api_metadata` to exclude both "archive_db" and "civarchive" sources
- Skip Civitai metadata updates when existing metadata is higher quality than incoming archive data
- Add test to verify API metadata is preserved when CivArchive provides lower-quality data
2026-01-09 19:47:32 +08:00

537 lines
18 KiB
Python

from types import SimpleNamespace
from typing import Any, Dict
from unittest.mock import AsyncMock
import pytest
from py.services.errors import RateLimitError
from py.services.metadata_sync_service import MetadataSyncService
class DummySettings:
def __init__(self, values: dict | None = None) -> None:
self._values = values or {}
def get(self, key: str, default=None):
return self._values.get(key, default)
def build_service(
*,
settings_values: dict | None = None,
default_provider: SimpleNamespace | None = None,
provider_selector: AsyncMock | None = None,
):
metadata_manager = SimpleNamespace(
save_metadata=AsyncMock(),
hydrate_model_data=AsyncMock(side_effect=lambda payload: payload),
)
preview_service = SimpleNamespace(ensure_preview_for_metadata=AsyncMock())
settings = DummySettings(settings_values)
provider = default_provider or SimpleNamespace(
get_model_by_hash=AsyncMock(),
get_model_version=AsyncMock(),
)
if default_provider is None:
provider.get_model_by_hash.return_value = (None, None)
default_provider_factory = AsyncMock(return_value=provider)
provider_selector = provider_selector or AsyncMock(return_value=provider)
service = MetadataSyncService(
metadata_manager=metadata_manager,
preview_service=preview_service,
settings=settings,
default_metadata_provider_factory=default_provider_factory,
metadata_provider_selector=provider_selector,
)
return SimpleNamespace(
service=service,
metadata_manager=metadata_manager,
preview_service=preview_service,
default_provider=provider,
default_provider_factory=default_provider_factory,
provider_selector=provider_selector,
)
@pytest.mark.asyncio
async def test_update_model_metadata_merges_and_persists():
helpers = build_service()
local = {
"civitai": {"trainedWords": ["alpha"], "creator": {"id": 1}},
"modelDescription": "",
"tags": [],
"model_name": "Local",
}
remote = {
"source": "api",
"trainedWords": ["beta"],
"model": {
"name": "Remote Model",
"description": "desc",
"tags": ["style"],
"creator": {"id": 2},
"allowNoCredit": False,
"allowCommercialUse": ["Image"],
"allowDerivatives": False,
"allowDifferentLicense": True,
},
"baseModel": "sdxl",
"images": ["img"],
}
result = await helpers.service.update_model_metadata(
"path/to/model.metadata.json",
local,
remote,
helpers.default_provider,
)
assert set(result["civitai"]["trainedWords"]) == {"alpha", "beta"}
assert result["model_name"] == "Remote Model"
assert result["modelDescription"] == "desc"
assert result["tags"] == ["style"]
assert result["base_model"] == "SDXL 1.0"
civitai_model = result["civitai"]["model"]
assert civitai_model["allowNoCredit"] is False
assert civitai_model["allowCommercialUse"] == ["Image"]
assert civitai_model["allowDerivatives"] is False
assert civitai_model["allowDifferentLicense"] is True
for key in ("allowNoCredit", "allowCommercialUse", "allowDerivatives", "allowDifferentLicense"):
assert key not in result
helpers.preview_service.ensure_preview_for_metadata.assert_awaited_once()
helpers.metadata_manager.save_metadata.assert_awaited_once_with(
"path/to/model.metadata.json",
result,
)
@pytest.mark.asyncio
async def test_fetch_and_update_model_success_updates_cache(tmp_path):
helpers = build_service()
civitai_payload = {
"source": "api",
"model": {"name": "Remote", "description": "", "tags": ["tag"]},
"images": [],
"baseModel": "sdxl",
}
helpers.default_provider.get_model_by_hash.return_value = (civitai_payload, None)
model_path = tmp_path / "model.safetensors"
async def hydrate(payload: Dict[str, Any]) -> Dict[str, Any]:
payload["hydrated"] = True
return payload
helpers.metadata_manager.hydrate_model_data.side_effect = hydrate
model_data = {
"model_name": "Local",
"folder": "root",
"file_path": str(model_path),
}
update_cache = AsyncMock(return_value=True)
await hydrate(model_data)
helpers.metadata_manager.hydrate_model_data.reset_mock()
ok, error = await helpers.service.fetch_and_update_model(
sha256="abc",
file_path=str(model_path),
model_data=model_data,
update_cache_func=update_cache,
)
assert ok and error is None
assert model_data["from_civitai"] is True
assert model_data["civitai_deleted"] is False
assert "civitai" in model_data
assert model_data["metadata_source"] == "civitai_api"
civitai_model = model_data["civitai"]["model"]
assert civitai_model["allowNoCredit"] is True
assert civitai_model["allowDerivatives"] is True
assert civitai_model["allowDifferentLicense"] is True
assert civitai_model["allowCommercialUse"] == ["Sell"]
for key in ("allowNoCredit", "allowCommercialUse", "allowDerivatives", "allowDifferentLicense"):
assert key not in model_data
helpers.metadata_manager.hydrate_model_data.assert_not_awaited()
assert model_data["hydrated"] is True
metadata_path = str(model_path.with_suffix(".metadata.json"))
await_args = helpers.metadata_manager.save_metadata.await_args_list
assert await_args, "expected metadata to be persisted"
last_call = await_args[-1]
assert last_call.args[0] == metadata_path
persisted_payload = last_call.args[1]
assert persisted_payload["hydrated"] is True
civitai_model = persisted_payload["civitai"]["model"]
assert civitai_model["allowNoCredit"] is True
assert civitai_model["allowCommercialUse"] == ["Sell"]
for key in ("allowNoCredit", "allowCommercialUse", "allowDerivatives", "allowDifferentLicense"):
assert key not in persisted_payload
update_cache.assert_awaited_once()
@pytest.mark.asyncio
async def test_fetch_and_update_model_keeps_deleted_flag_false_for_archive_source(tmp_path):
helpers = build_service()
civitai_payload = {
"source": "archive_db",
"model": {"name": "Recovered", "description": "", "tags": ["tag"]},
"images": [],
"baseModel": "sd15",
}
helpers.default_provider.get_model_by_hash.return_value = (civitai_payload, None)
model_path = tmp_path / "model.safetensors"
model_data = {
"model_name": "Local",
"folder": "root",
"file_path": str(model_path),
}
update_cache = AsyncMock(return_value=True)
ok, error = await helpers.service.fetch_and_update_model(
sha256="abc",
file_path=str(model_path),
model_data=model_data,
update_cache_func=update_cache,
)
assert ok and error is None
assert model_data["metadata_source"] == "archive_db"
assert model_data["civitai_deleted"] is False
update_cache.assert_awaited_once()
@pytest.mark.asyncio
async def test_fetch_and_update_model_handles_missing_remote_metadata(tmp_path):
helpers = build_service()
helpers.default_provider.get_model_by_hash.return_value = (None, "Model not found")
model_path = tmp_path / "model.safetensors"
async def hydrate(payload: Dict[str, Any]) -> Dict[str, Any]:
payload["hydrated"] = True
return payload
helpers.metadata_manager.hydrate_model_data.side_effect = hydrate
model_data = {
"model_name": "Local",
"folder": "sub",
"file_path": str(model_path),
}
await hydrate(model_data)
helpers.metadata_manager.hydrate_model_data.reset_mock()
ok, error = await helpers.service.fetch_and_update_model(
sha256="missing",
file_path=str(model_path),
model_data=model_data,
update_cache_func=AsyncMock(),
)
assert not ok
assert "Model not found" in error
assert model_data["from_civitai"] is False
assert model_data["civitai_deleted"] is True
helpers.metadata_manager.hydrate_model_data.assert_not_awaited()
assert model_data["hydrated"] is True
helpers.metadata_manager.save_metadata.assert_awaited_once()
call_args = helpers.metadata_manager.save_metadata.await_args
assert call_args.args[0].endswith("model.safetensors")
assert "folder" not in call_args.args[1]
assert call_args.args[1]["hydrated"] is True
@pytest.mark.asyncio
async def test_fetch_and_update_model_respects_deleted_without_archive():
helpers = build_service(settings_values={"enable_metadata_archive_db": False})
model_data = {
"civitai_deleted": True,
"file_path": "/tmp/model.safetensors",
}
update_cache = AsyncMock()
ok, error = await helpers.service.fetch_and_update_model(
sha256="abc",
file_path="/tmp/model.safetensors",
model_data=model_data,
update_cache_func=update_cache,
)
assert not ok
assert "metadata archive DB is not enabled" in error
helpers.default_provider_factory.assert_not_awaited()
helpers.metadata_manager.hydrate_model_data.assert_not_awaited()
update_cache.assert_not_awaited()
@pytest.mark.asyncio
async def test_fetch_and_update_model_prefers_civarchive_for_deleted_models(tmp_path):
default_provider = SimpleNamespace(
get_model_by_hash=AsyncMock(),
get_model_version=AsyncMock(),
)
civarchive_provider = SimpleNamespace(
get_model_by_hash=AsyncMock(
return_value=(
{
"source": "civarchive",
"model": {"name": "Recovered", "description": "", "tags": []},
"images": [],
"baseModel": "sdxl",
},
None,
)
),
get_model_version=AsyncMock(),
)
async def select_provider(name: str):
return civarchive_provider if name == "civarchive_api" else default_provider
provider_selector = AsyncMock(side_effect=select_provider)
helpers = build_service(
settings_values={"enable_metadata_archive_db": False},
default_provider=default_provider,
provider_selector=provider_selector,
)
model_path = tmp_path / "model.safetensors"
model_data = {
"civitai_deleted": True,
"metadata_source": "civarchive",
"civitai": {"source": "civarchive"},
"file_path": str(model_path),
}
update_cache = AsyncMock()
ok, error = await helpers.service.fetch_and_update_model(
sha256="deadbeef",
file_path=str(model_path),
model_data=model_data,
update_cache_func=update_cache,
)
assert ok
assert error is None
provider_selector.assert_awaited_with("civarchive_api")
helpers.default_provider_factory.assert_not_awaited()
civarchive_provider.get_model_by_hash.assert_awaited_once_with("deadbeef")
update_cache.assert_awaited()
assert model_data["metadata_source"] == "civarchive"
assert model_data["civitai_deleted"] is True
helpers.metadata_manager.save_metadata.assert_awaited()
@pytest.mark.asyncio
async def test_fetch_and_update_model_falls_back_to_sqlite_after_civarchive_failure(tmp_path):
default_provider = SimpleNamespace(
get_model_by_hash=AsyncMock(),
get_model_version=AsyncMock(),
)
civarchive_provider = SimpleNamespace(
get_model_by_hash=AsyncMock(return_value=(None, "Model not found")),
get_model_version=AsyncMock(),
)
sqlite_payload = {
"source": "archive_db",
"model": {"name": "Recovered", "description": "", "tags": []},
"images": [],
"baseModel": "sdxl",
}
sqlite_provider = SimpleNamespace(
get_model_by_hash=AsyncMock(return_value=(sqlite_payload, None)),
get_model_version=AsyncMock(),
)
async def select_provider(name: str):
if name == "civarchive_api":
return civarchive_provider
if name == "sqlite":
return sqlite_provider
return default_provider
provider_selector = AsyncMock(side_effect=select_provider)
helpers = build_service(
settings_values={"enable_metadata_archive_db": True},
default_provider=default_provider,
provider_selector=provider_selector,
)
model_path = tmp_path / "model.safetensors"
model_data = {
"civitai_deleted": True,
"db_checked": False,
"file_path": str(model_path),
}
update_cache = AsyncMock()
ok, error = await helpers.service.fetch_and_update_model(
sha256="cafe",
file_path=str(model_path),
model_data=model_data,
update_cache_func=update_cache,
)
assert ok and error is None
assert civarchive_provider.get_model_by_hash.await_count == 1
assert sqlite_provider.get_model_by_hash.await_count == 1
assert model_data["metadata_source"] == "archive_db"
assert model_data["db_checked"] is True
assert model_data["civitai_deleted"] is True
assert provider_selector.await_args_list[0].args == ("civarchive_api",)
assert provider_selector.await_args_list[1].args == ("sqlite",)
update_cache.assert_awaited()
helpers.metadata_manager.save_metadata.assert_awaited()
@pytest.mark.asyncio
async def test_fetch_and_update_model_returns_rate_limit_error(tmp_path):
rate_error = RateLimitError("limited", retry_after=7)
default_provider = SimpleNamespace(
get_model_by_hash=AsyncMock(side_effect=rate_error),
get_model_version=AsyncMock(),
)
helpers = build_service(default_provider=default_provider)
model_path = tmp_path / "model.safetensors"
model_data = {
"file_path": str(model_path),
"model_name": "Local",
}
update_cache = AsyncMock()
ok, error = await helpers.service.fetch_and_update_model(
sha256="deadbeef",
file_path=str(model_path),
model_data=model_data,
update_cache_func=update_cache,
)
assert ok is False
assert error is not None and "Rate limited" in error
assert "7" in error
helpers.metadata_manager.save_metadata.assert_not_awaited()
update_cache.assert_not_awaited()
helpers.provider_selector.assert_not_awaited()
@pytest.mark.asyncio
async def test_relink_metadata_fetches_version_without_overwriting_sha(tmp_path):
provider = SimpleNamespace(
get_model_by_hash=AsyncMock(),
get_model_version=AsyncMock(
return_value={
"files": [
{
"primary": True,
"type": "Model",
"hashes": {"SHA256": "ABCDEF"},
}
],
"model": {"name": "Remote"},
"images": [],
}
),
)
helpers = build_service(default_provider=provider)
metadata = {"model_name": "Local", "sha256": "original"}
result = await helpers.service.relink_metadata(
file_path=str(tmp_path / "model.safetensors"),
metadata=metadata,
model_id=1,
model_version_id=2,
)
assert result["model_name"] == "Remote"
assert result["sha256"] == "original"
helpers.metadata_manager.save_metadata.assert_awaited_once()
@pytest.mark.asyncio
async def test_relink_metadata_raises_when_version_missing():
provider = SimpleNamespace(
get_model_by_hash=AsyncMock(),
get_model_version=AsyncMock(return_value=None),
)
helpers = build_service(default_provider=provider)
with pytest.raises(ValueError):
await helpers.service.relink_metadata(
file_path="/tmp/model.safetensors",
metadata={},
model_id=9,
model_version_id=None,
)
@pytest.mark.asyncio
async def test_fetch_and_update_model_does_not_overwrite_api_metadata_with_archive(tmp_path):
helpers = build_service()
# Existing high-quality metadata
existing_civitai = {
"source": "api", # will be normalized to civitai_api in some paths, but let's use what is_civitai_api_metadata expects
"files": [{"id": 1}],
"images": [{"url": "img1"}],
"name": "High Quality",
"trainedWords": ["keyword1"]
}
# Incoming lower-quality metadata from CivArchive (simulating fallback)
civarchive_payload = {
"source": "civarchive",
"model": {"name": "Low Quality", "description": "low quality", "tags": []},
"images": [], # Missing images
"baseModel": "sdxl",
"trainedWords": ["keyword2"]
}
helpers.default_provider.get_model_by_hash.return_value = (civarchive_payload, None)
model_path = tmp_path / "model.safetensors"
model_data = {
"model_name": "High Quality",
"metadata_source": "civitai_api",
"civitai": existing_civitai,
"file_path": str(model_path),
}
update_cache = AsyncMock(return_value=True)
ok, error = await helpers.service.fetch_and_update_model(
sha256="abc",
file_path=str(model_path),
model_data=model_data,
update_cache_func=update_cache,
)
assert ok and error is None
# Ensure the civitai block still contains the high-quality data
assert model_data["civitai"]["name"] == "High Quality"
assert "keyword1" in model_data["civitai"]["trainedWords"]
# Source might be updated in model_data root, but the block should be protected if logic works
assert model_data["metadata_source"] == "civarchive"
# Check that trained words were merged if any (though in this case we might skip the whole update)
# Actually, according to the new logic, the update is SKIPPED entirely for the civitai block
assert model_data["civitai"]["trainedWords"] == ["keyword1"]
helpers.metadata_manager.save_metadata.assert_awaited()
update_cache.assert_awaited()