From 5264e49f2ae9cb529c7670312db98ec82fe5e4fc Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Tue, 28 Oct 2025 18:39:37 +0800 Subject: [PATCH] feat(cache): index versions by model id --- py/services/model_cache.py | 64 +++++++++++++++++++++++-- py/services/model_scanner.py | 17 ++----- tests/routes/test_misc_routes.py | 51 ++++++++++++++++++++ tests/services/test_model_cache.py | 76 +++++++++++++++++++----------- 4 files changed, 164 insertions(+), 44 deletions(-) diff --git a/py/services/model_cache.py b/py/services/model_cache.py index 648d2386..abd8b68b 100644 --- a/py/services/model_cache.py +++ b/py/services/model_cache.py @@ -25,6 +25,7 @@ class ModelCache: raw_data: List[Dict] folders: List[str] version_index: Dict[int, Dict] = field(default_factory=dict) + model_id_index: Dict[int, List[Dict[str, Any]]] = field(default_factory=dict) name_display_mode: str = "model_name" def __post_init__(self): @@ -97,14 +98,15 @@ class ModelCache: return None def rebuild_version_index(self) -> None: - """Rebuild the version index from the current raw data.""" + """Rebuild the version and model indexes from the current raw data.""" self.version_index = {} + self.model_id_index = {} for item in self.raw_data: self.add_to_version_index(item) def add_to_version_index(self, item: Dict) -> None: - """Register a cache item in the version index if possible.""" + """Register a cache item in the version/model indexes if possible.""" civitai_data = item.get('civitai') if isinstance(item, dict) else None if not isinstance(civitai_data, dict): @@ -116,8 +118,24 @@ class ModelCache: self.version_index[version_id] = item + model_id = self._normalize_version_id(civitai_data.get('modelId')) + if model_id is None: + return + + descriptor = self._build_version_descriptor(item, civitai_data, version_id) + if descriptor is None: + return + + versions = self.model_id_index.setdefault(model_id, []) + for index, existing in enumerate(versions): + if existing.get('versionId') == descriptor['versionId']: + versions[index] = descriptor + break + else: + versions.append(descriptor) + def remove_from_version_index(self, item: Dict) -> None: - """Remove a cache item from the version index if present.""" + """Remove a cache item from the version/model indexes if present.""" civitai_data = item.get('civitai') if isinstance(item, dict) else None if not isinstance(civitai_data, dict): @@ -134,6 +152,46 @@ class ModelCache: ): self.version_index.pop(version_id, None) + model_id = self._normalize_version_id(civitai_data.get('modelId')) + if model_id is None: + return + + versions = self.model_id_index.get(model_id) + if not versions: + return + + filtered = [v for v in versions if v.get('versionId') != version_id] + if filtered: + self.model_id_index[model_id] = filtered + else: + self.model_id_index.pop(model_id, None) + + def _build_version_descriptor( + self, + item: Dict, + civitai_data: Dict[str, Any], + version_id: int, + ) -> Optional[Dict[str, Any]]: + """Create a lightweight descriptor for a version entry.""" + + model_name = self._ensure_string(civitai_data.get('name')) + file_name = self._ensure_string(item.get('file_name')) + return { + 'versionId': version_id, + 'name': model_name, + 'fileName': file_name, + } + + def get_versions_by_model_id(self, model_id: Any) -> List[Dict[str, Any]]: + """Return cached version descriptors for a given model ID.""" + + normalized_id = self._normalize_version_id(model_id) + if normalized_id is None: + return [] + + versions = self.model_id_index.get(normalized_id, []) + return [dict(version) for version in versions] + async def resort(self): """Resort cached data according to last sort mode if set""" async with self._lock: diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index f035ea34..6e1e8a0a 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -1520,21 +1520,10 @@ class ModelScanner: """ try: cache = await self.get_cached_data() - if not cache or not cache.raw_data: + if not cache: return [] - - versions = [] - for item in cache.raw_data: - if (item.get('civitai') and - item['civitai'].get('modelId') == model_id and - item['civitai'].get('id')): - versions.append({ - 'versionId': item['civitai'].get('id'), - 'name': item['civitai'].get('name'), - 'fileName': item.get('file_name', '') - }) - - return versions + + return cache.get_versions_by_model_id(model_id) except Exception as e: logger.error(f"Error getting model versions: {e}") return [] diff --git a/tests/routes/test_misc_routes.py b/tests/routes/test_misc_routes.py index 77dda37e..c0565f0c 100644 --- a/tests/routes/test_misc_routes.py +++ b/tests/routes/test_misc_routes.py @@ -339,6 +339,19 @@ async def fake_scanner_factory(): return FakeScanner() +class RecordingVersionScanner: + def __init__(self, versions): + self._versions = versions + self.version_calls: list[int] = [] + + async def check_model_version_exists(self, _version_id): + return False + + async def get_model_versions_by_id(self, model_id): + self.version_calls.append(model_id) + return self._versions + + class FakeExistenceScanner: def __init__(self, existing=None): self._existing = set(existing or []) @@ -714,6 +727,44 @@ def test_ensure_handler_mapping_caches_result(): assert len(call_records) == 1, "Handler set factory should only be invoked once" +@pytest.mark.asyncio +async def test_check_model_exists_returns_local_versions(): + versions = [ + {'versionId': 11, 'name': 'v1', 'fileName': 'model-one'}, + {'versionId': 12, 'name': 'v2', 'fileName': 'model-two'}, + ] + + lora_scanner = RecordingVersionScanner(versions) + checkpoint_scanner = RecordingVersionScanner([]) + embedding_scanner = RecordingVersionScanner([]) + + async def lora_factory(): + return lora_scanner + + async def checkpoint_factory(): + return checkpoint_scanner + + async def embedding_factory(): + return embedding_scanner + + handler = ModelLibraryHandler( + ServiceRegistryAdapter( + get_lora_scanner=lora_factory, + get_checkpoint_scanner=checkpoint_factory, + get_embedding_scanner=embedding_factory, + ), + metadata_provider_factory=fake_metadata_provider_factory, + ) + + response = await handler.check_model_exists(FakeRequest(query={'modelId': '5'})) + payload = json.loads(response.text) + + assert payload['success'] is True + assert payload['modelType'] == 'lora' + assert payload['versions'] == versions + assert lora_scanner.version_calls == [5] + + def test_create_handler_set_uses_provided_dependencies(): recorded_handlers: list[dict] = [] diff --git a/tests/services/test_model_cache.py b/tests/services/test_model_cache.py index cc06ca21..4b6a8b57 100644 --- a/tests/services/test_model_cache.py +++ b/tests/services/test_model_cache.py @@ -4,38 +4,60 @@ from py.services.model_cache import ModelCache @pytest.mark.asyncio -async def test_name_sort_respects_file_name_display(): - items = [ - {"model_name": "Bravo", "file_name": "zulu", "folder": "", "size": 1, "modified": 1}, - {"model_name": "Alpha", "file_name": "alpha", "folder": "", "size": 1, "modified": 1}, - {"model_name": "Charlie", "file_name": "echo", "folder": "", "size": 1, "modified": 1}, +async def test_model_cache_tracks_versions_by_model_id(): + item_one = { + 'file_path': '/models/a.safetensors', + 'file_name': 'model-a-v1', + 'folder': '', + 'civitai': {'id': 101, 'modelId': 1, 'name': 'Alpha'}, + } + item_two = { + 'file_path': '/models/a_v2.safetensors', + 'file_name': 'model-a-v2', + 'folder': '', + 'civitai': {'id': 102, 'modelId': 1, 'name': 'Beta'}, + } + item_three = { + 'file_path': '/models/b.safetensors', + 'file_name': 'model-b', + 'folder': '', + 'civitai': {'id': 201, 'modelId': 2, 'name': 'Gamma'}, + } + + cache = ModelCache( + raw_data=[item_one, item_two, item_three], + folders=[], + name_display_mode='model_name', + ) + + versions = cache.get_versions_by_model_id(1) + assert versions == [ + {'versionId': 101, 'name': 'Alpha', 'fileName': 'model-a-v1'}, + {'versionId': 102, 'name': 'Beta', 'fileName': 'model-a-v2'}, ] - cache = ModelCache(raw_data=items, folders=[], name_display_mode="file_name") + # Returned descriptors should not allow external mutation of the cache index + versions[0]['name'] = 'mutated' + assert cache.model_id_index[1][0]['name'] == 'Alpha' - sorted_items = await cache.get_sorted_data("name", "asc") - - assert [item["file_name"] for item in sorted_items] == [ - "alpha", - "echo", - "zulu", + # Removing entries updates both indexes + cache.remove_from_version_index(item_one) + assert cache.get_versions_by_model_id(1) == [ + {'versionId': 102, 'name': 'Beta', 'fileName': 'model-a-v2'}, ] + cache.remove_from_version_index(item_two) + assert cache.get_versions_by_model_id(1) == [] + assert 1 not in cache.model_id_index -@pytest.mark.asyncio -async def test_update_name_display_mode_resorts_cached_name_order(): - items = [ - {"model_name": "Zulu", "file_name": "alpha", "folder": "", "size": 1, "modified": 1}, - {"model_name": "Alpha", "file_name": "zulu", "folder": "", "size": 1, "modified": 1}, + # Re-adding should not introduce duplicates + cache.add_to_version_index(item_two) + cache.add_to_version_index(item_two) + assert cache.get_versions_by_model_id('1') == [ + {'versionId': 102, 'name': 'Beta', 'fileName': 'model-a-v2'}, ] - cache = ModelCache(raw_data=items, folders=[], name_display_mode="model_name") - - initial = await cache.get_sorted_data("name", "asc") - assert [item["model_name"] for item in initial] == ["Alpha", "Zulu"] - - await cache.update_name_display_mode("file_name") - - # The cached name sort should refresh immediately based on the new mode - updated = await cache.get_sorted_data("name", "asc") - assert [item["file_name"] for item in updated] == ["alpha", "zulu"] + # Other model IDs remain accessible + assert cache.get_versions_by_model_id(2) == [ + {'versionId': 201, 'name': 'Gamma', 'fileName': 'model-b'}, + ]