From 77bbf85b5246b70ef63b57b62021abaf6d47f55a Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Fri, 3 Oct 2025 11:00:51 +0800 Subject: [PATCH] feat(persistent-cache): implement SQLite-based persistent model cache with loading and saving functionality --- py/services/base_model_service.py | 33 +- py/services/model_scanner.py | 258 +++++++++++-- py/services/persistent_model_cache.py | 346 ++++++++++++++++++ py/utils/example_images_download_manager.py | 58 +-- py/utils/example_images_metadata.py | 38 +- tests/services/test_model_scanner.py | 63 ++++ tests/services/test_persistent_model_cache.py | 92 +++++ 7 files changed, 809 insertions(+), 79 deletions(-) create mode 100644 py/services/persistent_model_cache.py create mode 100644 tests/services/test_persistent_model_cache.py diff --git a/py/services/base_model_service.py b/py/services/base_model_service.py index 2c2c0ad8..9e067a67 100644 --- a/py/services/base_model_service.py +++ b/py/services/base_model_service.py @@ -4,6 +4,7 @@ import logging import os from ..utils.models import BaseModelMetadata +from ..utils.metadata_manager import MetadataManager from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider from .settings_manager import settings as default_settings @@ -313,24 +314,24 @@ class BaseModelService(ABC): return {'civitai_url': None, 'model_id': None, 'version_id': None} async def get_model_metadata(self, file_path: str) -> Optional[Dict]: - """Get filtered CivitAI metadata for a model by file path""" - cache = await self.scanner.get_cached_data() - - for model in cache.raw_data: - if model.get('file_path') == file_path: - return self.filter_civitai_data(model.get("civitai", {})) - - return None + """Load full metadata for a single model. + + Listing/search endpoints return lightweight cache entries; this method performs + a lazy read of the on-disk metadata snapshot when callers need full detail. + """ + metadata, should_skip = await MetadataManager.load_metadata(file_path, self.metadata_class) + if should_skip or metadata is None: + return None + return self.filter_civitai_data(metadata.to_dict().get("civitai", {})) + async def get_model_description(self, file_path: str) -> Optional[str]: - """Get model description by file path""" - cache = await self.scanner.get_cached_data() - - for model in cache.raw_data: - if model.get('file_path') == file_path: - return model.get('modelDescription', '') - - return None + """Return the stored modelDescription field for a model.""" + metadata, should_skip = await MetadataManager.load_metadata(file_path, self.metadata_class) + if should_skip or metadata is None: + return None + return metadata.modelDescription or '' + async def search_relative_paths(self, search_term: str, limit: int = 15) -> List[str]: """Search model relative file paths for autocomplete functionality""" diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index ccba3b2c..84f46237 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -5,7 +5,7 @@ import asyncio import time import shutil from dataclasses import dataclass -from typing import Awaitable, Callable, List, Dict, Optional, Type, Set +from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Set, Type, Union from ..utils.models import BaseModelMetadata from ..config import config @@ -17,6 +17,7 @@ from ..utils.constants import PREVIEW_EXTENSIONS from .model_lifecycle_service import delete_model_artifacts from .service_registry import ServiceRegistry from .websocket_manager import ws_manager +from .persistent_model_cache import get_persistent_cache logger = logging.getLogger(__name__) @@ -79,6 +80,7 @@ class ModelScanner: self._tags_count = {} # Dictionary to store tag counts self._is_initializing = False # Flag to track initialization state self._excluded_models = [] # List to track excluded models + self._persistent_cache = get_persistent_cache() self._initialized = True # Register this service @@ -89,6 +91,95 @@ class ModelScanner: service_name = f"{self.model_type}_scanner" await ServiceRegistry.register_service(service_name, self) + def _slim_civitai_payload(self, civitai: Optional[Mapping[str, Any]]) -> Optional[Dict[str, Any]]: + """Return a lightweight civitai payload containing only frequently used keys.""" + if not isinstance(civitai, Mapping) or not civitai: + return None + + slim: Dict[str, Any] = {} + for key in ('id', 'modelId', 'name'): + value = civitai.get(key) + if value not in (None, '', []): + slim[key] = value + + trained_words = civitai.get('trainedWords') + if trained_words: + slim['trainedWords'] = list(trained_words) if isinstance(trained_words, list) else trained_words + + return slim or None + + def _build_cache_entry( + self, + source: Union[BaseModelMetadata, Mapping[str, Any]], + *, + folder: Optional[str] = None, + file_path_override: Optional[str] = None + ) -> Dict[str, Any]: + """Project metadata into the lightweight cache representation.""" + is_mapping = isinstance(source, Mapping) + + def get_value(key: str, default: Any = None) -> Any: + if is_mapping: + return source.get(key, default) + return getattr(source, key, default) + + file_path = file_path_override or get_value('file_path', '') or '' + normalized_path = file_path.replace('\\', '/') + + folder_value = folder if folder is not None else get_value('folder', '') or '' + normalized_folder = folder_value.replace('\\', '/') + + tags_value = get_value('tags') or [] + if isinstance(tags_value, list): + tags_list = list(tags_value) + elif isinstance(tags_value, (set, tuple)): + tags_list = list(tags_value) + else: + tags_list = [] + + preview_url = get_value('preview_url', '') or '' + if isinstance(preview_url, str): + preview_url = preview_url.replace('\\', '/') + else: + preview_url = '' + + civitai_slim = self._slim_civitai_payload(get_value('civitai')) + usage_tips = get_value('usage_tips', '') or '' + if not isinstance(usage_tips, str): + usage_tips = str(usage_tips) + notes = get_value('notes', '') or '' + if not isinstance(notes, str): + notes = str(notes) + + entry: Dict[str, Any] = { + 'file_path': normalized_path, + 'file_name': get_value('file_name', '') or '', + 'model_name': get_value('model_name', '') or '', + 'folder': normalized_folder, + 'size': int(get_value('size', 0) or 0), + 'modified': float(get_value('modified', 0.0) or 0.0), + 'sha256': (get_value('sha256', '') or '').lower(), + 'base_model': get_value('base_model', '') or '', + 'preview_url': preview_url, + 'preview_nsfw_level': int(get_value('preview_nsfw_level', 0) or 0), + 'from_civitai': bool(get_value('from_civitai', True)), + 'favorite': bool(get_value('favorite', False)), + 'notes': notes, + 'usage_tips': usage_tips, + 'exclude': bool(get_value('exclude', False)), + 'db_checked': bool(get_value('db_checked', False)), + 'last_checked_at': float(get_value('last_checked_at', 0.0) or 0.0), + 'tags': tags_list, + 'civitai': civitai_slim, + 'civitai_deleted': bool(get_value('civitai_deleted', False)), + } + + model_type = get_value('model_type', None) + if model_type: + entry['model_type'] = model_type + + return entry + async def initialize_in_background(self) -> None: """Initialize cache in background using thread pool""" try: @@ -113,7 +204,9 @@ class ModelScanner: 'scanner_type': self.model_type, 'pageType': page_type }) - + + await self._load_persisted_cache(page_type) + # If cache loading failed, proceed with full scan await ws_manager.broadcast_init_progress({ 'stage': 'scan_folders', @@ -150,7 +243,8 @@ class ModelScanner: if scan_result: await self._apply_scan_result(scan_result) - + await self._save_persistent_cache(scan_result) + # Send final progress update await ws_manager.broadcast_init_progress({ 'stage': 'finalizing', @@ -179,6 +273,105 @@ class ModelScanner: # Always clear the initializing flag when done self._is_initializing = False + async def _load_persisted_cache(self, page_type: str) -> bool: + """Attempt to hydrate the in-memory cache from the SQLite snapshot.""" + if not getattr(self, '_persistent_cache', None): + return False + + loop = asyncio.get_event_loop() + try: + persisted = await loop.run_in_executor( + None, + self._persistent_cache.load_cache, + self.model_type + ) + except FileNotFoundError: + return False + except Exception as exc: + logger.debug("%s Scanner: Could not load persisted cache: %s", self.model_type.capitalize(), exc) + return False + + if not persisted or not persisted.raw_data: + return False + + hash_index = ModelHashIndex() + for sha_value, path in persisted.hash_rows: + if sha_value and path: + hash_index.add_entry(sha_value.lower(), path) + + tags_count: Dict[str, int] = {} + for item in persisted.raw_data: + for tag in item.get('tags') or []: + tags_count[tag] = tags_count.get(tag, 0) + 1 + + scan_result = CacheBuildResult( + raw_data=list(persisted.raw_data), + hash_index=hash_index, + tags_count=tags_count, + excluded_models=list(persisted.excluded_models) + ) + + await self._apply_scan_result(scan_result) + + await ws_manager.broadcast_init_progress({ + 'stage': 'loading_cache', + 'progress': 1, + 'details': f"Loaded cached {self.model_type} data from disk", + 'scanner_type': self.model_type, + 'pageType': page_type + }) + return True + + async def _save_persistent_cache(self, scan_result: CacheBuildResult) -> None: + if not scan_result or not getattr(self, '_persistent_cache', None): + return + + hash_snapshot = self._build_hash_index_snapshot(scan_result.hash_index) + loop = asyncio.get_event_loop() + try: + await loop.run_in_executor( + None, + self._persistent_cache.save_cache, + self.model_type, + list(scan_result.raw_data), + hash_snapshot, + list(scan_result.excluded_models) + ) + except Exception as exc: + logger.warning("%s Scanner: Failed to persist cache: %s", self.model_type.capitalize(), exc) + + def _build_hash_index_snapshot(self, hash_index: Optional[ModelHashIndex]) -> Dict[str, List[str]]: + snapshot: Dict[str, List[str]] = {} + if not hash_index: + return snapshot + + for sha_value, path in getattr(hash_index, '_hash_to_path', {}).items(): + if not sha_value or not path: + continue + bucket = snapshot.setdefault(sha_value.lower(), []) + if path not in bucket: + bucket.append(path) + + for sha_value, paths in getattr(hash_index, '_duplicate_hashes', {}).items(): + if not sha_value: + continue + bucket = snapshot.setdefault(sha_value.lower(), []) + for path in paths: + if path and path not in bucket: + bucket.append(path) + return snapshot + + async def _persist_current_cache(self) -> None: + if self._cache is None or not getattr(self, '_persistent_cache', None): + return + + snapshot = CacheBuildResult( + raw_data=list(self._cache.raw_data), + hash_index=self._hash_index, + tags_count=dict(self._tags_count), + excluded_models=list(self._excluded_models) + ) + await self._save_persistent_cache(snapshot) def _count_model_files(self) -> int: """Count all model files with supported extensions in all roots @@ -300,6 +493,7 @@ class ModelScanner: # Scan for new data scan_result = await self._gather_model_data() await self._apply_scan_result(scan_result) + await self._save_persistent_cache(scan_result) logger.info( f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, " @@ -594,8 +788,12 @@ class ModelScanner: # Hook: allow subclasses to adjust metadata metadata = self.adjust_metadata(metadata, file_path, root_path) - model_data = metadata.to_dict() - + rel_path = os.path.relpath(file_path, root_path) + folder = os.path.dirname(rel_path) + normalized_folder = folder.replace(os.path.sep, '/') + + model_data = self._build_cache_entry(metadata, folder=normalized_folder) + # Skip excluded models if model_data.get('exclude', False): excluded_models.append(model_data['file_path']) @@ -609,10 +807,6 @@ class ModelScanner: # if existing_path and existing_path != file_path: # logger.warning(f"Duplicate filename detected: '{filename}' - files: '{existing_path}' and '{file_path}'") - rel_path = os.path.relpath(file_path, root_path) - folder = os.path.dirname(rel_path) - model_data['folder'] = folder.replace(os.path.sep, '/') - return model_data async def _apply_scan_result(self, scan_result: CacheBuildResult) -> None: @@ -749,6 +943,7 @@ class ModelScanner: # Update the hash index self._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path']) + await self._persist_current_cache() return True except Exception as e: logger.error(f"Error adding model to cache: {e}") @@ -894,28 +1089,30 @@ class ModelScanner: ] if metadata: - if original_path == new_path: - existing_folder = next((item['folder'] for item in cache.raw_data - if item['file_path'] == original_path), None) - if existing_folder: - metadata['folder'] = existing_folder - else: - metadata['folder'] = self._calculate_folder(new_path) + normalized_new_path = new_path.replace(os.sep, '/') + if original_path == new_path and existing_item: + folder_value = existing_item.get('folder', self._calculate_folder(new_path)) else: - metadata['folder'] = self._calculate_folder(new_path) - - cache.raw_data.append(metadata) - - if 'sha256' in metadata: - self._hash_index.add_entry(metadata['sha256'].lower(), new_path) - + folder_value = self._calculate_folder(new_path) + + cache_entry = self._build_cache_entry( + metadata, + folder=folder_value, + file_path_override=normalized_new_path, + ) + + cache.raw_data.append(cache_entry) + + sha_value = cache_entry.get('sha256') + if sha_value: + self._hash_index.add_entry(sha_value.lower(), normalized_new_path) + all_folders = set(item['folder'] for item in cache.raw_data) cache.folders = sorted(list(all_folders), key=lambda x: x.lower()) - - if 'tags' in metadata: - for tag in metadata.get('tags', []): - self._tags_count[tag] = self._tags_count.get(tag, 0) + 1 - + + for tag in cache_entry.get('tags', []): + self._tags_count[tag] = self._tags_count.get(tag, 0) + 1 + await cache.resort() return True @@ -1019,7 +1216,10 @@ class ModelScanner: if self._cache is None: return False - return await self._cache.update_preview_url(file_path, preview_url, preview_nsfw_level) + updated = await self._cache.update_preview_url(file_path, preview_url, preview_nsfw_level) + if updated: + await self._persist_current_cache() + return updated async def bulk_delete_models(self, file_paths: List[str]) -> Dict: """Delete multiple models and update cache in a batch operation diff --git a/py/services/persistent_model_cache.py b/py/services/persistent_model_cache.py new file mode 100644 index 00000000..e473035b --- /dev/null +++ b/py/services/persistent_model_cache.py @@ -0,0 +1,346 @@ +import json +import logging +import os +import sqlite3 +import threading +from dataclasses import dataclass +from typing import Dict, List, Optional, Sequence, Tuple + +from ..utils.settings_paths import get_settings_dir + +logger = logging.getLogger(__name__) + + +@dataclass +class PersistedCacheData: + """Lightweight structure returned by the persistent cache.""" + + raw_data: List[Dict] + hash_rows: List[Tuple[str, str]] + excluded_models: List[str] + + +class PersistentModelCache: + """Persist core model metadata and hash index data in SQLite.""" + + _DEFAULT_FILENAME = "model_cache.sqlite" + _instance: Optional["PersistentModelCache"] = None + _instance_lock = threading.Lock() + + def __init__(self, db_path: Optional[str] = None) -> None: + self._db_path = db_path or self._resolve_default_path() + self._db_lock = threading.Lock() + self._schema_initialized = False + try: + directory = os.path.dirname(self._db_path) + if directory: + os.makedirs(directory, exist_ok=True) + except Exception as exc: # pragma: no cover - defensive guard + logger.warning("Could not create cache directory %s: %s", directory, exc) + if self.is_enabled(): + self._initialize_schema() + + @classmethod + def get_default(cls) -> "PersistentModelCache": + with cls._instance_lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def is_enabled(self) -> bool: + return os.environ.get("LORA_MANAGER_DISABLE_PERSISTENT_CACHE", "0") != "1" + + def load_cache(self, model_type: str) -> Optional[PersistedCacheData]: + if not self.is_enabled(): + return None + if not self._schema_initialized: + self._initialize_schema() + if not self._schema_initialized: + return None + try: + with self._db_lock: + conn = self._connect(readonly=True) + try: + rows = conn.execute( + "SELECT file_path, file_name, model_name, folder, size, modified, sha256, base_model," + " preview_url, preview_nsfw_level, from_civitai, favorite, notes, usage_tips," + " civitai_id, civitai_model_id, civitai_name, trained_words, exclude, db_checked," + " last_checked_at" + " FROM models WHERE model_type = ?", + (model_type,), + ).fetchall() + + if not rows: + return None + + tags = self._load_tags(conn, model_type) + hash_rows = conn.execute( + "SELECT sha256, file_path FROM hash_index WHERE model_type = ?", + (model_type,), + ).fetchall() + excluded = conn.execute( + "SELECT file_path FROM excluded_models WHERE model_type = ?", + (model_type,), + ).fetchall() + finally: + conn.close() + except Exception as exc: + logger.warning("Failed to load persisted cache for %s: %s", model_type, exc) + return None + + raw_data: List[Dict] = [] + for row in rows: + file_path: str = row["file_path"] + trained_words = [] + if row["trained_words"]: + try: + trained_words = json.loads(row["trained_words"]) + except json.JSONDecodeError: + trained_words = [] + + civitai: Optional[Dict] = None + if any(row[col] is not None for col in ("civitai_id", "civitai_model_id", "civitai_name")): + civitai = {} + if row["civitai_id"] is not None: + civitai["id"] = row["civitai_id"] + if row["civitai_model_id"] is not None: + civitai["modelId"] = row["civitai_model_id"] + if row["civitai_name"]: + civitai["name"] = row["civitai_name"] + if trained_words: + civitai["trainedWords"] = trained_words + + item = { + "file_path": file_path, + "file_name": row["file_name"], + "model_name": row["model_name"], + "folder": row["folder"] or "", + "size": row["size"] or 0, + "modified": row["modified"] or 0.0, + "sha256": row["sha256"] or "", + "base_model": row["base_model"] or "", + "preview_url": row["preview_url"] or "", + "preview_nsfw_level": row["preview_nsfw_level"] or 0, + "from_civitai": bool(row["from_civitai"]), + "favorite": bool(row["favorite"]), + "notes": row["notes"] or "", + "usage_tips": row["usage_tips"] or "", + "exclude": bool(row["exclude"]), + "db_checked": bool(row["db_checked"]), + "last_checked_at": row["last_checked_at"] or 0.0, + "tags": tags.get(file_path, []), + "civitai": civitai, + } + raw_data.append(item) + + hash_pairs = [(entry["sha256"].lower(), entry["file_path"]) for entry in hash_rows if entry["sha256"]] + if not hash_pairs: + # Fall back to hashes stored on the model rows + for item in raw_data: + sha_value = item.get("sha256") + if sha_value: + hash_pairs.append((sha_value.lower(), item["file_path"])) + + excluded_paths = [row["file_path"] for row in excluded] + return PersistedCacheData(raw_data=raw_data, hash_rows=hash_pairs, excluded_models=excluded_paths) + + def save_cache(self, model_type: str, raw_data: Sequence[Dict], hash_index: Dict[str, List[str]], excluded_models: Sequence[str]) -> None: + if not self.is_enabled(): + return + if not self._schema_initialized: + self._initialize_schema() + if not self._schema_initialized: + return + try: + with self._db_lock: + conn = self._connect() + try: + conn.execute("PRAGMA foreign_keys = ON") + conn.execute("DELETE FROM models WHERE model_type = ?", (model_type,)) + conn.execute("DELETE FROM model_tags WHERE model_type = ?", (model_type,)) + conn.execute("DELETE FROM hash_index WHERE model_type = ?", (model_type,)) + conn.execute("DELETE FROM excluded_models WHERE model_type = ?", (model_type,)) + + model_rows = [self._prepare_model_row(model_type, item) for item in raw_data] + conn.executemany(self._insert_model_sql(), model_rows) + + tag_rows = [] + for item in raw_data: + file_path = item.get("file_path") + if not file_path: + continue + for tag in item.get("tags") or []: + tag_rows.append((model_type, file_path, tag)) + if tag_rows: + conn.executemany( + "INSERT INTO model_tags (model_type, file_path, tag) VALUES (?, ?, ?)", + tag_rows, + ) + + hash_rows: List[Tuple[str, str, str]] = [] + for sha_value, paths in hash_index.items(): + for path in paths: + if not sha_value or not path: + continue + hash_rows.append((model_type, sha_value.lower(), path)) + if hash_rows: + conn.executemany( + "INSERT OR IGNORE INTO hash_index (model_type, sha256, file_path) VALUES (?, ?, ?)", + hash_rows, + ) + + excluded_rows = [(model_type, path) for path in excluded_models] + if excluded_rows: + conn.executemany( + "INSERT OR IGNORE INTO excluded_models (model_type, file_path) VALUES (?, ?)", + excluded_rows, + ) + conn.commit() + finally: + conn.close() + except Exception as exc: + logger.warning("Failed to persist cache for %s: %s", model_type, exc) + + # Internal helpers ------------------------------------------------- + + def _resolve_default_path(self) -> str: + override = os.environ.get("LORA_MANAGER_CACHE_DB") + if override: + return override + try: + settings_dir = get_settings_dir(create=True) + except Exception as exc: # pragma: no cover - defensive guard + logger.warning("Falling back to project directory for cache: %s", exc) + settings_dir = os.path.dirname(os.path.dirname(self._db_path)) if hasattr(self, "_db_path") else os.getcwd() + return os.path.join(settings_dir, self._DEFAULT_FILENAME) + + def _initialize_schema(self) -> None: + with self._db_lock: + if self._schema_initialized: + return + try: + with self._connect() as conn: + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA foreign_keys = ON") + conn.executescript( + """ + CREATE TABLE IF NOT EXISTS models ( + model_type TEXT NOT NULL, + file_path TEXT NOT NULL, + file_name TEXT, + model_name TEXT, + folder TEXT, + size INTEGER, + modified REAL, + sha256 TEXT, + base_model TEXT, + preview_url TEXT, + preview_nsfw_level INTEGER, + from_civitai INTEGER, + favorite INTEGER, + notes TEXT, + usage_tips TEXT, + civitai_id INTEGER, + civitai_model_id INTEGER, + civitai_name TEXT, + trained_words TEXT, + exclude INTEGER, + db_checked INTEGER, + last_checked_at REAL, + PRIMARY KEY (model_type, file_path) + ); + + CREATE TABLE IF NOT EXISTS model_tags ( + model_type TEXT NOT NULL, + file_path TEXT NOT NULL, + tag TEXT NOT NULL, + PRIMARY KEY (model_type, file_path, tag) + ); + + CREATE TABLE IF NOT EXISTS hash_index ( + model_type TEXT NOT NULL, + sha256 TEXT NOT NULL, + file_path TEXT NOT NULL, + PRIMARY KEY (model_type, sha256, file_path) + ); + + CREATE TABLE IF NOT EXISTS excluded_models ( + model_type TEXT NOT NULL, + file_path TEXT NOT NULL, + PRIMARY KEY (model_type, file_path) + ); + """ + ) + conn.commit() + self._schema_initialized = True + except Exception as exc: # pragma: no cover - defensive guard + logger.warning("Failed to initialize persistent cache schema: %s", exc) + + def _connect(self, readonly: bool = False) -> sqlite3.Connection: + uri = False + path = self._db_path + if readonly: + if not os.path.exists(path): + raise FileNotFoundError(path) + path = f"file:{path}?mode=ro" + uri = True + conn = sqlite3.connect(path, check_same_thread=False, uri=uri, detect_types=sqlite3.PARSE_DECLTYPES) + conn.row_factory = sqlite3.Row + return conn + + def _prepare_model_row(self, model_type: str, item: Dict) -> Tuple: + civitai = item.get("civitai") or {} + trained_words = civitai.get("trainedWords") + if isinstance(trained_words, str): + trained_words_json = trained_words + elif trained_words is None: + trained_words_json = None + else: + trained_words_json = json.dumps(trained_words) + + return ( + model_type, + item.get("file_path"), + item.get("file_name"), + item.get("model_name"), + item.get("folder"), + int(item.get("size") or 0), + float(item.get("modified") or 0.0), + (item.get("sha256") or "").lower() or None, + item.get("base_model"), + item.get("preview_url"), + int(item.get("preview_nsfw_level") or 0), + 1 if item.get("from_civitai", True) else 0, + 1 if item.get("favorite") else 0, + item.get("notes"), + item.get("usage_tips"), + civitai.get("id"), + civitai.get("modelId"), + civitai.get("name"), + trained_words_json, + 1 if item.get("exclude") else 0, + 1 if item.get("db_checked") else 0, + float(item.get("last_checked_at") or 0.0), + ) + + def _insert_model_sql(self) -> str: + return ( + "INSERT INTO models (model_type, file_path, file_name, model_name, folder, size, modified, sha256," + " base_model, preview_url, preview_nsfw_level, from_civitai, favorite, notes, usage_tips," + " civitai_id, civitai_model_id, civitai_name, trained_words, exclude, db_checked, last_checked_at)" + " VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" + ) + + def _load_tags(self, conn: sqlite3.Connection, model_type: str) -> Dict[str, List[str]]: + tag_rows = conn.execute( + "SELECT file_path, tag FROM model_tags WHERE model_type = ?", + (model_type,), + ).fetchall() + result: Dict[str, List[str]] = {} + for row in tag_rows: + result.setdefault(row["file_path"], []).append(row["tag"]) + return result + + +def get_persistent_cache() -> PersistentModelCache: + return PersistentModelCache.get_default() diff --git a/py/utils/example_images_download_manager.py b/py/utils/example_images_download_manager.py index 9ddf03a4..017fb658 100644 --- a/py/utils/example_images_download_manager.py +++ b/py/utils/example_images_download_manager.py @@ -345,14 +345,19 @@ class DownloadManager: self._progress['processed_models'].add(model_hash) return False # Return False to indicate no remote download happened + full_model = await MetadataUpdater.get_updated_model( + model_hash, scanner + ) + civitai_payload = (full_model or {}).get('civitai', {}) if full_model else {} + # If no local images, try to download from remote - elif model.get('civitai') and model.get('civitai', {}).get('images'): - images = model.get('civitai', {}).get('images', []) - + if civitai_payload.get('images'): + images = civitai_payload.get('images', []) + success, is_stale = await ExampleImagesProcessor.download_model_images( model_hash, model_name, images, model_dir, optimize, downloader ) - + # If metadata is stale, try to refresh it if is_stale and model_hash not in self._progress['refreshed_models']: await MetadataUpdater.refresh_model_metadata( @@ -363,16 +368,17 @@ class DownloadManager: updated_model = await MetadataUpdater.get_updated_model( model_hash, scanner ) + updated_civitai = (updated_model or {}).get('civitai', {}) if updated_model else {} - if updated_model and updated_model.get('civitai', {}).get('images'): + if updated_civitai.get('images'): # Retry download with updated metadata - updated_images = updated_model.get('civitai', {}).get('images', []) + updated_images = updated_civitai.get('images', []) success, _ = await ExampleImagesProcessor.download_model_images( model_hash, model_name, updated_images, model_dir, optimize, downloader ) self._progress['refreshed_models'].add(model_hash) - + # Mark as processed if successful, or as failed if unsuccessful after refresh if success: self._progress['processed_models'].add(model_hash) @@ -381,13 +387,13 @@ class DownloadManager: if model_hash in self._progress['refreshed_models']: self._progress['failed_models'].add(model_hash) logger.info(f"Marking model {model_name} as failed after metadata refresh") - + return True # Return True to indicate a remote download happened else: # No civitai data or images available, mark as failed to avoid future attempts self._progress['failed_models'].add(model_hash) logger.debug(f"No civitai images available for model {model_name}, marking as failed") - + # Save progress periodically if self._progress['completed'] % 10 == 0 or self._progress['completed'] == self._progress['total'] - 1: self._save_progress(output_dir) @@ -627,51 +633,59 @@ class DownloadManager: self._progress['processed_models'].add(model_hash) return False # Return False to indicate no remote download happened + full_model = await MetadataUpdater.get_updated_model( + model_hash, scanner + ) + civitai_payload = (full_model or {}).get('civitai', {}) if full_model else {} + # If no local images, try to download from remote - elif model.get('civitai') and model.get('civitai', {}).get('images'): - images = model.get('civitai', {}).get('images', []) - + if civitai_payload.get('images'): + images = civitai_payload.get('images', []) + success, is_stale, failed_images = await ExampleImagesProcessor.download_model_images_with_tracking( model_hash, model_name, images, model_dir, optimize, downloader ) - + # If metadata is stale, try to refresh it if is_stale and model_hash not in self._progress['refreshed_models']: await MetadataUpdater.refresh_model_metadata( model_hash, model_name, scanner_type, scanner, self._progress ) - + # Get the updated model data updated_model = await MetadataUpdater.get_updated_model( model_hash, scanner ) - - if updated_model and updated_model.get('civitai', {}).get('images'): + updated_civitai = (updated_model or {}).get('civitai', {}) if updated_model else {} + + if updated_civitai.get('images'): # Retry download with updated metadata - updated_images = updated_model.get('civitai', {}).get('images', []) + updated_images = updated_civitai.get('images', []) success, _, additional_failed_images = await ExampleImagesProcessor.download_model_images_with_tracking( model_hash, model_name, updated_images, model_dir, optimize, downloader ) - + # Combine failed images from both attempts failed_images.extend(additional_failed_images) - + self._progress['refreshed_models'].add(model_hash) - + # For forced downloads, remove failed images from metadata if failed_images: # Create a copy of images excluding failed ones await self._remove_failed_images_from_metadata( model_hash, model_name, failed_images, scanner ) - + # Mark as processed if success or failed_images: # Mark as processed if we successfully downloaded some images or removed failed ones self._progress['processed_models'].add(model_hash) - + return True # Return True to indicate a remote download happened else: logger.debug(f"No civitai images available for model {model_name}") + + return False except Exception as e: diff --git a/py/utils/example_images_metadata.py b/py/utils/example_images_metadata.py index 780eb43b..39b5ad73 100644 --- a/py/utils/example_images_metadata.py +++ b/py/utils/example_images_metadata.py @@ -95,21 +95,35 @@ class MetadataUpdater: @staticmethod async def get_updated_model(model_hash, scanner): - """Get updated model data - - Args: - model_hash: SHA256 hash of the model - scanner: Scanner instance - - Returns: - dict: Updated model data or None if not found - """ + """Load the most recent metadata for a model identified by hash.""" cache = await scanner.get_cached_data() + target = None for item in cache.raw_data: if item.get('sha256') == model_hash: - return item - return None - + target = item + break + + if not target: + return None + + file_path = target.get('file_path') + if not file_path: + return target + + model_cls = getattr(scanner, 'model_class', None) + if model_cls is None: + metadata, should_skip = await MetadataManager.load_metadata(file_path) + else: + metadata, should_skip = await MetadataManager.load_metadata(file_path, model_cls) + + if should_skip or metadata is None: + return target + + rich_metadata = metadata.to_dict() + rich_metadata.setdefault('folder', target.get('folder', '')) + return rich_metadata + + @staticmethod async def update_metadata_from_local_examples(model_hash, model, scanner_type, scanner, model_dir): """Update model metadata with local example image information diff --git a/tests/services/test_model_scanner.py b/tests/services/test_model_scanner.py index 7eff5e78..2f2bd1b4 100644 --- a/tests/services/test_model_scanner.py +++ b/tests/services/test_model_scanner.py @@ -9,6 +9,7 @@ from py.services import model_scanner from py.services.model_cache import ModelCache from py.services.model_hash_index import ModelHashIndex from py.services.model_scanner import CacheBuildResult, ModelScanner +from py.services.persistent_model_cache import PersistentModelCache from py.utils.models import BaseModelMetadata @@ -78,6 +79,11 @@ def reset_model_scanner_singletons(): ModelScanner._locks.clear() +@pytest.fixture(autouse=True) +def disable_persistent_cache_env(monkeypatch): + monkeypatch.setenv('LORA_MANAGER_DISABLE_PERSISTENT_CACHE', '1') + + @pytest.fixture(autouse=True) def stub_register_service(monkeypatch): async def noop(*_args, **_kwargs): @@ -175,3 +181,60 @@ async def test_initialize_in_background_applies_scan_result(tmp_path: Path, monk assert scanner._tags_count == {"alpha": 1, "beta": 1} assert scanner._excluded_models == [_normalize_path(tmp_path / "skip-file.txt")] assert ws_stub.payloads[-1]["progress"] == 100 + + + +@pytest.mark.asyncio +async def test_load_persisted_cache_populates_cache(tmp_path: Path, monkeypatch): + # Enable persistence for this specific test and back it with a temp database + monkeypatch.setenv('LORA_MANAGER_DISABLE_PERSISTENT_CACHE', '0') + db_path = tmp_path / 'cache.sqlite' + store = PersistentModelCache(db_path=str(db_path)) + + file_path = tmp_path / 'one.txt' + file_path.write_text('one', encoding='utf-8') + normalized = _normalize_path(file_path) + + raw_model = { + 'file_path': normalized, + 'file_name': 'one', + 'model_name': 'one', + 'folder': '', + 'size': 3, + 'modified': 123.0, + 'sha256': 'hash-one', + 'base_model': 'test', + 'preview_url': '', + 'preview_nsfw_level': 0, + 'from_civitai': True, + 'favorite': False, + 'notes': '', + 'usage_tips': '', + 'exclude': False, + 'db_checked': False, + 'last_checked_at': 0.0, + 'tags': ['alpha'], + 'civitai': {'id': 11, 'modelId': 22, 'name': 'ver', 'trainedWords': ['abc']}, + } + + store.save_cache('dummy', [raw_model], {'hash-one': [normalized]}, []) + + monkeypatch.setattr(model_scanner, 'get_persistent_cache', lambda: store) + + scanner = DummyScanner(tmp_path) + ws_stub = RecordingWebSocketManager() + monkeypatch.setattr(model_scanner, 'ws_manager', ws_stub) + + loaded = await scanner._load_persisted_cache('dummy') + assert loaded is True + + cache = await scanner.get_cached_data() + assert len(cache.raw_data) == 1 + entry = cache.raw_data[0] + assert entry['file_path'] == normalized + assert entry['tags'] == ['alpha'] + assert entry['civitai']['trainedWords'] == ['abc'] + assert scanner._hash_index.get_path('hash-one') == normalized + assert scanner._tags_count == {'alpha': 1} + assert ws_stub.payloads[-1]['stage'] == 'loading_cache' + assert ws_stub.payloads[-1]['progress'] == 1 diff --git a/tests/services/test_persistent_model_cache.py b/tests/services/test_persistent_model_cache.py new file mode 100644 index 00000000..52e53a9e --- /dev/null +++ b/tests/services/test_persistent_model_cache.py @@ -0,0 +1,92 @@ +from pathlib import Path + +import pytest + +from py.services.persistent_model_cache import PersistentModelCache + + +def test_persistent_cache_roundtrip(tmp_path: Path, monkeypatch) -> None: + monkeypatch.setenv('LORA_MANAGER_DISABLE_PERSISTENT_CACHE', '0') + db_path = tmp_path / 'cache.sqlite' + store = PersistentModelCache(db_path=str(db_path)) + + file_a = (tmp_path / 'a.txt').as_posix() + file_b = (tmp_path / 'b.txt').as_posix() + duplicate_path = f"{file_b}.copy" + + raw_data = [ + { + 'file_path': file_a, + 'file_name': 'a', + 'model_name': 'Model A', + 'folder': '', + 'size': 10, + 'modified': 100.0, + 'sha256': 'hash-a', + 'base_model': 'base', + 'preview_url': 'preview/a.png', + 'preview_nsfw_level': 1, + 'from_civitai': True, + 'favorite': True, + 'notes': 'note', + 'usage_tips': '{}', + 'exclude': False, + 'db_checked': True, + 'last_checked_at': 200.0, + 'tags': ['alpha', 'beta'], + 'civitai': {'id': 1, 'modelId': 2, 'name': 'verA', 'trainedWords': ['word1']}, + }, + { + 'file_path': file_b, + 'file_name': 'b', + 'model_name': 'Model B', + 'folder': 'folder', + 'size': 20, + 'modified': 120.0, + 'sha256': 'hash-b', + 'base_model': '', + 'preview_url': '', + 'preview_nsfw_level': 0, + 'from_civitai': False, + 'favorite': False, + 'notes': '', + 'usage_tips': '', + 'exclude': True, + 'db_checked': False, + 'last_checked_at': 0.0, + 'tags': [], + 'civitai': None, + }, + ] + + hash_index = { + 'hash-a': [file_a], + 'hash-b': [file_b, duplicate_path], + } + excluded = [duplicate_path] + + store.save_cache('dummy', raw_data, hash_index, excluded) + + persisted = store.load_cache('dummy') + assert persisted is not None + assert len(persisted.raw_data) == 2 + + items = {item['file_path']: item for item in persisted.raw_data} + assert set(items.keys()) == {file_a, file_b} + first = items[file_a] + assert first['favorite'] is True + assert first['civitai']['id'] == 1 + assert first['civitai']['trainedWords'] == ['word1'] + assert first['tags'] == ['alpha', 'beta'] + + second = items[file_b] + assert second['exclude'] is True + assert second['civitai'] is None + + expected_hash_pairs = { + ('hash-a', file_a), + ('hash-b', file_b), + ('hash-b', duplicate_path), + } + assert set((sha, path) for sha, path in persisted.hash_rows) == expected_hash_pairs + assert persisted.excluded_models == excluded