Merge pull request #545 from willmiao/codex/evaluate-sqlite-cache-indexing-necessity

feat: index cached models by version id
This commit is contained in:
pixelpaws
2025-10-09 13:54:46 +08:00
committed by GitHub
3 changed files with 149 additions and 15 deletions

View File

@@ -1,6 +1,6 @@
import asyncio import asyncio
from typing import List, Dict, Tuple from typing import Any, Dict, List, Optional, Tuple
from dataclasses import dataclass from dataclasses import dataclass, field
from operator import itemgetter from operator import itemgetter
from natsort import natsorted from natsort import natsorted
@@ -17,10 +17,12 @@ SUPPORTED_SORT_MODES = [
@dataclass @dataclass
class ModelCache: class ModelCache:
"""Cache structure for model data with extensible sorting""" """Cache structure for model data with extensible sorting."""
raw_data: List[Dict] raw_data: List[Dict]
folders: List[str] folders: List[str]
version_index: Dict[int, Dict] = field(default_factory=dict)
def __post_init__(self): def __post_init__(self):
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
# Cache for last sort: (sort_key, order) -> sorted list # Cache for last sort: (sort_key, order) -> sorted list
@@ -28,6 +30,58 @@ class ModelCache:
self._last_sorted_data: List[Dict] = [] self._last_sorted_data: List[Dict] = []
# Default sort on init # Default sort on init
asyncio.create_task(self.resort()) asyncio.create_task(self.resort())
self.rebuild_version_index()
@staticmethod
def _normalize_version_id(value: Any) -> Optional[int]:
"""Normalize a potential version identifier into an integer."""
if isinstance(value, int):
return value
if isinstance(value, str):
try:
return int(value)
except ValueError:
return None
return None
def rebuild_version_index(self) -> None:
"""Rebuild the version index from the current raw data."""
self.version_index = {}
for item in self.raw_data:
self.add_to_version_index(item)
def add_to_version_index(self, item: Dict) -> None:
"""Register a cache item in the version index if possible."""
civitai_data = item.get('civitai') if isinstance(item, dict) else None
if not isinstance(civitai_data, dict):
return
version_id = self._normalize_version_id(civitai_data.get('id'))
if version_id is None:
return
self.version_index[version_id] = item
def remove_from_version_index(self, item: Dict) -> None:
"""Remove a cache item from the version index if present."""
civitai_data = item.get('civitai') if isinstance(item, dict) else None
if not isinstance(civitai_data, dict):
return
version_id = self._normalize_version_id(civitai_data.get('id'))
if version_id is None:
return
existing = self.version_index.get(version_id)
if existing is item or (
isinstance(existing, dict)
and existing.get('file_path') == item.get('file_path')
):
self.version_index.pop(version_id, None)
async def resort(self): async def resort(self):
"""Resort cached data according to last sort mode if set""" """Resort cached data according to last sort mode if set"""
@@ -41,6 +95,7 @@ class ModelCache:
all_folders = set(l['folder'] for l in self.raw_data) all_folders = set(l['folder'] for l in self.raw_data)
self.folders = sorted(list(all_folders), key=lambda x: x.lower()) self.folders = sorted(list(all_folders), key=lambda x: x.lower())
self.rebuild_version_index()
def _sort_data(self, data: List[Dict], sort_key: str, order: str) -> List[Dict]: def _sort_data(self, data: List[Dict], sort_key: str, order: str) -> List[Dict]:
"""Sort data by sort_key and order""" """Sort data by sort_key and order"""

View File

@@ -634,7 +634,8 @@ class ModelScanner:
if model_data: if model_data:
# Add to cache # Add to cache
self._cache.raw_data.append(model_data) self._cache.raw_data.append(model_data)
self._cache.add_to_version_index(model_data)
# Update hash index if available # Update hash index if available
if 'sha256' in model_data and 'file_path' in model_data: if 'sha256' in model_data and 'file_path' in model_data:
self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path']) self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path'])
@@ -661,7 +662,9 @@ class ModelScanner:
for path in missing_files: for path in missing_files:
try: try:
model_to_remove = path_to_item[path] model_to_remove = path_to_item[path]
self._cache.remove_from_version_index(model_to_remove)
# Update tags count # Update tags count
for tag in model_to_remove.get('tags', []): for tag in model_to_remove.get('tags', []):
if tag in self._tags_count: if tag in self._tags_count:
@@ -684,6 +687,8 @@ class ModelScanner:
all_folders = set(item.get('folder', '') for item in self._cache.raw_data) all_folders = set(item.get('folder', '') for item in self._cache.raw_data)
self._cache.folders = sorted(list(all_folders), key=lambda x: x.lower()) self._cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
self._cache.rebuild_version_index()
# Resort cache # Resort cache
await self._cache.resort() await self._cache.resort()
@@ -829,6 +834,8 @@ class ModelScanner:
else: else:
self._cache.raw_data = list(scan_result.raw_data) self._cache.raw_data = list(scan_result.raw_data)
self._cache.rebuild_version_index()
await self._cache.resort() await self._cache.resort()
async def _gather_model_data( async def _gather_model_data(
@@ -934,7 +941,8 @@ class ModelScanner:
# Add to cache # Add to cache
self._cache.raw_data.append(metadata_dict) self._cache.raw_data.append(metadata_dict)
self._cache.add_to_version_index(metadata_dict)
# Resort cache data # Resort cache data
await self._cache.resort() await self._cache.resort()
@@ -1076,6 +1084,9 @@ class ModelScanner:
cache = await self.get_cached_data() cache = await self.get_cached_data()
existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None) existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None)
if existing_item:
cache.remove_from_version_index(existing_item)
if existing_item and 'tags' in existing_item: if existing_item and 'tags' in existing_item:
for tag in existing_item.get('tags', []): for tag in existing_item.get('tags', []):
if tag in self._tags_count: if tag in self._tags_count:
@@ -1106,6 +1117,7 @@ class ModelScanner:
) )
cache.raw_data.append(cache_entry) cache.raw_data.append(cache_entry)
cache.add_to_version_index(cache_entry)
sha_value = cache_entry.get('sha256') sha_value = cache_entry.get('sha256')
if sha_value: if sha_value:
@@ -1117,6 +1129,8 @@ class ModelScanner:
for tag in cache_entry.get('tags', []): for tag in cache_entry.get('tags', []):
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1 self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
cache.rebuild_version_index()
await cache.resort() await cache.resort()
if cache_modified: if cache_modified:
@@ -1339,11 +1353,12 @@ class ModelScanner:
# Update hash index # Update hash index
for model in models_to_remove: for model in models_to_remove:
file_path = model['file_path'] file_path = model['file_path']
self._cache.remove_from_version_index(model)
if hasattr(self, '_hash_index') and self._hash_index: if hasattr(self, '_hash_index') and self._hash_index:
# Get the hash and filename before removal for duplicate checking # Get the hash and filename before removal for duplicate checking
file_name = os.path.splitext(os.path.basename(file_path))[0] file_name = os.path.splitext(os.path.basename(file_path))[0]
hash_val = model.get('sha256', '').lower() hash_val = model.get('sha256', '').lower()
# Remove from hash index # Remove from hash index
self._hash_index.remove_by_path(file_path, hash_val) self._hash_index.remove_by_path(file_path, hash_val)
@@ -1352,8 +1367,9 @@ class ModelScanner:
# Update cache data # Update cache data
self._cache.raw_data = [item for item in self._cache.raw_data if item['file_path'] not in file_paths] self._cache.raw_data = [item for item in self._cache.raw_data if item['file_path'] not in file_paths]
# Resort cache # Resort cache
self._cache.rebuild_version_index()
await self._cache.resort() await self._cache.resort()
await self._persist_current_cache() await self._persist_current_cache()
@@ -1393,16 +1409,17 @@ class ModelScanner:
Returns: Returns:
bool: True if the model version exists, False otherwise bool: True if the model version exists, False otherwise
""" """
try:
normalized_id = int(model_version_id)
except (TypeError, ValueError):
return False
try: try:
cache = await self.get_cached_data() cache = await self.get_cached_data()
if not cache or not cache.raw_data: if not cache:
return False return False
for item in cache.raw_data: return normalized_id in cache.version_index
if item.get('civitai') and item['civitai'].get('id') == model_version_id:
return True
return False
except Exception as e: except Exception as e:
logger.error(f"Error checking model version existence: {e}") logger.error(f"Error checking model version existence: {e}")
return False return False

