mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Merge pull request #545 from willmiao/codex/evaluate-sqlite-cache-indexing-necessity
feat: index cached models by version id
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import List, Dict, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from natsort import natsorted
|
from natsort import natsorted
|
||||||
|
|
||||||
@@ -17,10 +17,12 @@ SUPPORTED_SORT_MODES = [
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelCache:
|
class ModelCache:
|
||||||
"""Cache structure for model data with extensible sorting"""
|
"""Cache structure for model data with extensible sorting."""
|
||||||
|
|
||||||
raw_data: List[Dict]
|
raw_data: List[Dict]
|
||||||
folders: List[str]
|
folders: List[str]
|
||||||
|
version_index: Dict[int, Dict] = field(default_factory=dict)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
# Cache for last sort: (sort_key, order) -> sorted list
|
# Cache for last sort: (sort_key, order) -> sorted list
|
||||||
@@ -28,6 +30,58 @@ class ModelCache:
|
|||||||
self._last_sorted_data: List[Dict] = []
|
self._last_sorted_data: List[Dict] = []
|
||||||
# Default sort on init
|
# Default sort on init
|
||||||
asyncio.create_task(self.resort())
|
asyncio.create_task(self.resort())
|
||||||
|
self.rebuild_version_index()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_version_id(value: Any) -> Optional[int]:
|
||||||
|
"""Normalize a potential version identifier into an integer."""
|
||||||
|
|
||||||
|
if isinstance(value, int):
|
||||||
|
return value
|
||||||
|
if isinstance(value, str):
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def rebuild_version_index(self) -> None:
|
||||||
|
"""Rebuild the version index from the current raw data."""
|
||||||
|
|
||||||
|
self.version_index = {}
|
||||||
|
for item in self.raw_data:
|
||||||
|
self.add_to_version_index(item)
|
||||||
|
|
||||||
|
def add_to_version_index(self, item: Dict) -> None:
|
||||||
|
"""Register a cache item in the version index if possible."""
|
||||||
|
|
||||||
|
civitai_data = item.get('civitai') if isinstance(item, dict) else None
|
||||||
|
if not isinstance(civitai_data, dict):
|
||||||
|
return
|
||||||
|
|
||||||
|
version_id = self._normalize_version_id(civitai_data.get('id'))
|
||||||
|
if version_id is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.version_index[version_id] = item
|
||||||
|
|
||||||
|
def remove_from_version_index(self, item: Dict) -> None:
|
||||||
|
"""Remove a cache item from the version index if present."""
|
||||||
|
|
||||||
|
civitai_data = item.get('civitai') if isinstance(item, dict) else None
|
||||||
|
if not isinstance(civitai_data, dict):
|
||||||
|
return
|
||||||
|
|
||||||
|
version_id = self._normalize_version_id(civitai_data.get('id'))
|
||||||
|
if version_id is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
existing = self.version_index.get(version_id)
|
||||||
|
if existing is item or (
|
||||||
|
isinstance(existing, dict)
|
||||||
|
and existing.get('file_path') == item.get('file_path')
|
||||||
|
):
|
||||||
|
self.version_index.pop(version_id, None)
|
||||||
|
|
||||||
async def resort(self):
|
async def resort(self):
|
||||||
"""Resort cached data according to last sort mode if set"""
|
"""Resort cached data according to last sort mode if set"""
|
||||||
@@ -41,6 +95,7 @@ class ModelCache:
|
|||||||
|
|
||||||
all_folders = set(l['folder'] for l in self.raw_data)
|
all_folders = set(l['folder'] for l in self.raw_data)
|
||||||
self.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
self.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||||
|
self.rebuild_version_index()
|
||||||
|
|
||||||
def _sort_data(self, data: List[Dict], sort_key: str, order: str) -> List[Dict]:
|
def _sort_data(self, data: List[Dict], sort_key: str, order: str) -> List[Dict]:
|
||||||
"""Sort data by sort_key and order"""
|
"""Sort data by sort_key and order"""
|
||||||
|
|||||||
@@ -634,7 +634,8 @@ class ModelScanner:
|
|||||||
if model_data:
|
if model_data:
|
||||||
# Add to cache
|
# Add to cache
|
||||||
self._cache.raw_data.append(model_data)
|
self._cache.raw_data.append(model_data)
|
||||||
|
self._cache.add_to_version_index(model_data)
|
||||||
|
|
||||||
# Update hash index if available
|
# Update hash index if available
|
||||||
if 'sha256' in model_data and 'file_path' in model_data:
|
if 'sha256' in model_data and 'file_path' in model_data:
|
||||||
self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path'])
|
self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path'])
|
||||||
@@ -661,7 +662,9 @@ class ModelScanner:
|
|||||||
for path in missing_files:
|
for path in missing_files:
|
||||||
try:
|
try:
|
||||||
model_to_remove = path_to_item[path]
|
model_to_remove = path_to_item[path]
|
||||||
|
|
||||||
|
self._cache.remove_from_version_index(model_to_remove)
|
||||||
|
|
||||||
# Update tags count
|
# Update tags count
|
||||||
for tag in model_to_remove.get('tags', []):
|
for tag in model_to_remove.get('tags', []):
|
||||||
if tag in self._tags_count:
|
if tag in self._tags_count:
|
||||||
@@ -684,6 +687,8 @@ class ModelScanner:
|
|||||||
all_folders = set(item.get('folder', '') for item in self._cache.raw_data)
|
all_folders = set(item.get('folder', '') for item in self._cache.raw_data)
|
||||||
self._cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
self._cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||||
|
|
||||||
|
self._cache.rebuild_version_index()
|
||||||
|
|
||||||
# Resort cache
|
# Resort cache
|
||||||
await self._cache.resort()
|
await self._cache.resort()
|
||||||
|
|
||||||
@@ -829,6 +834,8 @@ class ModelScanner:
|
|||||||
else:
|
else:
|
||||||
self._cache.raw_data = list(scan_result.raw_data)
|
self._cache.raw_data = list(scan_result.raw_data)
|
||||||
|
|
||||||
|
self._cache.rebuild_version_index()
|
||||||
|
|
||||||
await self._cache.resort()
|
await self._cache.resort()
|
||||||
|
|
||||||
async def _gather_model_data(
|
async def _gather_model_data(
|
||||||
@@ -934,7 +941,8 @@ class ModelScanner:
|
|||||||
|
|
||||||
# Add to cache
|
# Add to cache
|
||||||
self._cache.raw_data.append(metadata_dict)
|
self._cache.raw_data.append(metadata_dict)
|
||||||
|
self._cache.add_to_version_index(metadata_dict)
|
||||||
|
|
||||||
# Resort cache data
|
# Resort cache data
|
||||||
await self._cache.resort()
|
await self._cache.resort()
|
||||||
|
|
||||||
@@ -1076,6 +1084,9 @@ class ModelScanner:
|
|||||||
cache = await self.get_cached_data()
|
cache = await self.get_cached_data()
|
||||||
|
|
||||||
existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None)
|
existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None)
|
||||||
|
if existing_item:
|
||||||
|
cache.remove_from_version_index(existing_item)
|
||||||
|
|
||||||
if existing_item and 'tags' in existing_item:
|
if existing_item and 'tags' in existing_item:
|
||||||
for tag in existing_item.get('tags', []):
|
for tag in existing_item.get('tags', []):
|
||||||
if tag in self._tags_count:
|
if tag in self._tags_count:
|
||||||
@@ -1106,6 +1117,7 @@ class ModelScanner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
cache.raw_data.append(cache_entry)
|
cache.raw_data.append(cache_entry)
|
||||||
|
cache.add_to_version_index(cache_entry)
|
||||||
|
|
||||||
sha_value = cache_entry.get('sha256')
|
sha_value = cache_entry.get('sha256')
|
||||||
if sha_value:
|
if sha_value:
|
||||||
@@ -1117,6 +1129,8 @@ class ModelScanner:
|
|||||||
for tag in cache_entry.get('tags', []):
|
for tag in cache_entry.get('tags', []):
|
||||||
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
|
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
|
||||||
|
|
||||||
|
cache.rebuild_version_index()
|
||||||
|
|
||||||
await cache.resort()
|
await cache.resort()
|
||||||
|
|
||||||
if cache_modified:
|
if cache_modified:
|
||||||
@@ -1339,11 +1353,12 @@ class ModelScanner:
|
|||||||
# Update hash index
|
# Update hash index
|
||||||
for model in models_to_remove:
|
for model in models_to_remove:
|
||||||
file_path = model['file_path']
|
file_path = model['file_path']
|
||||||
|
self._cache.remove_from_version_index(model)
|
||||||
if hasattr(self, '_hash_index') and self._hash_index:
|
if hasattr(self, '_hash_index') and self._hash_index:
|
||||||
# Get the hash and filename before removal for duplicate checking
|
# Get the hash and filename before removal for duplicate checking
|
||||||
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||||
hash_val = model.get('sha256', '').lower()
|
hash_val = model.get('sha256', '').lower()
|
||||||
|
|
||||||
# Remove from hash index
|
# Remove from hash index
|
||||||
self._hash_index.remove_by_path(file_path, hash_val)
|
self._hash_index.remove_by_path(file_path, hash_val)
|
||||||
|
|
||||||
@@ -1352,8 +1367,9 @@ class ModelScanner:
|
|||||||
|
|
||||||
# Update cache data
|
# Update cache data
|
||||||
self._cache.raw_data = [item for item in self._cache.raw_data if item['file_path'] not in file_paths]
|
self._cache.raw_data = [item for item in self._cache.raw_data if item['file_path'] not in file_paths]
|
||||||
|
|
||||||
# Resort cache
|
# Resort cache
|
||||||
|
self._cache.rebuild_version_index()
|
||||||
await self._cache.resort()
|
await self._cache.resort()
|
||||||
|
|
||||||
await self._persist_current_cache()
|
await self._persist_current_cache()
|
||||||
@@ -1393,16 +1409,17 @@ class ModelScanner:
|
|||||||
Returns:
|
Returns:
|
||||||
bool: True if the model version exists, False otherwise
|
bool: True if the model version exists, False otherwise
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
|
normalized_id = int(model_version_id)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cache = await self.get_cached_data()
|
cache = await self.get_cached_data()
|
||||||
if not cache or not cache.raw_data:
|
if not cache:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
for item in cache.raw_data:
|
return normalized_id in cache.version_index
|
||||||
if item.get('civitai') and item['civitai'].get('id') == model_version_id:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error checking model version existence: {e}")
|
logger.error(f"Error checking model version existence: {e}")
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -243,6 +243,7 @@ async def test_initialize_in_background_uses_persisted_cache_without_full_scan(t
|
|||||||
cache = await scanner.get_cached_data()
|
cache = await scanner.get_cached_data()
|
||||||
assert len(cache.raw_data) == 1
|
assert len(cache.raw_data) == 1
|
||||||
assert cache.raw_data[0]['file_path'] == normalized
|
assert cache.raw_data[0]['file_path'] == normalized
|
||||||
|
assert cache.version_index[11]['file_path'] == normalized
|
||||||
|
|
||||||
assert scanner._hash_index.get_path('hash-one') == normalized
|
assert scanner._hash_index.get_path('hash-one') == normalized
|
||||||
|
|
||||||
@@ -301,6 +302,7 @@ async def test_load_persisted_cache_populates_cache(tmp_path: Path, monkeypatch)
|
|||||||
assert entry['file_path'] == normalized
|
assert entry['file_path'] == normalized
|
||||||
assert entry['tags'] == ['alpha']
|
assert entry['tags'] == ['alpha']
|
||||||
assert entry['civitai']['trainedWords'] == ['abc']
|
assert entry['civitai']['trainedWords'] == ['abc']
|
||||||
|
assert cache.version_index[11]['file_path'] == normalized
|
||||||
assert scanner._hash_index.get_path('hash-one') == normalized
|
assert scanner._hash_index.get_path('hash-one') == normalized
|
||||||
assert scanner._tags_count == {'alpha': 1}
|
assert scanner._tags_count == {'alpha': 1}
|
||||||
assert ws_stub.payloads[-1]['stage'] == 'loading_cache'
|
assert ws_stub.payloads[-1]['stage'] == 'loading_cache'
|
||||||
@@ -381,6 +383,66 @@ async def test_batch_delete_persists_removal(tmp_path: Path, monkeypatch):
|
|||||||
assert remaining == 0
|
assert remaining == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_version_index_tracks_version_ids(tmp_path: Path):
|
||||||
|
scanner = DummyScanner(tmp_path)
|
||||||
|
|
||||||
|
first_path = _normalize_path(tmp_path / 'alpha.txt')
|
||||||
|
second_path = _normalize_path(tmp_path / 'beta.txt')
|
||||||
|
|
||||||
|
first_entry = {
|
||||||
|
'file_path': first_path,
|
||||||
|
'file_name': 'alpha',
|
||||||
|
'model_name': 'alpha',
|
||||||
|
'folder': '',
|
||||||
|
'size': 1,
|
||||||
|
'modified': 1.0,
|
||||||
|
'sha256': 'hash-alpha',
|
||||||
|
'tags': [],
|
||||||
|
'civitai': {'id': 101, 'modelId': 1, 'name': 'alpha'},
|
||||||
|
}
|
||||||
|
|
||||||
|
second_entry = {
|
||||||
|
'file_path': second_path,
|
||||||
|
'file_name': 'beta',
|
||||||
|
'model_name': 'beta',
|
||||||
|
'folder': '',
|
||||||
|
'size': 1,
|
||||||
|
'modified': 1.0,
|
||||||
|
'sha256': 'hash-beta',
|
||||||
|
'tags': [],
|
||||||
|
'civitai': {'id': 202, 'modelId': 2, 'name': 'beta'},
|
||||||
|
}
|
||||||
|
|
||||||
|
hash_index = ModelHashIndex()
|
||||||
|
hash_index.add_entry('hash-alpha', first_path)
|
||||||
|
hash_index.add_entry('hash-beta', second_path)
|
||||||
|
|
||||||
|
scan_result = CacheBuildResult(
|
||||||
|
raw_data=[first_entry, second_entry],
|
||||||
|
hash_index=hash_index,
|
||||||
|
tags_count={},
|
||||||
|
excluded_models=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
await scanner._apply_scan_result(scan_result)
|
||||||
|
|
||||||
|
cache = await scanner.get_cached_data()
|
||||||
|
assert cache.version_index[101]['file_path'] == first_path
|
||||||
|
assert cache.version_index[202]['file_path'] == second_path
|
||||||
|
|
||||||
|
assert await scanner.check_model_version_exists(101) is True
|
||||||
|
assert await scanner.check_model_version_exists('202') is True
|
||||||
|
assert await scanner.check_model_version_exists(999) is False
|
||||||
|
|
||||||
|
removed = await scanner._batch_update_cache_for_deleted_models([first_path])
|
||||||
|
assert removed is True
|
||||||
|
|
||||||
|
cache_after = await scanner.get_cached_data()
|
||||||
|
assert 101 not in cache_after.version_index
|
||||||
|
assert await scanner.check_model_version_exists(101) is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reconcile_cache_adds_new_files_and_updates_hash_index(tmp_path: Path):
|
async def test_reconcile_cache_adds_new_files_and_updates_hash_index(tmp_path: Path):
|
||||||
first, _, _ = _create_files(tmp_path)
|
first, _, _ = _create_files(tmp_path)
|
||||||
|
|||||||
Reference in New Issue
Block a user