mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 14:42:11 -03:00
feat(persistent-cache): implement SQLite-based persistent model cache with loading and saving functionality
This commit is contained in:
@@ -5,7 +5,7 @@ import asyncio
|
||||
import time
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from typing import Awaitable, Callable, List, Dict, Optional, Type, Set
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Set, Type, Union
|
||||
|
||||
from ..utils.models import BaseModelMetadata
|
||||
from ..config import config
|
||||
@@ -17,6 +17,7 @@ from ..utils.constants import PREVIEW_EXTENSIONS
|
||||
from .model_lifecycle_service import delete_model_artifacts
|
||||
from .service_registry import ServiceRegistry
|
||||
from .websocket_manager import ws_manager
|
||||
from .persistent_model_cache import get_persistent_cache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -79,6 +80,7 @@ class ModelScanner:
|
||||
self._tags_count = {} # Dictionary to store tag counts
|
||||
self._is_initializing = False # Flag to track initialization state
|
||||
self._excluded_models = [] # List to track excluded models
|
||||
self._persistent_cache = get_persistent_cache()
|
||||
self._initialized = True
|
||||
|
||||
# Register this service
|
||||
@@ -89,6 +91,95 @@ class ModelScanner:
|
||||
service_name = f"{self.model_type}_scanner"
|
||||
await ServiceRegistry.register_service(service_name, self)
|
||||
|
||||
def _slim_civitai_payload(self, civitai: Optional[Mapping[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
"""Return a lightweight civitai payload containing only frequently used keys."""
|
||||
if not isinstance(civitai, Mapping) or not civitai:
|
||||
return None
|
||||
|
||||
slim: Dict[str, Any] = {}
|
||||
for key in ('id', 'modelId', 'name'):
|
||||
value = civitai.get(key)
|
||||
if value not in (None, '', []):
|
||||
slim[key] = value
|
||||
|
||||
trained_words = civitai.get('trainedWords')
|
||||
if trained_words:
|
||||
slim['trainedWords'] = list(trained_words) if isinstance(trained_words, list) else trained_words
|
||||
|
||||
return slim or None
|
||||
|
||||
def _build_cache_entry(
|
||||
self,
|
||||
source: Union[BaseModelMetadata, Mapping[str, Any]],
|
||||
*,
|
||||
folder: Optional[str] = None,
|
||||
file_path_override: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Project metadata into the lightweight cache representation."""
|
||||
is_mapping = isinstance(source, Mapping)
|
||||
|
||||
def get_value(key: str, default: Any = None) -> Any:
|
||||
if is_mapping:
|
||||
return source.get(key, default)
|
||||
return getattr(source, key, default)
|
||||
|
||||
file_path = file_path_override or get_value('file_path', '') or ''
|
||||
normalized_path = file_path.replace('\\', '/')
|
||||
|
||||
folder_value = folder if folder is not None else get_value('folder', '') or ''
|
||||
normalized_folder = folder_value.replace('\\', '/')
|
||||
|
||||
tags_value = get_value('tags') or []
|
||||
if isinstance(tags_value, list):
|
||||
tags_list = list(tags_value)
|
||||
elif isinstance(tags_value, (set, tuple)):
|
||||
tags_list = list(tags_value)
|
||||
else:
|
||||
tags_list = []
|
||||
|
||||
preview_url = get_value('preview_url', '') or ''
|
||||
if isinstance(preview_url, str):
|
||||
preview_url = preview_url.replace('\\', '/')
|
||||
else:
|
||||
preview_url = ''
|
||||
|
||||
civitai_slim = self._slim_civitai_payload(get_value('civitai'))
|
||||
usage_tips = get_value('usage_tips', '') or ''
|
||||
if not isinstance(usage_tips, str):
|
||||
usage_tips = str(usage_tips)
|
||||
notes = get_value('notes', '') or ''
|
||||
if not isinstance(notes, str):
|
||||
notes = str(notes)
|
||||
|
||||
entry: Dict[str, Any] = {
|
||||
'file_path': normalized_path,
|
||||
'file_name': get_value('file_name', '') or '',
|
||||
'model_name': get_value('model_name', '') or '',
|
||||
'folder': normalized_folder,
|
||||
'size': int(get_value('size', 0) or 0),
|
||||
'modified': float(get_value('modified', 0.0) or 0.0),
|
||||
'sha256': (get_value('sha256', '') or '').lower(),
|
||||
'base_model': get_value('base_model', '') or '',
|
||||
'preview_url': preview_url,
|
||||
'preview_nsfw_level': int(get_value('preview_nsfw_level', 0) or 0),
|
||||
'from_civitai': bool(get_value('from_civitai', True)),
|
||||
'favorite': bool(get_value('favorite', False)),
|
||||
'notes': notes,
|
||||
'usage_tips': usage_tips,
|
||||
'exclude': bool(get_value('exclude', False)),
|
||||
'db_checked': bool(get_value('db_checked', False)),
|
||||
'last_checked_at': float(get_value('last_checked_at', 0.0) or 0.0),
|
||||
'tags': tags_list,
|
||||
'civitai': civitai_slim,
|
||||
'civitai_deleted': bool(get_value('civitai_deleted', False)),
|
||||
}
|
||||
|
||||
model_type = get_value('model_type', None)
|
||||
if model_type:
|
||||
entry['model_type'] = model_type
|
||||
|
||||
return entry
|
||||
|
||||
async def initialize_in_background(self) -> None:
|
||||
"""Initialize cache in background using thread pool"""
|
||||
try:
|
||||
@@ -113,7 +204,9 @@ class ModelScanner:
|
||||
'scanner_type': self.model_type,
|
||||
'pageType': page_type
|
||||
})
|
||||
|
||||
|
||||
await self._load_persisted_cache(page_type)
|
||||
|
||||
# If cache loading failed, proceed with full scan
|
||||
await ws_manager.broadcast_init_progress({
|
||||
'stage': 'scan_folders',
|
||||
@@ -150,7 +243,8 @@ class ModelScanner:
|
||||
|
||||
if scan_result:
|
||||
await self._apply_scan_result(scan_result)
|
||||
|
||||
await self._save_persistent_cache(scan_result)
|
||||
|
||||
# Send final progress update
|
||||
await ws_manager.broadcast_init_progress({
|
||||
'stage': 'finalizing',
|
||||
@@ -179,6 +273,105 @@ class ModelScanner:
|
||||
# Always clear the initializing flag when done
|
||||
self._is_initializing = False
|
||||
|
||||
async def _load_persisted_cache(self, page_type: str) -> bool:
|
||||
"""Attempt to hydrate the in-memory cache from the SQLite snapshot."""
|
||||
if not getattr(self, '_persistent_cache', None):
|
||||
return False
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
persisted = await loop.run_in_executor(
|
||||
None,
|
||||
self._persistent_cache.load_cache,
|
||||
self.model_type
|
||||
)
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
except Exception as exc:
|
||||
logger.debug("%s Scanner: Could not load persisted cache: %s", self.model_type.capitalize(), exc)
|
||||
return False
|
||||
|
||||
if not persisted or not persisted.raw_data:
|
||||
return False
|
||||
|
||||
hash_index = ModelHashIndex()
|
||||
for sha_value, path in persisted.hash_rows:
|
||||
if sha_value and path:
|
||||
hash_index.add_entry(sha_value.lower(), path)
|
||||
|
||||
tags_count: Dict[str, int] = {}
|
||||
for item in persisted.raw_data:
|
||||
for tag in item.get('tags') or []:
|
||||
tags_count[tag] = tags_count.get(tag, 0) + 1
|
||||
|
||||
scan_result = CacheBuildResult(
|
||||
raw_data=list(persisted.raw_data),
|
||||
hash_index=hash_index,
|
||||
tags_count=tags_count,
|
||||
excluded_models=list(persisted.excluded_models)
|
||||
)
|
||||
|
||||
await self._apply_scan_result(scan_result)
|
||||
|
||||
await ws_manager.broadcast_init_progress({
|
||||
'stage': 'loading_cache',
|
||||
'progress': 1,
|
||||
'details': f"Loaded cached {self.model_type} data from disk",
|
||||
'scanner_type': self.model_type,
|
||||
'pageType': page_type
|
||||
})
|
||||
return True
|
||||
|
||||
async def _save_persistent_cache(self, scan_result: CacheBuildResult) -> None:
|
||||
if not scan_result or not getattr(self, '_persistent_cache', None):
|
||||
return
|
||||
|
||||
hash_snapshot = self._build_hash_index_snapshot(scan_result.hash_index)
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
self._persistent_cache.save_cache,
|
||||
self.model_type,
|
||||
list(scan_result.raw_data),
|
||||
hash_snapshot,
|
||||
list(scan_result.excluded_models)
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("%s Scanner: Failed to persist cache: %s", self.model_type.capitalize(), exc)
|
||||
|
||||
def _build_hash_index_snapshot(self, hash_index: Optional[ModelHashIndex]) -> Dict[str, List[str]]:
|
||||
snapshot: Dict[str, List[str]] = {}
|
||||
if not hash_index:
|
||||
return snapshot
|
||||
|
||||
for sha_value, path in getattr(hash_index, '_hash_to_path', {}).items():
|
||||
if not sha_value or not path:
|
||||
continue
|
||||
bucket = snapshot.setdefault(sha_value.lower(), [])
|
||||
if path not in bucket:
|
||||
bucket.append(path)
|
||||
|
||||
for sha_value, paths in getattr(hash_index, '_duplicate_hashes', {}).items():
|
||||
if not sha_value:
|
||||
continue
|
||||
bucket = snapshot.setdefault(sha_value.lower(), [])
|
||||
for path in paths:
|
||||
if path and path not in bucket:
|
||||
bucket.append(path)
|
||||
return snapshot
|
||||
|
||||
async def _persist_current_cache(self) -> None:
|
||||
if self._cache is None or not getattr(self, '_persistent_cache', None):
|
||||
return
|
||||
|
||||
snapshot = CacheBuildResult(
|
||||
raw_data=list(self._cache.raw_data),
|
||||
hash_index=self._hash_index,
|
||||
tags_count=dict(self._tags_count),
|
||||
excluded_models=list(self._excluded_models)
|
||||
)
|
||||
await self._save_persistent_cache(snapshot)
|
||||
def _count_model_files(self) -> int:
|
||||
"""Count all model files with supported extensions in all roots
|
||||
|
||||
@@ -300,6 +493,7 @@ class ModelScanner:
|
||||
# Scan for new data
|
||||
scan_result = await self._gather_model_data()
|
||||
await self._apply_scan_result(scan_result)
|
||||
await self._save_persistent_cache(scan_result)
|
||||
|
||||
logger.info(
|
||||
f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, "
|
||||
@@ -594,8 +788,12 @@ class ModelScanner:
|
||||
# Hook: allow subclasses to adjust metadata
|
||||
metadata = self.adjust_metadata(metadata, file_path, root_path)
|
||||
|
||||
model_data = metadata.to_dict()
|
||||
|
||||
rel_path = os.path.relpath(file_path, root_path)
|
||||
folder = os.path.dirname(rel_path)
|
||||
normalized_folder = folder.replace(os.path.sep, '/')
|
||||
|
||||
model_data = self._build_cache_entry(metadata, folder=normalized_folder)
|
||||
|
||||
# Skip excluded models
|
||||
if model_data.get('exclude', False):
|
||||
excluded_models.append(model_data['file_path'])
|
||||
@@ -609,10 +807,6 @@ class ModelScanner:
|
||||
# if existing_path and existing_path != file_path:
|
||||
# logger.warning(f"Duplicate filename detected: '{filename}' - files: '{existing_path}' and '{file_path}'")
|
||||
|
||||
rel_path = os.path.relpath(file_path, root_path)
|
||||
folder = os.path.dirname(rel_path)
|
||||
model_data['folder'] = folder.replace(os.path.sep, '/')
|
||||
|
||||
return model_data
|
||||
|
||||
async def _apply_scan_result(self, scan_result: CacheBuildResult) -> None:
|
||||
@@ -749,6 +943,7 @@ class ModelScanner:
|
||||
|
||||
# Update the hash index
|
||||
self._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
|
||||
await self._persist_current_cache()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding model to cache: {e}")
|
||||
@@ -894,28 +1089,30 @@ class ModelScanner:
|
||||
]
|
||||
|
||||
if metadata:
|
||||
if original_path == new_path:
|
||||
existing_folder = next((item['folder'] for item in cache.raw_data
|
||||
if item['file_path'] == original_path), None)
|
||||
if existing_folder:
|
||||
metadata['folder'] = existing_folder
|
||||
else:
|
||||
metadata['folder'] = self._calculate_folder(new_path)
|
||||
normalized_new_path = new_path.replace(os.sep, '/')
|
||||
if original_path == new_path and existing_item:
|
||||
folder_value = existing_item.get('folder', self._calculate_folder(new_path))
|
||||
else:
|
||||
metadata['folder'] = self._calculate_folder(new_path)
|
||||
|
||||
cache.raw_data.append(metadata)
|
||||
|
||||
if 'sha256' in metadata:
|
||||
self._hash_index.add_entry(metadata['sha256'].lower(), new_path)
|
||||
|
||||
folder_value = self._calculate_folder(new_path)
|
||||
|
||||
cache_entry = self._build_cache_entry(
|
||||
metadata,
|
||||
folder=folder_value,
|
||||
file_path_override=normalized_new_path,
|
||||
)
|
||||
|
||||
cache.raw_data.append(cache_entry)
|
||||
|
||||
sha_value = cache_entry.get('sha256')
|
||||
if sha_value:
|
||||
self._hash_index.add_entry(sha_value.lower(), normalized_new_path)
|
||||
|
||||
all_folders = set(item['folder'] for item in cache.raw_data)
|
||||
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||
|
||||
if 'tags' in metadata:
|
||||
for tag in metadata.get('tags', []):
|
||||
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
|
||||
|
||||
|
||||
for tag in cache_entry.get('tags', []):
|
||||
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
|
||||
|
||||
await cache.resort()
|
||||
|
||||
return True
|
||||
@@ -1019,7 +1216,10 @@ class ModelScanner:
|
||||
if self._cache is None:
|
||||
return False
|
||||
|
||||
return await self._cache.update_preview_url(file_path, preview_url, preview_nsfw_level)
|
||||
updated = await self._cache.update_preview_url(file_path, preview_url, preview_nsfw_level)
|
||||
if updated:
|
||||
await self._persist_current_cache()
|
||||
return updated
|
||||
|
||||
async def bulk_delete_models(self, file_paths: List[str]) -> Dict:
|
||||
"""Delete multiple models and update cache in a batch operation
|
||||
|
||||
Reference in New Issue
Block a user