View File

@@ -243,6 +243,7 @@ async def test_initialize_in_background_uses_persisted_cache_without_full_scan(t
cache = await scanner.get_cached_data() cache = await scanner.get_cached_data()
assert len(cache.raw_data) == 1 assert len(cache.raw_data) == 1
assert cache.raw_data[0]['file_path'] == normalized assert cache.raw_data[0]['file_path'] == normalized
assert cache.version_index[11]['file_path'] == normalized
assert scanner._hash_index.get_path('hash-one') == normalized assert scanner._hash_index.get_path('hash-one') == normalized
@@ -301,6 +302,7 @@ async def test_load_persisted_cache_populates_cache(tmp_path: Path, monkeypatch)
assert entry['file_path'] == normalized assert entry['file_path'] == normalized
assert entry['tags'] == ['alpha'] assert entry['tags'] == ['alpha']
assert entry['civitai']['trainedWords'] == ['abc'] assert entry['civitai']['trainedWords'] == ['abc']
assert cache.version_index[11]['file_path'] == normalized
assert scanner._hash_index.get_path('hash-one') == normalized assert scanner._hash_index.get_path('hash-one') == normalized
assert scanner._tags_count == {'alpha': 1} assert scanner._tags_count == {'alpha': 1}
assert ws_stub.payloads[-1]['stage'] == 'loading_cache' assert ws_stub.payloads[-1]['stage'] == 'loading_cache'
@@ -381,6 +383,66 @@ async def test_batch_delete_persists_removal(tmp_path: Path, monkeypatch):
assert remaining == 0 assert remaining == 0
@pytest.mark.asyncio
async def test_version_index_tracks_version_ids(tmp_path: Path):
scanner = DummyScanner(tmp_path)
first_path = _normalize_path(tmp_path / 'alpha.txt')
second_path = _normalize_path(tmp_path / 'beta.txt')
first_entry = {
'file_path': first_path,
'file_name': 'alpha',
'model_name': 'alpha',
'folder': '',
'size': 1,
'modified': 1.0,
'sha256': 'hash-alpha',
'tags': [],
'civitai': {'id': 101, 'modelId': 1, 'name': 'alpha'},
}
second_entry = {
'file_path': second_path,
'file_name': 'beta',
'model_name': 'beta',
'folder': '',
'size': 1,
'modified': 1.0,
'sha256': 'hash-beta',
'tags': [],
'civitai': {'id': 202, 'modelId': 2, 'name': 'beta'},
}
hash_index = ModelHashIndex()
hash_index.add_entry('hash-alpha', first_path)
hash_index.add_entry('hash-beta', second_path)
scan_result = CacheBuildResult(
raw_data=[first_entry, second_entry],
hash_index=hash_index,
tags_count={},
excluded_models=[],
)
await scanner._apply_scan_result(scan_result)
cache = await scanner.get_cached_data()
assert cache.version_index[101]['file_path'] == first_path
assert cache.version_index[202]['file_path'] == second_path
assert await scanner.check_model_version_exists(101) is True
assert await scanner.check_model_version_exists('202') is True
assert await scanner.check_model_version_exists(999) is False
removed = await scanner._batch_update_cache_for_deleted_models([first_path])
assert removed is True
cache_after = await scanner.get_cached_data()
assert 101 not in cache_after.version_index
assert await scanner.check_model_version_exists(101) is False
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reconcile_cache_adds_new_files_and_updates_hash_index(tmp_path: Path): async def test_reconcile_cache_adds_new_files_and_updates_hash_index(tmp_path: Path):
first, _, _ = _create_files(tmp_path) first, _, _ = _create_files(tmp_path)