mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Merge pull request #545 from willmiao/codex/evaluate-sqlite-cache-indexing-necessity
feat: index cached models by version id
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
from typing import List, Dict, Tuple
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from operator import itemgetter
|
||||
from natsort import natsorted
|
||||
|
||||
@@ -17,10 +17,12 @@ SUPPORTED_SORT_MODES = [
|
||||
|
||||
@dataclass
|
||||
class ModelCache:
|
||||
"""Cache structure for model data with extensible sorting"""
|
||||
"""Cache structure for model data with extensible sorting."""
|
||||
|
||||
raw_data: List[Dict]
|
||||
folders: List[str]
|
||||
|
||||
version_index: Dict[int, Dict] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
self._lock = asyncio.Lock()
|
||||
# Cache for last sort: (sort_key, order) -> sorted list
|
||||
@@ -28,6 +30,58 @@ class ModelCache:
|
||||
self._last_sorted_data: List[Dict] = []
|
||||
# Default sort on init
|
||||
asyncio.create_task(self.resort())
|
||||
self.rebuild_version_index()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_version_id(value: Any) -> Optional[int]:
|
||||
"""Normalize a potential version identifier into an integer."""
|
||||
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
def rebuild_version_index(self) -> None:
|
||||
"""Rebuild the version index from the current raw data."""
|
||||
|
||||
self.version_index = {}
|
||||
for item in self.raw_data:
|
||||
self.add_to_version_index(item)
|
||||
|
||||
def add_to_version_index(self, item: Dict) -> None:
|
||||
"""Register a cache item in the version index if possible."""
|
||||
|
||||
civitai_data = item.get('civitai') if isinstance(item, dict) else None
|
||||
if not isinstance(civitai_data, dict):
|
||||
return
|
||||
|
||||
version_id = self._normalize_version_id(civitai_data.get('id'))
|
||||
if version_id is None:
|
||||
return
|
||||
|
||||
self.version_index[version_id] = item
|
||||
|
||||
def remove_from_version_index(self, item: Dict) -> None:
|
||||
"""Remove a cache item from the version index if present."""
|
||||
|
||||
civitai_data = item.get('civitai') if isinstance(item, dict) else None
|
||||
if not isinstance(civitai_data, dict):
|
||||
return
|
||||
|
||||
version_id = self._normalize_version_id(civitai_data.get('id'))
|
||||
if version_id is None:
|
||||
return
|
||||
|
||||
existing = self.version_index.get(version_id)
|
||||
if existing is item or (
|
||||
isinstance(existing, dict)
|
||||
and existing.get('file_path') == item.get('file_path')
|
||||
):
|
||||
self.version_index.pop(version_id, None)
|
||||
|
||||
async def resort(self):
|
||||
"""Resort cached data according to last sort mode if set"""
|
||||
@@ -41,6 +95,7 @@ class ModelCache:
|
||||
|
||||
all_folders = set(l['folder'] for l in self.raw_data)
|
||||
self.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||
self.rebuild_version_index()
|
||||
|
||||
def _sort_data(self, data: List[Dict], sort_key: str, order: str) -> List[Dict]:
|
||||
"""Sort data by sort_key and order"""
|
||||
|
||||
@@ -634,7 +634,8 @@ class ModelScanner:
|
||||
if model_data:
|
||||
# Add to cache
|
||||
self._cache.raw_data.append(model_data)
|
||||
|
||||
self._cache.add_to_version_index(model_data)
|
||||
|
||||
# Update hash index if available
|
||||
if 'sha256' in model_data and 'file_path' in model_data:
|
||||
self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path'])
|
||||
@@ -661,7 +662,9 @@ class ModelScanner:
|
||||
for path in missing_files:
|
||||
try:
|
||||
model_to_remove = path_to_item[path]
|
||||
|
||||
|
||||
self._cache.remove_from_version_index(model_to_remove)
|
||||
|
||||
# Update tags count
|
||||
for tag in model_to_remove.get('tags', []):
|
||||
if tag in self._tags_count:
|
||||
@@ -684,6 +687,8 @@ class ModelScanner:
|
||||
all_folders = set(item.get('folder', '') for item in self._cache.raw_data)
|
||||
self._cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||
|
||||
self._cache.rebuild_version_index()
|
||||
|
||||
# Resort cache
|
||||
await self._cache.resort()
|
||||
|
||||
@@ -829,6 +834,8 @@ class ModelScanner:
|
||||
else:
|
||||
self._cache.raw_data = list(scan_result.raw_data)
|
||||
|
||||
self._cache.rebuild_version_index()
|
||||
|
||||
await self._cache.resort()
|
||||
|
||||
async def _gather_model_data(
|
||||
@@ -934,7 +941,8 @@ class ModelScanner:
|
||||
|
||||
# Add to cache
|
||||
self._cache.raw_data.append(metadata_dict)
|
||||
|
||||
self._cache.add_to_version_index(metadata_dict)
|
||||
|
||||
# Resort cache data
|
||||
await self._cache.resort()
|
||||
|
||||
@@ -1076,6 +1084,9 @@ class ModelScanner:
|
||||
cache = await self.get_cached_data()
|
||||
|
||||
existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None)
|
||||
if existing_item:
|
||||
cache.remove_from_version_index(existing_item)
|
||||
|
||||
if existing_item and 'tags' in existing_item:
|
||||
for tag in existing_item.get('tags', []):
|
||||
if tag in self._tags_count:
|
||||
@@ -1106,6 +1117,7 @@ class ModelScanner:
|
||||
)
|
||||
|
||||
cache.raw_data.append(cache_entry)
|
||||
cache.add_to_version_index(cache_entry)
|
||||
|
||||
sha_value = cache_entry.get('sha256')
|
||||
if sha_value:
|
||||
@@ -1117,6 +1129,8 @@ class ModelScanner:
|
||||
for tag in cache_entry.get('tags', []):
|
||||
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
|
||||
|
||||
cache.rebuild_version_index()
|
||||
|
||||
await cache.resort()
|
||||
|
||||
if cache_modified:
|
||||
@@ -1339,11 +1353,12 @@ class ModelScanner:
|
||||
# Update hash index
|
||||
for model in models_to_remove:
|
||||
file_path = model['file_path']
|
||||
self._cache.remove_from_version_index(model)
|
||||
if hasattr(self, '_hash_index') and self._hash_index:
|
||||
# Get the hash and filename before removal for duplicate checking
|
||||
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
hash_val = model.get('sha256', '').lower()
|
||||
|
||||
|
||||
# Remove from hash index
|
||||
self._hash_index.remove_by_path(file_path, hash_val)
|
||||
|
||||
@@ -1352,8 +1367,9 @@ class ModelScanner:
|
||||
|
||||
# Update cache data
|
||||
self._cache.raw_data = [item for item in self._cache.raw_data if item['file_path'] not in file_paths]
|
||||
|
||||
|
||||
# Resort cache
|
||||
self._cache.rebuild_version_index()
|
||||
await self._cache.resort()
|
||||
|
||||
await self._persist_current_cache()
|
||||
@@ -1393,16 +1409,17 @@ class ModelScanner:
|
||||
Returns:
|
||||
bool: True if the model version exists, False otherwise
|
||||
"""
|
||||
try:
|
||||
normalized_id = int(model_version_id)
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
try:
|
||||
cache = await self.get_cached_data()
|
||||
if not cache or not cache.raw_data:
|
||||
if not cache:
|
||||
return False
|
||||
|
||||
for item in cache.raw_data:
|
||||
if item.get('civitai') and item['civitai'].get('id') == model_version_id:
|
||||
return True
|
||||
|
||||
return False
|
||||
return normalized_id in cache.version_index
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking model version existence: {e}")
|
||||
return False
|
||||
|
||||
@@ -243,6 +243,7 @@ async def test_initialize_in_background_uses_persisted_cache_without_full_scan(t
|
||||
cache = await scanner.get_cached_data()
|
||||
assert len(cache.raw_data) == 1
|
||||
assert cache.raw_data[0]['file_path'] == normalized
|
||||
assert cache.version_index[11]['file_path'] == normalized
|
||||
|
||||
assert scanner._hash_index.get_path('hash-one') == normalized
|
||||
|
||||
@@ -301,6 +302,7 @@ async def test_load_persisted_cache_populates_cache(tmp_path: Path, monkeypatch)
|
||||
assert entry['file_path'] == normalized
|
||||
assert entry['tags'] == ['alpha']
|
||||
assert entry['civitai']['trainedWords'] == ['abc']
|
||||
assert cache.version_index[11]['file_path'] == normalized
|
||||
assert scanner._hash_index.get_path('hash-one') == normalized
|
||||
assert scanner._tags_count == {'alpha': 1}
|
||||
assert ws_stub.payloads[-1]['stage'] == 'loading_cache'
|
||||
@@ -381,6 +383,66 @@ async def test_batch_delete_persists_removal(tmp_path: Path, monkeypatch):
|
||||
assert remaining == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_version_index_tracks_version_ids(tmp_path: Path):
|
||||
scanner = DummyScanner(tmp_path)
|
||||
|
||||
first_path = _normalize_path(tmp_path / 'alpha.txt')
|
||||
second_path = _normalize_path(tmp_path / 'beta.txt')
|
||||
|
||||
first_entry = {
|
||||
'file_path': first_path,
|
||||
'file_name': 'alpha',
|
||||
'model_name': 'alpha',
|
||||
'folder': '',
|
||||
'size': 1,
|
||||
'modified': 1.0,
|
||||
'sha256': 'hash-alpha',
|
||||
'tags': [],
|
||||
'civitai': {'id': 101, 'modelId': 1, 'name': 'alpha'},
|
||||
}
|
||||
|
||||
second_entry = {
|
||||
'file_path': second_path,
|
||||
'file_name': 'beta',
|
||||
'model_name': 'beta',
|
||||
'folder': '',
|
||||
'size': 1,
|
||||
'modified': 1.0,
|
||||
'sha256': 'hash-beta',
|
||||
'tags': [],
|
||||
'civitai': {'id': 202, 'modelId': 2, 'name': 'beta'},
|
||||
}
|
||||
|
||||
hash_index = ModelHashIndex()
|
||||
hash_index.add_entry('hash-alpha', first_path)
|
||||
hash_index.add_entry('hash-beta', second_path)
|
||||
|
||||
scan_result = CacheBuildResult(
|
||||
raw_data=[first_entry, second_entry],
|
||||
hash_index=hash_index,
|
||||
tags_count={},
|
||||
excluded_models=[],
|
||||
)
|
||||
|
||||
await scanner._apply_scan_result(scan_result)
|
||||
|
||||
cache = await scanner.get_cached_data()
|
||||
assert cache.version_index[101]['file_path'] == first_path
|
||||
assert cache.version_index[202]['file_path'] == second_path
|
||||
|
||||
assert await scanner.check_model_version_exists(101) is True
|
||||
assert await scanner.check_model_version_exists('202') is True
|
||||
assert await scanner.check_model_version_exists(999) is False
|
||||
|
||||
removed = await scanner._batch_update_cache_for_deleted_models([first_path])
|
||||
assert removed is True
|
||||
|
||||
cache_after = await scanner.get_cached_data()
|
||||
assert 101 not in cache_after.version_index
|
||||
assert await scanner.check_model_version_exists(101) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconcile_cache_adds_new_files_and_updates_hash_index(tmp_path: Path):
|
||||
first, _, _ = _create_files(tmp_path)
|
||||
|
||||
Reference in New Issue
Block a user