Merge pull request #545 from willmiao/codex/evaluate-sqlite-cache-indexing-necessity

feat: index cached models by version id
This commit is contained in:
pixelpaws
2025-10-09 13:54:46 +08:00
committed by GitHub
3 changed files with 149 additions and 15 deletions

View File

@@ -1,6 +1,6 @@
import asyncio
from typing import List, Dict, Tuple
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from dataclasses import dataclass, field
from operator import itemgetter
from natsort import natsorted
@@ -17,10 +17,12 @@ SUPPORTED_SORT_MODES = [
@dataclass
class ModelCache:
"""Cache structure for model data with extensible sorting"""
"""Cache structure for model data with extensible sorting."""
raw_data: List[Dict]
folders: List[str]
version_index: Dict[int, Dict] = field(default_factory=dict)
def __post_init__(self):
self._lock = asyncio.Lock()
# Cache for last sort: (sort_key, order) -> sorted list
@@ -28,6 +30,58 @@ class ModelCache:
self._last_sorted_data: List[Dict] = []
# Default sort on init
asyncio.create_task(self.resort())
self.rebuild_version_index()
@staticmethod
def _normalize_version_id(value: Any) -> Optional[int]:
"""Normalize a potential version identifier into an integer."""
if isinstance(value, int):
return value
if isinstance(value, str):
try:
return int(value)
except ValueError:
return None
return None
def rebuild_version_index(self) -> None:
"""Rebuild the version index from the current raw data."""
self.version_index = {}
for item in self.raw_data:
self.add_to_version_index(item)
def add_to_version_index(self, item: Dict) -> None:
"""Register a cache item in the version index if possible."""
civitai_data = item.get('civitai') if isinstance(item, dict) else None
if not isinstance(civitai_data, dict):
return
version_id = self._normalize_version_id(civitai_data.get('id'))
if version_id is None:
return
self.version_index[version_id] = item
def remove_from_version_index(self, item: Dict) -> None:
"""Remove a cache item from the version index if present."""
civitai_data = item.get('civitai') if isinstance(item, dict) else None
if not isinstance(civitai_data, dict):
return
version_id = self._normalize_version_id(civitai_data.get('id'))
if version_id is None:
return
existing = self.version_index.get(version_id)
if existing is item or (
isinstance(existing, dict)
and existing.get('file_path') == item.get('file_path')
):
self.version_index.pop(version_id, None)
async def resort(self):
"""Resort cached data according to last sort mode if set"""
@@ -41,6 +95,7 @@ class ModelCache:
all_folders = set(l['folder'] for l in self.raw_data)
self.folders = sorted(list(all_folders), key=lambda x: x.lower())
self.rebuild_version_index()
def _sort_data(self, data: List[Dict], sort_key: str, order: str) -> List[Dict]:
"""Sort data by sort_key and order"""

View File

@@ -634,7 +634,8 @@ class ModelScanner:
if model_data:
# Add to cache
self._cache.raw_data.append(model_data)
self._cache.add_to_version_index(model_data)
# Update hash index if available
if 'sha256' in model_data and 'file_path' in model_data:
self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path'])
@@ -661,7 +662,9 @@ class ModelScanner:
for path in missing_files:
try:
model_to_remove = path_to_item[path]
self._cache.remove_from_version_index(model_to_remove)
# Update tags count
for tag in model_to_remove.get('tags', []):
if tag in self._tags_count:
@@ -684,6 +687,8 @@ class ModelScanner:
all_folders = set(item.get('folder', '') for item in self._cache.raw_data)
self._cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
self._cache.rebuild_version_index()
# Resort cache
await self._cache.resort()
@@ -829,6 +834,8 @@ class ModelScanner:
else:
self._cache.raw_data = list(scan_result.raw_data)
self._cache.rebuild_version_index()
await self._cache.resort()
async def _gather_model_data(
@@ -934,7 +941,8 @@ class ModelScanner:
# Add to cache
self._cache.raw_data.append(metadata_dict)
self._cache.add_to_version_index(metadata_dict)
# Resort cache data
await self._cache.resort()
@@ -1076,6 +1084,9 @@ class ModelScanner:
cache = await self.get_cached_data()
existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None)
if existing_item:
cache.remove_from_version_index(existing_item)
if existing_item and 'tags' in existing_item:
for tag in existing_item.get('tags', []):
if tag in self._tags_count:
@@ -1106,6 +1117,7 @@ class ModelScanner:
)
cache.raw_data.append(cache_entry)
cache.add_to_version_index(cache_entry)
sha_value = cache_entry.get('sha256')
if sha_value:
@@ -1117,6 +1129,8 @@ class ModelScanner:
for tag in cache_entry.get('tags', []):
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
cache.rebuild_version_index()
await cache.resort()
if cache_modified:
@@ -1339,11 +1353,12 @@ class ModelScanner:
# Update hash index
for model in models_to_remove:
file_path = model['file_path']
self._cache.remove_from_version_index(model)
if hasattr(self, '_hash_index') and self._hash_index:
# Get the hash and filename before removal for duplicate checking
file_name = os.path.splitext(os.path.basename(file_path))[0]
hash_val = model.get('sha256', '').lower()
# Remove from hash index
self._hash_index.remove_by_path(file_path, hash_val)
@@ -1352,8 +1367,9 @@ class ModelScanner:
# Update cache data
self._cache.raw_data = [item for item in self._cache.raw_data if item['file_path'] not in file_paths]
# Resort cache
self._cache.rebuild_version_index()
await self._cache.resort()
await self._persist_current_cache()
@@ -1393,16 +1409,17 @@ class ModelScanner:
Returns:
bool: True if the model version exists, False otherwise
"""
try:
normalized_id = int(model_version_id)
except (TypeError, ValueError):
return False
try:
cache = await self.get_cached_data()
if not cache or not cache.raw_data:
if not cache:
return False
for item in cache.raw_data:
if item.get('civitai') and item['civitai'].get('id') == model_version_id:
return True
return False
return normalized_id in cache.version_index
except Exception as e:
logger.error(f"Error checking model version existence: {e}")
return False

View File

@@ -243,6 +243,7 @@ async def test_initialize_in_background_uses_persisted_cache_without_full_scan(t
cache = await scanner.get_cached_data()
assert len(cache.raw_data) == 1
assert cache.raw_data[0]['file_path'] == normalized
assert cache.version_index[11]['file_path'] == normalized
assert scanner._hash_index.get_path('hash-one') == normalized
@@ -301,6 +302,7 @@ async def test_load_persisted_cache_populates_cache(tmp_path: Path, monkeypatch)
assert entry['file_path'] == normalized
assert entry['tags'] == ['alpha']
assert entry['civitai']['trainedWords'] == ['abc']
assert cache.version_index[11]['file_path'] == normalized
assert scanner._hash_index.get_path('hash-one') == normalized
assert scanner._tags_count == {'alpha': 1}
assert ws_stub.payloads[-1]['stage'] == 'loading_cache'
@@ -381,6 +383,66 @@ async def test_batch_delete_persists_removal(tmp_path: Path, monkeypatch):
assert remaining == 0
@pytest.mark.asyncio
async def test_version_index_tracks_version_ids(tmp_path: Path):
scanner = DummyScanner(tmp_path)
first_path = _normalize_path(tmp_path / 'alpha.txt')
second_path = _normalize_path(tmp_path / 'beta.txt')
first_entry = {
'file_path': first_path,
'file_name': 'alpha',
'model_name': 'alpha',
'folder': '',
'size': 1,
'modified': 1.0,
'sha256': 'hash-alpha',
'tags': [],
'civitai': {'id': 101, 'modelId': 1, 'name': 'alpha'},
}
second_entry = {
'file_path': second_path,
'file_name': 'beta',
'model_name': 'beta',
'folder': '',
'size': 1,
'modified': 1.0,
'sha256': 'hash-beta',
'tags': [],
'civitai': {'id': 202, 'modelId': 2, 'name': 'beta'},
}
hash_index = ModelHashIndex()
hash_index.add_entry('hash-alpha', first_path)
hash_index.add_entry('hash-beta', second_path)
scan_result = CacheBuildResult(
raw_data=[first_entry, second_entry],
hash_index=hash_index,
tags_count={},
excluded_models=[],
)
await scanner._apply_scan_result(scan_result)
cache = await scanner.get_cached_data()
assert cache.version_index[101]['file_path'] == first_path
assert cache.version_index[202]['file_path'] == second_path
assert await scanner.check_model_version_exists(101) is True
assert await scanner.check_model_version_exists('202') is True
assert await scanner.check_model_version_exists(999) is False
removed = await scanner._batch_update_cache_for_deleted_models([first_path])
assert removed is True
cache_after = await scanner.get_cached_data()
assert 101 not in cache_after.version_index
assert await scanner.check_model_version_exists(101) is False
@pytest.mark.asyncio
async def test_reconcile_cache_adds_new_files_and_updates_hash_index(tmp_path: Path):
first, _, _ = _create_files(tmp_path)