feat(persistent-cache): implement SQLite-based persistent model cache with loading and saving functionality

This commit is contained in:
Will Miao
2025-10-03 11:00:51 +08:00
parent 3b1990e97a
commit 77bbf85b52
7 changed files with 809 additions and 79 deletions

View File

@@ -4,6 +4,7 @@ import logging
import os
from ..utils.models import BaseModelMetadata
from ..utils.metadata_manager import MetadataManager
from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider
from .settings_manager import settings as default_settings
@@ -313,24 +314,24 @@ class BaseModelService(ABC):
return {'civitai_url': None, 'model_id': None, 'version_id': None}
async def get_model_metadata(self, file_path: str) -> Optional[Dict]:
"""Get filtered CivitAI metadata for a model by file path"""
cache = await self.scanner.get_cached_data()
for model in cache.raw_data:
if model.get('file_path') == file_path:
return self.filter_civitai_data(model.get("civitai", {}))
return None
"""Load full metadata for a single model.
Listing/search endpoints return lightweight cache entries; this method performs
a lazy read of the on-disk metadata snapshot when callers need full detail.
"""
metadata, should_skip = await MetadataManager.load_metadata(file_path, self.metadata_class)
if should_skip or metadata is None:
return None
return self.filter_civitai_data(metadata.to_dict().get("civitai", {}))
async def get_model_description(self, file_path: str) -> Optional[str]:
"""Get model description by file path"""
cache = await self.scanner.get_cached_data()
for model in cache.raw_data:
if model.get('file_path') == file_path:
return model.get('modelDescription', '')
return None
"""Return the stored modelDescription field for a model."""
metadata, should_skip = await MetadataManager.load_metadata(file_path, self.metadata_class)
if should_skip or metadata is None:
return None
return metadata.modelDescription or ''
async def search_relative_paths(self, search_term: str, limit: int = 15) -> List[str]:
"""Search model relative file paths for autocomplete functionality"""

View File

