feat(context-menu): refresh missing license metadata

This commit is contained in:
Will Miao
2025-11-11 14:24:59 +08:00
parent 4557da8b63
commit 29bb85359e
20 changed files with 633 additions and 10 deletions

View File

@@ -1,3 +1,4 @@
import copy
import json
import logging
from types import SimpleNamespace
@@ -6,6 +7,7 @@ import pytest
from py.config import config
from py.routes.handlers.model_handlers import ModelUpdateHandler
from py.utils.metadata_manager import MetadataManager
from py.services.model_update_service import ModelUpdateRecord, ModelVersionRecord
@@ -264,3 +266,171 @@ async def test_refresh_model_updates_accepts_snake_case_ids():
call = update_service.calls[0]
assert call["target_model_ids"] == [3, 4]
@pytest.mark.asyncio
async def test_fetch_missing_license_data_updates_metadata(monkeypatch):
cache = SimpleNamespace(
raw_data=[
{"file_path": "/tmp/model1.safetensors", "civitai": {"modelId": 10}},
{"file_path": "/tmp/model2.safetensors", "civitai": {"modelId": 10}},
{"file_path": "/tmp/model3.safetensors", "civitai": {"modelId": 20}},
],
version_index={},
)
metadata_store = {
"/tmp/model1.safetensors": {"civitai": {"model": {}}},
"/tmp/model2.safetensors": {"civitai": {"model": {}}},
"/tmp/model3.safetensors": {"civitai": {"model": {}}},
}
async def fake_load(path: str):
data = metadata_store.get(path)
if data is None:
return None, False
return SimpleNamespace(to_dict=lambda: copy.deepcopy(data)), False
saved: list[tuple[str, dict]] = []
async def fake_save(path: str, metadata: dict):
saved.append((path, copy.deepcopy(metadata)))
return True
monkeypatch.setattr(MetadataManager, "load_metadata", staticmethod(fake_load))
monkeypatch.setattr(MetadataManager, "save_metadata", staticmethod(fake_save))
provider_calls: list[list[int]] = []
async def fake_bulk(model_ids):
provider_calls.append(list(model_ids))
return {
10: {
"allowNoCredit": True,
"allowCommercialUse": ["Sell"],
"allowDerivatives": True,
"allowDifferentLicense": True,
},
20: {
"allowNoCredit": False,
"allowCommercialUse": ["Image"],
"allowDerivatives": False,
"allowDifferentLicense": False,
},
}
provider = SimpleNamespace()
provider.get_model_versions_bulk = fake_bulk
async def metadata_selector(name):
assert name == "civitai_api"
return provider
handler = ModelUpdateHandler(
service=DummyService(cache),
update_service=SimpleNamespace(),
metadata_provider_selector=metadata_selector,
logger=logging.getLogger(__name__),
)
class DummyRequest:
can_read_body = True
query = {}
async def json(self):
return {}
response = await handler.fetch_missing_civitai_license_data(DummyRequest())
assert response.status == 200
payload = json.loads(response.text)
assert payload["success"] is True
assert len(payload["updated"]) == 3
assert provider_calls == [[10, 20]]
assert len(saved) == 3
first_metadata = saved[0][1]
assert first_metadata["civitai"]["model"]["allowNoCredit"] is True
assert first_metadata["civitai"]["model"]["allowCommercialUse"] == ["Sell"]
assert "missingModelIds" not in payload
assert "errors" not in payload
@pytest.mark.asyncio
async def test_fetch_missing_license_data_filters_model_ids(monkeypatch):
cache = SimpleNamespace(
raw_data=[
{"file_path": "/tmp/model1.safetensors", "civitai": {"modelId": 10}},
{"file_path": "/tmp/model2.safetensors", "civitai": {"modelId": 20}},
],
version_index={},
)
metadata_store = {
"/tmp/model1.safetensors": {"civitai": {"model": {}}},
"/tmp/model2.safetensors": {"civitai": {"model": {}}},
}
async def fake_load(path: str):
data = metadata_store.get(path)
if data is None:
return None, False
return SimpleNamespace(to_dict=lambda: copy.deepcopy(data)), False
saved: list[tuple[str, dict]] = []
async def fake_save(path: str, metadata: dict):
saved.append((path, copy.deepcopy(metadata)))
return True
monkeypatch.setattr(MetadataManager, "load_metadata", staticmethod(fake_load))
monkeypatch.setattr(MetadataManager, "save_metadata", staticmethod(fake_save))
provider_calls: list[list[int]] = []
async def fake_bulk(model_ids):
provider_calls.append(list(model_ids))
return {
10: {
"allowNoCredit": True,
"allowCommercialUse": ["Sell"],
"allowDerivatives": True,
"allowDifferentLicense": True,
},
20: {
"allowNoCredit": False,
"allowCommercialUse": ["Image"],
"allowDerivatives": False,
"allowDifferentLicense": False,
},
}
provider = SimpleNamespace()
provider.get_model_versions_bulk = fake_bulk
async def metadata_selector(name):
assert name == "civitai_api"
return provider
handler = ModelUpdateHandler(
service=DummyService(cache),
update_service=SimpleNamespace(),
metadata_provider_selector=metadata_selector,
logger=logging.getLogger(__name__),
)
class DummyRequest:
can_read_body = True
query = {}
async def json(self):
return {"modelIds": [20]}
response = await handler.fetch_missing_civitai_license_data(DummyRequest())
assert response.status == 200
payload = json.loads(response.text)
assert payload["success"] is True
assert len(payload["updated"]) == 1
assert provider_calls == [[20]]
assert len(saved) == 1