mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
feat(persistent-cache): implement SQLite-based persistent model cache with loading and saving functionality
This commit is contained in:
@@ -4,6 +4,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from ..utils.models import BaseModelMetadata
|
from ..utils.models import BaseModelMetadata
|
||||||
|
from ..utils.metadata_manager import MetadataManager
|
||||||
from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider
|
from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider
|
||||||
from .settings_manager import settings as default_settings
|
from .settings_manager import settings as default_settings
|
||||||
|
|
||||||
@@ -313,24 +314,24 @@ class BaseModelService(ABC):
|
|||||||
return {'civitai_url': None, 'model_id': None, 'version_id': None}
|
return {'civitai_url': None, 'model_id': None, 'version_id': None}
|
||||||
|
|
||||||
async def get_model_metadata(self, file_path: str) -> Optional[Dict]:
|
async def get_model_metadata(self, file_path: str) -> Optional[Dict]:
|
||||||
"""Get filtered CivitAI metadata for a model by file path"""
|
"""Load full metadata for a single model.
|
||||||
cache = await self.scanner.get_cached_data()
|
|
||||||
|
Listing/search endpoints return lightweight cache entries; this method performs
|
||||||
for model in cache.raw_data:
|
a lazy read of the on-disk metadata snapshot when callers need full detail.
|
||||||
if model.get('file_path') == file_path:
|
"""
|
||||||
return self.filter_civitai_data(model.get("civitai", {}))
|
metadata, should_skip = await MetadataManager.load_metadata(file_path, self.metadata_class)
|
||||||
|
if should_skip or metadata is None:
|
||||||
return None
|
return None
|
||||||
|
return self.filter_civitai_data(metadata.to_dict().get("civitai", {}))
|
||||||
|
|
||||||
|
|
||||||
async def get_model_description(self, file_path: str) -> Optional[str]:
|
async def get_model_description(self, file_path: str) -> Optional[str]:
|
||||||
"""Get model description by file path"""
|
"""Return the stored modelDescription field for a model."""
|
||||||
cache = await self.scanner.get_cached_data()
|
metadata, should_skip = await MetadataManager.load_metadata(file_path, self.metadata_class)
|
||||||
|
if should_skip or metadata is None:
|
||||||
for model in cache.raw_data:
|
return None
|
||||||
if model.get('file_path') == file_path:
|
return metadata.modelDescription or ''
|
||||||
return model.get('modelDescription', '')
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def search_relative_paths(self, search_term: str, limit: int = 15) -> List[str]:
|
async def search_relative_paths(self, search_term: str, limit: int = 15) -> List[str]:
|
||||||
"""Search model relative file paths for autocomplete functionality"""
|
"""Search model relative file paths for autocomplete functionality"""
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import asyncio
|
|||||||
import time
|
import time
|
||||||
import shutil
|
import shutil
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Awaitable, Callable, List, Dict, Optional, Type, Set
|
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Set, Type, Union
|
||||||
|
|
||||||
from ..utils.models import BaseModelMetadata
|
from ..utils.models import BaseModelMetadata
|
||||||
from ..config import config
|
from ..config import config
|
||||||
@@ -17,6 +17,7 @@ from ..utils.constants import PREVIEW_EXTENSIONS
|
|||||||
from .model_lifecycle_service import delete_model_artifacts
|
from .model_lifecycle_service import delete_model_artifacts
|
||||||
from .service_registry import ServiceRegistry
|
from .service_registry import ServiceRegistry
|
||||||
from .websocket_manager import ws_manager
|
from .websocket_manager import ws_manager
|
||||||
|
from .persistent_model_cache import get_persistent_cache
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -79,6 +80,7 @@ class ModelScanner:
|
|||||||
self._tags_count = {} # Dictionary to store tag counts
|
self._tags_count = {} # Dictionary to store tag counts
|
||||||
self._is_initializing = False # Flag to track initialization state
|
self._is_initializing = False # Flag to track initialization state
|
||||||
self._excluded_models = [] # List to track excluded models
|
self._excluded_models = [] # List to track excluded models
|
||||||
|
self._persistent_cache = get_persistent_cache()
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
# Register this service
|
# Register this service
|
||||||
@@ -89,6 +91,95 @@ class ModelScanner:
|
|||||||
service_name = f"{self.model_type}_scanner"
|
service_name = f"{self.model_type}_scanner"
|
||||||
await ServiceRegistry.register_service(service_name, self)
|
await ServiceRegistry.register_service(service_name, self)
|
||||||
|
|
||||||
|
def _slim_civitai_payload(self, civitai: Optional[Mapping[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Return a lightweight civitai payload containing only frequently used keys."""
|
||||||
|
if not isinstance(civitai, Mapping) or not civitai:
|
||||||
|
return None
|
||||||
|
|
||||||
|
slim: Dict[str, Any] = {}
|
||||||
|
for key in ('id', 'modelId', 'name'):
|
||||||
|
value = civitai.get(key)
|
||||||
|
if value not in (None, '', []):
|
||||||
|
slim[key] = value
|
||||||
|
|
||||||
|
trained_words = civitai.get('trainedWords')
|
||||||
|
if trained_words:
|
||||||
|
slim['trainedWords'] = list(trained_words) if isinstance(trained_words, list) else trained_words
|
||||||
|
|
||||||
|
return slim or None
|
||||||
|
|
||||||
|
def _build_cache_entry(
|
||||||
|
self,
|
||||||
|
source: Union[BaseModelMetadata, Mapping[str, Any]],
|
||||||
|
*,
|
||||||
|
folder: Optional[str] = None,
|
||||||
|
file_path_override: Optional[str] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Project metadata into the lightweight cache representation."""
|
||||||
|
is_mapping = isinstance(source, Mapping)
|
||||||
|
|
||||||
|
def get_value(key: str, default: Any = None) -> Any:
|
||||||
|
if is_mapping:
|
||||||
|
return source.get(key, default)
|
||||||
|
return getattr(source, key, default)
|
||||||
|
|
||||||
|
file_path = file_path_override or get_value('file_path', '') or ''
|
||||||
|
normalized_path = file_path.replace('\\', '/')
|
||||||
|
|
||||||
|
folder_value = folder if folder is not None else get_value('folder', '') or ''
|
||||||
|
normalized_folder = folder_value.replace('\\', '/')
|
||||||
|
|
||||||
|
tags_value = get_value('tags') or []
|
||||||
|
if isinstance(tags_value, list):
|
||||||
|
tags_list = list(tags_value)
|
||||||
|
elif isinstance(tags_value, (set, tuple)):
|
||||||
|
tags_list = list(tags_value)
|
||||||
|
else:
|
||||||
|
tags_list = []
|
||||||
|
|
||||||
|
preview_url = get_value('preview_url', '') or ''
|
||||||
|
if isinstance(preview_url, str):
|
||||||
|
preview_url = preview_url.replace('\\', '/')
|
||||||
|
else:
|
||||||
|
preview_url = ''
|
||||||
|
|
||||||
|
civitai_slim = self._slim_civitai_payload(get_value('civitai'))
|
||||||
|
usage_tips = get_value('usage_tips', '') or ''
|
||||||
|
if not isinstance(usage_tips, str):
|
||||||
|
usage_tips = str(usage_tips)
|
||||||
|
notes = get_value('notes', '') or ''
|
||||||
|
if not isinstance(notes, str):
|
||||||
|
notes = str(notes)
|
||||||
|
|
||||||
|
entry: Dict[str, Any] = {
|
||||||
|
'file_path': normalized_path,
|
||||||
|
'file_name': get_value('file_name', '') or '',
|
||||||
|
'model_name': get_value('model_name', '') or '',
|
||||||
|
'folder': normalized_folder,
|
||||||
|
'size': int(get_value('size', 0) or 0),
|
||||||
|
'modified': float(get_value('modified', 0.0) or 0.0),
|
||||||
|
'sha256': (get_value('sha256', '') or '').lower(),
|
||||||
|
'base_model': get_value('base_model', '') or '',
|
||||||
|
'preview_url': preview_url,
|
||||||
|
'preview_nsfw_level': int(get_value('preview_nsfw_level', 0) or 0),
|
||||||
|
'from_civitai': bool(get_value('from_civitai', True)),
|
||||||
|
'favorite': bool(get_value('favorite', False)),
|
||||||
|
'notes': notes,
|
||||||
|
'usage_tips': usage_tips,
|
||||||
|
'exclude': bool(get_value('exclude', False)),
|
||||||
|
'db_checked': bool(get_value('db_checked', False)),
|
||||||
|
'last_checked_at': float(get_value('last_checked_at', 0.0) or 0.0),
|
||||||
|
'tags': tags_list,
|
||||||
|
'civitai': civitai_slim,
|
||||||
|
'civitai_deleted': bool(get_value('civitai_deleted', False)),
|
||||||
|
}
|
||||||
|
|
||||||
|
model_type = get_value('model_type', None)
|
||||||
|
if model_type:
|
||||||
|
entry['model_type'] = model_type
|
||||||
|
|
||||||
|
return entry
|
||||||
|
|
||||||
async def initialize_in_background(self) -> None:
|
async def initialize_in_background(self) -> None:
|
||||||
"""Initialize cache in background using thread pool"""
|
"""Initialize cache in background using thread pool"""
|
||||||
try:
|
try:
|
||||||
@@ -113,7 +204,9 @@ class ModelScanner:
|
|||||||
'scanner_type': self.model_type,
|
'scanner_type': self.model_type,
|
||||||
'pageType': page_type
|
'pageType': page_type
|
||||||
})
|
})
|
||||||
|
|
||||||
|
await self._load_persisted_cache(page_type)
|
||||||
|
|
||||||
# If cache loading failed, proceed with full scan
|
# If cache loading failed, proceed with full scan
|
||||||
await ws_manager.broadcast_init_progress({
|
await ws_manager.broadcast_init_progress({
|
||||||
'stage': 'scan_folders',
|
'stage': 'scan_folders',
|
||||||
@@ -150,7 +243,8 @@ class ModelScanner:
|
|||||||
|
|
||||||
if scan_result:
|
if scan_result:
|
||||||
await self._apply_scan_result(scan_result)
|
await self._apply_scan_result(scan_result)
|
||||||
|
await self._save_persistent_cache(scan_result)
|
||||||
|
|
||||||
# Send final progress update
|
# Send final progress update
|
||||||
await ws_manager.broadcast_init_progress({
|
await ws_manager.broadcast_init_progress({
|
||||||
'stage': 'finalizing',
|
'stage': 'finalizing',
|
||||||
@@ -179,6 +273,105 @@ class ModelScanner:
|
|||||||
# Always clear the initializing flag when done
|
# Always clear the initializing flag when done
|
||||||
self._is_initializing = False
|
self._is_initializing = False
|
||||||
|
|
||||||
|
async def _load_persisted_cache(self, page_type: str) -> bool:
|
||||||
|
"""Attempt to hydrate the in-memory cache from the SQLite snapshot."""
|
||||||
|
if not getattr(self, '_persistent_cache', None):
|
||||||
|
return False
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
try:
|
||||||
|
persisted = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
self._persistent_cache.load_cache,
|
||||||
|
self.model_type
|
||||||
|
)
|
||||||
|
except FileNotFoundError:
|
||||||
|
return False
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("%s Scanner: Could not load persisted cache: %s", self.model_type.capitalize(), exc)
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not persisted or not persisted.raw_data:
|
||||||
|
return False
|
||||||
|
|
||||||
|
hash_index = ModelHashIndex()
|
||||||
|
for sha_value, path in persisted.hash_rows:
|
||||||
|
if sha_value and path:
|
||||||
|
hash_index.add_entry(sha_value.lower(), path)
|
||||||
|
|
||||||
|
tags_count: Dict[str, int] = {}
|
||||||
|
for item in persisted.raw_data:
|
||||||
|
for tag in item.get('tags') or []:
|
||||||
|
tags_count[tag] = tags_count.get(tag, 0) + 1
|
||||||
|
|
||||||
|
scan_result = CacheBuildResult(
|
||||||
|
raw_data=list(persisted.raw_data),
|
||||||
|
hash_index=hash_index,
|
||||||
|
tags_count=tags_count,
|
||||||
|
excluded_models=list(persisted.excluded_models)
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._apply_scan_result(scan_result)
|
||||||
|
|
||||||
|
await ws_manager.broadcast_init_progress({
|
||||||
|
'stage': 'loading_cache',
|
||||||
|
'progress': 1,
|
||||||
|
'details': f"Loaded cached {self.model_type} data from disk",
|
||||||
|
'scanner_type': self.model_type,
|
||||||
|
'pageType': page_type
|
||||||
|
})
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _save_persistent_cache(self, scan_result: CacheBuildResult) -> None:
|
||||||
|
if not scan_result or not getattr(self, '_persistent_cache', None):
|
||||||
|
return
|
||||||
|
|
||||||
|
hash_snapshot = self._build_hash_index_snapshot(scan_result.hash_index)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
try:
|
||||||
|
await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
self._persistent_cache.save_cache,
|
||||||
|
self.model_type,
|
||||||
|
list(scan_result.raw_data),
|
||||||
|
hash_snapshot,
|
||||||
|
list(scan_result.excluded_models)
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("%s Scanner: Failed to persist cache: %s", self.model_type.capitalize(), exc)
|
||||||
|
|
||||||
|
def _build_hash_index_snapshot(self, hash_index: Optional[ModelHashIndex]) -> Dict[str, List[str]]:
|
||||||
|
snapshot: Dict[str, List[str]] = {}
|
||||||
|
if not hash_index:
|
||||||
|
return snapshot
|
||||||
|
|
||||||
|
for sha_value, path in getattr(hash_index, '_hash_to_path', {}).items():
|
||||||
|
if not sha_value or not path:
|
||||||
|
continue
|
||||||
|
bucket = snapshot.setdefault(sha_value.lower(), [])
|
||||||
|
if path not in bucket:
|
||||||
|
bucket.append(path)
|
||||||
|
|
||||||
|
for sha_value, paths in getattr(hash_index, '_duplicate_hashes', {}).items():
|
||||||
|
if not sha_value:
|
||||||
|
continue
|
||||||
|
bucket = snapshot.setdefault(sha_value.lower(), [])
|
||||||
|
for path in paths:
|
||||||
|
if path and path not in bucket:
|
||||||
|
bucket.append(path)
|
||||||
|
return snapshot
|
||||||
|
|
||||||
|
async def _persist_current_cache(self) -> None:
|
||||||
|
if self._cache is None or not getattr(self, '_persistent_cache', None):
|
||||||
|
return
|
||||||
|
|
||||||
|
snapshot = CacheBuildResult(
|
||||||
|
raw_data=list(self._cache.raw_data),
|
||||||
|
hash_index=self._hash_index,
|
||||||
|
tags_count=dict(self._tags_count),
|
||||||
|
excluded_models=list(self._excluded_models)
|
||||||
|
)
|
||||||
|
await self._save_persistent_cache(snapshot)
|
||||||
def _count_model_files(self) -> int:
|
def _count_model_files(self) -> int:
|
||||||
"""Count all model files with supported extensions in all roots
|
"""Count all model files with supported extensions in all roots
|
||||||
|
|
||||||
@@ -300,6 +493,7 @@ class ModelScanner:
|
|||||||
# Scan for new data
|
# Scan for new data
|
||||||
scan_result = await self._gather_model_data()
|
scan_result = await self._gather_model_data()
|
||||||
await self._apply_scan_result(scan_result)
|
await self._apply_scan_result(scan_result)
|
||||||
|
await self._save_persistent_cache(scan_result)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, "
|
f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, "
|
||||||
@@ -594,8 +788,12 @@ class ModelScanner:
|
|||||||
# Hook: allow subclasses to adjust metadata
|
# Hook: allow subclasses to adjust metadata
|
||||||
metadata = self.adjust_metadata(metadata, file_path, root_path)
|
metadata = self.adjust_metadata(metadata, file_path, root_path)
|
||||||
|
|
||||||
model_data = metadata.to_dict()
|
rel_path = os.path.relpath(file_path, root_path)
|
||||||
|
folder = os.path.dirname(rel_path)
|
||||||
|
normalized_folder = folder.replace(os.path.sep, '/')
|
||||||
|
|
||||||
|
model_data = self._build_cache_entry(metadata, folder=normalized_folder)
|
||||||
|
|
||||||
# Skip excluded models
|
# Skip excluded models
|
||||||
if model_data.get('exclude', False):
|
if model_data.get('exclude', False):
|
||||||
excluded_models.append(model_data['file_path'])
|
excluded_models.append(model_data['file_path'])
|
||||||
@@ -609,10 +807,6 @@ class ModelScanner:
|
|||||||
# if existing_path and existing_path != file_path:
|
# if existing_path and existing_path != file_path:
|
||||||
# logger.warning(f"Duplicate filename detected: '{filename}' - files: '{existing_path}' and '{file_path}'")
|
# logger.warning(f"Duplicate filename detected: '{filename}' - files: '{existing_path}' and '{file_path}'")
|
||||||
|
|
||||||
rel_path = os.path.relpath(file_path, root_path)
|
|
||||||
folder = os.path.dirname(rel_path)
|
|
||||||
model_data['folder'] = folder.replace(os.path.sep, '/')
|
|
||||||
|
|
||||||
return model_data
|
return model_data
|
||||||
|
|
||||||
async def _apply_scan_result(self, scan_result: CacheBuildResult) -> None:
|
async def _apply_scan_result(self, scan_result: CacheBuildResult) -> None:
|
||||||
@@ -749,6 +943,7 @@ class ModelScanner:
|
|||||||
|
|
||||||
# Update the hash index
|
# Update the hash index
|
||||||
self._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
|
self._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
|
||||||
|
await self._persist_current_cache()
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error adding model to cache: {e}")
|
logger.error(f"Error adding model to cache: {e}")
|
||||||
@@ -894,28 +1089,30 @@ class ModelScanner:
|
|||||||
]
|
]
|
||||||
|
|
||||||
if metadata:
|
if metadata:
|
||||||
if original_path == new_path:
|
normalized_new_path = new_path.replace(os.sep, '/')
|
||||||
existing_folder = next((item['folder'] for item in cache.raw_data
|
if original_path == new_path and existing_item:
|
||||||
if item['file_path'] == original_path), None)
|
folder_value = existing_item.get('folder', self._calculate_folder(new_path))
|
||||||
if existing_folder:
|
|
||||||
metadata['folder'] = existing_folder
|
|
||||||
else:
|
|
||||||
metadata['folder'] = self._calculate_folder(new_path)
|
|
||||||
else:
|
else:
|
||||||
metadata['folder'] = self._calculate_folder(new_path)
|
folder_value = self._calculate_folder(new_path)
|
||||||
|
|
||||||
cache.raw_data.append(metadata)
|
cache_entry = self._build_cache_entry(
|
||||||
|
metadata,
|
||||||
if 'sha256' in metadata:
|
folder=folder_value,
|
||||||
self._hash_index.add_entry(metadata['sha256'].lower(), new_path)
|
file_path_override=normalized_new_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
cache.raw_data.append(cache_entry)
|
||||||
|
|
||||||
|
sha_value = cache_entry.get('sha256')
|
||||||
|
if sha_value:
|
||||||
|
self._hash_index.add_entry(sha_value.lower(), normalized_new_path)
|
||||||
|
|
||||||
all_folders = set(item['folder'] for item in cache.raw_data)
|
all_folders = set(item['folder'] for item in cache.raw_data)
|
||||||
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||||
|
|
||||||
if 'tags' in metadata:
|
for tag in cache_entry.get('tags', []):
|
||||||
for tag in metadata.get('tags', []):
|
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
|
||||||
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
|
|
||||||
|
|
||||||
await cache.resort()
|
await cache.resort()
|
||||||
|
|
||||||
return True
|
return True
|
||||||
@@ -1019,7 +1216,10 @@ class ModelScanner:
|
|||||||
if self._cache is None:
|
if self._cache is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return await self._cache.update_preview_url(file_path, preview_url, preview_nsfw_level)
|
updated = await self._cache.update_preview_url(file_path, preview_url, preview_nsfw_level)
|
||||||
|
if updated:
|
||||||
|
await self._persist_current_cache()
|
||||||
|
return updated
|
||||||
|
|
||||||
async def bulk_delete_models(self, file_paths: List[str]) -> Dict:
|
async def bulk_delete_models(self, file_paths: List[str]) -> Dict:
|
||||||
"""Delete multiple models and update cache in a batch operation
|
"""Delete multiple models and update cache in a batch operation
|
||||||
|
|||||||
346
py/services/persistent_model_cache.py
Normal file
346
py/services/persistent_model_cache.py
Normal file
@@ -0,0 +1,346 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
from ..utils.settings_paths import get_settings_dir
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PersistedCacheData:
|
||||||
|
"""Lightweight structure returned by the persistent cache."""
|
||||||
|
|
||||||
|
raw_data: List[Dict]
|
||||||
|
hash_rows: List[Tuple[str, str]]
|
||||||
|
excluded_models: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class PersistentModelCache:
|
||||||
|
"""Persist core model metadata and hash index data in SQLite."""
|
||||||
|
|
||||||
|
_DEFAULT_FILENAME = "model_cache.sqlite"
|
||||||
|
_instance: Optional["PersistentModelCache"] = None
|
||||||
|
_instance_lock = threading.Lock()
|
||||||
|
|
||||||
|
def __init__(self, db_path: Optional[str] = None) -> None:
|
||||||
|
self._db_path = db_path or self._resolve_default_path()
|
||||||
|
self._db_lock = threading.Lock()
|
||||||
|
self._schema_initialized = False
|
||||||
|
try:
|
||||||
|
directory = os.path.dirname(self._db_path)
|
||||||
|
if directory:
|
||||||
|
os.makedirs(directory, exist_ok=True)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive guard
|
||||||
|
logger.warning("Could not create cache directory %s: %s", directory, exc)
|
||||||
|
if self.is_enabled():
|
||||||
|
self._initialize_schema()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_default(cls) -> "PersistentModelCache":
|
||||||
|
with cls._instance_lock:
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = cls()
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def is_enabled(self) -> bool:
|
||||||
|
return os.environ.get("LORA_MANAGER_DISABLE_PERSISTENT_CACHE", "0") != "1"
|
||||||
|
|
||||||
|
def load_cache(self, model_type: str) -> Optional[PersistedCacheData]:
|
||||||
|
if not self.is_enabled():
|
||||||
|
return None
|
||||||
|
if not self._schema_initialized:
|
||||||
|
self._initialize_schema()
|
||||||
|
if not self._schema_initialized:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
with self._db_lock:
|
||||||
|
conn = self._connect(readonly=True)
|
||||||
|
try:
|
||||||
|
rows = conn.execute(
|
||||||
|
"SELECT file_path, file_name, model_name, folder, size, modified, sha256, base_model,"
|
||||||
|
" preview_url, preview_nsfw_level, from_civitai, favorite, notes, usage_tips,"
|
||||||
|
" civitai_id, civitai_model_id, civitai_name, trained_words, exclude, db_checked,"
|
||||||
|
" last_checked_at"
|
||||||
|
" FROM models WHERE model_type = ?",
|
||||||
|
(model_type,),
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
return None
|
||||||
|
|
||||||
|
tags = self._load_tags(conn, model_type)
|
||||||
|
hash_rows = conn.execute(
|
||||||
|
"SELECT sha256, file_path FROM hash_index WHERE model_type = ?",
|
||||||
|
(model_type,),
|
||||||
|
).fetchall()
|
||||||
|
excluded = conn.execute(
|
||||||
|
"SELECT file_path FROM excluded_models WHERE model_type = ?",
|
||||||
|
(model_type,),
|
||||||
|
).fetchall()
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to load persisted cache for %s: %s", model_type, exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
raw_data: List[Dict] = []
|
||||||
|
for row in rows:
|
||||||
|
file_path: str = row["file_path"]
|
||||||
|
trained_words = []
|
||||||
|
if row["trained_words"]:
|
||||||
|
try:
|
||||||
|
trained_words = json.loads(row["trained_words"])
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
trained_words = []
|
||||||
|
|
||||||
|
civitai: Optional[Dict] = None
|
||||||
|
if any(row[col] is not None for col in ("civitai_id", "civitai_model_id", "civitai_name")):
|
||||||
|
civitai = {}
|
||||||
|
if row["civitai_id"] is not None:
|
||||||
|
civitai["id"] = row["civitai_id"]
|
||||||
|
if row["civitai_model_id"] is not None:
|
||||||
|
civitai["modelId"] = row["civitai_model_id"]
|
||||||
|
if row["civitai_name"]:
|
||||||
|
civitai["name"] = row["civitai_name"]
|
||||||
|
if trained_words:
|
||||||
|
civitai["trainedWords"] = trained_words
|
||||||
|
|
||||||
|
item = {
|
||||||
|
"file_path": file_path,
|
||||||
|
"file_name": row["file_name"],
|
||||||
|
"model_name": row["model_name"],
|
||||||
|
"folder": row["folder"] or "",
|
||||||
|
"size": row["size"] or 0,
|
||||||
|
"modified": row["modified"] or 0.0,
|
||||||
|
"sha256": row["sha256"] or "",
|
||||||
|
"base_model": row["base_model"] or "",
|
||||||
|
"preview_url": row["preview_url"] or "",
|
||||||
|
"preview_nsfw_level": row["preview_nsfw_level"] or 0,
|
||||||
|
"from_civitai": bool(row["from_civitai"]),
|
||||||
|
"favorite": bool(row["favorite"]),
|
||||||
|
"notes": row["notes"] or "",
|
||||||
|
"usage_tips": row["usage_tips"] or "",
|
||||||
|
"exclude": bool(row["exclude"]),
|
||||||
|
"db_checked": bool(row["db_checked"]),
|
||||||
|
"last_checked_at": row["last_checked_at"] or 0.0,
|
||||||
|
"tags": tags.get(file_path, []),
|
||||||
|
"civitai": civitai,
|
||||||
|
}
|
||||||
|
raw_data.append(item)
|
||||||
|
|
||||||
|
hash_pairs = [(entry["sha256"].lower(), entry["file_path"]) for entry in hash_rows if entry["sha256"]]
|
||||||
|
if not hash_pairs:
|
||||||
|
# Fall back to hashes stored on the model rows
|
||||||
|
for item in raw_data:
|
||||||
|
sha_value = item.get("sha256")
|
||||||
|
if sha_value:
|
||||||
|
hash_pairs.append((sha_value.lower(), item["file_path"]))
|
||||||
|
|
||||||
|
excluded_paths = [row["file_path"] for row in excluded]
|
||||||
|
return PersistedCacheData(raw_data=raw_data, hash_rows=hash_pairs, excluded_models=excluded_paths)
|
||||||
|
|
||||||
|
def save_cache(self, model_type: str, raw_data: Sequence[Dict], hash_index: Dict[str, List[str]], excluded_models: Sequence[str]) -> None:
|
||||||
|
if not self.is_enabled():
|
||||||
|
return
|
||||||
|
if not self._schema_initialized:
|
||||||
|
self._initialize_schema()
|
||||||
|
if not self._schema_initialized:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
with self._db_lock:
|
||||||
|
conn = self._connect()
|
||||||
|
try:
|
||||||
|
conn.execute("PRAGMA foreign_keys = ON")
|
||||||
|
conn.execute("DELETE FROM models WHERE model_type = ?", (model_type,))
|
||||||
|
conn.execute("DELETE FROM model_tags WHERE model_type = ?", (model_type,))
|
||||||
|
conn.execute("DELETE FROM hash_index WHERE model_type = ?", (model_type,))
|
||||||
|
conn.execute("DELETE FROM excluded_models WHERE model_type = ?", (model_type,))
|
||||||
|
|
||||||
|
model_rows = [self._prepare_model_row(model_type, item) for item in raw_data]
|
||||||
|
conn.executemany(self._insert_model_sql(), model_rows)
|
||||||
|
|
||||||
|
tag_rows = []
|
||||||
|
for item in raw_data:
|
||||||
|
file_path = item.get("file_path")
|
||||||
|
if not file_path:
|
||||||
|
continue
|
||||||
|
for tag in item.get("tags") or []:
|
||||||
|
tag_rows.append((model_type, file_path, tag))
|
||||||
|
if tag_rows:
|
||||||
|
conn.executemany(
|
||||||
|
"INSERT INTO model_tags (model_type, file_path, tag) VALUES (?, ?, ?)",
|
||||||
|
tag_rows,
|
||||||
|
)
|
||||||
|
|
||||||
|
hash_rows: List[Tuple[str, str, str]] = []
|
||||||
|
for sha_value, paths in hash_index.items():
|
||||||
|
for path in paths:
|
||||||
|
if not sha_value or not path:
|
||||||
|
continue
|
||||||
|
hash_rows.append((model_type, sha_value.lower(), path))
|
||||||
|
if hash_rows:
|
||||||
|
conn.executemany(
|
||||||
|
"INSERT OR IGNORE INTO hash_index (model_type, sha256, file_path) VALUES (?, ?, ?)",
|
||||||
|
hash_rows,
|
||||||
|
)
|
||||||
|
|
||||||
|
excluded_rows = [(model_type, path) for path in excluded_models]
|
||||||
|
if excluded_rows:
|
||||||
|
conn.executemany(
|
||||||
|
"INSERT OR IGNORE INTO excluded_models (model_type, file_path) VALUES (?, ?)",
|
||||||
|
excluded_rows,
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to persist cache for %s: %s", model_type, exc)
|
||||||
|
|
||||||
|
# Internal helpers -------------------------------------------------
|
||||||
|
|
||||||
|
def _resolve_default_path(self) -> str:
|
||||||
|
override = os.environ.get("LORA_MANAGER_CACHE_DB")
|
||||||
|
if override:
|
||||||
|
return override
|
||||||
|
try:
|
||||||
|
settings_dir = get_settings_dir(create=True)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive guard
|
||||||
|
logger.warning("Falling back to project directory for cache: %s", exc)
|
||||||
|
settings_dir = os.path.dirname(os.path.dirname(self._db_path)) if hasattr(self, "_db_path") else os.getcwd()
|
||||||
|
return os.path.join(settings_dir, self._DEFAULT_FILENAME)
|
||||||
|
|
||||||
|
def _initialize_schema(self) -> None:
|
||||||
|
with self._db_lock:
|
||||||
|
if self._schema_initialized:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
with self._connect() as conn:
|
||||||
|
conn.execute("PRAGMA journal_mode=WAL")
|
||||||
|
conn.execute("PRAGMA foreign_keys = ON")
|
||||||
|
conn.executescript(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS models (
|
||||||
|
model_type TEXT NOT NULL,
|
||||||
|
file_path TEXT NOT NULL,
|
||||||
|
file_name TEXT,
|
||||||
|
model_name TEXT,
|
||||||
|
folder TEXT,
|
||||||
|
size INTEGER,
|
||||||
|
modified REAL,
|
||||||
|
sha256 TEXT,
|
||||||
|
base_model TEXT,
|
||||||
|
preview_url TEXT,
|
||||||
|
preview_nsfw_level INTEGER,
|
||||||
|
from_civitai INTEGER,
|
||||||
|
favorite INTEGER,
|
||||||
|
notes TEXT,
|
||||||
|
usage_tips TEXT,
|
||||||
|
civitai_id INTEGER,
|
||||||
|
civitai_model_id INTEGER,
|
||||||
|
civitai_name TEXT,
|
||||||
|
trained_words TEXT,
|
||||||
|
exclude INTEGER,
|
||||||
|
db_checked INTEGER,
|
||||||
|
last_checked_at REAL,
|
||||||
|
PRIMARY KEY (model_type, file_path)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS model_tags (
|
||||||
|
model_type TEXT NOT NULL,
|
||||||
|
file_path TEXT NOT NULL,
|
||||||
|
tag TEXT NOT NULL,
|
||||||
|
PRIMARY KEY (model_type, file_path, tag)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS hash_index (
|
||||||
|
model_type TEXT NOT NULL,
|
||||||
|
sha256 TEXT NOT NULL,
|
||||||
|
file_path TEXT NOT NULL,
|
||||||
|
PRIMARY KEY (model_type, sha256, file_path)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS excluded_models (
|
||||||
|
model_type TEXT NOT NULL,
|
||||||
|
file_path TEXT NOT NULL,
|
||||||
|
PRIMARY KEY (model_type, file_path)
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
self._schema_initialized = True
|
||||||
|
except Exception as exc: # pragma: no cover - defensive guard
|
||||||
|
logger.warning("Failed to initialize persistent cache schema: %s", exc)
|
||||||
|
|
||||||
|
def _connect(self, readonly: bool = False) -> sqlite3.Connection:
|
||||||
|
uri = False
|
||||||
|
path = self._db_path
|
||||||
|
if readonly:
|
||||||
|
if not os.path.exists(path):
|
||||||
|
raise FileNotFoundError(path)
|
||||||
|
path = f"file:{path}?mode=ro"
|
||||||
|
uri = True
|
||||||
|
conn = sqlite3.connect(path, check_same_thread=False, uri=uri, detect_types=sqlite3.PARSE_DECLTYPES)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
return conn
|
||||||
|
|
||||||
|
def _prepare_model_row(self, model_type: str, item: Dict) -> Tuple:
|
||||||
|
civitai = item.get("civitai") or {}
|
||||||
|
trained_words = civitai.get("trainedWords")
|
||||||
|
if isinstance(trained_words, str):
|
||||||
|
trained_words_json = trained_words
|
||||||
|
elif trained_words is None:
|
||||||
|
trained_words_json = None
|
||||||
|
else:
|
||||||
|
trained_words_json = json.dumps(trained_words)
|
||||||
|
|
||||||
|
return (
|
||||||
|
model_type,
|
||||||
|
item.get("file_path"),
|
||||||
|
item.get("file_name"),
|
||||||
|
item.get("model_name"),
|
||||||
|
item.get("folder"),
|
||||||
|
int(item.get("size") or 0),
|
||||||
|
float(item.get("modified") or 0.0),
|
||||||
|
(item.get("sha256") or "").lower() or None,
|
||||||
|
item.get("base_model"),
|
||||||
|
item.get("preview_url"),
|
||||||
|
int(item.get("preview_nsfw_level") or 0),
|
||||||
|
1 if item.get("from_civitai", True) else 0,
|
||||||
|
1 if item.get("favorite") else 0,
|
||||||
|
item.get("notes"),
|
||||||
|
item.get("usage_tips"),
|
||||||
|
civitai.get("id"),
|
||||||
|
civitai.get("modelId"),
|
||||||
|
civitai.get("name"),
|
||||||
|
trained_words_json,
|
||||||
|
1 if item.get("exclude") else 0,
|
||||||
|
1 if item.get("db_checked") else 0,
|
||||||
|
float(item.get("last_checked_at") or 0.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _insert_model_sql(self) -> str:
|
||||||
|
return (
|
||||||
|
"INSERT INTO models (model_type, file_path, file_name, model_name, folder, size, modified, sha256,"
|
||||||
|
" base_model, preview_url, preview_nsfw_level, from_civitai, favorite, notes, usage_tips,"
|
||||||
|
" civitai_id, civitai_model_id, civitai_name, trained_words, exclude, db_checked, last_checked_at)"
|
||||||
|
" VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_tags(self, conn: sqlite3.Connection, model_type: str) -> Dict[str, List[str]]:
|
||||||
|
tag_rows = conn.execute(
|
||||||
|
"SELECT file_path, tag FROM model_tags WHERE model_type = ?",
|
||||||
|
(model_type,),
|
||||||
|
).fetchall()
|
||||||
|
result: Dict[str, List[str]] = {}
|
||||||
|
for row in tag_rows:
|
||||||
|
result.setdefault(row["file_path"], []).append(row["tag"])
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_persistent_cache() -> PersistentModelCache:
|
||||||
|
return PersistentModelCache.get_default()
|
||||||
@@ -345,14 +345,19 @@ class DownloadManager:
|
|||||||
self._progress['processed_models'].add(model_hash)
|
self._progress['processed_models'].add(model_hash)
|
||||||
return False # Return False to indicate no remote download happened
|
return False # Return False to indicate no remote download happened
|
||||||
|
|
||||||
|
full_model = await MetadataUpdater.get_updated_model(
|
||||||
|
model_hash, scanner
|
||||||
|
)
|
||||||
|
civitai_payload = (full_model or {}).get('civitai', {}) if full_model else {}
|
||||||
|
|
||||||
# If no local images, try to download from remote
|
# If no local images, try to download from remote
|
||||||
elif model.get('civitai') and model.get('civitai', {}).get('images'):
|
if civitai_payload.get('images'):
|
||||||
images = model.get('civitai', {}).get('images', [])
|
images = civitai_payload.get('images', [])
|
||||||
|
|
||||||
success, is_stale = await ExampleImagesProcessor.download_model_images(
|
success, is_stale = await ExampleImagesProcessor.download_model_images(
|
||||||
model_hash, model_name, images, model_dir, optimize, downloader
|
model_hash, model_name, images, model_dir, optimize, downloader
|
||||||
)
|
)
|
||||||
|
|
||||||
# If metadata is stale, try to refresh it
|
# If metadata is stale, try to refresh it
|
||||||
if is_stale and model_hash not in self._progress['refreshed_models']:
|
if is_stale and model_hash not in self._progress['refreshed_models']:
|
||||||
await MetadataUpdater.refresh_model_metadata(
|
await MetadataUpdater.refresh_model_metadata(
|
||||||
@@ -363,16 +368,17 @@ class DownloadManager:
|
|||||||
updated_model = await MetadataUpdater.get_updated_model(
|
updated_model = await MetadataUpdater.get_updated_model(
|
||||||
model_hash, scanner
|
model_hash, scanner
|
||||||
)
|
)
|
||||||
|
updated_civitai = (updated_model or {}).get('civitai', {}) if updated_model else {}
|
||||||
|
|
||||||
if updated_model and updated_model.get('civitai', {}).get('images'):
|
if updated_civitai.get('images'):
|
||||||
# Retry download with updated metadata
|
# Retry download with updated metadata
|
||||||
updated_images = updated_model.get('civitai', {}).get('images', [])
|
updated_images = updated_civitai.get('images', [])
|
||||||
success, _ = await ExampleImagesProcessor.download_model_images(
|
success, _ = await ExampleImagesProcessor.download_model_images(
|
||||||
model_hash, model_name, updated_images, model_dir, optimize, downloader
|
model_hash, model_name, updated_images, model_dir, optimize, downloader
|
||||||
)
|
)
|
||||||
|
|
||||||
self._progress['refreshed_models'].add(model_hash)
|
self._progress['refreshed_models'].add(model_hash)
|
||||||
|
|
||||||
# Mark as processed if successful, or as failed if unsuccessful after refresh
|
# Mark as processed if successful, or as failed if unsuccessful after refresh
|
||||||
if success:
|
if success:
|
||||||
self._progress['processed_models'].add(model_hash)
|
self._progress['processed_models'].add(model_hash)
|
||||||
@@ -381,13 +387,13 @@ class DownloadManager:
|
|||||||
if model_hash in self._progress['refreshed_models']:
|
if model_hash in self._progress['refreshed_models']:
|
||||||
self._progress['failed_models'].add(model_hash)
|
self._progress['failed_models'].add(model_hash)
|
||||||
logger.info(f"Marking model {model_name} as failed after metadata refresh")
|
logger.info(f"Marking model {model_name} as failed after metadata refresh")
|
||||||
|
|
||||||
return True # Return True to indicate a remote download happened
|
return True # Return True to indicate a remote download happened
|
||||||
else:
|
else:
|
||||||
# No civitai data or images available, mark as failed to avoid future attempts
|
# No civitai data or images available, mark as failed to avoid future attempts
|
||||||
self._progress['failed_models'].add(model_hash)
|
self._progress['failed_models'].add(model_hash)
|
||||||
logger.debug(f"No civitai images available for model {model_name}, marking as failed")
|
logger.debug(f"No civitai images available for model {model_name}, marking as failed")
|
||||||
|
|
||||||
# Save progress periodically
|
# Save progress periodically
|
||||||
if self._progress['completed'] % 10 == 0 or self._progress['completed'] == self._progress['total'] - 1:
|
if self._progress['completed'] % 10 == 0 or self._progress['completed'] == self._progress['total'] - 1:
|
||||||
self._save_progress(output_dir)
|
self._save_progress(output_dir)
|
||||||
@@ -627,51 +633,59 @@ class DownloadManager:
|
|||||||
self._progress['processed_models'].add(model_hash)
|
self._progress['processed_models'].add(model_hash)
|
||||||
return False # Return False to indicate no remote download happened
|
return False # Return False to indicate no remote download happened
|
||||||
|
|
||||||
|
full_model = await MetadataUpdater.get_updated_model(
|
||||||
|
model_hash, scanner
|
||||||
|
)
|
||||||
|
civitai_payload = (full_model or {}).get('civitai', {}) if full_model else {}
|
||||||
|
|
||||||
# If no local images, try to download from remote
|
# If no local images, try to download from remote
|
||||||
elif model.get('civitai') and model.get('civitai', {}).get('images'):
|
if civitai_payload.get('images'):
|
||||||
images = model.get('civitai', {}).get('images', [])
|
images = civitai_payload.get('images', [])
|
||||||
|
|
||||||
success, is_stale, failed_images = await ExampleImagesProcessor.download_model_images_with_tracking(
|
success, is_stale, failed_images = await ExampleImagesProcessor.download_model_images_with_tracking(
|
||||||
model_hash, model_name, images, model_dir, optimize, downloader
|
model_hash, model_name, images, model_dir, optimize, downloader
|
||||||
)
|
)
|
||||||
|
|
||||||
# If metadata is stale, try to refresh it
|
# If metadata is stale, try to refresh it
|
||||||
if is_stale and model_hash not in self._progress['refreshed_models']:
|
if is_stale and model_hash not in self._progress['refreshed_models']:
|
||||||
await MetadataUpdater.refresh_model_metadata(
|
await MetadataUpdater.refresh_model_metadata(
|
||||||
model_hash, model_name, scanner_type, scanner, self._progress
|
model_hash, model_name, scanner_type, scanner, self._progress
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the updated model data
|
# Get the updated model data
|
||||||
updated_model = await MetadataUpdater.get_updated_model(
|
updated_model = await MetadataUpdater.get_updated_model(
|
||||||
model_hash, scanner
|
model_hash, scanner
|
||||||
)
|
)
|
||||||
|
updated_civitai = (updated_model or {}).get('civitai', {}) if updated_model else {}
|
||||||
if updated_model and updated_model.get('civitai', {}).get('images'):
|
|
||||||
|
if updated_civitai.get('images'):
|
||||||
# Retry download with updated metadata
|
# Retry download with updated metadata
|
||||||
updated_images = updated_model.get('civitai', {}).get('images', [])
|
updated_images = updated_civitai.get('images', [])
|
||||||
success, _, additional_failed_images = await ExampleImagesProcessor.download_model_images_with_tracking(
|
success, _, additional_failed_images = await ExampleImagesProcessor.download_model_images_with_tracking(
|
||||||
model_hash, model_name, updated_images, model_dir, optimize, downloader
|
model_hash, model_name, updated_images, model_dir, optimize, downloader
|
||||||
)
|
)
|
||||||
|
|
||||||
# Combine failed images from both attempts
|
# Combine failed images from both attempts
|
||||||
failed_images.extend(additional_failed_images)
|
failed_images.extend(additional_failed_images)
|
||||||
|
|
||||||
self._progress['refreshed_models'].add(model_hash)
|
self._progress['refreshed_models'].add(model_hash)
|
||||||
|
|
||||||
# For forced downloads, remove failed images from metadata
|
# For forced downloads, remove failed images from metadata
|
||||||
if failed_images:
|
if failed_images:
|
||||||
# Create a copy of images excluding failed ones
|
# Create a copy of images excluding failed ones
|
||||||
await self._remove_failed_images_from_metadata(
|
await self._remove_failed_images_from_metadata(
|
||||||
model_hash, model_name, failed_images, scanner
|
model_hash, model_name, failed_images, scanner
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mark as processed
|
# Mark as processed
|
||||||
if success or failed_images: # Mark as processed if we successfully downloaded some images or removed failed ones
|
if success or failed_images: # Mark as processed if we successfully downloaded some images or removed failed ones
|
||||||
self._progress['processed_models'].add(model_hash)
|
self._progress['processed_models'].add(model_hash)
|
||||||
|
|
||||||
return True # Return True to indicate a remote download happened
|
return True # Return True to indicate a remote download happened
|
||||||
else:
|
else:
|
||||||
logger.debug(f"No civitai images available for model {model_name}")
|
logger.debug(f"No civitai images available for model {model_name}")
|
||||||
|
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -95,21 +95,35 @@ class MetadataUpdater:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def get_updated_model(model_hash, scanner):
|
async def get_updated_model(model_hash, scanner):
|
||||||
"""Get updated model data
|
"""Load the most recent metadata for a model identified by hash."""
|
||||||
|
|
||||||
Args:
|
|
||||||
model_hash: SHA256 hash of the model
|
|
||||||
scanner: Scanner instance
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Updated model data or None if not found
|
|
||||||
"""
|
|
||||||
cache = await scanner.get_cached_data()
|
cache = await scanner.get_cached_data()
|
||||||
|
target = None
|
||||||
for item in cache.raw_data:
|
for item in cache.raw_data:
|
||||||
if item.get('sha256') == model_hash:
|
if item.get('sha256') == model_hash:
|
||||||
return item
|
target = item
|
||||||
return None
|
break
|
||||||
|
|
||||||
|
if not target:
|
||||||
|
return None
|
||||||
|
|
||||||
|
file_path = target.get('file_path')
|
||||||
|
if not file_path:
|
||||||
|
return target
|
||||||
|
|
||||||
|
model_cls = getattr(scanner, 'model_class', None)
|
||||||
|
if model_cls is None:
|
||||||
|
metadata, should_skip = await MetadataManager.load_metadata(file_path)
|
||||||
|
else:
|
||||||
|
metadata, should_skip = await MetadataManager.load_metadata(file_path, model_cls)
|
||||||
|
|
||||||
|
if should_skip or metadata is None:
|
||||||
|
return target
|
||||||
|
|
||||||
|
rich_metadata = metadata.to_dict()
|
||||||
|
rich_metadata.setdefault('folder', target.get('folder', ''))
|
||||||
|
return rich_metadata
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def update_metadata_from_local_examples(model_hash, model, scanner_type, scanner, model_dir):
|
async def update_metadata_from_local_examples(model_hash, model, scanner_type, scanner, model_dir):
|
||||||
"""Update model metadata with local example image information
|
"""Update model metadata with local example image information
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from py.services import model_scanner
|
|||||||
from py.services.model_cache import ModelCache
|
from py.services.model_cache import ModelCache
|
||||||
from py.services.model_hash_index import ModelHashIndex
|
from py.services.model_hash_index import ModelHashIndex
|
||||||
from py.services.model_scanner import CacheBuildResult, ModelScanner
|
from py.services.model_scanner import CacheBuildResult, ModelScanner
|
||||||
|
from py.services.persistent_model_cache import PersistentModelCache
|
||||||
from py.utils.models import BaseModelMetadata
|
from py.utils.models import BaseModelMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -78,6 +79,11 @@ def reset_model_scanner_singletons():
|
|||||||
ModelScanner._locks.clear()
|
ModelScanner._locks.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def disable_persistent_cache_env(monkeypatch):
|
||||||
|
monkeypatch.setenv('LORA_MANAGER_DISABLE_PERSISTENT_CACHE', '1')
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def stub_register_service(monkeypatch):
|
def stub_register_service(monkeypatch):
|
||||||
async def noop(*_args, **_kwargs):
|
async def noop(*_args, **_kwargs):
|
||||||
@@ -175,3 +181,60 @@ async def test_initialize_in_background_applies_scan_result(tmp_path: Path, monk
|
|||||||
assert scanner._tags_count == {"alpha": 1, "beta": 1}
|
assert scanner._tags_count == {"alpha": 1, "beta": 1}
|
||||||
assert scanner._excluded_models == [_normalize_path(tmp_path / "skip-file.txt")]
|
assert scanner._excluded_models == [_normalize_path(tmp_path / "skip-file.txt")]
|
||||||
assert ws_stub.payloads[-1]["progress"] == 100
|
assert ws_stub.payloads[-1]["progress"] == 100
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_persisted_cache_populates_cache(tmp_path: Path, monkeypatch):
|
||||||
|
# Enable persistence for this specific test and back it with a temp database
|
||||||
|
monkeypatch.setenv('LORA_MANAGER_DISABLE_PERSISTENT_CACHE', '0')
|
||||||
|
db_path = tmp_path / 'cache.sqlite'
|
||||||
|
store = PersistentModelCache(db_path=str(db_path))
|
||||||
|
|
||||||
|
file_path = tmp_path / 'one.txt'
|
||||||
|
file_path.write_text('one', encoding='utf-8')
|
||||||
|
normalized = _normalize_path(file_path)
|
||||||
|
|
||||||
|
raw_model = {
|
||||||
|
'file_path': normalized,
|
||||||
|
'file_name': 'one',
|
||||||
|
'model_name': 'one',
|
||||||
|
'folder': '',
|
||||||
|
'size': 3,
|
||||||
|
'modified': 123.0,
|
||||||
|
'sha256': 'hash-one',
|
||||||
|
'base_model': 'test',
|
||||||
|
'preview_url': '',
|
||||||
|
'preview_nsfw_level': 0,
|
||||||
|
'from_civitai': True,
|
||||||
|
'favorite': False,
|
||||||
|
'notes': '',
|
||||||
|
'usage_tips': '',
|
||||||
|
'exclude': False,
|
||||||
|
'db_checked': False,
|
||||||
|
'last_checked_at': 0.0,
|
||||||
|
'tags': ['alpha'],
|
||||||
|
'civitai': {'id': 11, 'modelId': 22, 'name': 'ver', 'trainedWords': ['abc']},
|
||||||
|
}
|
||||||
|
|
||||||
|
store.save_cache('dummy', [raw_model], {'hash-one': [normalized]}, [])
|
||||||
|
|
||||||
|
monkeypatch.setattr(model_scanner, 'get_persistent_cache', lambda: store)
|
||||||
|
|
||||||
|
scanner = DummyScanner(tmp_path)
|
||||||
|
ws_stub = RecordingWebSocketManager()
|
||||||
|
monkeypatch.setattr(model_scanner, 'ws_manager', ws_stub)
|
||||||
|
|
||||||
|
loaded = await scanner._load_persisted_cache('dummy')
|
||||||
|
assert loaded is True
|
||||||
|
|
||||||
|
cache = await scanner.get_cached_data()
|
||||||
|
assert len(cache.raw_data) == 1
|
||||||
|
entry = cache.raw_data[0]
|
||||||
|
assert entry['file_path'] == normalized
|
||||||
|
assert entry['tags'] == ['alpha']
|
||||||
|
assert entry['civitai']['trainedWords'] == ['abc']
|
||||||
|
assert scanner._hash_index.get_path('hash-one') == normalized
|
||||||
|
assert scanner._tags_count == {'alpha': 1}
|
||||||
|
assert ws_stub.payloads[-1]['stage'] == 'loading_cache'
|
||||||
|
assert ws_stub.payloads[-1]['progress'] == 1
|
||||||
|
|||||||
92
tests/services/test_persistent_model_cache.py
Normal file
92
tests/services/test_persistent_model_cache.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from py.services.persistent_model_cache import PersistentModelCache
|
||||||
|
|
||||||
|
|
||||||
|
def test_persistent_cache_roundtrip(tmp_path: Path, monkeypatch) -> None:
|
||||||
|
monkeypatch.setenv('LORA_MANAGER_DISABLE_PERSISTENT_CACHE', '0')
|
||||||
|
db_path = tmp_path / 'cache.sqlite'
|
||||||
|
store = PersistentModelCache(db_path=str(db_path))
|
||||||
|
|
||||||
|
file_a = (tmp_path / 'a.txt').as_posix()
|
||||||
|
file_b = (tmp_path / 'b.txt').as_posix()
|
||||||
|
duplicate_path = f"{file_b}.copy"
|
||||||
|
|
||||||
|
raw_data = [
|
||||||
|
{
|
||||||
|
'file_path': file_a,
|
||||||
|
'file_name': 'a',
|
||||||
|
'model_name': 'Model A',
|
||||||
|
'folder': '',
|
||||||
|
'size': 10,
|
||||||
|
'modified': 100.0,
|
||||||
|
'sha256': 'hash-a',
|
||||||
|
'base_model': 'base',
|
||||||
|
'preview_url': 'preview/a.png',
|
||||||
|
'preview_nsfw_level': 1,
|
||||||
|
'from_civitai': True,
|
||||||
|
'favorite': True,
|
||||||
|
'notes': 'note',
|
||||||
|
'usage_tips': '{}',
|
||||||
|
'exclude': False,
|
||||||
|
'db_checked': True,
|
||||||
|
'last_checked_at': 200.0,
|
||||||
|
'tags': ['alpha', 'beta'],
|
||||||
|
'civitai': {'id': 1, 'modelId': 2, 'name': 'verA', 'trainedWords': ['word1']},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'file_path': file_b,
|
||||||
|
'file_name': 'b',
|
||||||
|
'model_name': 'Model B',
|
||||||
|
'folder': 'folder',
|
||||||
|
'size': 20,
|
||||||
|
'modified': 120.0,
|
||||||
|
'sha256': 'hash-b',
|
||||||
|
'base_model': '',
|
||||||
|
'preview_url': '',
|
||||||
|
'preview_nsfw_level': 0,
|
||||||
|
'from_civitai': False,
|
||||||
|
'favorite': False,
|
||||||
|
'notes': '',
|
||||||
|
'usage_tips': '',
|
||||||
|
'exclude': True,
|
||||||
|
'db_checked': False,
|
||||||
|
'last_checked_at': 0.0,
|
||||||
|
'tags': [],
|
||||||
|
'civitai': None,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
hash_index = {
|
||||||
|
'hash-a': [file_a],
|
||||||
|
'hash-b': [file_b, duplicate_path],
|
||||||
|
}
|
||||||
|
excluded = [duplicate_path]
|
||||||
|
|
||||||
|
store.save_cache('dummy', raw_data, hash_index, excluded)
|
||||||
|
|
||||||
|
persisted = store.load_cache('dummy')
|
||||||
|
assert persisted is not None
|
||||||
|
assert len(persisted.raw_data) == 2
|
||||||
|
|
||||||
|
items = {item['file_path']: item for item in persisted.raw_data}
|
||||||
|
assert set(items.keys()) == {file_a, file_b}
|
||||||
|
first = items[file_a]
|
||||||
|
assert first['favorite'] is True
|
||||||
|
assert first['civitai']['id'] == 1
|
||||||
|
assert first['civitai']['trainedWords'] == ['word1']
|
||||||
|
assert first['tags'] == ['alpha', 'beta']
|
||||||
|
|
||||||
|
second = items[file_b]
|
||||||
|
assert second['exclude'] is True
|
||||||
|
assert second['civitai'] is None
|
||||||
|
|
||||||
|
expected_hash_pairs = {
|
||||||
|
('hash-a', file_a),
|
||||||
|
('hash-b', file_b),
|
||||||
|
('hash-b', duplicate_path),
|
||||||
|
}
|
||||||
|
assert set((sha, path) for sha, path in persisted.hash_rows) == expected_hash_pairs
|
||||||
|
assert persisted.excluded_models == excluded
|
||||||
Reference in New Issue
Block a user