@@ -5,7 +5,7 @@ import asyncio
import time
import shutil
from dataclasses import dataclass
from typing import Awaitable, Callable, List, Dict, Optional, Type, Set
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Set, Type, Union
from ..utils.models import BaseModelMetadata
from ..config import config
@@ -17,6 +17,7 @@ from ..utils.constants import PREVIEW_EXTENSIONS
from .model_lifecycle_service import delete_model_artifacts
from .service_registry import ServiceRegistry
from .websocket_manager import ws_manager
from .persistent_model_cache import get_persistent_cache
logger = logging.getLogger(__name__)
@@ -79,6 +80,7 @@ class ModelScanner:
self._tags_count = {} # Dictionary to store tag counts
self._is_initializing = False # Flag to track initialization state
self._excluded_models = [] # List to track excluded models
self._persistent_cache = get_persistent_cache()
self._initialized = True
# Register this service
@@ -89,6 +91,95 @@ class ModelScanner:
service_name = f"{self.model_type}_scanner"
await ServiceRegistry.register_service(service_name, self)
def _slim_civitai_payload(self, civitai: Optional[Mapping[str, Any]]) -> Optional[Dict[str, Any]]:
"""Return a lightweight civitai payload containing only frequently used keys."""
if not isinstance(civitai, Mapping) or not civitai:
return None
slim: Dict[str, Any] = {}
for key in ('id', 'modelId', 'name'):
value = civitai.get(key)
if value not in (None, '', []):
slim[key] = value
trained_words = civitai.get('trainedWords')
if trained_words:
slim['trainedWords'] = list(trained_words) if isinstance(trained_words, list) else trained_words
return slim or None
def _build_cache_entry(
self,
source: Union[BaseModelMetadata, Mapping[str, Any]],
*,
folder: Optional[str] = None,
file_path_override: Optional[str] = None
) -> Dict[str, Any]:
"""Project metadata into the lightweight cache representation."""
is_mapping = isinstance(source, Mapping)
def get_value(key: str, default: Any = None) -> Any:
if is_mapping:
return source.get(key, default)
return getattr(source, key, default)
file_path = file_path_override or get_value('file_path', '') or ''
normalized_path = file_path.replace('\\', '/')
folder_value = folder if folder is not None else get_value('folder', '') or ''
normalized_folder = folder_value.replace('\\', '/')
tags_value = get_value('tags') or []
if isinstance(tags_value, list):
tags_list = list(tags_value)
elif isinstance(tags_value, (set, tuple)):
tags_list = list(tags_value)
else:
tags_list = []
preview_url = get_value('preview_url', '') or ''
if isinstance(preview_url, str):
preview_url = preview_url.replace('\\', '/')
else:
preview_url = ''
civitai_slim = self._slim_civitai_payload(get_value('civitai'))
usage_tips = get_value('usage_tips', '') or ''
if not isinstance(usage_tips, str):
usage_tips = str(usage_tips)
notes = get_value('notes', '') or ''
if not isinstance(notes, str):
notes = str(notes)
entry: Dict[str, Any] = {
'file_path': normalized_path,
'file_name': get_value('file_name', '') or '',
'model_name': get_value('model_name', '') or '',
'folder': normalized_folder,
'size': int(get_value('size', 0) or 0),
'modified': float(get_value('modified', 0.0) or 0.0),
'sha256': (get_value('sha256', '') or '').lower(),
'base_model': get_value('base_model', '') or '',
'preview_url': preview_url,
'preview_nsfw_level': int(get_value('preview_nsfw_level', 0) or 0),
'from_civitai': bool(get_value('from_civitai', True)),
'favorite': bool(get_value('favorite', False)),
'notes': notes,
'usage_tips': usage_tips,
'exclude': bool(get_value('exclude', False)),
'db_checked': bool(get_value('db_checked', False)),
'last_checked_at': float(get_value('last_checked_at', 0.0) or 0.0),
'tags': tags_list,
'civitai': civitai_slim,
'civitai_deleted': bool(get_value('civitai_deleted', False)),
}
model_type = get_value('model_type', None)
if model_type:
entry['model_type'] = model_type
return entry
async def initialize_in_background(self) -> None:
"""Initialize cache in background using thread pool"""
try:
@@ -113,7 +204,9 @@ class ModelScanner:
'scanner_type': self.model_type,
'pageType': page_type
})
await self._load_persisted_cache(page_type)
# If cache loading failed, proceed with full scan
await ws_manager.broadcast_init_progress({
'stage': 'scan_folders',
@@ -150,7 +243,8 @@ class ModelScanner:
if scan_result:
await self._apply_scan_result(scan_result)
await self._save_persistent_cache(scan_result)
# Send final progress update
await ws_manager.broadcast_init_progress({
'stage': 'finalizing',
@@ -179,6 +273,105 @@ class ModelScanner:
# Always clear the initializing flag when done
self._is_initializing = False
async def _load_persisted_cache(self, page_type: str) -> bool:
"""Attempt to hydrate the in-memory cache from the SQLite snapshot."""
if not getattr(self, '_persistent_cache', None):
return False
loop = asyncio.get_event_loop()
try:
persisted = await loop.run_in_executor(
None,
self._persistent_cache.load_cache,
self.model_type
)
except FileNotFoundError:
return False
except Exception as exc:
logger.debug("%s Scanner: Could not load persisted cache: %s", self.model_type.capitalize(), exc)
return False
if not persisted or not persisted.raw_data:
return False
hash_index = ModelHashIndex()
for sha_value, path in persisted.hash_rows:
if sha_value and path:
hash_index.add_entry(sha_value.lower(), path)
tags_count: Dict[str, int] = {}
for item in persisted.raw_data:
for tag in item.get('tags') or []:
tags_count[tag] = tags_count.get(tag, 0) + 1
scan_result = CacheBuildResult(
raw_data=list(persisted.raw_data),
hash_index=hash_index,
tags_count=tags_count,
excluded_models=list(persisted.excluded_models)
)
await self._apply_scan_result(scan_result)
await ws_manager.broadcast_init_progress({
'stage': 'loading_cache',
'progress': 1,
'details': f"Loaded cached {self.model_type} data from disk",
'scanner_type': self.model_type,
'pageType': page_type
})
return True
async def _save_persistent_cache(self, scan_result: CacheBuildResult) -> None:
if not scan_result or not getattr(self, '_persistent_cache', None):
return
hash_snapshot = self._build_hash_index_snapshot(scan_result.hash_index)
loop = asyncio.get_event_loop()
try:
await loop.run_in_executor(
None,
self._persistent_cache.save_cache,
self.model_type,
list(scan_result.raw_data),
hash_snapshot,
list(scan_result.excluded_models)
)
except Exception as exc:
logger.warning("%s Scanner: Failed to persist cache: %s", self.model_type.capitalize(), exc)
def _build_hash_index_snapshot(self, hash_index: Optional[ModelHashIndex]) -> Dict[str, List[str]]:
snapshot: Dict[str, List[str]] = {}
if not hash_index:
return snapshot
for sha_value, path in getattr(hash_index, '_hash_to_path', {}).items():
if not sha_value or not path:
continue
bucket = snapshot.setdefault(sha_value.lower(), [])
if path not in bucket:
bucket.append(path)
for sha_value, paths in getattr(hash_index, '_duplicate_hashes', {}).items():
if not sha_value:
continue
bucket = snapshot.setdefault(sha_value.lower(), [])
for path in paths:
if path and path not in bucket:
bucket.append(path)
return snapshot
async def _persist_current_cache(self) -> None:
if self._cache is None or not getattr(self, '_persistent_cache', None):
return
snapshot = CacheBuildResult(
raw_data=list(self._cache.raw_data),
hash_index=self._hash_index,
tags_count=dict(self._tags_count),
excluded_models=list(self._excluded_models)
)
await self._save_persistent_cache(snapshot)
def _count_model_files(self) -> int:
"""Count all model files with supported extensions in all roots
@@ -300,6 +493,7 @@ class ModelScanner:
# Scan for new data
scan_result = await self._gather_model_data()
await self._apply_scan_result(scan_result)
await self._save_persistent_cache(scan_result)
logger.info(
f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, "
@@ -594,8 +788,12 @@ class ModelScanner:
# Hook: allow subclasses to adjust metadata
metadata = self.adjust_metadata(metadata, file_path, root_path)
model_data = metadata.to_dict()
rel_path = os.path.relpath(file_path, root_path)
folder = os.path.dirname(rel_path)
normalized_folder = folder.replace(os.path.sep, '/')
model_data = self._build_cache_entry(metadata, folder=normalized_folder)
# Skip excluded models
if model_data.get('exclude', False):
excluded_models.append(model_data['file_path'])
@@ -609,10 +807,6 @@ class ModelScanner:
# if existing_path and existing_path != file_path:
# logger.warning(f"Duplicate filename detected: '{filename}' - files: '{existing_path}' and '{file_path}'")
rel_path = os.path.relpath(file_path, root_path)
folder = os.path.dirname(rel_path)
model_data['folder'] = folder.replace(os.path.sep, '/')
return model_data
async def _apply_scan_result(self, scan_result: CacheBuildResult) -> None:
@@ -749,6 +943,7 @@ class ModelScanner:
# Update the hash index
self._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
await self._persist_current_cache()
return True
except Exception as e:
logger.error(f"Error adding model to cache: {e}")
@@ -894,28 +1089,30 @@ class ModelScanner:
]
if metadata:
if original_path == new_path:
existing_folder = next((item['folder'] for item in cache.raw_data
if item['file_path'] == original_path), None)
if existing_folder:
metadata['folder'] = existing_folder
else:
metadata['folder'] = self._calculate_folder(new_path)
normalized_new_path = new_path.replace(os.sep, '/')
if original_path == new_path and existing_item:
folder_value = existing_item.get('folder', self._calculate_folder(new_path))
else:
metadata['folder'] = self._calculate_folder(new_path)
cache.raw_data.append(metadata)
if 'sha256' in metadata:
self._hash_index.add_entry(metadata['sha256'].lower(), new_path)
folder_value = self._calculate_folder(new_path)
cache_entry = self._build_cache_entry(
metadata,
folder=folder_value,
file_path_override=normalized_new_path,
)
cache.raw_data.append(cache_entry)
sha_value = cache_entry.get('sha256')
if sha_value:
self._hash_index.add_entry(sha_value.lower(), normalized_new_path)
all_folders = set(item['folder'] for item in cache.raw_data)
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
if 'tags' in metadata:
for tag in metadata.get('tags', []):
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
for tag in cache_entry.get('tags', []):
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
await cache.resort()
return True
@@ -1019,7 +1216,10 @@ class ModelScanner:
if self._cache is None:
return False
return await self._cache.update_preview_url(file_path, preview_url, preview_nsfw_level)
updated = await self._cache.update_preview_url(file_path, preview_url, preview_nsfw_level)
if updated:
await self._persist_current_cache()
return updated
async def bulk_delete_models(self, file_paths: List[str]) -> Dict:
"""Delete multiple models and update cache in a batch operation

View File

@@ -0,0 +1,346 @@
import json
import logging
import os
import sqlite3
import threading
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple
from ..utils.settings_paths import get_settings_dir
logger = logging.getLogger(__name__)
@dataclass
class PersistedCacheData:
"""Lightweight structure returned by the persistent cache."""
raw_data: List[Dict]
hash_rows: List[Tuple[str, str]]
excluded_models: List[str]
class PersistentModelCache:
"""Persist core model metadata and hash index data in SQLite."""
_DEFAULT_FILENAME = "model_cache.sqlite"
_instance: Optional["PersistentModelCache"] = None
_instance_lock = threading.Lock()
def __init__(self, db_path: Optional[str] = None) -> None:
self._db_path = db_path or self._resolve_default_path()
self._db_lock = threading.Lock()
self._schema_initialized = False
try:
directory = os.path.dirname(self._db_path)
if directory:
os.makedirs(directory, exist_ok=True)
except Exception as exc: # pragma: no cover - defensive guard
logger.warning("Could not create cache directory %s: %s", directory, exc)
if self.is_enabled():
self._initialize_schema()
@classmethod
def get_default(cls) -> "PersistentModelCache":
with cls._instance_lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def is_enabled(self) -> bool:
return os.environ.get("LORA_MANAGER_DISABLE_PERSISTENT_CACHE", "0") != "1"
def load_cache(self, model_type: str) -> Optional[PersistedCacheData]:
if not self.is_enabled():
return None
if not self._schema_initialized:
self._initialize_schema()
if not self._schema_initialized:
return None
try:
with self._db_lock:
conn = self._connect(readonly=True)
try:
rows = conn.execute(
"SELECT file_path, file_name, model_name, folder, size, modified, sha256, base_model,"
" preview_url, preview_nsfw_level, from_civitai, favorite, notes, usage_tips,"
" civitai_id, civitai_model_id, civitai_name, trained_words, exclude, db_checked,"
" last_checked_at"
" FROM models WHERE model_type = ?",
(model_type,),
).fetchall()
if not rows:
return None
tags = self._load_tags(conn, model_type)
hash_rows = conn.execute(
"SELECT sha256, file_path FROM hash_index WHERE model_type = ?",
(model_type,),
).fetchall()
excluded = conn.execute(
"SELECT file_path FROM excluded_models WHERE model_type = ?",
(model_type,),
).fetchall()
finally:
conn.close()
except Exception as exc:
logger.warning("Failed to load persisted cache for %s: %s", model_type, exc)
return None
raw_data: List[Dict] = []
for row in rows:
file_path: str = row["file_path"]
trained_words = []
if row["trained_words"]:
try:
trained_words = json.loads(row["trained_words"])
except json.JSONDecodeError:
trained_words = []
civitai: Optional[Dict] = None
if any(row[col] is not None for col in ("civitai_id", "civitai_model_id", "civitai_name")):
civitai = {}
if row["civitai_id"] is not None:
civitai["id"] = row["civitai_id"]
if row["civitai_model_id"] is not None:
civitai["modelId"] = row["civitai_model_id"]
if row["civitai_name"]:
civitai["name"] = row["civitai_name"]
if trained_words:
civitai["trainedWords"] = trained_words
item = {
"file_path": file_path,
"file_name": row["file_name"],
"model_name": row["model_name"],
"folder": row["folder"] or "",
"size": row["size"] or 0,
"modified": row["modified"] or 0.0,
"sha256": row["sha256"] or "",
"base_model": row["base_model"] or "",
"preview_url": row["preview_url"] or "",
"preview_nsfw_level": row["preview_nsfw_level"] or 0,
"from_civitai": bool(row["from_civitai"]),
"favorite": bool(row["favorite"]),
"notes": row["notes"] or "",
"usage_tips": row["usage_tips"] or "",
"exclude": bool(row["exclude"]),
"db_checked": bool(row["db_checked"]),
"last_checked_at": row["last_checked_at"] or 0.0,
"tags": tags.get(file_path, []),
"civitai": civitai,
}
raw_data.append(item)
hash_pairs = [(entry["sha256"].lower(), entry["file_path"]) for entry in hash_rows if entry["sha256"]]
if not hash_pairs:
# Fall back to hashes stored on the model rows
for item in raw_data:
sha_value = item.get("sha256")
if sha_value:
hash_pairs.append((sha_value.lower(), item["file_path"]))
excluded_paths = [row["file_path"] for row in excluded]
return PersistedCacheData(raw_data=raw_data, hash_rows=hash_pairs, excluded_models=excluded_paths)
def save_cache(self, model_type: str, raw_data: Sequence[Dict], hash_index: Dict[str, List[str]], excluded_models: Sequence[str]) -> None:
if not self.is_enabled():
return
if not self._schema_initialized:
self._initialize_schema()
if not self._schema_initialized:
return
try:
with self._db_lock:
conn = self._connect()
try:
conn.execute("PRAGMA foreign_keys = ON")
conn.execute("DELETE FROM models WHERE model_type = ?", (model_type,))
conn.execute("DELETE FROM model_tags WHERE model_type = ?", (model_type,))
conn.execute("DELETE FROM hash_index WHERE model_type = ?", (model_type,))
conn.execute("DELETE FROM excluded_models WHERE model_type = ?", (model_type,))
model_rows = [self._prepare_model_row(model_type, item) for item in raw_data]
conn.executemany(self._insert_model_sql(), model_rows)
tag_rows = []
for item in raw_data:
file_path = item.get("file_path")
if not file_path:
continue
for tag in item.get("tags") or []:
tag_rows.append((model_type, file_path, tag))
if tag_rows:
conn.executemany(
"INSERT INTO model_tags (model_type, file_path, tag) VALUES (?, ?, ?)",
tag_rows,
)
hash_rows: List[Tuple[str, str, str]] = []
for sha_value, paths in hash_index.items():
for path in paths:
if not sha_value or not path:
continue
hash_rows.append((model_type, sha_value.lower(), path))
if hash_rows:
conn.executemany(
"INSERT OR IGNORE INTO hash_index (model_type, sha256, file_path) VALUES (?, ?, ?)",
hash_rows,
)
excluded_rows = [(model_type, path) for path in excluded_models]
if excluded_rows:
conn.executemany(
"INSERT OR IGNORE INTO excluded_models (model_type, file_path) VALUES (?, ?)",
excluded_rows,
)
conn.commit()
finally:
conn.close()
except Exception as exc:
logger.warning("Failed to persist cache for %s: %s", model_type, exc)
# Internal helpers -------------------------------------------------
def _resolve_default_path(self) -> str:
override = os.environ.get("LORA_MANAGER_CACHE_DB")
if override:
return override
try:
settings_dir = get_settings_dir(create=True)
except Exception as exc: # pragma: no cover - defensive guard
logger.warning("Falling back to project directory for cache: %s", exc)
settings_dir = os.path.dirname(os.path.dirname(self._db_path)) if hasattr(self, "_db_path") else os.getcwd()
return os.path.join(settings_dir, self._DEFAULT_FILENAME)
def _initialize_schema(self) -> None:
with self._db_lock:
if self._schema_initialized:
return
try:
with self._connect() as conn:
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA foreign_keys = ON")
conn.executescript(
"""
CREATE TABLE IF NOT EXISTS models (
model_type TEXT NOT NULL,
file_path TEXT NOT NULL,
file_name TEXT,
model_name TEXT,
folder TEXT,
size INTEGER,
modified REAL,
sha256 TEXT,
base_model TEXT,
preview_url TEXT,
preview_nsfw_level INTEGER,
from_civitai INTEGER,
favorite INTEGER,
notes TEXT,
usage_tips TEXT,
civitai_id INTEGER,
civitai_model_id INTEGER,
civitai_name TEXT,
trained_words TEXT,
exclude INTEGER,
db_checked INTEGER,
last_checked_at REAL,
PRIMARY KEY (model_type, file_path)
);
CREATE TABLE IF NOT EXISTS model_tags (
model_type TEXT NOT NULL,
file_path TEXT NOT NULL,
tag TEXT NOT NULL,
PRIMARY KEY (model_type, file_path, tag)
);
CREATE TABLE IF NOT EXISTS hash_index (
model_type TEXT NOT NULL,
sha256 TEXT NOT NULL,
file_path TEXT NOT NULL,
PRIMARY KEY (model_type, sha256, file_path)
);
CREATE TABLE IF NOT EXISTS excluded_models (
model_type TEXT NOT NULL,
file_path TEXT NOT NULL,
PRIMARY KEY (model_type, file_path)
);
"""
)
conn.commit()
self._schema_initialized = True
except Exception as exc: # pragma: no cover - defensive guard
logger.warning("Failed to initialize persistent cache schema: %s", exc)
def _connect(self, readonly: bool = False) -> sqlite3.Connection:
uri = False
path = self._db_path
if readonly:
if not os.path.exists(path):
raise FileNotFoundError(path)
path = f"file:{path}?mode=ro"
uri = True
conn = sqlite3.connect(path, check_same_thread=False, uri=uri, detect_types=sqlite3.PARSE_DECLTYPES)
conn.row_factory = sqlite3.Row
return conn
def _prepare_model_row(self, model_type: str, item: Dict) -> Tuple:
civitai = item.get("civitai") or {}
trained_words = civitai.get("trainedWords")
if isinstance(trained_words, str):
trained_words_json = trained_words
elif trained_words is None:
trained_words_json = None
else:
trained_words_json = json.dumps(trained_words)
return (
model_type,
item.get("file_path"),
item.get("file_name"),
item.get("model_name"),
item.get("folder"),
int(item.get("size") or 0),
float(item.get("modified") or 0.0),
(item.get("sha256") or "").lower() or None,
item.get("base_model"),
item.get("preview_url"),
int(item.get("preview_nsfw_level") or 0),
1 if item.get("from_civitai", True) else 0,
1 if item.get("favorite") else 0,
item.get("notes"),
item.get("usage_tips"),
civitai.get("id"),
civitai.get("modelId"),
civitai.get("name"),
trained_words_json,
1 if item.get("exclude") else 0,
1 if item.get("db_checked") else 0,
float(item.get("last_checked_at") or 0.0),
)
def _insert_model_sql(self) -> str:
return (
"INSERT INTO models (model_type, file_path, file_name, model_name, folder, size, modified, sha256,"
" base_model, preview_url, preview_nsfw_level, from_civitai, favorite, notes, usage_tips,"
" civitai_id, civitai_model_id, civitai_name, trained_words, exclude, db_checked, last_checked_at)"
" VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"
)
def _load_tags(self, conn: sqlite3.Connection, model_type: str) -> Dict[str, List[str]]:
tag_rows = conn.execute(
"SELECT file_path, tag FROM model_tags WHERE model_type = ?",
(model_type,),
).fetchall()
result: Dict[str, List[str]] = {}
for row in tag_rows:
result.setdefault(row["file_path"], []).append(row["tag"])
return result
def get_persistent_cache() -> PersistentModelCache:
return PersistentModelCache.get_default()

View File

@@ -345,14 +345,19 @@ class DownloadManager:
self._progress['processed_models'].add(model_hash)
return False # Return False to indicate no remote download happened
full_model = await MetadataUpdater.get_updated_model(
model_hash, scanner
)
civitai_payload = (full_model or {}).get('civitai', {}) if full_model else {}
# If no local images, try to download from remote
elif model.get('civitai') and model.get('civitai', {}).get('images'):
images = model.get('civitai', {}).get('images', [])
if civitai_payload.get('images'):
images = civitai_payload.get('images', [])
success, is_stale = await ExampleImagesProcessor.download_model_images(
model_hash, model_name, images, model_dir, optimize, downloader
)
# If metadata is stale, try to refresh it
if is_stale and model_hash not in self._progress['refreshed_models']:
await MetadataUpdater.refresh_model_metadata(
@@ -363,16 +368,17 @@ class DownloadManager:
updated_model = await MetadataUpdater.get_updated_model(
model_hash, scanner
)
updated_civitai = (updated_model or {}).get('civitai', {}) if updated_model else {}
if updated_model and updated_model.get('civitai', {}).get('images'):
if updated_civitai.get('images'):
# Retry download with updated metadata
updated_images = updated_model.get('civitai', {}).get('images', [])
updated_images = updated_civitai.get('images', [])
success, _ = await ExampleImagesProcessor.download_model_images(
model_hash, model_name, updated_images, model_dir, optimize, downloader
)
self._progress['refreshed_models'].add(model_hash)
# Mark as processed if successful, or as failed if unsuccessful after refresh
if success:
self._progress['processed_models'].add(model_hash)
@@ -381,13 +387,13 @@ class DownloadManager:
if model_hash in self._progress['refreshed_models']:
self._progress['failed_models'].add(model_hash)
logger.info(f"Marking model {model_name} as failed after metadata refresh")
return True # Return True to indicate a remote download happened
else:
# No civitai data or images available, mark as failed to avoid future attempts
self._progress['failed_models'].add(model_hash)
logger.debug(f"No civitai images available for model {model_name}, marking as failed")
# Save progress periodically
if self._progress['completed'] % 10 == 0 or self._progress['completed'] == self._progress['total'] - 1:
self._save_progress(output_dir)
@@ -627,51 +633,59 @@ class DownloadManager:
self._progress['processed_models'].add(model_hash)
return False # Return False to indicate no remote download happened
full_model = await MetadataUpdater.get_updated_model(
model_hash, scanner
)
civitai_payload = (full_model or {}).get('civitai', {}) if full_model else {}
# If no local images, try to download from remote
elif model.get('civitai') and model.get('civitai', {}).get('images'):
images = model.get('civitai', {}).get('images', [])
if civitai_payload.get('images'):
images = civitai_payload.get('images', [])
success, is_stale, failed_images = await ExampleImagesProcessor.download_model_images_with_tracking(
model_hash, model_name, images, model_dir, optimize, downloader
)
# If metadata is stale, try to refresh it
if is_stale and model_hash not in self._progress['refreshed_models']:
await MetadataUpdater.refresh_model_metadata(
model_hash, model_name, scanner_type, scanner, self._progress
)
# Get the updated model data
updated_model = await MetadataUpdater.get_updated_model(
model_hash, scanner
)
if updated_model and updated_model.get('civitai', {}).get('images'):
updated_civitai = (updated_model or {}).get('civitai', {}) if updated_model else {}
if updated_civitai.get('images'):
# Retry download with updated metadata
updated_images = updated_model.get('civitai', {}).get('images', [])
updated_images = updated_civitai.get('images', [])
success, _, additional_failed_images = await ExampleImagesProcessor.download_model_images_with_tracking(
model_hash, model_name, updated_images, model_dir, optimize, downloader
)
# Combine failed images from both attempts
failed_images.extend(additional_failed_images)
self._progress['refreshed_models'].add(model_hash)
# For forced downloads, remove failed images from metadata
if failed_images:
# Create a copy of images excluding failed ones
await self._remove_failed_images_from_metadata(
model_hash, model_name, failed_images, scanner
)
# Mark as processed
if success or failed_images: # Mark as processed if we successfully downloaded some images or removed failed ones
self._progress['processed_models'].add(model_hash)
return True # Return True to indicate a remote download happened
else:
logger.debug(f"No civitai images available for model {model_name}")
return False
except Exception as e:

View File

@@ -95,21 +95,35 @@ class MetadataUpdater:
@staticmethod
async def get_updated_model(model_hash, scanner):
"""Get updated model data
Args:
model_hash: SHA256 hash of the model
scanner: Scanner instance
Returns:
dict: Updated model data or None if not found
"""
"""Load the most recent metadata for a model identified by hash."""
cache = await scanner.get_cached_data()
target = None
for item in cache.raw_data:
if item.get('sha256') == model_hash:
return item
return None
target = item
break
if not target:
return None
file_path = target.get('file_path')
if not file_path:
return target
model_cls = getattr(scanner, 'model_class', None)
if model_cls is None:
metadata, should_skip = await MetadataManager.load_metadata(file_path)
else:
metadata, should_skip = await MetadataManager.load_metadata(file_path, model_cls)
if should_skip or metadata is None:
return target
rich_metadata = metadata.to_dict()
rich_metadata.setdefault('folder', target.get('folder', ''))
return rich_metadata
@staticmethod
async def update_metadata_from_local_examples(model_hash, model, scanner_type, scanner, model_dir):
"""Update model metadata with local example image information

View File

@@ -9,6 +9,7 @@ from py.services import model_scanner
from py.services.model_cache import ModelCache
from py.services.model_hash_index import ModelHashIndex
from py.services.model_scanner import CacheBuildResult, ModelScanner
from py.services.persistent_model_cache import PersistentModelCache
from py.utils.models import BaseModelMetadata
@@ -78,6 +79,11 @@ def reset_model_scanner_singletons():
ModelScanner._locks.clear()
@pytest.fixture(autouse=True)
def disable_persistent_cache_env(monkeypatch):
monkeypatch.setenv('LORA_MANAGER_DISABLE_PERSISTENT_CACHE', '1')
@pytest.fixture(autouse=True)
def stub_register_service(monkeypatch):
async def noop(*_args, **_kwargs):
@@ -175,3 +181,60 @@ async def test_initialize_in_background_applies_scan_result(tmp_path: Path, monk
assert scanner._tags_count == {"alpha": 1, "beta": 1}
assert scanner._excluded_models == [_normalize_path(tmp_path / "skip-file.txt")]
assert ws_stub.payloads[-1]["progress"] == 100
@pytest.mark.asyncio
async def test_load_persisted_cache_populates_cache(tmp_path: Path, monkeypatch):
# Enable persistence for this specific test and back it with a temp database
monkeypatch.setenv('LORA_MANAGER_DISABLE_PERSISTENT_CACHE', '0')
db_path = tmp_path / 'cache.sqlite'
store = PersistentModelCache(db_path=str(db_path))
file_path = tmp_path / 'one.txt'
file_path.write_text('one', encoding='utf-8')
normalized = _normalize_path(file_path)
raw_model = {
'file_path': normalized,
'file_name': 'one',
'model_name': 'one',
'folder': '',
'size': 3,
'modified': 123.0,
'sha256': 'hash-one',
'base_model': 'test',
'preview_url': '',
'preview_nsfw_level': 0,
'from_civitai': True,
'favorite': False,
'notes': '',
'usage_tips': '',
'exclude': False,
'db_checked': False,
'last_checked_at': 0.0,
'tags': ['alpha'],
'civitai': {'id': 11, 'modelId': 22, 'name': 'ver', 'trainedWords': ['abc']},
}
store.save_cache('dummy', [raw_model], {'hash-one': [normalized]}, [])
monkeypatch.setattr(model_scanner, 'get_persistent_cache', lambda: store)
scanner = DummyScanner(tmp_path)
ws_stub = RecordingWebSocketManager()
monkeypatch.setattr(model_scanner, 'ws_manager', ws_stub)
loaded = await scanner._load_persisted_cache('dummy')
assert loaded is True
cache = await scanner.get_cached_data()
assert len(cache.raw_data) == 1
entry = cache.raw_data[0]
assert entry['file_path'] == normalized
assert entry['tags'] == ['alpha']
assert entry['civitai']['trainedWords'] == ['abc']
assert scanner._hash_index.get_path('hash-one') == normalized
assert scanner._tags_count == {'alpha': 1}
assert ws_stub.payloads[-1]['stage'] == 'loading_cache'
assert ws_stub.payloads[-1]['progress'] == 1

View File

@@ -0,0 +1,92 @@
from pathlib import Path
import pytest
from py.services.persistent_model_cache import PersistentModelCache
def test_persistent_cache_roundtrip(tmp_path: Path, monkeypatch) -> None:
monkeypatch.setenv('LORA_MANAGER_DISABLE_PERSISTENT_CACHE', '0')
db_path = tmp_path / 'cache.sqlite'
store = PersistentModelCache(db_path=str(db_path))
file_a = (tmp_path / 'a.txt').as_posix()
file_b = (tmp_path / 'b.txt').as_posix()
duplicate_path = f"{file_b}.copy"
raw_data = [
{
'file_path': file_a,
'file_name': 'a',
'model_name': 'Model A',
'folder': '',
'size': 10,
'modified': 100.0,
'sha256': 'hash-a',
'base_model': 'base',
'preview_url': 'preview/a.png',
'preview_nsfw_level': 1,
'from_civitai': True,
'favorite': True,
'notes': 'note',
'usage_tips': '{}',
'exclude': False,
'db_checked': True,
'last_checked_at': 200.0,
'tags': ['alpha', 'beta'],
'civitai': {'id': 1, 'modelId': 2, 'name': 'verA', 'trainedWords': ['word1']},
},
{
'file_path': file_b,
'file_name': 'b',
'model_name': 'Model B',
'folder': 'folder',
'size': 20,
'modified': 120.0,
'sha256': 'hash-b',
'base_model': '',
'preview_url': '',
'preview_nsfw_level': 0,
'from_civitai': False,
'favorite': False,
'notes': '',
'usage_tips': '',
'exclude': True,
'db_checked': False,
'last_checked_at': 0.0,
'tags': [],
'civitai': None,
},
]
hash_index = {
'hash-a': [file_a],
'hash-b': [file_b, duplicate_path],
}
excluded = [duplicate_path]
store.save_cache('dummy', raw_data, hash_index, excluded)
persisted = store.load_cache('dummy')
assert persisted is not None
assert len(persisted.raw_data) == 2
items = {item['file_path']: item for item in persisted.raw_data}
assert set(items.keys()) == {file_a, file_b}
first = items[file_a]
assert first['favorite'] is True
assert first['civitai']['id'] == 1
assert first['civitai']['trainedWords'] == ['word1']
assert first['tags'] == ['alpha', 'beta']
second = items[file_b]
assert second['exclude'] is True
assert second['civitai'] is None
expected_hash_pairs = {
('hash-a', file_a),
('hash-b', file_b),
('hash-b', duplicate_path),
}
assert set((sha, path) for sha, path in persisted.hash_rows) == expected_hash_pairs
assert persisted.excluded_models == excluded