feat(persistent-cache): implement SQLite-based persistent model cache with loading and saving functionality

This commit is contained in:
Will Miao
2025-10-03 11:00:51 +08:00
parent 3b1990e97a
commit 77bbf85b52
7 changed files with 809 additions and 79 deletions

View File

@@ -5,7 +5,7 @@ import asyncio
import time
import shutil
from dataclasses import dataclass
from typing import Awaitable, Callable, List, Dict, Optional, Type, Set
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Set, Type, Union
from ..utils.models import BaseModelMetadata
from ..config import config
@@ -17,6 +17,7 @@ from ..utils.constants import PREVIEW_EXTENSIONS
from .model_lifecycle_service import delete_model_artifacts
from .service_registry import ServiceRegistry
from .websocket_manager import ws_manager
from .persistent_model_cache import get_persistent_cache
logger = logging.getLogger(__name__)
@@ -79,6 +80,7 @@ class ModelScanner:
self._tags_count = {} # Dictionary to store tag counts
self._is_initializing = False # Flag to track initialization state
self._excluded_models = [] # List to track excluded models
self._persistent_cache = get_persistent_cache()
self._initialized = True
# Register this service
@@ -89,6 +91,95 @@ class ModelScanner:
service_name = f"{self.model_type}_scanner"
await ServiceRegistry.register_service(service_name, self)
def _slim_civitai_payload(self, civitai: Optional[Mapping[str, Any]]) -> Optional[Dict[str, Any]]:
"""Return a lightweight civitai payload containing only frequently used keys."""
if not isinstance(civitai, Mapping) or not civitai:
return None
slim: Dict[str, Any] = {}
for key in ('id', 'modelId', 'name'):
value = civitai.get(key)
if value not in (None, '', []):
slim[key] = value
trained_words = civitai.get('trainedWords')
if trained_words:
slim['trainedWords'] = list(trained_words) if isinstance(trained_words, list) else trained_words
return slim or None
def _build_cache_entry(
self,
source: Union[BaseModelMetadata, Mapping[str, Any]],
*,
folder: Optional[str] = None,
file_path_override: Optional[str] = None
) -> Dict[str, Any]:
"""Project metadata into the lightweight cache representation."""
is_mapping = isinstance(source, Mapping)
def get_value(key: str, default: Any = None) -> Any:
if is_mapping:
return source.get(key, default)
return getattr(source, key, default)
file_path = file_path_override or get_value('file_path', '') or ''
normalized_path = file_path.replace('\\', '/')
folder_value = folder if folder is not None else get_value('folder', '') or ''
normalized_folder = folder_value.replace('\\', '/')
tags_value = get_value('tags') or []
if isinstance(tags_value, list):
tags_list = list(tags_value)
elif isinstance(tags_value, (set, tuple)):
tags_list = list(tags_value)
else:
tags_list = []
preview_url = get_value('preview_url', '') or ''
if isinstance(preview_url, str):
preview_url = preview_url.replace('\\', '/')
else:
preview_url = ''
civitai_slim = self._slim_civitai_payload(get_value('civitai'))
usage_tips = get_value('usage_tips', '') or ''
if not isinstance(usage_tips, str):
usage_tips = str(usage_tips)
notes = get_value('notes', '') or ''
if not isinstance(notes, str):
notes = str(notes)
entry: Dict[str, Any] = {
'file_path': normalized_path,
'file_name': get_value('file_name', '') or '',
'model_name': get_value('model_name', '') or '',
'folder': normalized_folder,
'size': int(get_value('size', 0) or 0),
'modified': float(get_value('modified', 0.0) or 0.0),
'sha256': (get_value('sha256', '') or '').lower(),
'base_model': get_value('base_model', '') or '',
'preview_url': preview_url,
'preview_nsfw_level': int(get_value('preview_nsfw_level', 0) or 0),
'from_civitai': bool(get_value('from_civitai', True)),
'favorite': bool(get_value('favorite', False)),
'notes': notes,
'usage_tips': usage_tips,
'exclude': bool(get_value('exclude', False)),
'db_checked': bool(get_value('db_checked', False)),
'last_checked_at': float(get_value('last_checked_at', 0.0) or 0.0),
'tags': tags_list,
'civitai': civitai_slim,
'civitai_deleted': bool(get_value('civitai_deleted', False)),
}
model_type = get_value('model_type', None)
if model_type:
entry['model_type'] = model_type
return entry
async def initialize_in_background(self) -> None:
"""Initialize cache in background using thread pool"""
try:
@@ -113,7 +204,9 @@ class ModelScanner:
'scanner_type': self.model_type,
'pageType': page_type
})
await self._load_persisted_cache(page_type)
# If cache loading failed, proceed with full scan
await ws_manager.broadcast_init_progress({
'stage': 'scan_folders',
@@ -150,7 +243,8 @@ class ModelScanner:
if scan_result:
await self._apply_scan_result(scan_result)
await self._save_persistent_cache(scan_result)
# Send final progress update
await ws_manager.broadcast_init_progress({
'stage': 'finalizing',
@@ -179,6 +273,105 @@ class ModelScanner:
# Always clear the initializing flag when done
self._is_initializing = False
async def _load_persisted_cache(self, page_type: str) -> bool:
"""Attempt to hydrate the in-memory cache from the SQLite snapshot."""
if not getattr(self, '_persistent_cache', None):
return False
loop = asyncio.get_event_loop()
try:
persisted = await loop.run_in_executor(
None,
self._persistent_cache.load_cache,
self.model_type
)
except FileNotFoundError:
return False
except Exception as exc:
logger.debug("%s Scanner: Could not load persisted cache: %s", self.model_type.capitalize(), exc)
return False
if not persisted or not persisted.raw_data:
return False
hash_index = ModelHashIndex()
for sha_value, path in persisted.hash_rows:
if sha_value and path:
hash_index.add_entry(sha_value.lower(), path)
tags_count: Dict[str, int] = {}
for item in persisted.raw_data:
for tag in item.get('tags') or []:
tags_count[tag] = tags_count.get(tag, 0) + 1
scan_result = CacheBuildResult(
raw_data=list(persisted.raw_data),
hash_index=hash_index,
tags_count=tags_count,
excluded_models=list(persisted.excluded_models)
)
await self._apply_scan_result(scan_result)
await ws_manager.broadcast_init_progress({
'stage': 'loading_cache',
'progress': 1,
'details': f"Loaded cached {self.model_type} data from disk",
'scanner_type': self.model_type,
'pageType': page_type
})
return True
async def _save_persistent_cache(self, scan_result: CacheBuildResult) -> None:
if not scan_result or not getattr(self, '_persistent_cache', None):
return
hash_snapshot = self._build_hash_index_snapshot(scan_result.hash_index)
loop = asyncio.get_event_loop()
try:
await loop.run_in_executor(
None,
self._persistent_cache.save_cache,
self.model_type,
list(scan_result.raw_data),
hash_snapshot,
list(scan_result.excluded_models)
)
except Exception as exc:
logger.warning("%s Scanner: Failed to persist cache: %s", self.model_type.capitalize(), exc)
def _build_hash_index_snapshot(self, hash_index: Optional[ModelHashIndex]) -> Dict[str, List[str]]:
snapshot: Dict[str, List[str]] = {}
if not hash_index:
return snapshot
for sha_value, path in getattr(hash_index, '_hash_to_path', {}).items():
if not sha_value or not path:
continue
bucket = snapshot.setdefault(sha_value.lower(), [])
if path not in bucket:
bucket.append(path)
for sha_value, paths in getattr(hash_index, '_duplicate_hashes', {}).items():
if not sha_value:
continue
bucket = snapshot.setdefault(sha_value.lower(), [])
for path in paths:
if path and path not in bucket:
bucket.append(path)
return snapshot
async def _persist_current_cache(self) -> None:
if self._cache is None or not getattr(self, '_persistent_cache', None):
return
snapshot = CacheBuildResult(
raw_data=list(self._cache.raw_data),
hash_index=self._hash_index,
tags_count=dict(self._tags_count),
excluded_models=list(self._excluded_models)
)
await self._save_persistent_cache(snapshot)
def _count_model_files(self) -> int:
"""Count all model files with supported extensions in all roots
@@ -300,6 +493,7 @@ class ModelScanner:
# Scan for new data
scan_result = await self._gather_model_data()
await self._apply_scan_result(scan_result)
await self._save_persistent_cache(scan_result)
logger.info(
f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, "
@@ -594,8 +788,12 @@ class ModelScanner:
# Hook: allow subclasses to adjust metadata
metadata = self.adjust_metadata(metadata, file_path, root_path)
model_data = metadata.to_dict()
rel_path = os.path.relpath(file_path, root_path)
folder = os.path.dirname(rel_path)
normalized_folder = folder.replace(os.path.sep, '/')
model_data = self._build_cache_entry(metadata, folder=normalized_folder)
# Skip excluded models
if model_data.get('exclude', False):
excluded_models.append(model_data['file_path'])
@@ -609,10 +807,6 @@ class ModelScanner:
# if existing_path and existing_path != file_path:
# logger.warning(f"Duplicate filename detected: '{filename}' - files: '{existing_path}' and '{file_path}'")
rel_path = os.path.relpath(file_path, root_path)
folder = os.path.dirname(rel_path)
model_data['folder'] = folder.replace(os.path.sep, '/')
return model_data
async def _apply_scan_result(self, scan_result: CacheBuildResult) -> None:
@@ -749,6 +943,7 @@ class ModelScanner:
# Update the hash index
self._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
await self._persist_current_cache()
return True
except Exception as e:
logger.error(f"Error adding model to cache: {e}")
@@ -894,28 +1089,30 @@ class ModelScanner:
]
if metadata:
if original_path == new_path:
existing_folder = next((item['folder'] for item in cache.raw_data
if item['file_path'] == original_path), None)
if existing_folder:
metadata['folder'] = existing_folder
else:
metadata['folder'] = self._calculate_folder(new_path)
normalized_new_path = new_path.replace(os.sep, '/')
if original_path == new_path and existing_item:
folder_value = existing_item.get('folder', self._calculate_folder(new_path))
else:
metadata['folder'] = self._calculate_folder(new_path)
cache.raw_data.append(metadata)
if 'sha256' in metadata:
self._hash_index.add_entry(metadata['sha256'].lower(), new_path)
folder_value = self._calculate_folder(new_path)
cache_entry = self._build_cache_entry(
metadata,
folder=folder_value,
file_path_override=normalized_new_path,
)
cache.raw_data.append(cache_entry)
sha_value = cache_entry.get('sha256')
if sha_value:
self._hash_index.add_entry(sha_value.lower(), normalized_new_path)
all_folders = set(item['folder'] for item in cache.raw_data)
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
if 'tags' in metadata:
for tag in metadata.get('tags', []):
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
for tag in cache_entry.get('tags', []):
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
await cache.resort()
return True
@@ -1019,7 +1216,10 @@ class ModelScanner:
if self._cache is None:
return False
return await self._cache.update_preview_url(file_path, preview_url, preview_nsfw_level)
updated = await self._cache.update_preview_url(file_path, preview_url, preview_nsfw_level)
if updated:
await self._persist_current_cache()
return updated
async def bulk_delete_models(self, file_paths: List[str]) -> Dict:
"""Delete multiple models and update cache in a batch operation