diff --git a/py/services/model_cache.py b/py/services/model_cache.py index f67b2444..e3c94cee 100644 --- a/py/services/model_cache.py +++ b/py/services/model_cache.py @@ -1,6 +1,6 @@ import asyncio -from typing import List, Dict, Tuple -from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple +from dataclasses import dataclass, field from operator import itemgetter from natsort import natsorted @@ -17,10 +17,12 @@ SUPPORTED_SORT_MODES = [ @dataclass class ModelCache: - """Cache structure for model data with extensible sorting""" + """Cache structure for model data with extensible sorting.""" + raw_data: List[Dict] folders: List[str] - + version_index: Dict[int, Dict] = field(default_factory=dict) + def __post_init__(self): self._lock = asyncio.Lock() # Cache for last sort: (sort_key, order) -> sorted list @@ -28,6 +30,58 @@ class ModelCache: self._last_sorted_data: List[Dict] = [] # Default sort on init asyncio.create_task(self.resort()) + self.rebuild_version_index() + + @staticmethod + def _normalize_version_id(value: Any) -> Optional[int]: + """Normalize a potential version identifier into an integer.""" + + if isinstance(value, int): + return value + if isinstance(value, str): + try: + return int(value) + except ValueError: + return None + return None + + def rebuild_version_index(self) -> None: + """Rebuild the version index from the current raw data.""" + + self.version_index = {} + for item in self.raw_data: + self.add_to_version_index(item) + + def add_to_version_index(self, item: Dict) -> None: + """Register a cache item in the version index if possible.""" + + civitai_data = item.get('civitai') if isinstance(item, dict) else None + if not isinstance(civitai_data, dict): + return + + version_id = self._normalize_version_id(civitai_data.get('id')) + if version_id is None: + return + + self.version_index[version_id] = item + + def remove_from_version_index(self, item: Dict) -> None: + """Remove a cache item from the version index if present.""" + + civitai_data = item.get('civitai') if isinstance(item, dict) else None + if not isinstance(civitai_data, dict): + return + + version_id = self._normalize_version_id(civitai_data.get('id')) + if version_id is None: + return + + existing = self.version_index.get(version_id) + if existing is item or ( + isinstance(existing, dict) + and existing.get('file_path') == item.get('file_path') + ): + self.version_index.pop(version_id, None) async def resort(self): """Resort cached data according to last sort mode if set""" @@ -41,6 +95,7 @@ class ModelCache: all_folders = set(l['folder'] for l in self.raw_data) self.folders = sorted(list(all_folders), key=lambda x: x.lower()) + self.rebuild_version_index() def _sort_data(self, data: List[Dict], sort_key: str, order: str) -> List[Dict]: """Sort data by sort_key and order""" diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index 5822b223..98eb61b6 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -634,7 +634,8 @@ class ModelScanner: if model_data: # Add to cache self._cache.raw_data.append(model_data) - + self._cache.add_to_version_index(model_data) + # Update hash index if available if 'sha256' in model_data and 'file_path' in model_data: self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path']) @@ -661,7 +662,9 @@ class ModelScanner: for path in missing_files: try: model_to_remove = path_to_item[path] - + + self._cache.remove_from_version_index(model_to_remove) + # Update tags count for tag in model_to_remove.get('tags', []): if tag in self._tags_count: @@ -684,6 +687,8 @@ class ModelScanner: all_folders = set(item.get('folder', '') for item in self._cache.raw_data) self._cache.folders = sorted(list(all_folders), key=lambda x: x.lower()) + self._cache.rebuild_version_index() + # Resort cache await self._cache.resort() @@ -829,6 +834,8 @@ class ModelScanner: else: self._cache.raw_data = list(scan_result.raw_data) + self._cache.rebuild_version_index() + await self._cache.resort() async def _gather_model_data( @@ -934,7 +941,8 @@ class ModelScanner: # Add to cache self._cache.raw_data.append(metadata_dict) - + self._cache.add_to_version_index(metadata_dict) + # Resort cache data await self._cache.resort() @@ -1076,6 +1084,9 @@ class ModelScanner: cache = await self.get_cached_data() existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None) + if existing_item: + cache.remove_from_version_index(existing_item) + if existing_item and 'tags' in existing_item: for tag in existing_item.get('tags', []): if tag in self._tags_count: @@ -1106,6 +1117,7 @@ class ModelScanner: ) cache.raw_data.append(cache_entry) + cache.add_to_version_index(cache_entry) sha_value = cache_entry.get('sha256') if sha_value: @@ -1117,6 +1129,8 @@ class ModelScanner: for tag in cache_entry.get('tags', []): self._tags_count[tag] = self._tags_count.get(tag, 0) + 1 + cache.rebuild_version_index() + await cache.resort() if cache_modified: @@ -1339,11 +1353,12 @@ class ModelScanner: # Update hash index for model in models_to_remove: file_path = model['file_path'] + self._cache.remove_from_version_index(model) if hasattr(self, '_hash_index') and self._hash_index: # Get the hash and filename before removal for duplicate checking file_name = os.path.splitext(os.path.basename(file_path))[0] hash_val = model.get('sha256', '').lower() - + # Remove from hash index self._hash_index.remove_by_path(file_path, hash_val) @@ -1352,8 +1367,9 @@ class ModelScanner: # Update cache data self._cache.raw_data = [item for item in self._cache.raw_data if item['file_path'] not in file_paths] - + # Resort cache + self._cache.rebuild_version_index() await self._cache.resort() await self._persist_current_cache() @@ -1393,16 +1409,17 @@ class ModelScanner: Returns: bool: True if the model version exists, False otherwise """ + try: + normalized_id = int(model_version_id) + except (TypeError, ValueError): + return False + try: cache = await self.get_cached_data() - if not cache or not cache.raw_data: + if not cache: return False - for item in cache.raw_data: - if item.get('civitai') and item['civitai'].get('id') == model_version_id: - return True - - return False + return normalized_id in cache.version_index except Exception as e: logger.error(f"Error checking model version existence: {e}") return False diff --git a/tests/services/test_model_scanner.py b/tests/services/test_model_scanner.py index a077caeb..78505f3c 100644 --- a/tests/services/test_model_scanner.py +++ b/tests/services/test_model_scanner.py @@ -243,6 +243,7 @@ async def test_initialize_in_background_uses_persisted_cache_without_full_scan(t cache = await scanner.get_cached_data() assert len(cache.raw_data) == 1 assert cache.raw_data[0]['file_path'] == normalized + assert cache.version_index[11]['file_path'] == normalized assert scanner._hash_index.get_path('hash-one') == normalized @@ -301,6 +302,7 @@ async def test_load_persisted_cache_populates_cache(tmp_path: Path, monkeypatch) assert entry['file_path'] == normalized assert entry['tags'] == ['alpha'] assert entry['civitai']['trainedWords'] == ['abc'] + assert cache.version_index[11]['file_path'] == normalized assert scanner._hash_index.get_path('hash-one') == normalized assert scanner._tags_count == {'alpha': 1} assert ws_stub.payloads[-1]['stage'] == 'loading_cache' @@ -381,6 +383,66 @@ async def test_batch_delete_persists_removal(tmp_path: Path, monkeypatch): assert remaining == 0 +@pytest.mark.asyncio +async def test_version_index_tracks_version_ids(tmp_path: Path): + scanner = DummyScanner(tmp_path) + + first_path = _normalize_path(tmp_path / 'alpha.txt') + second_path = _normalize_path(tmp_path / 'beta.txt') + + first_entry = { + 'file_path': first_path, + 'file_name': 'alpha', + 'model_name': 'alpha', + 'folder': '', + 'size': 1, + 'modified': 1.0, + 'sha256': 'hash-alpha', + 'tags': [], + 'civitai': {'id': 101, 'modelId': 1, 'name': 'alpha'}, + } + + second_entry = { + 'file_path': second_path, + 'file_name': 'beta', + 'model_name': 'beta', + 'folder': '', + 'size': 1, + 'modified': 1.0, + 'sha256': 'hash-beta', + 'tags': [], + 'civitai': {'id': 202, 'modelId': 2, 'name': 'beta'}, + } + + hash_index = ModelHashIndex() + hash_index.add_entry('hash-alpha', first_path) + hash_index.add_entry('hash-beta', second_path) + + scan_result = CacheBuildResult( + raw_data=[first_entry, second_entry], + hash_index=hash_index, + tags_count={}, + excluded_models=[], + ) + + await scanner._apply_scan_result(scan_result) + + cache = await scanner.get_cached_data() + assert cache.version_index[101]['file_path'] == first_path + assert cache.version_index[202]['file_path'] == second_path + + assert await scanner.check_model_version_exists(101) is True + assert await scanner.check_model_version_exists('202') is True + assert await scanner.check_model_version_exists(999) is False + + removed = await scanner._batch_update_cache_for_deleted_models([first_path]) + assert removed is True + + cache_after = await scanner.get_cached_data() + assert 101 not in cache_after.version_index + assert await scanner.check_model_version_exists(101) is False + + @pytest.mark.asyncio async def test_reconcile_cache_adds_new_files_and_updates_hash_index(tmp_path: Path): first, _, _ = _create_files(tmp_path)