From d287883671a8d8253fcb7b45147b556b5f105e38 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Sat, 6 Sep 2025 22:17:24 +0800 Subject: [PATCH 01/13] refactor(civitai): remove legacy get_model_description and _get_hash_from_civitai methods --- .gitignore | 1 + py/services/civitai_client.py | 30 ------------------------------ 2 files changed, 1 insertion(+), 30 deletions(-) diff --git a/.gitignore b/.gitignore index 03d5100a..7b15029e 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ output/* py/run_test.py .vscode/ cache/ +civitai/ diff --git a/py/services/civitai_client.py b/py/services/civitai_client.py index 8767621e..ff870f6d 100644 --- a/py/services/civitai_client.py +++ b/py/services/civitai_client.py @@ -495,42 +495,12 @@ class CivitaiClient: logger.error(f"Error fetching model metadata: {e}", exc_info=True) return None, 0 - # Keep old method for backward compatibility, delegating to the new one - async def get_model_description(self, model_id: str) -> Optional[str]: - """Fetch the model description from Civitai API (Legacy method)""" - metadata, _ = await self.get_model_metadata(model_id) - return metadata.get("description") if metadata else None - async def close(self): """Close the session if it exists""" if self._session is not None: await self._session.close() self._session = None - async def _get_hash_from_civitai(self, model_version_id: str) -> Optional[str]: - """Get hash from Civitai API""" - try: - session = await self._ensure_fresh_session() - if not session: - return None - - version_info = await session.get(f"{self.base_url}/model-versions/{model_version_id}") - - if not version_info or not version_info.json().get('files'): - return None - - # Get hash from the first file - for file_info in version_info.json().get('files', []): - if file_info.get('hashes', {}).get('SHA256'): - # Convert hash to lowercase to standardize - hash_value = file_info['hashes']['SHA256'].lower() - return hash_value - - return None - except Exception as e: - logger.error(f"Error getting hash from Civitai: {e}") - return None - async def get_image_info(self, image_id: str) -> Optional[Dict]: """Fetch image information from Civitai API From 9ba3e2c204a79ed92d242f0960a6276bebf6d44a Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Mon, 8 Sep 2025 10:33:59 +0800 Subject: [PATCH 02/13] feat(metadata): implement metadata providers and initialize metadata service - Added ModelMetadataProvider and CivitaiModelMetadataProvider for handling model metadata. - Introduced SQLiteModelMetadataProvider for SQLite database integration. - Created metadata_service.py to initialize and configure metadata providers. - Updated CivitaiClient to register as a metadata provider. - Refactored download_manager to use the new download_file method. - Added SQL schema for models, model_versions, and model_files. - Updated requirements.txt to include aiosqlite. --- py/lora_manager.py | 3 + py/services/civitai_client.py | 29 +- py/services/download_manager.py | 2 +- py/services/metadata_service.py | 33 +++ py/services/model_metadata_provider.py | 389 +++++++++++++++++++++++++ py/utils/routes_common.py | 1 + refs/civitai.sql | 38 +++ requirements.txt | 1 + 8 files changed, 473 insertions(+), 23 deletions(-) create mode 100644 py/services/metadata_service.py create mode 100644 py/services/model_metadata_provider.py create mode 100644 refs/civitai.sql diff --git a/py/lora_manager.py b/py/lora_manager.py index df082d1b..f6598bca 100644 --- a/py/lora_manager.py +++ b/py/lora_manager.py @@ -190,6 +190,9 @@ class LoraManager: # Register DownloadManager with ServiceRegistry await ServiceRegistry.get_download_manager() + + from .services.metadata_service import initialize_metadata_providers + await initialize_metadata_providers() # Initialize WebSocket manager await ServiceRegistry.get_websocket_manager() diff --git a/py/services/civitai_client.py b/py/services/civitai_client.py index ff870f6d..64a6f2f5 100644 --- a/py/services/civitai_client.py +++ b/py/services/civitai_client.py @@ -3,9 +3,8 @@ import aiohttp import os import logging import asyncio -from email.parser import Parser from typing import Optional, Dict, Tuple, List -from urllib.parse import unquote +from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager logger = logging.getLogger(__name__) @@ -19,6 +18,11 @@ class CivitaiClient: async with cls._lock: if cls._instance is None: cls._instance = cls() + + # Register this client as a metadata provider + provider_manager = await ModelMetadataProviderManager.get_instance() + provider_manager.register_provider('civitai', CivitaiModelMetadataProvider(cls._instance), True) + return cls._instance def __init__(self): @@ -69,24 +73,6 @@ class CivitaiClient: return await self.session - def _parse_content_disposition(self, header: str) -> str: - """Parse filename from content-disposition header""" - if not header: - return None - - # Handle quoted filenames - if 'filename="' in header: - start = header.index('filename="') + 10 - end = header.index('"', start) - return unquote(header[start:end]) - - # Fallback to original parsing - disposition = Parser().parsestr(f'Content-Disposition: {header}') - filename = disposition.get_param('filename') - if filename: - return unquote(filename) - return None - def _get_request_headers(self) -> dict: """Get request headers with optional API key""" headers = { @@ -101,7 +87,7 @@ class CivitaiClient: return headers - async def _download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]: + async def download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]: """Download file with resumable downloads and retry mechanism Args: @@ -129,7 +115,6 @@ class CivitaiClient: logger.info(f"Resuming download from offset {resume_offset} bytes") total_size = 0 - filename = default_filename while retry_count <= max_retries: try: diff --git a/py/services/download_manager.py b/py/services/download_manager.py index 8f491fa8..08295985 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -487,7 +487,7 @@ class DownloadManager: await progress_callback(3) # 3% progress after preview download # Download model file with progress tracking - success, result = await civitai_client._download_file( + success, result = await civitai_client.download_file( download_url, save_dir, os.path.basename(save_path), diff --git a/py/services/metadata_service.py b/py/services/metadata_service.py new file mode 100644 index 00000000..beb8912d --- /dev/null +++ b/py/services/metadata_service.py @@ -0,0 +1,33 @@ +import os +import logging +from .model_metadata_provider import ModelMetadataProviderManager, SQLiteModelMetadataProvider + +logger = logging.getLogger(__name__) + +async def initialize_metadata_providers(): + """Initialize and configure all metadata providers""" + provider_manager = await ModelMetadataProviderManager.get_instance() + + # Use hardcoded SQLite DB path if not set in settings + db_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))), + 'civitai', 'civitai.sqlite' + ) + if db_path and os.path.exists(db_path): + try: + sqlite_provider = SQLiteModelMetadataProvider(db_path) + provider_manager.register_provider('sqlite', sqlite_provider) + logger.info(f"SQLite metadata provider registered with database: {db_path}") + except Exception as e: + logger.error(f"Failed to initialize SQLite metadata provider: {e}") + + return provider_manager + +async def get_metadata_provider(provider_name: str = None): + """Get a specific metadata provider or default provider""" + provider_manager = await ModelMetadataProviderManager.get_instance() + + if provider_name: + return provider_manager._get_provider(provider_name) + + return provider_manager._get_provider() diff --git a/py/services/model_metadata_provider.py b/py/services/model_metadata_provider.py new file mode 100644 index 00000000..95dfc1da --- /dev/null +++ b/py/services/model_metadata_provider.py @@ -0,0 +1,389 @@ +from abc import ABC, abstractmethod +import json +import aiosqlite +import logging +from typing import Optional, Dict, List, Tuple, Any + +logger = logging.getLogger(__name__) + +class ModelMetadataProvider(ABC): + """Base abstract class for all model metadata providers""" + + @abstractmethod + async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]: + """Find model by hash value""" + pass + + @abstractmethod + async def get_model_versions(self, model_id: str) -> Optional[Dict]: + """Get all versions of a model with their details""" + pass + + @abstractmethod + async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: + """Get specific model version with additional metadata""" + pass + + @abstractmethod + async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]: + """Fetch model version metadata""" + pass + + @abstractmethod + async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]: + """Fetch model metadata (description, tags, and creator info)""" + pass + +class CivitaiModelMetadataProvider(ModelMetadataProvider): + """Provider that uses Civitai API for metadata""" + + def __init__(self, civitai_client): + self.client = civitai_client + + async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]: + return await self.client.get_model_by_hash(model_hash) + + async def get_model_versions(self, model_id: str) -> Optional[Dict]: + return await self.client.get_model_versions(model_id) + + async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: + return await self.client.get_model_version(model_id, version_id) + + async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]: + return await self.client.get_model_version_info(version_id) + + async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]: + return await self.client.get_model_metadata(model_id) + +class SQLiteModelMetadataProvider(ModelMetadataProvider): + """Provider that uses SQLite database for metadata""" + + def __init__(self, db_path: str): + self.db_path = db_path + + async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]: + """Find model by hash value from SQLite database""" + async with aiosqlite.connect(self.db_path) as db: + # Look up in model_files table to get model_id and version_id + query = """ + SELECT model_id, version_id + FROM model_files + WHERE sha256 = ? + LIMIT 1 + """ + db.row_factory = aiosqlite.Row + cursor = await db.execute(query, (model_hash.upper(),)) + file_row = await cursor.fetchone() + + if not file_row: + return None + + # Get version details + model_id = file_row['model_id'] + version_id = file_row['version_id'] + + # Build response in the same format as Civitai API + return await self._get_version_with_model_data(db, model_id, version_id) + + async def get_model_versions(self, model_id: str) -> Optional[Dict]: + """Get all versions of a model from SQLite database""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + + # First check if model exists + model_query = "SELECT * FROM models WHERE id = ?" + cursor = await db.execute(model_query, (model_id,)) + model_row = await cursor.fetchone() + + if not model_row: + return None + + model_data = json.loads(model_row['data']) + model_type = model_row['type'] + + # Get all versions for this model + versions_query = """ + SELECT id, name, base_model, data, position, published_at + FROM model_versions + WHERE model_id = ? + ORDER BY position ASC + """ + cursor = await db.execute(versions_query, (model_id,)) + version_rows = await cursor.fetchall() + + if not version_rows: + return {'modelVersions': [], 'type': model_type} + + # Format versions similar to Civitai API + model_versions = [] + for row in version_rows: + version_data = json.loads(row['data']) + # Add fields from the row to ensure we have the basic fields + version_entry = { + 'id': row['id'], + 'modelId': int(model_id), + 'name': row['name'], + 'baseModel': row['base_model'], + 'model': { + 'name': model_row['name'], + 'type': model_type, + } + } + # Update with any additional data + version_entry.update(version_data) + model_versions.append(version_entry) + + return { + 'modelVersions': model_versions, + 'type': model_type + } + + async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: + """Get specific model version with additional metadata from SQLite database""" + if not model_id and not version_id: + return None + + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + + # Case 1: Only version_id is provided + if model_id is None and version_id is not None: + # First get the version info to extract model_id + version_query = "SELECT model_id FROM model_versions WHERE id = ?" + cursor = await db.execute(version_query, (version_id,)) + version_row = await cursor.fetchone() + + if not version_row: + return None + + model_id = version_row['model_id'] + + # Case 2: model_id is provided but version_id is not + elif model_id is not None and version_id is None: + # Find the latest version + version_query = """ + SELECT id FROM model_versions + WHERE model_id = ? + ORDER BY position ASC + LIMIT 1 + """ + cursor = await db.execute(version_query, (model_id,)) + version_row = await cursor.fetchone() + + if not version_row: + return None + + version_id = version_row['id'] + + # Now we have both model_id and version_id, get the full data + return await self._get_version_with_model_data(db, model_id, version_id) + + async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]: + """Fetch model version metadata from SQLite database""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + + # Get version details + version_query = "SELECT model_id FROM model_versions WHERE id = ?" + cursor = await db.execute(version_query, (version_id,)) + version_row = await cursor.fetchone() + + if not version_row: + return None, "Model version not found" + + model_id = version_row['model_id'] + + # Build complete version data with model info + version_data = await self._get_version_with_model_data(db, model_id, version_id) + return version_data, None + + async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]: + """Fetch model metadata from SQLite database""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + + # Get model details + model_query = "SELECT name, type, data, username FROM models WHERE id = ?" + cursor = await db.execute(model_query, (model_id,)) + model_row = await cursor.fetchone() + + if not model_row: + return None, 404 + + # Parse data JSON + try: + model_data = json.loads(model_row['data']) + + # Extract relevant metadata + metadata = { + "description": model_data.get("description", "No model description available"), + "tags": model_data.get("tags", []), + "creator": { + "username": model_row['username'] or model_data.get("creator", {}).get("username"), + "image": model_data.get("creator", {}).get("image") + } + } + + return metadata, 200 + except json.JSONDecodeError: + return None, 500 + + async def _get_version_with_model_data(self, db, model_id, version_id) -> Optional[Dict]: + """Helper to build version data with model information""" + # Get version details + version_query = "SELECT name, base_model, data FROM model_versions WHERE id = ? AND model_id = ?" + cursor = await db.execute(version_query, (version_id, model_id)) + version_row = await cursor.fetchone() + + if not version_row: + return None + + # Get model details + model_query = "SELECT name, type, data, username FROM models WHERE id = ?" + cursor = await db.execute(model_query, (model_id,)) + model_row = await cursor.fetchone() + + if not model_row: + return None + + # Parse JSON data + try: + version_data = json.loads(version_row['data']) + model_data = json.loads(model_row['data']) + + # Build response + result = { + "id": int(version_id), + "modelId": int(model_id), + "name": version_row['name'], + "baseModel": version_row['base_model'], + "model": { + "name": model_row['name'], + "description": model_data.get("description"), + "type": model_row['type'], + "tags": model_data.get("tags", []) + }, + "creator": { + "username": model_row['username'] or model_data.get("creator", {}).get("username"), + "image": model_data.get("creator", {}).get("image") + } + } + + # Add any additional fields from version data + result.update(version_data) + + return result + except json.JSONDecodeError: + return None + +class FallbackMetadataProvider(ModelMetadataProvider): + """Try providers in order, return first successful result.""" + def __init__(self, providers: list): + self.providers = providers + + async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]: + for provider in self.providers: + try: + result = await provider.get_model_by_hash(model_hash) + if result: + return result + except Exception: + continue + return None + + async def get_model_versions(self, model_id: str) -> Optional[Dict]: + for provider in self.providers: + try: + result = await provider.get_model_versions(model_id) + if result: + return result + except Exception: + continue + return None + + async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: + for provider in self.providers: + try: + result = await provider.get_model_version(model_id, version_id) + if result: + return result + except Exception: + continue + return None + + async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]: + for provider in self.providers: + try: + result, err = await provider.get_model_version_info(version_id) + if result: + return result, err + except Exception: + continue + return None, "Not found in any provider" + + async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]: + for provider in self.providers: + try: + result, code = await provider.get_model_metadata(model_id) + if result: + return result, code + except Exception: + continue + return None, 404 + +class ModelMetadataProviderManager: + """Manager for selecting and using model metadata providers""" + + _instance = None + + @classmethod + async def get_instance(cls): + """Get singleton instance of ModelMetadataProviderManager""" + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def __init__(self): + self.providers = {} + self.default_provider = None + + def register_provider(self, name: str, provider: ModelMetadataProvider, is_default: bool = False): + """Register a metadata provider""" + self.providers[name] = provider + if is_default or self.default_provider is None: + self.default_provider = name + + async def get_model_by_hash(self, model_hash: str, provider_name: str = None) -> Optional[Dict]: + """Find model by hash using specified or default provider""" + provider = self._get_provider(provider_name) + return await provider.get_model_by_hash(model_hash) + + async def get_model_versions(self, model_id: str, provider_name: str = None) -> Optional[Dict]: + """Get model versions using specified or default provider""" + provider = self._get_provider(provider_name) + return await provider.get_model_versions(model_id) + + async def get_model_version(self, model_id: int = None, version_id: int = None, provider_name: str = None) -> Optional[Dict]: + """Get specific model version using specified or default provider""" + provider = self._get_provider(provider_name) + return await provider.get_model_version(model_id, version_id) + + async def get_model_version_info(self, version_id: str, provider_name: str = None) -> Tuple[Optional[Dict], Optional[str]]: + """Fetch model version info using specified or default provider""" + provider = self._get_provider(provider_name) + return await provider.get_model_version_info(version_id) + + async def get_model_metadata(self, model_id: str, provider_name: str = None) -> Tuple[Optional[Dict], int]: + """Fetch model metadata using specified or default provider""" + provider = self._get_provider(provider_name) + return await provider.get_model_metadata(model_id) + + def _get_provider(self, provider_name: str = None) -> ModelMetadataProvider: + """Get provider by name or default provider""" + if provider_name and provider_name in self.providers: + return self.providers[provider_name] + + if self.default_provider is None: + raise ValueError("No default provider set and no valid provider specified") + + return self.providers[self.default_provider] diff --git a/py/utils/routes_common.py b/py/utils/routes_common.py index 8f9cfe7f..778725a1 100644 --- a/py/utils/routes_common.py +++ b/py/utils/routes_common.py @@ -13,6 +13,7 @@ from ..utils.exif_utils import ExifUtils from ..utils.metadata_manager import MetadataManager from ..services.download_manager import DownloadManager from ..services.websocket_manager import ws_manager +from ..services.metadata_service import get_metadata_provider logger = logging.getLogger(__name__) diff --git a/refs/civitai.sql b/refs/civitai.sql new file mode 100644 index 00000000..99814e19 --- /dev/null +++ b/refs/civitai.sql @@ -0,0 +1,38 @@ +CREATE TABLE models ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + type TEXT NOT NULL, + username TEXT, + data TEXT NOT NULL, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL + ) STRICT; +CREATE TABLE model_versions ( + id INTEGER PRIMARY KEY, + model_id INTEGER NOT NULL, + position INTEGER NOT NULL, + name TEXT NOT NULL, + base_model TEXT NOT NULL, + published_at INTEGER, + data TEXT NOT NULL, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL + ) STRICT; +CREATE INDEX model_versions_model_id_idx ON model_versions (model_id); +CREATE TABLE model_files ( + id INTEGER PRIMARY KEY, + model_id INTEGER NOT NULL, + version_id INTEGER NOT NULL, + type TEXT NOT NULL, + sha256 TEXT, + data TEXT NOT NULL, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL + ) STRICT; +CREATE INDEX model_files_model_id_idx ON model_files (model_id); +CREATE INDEX model_files_version_id_idx ON model_files (version_id); +CREATE TABLE archived_model_files ( + file_id INTEGER PRIMARY KEY, + model_id INTEGER NOT NULL, + version_id INTEGER NOT NULL + ) STRICT; \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index bb65770d..9e280c64 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ toml numpy natsort GitPython +aiosqlite From 821827a375d5a08ab8f57a5fe13e3b4eda0a6928 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Mon, 8 Sep 2025 13:17:16 +0800 Subject: [PATCH 03/13] feat(metadata): implement metadata archive management and update settings for metadata providers --- locales/en.json | 32 ++- py/routes/checkpoint_routes.py | 7 +- py/routes/embedding_routes.py | 7 +- py/routes/lora_routes.py | 13 +- py/routes/misc_routes.py | 112 +++++++++++ py/services/metadata_archive_manager.py | 150 ++++++++++++++ py/services/metadata_service.py | 97 +++++++-- py/services/model_metadata_provider.py | 22 +- py/services/settings_manager.py | 4 +- static/js/managers/SettingsManager.js | 189 ++++++++++++++++++ .../components/modals/settings_modal.html | 64 ++++++ 11 files changed, 659 insertions(+), 38 deletions(-) create mode 100644 py/services/metadata_archive_manager.py diff --git a/locales/en.json b/locales/en.json index 093433c8..52bc8580 100644 --- a/locales/en.json +++ b/locales/en.json @@ -16,7 +16,9 @@ "loading": "Loading...", "unknown": "Unknown", "date": "Date", - "version": "Version" + "version": "Version", + "enabled": "Enabled", + "disabled": "Disabled" }, "language": { "select": "Language", @@ -178,7 +180,8 @@ "folderSettings": "Folder Settings", "downloadPathTemplates": "Download Path Templates", "exampleImages": "Example Images", - "misc": "Misc." + "misc": "Misc.", + "metadataArchive": "Metadata Archive Database" }, "contentFiltering": { "blurNsfwContent": "Blur NSFW Content", @@ -273,6 +276,31 @@ "misc": { "includeTriggerWords": "Include Trigger Words in LoRA Syntax", "includeTriggerWordsHelp": "Include trained trigger words when copying LoRA syntax to clipboard" + }, + "metadataArchive": { + "enableArchiveDb": "Enable Metadata Archive Database", + "enableArchiveDbHelp": "Use local database for faster metadata retrieval and access to deleted models. Recommended for better performance.", + "providerPriority": "Metadata Provider Priority", + "providerPriorityHelp": "Choose which metadata source to try first when loading model information", + "priorityArchiveDb": "Archive Database (Recommended)", + "priorityCivitaiApi": "Civitai API", + "status": "Status", + "statusAvailable": "Available", + "statusUnavailable": "Not Available", + "enabled": "Enabled", + "currentPriority": "Current Priority", + "management": "Database Management", + "managementHelp": "Download or remove the metadata archive database", + "downloadButton": "Download Database", + "downloadingButton": "Downloading...", + "downloadedButton": "Downloaded", + "removeButton": "Remove Database", + "removingButton": "Removing...", + "downloadSuccess": "Metadata archive database downloaded successfully", + "downloadError": "Failed to download metadata archive database", + "removeSuccess": "Metadata archive database removed successfully", + "removeError": "Failed to remove metadata archive database", + "removeConfirm": "Are you sure you want to remove the metadata archive database? This will delete the local database file and you'll need to download it again to use this feature." } }, "loras": { diff --git a/py/routes/checkpoint_routes.py b/py/routes/checkpoint_routes.py index 9b5d20a6..b93700cf 100644 --- a/py/routes/checkpoint_routes.py +++ b/py/routes/checkpoint_routes.py @@ -4,6 +4,7 @@ from aiohttp import web from .base_model_routes import BaseModelRoutes from ..services.checkpoint_service import CheckpointService from ..services.service_registry import ServiceRegistry +from ..services.model_metadata_provider import ModelMetadataProviderManager from ..config import config logger = logging.getLogger(__name__) @@ -15,14 +16,14 @@ class CheckpointRoutes(BaseModelRoutes): """Initialize Checkpoint routes with Checkpoint service""" # Service will be initialized later via setup_routes self.service = None - self.civitai_client = None + self.metadata_provider = None self.template_name = "checkpoints.html" async def initialize_services(self): """Initialize services from ServiceRegistry""" checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() self.service = CheckpointService(checkpoint_scanner) - self.civitai_client = await ServiceRegistry.get_civitai_client() + self.metadata_provider = await ModelMetadataProviderManager.get_instance() # Initialize parent with the service super().__init__(self.service) @@ -66,7 +67,7 @@ class CheckpointRoutes(BaseModelRoutes): """Get available versions for a Civitai checkpoint model with local availability info""" try: model_id = request.match_info['model_id'] - response = await self.civitai_client.get_model_versions(model_id) + response = await self.metadata_provider.get_model_versions(model_id) if not response or not response.get('modelVersions'): return web.Response(status=404, text="Model not found") diff --git a/py/routes/embedding_routes.py b/py/routes/embedding_routes.py index eb9a5203..65a66824 100644 --- a/py/routes/embedding_routes.py +++ b/py/routes/embedding_routes.py @@ -4,6 +4,7 @@ from aiohttp import web from .base_model_routes import BaseModelRoutes from ..services.embedding_service import EmbeddingService from ..services.service_registry import ServiceRegistry +from ..services.model_metadata_provider import ModelMetadataProviderManager logger = logging.getLogger(__name__) @@ -14,14 +15,14 @@ class EmbeddingRoutes(BaseModelRoutes): """Initialize Embedding routes with Embedding service""" # Service will be initialized later via setup_routes self.service = None - self.civitai_client = None + self.metadata_provider = None self.template_name = "embeddings.html" async def initialize_services(self): """Initialize services from ServiceRegistry""" embedding_scanner = await ServiceRegistry.get_embedding_scanner() self.service = EmbeddingService(embedding_scanner) - self.civitai_client = await ServiceRegistry.get_civitai_client() + self.metadata_provider = await ModelMetadataProviderManager.get_instance() # Initialize parent with the service super().__init__(self.service) @@ -61,7 +62,7 @@ class EmbeddingRoutes(BaseModelRoutes): """Get available versions for a Civitai embedding model with local availability info""" try: model_id = request.match_info['model_id'] - response = await self.civitai_client.get_model_versions(model_id) + response = await self.metadata_provider.get_model_versions(model_id) if not response or not response.get('modelVersions'): return web.Response(status=404, text="Model not found") diff --git a/py/routes/lora_routes.py b/py/routes/lora_routes.py index 2da33cb1..4c1c0467 100644 --- a/py/routes/lora_routes.py +++ b/py/routes/lora_routes.py @@ -7,6 +7,7 @@ from server import PromptServer # type: ignore from .base_model_routes import BaseModelRoutes from ..services.lora_service import LoraService from ..services.service_registry import ServiceRegistry +from ..services.model_metadata_provider import ModelMetadataProviderManager from ..utils.routes_common import ModelRouteUtils from ..utils.utils import get_lora_info @@ -19,14 +20,14 @@ class LoraRoutes(BaseModelRoutes): """Initialize LoRA routes with LoRA service""" # Service will be initialized later via setup_routes self.service = None - self.civitai_client = None + self.metadata_provider = None self.template_name = "loras.html" async def initialize_services(self): """Initialize services from ServiceRegistry""" lora_scanner = await ServiceRegistry.get_lora_scanner() self.service = LoraService(lora_scanner) - self.civitai_client = await ServiceRegistry.get_civitai_client() + self.metadata_provider = await ModelMetadataProviderManager.get_instance() # Initialize parent with the service super().__init__(self.service) @@ -217,7 +218,7 @@ class LoraRoutes(BaseModelRoutes): """Get available versions for a Civitai LoRA model with local availability info""" try: model_id = request.match_info['model_id'] - response = await self.civitai_client.get_model_versions(model_id) + response = await self.metadata_provider.get_model_versions(model_id) if not response or not response.get('modelVersions'): return web.Response(status=404, text="Model not found") @@ -261,8 +262,8 @@ class LoraRoutes(BaseModelRoutes): try: model_version_id = request.match_info.get('modelVersionId') - # Get model details from Civitai API - model, error_msg = await self.civitai_client.get_model_version_info(model_version_id) + # Get model details from metadata provider + model, error_msg = await self.metadata_provider.get_model_version_info(model_version_id) if not model: # Log warning for failed model retrieval @@ -288,7 +289,7 @@ class LoraRoutes(BaseModelRoutes): """Get CivitAI model details by hash""" try: hash = request.match_info.get('hash') - model = await self.civitai_client.get_model_by_hash(hash) + model = await self.metadata_provider.get_model_by_hash(hash) return web.json_response(model) except Exception as e: logger.error(f"Error fetching model details by hash: {e}") diff --git a/py/routes/misc_routes.py b/py/routes/misc_routes.py index 87a16aac..9a29a24d 100644 --- a/py/routes/misc_routes.py +++ b/py/routes/misc_routes.py @@ -11,6 +11,8 @@ from ..utils.lora_metadata import extract_trained_words from ..config import config from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS, NODE_TYPES, DEFAULT_NODE_COLOR from ..services.service_registry import ServiceRegistry +from ..services.metadata_service import get_metadata_archive_manager, update_metadata_provider_priority +from ..services.websocket_manager import ws_manager import re logger = logging.getLogger(__name__) @@ -112,6 +114,11 @@ class MiscRoutes: # Add new route for checking if a model exists in the library app.router.add_get('/api/check-model-exists', MiscRoutes.check_model_exists) + + # Add routes for metadata archive database management + app.router.add_post('/api/download-metadata-archive', MiscRoutes.download_metadata_archive) + app.router.add_post('/api/remove-metadata-archive', MiscRoutes.remove_metadata_archive) + app.router.add_get('/api/metadata-archive-status', MiscRoutes.get_metadata_archive_status) @staticmethod async def clear_cache(request): @@ -697,3 +704,108 @@ class MiscRoutes: 'success': False, 'error': str(e) }, status=500) + + @staticmethod + async def download_metadata_archive(request): + """Download and extract the metadata archive database""" + try: + archive_manager = await get_metadata_archive_manager() + + # Progress callback to send updates via WebSocket + def progress_callback(stage, message): + asyncio.create_task(ws_manager.broadcast({ + 'stage': stage, + 'message': message, + 'type': 'metadata_archive_download' + })) + + # Download and extract in background + success = await archive_manager.download_and_extract_database(progress_callback) + + if success: + # Update settings to enable metadata archive + settings.set('enable_metadata_archive_db', True) + + # Update provider priority + await update_metadata_provider_priority() + + return web.json_response({ + 'success': True, + 'message': 'Metadata archive database downloaded and extracted successfully' + }) + else: + return web.json_response({ + 'success': False, + 'error': 'Failed to download and extract metadata archive database' + }, status=500) + + except Exception as e: + logger.error(f"Error downloading metadata archive: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + @staticmethod + async def remove_metadata_archive(request): + """Remove the metadata archive database""" + try: + archive_manager = await get_metadata_archive_manager() + + success = await archive_manager.remove_database() + + if success: + # Update settings to disable metadata archive + settings.set('enable_metadata_archive_db', False) + + # Update provider priority + await update_metadata_provider_priority() + + return web.json_response({ + 'success': True, + 'message': 'Metadata archive database removed successfully' + }) + else: + return web.json_response({ + 'success': False, + 'error': 'Failed to remove metadata archive database' + }, status=500) + + except Exception as e: + logger.error(f"Error removing metadata archive: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + @staticmethod + async def get_metadata_archive_status(request): + """Get the status of metadata archive database""" + try: + archive_manager = await get_metadata_archive_manager() + + is_available = archive_manager.is_database_available() + is_enabled = settings.get('enable_metadata_archive_db', False) + priority = settings.get('metadata_provider_priority', 'archive_db') + + db_size = 0 + if is_available: + db_path = archive_manager.get_database_path() + if db_path and os.path.exists(db_path): + db_size = os.path.getsize(db_path) + + return web.json_response({ + 'success': True, + 'isAvailable': is_available, + 'isEnabled': is_enabled, + 'priority': priority, + 'databaseSize': db_size, + 'databasePath': archive_manager.get_database_path() if is_available else None + }) + + except Exception as e: + logger.error(f"Error getting metadata archive status: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) diff --git a/py/services/metadata_archive_manager.py b/py/services/metadata_archive_manager.py new file mode 100644 index 00000000..3daf761b --- /dev/null +++ b/py/services/metadata_archive_manager.py @@ -0,0 +1,150 @@ +import zipfile +import aiohttp +import logging +import asyncio +from pathlib import Path +from typing import Optional + +logger = logging.getLogger(__name__) + +class MetadataArchiveManager: + """Manages downloading and extracting Civitai metadata archive database""" + + DOWNLOAD_URLS = [ + "https://github.com/willmiao/civitai-metadata-archive-db/releases/download/db-2025-08-08/civitai.zip", + "https://huggingface.co/datasets/willmiao/civitai-metadata-archive-db/blob/main/civitai.zip" + ] + + def __init__(self, base_path: str): + """Initialize with base path where files will be stored""" + self.base_path = Path(base_path) + self.civitai_folder = self.base_path / "civitai" + self.archive_path = self.base_path / "civitai.zip" + self.db_path = self.civitai_folder / "civitai.sqlite" + + def is_database_available(self) -> bool: + """Check if the SQLite database is available and valid""" + return self.db_path.exists() and self.db_path.stat().st_size > 0 + + def get_database_path(self) -> Optional[str]: + """Get the path to the SQLite database if available""" + if self.is_database_available(): + return str(self.db_path) + return None + + async def download_and_extract_database(self, progress_callback=None) -> bool: + """Download and extract the metadata archive database + + Args: + progress_callback: Optional callback function to report progress + + Returns: + bool: True if successful, False otherwise + """ + try: + # Create directories if they don't exist + self.base_path.mkdir(parents=True, exist_ok=True) + self.civitai_folder.mkdir(parents=True, exist_ok=True) + + # Download the archive + if not await self._download_archive(progress_callback): + return False + + # Extract the archive + if not await self._extract_archive(progress_callback): + return False + + # Clean up the archive file + if self.archive_path.exists(): + self.archive_path.unlink() + + logger.info(f"Successfully downloaded and extracted metadata database to {self.db_path}") + return True + + except Exception as e: + logger.error(f"Error downloading and extracting metadata database: {e}", exc_info=True) + return False + + async def _download_archive(self, progress_callback=None) -> bool: + """Download the zip archive from one of the available URLs""" + for url in self.DOWNLOAD_URLS: + try: + logger.info(f"Attempting to download from {url}") + + if progress_callback: + progress_callback("download", f"Downloading from {url}") + + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + if response.status == 200: + total_size = int(response.headers.get('content-length', 0)) + downloaded = 0 + + with open(self.archive_path, 'wb') as f: + async for chunk in response.content.iter_chunked(8192): + f.write(chunk) + downloaded += len(chunk) + + if progress_callback and total_size > 0: + percentage = (downloaded / total_size) * 100 + progress_callback("download", f"Downloaded {percentage:.1f}%") + + logger.info(f"Successfully downloaded archive from {url}") + return True + else: + logger.warning(f"Failed to download from {url}: HTTP {response.status}") + continue + + except Exception as e: + logger.warning(f"Error downloading from {url}: {e}") + continue + + logger.error("Failed to download archive from any URL") + return False + + async def _extract_archive(self, progress_callback=None) -> bool: + """Extract the zip archive to the civitai folder""" + try: + if progress_callback: + progress_callback("extract", "Extracting archive...") + + # Run extraction in thread pool to avoid blocking + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, self._extract_zip_sync) + + if progress_callback: + progress_callback("extract", "Extraction completed") + + return True + + except Exception as e: + logger.error(f"Error extracting archive: {e}", exc_info=True) + return False + + def _extract_zip_sync(self): + """Synchronous zip extraction (runs in thread pool)""" + with zipfile.ZipFile(self.archive_path, 'r') as archive: + archive.extractall(path=self.base_path) + + async def remove_database(self) -> bool: + """Remove the metadata database and folder""" + try: + if self.civitai_folder.exists(): + # Remove all files in the civitai folder + for file_path in self.civitai_folder.iterdir(): + if file_path.is_file(): + file_path.unlink() + + # Remove the folder itself + self.civitai_folder.rmdir() + + # Also remove the archive file if it exists + if self.archive_path.exists(): + self.archive_path.unlink() + + logger.info("Successfully removed metadata database") + return True + + except Exception as e: + logger.error(f"Error removing metadata database: {e}", exc_info=True) + return False diff --git a/py/services/metadata_service.py b/py/services/metadata_service.py index beb8912d..0e5d9199 100644 --- a/py/services/metadata_service.py +++ b/py/services/metadata_service.py @@ -1,28 +1,97 @@ import os import logging -from .model_metadata_provider import ModelMetadataProviderManager, SQLiteModelMetadataProvider +from .model_metadata_provider import ( + ModelMetadataProviderManager, + SQLiteModelMetadataProvider, + CivitaiModelMetadataProvider, + FallbackMetadataProvider +) +from .settings_manager import settings +from .metadata_archive_manager import MetadataArchiveManager +from .service_registry import ServiceRegistry logger = logging.getLogger(__name__) async def initialize_metadata_providers(): - """Initialize and configure all metadata providers""" + """Initialize and configure all metadata providers based on settings""" provider_manager = await ModelMetadataProviderManager.get_instance() - # Use hardcoded SQLite DB path if not set in settings - db_path = os.path.join( - os.path.dirname(os.path.dirname(os.path.dirname(__file__))), - 'civitai', 'civitai.sqlite' - ) - if db_path and os.path.exists(db_path): - try: - sqlite_provider = SQLiteModelMetadataProvider(db_path) - provider_manager.register_provider('sqlite', sqlite_provider) - logger.info(f"SQLite metadata provider registered with database: {db_path}") - except Exception as e: - logger.error(f"Failed to initialize SQLite metadata provider: {e}") + # Get settings + enable_archive_db = settings.get('enable_metadata_archive_db', False) + priority = settings.get('metadata_provider_priority', 'archive_db') + + providers = [] + + # Initialize archive database provider if enabled + if enable_archive_db: + # Initialize archive manager + base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + archive_manager = MetadataArchiveManager(base_path) + + db_path = archive_manager.get_database_path() + if db_path: + try: + sqlite_provider = SQLiteModelMetadataProvider(db_path) + provider_manager.register_provider('sqlite', sqlite_provider) + providers.append(('sqlite', sqlite_provider)) + logger.info(f"SQLite metadata provider registered with database: {db_path}") + except Exception as e: + logger.error(f"Failed to initialize SQLite metadata provider: {e}") + else: + logger.warning("Metadata archive database is enabled but not available") + + # Initialize Civitai API provider + try: + civitai_client = await ServiceRegistry.get_civitai_client() + civitai_provider = CivitaiModelMetadataProvider(civitai_client) + provider_manager.register_provider('civitai_api', civitai_provider) + providers.append(('civitai_api', civitai_provider)) + logger.info("Civitai API metadata provider registered") + except Exception as e: + logger.error(f"Failed to initialize Civitai API metadata provider: {e}") + + # Set up fallback provider based on priority + if len(providers) > 1: + # Order providers based on priority setting + if priority == 'archive_db': + # Archive DB first, then Civitai API + ordered_providers = [p[1] for p in providers if p[0] == 'sqlite'] + [p[1] for p in providers if p[0] == 'civitai_api'] + else: + # Civitai API first, then Archive DB + ordered_providers = [p[1] for p in providers if p[0] == 'civitai_api'] + [p[1] for p in providers if p[0] == 'sqlite'] + + if ordered_providers: + fallback_provider = FallbackMetadataProvider(ordered_providers) + provider_manager.register_provider('fallback', fallback_provider, is_default=True) + logger.info(f"Fallback metadata provider registered with priority: {priority}") + elif len(providers) == 1: + # Only one provider available, set it as default + provider_name, provider = providers[0] + provider_manager.register_provider(provider_name, provider, is_default=True) + logger.info(f"Single metadata provider registered as default: {provider_name}") + else: + logger.warning("No metadata providers available") return provider_manager +async def update_metadata_provider_priority(): + """Update metadata provider priority based on current settings""" + provider_manager = await ModelMetadataProviderManager.get_instance() + + # Get current settings + enable_archive_db = settings.get('enable_metadata_archive_db', False) + priority = settings.get('metadata_provider_priority', 'archive_db') + + # Rebuild providers with new priority + await initialize_metadata_providers() + + logger.info(f"Updated metadata provider priority to: {priority}") + +async def get_metadata_archive_manager(): + """Get metadata archive manager instance""" + base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + return MetadataArchiveManager(base_path) + async def get_metadata_provider(provider_name: str = None): """Get a specific metadata provider or default provider""" provider_manager = await ModelMetadataProviderManager.get_instance() diff --git a/py/services/model_metadata_provider.py b/py/services/model_metadata_provider.py index 95dfc1da..2a30df11 100644 --- a/py/services/model_metadata_provider.py +++ b/py/services/model_metadata_provider.py @@ -297,7 +297,8 @@ class FallbackMetadataProvider(ModelMetadataProvider): result = await provider.get_model_versions(model_id) if result: return result - except Exception: + except Exception as e: + logger.debug(f"Provider failed for get_model_versions: {e}") continue return None @@ -307,27 +308,30 @@ class FallbackMetadataProvider(ModelMetadataProvider): result = await provider.get_model_version(model_id, version_id) if result: return result - except Exception: + except Exception as e: + logger.debug(f"Provider failed for get_model_version: {e}") continue return None async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]: for provider in self.providers: try: - result, err = await provider.get_model_version_info(version_id) + result, error = await provider.get_model_version_info(version_id) if result: - return result, err - except Exception: + return result, error + except Exception as e: + logger.debug(f"Provider failed for get_model_version_info: {e}") continue - return None, "Not found in any provider" + return None, "No provider could retrieve the data" async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]: for provider in self.providers: try: - result, code = await provider.get_model_metadata(model_id) + result, status = await provider.get_model_metadata(model_id) if result: - return result, code - except Exception: + return result, status + except Exception as e: + logger.debug(f"Provider failed for get_model_metadata: {e}") continue return None, 404 diff --git a/py/services/settings_manager.py b/py/services/settings_manager.py index 53146613..0b86ce82 100644 --- a/py/services/settings_manager.py +++ b/py/services/settings_manager.py @@ -81,7 +81,9 @@ class SettingsManager: return { "civitai_api_key": "", "show_only_sfw": False, - "language": "en" # 添加默认语言设置 + "language": "en", # 添加默认语言设置 + "enable_metadata_archive_db": False, # Enable metadata archive database + "metadata_provider_priority": "archive_db" # Default priority: 'archive_db' or 'civitai_api' } def get(self, key: str, default: Any = None) -> Any: diff --git a/static/js/managers/SettingsManager.js b/static/js/managers/SettingsManager.js index b31613e2..1a15bdb4 100644 --- a/static/js/managers/SettingsManager.js +++ b/static/js/managers/SettingsManager.js @@ -260,6 +260,9 @@ export class SettingsManager { includeTriggerWordsCheckbox.checked = state.global.settings.includeTriggerWords || false; } + // Load metadata archive settings + await this.loadMetadataArchiveSettings(); + // Load base model path mappings this.loadBaseModelMappings(); @@ -838,6 +841,11 @@ export class SettingsManager { state: value ? 'toast.settings.compactModeEnabled' : 'toast.settings.compactModeDisabled' }, 'success'); } + + // Special handling for metadata archive settings + if (settingKey === 'enable_metadata_archive_db' || settingKey === 'metadata_provider_priority') { + await this.updateMetadataArchiveStatus(); + } } catch (error) { showToast('toast.settings.settingSaveFailed', { message: error.message }, 'error'); @@ -910,11 +918,192 @@ export class SettingsManager { showToast('toast.settings.displayDensitySet', { density: densityName }, 'success'); } + + // Special handling for metadata archive settings + if (settingKey === 'metadata_provider_priority') { + await this.updateMetadataArchiveStatus(); + } } catch (error) { showToast('toast.settings.settingSaveFailed', { message: error.message }, 'error'); } } + + async loadMetadataArchiveSettings() { + try { + // Load current settings from state + const enableMetadataArchiveCheckbox = document.getElementById('enableMetadataArchive'); + if (enableMetadataArchiveCheckbox) { + enableMetadataArchiveCheckbox.checked = state.global.settings.enable_metadata_archive_db || false; + } + + const metadataProviderPrioritySelect = document.getElementById('metadataProviderPriority'); + if (metadataProviderPrioritySelect) { + metadataProviderPrioritySelect.value = state.global.settings.metadata_provider_priority || 'archive_db'; + } + + // Load status + await this.updateMetadataArchiveStatus(); + } catch (error) { + console.error('Error loading metadata archive settings:', error); + } + } + + async updateMetadataArchiveStatus() { + try { + const response = await fetch('/api/metadata-archive-status'); + const data = await response.json(); + + const statusContainer = document.getElementById('metadataArchiveStatus'); + if (statusContainer && data.success) { + const status = data; + const sizeText = status.databaseSize > 0 ? ` (${this.formatFileSize(status.databaseSize)})` : ''; + + statusContainer.innerHTML = ` +
+
+ ${translate('settings.metadataArchive.status')}: + + ${status.isAvailable ? translate('settings.metadataArchive.statusAvailable') : translate('settings.metadataArchive.statusUnavailable')} + + ${sizeText} +
+
+ ${translate('settings.metadataArchive.enabled')}: + + ${status.isEnabled ? translate('common.enabled') : translate('common.disabled')} + +
+
+ ${translate('settings.metadataArchive.currentPriority')}: + ${status.priority === 'archive_db' ? translate('settings.metadataArchive.priorityArchiveDb') : translate('settings.metadataArchive.priorityCivitaiApi')} +
+
+ `; + + // Update button states + const downloadBtn = document.getElementById('downloadMetadataArchiveBtn'); + const removeBtn = document.getElementById('removeMetadataArchiveBtn'); + + if (downloadBtn) { + downloadBtn.disabled = status.isAvailable; + downloadBtn.textContent = status.isAvailable ? + translate('settings.metadataArchive.downloadedButton') : + translate('settings.metadataArchive.downloadButton'); + } + + if (removeBtn) { + removeBtn.disabled = !status.isAvailable; + } + } + } catch (error) { + console.error('Error updating metadata archive status:', error); + } + } + + formatFileSize(bytes) { + if (bytes === 0) return '0 Bytes'; + const k = 1024; + const sizes = ['Bytes', 'KB', 'MB', 'GB']; + const i = Math.floor(Math.log(bytes) / Math.log(k)); + return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i]; + } + + async downloadMetadataArchive() { + try { + const downloadBtn = document.getElementById('downloadMetadataArchiveBtn'); + if (downloadBtn) { + downloadBtn.disabled = true; + downloadBtn.textContent = translate('settings.metadataArchive.downloadingButton'); + } + + const response = await fetch('/api/download-metadata-archive', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + } + }); + + const data = await response.json(); + + if (data.success) { + showNotification(translate('settings.metadataArchive.downloadSuccess'), 'success'); + + // Update settings in state + state.global.settings.enable_metadata_archive_db = true; + setStorageItem('settings', state.global.settings); + + // Update UI + const enableCheckbox = document.getElementById('enableMetadataArchive'); + if (enableCheckbox) { + enableCheckbox.checked = true; + } + + await this.updateMetadataArchiveStatus(); + } else { + showNotification(translate('settings.metadataArchive.downloadError') + ': ' + data.error, 'error'); + } + } catch (error) { + console.error('Error downloading metadata archive:', error); + showNotification(translate('settings.metadataArchive.downloadError') + ': ' + error.message, 'error'); + } finally { + const downloadBtn = document.getElementById('downloadMetadataArchiveBtn'); + if (downloadBtn) { + downloadBtn.disabled = false; + downloadBtn.textContent = translate('settings.metadataArchive.downloadButton'); + } + } + } + + async removeMetadataArchive() { + if (!confirm(translate('settings.metadataArchive.removeConfirm'))) { + return; + } + + try { + const removeBtn = document.getElementById('removeMetadataArchiveBtn'); + if (removeBtn) { + removeBtn.disabled = true; + removeBtn.textContent = translate('settings.metadataArchive.removingButton'); + } + + const response = await fetch('/api/remove-metadata-archive', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + } + }); + + const data = await response.json(); + + if (data.success) { + showNotification(translate('settings.metadataArchive.removeSuccess'), 'success'); + + // Update settings in state + state.global.settings.enable_metadata_archive_db = false; + setStorageItem('settings', state.global.settings); + + // Update UI + const enableCheckbox = document.getElementById('enableMetadataArchive'); + if (enableCheckbox) { + enableCheckbox.checked = false; + } + + await this.updateMetadataArchiveStatus(); + } else { + showNotification(translate('settings.metadataArchive.removeError') + ': ' + data.error, 'error'); + } + } catch (error) { + console.error('Error removing metadata archive:', error); + showNotification(translate('settings.metadataArchive.removeError') + ': ' + error.message, 'error'); + } finally { + const removeBtn = document.getElementById('removeMetadataArchiveBtn'); + if (removeBtn) { + removeBtn.disabled = false; + removeBtn.textContent = translate('settings.metadataArchive.removeButton'); + } + } + } async saveInputSetting(elementId, settingKey) { const element = document.getElementById(elementId); diff --git a/templates/components/modals/settings_modal.html b/templates/components/modals/settings_modal.html index 2b9be5e7..a2f85716 100644 --- a/templates/components/modals/settings_modal.html +++ b/templates/components/modals/settings_modal.html @@ -419,6 +419,70 @@ + + +
+

{{ t('settings.sections.metadataArchive') }}

+ +
+
+
+ +
+
+ +
+
+
+ {{ t('settings.metadataArchive.enableArchiveDbHelp') }} +
+
+ +
+
+
+ +
+
+ +
+
+
+ {{ t('settings.metadataArchive.providerPriorityHelp') }} +
+
+ +
+ +
+ +
+
+
+ +
+
+ + +
+
+
+ {{ t('settings.metadataArchive.managementHelp') }} +
+
+
\ No newline at end of file From 14721c265f5d9bfc910474e08908fc66f96f1648 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Tue, 9 Sep 2025 10:34:14 +0800 Subject: [PATCH 04/13] Refactor download logic to use unified downloader service - Introduced a new `Downloader` class to centralize HTTP/HTTPS download management. - Replaced direct `aiohttp` session handling with the unified downloader in `MetadataArchiveManager`, `DownloadManager`, and `ExampleImagesProcessor`. - Added support for resumable downloads, progress tracking, and error handling in the new downloader. - Updated methods to utilize the downloader's capabilities for downloading files and images, improving code maintainability and readability. --- py/routes/update_routes.py | 127 ++--- py/services/civitai_client.py | 515 +++++++------------- py/services/downloader.py | 465 ++++++++++++++++++ py/services/metadata_archive_manager.py | 43 +- py/utils/example_images_download_manager.py | 61 +-- py/utils/example_images_processor.py | 80 +-- 6 files changed, 777 insertions(+), 514 deletions(-) create mode 100644 py/services/downloader.py diff --git a/py/routes/update_routes.py b/py/routes/update_routes.py index 9d60036e..66ef603a 100644 --- a/py/routes/update_routes.py +++ b/py/routes/update_routes.py @@ -1,5 +1,4 @@ import os -import aiohttp import logging import toml import git @@ -8,7 +7,7 @@ import shutil import tempfile from aiohttp import web from typing import Dict, List - +from ..services.downloader import get_downloader, Downloader logger = logging.getLogger(__name__) @@ -162,28 +161,42 @@ class UpdateRoutes: github_api = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest" try: - async with aiohttp.ClientSession() as session: - async with session.get(github_api) as resp: - if resp.status != 200: - logger.error(f"Failed to fetch release info: {resp.status}") - return False, "" - data = await resp.json() - zip_url = data.get("zipball_url") - version = data.get("tag_name", "unknown") + downloader = await get_downloader() + + # Get release info + success, data = await downloader.make_request( + 'GET', + github_api, + use_auth=False + ) + if not success: + logger.error(f"Failed to fetch release info: {data}") + return False, "" + + zip_url = data.get("zipball_url") + version = data.get("tag_name", "unknown") - # Download ZIP - async with session.get(zip_url) as zip_resp: - if zip_resp.status != 200: - logger.error(f"Failed to download ZIP: {zip_resp.status}") - return False, "" - with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_zip: - tmp_zip.write(await zip_resp.read()) - zip_path = tmp_zip.name + # Download ZIP to temporary file + with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_zip: + tmp_zip_path = tmp_zip.name + + success, result = await downloader.download_file( + url=zip_url, + save_path=tmp_zip_path, + use_auth=False, + allow_resume=False + ) + + if not success: + logger.error(f"Failed to download ZIP: {result}") + return False, "" - UpdateRoutes._clean_plugin_folder(plugin_root, skip_files=['settings.json']) + zip_path = tmp_zip_path - # Extract ZIP to temp dir - with tempfile.TemporaryDirectory() as tmp_dir: + UpdateRoutes._clean_plugin_folder(plugin_root, skip_files=['settings.json']) + + # Extract ZIP to temp dir + with tempfile.TemporaryDirectory() as tmp_dir: with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(tmp_dir) # Find extracted folder (GitHub ZIP contains a root folder) @@ -213,9 +226,9 @@ class UpdateRoutes: with open(tracking_info_file, "w", encoding='utf-8') as file: file.write('\n'.join(tracking_files)) - os.remove(zip_path) - logger.info(f"Updated plugin via ZIP to {version}") - return True, version + os.remove(zip_path) + logger.info(f"Updated plugin via ZIP to {version}") + return True, version except Exception as e: logger.error(f"ZIP update failed: {e}", exc_info=True) @@ -244,23 +257,23 @@ class UpdateRoutes: github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/commits/main" try: - async with aiohttp.ClientSession() as session: - async with session.get(github_url, headers={'Accept': 'application/vnd.github+json'}) as response: - if response.status != 200: - logger.warning(f"Failed to fetch GitHub commit: {response.status}") - return "main", [] - - data = await response.json() - commit_sha = data.get('sha', '')[:7] # Short hash - commit_message = data.get('commit', {}).get('message', '') - - # Format as "main-{short_hash}" - version = f"main-{commit_sha}" - - # Use commit message as changelog - changelog = [commit_message] if commit_message else [] - - return version, changelog + downloader = await Downloader.get_instance() + success, data = await downloader.make_request('GET', github_url, headers={'Accept': 'application/vnd.github+json'}) + + if not success: + logger.warning(f"Failed to fetch GitHub commit: {data}") + return "main", [] + + commit_sha = data.get('sha', '')[:7] # Short hash + commit_message = data.get('commit', {}).get('message', '') + + # Format as "main-{short_hash}" + version = f"main-{commit_sha}" + + # Use commit message as changelog + changelog = [commit_message] if commit_message else [] + + return version, changelog except Exception as e: logger.error(f"Error fetching nightly version: {e}", exc_info=True) @@ -410,22 +423,22 @@ class UpdateRoutes: github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest" try: - async with aiohttp.ClientSession() as session: - async with session.get(github_url, headers={'Accept': 'application/vnd.github+json'}) as response: - if response.status != 200: - logger.warning(f"Failed to fetch GitHub release: {response.status}") - return "v0.0.0", [] - - data = await response.json() - version = data.get('tag_name', '') - if not version.startswith('v'): - version = f"v{version}" - - # Extract changelog from release notes - body = data.get('body', '') - changelog = UpdateRoutes._parse_changelog(body) - - return version, changelog + downloader = await Downloader.get_instance() + success, data = await downloader.make_request('GET', github_url, headers={'Accept': 'application/vnd.github+json'}) + + if not success: + logger.warning(f"Failed to fetch GitHub release: {data}") + return "v0.0.0", [] + + version = data.get('tag_name', '') + if not version.startswith('v'): + version = f"v{version}" + + # Extract changelog from release notes + body = data.get('body', '') + changelog = UpdateRoutes._parse_changelog(body) + + return version, changelog except Exception as e: logger.error(f"Error fetching remote version: {e}", exc_info=True) diff --git a/py/services/civitai_client.py b/py/services/civitai_client.py index 64a6f2f5..6f87fbe4 100644 --- a/py/services/civitai_client.py +++ b/py/services/civitai_client.py @@ -1,10 +1,10 @@ from datetime import datetime -import aiohttp import os import logging import asyncio from typing import Optional, Dict, Tuple, List from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager +from .downloader import get_downloader logger = logging.getLogger(__name__) @@ -32,61 +32,7 @@ class CivitaiClient: self._initialized = True self.base_url = "https://civitai.com/api/v1" - self.headers = { - 'User-Agent': 'ComfyUI-LoRA-Manager/1.0' - } - self._session = None - self._session_created_at = None - # Adjust chunk size based on storage type - consider making this configurable - self.chunk_size = 4 * 1024 * 1024 # 4MB chunks for better HDD throughput - @property - async def session(self) -> aiohttp.ClientSession: - """Lazy initialize the session""" - if self._session is None: - # Optimize TCP connection parameters - connector = aiohttp.TCPConnector( - ssl=True, - limit=8, # Increase from 3 to 8 for better parallelism - ttl_dns_cache=300, # Enable DNS caching with reasonable timeout - force_close=False, # Keep connections for reuse - enable_cleanup_closed=True - ) - trust_env = True # Allow using system environment proxy settings - # Configure timeout parameters - increase read timeout for large files and remove sock_read timeout - timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=None) - self._session = aiohttp.ClientSession( - connector=connector, - trust_env=trust_env, - timeout=timeout - ) - self._session_created_at = datetime.now() - return self._session - - async def _ensure_fresh_session(self): - """Refresh session if it's been open too long""" - if self._session is not None: - if not hasattr(self, '_session_created_at') or \ - (datetime.now() - self._session_created_at).total_seconds() > 300: # 5 minutes - await self.close() - self._session = None - - return await self.session - - def _get_request_headers(self) -> dict: - """Get request headers with optional API key""" - headers = { - 'User-Agent': 'ComfyUI-LoRA-Manager/1.0', - 'Content-Type': 'application/json' - } - - from .settings_manager import settings - api_key = settings.get('civitai_api_key') - if (api_key): - headers['Authorization'] = f'Bearer {api_key}' - - return headers - async def download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]: """Download file with resumable downloads and retry mechanism @@ -99,214 +45,69 @@ class CivitaiClient: Returns: Tuple[bool, str]: (success, save_path or error message) """ - max_retries = 5 - retry_count = 0 - base_delay = 2.0 # Base delay for exponential backoff - - # Initial setup - session = await self._ensure_fresh_session() + downloader = await get_downloader() save_path = os.path.join(save_dir, default_filename) - part_path = save_path + '.part' - # Get existing file size for resume - resume_offset = 0 - if os.path.exists(part_path): - resume_offset = os.path.getsize(part_path) - logger.info(f"Resuming download from offset {resume_offset} bytes") + # Use unified downloader with CivitAI authentication + success, result = await downloader.download_file( + url=url, + save_path=save_path, + progress_callback=progress_callback, + use_auth=True, # Enable CivitAI authentication + allow_resume=True + ) - total_size = 0 - - while retry_count <= max_retries: - try: - headers = self._get_request_headers() - - # Add Range header for resume if we have partial data - if resume_offset > 0: - headers['Range'] = f'bytes={resume_offset}-' - - # Add Range header to allow resumable downloads - headers['Accept-Encoding'] = 'identity' # Disable compression for better chunked downloads - - logger.debug(f"Download attempt {retry_count + 1}/{max_retries + 1} from: {url}") - if resume_offset > 0: - logger.debug(f"Requesting range from byte {resume_offset}") - - async with session.get(url, headers=headers, allow_redirects=True) as response: - # Handle different response codes - if response.status == 200: - # Full content response - if resume_offset > 0: - # Server doesn't support ranges, restart from beginning - logger.warning("Server doesn't support range requests, restarting download") - resume_offset = 0 - if os.path.exists(part_path): - os.remove(part_path) - elif response.status == 206: - # Partial content response (resume successful) - content_range = response.headers.get('Content-Range') - if content_range: - # Parse total size from Content-Range header (e.g., "bytes 1024-2047/2048") - range_parts = content_range.split('/') - if len(range_parts) == 2: - total_size = int(range_parts[1]) - logger.info(f"Successfully resumed download from byte {resume_offset}") - elif response.status == 416: - # Range not satisfiable - file might be complete or corrupted - if os.path.exists(part_path): - part_size = os.path.getsize(part_path) - logger.warning(f"Range not satisfiable. Part file size: {part_size}") - # Try to get actual file size - head_response = await session.head(url, headers=self._get_request_headers()) - if head_response.status == 200: - actual_size = int(head_response.headers.get('content-length', 0)) - if part_size == actual_size: - # File is complete, just rename it - os.rename(part_path, save_path) - if progress_callback: - await progress_callback(100) - return True, save_path - # Remove corrupted part file and restart - os.remove(part_path) - resume_offset = 0 - continue - elif response.status == 401: - logger.warning(f"Unauthorized access to resource: {url} (Status 401)") - return False, "Invalid or missing CivitAI API key, or early access restriction." - elif response.status == 403: - logger.warning(f"Forbidden access to resource: {url} (Status 403)") - return False, "Access forbidden: You don't have permission to download this file." - else: - logger.error(f"Download failed for {url} with status {response.status}") - return False, f"Download failed with status {response.status}" - - # Get total file size for progress calculation (if not set from Content-Range) - if total_size == 0: - total_size = int(response.headers.get('content-length', 0)) - if response.status == 206: - # For partial content, add the offset to get total file size - total_size += resume_offset - - current_size = resume_offset - last_progress_report_time = datetime.now() - - # Stream download to file with progress updates using larger buffer - loop = asyncio.get_running_loop() - mode = 'ab' if resume_offset > 0 else 'wb' - with open(part_path, mode) as f: - async for chunk in response.content.iter_chunked(self.chunk_size): - if chunk: - # Run blocking file write in executor - await loop.run_in_executor(None, f.write, chunk) - current_size += len(chunk) - - # Limit progress update frequency to reduce overhead - now = datetime.now() - time_diff = (now - last_progress_report_time).total_seconds() - - if progress_callback and total_size and time_diff >= 1.0: - progress = (current_size / total_size) * 100 - await progress_callback(progress) - last_progress_report_time = now - - # Download completed successfully - # Verify file size if total_size was provided - final_size = os.path.getsize(part_path) - if total_size > 0 and final_size != total_size: - logger.warning(f"File size mismatch. Expected: {total_size}, Got: {final_size}") - # Don't treat this as fatal error, rename anyway - - # Atomically rename .part to final file with retries - max_rename_attempts = 5 - rename_attempt = 0 - rename_success = False - - while rename_attempt < max_rename_attempts and not rename_success: - try: - os.rename(part_path, save_path) - rename_success = True - except PermissionError as e: - rename_attempt += 1 - if rename_attempt < max_rename_attempts: - logger.info(f"File still in use, retrying rename in 2 seconds (attempt {rename_attempt}/{max_rename_attempts})") - await asyncio.sleep(2) # Wait before retrying - else: - logger.error(f"Failed to rename file after {max_rename_attempts} attempts: {e}") - return False, f"Failed to finalize download: {str(e)}" - - # Ensure 100% progress is reported - if progress_callback: - await progress_callback(100) - - return True, save_path - - except (aiohttp.ClientError, aiohttp.ClientPayloadError, - aiohttp.ServerDisconnectedError, asyncio.TimeoutError) as e: - retry_count += 1 - logger.warning(f"Network error during download (attempt {retry_count}/{max_retries + 1}): {e}") - - if retry_count <= max_retries: - # Calculate delay with exponential backoff - delay = base_delay * (2 ** (retry_count - 1)) - logger.info(f"Retrying in {delay} seconds...") - await asyncio.sleep(delay) - - # Update resume offset for next attempt - if os.path.exists(part_path): - resume_offset = os.path.getsize(part_path) - logger.info(f"Will resume from byte {resume_offset}") - - # Refresh session to get new connection - await self.close() - session = await self._ensure_fresh_session() - continue - else: - logger.error(f"Max retries exceeded for download: {e}") - return False, f"Network error after {max_retries + 1} attempts: {str(e)}" - - except Exception as e: - logger.error(f"Unexpected download error: {e}") - return False, str(e) - - return False, f"Download failed after {max_retries + 1} attempts" + return success, result async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]: try: - session = await self._ensure_fresh_session() - async with session.get(f"{self.base_url}/model-versions/by-hash/{model_hash}") as response: - if response.status == 200: - return await response.json() - return None + downloader = await get_downloader() + success, result = await downloader.make_request( + 'GET', + f"{self.base_url}/model-versions/by-hash/{model_hash}", + use_auth=True + ) + if success: + return result + return None except Exception as e: logger.error(f"API Error: {str(e)}") return None async def download_preview_image(self, image_url: str, save_path: str): try: - session = await self._ensure_fresh_session() - async with session.get(image_url) as response: - if response.status == 200: - content = await response.read() - with open(save_path, 'wb') as f: - f.write(content) - return True - return False + downloader = await get_downloader() + success, content = await downloader.download_to_memory( + image_url, + use_auth=False # Preview images don't need auth + ) + if success: + # Ensure directory exists + os.makedirs(os.path.dirname(save_path), exist_ok=True) + with open(save_path, 'wb') as f: + f.write(content) + return True + return False except Exception as e: - print(f"Download Error: {str(e)}") + logger.error(f"Download Error: {str(e)}") return False async def get_model_versions(self, model_id: str) -> List[Dict]: """Get all versions of a model with local availability info""" try: - session = await self._ensure_fresh_session() # Use fresh session - async with session.get(f"{self.base_url}/models/{model_id}") as response: - if response.status != 200: - return None - data = await response.json() + downloader = await get_downloader() + success, result = await downloader.make_request( + 'GET', + f"{self.base_url}/models/{model_id}", + use_auth=True + ) + if success: # Also return model type along with versions return { - 'modelVersions': data.get('modelVersions', []), - 'type': data.get('type', '') + 'modelVersions': result.get('modelVersions', []), + 'type': result.get('type', '') } + return None except Exception as e: logger.error(f"Error fetching model versions: {e}") return None @@ -322,68 +123,74 @@ class CivitaiClient: Optional[Dict]: The model version data with additional fields or None if not found """ try: - session = await self._ensure_fresh_session() - headers = self._get_request_headers() + downloader = await get_downloader() # Case 1: Only version_id is provided if model_id is None and version_id is not None: # First get the version info to extract model_id - async with session.get(f"{self.base_url}/model-versions/{version_id}", headers=headers) as response: - if response.status != 200: - return None - - version = await response.json() - model_id = version.get('modelId') - - if not model_id: - logger.error(f"No modelId found in version {version_id}") - return None + success, version = await downloader.make_request( + 'GET', + f"{self.base_url}/model-versions/{version_id}", + use_auth=True + ) + if not success: + return None + model_id = version.get('modelId') + if not model_id: + logger.error(f"No modelId found in version {version_id}") + return None + # Now get the model data for additional metadata - async with session.get(f"{self.base_url}/models/{model_id}") as response: - if response.status != 200: - return version # Return version without additional metadata - - model_data = await response.json() - + success, model_data = await downloader.make_request( + 'GET', + f"{self.base_url}/models/{model_id}", + use_auth=True + ) + if success: # Enrich version with model data version['model']['description'] = model_data.get("description") version['model']['tags'] = model_data.get("tags", []) version['creator'] = model_data.get("creator") - - return version + + return version # Case 2: model_id is provided (with or without version_id) elif model_id is not None: # Step 1: Get model data to find version_id if not provided and get additional metadata - async with session.get(f"{self.base_url}/models/{model_id}") as response: - if response.status != 200: - return None - - data = await response.json() - model_versions = data.get('modelVersions', []) + success, data = await downloader.make_request( + 'GET', + f"{self.base_url}/models/{model_id}", + use_auth=True + ) + if not success: + return None - # Step 2: Determine the version_id to use - target_version_id = version_id - if target_version_id is None: - target_version_id = model_versions[0].get('id') + model_versions = data.get('modelVersions', []) + # Step 2: Determine the version_id to use + target_version_id = version_id + if target_version_id is None: + target_version_id = model_versions[0].get('id') + # Step 3: Get detailed version info using the version_id - async with session.get(f"{self.base_url}/model-versions/{target_version_id}", headers=headers) as response: - if response.status != 200: - return None - - version = await response.json() - - # Step 4: Enrich version_info with model data - # Add description and tags from model data - version['model']['description'] = data.get("description") - version['model']['tags'] = data.get("tags", []) - - # Add creator from model data - version['creator'] = data.get("creator") - - return version + success, version = await downloader.make_request( + 'GET', + f"{self.base_url}/model-versions/{target_version_id}", + use_auth=True + ) + if not success: + return None + + # Step 4: Enrich version_info with model data + # Add description and tags from model data + version['model']['description'] = data.get("description") + version['model']['tags'] = data.get("tags", []) + + # Add creator from model data + version['creator'] = data.get("creator") + + return version # Case 3: Neither model_id nor version_id provided else: @@ -406,30 +213,29 @@ class CivitaiClient: - An error message if there was an error, or None on success """ try: - session = await self._ensure_fresh_session() + downloader = await get_downloader() url = f"{self.base_url}/model-versions/{version_id}" - headers = self._get_request_headers() logger.debug(f"Resolving DNS for model version info: {url}") - async with session.get(url, headers=headers) as response: - if response.status == 200: - logger.debug(f"Successfully fetched model version info for: {version_id}") - return await response.json(), None - - # Handle specific error cases - if response.status == 404: - # Try to parse the error message - try: - error_data = await response.json() - error_msg = error_data.get('error', f"Model not found (status 404)") - logger.warning(f"Model version not found: {version_id} - {error_msg}") - return None, error_msg - except: - return None, "Model not found (status 404)" - - # Other error cases - logger.error(f"Failed to fetch model info for {version_id} (status {response.status})") - return None, f"Failed to fetch model info (status {response.status})" + success, result = await downloader.make_request( + 'GET', + url, + use_auth=True + ) + + if success: + logger.debug(f"Successfully fetched model version info for: {version_id}") + return result, None + + # Handle specific error cases + if "404" in str(result): + error_msg = f"Model not found (status 404)" + logger.warning(f"Model version not found: {version_id} - {error_msg}") + return None, error_msg + + # Other error cases + logger.error(f"Failed to fetch model info for {version_id}: {result}") + return None, str(result) except Exception as e: error_msg = f"Error fetching model version info: {e}" logger.error(error_msg) @@ -444,48 +250,50 @@ class CivitaiClient: Returns: Tuple[Optional[Dict], int]: A tuple containing: - A dictionary with model metadata or None if not found - - The HTTP status code from the request + - The HTTP status code from the request (0 for exceptions) """ try: - session = await self._ensure_fresh_session() - headers = self._get_request_headers() + downloader = await get_downloader() url = f"{self.base_url}/models/{model_id}" - async with session.get(url, headers=headers) as response: - status_code = response.status - - if status_code != 200: - logger.warning(f"Failed to fetch model metadata: Status {status_code}") - return None, status_code - - data = await response.json() - - # Extract relevant metadata - metadata = { - "description": data.get("description") or "No model description available", - "tags": data.get("tags", []), - "creator": { - "username": data.get("creator", {}).get("username"), - "image": data.get("creator", {}).get("image") - } + success, result = await downloader.make_request( + 'GET', + url, + use_auth=True + ) + + if not success: + # Try to extract status code from error message + status_code = 0 + if "404" in str(result): + status_code = 404 + elif "401" in str(result): + status_code = 401 + elif "403" in str(result): + status_code = 403 + logger.warning(f"Failed to fetch model metadata: {result}") + return None, status_code + + # Extract relevant metadata + metadata = { + "description": result.get("description") or "No model description available", + "tags": result.get("tags", []), + "creator": { + "username": result.get("creator", {}).get("username"), + "image": result.get("creator", {}).get("image") } - - if metadata["description"] or metadata["tags"] or metadata["creator"]["username"]: - return metadata, status_code - else: - logger.warning(f"No metadata found for model {model_id}") - return None, status_code + } + + if metadata["description"] or metadata["tags"] or metadata["creator"]["username"]: + return metadata, 200 + else: + logger.warning(f"No metadata found for model {model_id}") + return None, 200 except Exception as e: logger.error(f"Error fetching model metadata: {e}", exc_info=True) return None, 0 - async def close(self): - """Close the session if it exists""" - if self._session is not None: - await self._session.close() - self._session = None - async def get_image_info(self, image_id: str) -> Optional[Dict]: """Fetch image information from Civitai API @@ -496,22 +304,25 @@ class CivitaiClient: Optional[Dict]: The image data or None if not found """ try: - session = await self._ensure_fresh_session() - headers = self._get_request_headers() + downloader = await get_downloader() url = f"{self.base_url}/images?imageId={image_id}&nsfw=X" logger.debug(f"Fetching image info for ID: {image_id}") - async with session.get(url, headers=headers) as response: - if response.status == 200: - data = await response.json() - if data and "items" in data and len(data["items"]) > 0: - logger.debug(f"Successfully fetched image info for ID: {image_id}") - return data["items"][0] - logger.warning(f"No image found with ID: {image_id}") - return None - - logger.error(f"Failed to fetch image info for ID: {image_id} (status {response.status})") + success, result = await downloader.make_request( + 'GET', + url, + use_auth=True + ) + + if success: + if result and "items" in result and len(result["items"]) > 0: + logger.debug(f"Successfully fetched image info for ID: {image_id}") + return result["items"][0] + logger.warning(f"No image found with ID: {image_id}") return None + + logger.error(f"Failed to fetch image info for ID: {image_id}: {result}") + return None except Exception as e: error_msg = f"Error fetching image info: {e}" logger.error(error_msg) diff --git a/py/services/downloader.py b/py/services/downloader.py new file mode 100644 index 00000000..cb7f5ef1 --- /dev/null +++ b/py/services/downloader.py @@ -0,0 +1,465 @@ +""" +Unified download manager for all HTTP/HTTPS downloads in the application. + +This module provides a centralized download service with: +- Singleton pattern for global session management +- Support for authenticated downloads (e.g., CivitAI API key) +- Resumable downloads with automatic retry +- Progress tracking and callbacks +- Optimized connection pooling and timeouts +- Unified error handling and logging +""" + +import os +import logging +import asyncio +import aiohttp +from datetime import datetime +from typing import Optional, Dict, Tuple, Callable, Union +from ..services.settings_manager import settings + +logger = logging.getLogger(__name__) + + +class Downloader: + """Unified downloader for all HTTP/HTTPS downloads in the application.""" + + _instance = None + _lock = asyncio.Lock() + + @classmethod + async def get_instance(cls): + """Get singleton instance of Downloader""" + async with cls._lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def __init__(self): + """Initialize the downloader with optimal settings""" + # Check if already initialized for singleton pattern + if hasattr(self, '_initialized'): + return + self._initialized = True + + # Session management + self._session = None + self._session_created_at = None + + # Configuration + self.chunk_size = 4 * 1024 * 1024 # 4MB chunks for better throughput + self.max_retries = 5 + self.base_delay = 2.0 # Base delay for exponential backoff + self.session_timeout = 300 # 5 minutes + + # Default headers + self.default_headers = { + 'User-Agent': 'ComfyUI-LoRA-Manager/1.0' + } + + @property + async def session(self) -> aiohttp.ClientSession: + """Get or create the global aiohttp session with optimized settings""" + if self._session is None or self._should_refresh_session(): + await self._create_session() + return self._session + + def _should_refresh_session(self) -> bool: + """Check if session should be refreshed""" + if self._session is None: + return True + + if not hasattr(self, '_session_created_at') or self._session_created_at is None: + return True + + # Refresh if session is older than timeout + if (datetime.now() - self._session_created_at).total_seconds() > self.session_timeout: + return True + + return False + + async def _create_session(self): + """Create a new aiohttp session with optimized settings""" + # Close existing session if any + if self._session is not None: + await self._session.close() + + # Optimize TCP connection parameters + connector = aiohttp.TCPConnector( + ssl=True, + limit=8, # Concurrent connections + ttl_dns_cache=300, # DNS cache timeout + force_close=False, # Keep connections for reuse + enable_cleanup_closed=True + ) + + # Configure timeout parameters + timeout = aiohttp.ClientTimeout( + total=None, # No total timeout for large downloads + connect=60, # Connection timeout + sock_read=None # No socket read timeout + ) + + self._session = aiohttp.ClientSession( + connector=connector, + trust_env=True, # Use system proxy settings + timeout=timeout + ) + self._session_created_at = datetime.now() + + logger.debug("Created new HTTP session") + + def _get_auth_headers(self, use_auth: bool = False) -> Dict[str, str]: + """Get headers with optional authentication""" + headers = self.default_headers.copy() + + if use_auth: + # Add CivitAI API key if available + api_key = settings.get('civitai_api_key') + if api_key: + headers['Authorization'] = f'Bearer {api_key}' + headers['Content-Type'] = 'application/json' + + return headers + + async def download_file( + self, + url: str, + save_path: str, + progress_callback: Optional[Callable[[float], None]] = None, + use_auth: bool = False, + custom_headers: Optional[Dict[str, str]] = None, + allow_resume: bool = True + ) -> Tuple[bool, str]: + """ + Download a file with resumable downloads and retry mechanism + + Args: + url: Download URL + save_path: Full path where the file should be saved + progress_callback: Optional callback for progress updates (0-100) + use_auth: Whether to include authentication headers (e.g., CivitAI API key) + custom_headers: Additional headers to include in request + allow_resume: Whether to support resumable downloads + + Returns: + Tuple[bool, str]: (success, save_path or error message) + """ + retry_count = 0 + part_path = save_path + '.part' if allow_resume else save_path + + # Prepare headers + headers = self._get_auth_headers(use_auth) + if custom_headers: + headers.update(custom_headers) + + # Get existing file size for resume + resume_offset = 0 + if allow_resume and os.path.exists(part_path): + resume_offset = os.path.getsize(part_path) + logger.info(f"Resuming download from offset {resume_offset} bytes") + + total_size = 0 + + while retry_count <= self.max_retries: + try: + session = await self.session + + # Add Range header for resume if we have partial data + request_headers = headers.copy() + if allow_resume and resume_offset > 0: + request_headers['Range'] = f'bytes={resume_offset}-' + + # Disable compression for better chunked downloads + request_headers['Accept-Encoding'] = 'identity' + + logger.debug(f"Download attempt {retry_count + 1}/{self.max_retries + 1} from: {url}") + if resume_offset > 0: + logger.debug(f"Requesting range from byte {resume_offset}") + + async with session.get(url, headers=request_headers, allow_redirects=True) as response: + # Handle different response codes + if response.status == 200: + # Full content response + if resume_offset > 0: + # Server doesn't support ranges, restart from beginning + logger.warning("Server doesn't support range requests, restarting download") + resume_offset = 0 + if os.path.exists(part_path): + os.remove(part_path) + elif response.status == 206: + # Partial content response (resume successful) + content_range = response.headers.get('Content-Range') + if content_range: + # Parse total size from Content-Range header (e.g., "bytes 1024-2047/2048") + range_parts = content_range.split('/') + if len(range_parts) == 2: + total_size = int(range_parts[1]) + logger.info(f"Successfully resumed download from byte {resume_offset}") + elif response.status == 416: + # Range not satisfiable - file might be complete or corrupted + if allow_resume and os.path.exists(part_path): + part_size = os.path.getsize(part_path) + logger.warning(f"Range not satisfiable. Part file size: {part_size}") + # Try to get actual file size + head_response = await session.head(url, headers=headers) + if head_response.status == 200: + actual_size = int(head_response.headers.get('content-length', 0)) + if part_size == actual_size: + # File is complete, just rename it + if allow_resume: + os.rename(part_path, save_path) + if progress_callback: + await progress_callback(100) + return True, save_path + # Remove corrupted part file and restart + os.remove(part_path) + resume_offset = 0 + continue + elif response.status == 401: + logger.warning(f"Unauthorized access to resource: {url} (Status 401)") + return False, "Invalid or missing API key, or early access restriction." + elif response.status == 403: + logger.warning(f"Forbidden access to resource: {url} (Status 403)") + return False, "Access forbidden: You don't have permission to download this file." + elif response.status == 404: + logger.warning(f"Resource not found: {url} (Status 404)") + return False, "File not found - the download link may be invalid or expired." + else: + logger.error(f"Download failed for {url} with status {response.status}") + return False, f"Download failed with status {response.status}" + + # Get total file size for progress calculation (if not set from Content-Range) + if total_size == 0: + total_size = int(response.headers.get('content-length', 0)) + if response.status == 206: + # For partial content, add the offset to get total file size + total_size += resume_offset + + current_size = resume_offset + last_progress_report_time = datetime.now() + + # Ensure directory exists + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + # Stream download to file with progress updates + loop = asyncio.get_running_loop() + mode = 'ab' if (allow_resume and resume_offset > 0) else 'wb' + with open(part_path, mode) as f: + async for chunk in response.content.iter_chunked(self.chunk_size): + if chunk: + # Run blocking file write in executor + await loop.run_in_executor(None, f.write, chunk) + current_size += len(chunk) + + # Limit progress update frequency to reduce overhead + now = datetime.now() + time_diff = (now - last_progress_report_time).total_seconds() + + if progress_callback and total_size and time_diff >= 1.0: + progress = (current_size / total_size) * 100 + await progress_callback(progress) + last_progress_report_time = now + + # Download completed successfully + # Verify file size if total_size was provided + final_size = os.path.getsize(part_path) + if total_size > 0 and final_size != total_size: + logger.warning(f"File size mismatch. Expected: {total_size}, Got: {final_size}") + # Don't treat this as fatal error, continue anyway + + # Atomically rename .part to final file (only if using resume) + if allow_resume and part_path != save_path: + max_rename_attempts = 5 + rename_attempt = 0 + rename_success = False + + while rename_attempt < max_rename_attempts and not rename_success: + try: + os.rename(part_path, save_path) + rename_success = True + except PermissionError as e: + rename_attempt += 1 + if rename_attempt < max_rename_attempts: + logger.info(f"File still in use, retrying rename in 2 seconds (attempt {rename_attempt}/{max_rename_attempts})") + await asyncio.sleep(2) + else: + logger.error(f"Failed to rename file after {max_rename_attempts} attempts: {e}") + return False, f"Failed to finalize download: {str(e)}" + + # Ensure 100% progress is reported + if progress_callback: + await progress_callback(100) + + return True, save_path + + except (aiohttp.ClientError, aiohttp.ClientPayloadError, + aiohttp.ServerDisconnectedError, asyncio.TimeoutError) as e: + retry_count += 1 + logger.warning(f"Network error during download (attempt {retry_count}/{self.max_retries + 1}): {e}") + + if retry_count <= self.max_retries: + # Calculate delay with exponential backoff + delay = self.base_delay * (2 ** (retry_count - 1)) + logger.info(f"Retrying in {delay} seconds...") + await asyncio.sleep(delay) + + # Update resume offset for next attempt + if allow_resume and os.path.exists(part_path): + resume_offset = os.path.getsize(part_path) + logger.info(f"Will resume from byte {resume_offset}") + + # Refresh session to get new connection + await self._create_session() + continue + else: + logger.error(f"Max retries exceeded for download: {e}") + return False, f"Network error after {self.max_retries + 1} attempts: {str(e)}" + + except Exception as e: + logger.error(f"Unexpected download error: {e}") + return False, str(e) + + return False, f"Download failed after {self.max_retries + 1} attempts" + + async def download_to_memory( + self, + url: str, + use_auth: bool = False, + custom_headers: Optional[Dict[str, str]] = None + ) -> Tuple[bool, Union[bytes, str]]: + """ + Download a file to memory (for small files like preview images) + + Args: + url: Download URL + use_auth: Whether to include authentication headers + custom_headers: Additional headers to include in request + + Returns: + Tuple[bool, Union[bytes, str]]: (success, content or error message) + """ + try: + session = await self.session + + # Prepare headers + headers = self._get_auth_headers(use_auth) + if custom_headers: + headers.update(custom_headers) + + async with session.get(url, headers=headers) as response: + if response.status == 200: + content = await response.read() + return True, content + elif response.status == 401: + return False, "Unauthorized access - invalid or missing API key" + elif response.status == 403: + return False, "Access forbidden" + elif response.status == 404: + return False, "File not found" + else: + return False, f"Download failed with status {response.status}" + + except Exception as e: + logger.error(f"Error downloading to memory from {url}: {e}") + return False, str(e) + + async def get_response_headers( + self, + url: str, + use_auth: bool = False, + custom_headers: Optional[Dict[str, str]] = None + ) -> Tuple[bool, Union[Dict, str]]: + """ + Get response headers without downloading the full content + + Args: + url: URL to check + use_auth: Whether to include authentication headers + custom_headers: Additional headers to include in request + + Returns: + Tuple[bool, Union[Dict, str]]: (success, headers dict or error message) + """ + try: + session = await self.session + + # Prepare headers + headers = self._get_auth_headers(use_auth) + if custom_headers: + headers.update(custom_headers) + + async with session.head(url, headers=headers) as response: + if response.status == 200: + return True, dict(response.headers) + else: + return False, f"Head request failed with status {response.status}" + + except Exception as e: + logger.error(f"Error getting headers from {url}: {e}") + return False, str(e) + + async def make_request( + self, + method: str, + url: str, + use_auth: bool = False, + custom_headers: Optional[Dict[str, str]] = None, + **kwargs + ) -> Tuple[bool, Union[Dict, str]]: + """ + Make a generic HTTP request and return JSON response + + Args: + method: HTTP method (GET, POST, etc.) + url: Request URL + use_auth: Whether to include authentication headers + custom_headers: Additional headers to include in request + **kwargs: Additional arguments for aiohttp request + + Returns: + Tuple[bool, Union[Dict, str]]: (success, response data or error message) + """ + try: + session = await self.session + + # Prepare headers + headers = self._get_auth_headers(use_auth) + if custom_headers: + headers.update(custom_headers) + + async with session.request(method, url, headers=headers, **kwargs) as response: + if response.status == 200: + # Try to parse as JSON, fall back to text + try: + data = await response.json() + return True, data + except: + text = await response.text() + return True, text + elif response.status == 401: + return False, "Unauthorized access - invalid or missing API key" + elif response.status == 403: + return False, "Access forbidden" + elif response.status == 404: + return False, "Resource not found" + else: + return False, f"Request failed with status {response.status}" + + except Exception as e: + logger.error(f"Error making {method} request to {url}: {e}") + return False, str(e) + + async def close(self): + """Close the HTTP session""" + if self._session is not None: + await self._session.close() + self._session = None + self._session_created_at = None + logger.debug("Closed HTTP session") + + +# Global instance accessor +async def get_downloader() -> Downloader: + """Get the global downloader instance""" + return await Downloader.get_instance() diff --git a/py/services/metadata_archive_manager.py b/py/services/metadata_archive_manager.py index 3daf761b..a1ba9b74 100644 --- a/py/services/metadata_archive_manager.py +++ b/py/services/metadata_archive_manager.py @@ -1,9 +1,9 @@ import zipfile -import aiohttp import logging import asyncio from pathlib import Path from typing import Optional +from .downloader import get_downloader logger = logging.getLogger(__name__) @@ -67,6 +67,8 @@ class MetadataArchiveManager: async def _download_archive(self, progress_callback=None) -> bool: """Download the zip archive from one of the available URLs""" + downloader = await get_downloader() + for url in self.DOWNLOAD_URLS: try: logger.info(f"Attempting to download from {url}") @@ -74,26 +76,25 @@ class MetadataArchiveManager: if progress_callback: progress_callback("download", f"Downloading from {url}") - async with aiohttp.ClientSession() as session: - async with session.get(url) as response: - if response.status == 200: - total_size = int(response.headers.get('content-length', 0)) - downloaded = 0 - - with open(self.archive_path, 'wb') as f: - async for chunk in response.content.iter_chunked(8192): - f.write(chunk) - downloaded += len(chunk) - - if progress_callback and total_size > 0: - percentage = (downloaded / total_size) * 100 - progress_callback("download", f"Downloaded {percentage:.1f}%") - - logger.info(f"Successfully downloaded archive from {url}") - return True - else: - logger.warning(f"Failed to download from {url}: HTTP {response.status}") - continue + # Custom progress callback to report download progress + async def download_progress(progress): + if progress_callback: + progress_callback("download", f"Downloaded {progress:.1f}%") + + success, result = await downloader.download_file( + url=url, + save_path=str(self.archive_path), + progress_callback=download_progress, + use_auth=False, # Public download, no auth needed + allow_resume=True + ) + + if success: + logger.info(f"Successfully downloaded archive from {url}") + return True + else: + logger.warning(f"Failed to download from {url}: {result}") + continue except Exception as e: logger.warning(f"Error downloading from {url}: {e}") diff --git a/py/utils/example_images_download_manager.py b/py/utils/example_images_download_manager.py index db1b93f0..e3f46244 100644 --- a/py/utils/example_images_download_manager.py +++ b/py/utils/example_images_download_manager.py @@ -3,13 +3,13 @@ import os import asyncio import json import time -import aiohttp from aiohttp import web from ..services.service_registry import ServiceRegistry from ..utils.metadata_manager import MetadataManager from .example_images_processor import ExampleImagesProcessor from .example_images_metadata import MetadataUpdater from ..services.websocket_manager import ws_manager # Add this import at the top +from ..services.downloader import get_downloader logger = logging.getLogger(__name__) @@ -199,19 +199,8 @@ class DownloadManager: """Download example images for all models""" global is_downloading, download_progress - # Create independent download session - connector = aiohttp.TCPConnector( - ssl=True, - limit=3, - force_close=False, - enable_cleanup_closed=True - ) - timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=60) - independent_session = aiohttp.ClientSession( - connector=connector, - trust_env=True, - timeout=timeout - ) + # Get unified downloader + downloader = await get_downloader() try: # Get scanners @@ -246,7 +235,7 @@ class DownloadManager: # Main logic for processing model is here, but actual operations are delegated to other classes was_remote_download = await DownloadManager._process_model( scanner_type, model, scanner, - output_dir, optimize, independent_session + output_dir, optimize, downloader ) # Update progress @@ -270,12 +259,6 @@ class DownloadManager: download_progress['end_time'] = time.time() finally: - # Close the independent session - try: - await independent_session.close() - except Exception as e: - logger.error(f"Error closing download session: {e}") - # Save final progress to file try: DownloadManager._save_progress(output_dir) @@ -286,7 +269,7 @@ class DownloadManager: is_downloading = False @staticmethod - async def _process_model(scanner_type, model, scanner, output_dir, optimize, independent_session): + async def _process_model(scanner_type, model, scanner, output_dir, optimize, downloader): """Process a single model download""" global download_progress @@ -347,7 +330,7 @@ class DownloadManager: images = model.get('civitai', {}).get('images', []) success, is_stale = await ExampleImagesProcessor.download_model_images( - model_hash, model_name, images, model_dir, optimize, independent_session + model_hash, model_name, images, model_dir, optimize, downloader ) # If metadata is stale, try to refresh it @@ -365,7 +348,7 @@ class DownloadManager: # Retry download with updated metadata updated_images = updated_model.get('civitai', {}).get('images', []) success, _ = await ExampleImagesProcessor.download_model_images( - model_hash, model_name, updated_images, model_dir, optimize, independent_session + model_hash, model_name, updated_images, model_dir, optimize, downloader ) download_progress['refreshed_models'].add(model_hash) @@ -529,19 +512,8 @@ class DownloadManager: """Download example images for specific models only - synchronous version""" global download_progress - # Create independent download session - connector = aiohttp.TCPConnector( - ssl=True, - limit=3, - force_close=False, - enable_cleanup_closed=True - ) - timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=60) - independent_session = aiohttp.ClientSession( - connector=connector, - trust_env=True, - timeout=timeout - ) + # Get unified downloader + downloader = await get_downloader() try: # Get scanners @@ -586,7 +558,7 @@ class DownloadManager: # Force process this model regardless of previous status was_successful = await DownloadManager._process_specific_model( scanner_type, model, scanner, - output_dir, optimize, independent_session + output_dir, optimize, downloader ) if was_successful: @@ -650,14 +622,11 @@ class DownloadManager: raise finally: - # Close the independent session - try: - await independent_session.close() - except Exception as e: - logger.error(f"Error closing download session: {e}") + # No need to close any sessions since we use the global downloader + pass @staticmethod - async def _process_specific_model(scanner_type, model, scanner, output_dir, optimize, independent_session): + async def _process_specific_model(scanner_type, model, scanner, output_dir, optimize, downloader): """Process a specific model for forced download, ignoring previous download status""" global download_progress @@ -701,7 +670,7 @@ class DownloadManager: images = model.get('civitai', {}).get('images', []) success, is_stale, failed_images = await ExampleImagesProcessor.download_model_images_with_tracking( - model_hash, model_name, images, model_dir, optimize, independent_session + model_hash, model_name, images, model_dir, optimize, downloader ) # If metadata is stale, try to refresh it @@ -719,7 +688,7 @@ class DownloadManager: # Retry download with updated metadata updated_images = updated_model.get('civitai', {}).get('images', []) success, _, additional_failed_images = await ExampleImagesProcessor.download_model_images_with_tracking( - model_hash, model_name, updated_images, model_dir, optimize, independent_session + model_hash, model_name, updated_images, model_dir, optimize, downloader ) # Combine failed images from both attempts diff --git a/py/utils/example_images_processor.py b/py/utils/example_images_processor.py index 6d14e621..9dba4e2c 100644 --- a/py/utils/example_images_processor.py +++ b/py/utils/example_images_processor.py @@ -35,7 +35,7 @@ class ExampleImagesProcessor: return image_url @staticmethod - async def download_model_images(model_hash, model_name, model_images, model_dir, optimize, independent_session): + async def download_model_images(model_hash, model_name, model_images, model_dir, optimize, downloader): """Download images for a single model Returns: @@ -78,23 +78,25 @@ class ExampleImagesProcessor: try: logger.debug(f"Downloading {save_filename} for {model_name}") - # Download directly using the independent session - async with independent_session.get(image_url, timeout=60) as response: - if response.status == 200: - with open(save_path, 'wb') as f: - async for chunk in response.content.iter_chunked(8192): - if chunk: - f.write(chunk) - elif response.status == 404: - error_msg = f"Failed to download file: {image_url}, status code: 404 - Model metadata might be stale" - logger.warning(error_msg) - model_success = False # Mark the model as failed due to 404 error - # Return early to trigger metadata refresh attempt - return False, True # (success, is_metadata_stale) - else: - error_msg = f"Failed to download file: {image_url}, status code: {response.status}" - logger.warning(error_msg) - model_success = False # Mark the model as failed + # Download using the unified downloader + success, content = await downloader.download_to_memory( + image_url, + use_auth=False # Example images don't need auth + ) + + if success: + with open(save_path, 'wb') as f: + f.write(content) + elif "404" in str(content): + error_msg = f"Failed to download file: {image_url}, status code: 404 - Model metadata might be stale" + logger.warning(error_msg) + model_success = False # Mark the model as failed due to 404 error + # Return early to trigger metadata refresh attempt + return False, True # (success, is_metadata_stale) + else: + error_msg = f"Failed to download file: {image_url}, error: {content}" + logger.warning(error_msg) + model_success = False # Mark the model as failed except Exception as e: error_msg = f"Error downloading file {image_url}: {str(e)}" logger.error(error_msg) @@ -103,7 +105,7 @@ class ExampleImagesProcessor: return model_success, False # (success, is_metadata_stale) @staticmethod - async def download_model_images_with_tracking(model_hash, model_name, model_images, model_dir, optimize, independent_session): + async def download_model_images_with_tracking(model_hash, model_name, model_images, model_dir, optimize, downloader): """Download images for a single model with tracking of failed image URLs Returns: @@ -147,25 +149,27 @@ class ExampleImagesProcessor: try: logger.debug(f"Downloading {save_filename} for {model_name}") - # Download directly using the independent session - async with independent_session.get(image_url, timeout=60) as response: - if response.status == 200: - with open(save_path, 'wb') as f: - async for chunk in response.content.iter_chunked(8192): - if chunk: - f.write(chunk) - elif response.status == 404: - error_msg = f"Failed to download file: {image_url}, status code: 404 - Model metadata might be stale" - logger.warning(error_msg) - model_success = False # Mark the model as failed due to 404 error - failed_images.append(image_url) # Track failed URL - # Return early to trigger metadata refresh attempt - return False, True, failed_images # (success, is_metadata_stale, failed_images) - else: - error_msg = f"Failed to download file: {image_url}, status code: {response.status}" - logger.warning(error_msg) - model_success = False # Mark the model as failed - failed_images.append(image_url) # Track failed URL + # Download using the unified downloader + success, content = await downloader.download_to_memory( + image_url, + use_auth=False # Example images don't need auth + ) + + if success: + with open(save_path, 'wb') as f: + f.write(content) + elif "404" in str(content): + error_msg = f"Failed to download file: {image_url}, status code: 404 - Model metadata might be stale" + logger.warning(error_msg) + model_success = False # Mark the model as failed due to 404 error + failed_images.append(image_url) # Track failed URL + # Return early to trigger metadata refresh attempt + return False, True, failed_images # (success, is_metadata_stale, failed_images) + else: + error_msg = f"Failed to download file: {image_url}, error: {content}" + logger.warning(error_msg) + model_success = False # Mark the model as failed + failed_images.append(image_url) # Track failed URL except Exception as e: error_msg = f"Error downloading file {image_url}: {str(e)}" logger.error(error_msg) From 1ea468cfc4b43fae260f1175422ceb7768a2d895 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Tue, 9 Sep 2025 15:24:28 +0800 Subject: [PATCH 05/13] feat(metadata): enhance metadata archive management with download progress and status updates --- locales/en.json | 6 +- py/routes/misc_routes.py | 14 +- py/services/metadata_archive_manager.py | 2 +- static/css/components/modal/_base.css | 82 +++++++++- static/js/managers/SettingsManager.js | 149 +++++++++++++++--- .../components/modals/settings_modal.html | 48 +++--- 6 files changed, 241 insertions(+), 60 deletions(-) diff --git a/locales/en.json b/locales/en.json index 52bc8580..343b813a 100644 --- a/locales/en.json +++ b/locales/en.json @@ -300,7 +300,11 @@ "downloadError": "Failed to download metadata archive database", "removeSuccess": "Metadata archive database removed successfully", "removeError": "Failed to remove metadata archive database", - "removeConfirm": "Are you sure you want to remove the metadata archive database? This will delete the local database file and you'll need to download it again to use this feature." + "removeConfirm": "Are you sure you want to remove the metadata archive database? This will delete the local database file and you'll need to download it again to use this feature.", + "preparing": "Preparing download...", + "connecting": "Connecting to download server...", + "completed": "Completed", + "downloadComplete": "Download completed successfully" } }, "loras": { diff --git a/py/routes/misc_routes.py b/py/routes/misc_routes.py index 9a29a24d..118afea6 100644 --- a/py/routes/misc_routes.py +++ b/py/routes/misc_routes.py @@ -711,13 +711,23 @@ class MiscRoutes: try: archive_manager = await get_metadata_archive_manager() + # Get the download_id from query parameters if provided + download_id = request.query.get('download_id') + # Progress callback to send updates via WebSocket def progress_callback(stage, message): - asyncio.create_task(ws_manager.broadcast({ + data = { 'stage': stage, 'message': message, 'type': 'metadata_archive_download' - })) + } + + if download_id: + # Send to specific download WebSocket if download_id is provided + asyncio.create_task(ws_manager.broadcast_download_progress(download_id, data)) + else: + # Fallback to general broadcast + asyncio.create_task(ws_manager.broadcast(data)) # Download and extract in background success = await archive_manager.download_and_extract_database(progress_callback) diff --git a/py/services/metadata_archive_manager.py b/py/services/metadata_archive_manager.py index a1ba9b74..49b22c01 100644 --- a/py/services/metadata_archive_manager.py +++ b/py/services/metadata_archive_manager.py @@ -79,7 +79,7 @@ class MetadataArchiveManager: # Custom progress callback to report download progress async def download_progress(progress): if progress_callback: - progress_callback("download", f"Downloaded {progress:.1f}%") + progress_callback("download", f"Downloading archive... {progress:.1f}%") success, result = await downloader.download_file( url=url, diff --git a/static/css/components/modal/_base.css b/static/css/components/modal/_base.css index 2b1d542d..57d01ac9 100644 --- a/static/css/components/modal/_base.css +++ b/static/css/components/modal/_base.css @@ -208,6 +208,14 @@ body.modal-open { pointer-events: none; } +button:disabled, +.primary-btn:disabled, +.danger-btn:disabled { + opacity: 0.5; + cursor: not-allowed; + pointer-events: none; +} + .restart-required-icon { color: var(--lora-warning); margin-left: 5px; @@ -228,14 +236,76 @@ body.modal-open { background-color: oklch(35% 0.02 256 / 0.98); } -.primary-btn.disabled { - opacity: 0.5; - cursor: not-allowed; +/* Danger button styles */ +.danger-btn { + display: flex; + align-items: center; + gap: 8px; + padding: 8px 16px; + background-color: var(--lora-error); + color: white; + border: none; + border-radius: var(--border-radius-sm); + cursor: pointer; + transition: background-color 0.2s; + font-size: 0.95em; } -.primary-btn.disabled { - opacity: 0.5; - cursor: not-allowed; +.danger-btn:hover { + background-color: oklch(from var(--lora-error) l c h / 85%); + color: white; +} + +/* Metadata archive status styles */ +.metadata-archive-status { + background: rgba(0, 0, 0, 0.03); + border: 1px solid rgba(0, 0, 0, 0.1); + border-radius: var(--border-radius-sm); + padding: var(--space-2); + margin-bottom: var(--space-2); +} + +[data-theme="dark"] .metadata-archive-status { + background: rgba(255, 255, 255, 0.03); + border: 1px solid var(--lora-border); +} + +.archive-status-item { + display: flex; + justify-content: space-between; + align-items: center; + margin-bottom: 8px; + font-size: 0.95em; +} + +.archive-status-item:last-child { + margin-bottom: 0; +} + +.archive-status-label { + font-weight: 500; + color: var(--text-color); + opacity: 0.8; +} + +.archive-status-value { + color: var(--text-color); +} + +.archive-status-value.status-available { + color: var(--lora-success, #10b981); +} + +.archive-status-value.status-unavailable { + color: var(--lora-warning, #f59e0b); +} + +.archive-status-value.status-enabled { + color: var(--lora-success, #10b981); +} + +.archive-status-value.status-disabled { + color: var(--lora-error, #ef4444); } /* Add styles for delete preview image */ diff --git a/static/js/managers/SettingsManager.js b/static/js/managers/SettingsManager.js index 1a15bdb4..ac795903 100644 --- a/static/js/managers/SettingsManager.js +++ b/static/js/managers/SettingsManager.js @@ -789,6 +789,8 @@ export class SettingsManager { state.global.settings.compactMode = value; } else if (settingKey === 'include_trigger_words') { state.global.settings.includeTriggerWords = value; + } else if (settingKey === 'enable_metadata_archive_db') { + state.global.settings.enable_metadata_archive_db = value; } else { // For any other settings that might be added in the future state.global.settings[settingKey] = value; @@ -799,7 +801,7 @@ export class SettingsManager { try { // For backend settings, make API call - if (['show_only_sfw'].includes(settingKey)) { + if (['show_only_sfw', 'enable_metadata_archive_db'].includes(settingKey)) { const payload = {}; payload[settingKey] = value; @@ -814,6 +816,11 @@ export class SettingsManager { if (!response.ok) { throw new Error('Failed to save setting'); } + + // Refresh metadata archive status when enable setting changes + if (settingKey === 'enable_metadata_archive_db') { + await this.updateMetadataArchiveStatus(); + } } showToast('toast.settings.settingsUpdated', { setting: settingKey.replace(/_/g, ' ') }, 'success'); @@ -872,6 +879,8 @@ export class SettingsManager { state.global.settings.compactMode = (value !== 'default'); } else if (settingKey === 'card_info_display') { state.global.settings.cardInfoDisplay = value; + } else if (settingKey === 'metadata_provider_priority') { + state.global.settings.metadata_provider_priority = value; } else { // For any other settings that might be added in the future state.global.settings[settingKey] = value; @@ -882,7 +891,7 @@ export class SettingsManager { try { // For backend settings, make API call - if (settingKey === 'default_lora_root' || settingKey === 'default_checkpoint_root' || settingKey === 'default_embedding_root' || settingKey === 'download_path_templates') { + if (settingKey === 'default_lora_root' || settingKey === 'default_checkpoint_root' || settingKey === 'default_embedding_root' || settingKey === 'download_path_templates' || settingKey === 'metadata_provider_priority') { const payload = {}; if (settingKey === 'download_path_templates') { payload[settingKey] = state.global.settings.download_path_templates; @@ -903,6 +912,11 @@ export class SettingsManager { } showToast('toast.settings.settingsUpdated', { setting: settingKey.replace(/_/g, ' ') }, 'success'); + + // Refresh metadata archive status when provider priority changes + if (settingKey === 'metadata_provider_priority') { + await this.updateMetadataArchiveStatus(); + } } // Apply frontend settings immediately @@ -960,24 +974,24 @@ export class SettingsManager { const sizeText = status.databaseSize > 0 ? ` (${this.formatFileSize(status.databaseSize)})` : ''; statusContainer.innerHTML = ` -
-
- ${translate('settings.metadataArchive.status')}: - - ${status.isAvailable ? translate('settings.metadataArchive.statusAvailable') : translate('settings.metadataArchive.statusUnavailable')} - +
+ ${translate('settings.metadataArchive.status')}: + + ${status.isAvailable ? translate('settings.metadataArchive.statusAvailable') : translate('settings.metadataArchive.statusUnavailable')} ${sizeText} -
-
- ${translate('settings.metadataArchive.enabled')}: - - ${status.isEnabled ? translate('common.enabled') : translate('common.disabled')} - -
-
- ${translate('settings.metadataArchive.currentPriority')}: + +
+
+ ${translate('settings.metadataArchive.enabled')}: + + ${status.isEnabled ? translate('common.status.enabled') : translate('common.status.disabled')} + +
+
+ ${translate('settings.metadataArchive.currentPriority')}: + ${status.priority === 'archive_db' ? translate('settings.metadataArchive.priorityArchiveDb') : translate('settings.metadataArchive.priorityCivitaiApi')} -
+
`; @@ -1012,12 +1026,81 @@ export class SettingsManager { async downloadMetadataArchive() { try { const downloadBtn = document.getElementById('downloadMetadataArchiveBtn'); + if (downloadBtn) { downloadBtn.disabled = true; downloadBtn.textContent = translate('settings.metadataArchive.downloadingButton'); } + + // Show loading with enhanced progress + const progressUpdater = state.loadingManager.showEnhancedProgress(translate('settings.metadataArchive.preparing')); - const response = await fetch('/api/download-metadata-archive', { + // Set up WebSocket for progress updates + const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://'; + const downloadId = `metadata_archive_${Date.now()}`; + const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/download-progress?id=${downloadId}`); + + let wsConnected = false; + let actualDownloadId = downloadId; // Will be updated when WebSocket confirms the ID + + // Promise to wait for WebSocket connection and ID confirmation + const wsReady = new Promise((resolve) => { + ws.onopen = () => { + wsConnected = true; + console.log('Connected to metadata archive download progress WebSocket'); + }; + + ws.onmessage = (event) => { + const data = JSON.parse(event.data); + + // Handle download ID confirmation + if (data.type === 'download_id') { + actualDownloadId = data.download_id; + console.log(`Connected to metadata archive download progress with ID: ${data.download_id}`); + resolve(data.download_id); + return; + } + + // Handle metadata archive download progress + if (data.type === 'metadata_archive_download') { + const message = data.message || ''; + + // Update progress bar based on stage + let progressPercent = 0; + if (data.stage === 'download') { + // Extract percentage from message if available + const percentMatch = data.message.match(/(\d+\.?\d*)%/); + if (percentMatch) { + progressPercent = Math.min(parseFloat(percentMatch[1]), 90); // Cap at 90% for download + } else { + progressPercent = 0; // Default download progress + } + } else if (data.stage === 'extract') { + progressPercent = 95; // Near completion for extraction + } + + // Update loading manager progress + progressUpdater.updateProgress(progressPercent, '', `${message}`); + } + }; + + ws.onerror = (error) => { + console.error('WebSocket error:', error); + resolve(downloadId); // Fallback to original ID + }; + + // Timeout fallback + setTimeout(() => resolve(downloadId), 5000); + }); + + ws.onclose = () => { + console.log('WebSocket connection closed'); + }; + + // Wait for WebSocket to be ready + await wsReady; + + const response = await fetch(`/api/download-metadata-archive?download_id=${encodeURIComponent(actualDownloadId)}`, { method: 'POST', headers: { 'Content-Type': 'application/json' @@ -1026,8 +1109,16 @@ export class SettingsManager { const data = await response.json(); + // Close WebSocket + if (ws.readyState === WebSocket.OPEN) { + ws.close(); + } + if (data.success) { - showNotification(translate('settings.metadataArchive.downloadSuccess'), 'success'); + // Complete progress + await progressUpdater.complete(translate('settings.metadataArchive.downloadComplete')); + + showToast('settings.metadataArchive.downloadSuccess', 'success'); // Update settings in state state.global.settings.enable_metadata_archive_db = true; @@ -1041,11 +1132,17 @@ export class SettingsManager { await this.updateMetadataArchiveStatus(); } else { - showNotification(translate('settings.metadataArchive.downloadError') + ': ' + data.error, 'error'); + // Hide loading on error + state.loadingManager.hide(); + showToast('settings.metadataArchive.downloadError' + ': ' + data.error, 'error'); } } catch (error) { console.error('Error downloading metadata archive:', error); - showNotification(translate('settings.metadataArchive.downloadError') + ': ' + error.message, 'error'); + + // Hide loading on error + state.loadingManager.hide(); + + showToast('settings.metadataArchive.downloadError' + ': ' + error.message, 'error'); } finally { const downloadBtn = document.getElementById('downloadMetadataArchiveBtn'); if (downloadBtn) { @@ -1077,8 +1174,8 @@ export class SettingsManager { const data = await response.json(); if (data.success) { - showNotification(translate('settings.metadataArchive.removeSuccess'), 'success'); - + showToast('settings.metadataArchive.removeSuccess', 'success'); + // Update settings in state state.global.settings.enable_metadata_archive_db = false; setStorageItem('settings', state.global.settings); @@ -1091,11 +1188,11 @@ export class SettingsManager { await this.updateMetadataArchiveStatus(); } else { - showNotification(translate('settings.metadataArchive.removeError') + ': ' + data.error, 'error'); + showToast('settings.metadataArchive.removeError' + ': ' + data.error, 'error'); } } catch (error) { console.error('Error removing metadata archive:', error); - showNotification(translate('settings.metadataArchive.removeError') + ': ' + error.message, 'error'); + showToast('settings.metadataArchive.removeError' + ': ' + error.message, 'error'); } finally { const removeBtn = document.getElementById('removeMetadataArchiveBtn'); if (removeBtn) { diff --git a/templates/components/modals/settings_modal.html b/templates/components/modals/settings_modal.html index a2f85716..d150c81d 100644 --- a/templates/components/modals/settings_modal.html +++ b/templates/components/modals/settings_modal.html @@ -398,28 +398,6 @@
- -
-

{{ t('settings.sections.misc') }}

-
-
-
- -
-
- -
-
-
- {{ t('settings.misc.includeTriggerWordsHelp') }} -
-
-
-

{{ t('settings.sections.metadataArchive') }}

@@ -470,10 +448,10 @@
- -
@@ -483,6 +461,28 @@ + + +
+

{{ t('settings.sections.misc') }}

+
+
+
+ +
+
+ +
+
+
+ {{ t('settings.misc.includeTriggerWordsHelp') }} +
+
+
\ No newline at end of file From 6fd74952b76444c3172304b9a732df5d3eae0baf Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Tue, 9 Sep 2025 20:57:45 +0800 Subject: [PATCH 06/13] Refactor metadata handling to use unified provider system - Replaced direct usage of Civitai client with a fallback metadata provider across all recipe parsers. - Updated metadata service to improve initialization and error handling. - Enhanced download manager to utilize a downloader service for file operations. - Improved recipe scanner to fetch model information through the new metadata provider. - Updated utility functions to streamline image downloading and processing. - Added comprehensive logging and error handling for better debugging and reliability. - Introduced `get_default_metadata_provider()` for simplified access to the default provider. - Ensured backward compatibility with existing APIs and workflows. --- METADATA_PROVIDER_REFACTOR_SUMMARY.md | 119 +++++++++++++++++++ py/recipes/parsers/automatic.py | 12 +- py/recipes/parsers/civitai_image.py | 32 ++--- py/recipes/parsers/comfy.py | 14 ++- py/recipes/parsers/meta_format.py | 10 +- py/recipes/parsers/recipe_format.py | 8 +- py/routes/recipe_routes.py | 31 ++--- py/routes/update_routes.py | 4 +- py/services/download_manager.py | 41 +++++-- py/services/downloader.py | 4 + py/services/metadata_service.py | 68 ++++++----- py/services/model_scanner.py | 7 +- py/services/recipe_scanner.py | 11 +- py/services/settings_manager.py | 2 +- py/utils/routes_common.py | 165 +++++++++++++------------- 15 files changed, 350 insertions(+), 178 deletions(-) create mode 100644 METADATA_PROVIDER_REFACTOR_SUMMARY.md diff --git a/METADATA_PROVIDER_REFACTOR_SUMMARY.md b/METADATA_PROVIDER_REFACTOR_SUMMARY.md new file mode 100644 index 00000000..8e28a60d --- /dev/null +++ b/METADATA_PROVIDER_REFACTOR_SUMMARY.md @@ -0,0 +1,119 @@ +# Metadata Provider Refactor Summary + +## Overview +This refactor improves the metadata provider initialization logic and replaces direct Civitai client usage with the unified FallbackMetadataProvider system throughout the codebase. + +## Key Changes + +### 1. Enhanced Metadata Service (`py/services/metadata_service.py`) + +#### Improved `initialize_metadata_providers()`: +- Added provider clearing for proper reinitialization +- Enhanced error handling and validation +- Better logging for debugging +- Improved provider ordering logic based on priority settings +- More robust database path validation + +#### Enhanced `update_metadata_provider_priority()`: +- More robust error handling +- Proper reinitalization of all providers +- Better logging for setting changes + +#### New helper function: +- Added `get_default_metadata_provider()` for easier access to the default provider + +### 2. Updated Recipe Parsers +All recipe parsers now use the unified metadata provider instead of direct civitai_client: + +#### Files Updated: +- `py/recipes/parsers/civitai_image.py` +- `py/recipes/parsers/comfy.py` +- `py/recipes/parsers/automatic.py` +- `py/recipes/parsers/recipe_format.py` +- `py/recipes/parsers/meta_format.py` + +#### Changes Made: +- Added import for `get_default_metadata_provider` +- Replaced `civitai_client.get_model_by_hash()` with `metadata_provider.get_model_by_hash()` +- Replaced `civitai_client.get_model_version_info()` with `metadata_provider.get_model_version_info()` +- Updated method signatures to indicate civitai_client parameter is deprecated + +### 3. Download Manager Updates (`py/services/download_manager.py`) + +#### Metadata Operations: +- Replaced direct civitai_client usage with metadata_provider for: + - `get_model_version()` calls for version info + +#### Download Operations: +- Replaced `civitai_client.download_file()` with direct `downloader.download_file()` calls +- Replaced `civitai_client.download_preview_image()` with `downloader.download_to_memory()` for images +- Added proper authentication flags (`use_auth=True` for model files, `use_auth=False` for preview images) + +### 4. Recipe Scanner Updates (`py/services/recipe_scanner.py`) +- Added import for `get_default_metadata_provider` +- Replaced `civitai_client.get_model_version_info()` with `metadata_provider.get_model_version_info()` + +### 5. Utility Functions Updates (`py/utils/routes_common.py`) +- Added import for `get_downloader` +- Replaced preview image downloads with direct downloader usage +- Improved image optimization logic to work with in-memory downloads +- Better error handling for download and image processing operations + +## Benefits + +### 1. Unified Metadata Access +- All metadata requests now go through the fallback provider system +- Automatic failover between SQLite archive database and Civitai API +- Consistent metadata access patterns across all components + +### 2. Improved Download Performance +- Direct use of the optimized downloader service +- Better connection pooling and retry logic +- Proper authentication handling +- Support for resumable downloads + +### 3. Better Configuration Management +- Settings changes now properly update provider priority +- Clear separation between metadata and download operations +- Improved error handling and logging + +### 4. Enhanced Reliability +- Fallback mechanisms ensure metadata is always available when possible +- Better error handling and recovery +- Consistent behavior across all parsers and services + +## Usage + +### Settings Changes +When users change metadata provider settings: +1. The `update_metadata_provider_priority()` function is automatically called +2. All providers are reinitialized with the new settings +3. The fallback provider is updated with the correct priority order + +### Metadata Access +All components now use: +```python +from ...services.metadata_service import get_default_metadata_provider + +metadata_provider = await get_default_metadata_provider() +result = await metadata_provider.get_model_by_hash(hash_value) +``` + +### Downloads +All downloads now use the unified downloader: +```python +from ...services.downloader import get_downloader + +downloader = await get_downloader() +success, result = await downloader.download_file(url, path, use_auth=True) +``` + +## Compatibility +- All existing APIs and interfaces remain unchanged +- Backward compatibility maintained for existing workflows +- No changes required for external integrations + +## Testing +- All updated files pass syntax validation +- Existing functionality preserved +- Enhanced error handling and logging for better debugging diff --git a/py/recipes/parsers/automatic.py b/py/recipes/parsers/automatic.py index 3c3534e0..b7399c72 100644 --- a/py/recipes/parsers/automatic.py +++ b/py/recipes/parsers/automatic.py @@ -6,6 +6,7 @@ import logging from typing import Dict, Any from ..base import RecipeMetadataParser from ..constants import GEN_PARAM_KEYS +from ...services.metadata_service import get_default_metadata_provider logger = logging.getLogger(__name__) @@ -30,6 +31,9 @@ class AutomaticMetadataParser(RecipeMetadataParser): async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]: """Parse metadata from Automatic1111 format""" try: + # Get metadata provider instead of using civitai_client directly + metadata_provider = await get_default_metadata_provider() + # Split on Negative prompt if it exists if "Negative prompt:" in user_comment: parts = user_comment.split('Negative prompt:', 1) @@ -216,9 +220,9 @@ class AutomaticMetadataParser(RecipeMetadataParser): } # Get additional info from Civitai - if civitai_client: + if metadata_provider: try: - civitai_info = await civitai_client.get_model_version_info(resource.get("modelVersionId")) + civitai_info = await metadata_provider.get_model_version_info(resource.get("modelVersionId")) populated_entry = await self.populate_lora_from_civitai( lora_entry, civitai_info, @@ -271,11 +275,11 @@ class AutomaticMetadataParser(RecipeMetadataParser): } # Try to get info from Civitai - if civitai_client: + if metadata_provider: try: if lora_hash: # If we have hash, use it for lookup - civitai_info = await civitai_client.get_model_by_hash(lora_hash) + civitai_info = await metadata_provider.get_model_by_hash(lora_hash) else: civitai_info = None diff --git a/py/recipes/parsers/civitai_image.py b/py/recipes/parsers/civitai_image.py index 37c9cc25..8e96c99b 100644 --- a/py/recipes/parsers/civitai_image.py +++ b/py/recipes/parsers/civitai_image.py @@ -5,6 +5,7 @@ import logging from typing import Dict, Any, Union from ..base import RecipeMetadataParser from ..constants import GEN_PARAM_KEYS +from ...services.metadata_service import get_default_metadata_provider logger = logging.getLogger(__name__) @@ -36,12 +37,15 @@ class CivitaiApiMetadataParser(RecipeMetadataParser): Args: metadata: The metadata from the image (dict) recipe_scanner: Optional recipe scanner service - civitai_client: Optional Civitai API client + civitai_client: Optional Civitai API client (deprecated, use metadata_provider instead) Returns: Dict containing parsed recipe data """ try: + # Get metadata provider instead of using civitai_client directly + metadata_provider = await get_default_metadata_provider() + # Initialize result structure result = { 'base_model': None, @@ -85,9 +89,9 @@ class CivitaiApiMetadataParser(RecipeMetadataParser): # Extract base model information - directly if available if "baseModel" in metadata: result["base_model"] = metadata["baseModel"] - elif "Model hash" in metadata and civitai_client: + elif "Model hash" in metadata and metadata_provider: model_hash = metadata["Model hash"] - model_info = await civitai_client.get_model_by_hash(model_hash) + model_info = await metadata_provider.get_model_by_hash(model_hash) if model_info: result["base_model"] = model_info.get("baseModel", "") elif "Model" in metadata and isinstance(metadata.get("resources"), list): @@ -95,8 +99,8 @@ class CivitaiApiMetadataParser(RecipeMetadataParser): for resource in metadata.get("resources", []): if resource.get("type") == "model" and resource.get("name") == metadata.get("Model"): # This is likely the checkpoint model - if civitai_client and resource.get("hash"): - model_info = await civitai_client.get_model_by_hash(resource.get("hash")) + if metadata_provider and resource.get("hash"): + model_info = await metadata_provider.get_model_by_hash(resource.get("hash")) if model_info: result["base_model"] = model_info.get("baseModel", "") @@ -138,9 +142,9 @@ class CivitaiApiMetadataParser(RecipeMetadataParser): } # Try to get info from Civitai if hash is available - if lora_entry['hash'] and civitai_client: + if lora_entry['hash'] and metadata_provider: try: - civitai_info = await civitai_client.get_model_by_hash(lora_hash) + civitai_info = await metadata_provider.get_model_by_hash(lora_hash) populated_entry = await self.populate_lora_from_civitai( lora_entry, @@ -194,10 +198,10 @@ class CivitaiApiMetadataParser(RecipeMetadataParser): } # Try to get info from Civitai if modelVersionId is available - if version_id and civitai_client: + if version_id and metadata_provider: try: # Use get_model_version_info instead of get_model_version - civitai_info, error = await civitai_client.get_model_version_info(version_id) + civitai_info, error = await metadata_provider.get_model_version_info(version_id) if error: logger.warning(f"Error getting model version info: {error}") @@ -259,11 +263,11 @@ class CivitaiApiMetadataParser(RecipeMetadataParser): 'isDeleted': False } - # If we have a version ID and civitai client, try to get more info - if version_id and civitai_client: + # If we have a version ID and metadata provider, try to get more info + if version_id and metadata_provider: try: # Use get_model_version_info with the version ID - civitai_info, error = await civitai_client.get_model_version_info(version_id) + civitai_info, error = await metadata_provider.get_model_version_info(version_id) if error: logger.warning(f"Error getting model version info: {error}") @@ -316,9 +320,9 @@ class CivitaiApiMetadataParser(RecipeMetadataParser): } # Try to get info from Civitai if hash is available - if lora_entry['hash'] and civitai_client: + if lora_entry['hash'] and metadata_provider: try: - civitai_info = await civitai_client.get_model_by_hash(lora_hash) + civitai_info = await metadata_provider.get_model_by_hash(lora_hash) populated_entry = await self.populate_lora_from_civitai( lora_entry, diff --git a/py/recipes/parsers/comfy.py b/py/recipes/parsers/comfy.py index c8f3eebf..f81a15ad 100644 --- a/py/recipes/parsers/comfy.py +++ b/py/recipes/parsers/comfy.py @@ -6,6 +6,7 @@ import logging from typing import Dict, Any from ..base import RecipeMetadataParser from ..constants import GEN_PARAM_KEYS +from ...services.metadata_service import get_default_metadata_provider logger = logging.getLogger(__name__) @@ -26,6 +27,9 @@ class ComfyMetadataParser(RecipeMetadataParser): async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]: """Parse metadata from Civitai ComfyUI metadata format""" try: + # Get metadata provider instead of using civitai_client directly + metadata_provider = await get_default_metadata_provider() + data = json.loads(user_comment) loras = [] @@ -73,10 +77,10 @@ class ComfyMetadataParser(RecipeMetadataParser): 'isDeleted': False } - # Get additional info from Civitai if client is available - if civitai_client: + # Get additional info from Civitai if metadata provider is available + if metadata_provider: try: - civitai_info_tuple = await civitai_client.get_model_version_info(model_version_id) + civitai_info_tuple = await metadata_provider.get_model_version_info(model_version_id) # Populate lora entry with Civitai info populated_entry = await self.populate_lora_from_civitai( lora_entry, @@ -116,9 +120,9 @@ class ComfyMetadataParser(RecipeMetadataParser): } # Get additional checkpoint info from Civitai - if civitai_client: + if metadata_provider: try: - civitai_info_tuple = await civitai_client.get_model_version_info(checkpoint_version_id) + civitai_info_tuple = await metadata_provider.get_model_version_info(checkpoint_version_id) civitai_info, _ = civitai_info_tuple if isinstance(civitai_info_tuple, tuple) else (civitai_info_tuple, None) # Populate checkpoint with Civitai info checkpoint = await self.populate_checkpoint_from_civitai(checkpoint, civitai_info) diff --git a/py/recipes/parsers/meta_format.py b/py/recipes/parsers/meta_format.py index acd7e8bf..5eb53af7 100644 --- a/py/recipes/parsers/meta_format.py +++ b/py/recipes/parsers/meta_format.py @@ -5,6 +5,7 @@ import logging from typing import Dict, Any from ..base import RecipeMetadataParser from ..constants import GEN_PARAM_KEYS +from ...services.metadata_service import get_default_metadata_provider logger = logging.getLogger(__name__) @@ -18,8 +19,11 @@ class MetaFormatParser(RecipeMetadataParser): return re.search(self.METADATA_MARKER, user_comment, re.IGNORECASE | re.DOTALL) is not None async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]: - """Parse metadata from images with meta format metadata""" + """Parse metadata from images with meta format metadata (Lora_N Model hash format)""" try: + # Get metadata provider instead of using civitai_client directly + metadata_provider = await get_default_metadata_provider() + # Extract prompt and negative prompt parts = user_comment.split('Negative prompt:', 1) prompt = parts[0].strip() @@ -122,9 +126,9 @@ class MetaFormatParser(RecipeMetadataParser): } # Get info from Civitai by hash if available - if civitai_client and hash_value: + if metadata_provider and hash_value: try: - civitai_info = await civitai_client.get_model_by_hash(hash_value) + civitai_info = await metadata_provider.get_model_by_hash(hash_value) # Populate lora entry with Civitai info populated_entry = await self.populate_lora_from_civitai( lora_entry, diff --git a/py/recipes/parsers/recipe_format.py b/py/recipes/parsers/recipe_format.py index 667bdc43..5380cc69 100644 --- a/py/recipes/parsers/recipe_format.py +++ b/py/recipes/parsers/recipe_format.py @@ -7,6 +7,7 @@ from typing import Dict, Any from ...config import config from ..base import RecipeMetadataParser from ..constants import GEN_PARAM_KEYS +from ...services.metadata_service import get_default_metadata_provider logger = logging.getLogger(__name__) @@ -23,6 +24,9 @@ class RecipeFormatParser(RecipeMetadataParser): async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]: """Parse metadata from images with dedicated recipe metadata format""" try: + # Get metadata provider instead of using civitai_client directly + metadata_provider = await get_default_metadata_provider() + # Extract recipe metadata from user comment try: # Look for recipe metadata section @@ -71,9 +75,9 @@ class RecipeFormatParser(RecipeMetadataParser): lora_entry['localPath'] = None # Try to get additional info from Civitai if we have a model version ID - if lora.get('modelVersionId') and civitai_client: + if lora.get('modelVersionId') and metadata_provider: try: - civitai_info_tuple = await civitai_client.get_model_version_info(lora['modelVersionId']) + civitai_info_tuple = await metadata_provider.get_model_version_info(lora['modelVersionId']) # Populate lora entry with Civitai info populated_entry = await self.populate_lora_from_civitai( lora_entry, diff --git a/py/routes/recipe_routes.py b/py/routes/recipe_routes.py index cdd1b793..003d869a 100644 --- a/py/routes/recipe_routes.py +++ b/py/routes/recipe_routes.py @@ -24,6 +24,7 @@ from ..config import config standalone_mode = 'nodes' not in sys.modules from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import +from ..services.downloader import get_downloader # Only import MetadataRegistry in non-standalone mode if not standalone_mode: @@ -372,21 +373,23 @@ class RecipeRoutes: "loras": [] }, status=400) - # Download image directly from URL - session = await self.civitai_client.session + # Download image using unified downloader + downloader = await get_downloader() # Create a temporary file to save the downloaded image with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as temp_file: temp_path = temp_file.name - async with session.get(image_url) as response: - if response.status != 200: - return web.json_response({ - "error": f"Failed to download image from URL: HTTP {response.status}", - "loras": [] - }, status=400) - - with open(temp_path, 'wb') as f: - f.write(await response.read()) + success, result = await downloader.download_file( + image_url, + temp_path, + use_auth=False # Image downloads typically don't need auth + ) + + if not success: + return web.json_response({ + "error": f"Failed to download image from URL: {result}", + "loras": [] + }, status=400) # Use meta field from image_info as metadata if 'meta' in image_info: @@ -430,8 +433,7 @@ class RecipeRoutes: # Parse the metadata result = await parser.parse_metadata( metadata, - recipe_scanner=self.recipe_scanner, - civitai_client=self.civitai_client + recipe_scanner=self.recipe_scanner ) # For URL mode, include the image data as base64 @@ -532,8 +534,7 @@ class RecipeRoutes: # Parse the metadata result = await parser.parse_metadata( metadata, - recipe_scanner=self.recipe_scanner, - civitai_client=self.civitai_client + recipe_scanner=self.recipe_scanner ) # Add base64 image data to result diff --git a/py/routes/update_routes.py b/py/routes/update_routes.py index 66ef603a..d139ce77 100644 --- a/py/routes/update_routes.py +++ b/py/routes/update_routes.py @@ -258,7 +258,7 @@ class UpdateRoutes: try: downloader = await Downloader.get_instance() - success, data = await downloader.make_request('GET', github_url, headers={'Accept': 'application/vnd.github+json'}) + success, data = await downloader.make_request('GET', github_url, custom_headers={'Accept': 'application/vnd.github+json'}) if not success: logger.warning(f"Failed to fetch GitHub commit: {data}") @@ -424,7 +424,7 @@ class UpdateRoutes: try: downloader = await Downloader.get_instance() - success, data = await downloader.make_request('GET', github_url, headers={'Accept': 'application/vnd.github+json'}) + success, data = await downloader.make_request('GET', github_url, custom_headers={'Accept': 'application/vnd.github+json'}) if not success: logger.warning(f"Failed to fetch GitHub release: {data}") diff --git a/py/services/download_manager.py b/py/services/download_manager.py index 08295985..9f090b20 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -10,6 +10,8 @@ from ..utils.exif_utils import ExifUtils from ..utils.metadata_manager import MetadataManager from .service_registry import ServiceRegistry from .settings_manager import settings +from .metadata_service import get_default_metadata_provider +from .downloader import get_downloader # Download to temporary file first import tempfile @@ -199,11 +201,11 @@ class DownloadManager: if await embedding_scanner.check_model_version_exists(model_version_id): return {'success': False, 'error': 'Model version already exists in embedding library'} - # Get civitai client - civitai_client = await self._get_civitai_client() + # Get metadata provider instead of civitai client directly + metadata_provider = await get_default_metadata_provider() # Get version info based on the provided identifier - version_info = await civitai_client.get_model_version(model_id, model_version_id) + version_info = await metadata_provider.get_model_version(model_id, model_version_id) if not version_info: return {'success': False, 'error': 'Failed to fetch model metadata'} @@ -445,8 +447,14 @@ class DownloadManager: preview_ext = '.mp4' preview_path = os.path.splitext(save_path)[0] + preview_ext - # Download video directly - if await civitai_client.download_preview_image(images[0]['url'], preview_path): + # Download video directly using downloader + downloader = await get_downloader() + success, result = await downloader.download_file( + images[0]['url'], + preview_path, + use_auth=False # Preview images typically don't need auth + ) + if success: metadata.preview_url = preview_path.replace(os.sep, '/') metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0) else: @@ -454,8 +462,16 @@ class DownloadManager: with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file: temp_path = temp_file.name - # Download the original image to temp path - if await civitai_client.download_preview_image(images[0]['url'], temp_path): + # Download the original image to temp path using downloader + downloader = await get_downloader() + success, content = await downloader.download_to_memory( + images[0]['url'], + use_auth=False + ) + if success: + # Save to temp file + with open(temp_path, 'wb') as f: + f.write(content) # Optimize and convert to WebP preview_path = os.path.splitext(save_path)[0] + '.webp' @@ -486,12 +502,13 @@ class DownloadManager: if progress_callback: await progress_callback(3) # 3% progress after preview download - # Download model file with progress tracking - success, result = await civitai_client.download_file( + # Download model file with progress tracking using downloader + downloader = await get_downloader() + success, result = await downloader.download_file( download_url, - save_dir, - os.path.basename(save_path), - progress_callback=lambda p: self._handle_download_progress(p, progress_callback) + save_path, # Use full path instead of separate dir and filename + progress_callback=lambda p: self._handle_download_progress(p, progress_callback), + use_auth=True # Model downloads need authentication ) if not success: diff --git a/py/services/downloader.py b/py/services/downloader.py index cb7f5ef1..9efba53b 100644 --- a/py/services/downloader.py +++ b/py/services/downloader.py @@ -276,6 +276,10 @@ class Downloader: while rename_attempt < max_rename_attempts and not rename_success: try: + # If the destination file exists, remove it first (Windows safe) + if os.path.exists(save_path): + os.remove(save_path) + os.rename(part_path, save_path) rename_success = True except PermissionError as e: diff --git a/py/services/metadata_service.py b/py/services/metadata_service.py index 0e5d9199..7823a1f7 100644 --- a/py/services/metadata_service.py +++ b/py/services/metadata_service.py @@ -16,6 +16,10 @@ async def initialize_metadata_providers(): """Initialize and configure all metadata providers based on settings""" provider_manager = await ModelMetadataProviderManager.get_instance() + # Clear existing providers to allow reinitialization + provider_manager.providers.clear() + provider_manager.default_provider = None + # Get settings enable_archive_db = settings.get('enable_metadata_archive_db', False) priority = settings.get('metadata_provider_priority', 'archive_db') @@ -24,23 +28,23 @@ async def initialize_metadata_providers(): # Initialize archive database provider if enabled if enable_archive_db: - # Initialize archive manager - base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) - archive_manager = MetadataArchiveManager(base_path) - - db_path = archive_manager.get_database_path() - if db_path: - try: + try: + # Initialize archive manager + base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + archive_manager = MetadataArchiveManager(base_path) + + db_path = archive_manager.get_database_path() + if db_path and os.path.exists(db_path): sqlite_provider = SQLiteModelMetadataProvider(db_path) provider_manager.register_provider('sqlite', sqlite_provider) providers.append(('sqlite', sqlite_provider)) logger.info(f"SQLite metadata provider registered with database: {db_path}") - except Exception as e: - logger.error(f"Failed to initialize SQLite metadata provider: {e}") - else: - logger.warning("Metadata archive database is enabled but not available") + else: + logger.warning("Metadata archive database is enabled but database file not found") + except Exception as e: + logger.error(f"Failed to initialize SQLite metadata provider: {e}") - # Initialize Civitai API provider + # Initialize Civitai API provider (always available as fallback) try: civitai_client = await ServiceRegistry.get_civitai_client() civitai_provider = CivitaiModelMetadataProvider(civitai_client) @@ -50,42 +54,48 @@ async def initialize_metadata_providers(): except Exception as e: logger.error(f"Failed to initialize Civitai API metadata provider: {e}") - # Set up fallback provider based on priority + # Set up fallback provider based on priority and available providers if len(providers) > 1: # Order providers based on priority setting + ordered_providers = [] if priority == 'archive_db': # Archive DB first, then Civitai API - ordered_providers = [p[1] for p in providers if p[0] == 'sqlite'] + [p[1] for p in providers if p[0] == 'civitai_api'] + ordered_providers = [p[1] for p in providers if p[0] == 'sqlite'] + ordered_providers.extend([p[1] for p in providers if p[0] == 'civitai_api']) else: # Civitai API first, then Archive DB - ordered_providers = [p[1] for p in providers if p[0] == 'civitai_api'] + [p[1] for p in providers if p[0] == 'sqlite'] + ordered_providers = [p[1] for p in providers if p[0] == 'civitai_api'] + ordered_providers.extend([p[1] for p in providers if p[0] == 'sqlite']) if ordered_providers: fallback_provider = FallbackMetadataProvider(ordered_providers) provider_manager.register_provider('fallback', fallback_provider, is_default=True) - logger.info(f"Fallback metadata provider registered with priority: {priority}") + logger.info(f"Fallback metadata provider registered with {len(ordered_providers)} providers, priority: {priority}") elif len(providers) == 1: # Only one provider available, set it as default provider_name, provider = providers[0] provider_manager.register_provider(provider_name, provider, is_default=True) logger.info(f"Single metadata provider registered as default: {provider_name}") else: - logger.warning("No metadata providers available") + logger.warning("No metadata providers available - this may cause metadata lookup failures") return provider_manager async def update_metadata_provider_priority(): """Update metadata provider priority based on current settings""" - provider_manager = await ModelMetadataProviderManager.get_instance() - - # Get current settings - enable_archive_db = settings.get('enable_metadata_archive_db', False) - priority = settings.get('metadata_provider_priority', 'archive_db') - - # Rebuild providers with new priority - await initialize_metadata_providers() - - logger.info(f"Updated metadata provider priority to: {priority}") + try: + # Get current settings + enable_archive_db = settings.get('enable_metadata_archive_db', False) + priority = settings.get('metadata_provider_priority', 'archive_db') + + # Reinitialize all providers with new settings + provider_manager = await initialize_metadata_providers() + + logger.info(f"Updated metadata provider priority to: {priority}, archive_db enabled: {enable_archive_db}") + return provider_manager + except Exception as e: + logger.error(f"Failed to update metadata provider priority: {e}") + return await ModelMetadataProviderManager.get_instance() async def get_metadata_archive_manager(): """Get metadata archive manager instance""" @@ -100,3 +110,7 @@ async def get_metadata_provider(provider_name: str = None): return provider_manager._get_provider(provider_name) return provider_manager._get_provider() + +async def get_default_metadata_provider(): + """Get the default metadata provider (fallback or single provider)""" + return await get_metadata_provider() diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index 2f419440..4e6f3553 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -730,11 +730,10 @@ class ModelScanner: if needs_metadata_update and model_id: logger.debug(f"Fetching missing metadata for {file_path} with model ID {model_id}") - from ..services.civitai_client import CivitaiClient - client = CivitaiClient() + from ..services.metadata_service import get_default_metadata_provider + metadata_provider = await get_default_metadata_provider() - model_metadata, status_code = await client.get_model_metadata(model_id) - await client.close() + model_metadata, status_code = await metadata_provider.get_model_metadata(model_id) if status_code == 404: logger.warning(f"Model {model_id} appears to be deleted from Civitai (404 response)") diff --git a/py/services/recipe_scanner.py b/py/services/recipe_scanner.py index 89bbef14..ca5a20ac 100644 --- a/py/services/recipe_scanner.py +++ b/py/services/recipe_scanner.py @@ -8,6 +8,7 @@ from ..config import config from .recipe_cache import RecipeCache from .service_registry import ServiceRegistry from .lora_scanner import LoraScanner +from .metadata_service import get_default_metadata_provider from ..utils.utils import fuzzy_match from natsort import natsorted import sys @@ -431,13 +432,13 @@ class RecipeScanner: async def _get_hash_from_civitai(self, model_version_id: str) -> Optional[str]: """Get hash from Civitai API""" try: - # Get CivitaiClient from ServiceRegistry - civitai_client = await self._get_civitai_client() - if not civitai_client: - logger.error("Failed to get CivitaiClient from ServiceRegistry") + # Get metadata provider instead of civitai client directly + metadata_provider = await get_default_metadata_provider() + if not metadata_provider: + logger.error("Failed to get metadata provider") return None - version_info, error_msg = await civitai_client.get_model_version_info(model_version_id) + version_info, error_msg = await metadata_provider.get_model_version_info(model_version_id) if not version_info: if error_msg and "model not found" in error_msg.lower(): diff --git a/py/services/settings_manager.py b/py/services/settings_manager.py index 0b86ce82..058d9944 100644 --- a/py/services/settings_manager.py +++ b/py/services/settings_manager.py @@ -81,7 +81,7 @@ class SettingsManager: return { "civitai_api_key": "", "show_only_sfw": False, - "language": "en", # 添加默认语言设置 + "language": "en", "enable_metadata_archive_db": False, # Enable metadata archive database "metadata_provider_priority": "archive_db" # Default priority: 'archive_db' or 'civitai_api' } diff --git a/py/utils/routes_common.py b/py/utils/routes_common.py index 778725a1..eba348ea 100644 --- a/py/utils/routes_common.py +++ b/py/utils/routes_common.py @@ -7,13 +7,12 @@ from aiohttp import web from .model_utils import determine_base_model from .constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH from ..config import config -from ..services.civitai_client import CivitaiClient from ..services.service_registry import ServiceRegistry +from ..services.downloader import get_downloader from ..utils.exif_utils import ExifUtils from ..utils.metadata_manager import MetadataManager -from ..services.download_manager import DownloadManager from ..services.websocket_manager import ws_manager -from ..services.metadata_service import get_metadata_provider +from ..services.metadata_service import get_default_metadata_provider logger = logging.getLogger(__name__) @@ -40,7 +39,7 @@ class ModelRouteUtils: @staticmethod async def update_model_metadata(metadata_path: str, local_metadata: Dict, - civitai_metadata: Dict, client: CivitaiClient) -> None: + civitai_metadata: Dict, metadata_provider=None) -> None: """Update local metadata with CivitAI data""" # Save existing trainedWords and customImages if they exist existing_civitai = local_metadata.get('civitai') or {} # Use empty dict if None @@ -80,15 +79,17 @@ class ModelRouteUtils: # If we have modelId and don't have enough metadata, fetch additional data if not model_metadata or not model_metadata.get('description'): model_id = civitai_metadata.get('modelId') - if model_id: - fetched_metadata, _ = await client.get_model_metadata(str(model_id)) + if model_id and metadata_provider: + fetched_metadata, _ = await metadata_provider.get_model_metadata(str(model_id)) if fetched_metadata: model_metadata = fetched_metadata # Update local metadata with the model information if model_metadata: local_metadata['modelDescription'] = model_metadata.get('description', '') - local_metadata['tags'] = model_metadata.get('tags', []) + # Only set tags if local_metadata['tags'] is empty + if not local_metadata.get('tags'): + local_metadata['tags'] = model_metadata.get('tags', []) if 'creator' in model_metadata and model_metadata['creator']: local_metadata['civitai']['creator'] = model_metadata['creator'] @@ -114,22 +115,28 @@ class ModelRouteUtils: preview_path = os.path.join(os.path.dirname(metadata_path), preview_filename) if is_video: - # Download video as is - if await client.download_preview_image(first_preview['url'], preview_path): + # Download video as is using downloader + downloader = await get_downloader() + success, result = await downloader.download_file( + first_preview['url'], + preview_path, + use_auth=False + ) + if success: local_metadata['preview_url'] = preview_path.replace(os.sep, '/') local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0) else: - # For images, download and then optimize to WebP - temp_path = preview_path + ".temp" - if await client.download_preview_image(first_preview['url'], temp_path): + # For images, download and then optimize to WebP using downloader + downloader = await get_downloader() + success, content = await downloader.download_to_memory( + first_preview['url'], + use_auth=False + ) + if success: try: - # Read the downloaded image - with open(temp_path, 'rb') as f: - image_data = f.read() - # Optimize and convert to WebP optimized_data, _ = ExifUtils.optimize_image( - image_data=image_data, + image_data=content, # Use downloaded content directly target_width=CARD_PREVIEW_WIDTH, format='webp', quality=85, @@ -144,17 +151,16 @@ class ModelRouteUtils: local_metadata['preview_url'] = preview_path.replace(os.sep, '/') local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0) - # Remove the temporary file - if os.path.exists(temp_path): - os.remove(temp_path) - except Exception as e: logger.error(f"Error optimizing preview image: {e}") - # If optimization fails, try to use the downloaded image directly - if os.path.exists(temp_path): - os.rename(temp_path, preview_path) + # If optimization fails, save the original content + try: + with open(preview_path, 'wb') as f: + f.write(content) local_metadata['preview_url'] = preview_path.replace(os.sep, '/') local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0) + except Exception as save_error: + logger.error(f"Error saving preview image: {save_error}") # Save updated metadata await MetadataManager.save_metadata(metadata_path, local_metadata) @@ -177,7 +183,6 @@ class ModelRouteUtils: Returns: bool: True if successful, False otherwise """ - client = CivitaiClient() try: # Validate input parameters if not isinstance(model_data, dict): @@ -189,8 +194,9 @@ class ModelRouteUtils: # Check if model metadata exists local_metadata = await ModelRouteUtils.load_local_metadata(metadata_path) - # Fetch metadata from Civitai - civitai_metadata = await client.get_model_by_hash(sha256) + # Get metadata provider and fetch metadata from unified provider + metadata_provider = await get_default_metadata_provider() + civitai_metadata = await metadata_provider.get_model_by_hash(sha256) if not civitai_metadata: # Mark as not from CivitAI if not found local_metadata['from_civitai'] = False @@ -203,7 +209,7 @@ class ModelRouteUtils: metadata_path, local_metadata, civitai_metadata, - client + metadata_provider ) # Update cache object directly using safe .get() method @@ -226,8 +232,6 @@ class ModelRouteUtils: except Exception as e: logger.error(f"Error fetching CivitAI data: {str(e)}", exc_info=True) # Include stack trace return False - finally: - await client.close() @staticmethod def filter_civitai_data(data: Dict, minimal: bool = False) -> Dict: @@ -360,24 +364,22 @@ class ModelRouteUtils: if not local_metadata or not local_metadata.get('sha256'): return web.json_response({"success": False, "error": "No SHA256 hash found"}, status=400) - # Create a client for fetching from Civitai - client = CivitaiClient() - try: - # Fetch and update metadata - civitai_metadata = await client.get_model_by_hash(local_metadata["sha256"]) - if not civitai_metadata: - await ModelRouteUtils.handle_not_found_on_civitai(metadata_path, local_metadata) - return web.json_response({"success": False, "error": "Not found on CivitAI"}, status=404) + # Get metadata provider and fetch from unified provider + metadata_provider = await get_default_metadata_provider() + + # Fetch and update metadata + civitai_metadata = await metadata_provider.get_model_by_hash(local_metadata["sha256"]) + if not civitai_metadata: + await ModelRouteUtils.handle_not_found_on_civitai(metadata_path, local_metadata) + return web.json_response({"success": False, "error": "Not found on CivitAI"}, status=404) - await ModelRouteUtils.update_model_metadata(metadata_path, local_metadata, civitai_metadata, client) - - # Update the cache - await scanner.update_single_model_cache(data['file_path'], data['file_path'], local_metadata) - - # Return the updated metadata along with success status - return web.json_response({"success": True, "metadata": local_metadata}) - finally: - await client.close() + await ModelRouteUtils.update_model_metadata(metadata_path, local_metadata, civitai_metadata, metadata_provider) + + # Update the cache + await scanner.update_single_model_cache(data['file_path'], data['file_path'], local_metadata) + + # Return the updated metadata along with success status + return web.json_response({"success": True, "metadata": local_metadata}) except Exception as e: logger.error(f"Error fetching from CivitAI: {e}", exc_info=True) @@ -778,43 +780,38 @@ class ModelRouteUtils: # Check if model metadata exists local_metadata = await ModelRouteUtils.load_local_metadata(metadata_path) - # Create a client for fetching from Civitai - client = await CivitaiClient.get_instance() - try: - # Fetch metadata using get_model_version which includes more comprehensive data - civitai_metadata = await client.get_model_version(model_id, model_version_id) - if not civitai_metadata: - error_msg = f"Model version not found on CivitAI for ID: {model_id}" - if model_version_id: - error_msg += f" with version: {model_version_id}" - return web.json_response({"success": False, "error": error_msg}, status=404) - - # Try to find the primary model file to get the SHA256 hash - primary_model_file = None - for file in civitai_metadata.get('files', []): - if file.get('primary', False) and file.get('type') == 'Model': - primary_model_file = file - break - - # Update the SHA256 hash in local metadata if available - if primary_model_file and primary_model_file.get('hashes', {}).get('SHA256'): - local_metadata['sha256'] = primary_model_file['hashes']['SHA256'].lower() - - # Update metadata with CivitAI information - await ModelRouteUtils.update_model_metadata(metadata_path, local_metadata, civitai_metadata, client) - - # Update the cache - await scanner.update_single_model_cache(file_path, file_path, local_metadata) - - return web.json_response({ - "success": True, - "message": f"Model successfully re-linked to Civitai model {model_id}" + - (f" version {model_version_id}" if model_version_id else ""), - "hash": local_metadata.get('sha256', '') - }) - - finally: - await client.close() + # Get metadata provider and fetch metadata using get_model_version which includes more comprehensive data + metadata_provider = await get_default_metadata_provider() + civitai_metadata = await metadata_provider.get_model_version(model_id, model_version_id) + if not civitai_metadata: + error_msg = f"Model version not found on CivitAI for ID: {model_id}" + if model_version_id: + error_msg += f" with version: {model_version_id}" + return web.json_response({"success": False, "error": error_msg}, status=404) + + # Try to find the primary model file to get the SHA256 hash + primary_model_file = None + for file in civitai_metadata.get('files', []): + if file.get('primary', False) and file.get('type') == 'Model': + primary_model_file = file + break + + # Update the SHA256 hash in local metadata if available + if primary_model_file and primary_model_file.get('hashes', {}).get('SHA256'): + local_metadata['sha256'] = primary_model_file['hashes']['SHA256'].lower() + + # Update metadata with CivitAI information + await ModelRouteUtils.update_model_metadata(metadata_path, local_metadata, civitai_metadata, metadata_provider) + + # Update the cache + await scanner.update_single_model_cache(file_path, file_path, local_metadata) + + return web.json_response({ + "success": True, + "message": f"Model successfully re-linked to Civitai model {model_id}" + + (f" version {model_version_id}" if model_version_id else ""), + "hash": local_metadata.get('sha256', '') + }) except Exception as e: logger.error(f"Error re-linking to CivitAI: {e}", exc_info=True) From 68f887140331cf71ba71d4fb43a319cffbba7865 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 10 Sep 2025 11:20:58 +0800 Subject: [PATCH 07/13] feat(metadata): add source tracking for SQLite metadata and implement Civitai API metadata validation --- py/services/model_metadata_provider.py | 6 ++-- py/utils/routes_common.py | 42 ++++++++++++++++++-------- 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/py/services/model_metadata_provider.py b/py/services/model_metadata_provider.py index 2a30df11..9f54a2e7 100644 --- a/py/services/model_metadata_provider.py +++ b/py/services/model_metadata_provider.py @@ -127,7 +127,8 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider): 'model': { 'name': model_row['name'], 'type': model_type, - } + }, + 'source': 'archive_db' } # Update with any additional data version_entry.update(version_data) @@ -266,7 +267,8 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider): "creator": { "username": model_row['username'] or model_data.get("creator", {}).get("username"), "image": model_data.get("creator", {}).get("image") - } + }, + "source": "archive_db" } # Add any additional fields from version data diff --git a/py/utils/routes_common.py b/py/utils/routes_common.py index eba348ea..a82dbd2f 100644 --- a/py/utils/routes_common.py +++ b/py/utils/routes_common.py @@ -37,6 +37,18 @@ class ModelRouteUtils: local_metadata['from_civitai'] = False await MetadataManager.save_metadata(metadata_path, local_metadata) + @staticmethod + def is_civitai_api_metadata(meta: dict) -> bool: + """ + Determine if the given civitai metadata is from the civitai API. + Returns True if both 'files' and 'images' exist and are non-empty. + """ + if not isinstance(meta, dict): + return False + files = meta.get('files') + images = meta.get('images') + return bool(files) and bool(images) + @staticmethod async def update_model_metadata(metadata_path: str, local_metadata: Dict, civitai_metadata: Dict, metadata_provider=None) -> None: @@ -44,21 +56,25 @@ class ModelRouteUtils: # Save existing trainedWords and customImages if they exist existing_civitai = local_metadata.get('civitai') or {} # Use empty dict if None - # Create a new civitai metadata by updating existing with new - merged_civitai = existing_civitai.copy() - merged_civitai.update(civitai_metadata) + # Check if we should skip the update to avoid overwriting richer data + if civitai_metadata.get('source') == 'archive_db' and ModelRouteUtils.is_civitai_api_metadata(existing_civitai): + logger.info(f"Skip civitai update for {local_metadata.get('model_name', '')}: {existing_civitai.get('name', '')}") + else: + # Create a new civitai metadata by updating existing with new + merged_civitai = existing_civitai.copy() + merged_civitai.update(civitai_metadata) - # Special handling for trainedWords - ensure we don't lose any existing trained words - if 'trainedWords' in existing_civitai: - existing_trained_words = existing_civitai.get('trainedWords', []) - new_trained_words = civitai_metadata.get('trainedWords', []) - # Use a set to combine words without duplicates, then convert back to list - merged_trained_words = list(set(existing_trained_words + new_trained_words)) - merged_civitai['trainedWords'] = merged_trained_words + # Special handling for trainedWords - ensure we don't lose any existing trained words + if 'trainedWords' in existing_civitai: + existing_trained_words = existing_civitai.get('trainedWords', []) + new_trained_words = civitai_metadata.get('trainedWords', []) + # Use a set to combine words without duplicates, then convert back to list + merged_trained_words = list(set(existing_trained_words + new_trained_words)) + merged_civitai['trainedWords'] = merged_trained_words - # Update local metadata with merged civitai data - local_metadata['civitai'] = merged_civitai - local_metadata['from_civitai'] = True + # Update local metadata with merged civitai data + local_metadata['civitai'] = merged_civitai + local_metadata['from_civitai'] = True # Update model name if available if 'model' in civitai_metadata: From a4fbeb62952e66498feffca0111746efb2705c89 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 10 Sep 2025 15:55:29 +0800 Subject: [PATCH 08/13] feat(metadata): update metadata archive management and remove provider priority settings --- locales/de.json | 31 +++++++++++++++++-- locales/en.json | 7 +---- locales/es.json | 31 +++++++++++++++++-- locales/fr.json | 31 +++++++++++++++++-- locales/ja.json | 31 +++++++++++++++++-- locales/ko.json | 31 +++++++++++++++++-- locales/ru.json | 31 +++++++++++++++++-- locales/zh-CN.json | 31 +++++++++++++++++-- locales/zh-TW.json | 31 +++++++++++++++++-- py/routes/base_model_routes.py | 4 +-- py/routes/misc_routes.py | 12 +++---- py/services/metadata_service.py | 26 ++++++---------- py/services/settings_manager.py | 3 +- py/utils/routes_common.py | 23 ++++++++++---- static/js/managers/SettingsManager.js | 27 ++-------------- .../components/modals/settings_modal.html | 17 ---------- 16 files changed, 269 insertions(+), 98 deletions(-) diff --git a/locales/de.json b/locales/de.json index a9855674..aea0c592 100644 --- a/locales/de.json +++ b/locales/de.json @@ -16,7 +16,9 @@ "loading": "Wird geladen...", "unknown": "Unbekannt", "date": "Datum", - "version": "Version" + "version": "Version", + "enabled": "Aktiviert", + "disabled": "Deaktiviert" }, "language": { "select": "Sprache", @@ -178,7 +180,8 @@ "folderSettings": "Ordner-Einstellungen", "downloadPathTemplates": "Download-Pfad-Vorlagen", "exampleImages": "Beispielbilder", - "misc": "Verschiedenes" + "misc": "Verschiedenes", + "metadataArchive": "Metadaten-Archiv-Datenbank" }, "contentFiltering": { "blurNsfwContent": "NSFW-Inhalte unscharf stellen", @@ -273,6 +276,30 @@ "misc": { "includeTriggerWords": "Trigger Words in LoRA-Syntax einschließen", "includeTriggerWordsHelp": "Trainierte Trigger Words beim Kopieren der LoRA-Syntax in die Zwischenablage einschließen" + }, + "metadataArchive": { + "enableArchiveDb": "Metadaten-Archiv-Datenbank aktivieren", + "enableArchiveDbHelp": "Verwenden Sie eine lokale Datenbank, um auf Metadaten von Modellen zuzugreifen, die von Civitai gelöscht wurden.", + "status": "Status", + "statusAvailable": "Verfügbar", + "statusUnavailable": "Nicht verfügbar", + "enabled": "Aktiviert", + "management": "Datenbankverwaltung", + "managementHelp": "Laden Sie die Metadaten-Archiv-Datenbank herunter oder entfernen Sie sie", + "downloadButton": "Datenbank herunterladen", + "downloadingButton": "Wird heruntergeladen...", + "downloadedButton": "Heruntergeladen", + "removeButton": "Datenbank entfernen", + "removingButton": "Wird entfernt...", + "downloadSuccess": "Metadaten-Archiv-Datenbank erfolgreich heruntergeladen", + "downloadError": "Fehler beim Herunterladen der Metadaten-Archiv-Datenbank", + "removeSuccess": "Metadaten-Archiv-Datenbank erfolgreich entfernt", + "removeError": "Fehler beim Entfernen der Metadaten-Archiv-Datenbank", + "removeConfirm": "Sind Sie sicher, dass Sie die Metadaten-Archiv-Datenbank entfernen möchten? Dadurch wird die lokale Datenbankdatei gelöscht und Sie müssen sie erneut herunterladen, um diese Funktion zu nutzen.", + "preparing": "Download wird vorbereitet...", + "connecting": "Verbindung zum Download-Server wird hergestellt...", + "completed": "Abgeschlossen", + "downloadComplete": "Download erfolgreich abgeschlossen" } }, "loras": { diff --git a/locales/en.json b/locales/en.json index 343b813a..251688f6 100644 --- a/locales/en.json +++ b/locales/en.json @@ -279,16 +279,11 @@ }, "metadataArchive": { "enableArchiveDb": "Enable Metadata Archive Database", - "enableArchiveDbHelp": "Use local database for faster metadata retrieval and access to deleted models. Recommended for better performance.", - "providerPriority": "Metadata Provider Priority", - "providerPriorityHelp": "Choose which metadata source to try first when loading model information", - "priorityArchiveDb": "Archive Database (Recommended)", - "priorityCivitaiApi": "Civitai API", + "enableArchiveDbHelp": "Use a local database to access metadata for models that have been deleted from Civitai.", "status": "Status", "statusAvailable": "Available", "statusUnavailable": "Not Available", "enabled": "Enabled", - "currentPriority": "Current Priority", "management": "Database Management", "managementHelp": "Download or remove the metadata archive database", "downloadButton": "Download Database", diff --git a/locales/es.json b/locales/es.json index 274c497f..691e4651 100644 --- a/locales/es.json +++ b/locales/es.json @@ -16,7 +16,9 @@ "loading": "Cargando...", "unknown": "Desconocido", "date": "Fecha", - "version": "Versión" + "version": "Versión", + "enabled": "Habilitado", + "disabled": "Deshabilitado" }, "language": { "select": "Idioma", @@ -178,7 +180,8 @@ "folderSettings": "Configuración de carpetas", "downloadPathTemplates": "Plantillas de rutas de descarga", "exampleImages": "Imágenes de ejemplo", - "misc": "Varios" + "misc": "Varios", + "metadataArchive": "Base de datos de archivo de metadatos" }, "contentFiltering": { "blurNsfwContent": "Difuminar contenido NSFW", @@ -273,6 +276,30 @@ "misc": { "includeTriggerWords": "Incluir palabras clave en la sintaxis de LoRA", "includeTriggerWordsHelp": "Incluir palabras clave entrenadas al copiar la sintaxis de LoRA al portapapeles" + }, + "metadataArchive": { + "enableArchiveDb": "Habilitar base de datos de archivo de metadatos", + "enableArchiveDbHelp": "Utiliza una base de datos local para acceder a metadatos de modelos que han sido eliminados de Civitai.", + "status": "Estado", + "statusAvailable": "Disponible", + "statusUnavailable": "No disponible", + "enabled": "Habilitado", + "management": "Gestión de base de datos", + "managementHelp": "Descargar o eliminar la base de datos de archivo de metadatos", + "downloadButton": "Descargar base de datos", + "downloadingButton": "Descargando...", + "downloadedButton": "Descargado", + "removeButton": "Eliminar base de datos", + "removingButton": "Eliminando...", + "downloadSuccess": "Base de datos de archivo de metadatos descargada exitosamente", + "downloadError": "Error al descargar la base de datos de archivo de metadatos", + "removeSuccess": "Base de datos de archivo de metadatos eliminada exitosamente", + "removeError": "Error al eliminar la base de datos de archivo de metadatos", + "removeConfirm": "¿Estás seguro de que quieres eliminar la base de datos de archivo de metadatos? Esto eliminará el archivo de base de datos local y tendrás que descargarlo de nuevo para usar esta función.", + "preparing": "Preparando descarga...", + "connecting": "Conectando al servidor de descarga...", + "completed": "Completado", + "downloadComplete": "Descarga completada exitosamente" } }, "loras": { diff --git a/locales/fr.json b/locales/fr.json index a86b464c..0424a93a 100644 --- a/locales/fr.json +++ b/locales/fr.json @@ -16,7 +16,9 @@ "loading": "Chargement...", "unknown": "Inconnu", "date": "Date", - "version": "Version" + "version": "Version", + "enabled": "Activé", + "disabled": "Désactivé" }, "language": { "select": "Langue", @@ -178,7 +180,8 @@ "folderSettings": "Paramètres des dossiers", "downloadPathTemplates": "Modèles de chemin de téléchargement", "exampleImages": "Images d'exemple", - "misc": "Divers" + "misc": "Divers", + "metadataArchive": "Base de données d'archive des métadonnées" }, "contentFiltering": { "blurNsfwContent": "Flouter le contenu NSFW", @@ -273,6 +276,30 @@ "misc": { "includeTriggerWords": "Inclure les mots-clés dans la syntaxe LoRA", "includeTriggerWordsHelp": "Inclure les mots-clés d'entraînement lors de la copie de la syntaxe LoRA dans le presse-papiers" + }, + "metadataArchive": { + "enableArchiveDb": "Activer la base de données d'archive des métadonnées", + "enableArchiveDbHelp": "Utiliser une base de données locale pour accéder aux métadonnées des modèles supprimés de Civitai.", + "status": "Statut", + "statusAvailable": "Disponible", + "statusUnavailable": "Non disponible", + "enabled": "Activé", + "management": "Gestion de la base de données", + "managementHelp": "Télécharger ou supprimer la base de données d'archive des métadonnées", + "downloadButton": "Télécharger la base de données", + "downloadingButton": "Téléchargement...", + "downloadedButton": "Téléchargé", + "removeButton": "Supprimer la base de données", + "removingButton": "Suppression...", + "downloadSuccess": "Base de données d'archive des métadonnées téléchargée avec succès", + "downloadError": "Échec du téléchargement de la base de données d'archive des métadonnées", + "removeSuccess": "Base de données d'archive des métadonnées supprimée avec succès", + "removeError": "Échec de la suppression de la base de données d'archive des métadonnées", + "removeConfirm": "Êtes-vous sûr de vouloir supprimer la base de données d'archive des métadonnées ? Cela supprimera le fichier local et vous devrez la télécharger à nouveau pour utiliser cette fonctionnalité.", + "preparing": "Préparation du téléchargement...", + "connecting": "Connexion au serveur de téléchargement...", + "completed": "Terminé", + "downloadComplete": "Téléchargement terminé avec succès" } }, "loras": { diff --git a/locales/ja.json b/locales/ja.json index cef7074b..f4bad7a6 100644 --- a/locales/ja.json +++ b/locales/ja.json @@ -16,7 +16,9 @@ "loading": "読み込み中...", "unknown": "不明", "date": "日付", - "version": "バージョン" + "version": "バージョン", + "enabled": "有効", + "disabled": "無効" }, "language": { "select": "言語", @@ -178,7 +180,8 @@ "folderSettings": "フォルダ設定", "downloadPathTemplates": "ダウンロードパステンプレート", "exampleImages": "例画像", - "misc": "その他" + "misc": "その他", + "metadataArchive": "メタデータアーカイブデータベース" }, "contentFiltering": { "blurNsfwContent": "NSFWコンテンツをぼかす", @@ -273,6 +276,30 @@ "misc": { "includeTriggerWords": "LoRA構文にトリガーワードを含める", "includeTriggerWordsHelp": "LoRA構文をクリップボードにコピーする際、学習済みトリガーワードを含めます" + }, + "metadataArchive": { + "enableArchiveDb": "メタデータアーカイブデータベースを有効化", + "enableArchiveDbHelp": "Civitaiから削除されたモデルのメタデータにアクセスするためにローカルデータベースを使用します。", + "status": "ステータス", + "statusAvailable": "利用可能", + "statusUnavailable": "利用不可", + "enabled": "有効", + "management": "データベース管理", + "managementHelp": "メタデータアーカイブデータベースのダウンロードまたは削除", + "downloadButton": "データベースをダウンロード", + "downloadingButton": "ダウンロード中...", + "downloadedButton": "ダウンロード済み", + "removeButton": "データベースを削除", + "removingButton": "削除中...", + "downloadSuccess": "メタデータアーカイブデータベースのダウンロードが完了しました", + "downloadError": "メタデータアーカイブデータベースのダウンロードに失敗しました", + "removeSuccess": "メタデータアーカイブデータベースが削除されました", + "removeError": "メタデータアーカイブデータベースの削除に失敗しました", + "removeConfirm": "本当にメタデータアーカイブデータベースを削除しますか?ローカルのデータベースファイルが削除され、この機能を再度利用するには再ダウンロードが必要です。", + "preparing": "ダウンロードを準備中...", + "connecting": "ダウンロードサーバーに接続中...", + "completed": "完了", + "downloadComplete": "ダウンロードが正常に完了しました" } }, "loras": { diff --git a/locales/ko.json b/locales/ko.json index 09765802..5eccaead 100644 --- a/locales/ko.json +++ b/locales/ko.json @@ -16,7 +16,9 @@ "loading": "로딩 중...", "unknown": "알 수 없음", "date": "날짜", - "version": "버전" + "version": "버전", + "enabled": "활성화됨", + "disabled": "비활성화됨" }, "language": { "select": "언어", @@ -178,7 +180,8 @@ "folderSettings": "폴더 설정", "downloadPathTemplates": "다운로드 경로 템플릿", "exampleImages": "예시 이미지", - "misc": "기타" + "misc": "기타", + "metadataArchive": "메타데이터 아카이브 데이터베이스" }, "contentFiltering": { "blurNsfwContent": "NSFW 콘텐츠 블러 처리", @@ -273,6 +276,30 @@ "misc": { "includeTriggerWords": "LoRA 문법에 트리거 단어 포함", "includeTriggerWordsHelp": "LoRA 문법을 클립보드에 복사할 때 학습된 트리거 단어를 포함합니다" + }, + "metadataArchive": { + "enableArchiveDb": "메타데이터 아카이브 데이터베이스 활성화", + "enableArchiveDbHelp": "Civitai에서 삭제된 모델의 메타데이터에 접근하기 위해 로컬 데이터베이스를 사용합니다.", + "status": "상태", + "statusAvailable": "사용 가능", + "statusUnavailable": "사용 불가", + "enabled": "활성화됨", + "management": "데이터베이스 관리", + "managementHelp": "메타데이터 아카이브 데이터베이스를 다운로드하거나 제거합니다", + "downloadButton": "데이터베이스 다운로드", + "downloadingButton": "다운로드 중...", + "downloadedButton": "다운로드 완료", + "removeButton": "데이터베이스 제거", + "removingButton": "제거 중...", + "downloadSuccess": "메타데이터 아카이브 데이터베이스가 성공적으로 다운로드되었습니다", + "downloadError": "메타데이터 아카이브 데이터베이스 다운로드 실패", + "removeSuccess": "메타데이터 아카이브 데이터베이스가 성공적으로 제거되었습니다", + "removeError": "메타데이터 아카이브 데이터베이스 제거 실패", + "removeConfirm": "메타데이터 아카이브 데이터베이스를 제거하시겠습니까? 이 작업은 로컬 데이터베이스 파일을 삭제하며, 이 기능을 사용하려면 다시 다운로드해야 합니다.", + "preparing": "다운로드 준비 중...", + "connecting": "다운로드 서버에 연결 중...", + "completed": "완료됨", + "downloadComplete": "다운로드가 성공적으로 완료되었습니다" } }, "loras": { diff --git a/locales/ru.json b/locales/ru.json index c32f9357..212c2c6f 100644 --- a/locales/ru.json +++ b/locales/ru.json @@ -16,7 +16,9 @@ "loading": "Загрузка...", "unknown": "Неизвестно", "date": "Дата", - "version": "Версия" + "version": "Версия", + "enabled": "Включено", + "disabled": "Отключено" }, "language": { "select": "Язык", @@ -178,7 +180,8 @@ "folderSettings": "Настройки папок", "downloadPathTemplates": "Шаблоны путей загрузки", "exampleImages": "Примеры изображений", - "misc": "Разное" + "misc": "Разное", + "metadataArchive": "Архив метаданных" }, "contentFiltering": { "blurNsfwContent": "Размывать NSFW контент", @@ -273,6 +276,30 @@ "misc": { "includeTriggerWords": "Включать триггерные слова в синтаксис LoRA", "includeTriggerWordsHelp": "Включать обученные триггерные слова при копировании синтаксиса LoRA в буфер обмена" + }, + "metadataArchive": { + "enableArchiveDb": "Включить архив метаданных", + "enableArchiveDbHelp": "Использовать локальную базу данных для доступа к метаданным моделей, удалённых с Civitai.", + "status": "Статус", + "statusAvailable": "Доступно", + "statusUnavailable": "Недоступно", + "enabled": "Включено", + "management": "Управление базой данных", + "managementHelp": "Скачать или удалить базу данных архива метаданных", + "downloadButton": "Скачать базу данных", + "downloadingButton": "Скачивание...", + "downloadedButton": "Скачано", + "removeButton": "Удалить базу данных", + "removingButton": "Удаление...", + "downloadSuccess": "База данных архива метаданных успешно загружена", + "downloadError": "Не удалось загрузить базу данных архива метаданных", + "removeSuccess": "База данных архива метаданных успешно удалена", + "removeError": "Не удалось удалить базу данных архива метаданных", + "removeConfirm": "Вы уверены, что хотите удалить базу данных архива метаданных? Это удалит локальный файл базы данных, и для использования этой функции потребуется повторная загрузка.", + "preparing": "Подготовка к загрузке...", + "connecting": "Подключение к серверу загрузки...", + "completed": "Завершено", + "downloadComplete": "Загрузка успешно завершена" } }, "loras": { diff --git a/locales/zh-CN.json b/locales/zh-CN.json index a9b6a1e6..b419ae0d 100644 --- a/locales/zh-CN.json +++ b/locales/zh-CN.json @@ -16,7 +16,9 @@ "loading": "加载中...", "unknown": "未知", "date": "日期", - "version": "版本" + "version": "版本", + "enabled": "已启用", + "disabled": "已禁用" }, "language": { "select": "语言", @@ -178,7 +180,8 @@ "folderSettings": "文件夹设置", "downloadPathTemplates": "下载路径模板", "exampleImages": "示例图片", - "misc": "其他" + "misc": "其他", + "metadataArchive": "元数据归档数据库" }, "contentFiltering": { "blurNsfwContent": "模糊 NSFW 内容", @@ -273,6 +276,30 @@ "misc": { "includeTriggerWords": "复制 LoRA 语法时包含触发词", "includeTriggerWordsHelp": "复制 LoRA 语法到剪贴板时包含训练触发词" + }, + "metadataArchive": { + "enableArchiveDb": "启用元数据归档数据库", + "enableArchiveDbHelp": "使用本地数据库访问已从 Civitai 删除的模型元数据。", + "status": "状态", + "statusAvailable": "可用", + "statusUnavailable": "不可用", + "enabled": "已启用", + "management": "数据库管理", + "managementHelp": "下载或移除元数据归档数据库", + "downloadButton": "下载数据库", + "downloadingButton": "正在下载...", + "downloadedButton": "已下载", + "removeButton": "移除数据库", + "removingButton": "正在移除...", + "downloadSuccess": "元数据归档数据库下载成功", + "downloadError": "元数据归档数据库下载失败", + "removeSuccess": "元数据归档数据库移除成功", + "removeError": "元数据归档数据库移除失败", + "removeConfirm": "你确定要移除元数据归档数据库吗?这将删除本地数据库文件,如需使用此功能需重新下载。", + "preparing": "正在准备下载...", + "connecting": "正在连接下载服务器...", + "completed": "已完成", + "downloadComplete": "下载成功完成" } }, "loras": { diff --git a/locales/zh-TW.json b/locales/zh-TW.json index 6c498946..7265fadf 100644 --- a/locales/zh-TW.json +++ b/locales/zh-TW.json @@ -16,7 +16,9 @@ "loading": "載入中...", "unknown": "未知", "date": "日期", - "version": "版本" + "version": "版本", + "enabled": "已啟用", + "disabled": "已停用" }, "language": { "select": "語言", @@ -178,7 +180,8 @@ "folderSettings": "資料夾設定", "downloadPathTemplates": "下載路徑範本", "exampleImages": "範例圖片", - "misc": "其他" + "misc": "其他", + "metadataArchive": "中繼資料封存資料庫" }, "contentFiltering": { "blurNsfwContent": "模糊 NSFW 內容", @@ -273,6 +276,30 @@ "misc": { "includeTriggerWords": "在 LoRA 語法中包含觸發詞", "includeTriggerWordsHelp": "複製 LoRA 語法到剪貼簿時包含訓練觸發詞" + }, + "metadataArchive": { + "enableArchiveDb": "啟用中繼資料封存資料庫", + "enableArchiveDbHelp": "使用本機資料庫以存取已從 Civitai 刪除模型的中繼資料。", + "status": "狀態", + "statusAvailable": "可用", + "statusUnavailable": "不可用", + "enabled": "已啟用", + "management": "資料庫管理", + "managementHelp": "下載或移除中繼資料封存資料庫", + "downloadButton": "下載資料庫", + "downloadingButton": "下載中...", + "downloadedButton": "已下載", + "removeButton": "移除資料庫", + "removingButton": "移除中...", + "downloadSuccess": "中繼資料封存資料庫下載成功", + "downloadError": "下載中繼資料封存資料庫失敗", + "removeSuccess": "中繼資料封存資料庫移除成功", + "removeError": "移除中繼資料封存資料庫失敗", + "removeConfirm": "您確定要移除中繼資料封存資料庫嗎?這將刪除本機資料庫檔案,若要再次使用此功能需重新下載。", + "preparing": "準備下載中...", + "connecting": "正在連接下載伺服器...", + "completed": "已完成", + "downloadComplete": "下載成功完成" } }, "loras": { diff --git a/py/routes/base_model_routes.py b/py/routes/base_model_routes.py index 099c2f58..b0a8055f 100644 --- a/py/routes/base_model_routes.py +++ b/py/routes/base_model_routes.py @@ -611,10 +611,10 @@ class BaseModelRoutes(ABC): success = 0 needs_resort = False - # Prepare models to process + # Prepare models to process, only those without CivitAI data to_process = [ model for model in cache.raw_data - if model.get('sha256') and (not model.get('civitai') or 'id' not in model.get('civitai')) and model.get('from_civitai', True) + if model.get('sha256') and (not model.get('civitai') or 'id' not in model.get('civitai')) ] total_to_process = len(to_process) diff --git a/py/routes/misc_routes.py b/py/routes/misc_routes.py index 118afea6..591cc15a 100644 --- a/py/routes/misc_routes.py +++ b/py/routes/misc_routes.py @@ -11,7 +11,7 @@ from ..utils.lora_metadata import extract_trained_words from ..config import config from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS, NODE_TYPES, DEFAULT_NODE_COLOR from ..services.service_registry import ServiceRegistry -from ..services.metadata_service import get_metadata_archive_manager, update_metadata_provider_priority +from ..services.metadata_service import get_metadata_archive_manager, update_metadata_providers from ..services.websocket_manager import ws_manager import re @@ -736,8 +736,8 @@ class MiscRoutes: # Update settings to enable metadata archive settings.set('enable_metadata_archive_db', True) - # Update provider priority - await update_metadata_provider_priority() + # Update metadata providers + await update_metadata_providers() return web.json_response({ 'success': True, @@ -768,8 +768,8 @@ class MiscRoutes: # Update settings to disable metadata archive settings.set('enable_metadata_archive_db', False) - # Update provider priority - await update_metadata_provider_priority() + # Update metadata providers + await update_metadata_providers() return web.json_response({ 'success': True, @@ -796,7 +796,6 @@ class MiscRoutes: is_available = archive_manager.is_database_available() is_enabled = settings.get('enable_metadata_archive_db', False) - priority = settings.get('metadata_provider_priority', 'archive_db') db_size = 0 if is_available: @@ -808,7 +807,6 @@ class MiscRoutes: 'success': True, 'isAvailable': is_available, 'isEnabled': is_enabled, - 'priority': priority, 'databaseSize': db_size, 'databasePath': archive_manager.get_database_path() if is_available else None }) diff --git a/py/services/metadata_service.py b/py/services/metadata_service.py index 7823a1f7..86a94eaf 100644 --- a/py/services/metadata_service.py +++ b/py/services/metadata_service.py @@ -22,7 +22,6 @@ async def initialize_metadata_providers(): # Get settings enable_archive_db = settings.get('enable_metadata_archive_db', False) - priority = settings.get('metadata_provider_priority', 'archive_db') providers = [] @@ -54,23 +53,17 @@ async def initialize_metadata_providers(): except Exception as e: logger.error(f"Failed to initialize Civitai API metadata provider: {e}") - # Set up fallback provider based on priority and available providers + # Set up fallback provider based on available providers if len(providers) > 1: - # Order providers based on priority setting + # Always use Civitai API first, then Archive DB ordered_providers = [] - if priority == 'archive_db': - # Archive DB first, then Civitai API - ordered_providers = [p[1] for p in providers if p[0] == 'sqlite'] - ordered_providers.extend([p[1] for p in providers if p[0] == 'civitai_api']) - else: - # Civitai API first, then Archive DB - ordered_providers = [p[1] for p in providers if p[0] == 'civitai_api'] - ordered_providers.extend([p[1] for p in providers if p[0] == 'sqlite']) + ordered_providers.extend([p[1] for p in providers if p[0] == 'civitai_api']) + ordered_providers.extend([p[1] for p in providers if p[0] == 'sqlite']) if ordered_providers: fallback_provider = FallbackMetadataProvider(ordered_providers) provider_manager.register_provider('fallback', fallback_provider, is_default=True) - logger.info(f"Fallback metadata provider registered with {len(ordered_providers)} providers, priority: {priority}") + logger.info(f"Fallback metadata provider registered with {len(ordered_providers)} providers, Civitai API first") elif len(providers) == 1: # Only one provider available, set it as default provider_name, provider = providers[0] @@ -81,20 +74,19 @@ async def initialize_metadata_providers(): return provider_manager -async def update_metadata_provider_priority(): - """Update metadata provider priority based on current settings""" +async def update_metadata_providers(): + """Update metadata providers based on current settings""" try: # Get current settings enable_archive_db = settings.get('enable_metadata_archive_db', False) - priority = settings.get('metadata_provider_priority', 'archive_db') # Reinitialize all providers with new settings provider_manager = await initialize_metadata_providers() - logger.info(f"Updated metadata provider priority to: {priority}, archive_db enabled: {enable_archive_db}") + logger.info(f"Updated metadata providers, archive_db enabled: {enable_archive_db}") return provider_manager except Exception as e: - logger.error(f"Failed to update metadata provider priority: {e}") + logger.error(f"Failed to update metadata providers: {e}") return await ModelMetadataProviderManager.get_instance() async def get_metadata_archive_manager(): diff --git a/py/services/settings_manager.py b/py/services/settings_manager.py index 058d9944..7d99da48 100644 --- a/py/services/settings_manager.py +++ b/py/services/settings_manager.py @@ -82,8 +82,7 @@ class SettingsManager: "civitai_api_key": "", "show_only_sfw": False, "language": "en", - "enable_metadata_archive_db": False, # Enable metadata archive database - "metadata_provider_priority": "archive_db" # Default priority: 'archive_db' or 'civitai_api' + "enable_metadata_archive_db": False # Enable metadata archive database } def get(self, key: str, default: Any = None) -> Any: diff --git a/py/utils/routes_common.py b/py/utils/routes_common.py index a82dbd2f..2c7e13ca 100644 --- a/py/utils/routes_common.py +++ b/py/utils/routes_common.py @@ -12,7 +12,7 @@ from ..services.downloader import get_downloader from ..utils.exif_utils import ExifUtils from ..utils.metadata_manager import MetadataManager from ..services.websocket_manager import ws_manager -from ..services.metadata_service import get_default_metadata_provider +from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider logger = logging.getLogger(__name__) @@ -41,13 +41,15 @@ class ModelRouteUtils: def is_civitai_api_metadata(meta: dict) -> bool: """ Determine if the given civitai metadata is from the civitai API. - Returns True if both 'files' and 'images' exist and are non-empty. + Returns True if both 'files' and 'images' exist and are non-empty, + and the 'source' is not 'archive_db'. """ if not isinstance(meta, dict): return False files = meta.get('files') images = meta.get('images') - return bool(files) and bool(images) + source = meta.get('source') + return bool(files) and bool(images) and source != 'archive_db' @staticmethod async def update_model_metadata(metadata_path: str, local_metadata: Dict, @@ -58,12 +60,17 @@ class ModelRouteUtils: # Check if we should skip the update to avoid overwriting richer data if civitai_metadata.get('source') == 'archive_db' and ModelRouteUtils.is_civitai_api_metadata(existing_civitai): - logger.info(f"Skip civitai update for {local_metadata.get('model_name', '')}: {existing_civitai.get('name', '')}") + logger.info(f"Skip civitai update for {local_metadata.get('model_name', '')} ({existing_civitai.get('name', '')})") else: # Create a new civitai metadata by updating existing with new merged_civitai = existing_civitai.copy() merged_civitai.update(civitai_metadata) + if civitai_metadata.get('source') == 'archive_db': + model_name = civitai_metadata.get('model', {}).get('name', '') + version_name = civitai_metadata.get('name', '') + logger.info(f"Recovered metadata from archive_db for deleted model: {model_name} ({version_name})") + # Special handling for trainedWords - ensure we don't lose any existing trained words if 'trainedWords' in existing_civitai: existing_trained_words = existing_civitai.get('trainedWords', []) @@ -210,8 +217,12 @@ class ModelRouteUtils: # Check if model metadata exists local_metadata = await ModelRouteUtils.load_local_metadata(metadata_path) - # Get metadata provider and fetch metadata from unified provider - metadata_provider = await get_default_metadata_provider() + if model_data.get('from_civitai') is False: + # Likely deleted from CivitAI, use archive_db if available + metadata_provider = await get_metadata_provider('sqlite') + else: + metadata_provider = await get_default_metadata_provider() + civitai_metadata = await metadata_provider.get_model_by_hash(sha256) if not civitai_metadata: # Mark as not from CivitAI if not found diff --git a/static/js/managers/SettingsManager.js b/static/js/managers/SettingsManager.js index ac795903..6ef0e404 100644 --- a/static/js/managers/SettingsManager.js +++ b/static/js/managers/SettingsManager.js @@ -850,7 +850,7 @@ export class SettingsManager { } // Special handling for metadata archive settings - if (settingKey === 'enable_metadata_archive_db' || settingKey === 'metadata_provider_priority') { + if (settingKey === 'enable_metadata_archive_db') { await this.updateMetadataArchiveStatus(); } @@ -879,8 +879,6 @@ export class SettingsManager { state.global.settings.compactMode = (value !== 'default'); } else if (settingKey === 'card_info_display') { state.global.settings.cardInfoDisplay = value; - } else if (settingKey === 'metadata_provider_priority') { - state.global.settings.metadata_provider_priority = value; } else { // For any other settings that might be added in the future state.global.settings[settingKey] = value; @@ -891,7 +889,7 @@ export class SettingsManager { try { // For backend settings, make API call - if (settingKey === 'default_lora_root' || settingKey === 'default_checkpoint_root' || settingKey === 'default_embedding_root' || settingKey === 'download_path_templates' || settingKey === 'metadata_provider_priority') { + if (settingKey === 'default_lora_root' || settingKey === 'default_checkpoint_root' || settingKey === 'default_embedding_root' || settingKey === 'download_path_templates') { const payload = {}; if (settingKey === 'download_path_templates') { payload[settingKey] = state.global.settings.download_path_templates; @@ -912,11 +910,6 @@ export class SettingsManager { } showToast('toast.settings.settingsUpdated', { setting: settingKey.replace(/_/g, ' ') }, 'success'); - - // Refresh metadata archive status when provider priority changes - if (settingKey === 'metadata_provider_priority') { - await this.updateMetadataArchiveStatus(); - } } // Apply frontend settings immediately @@ -932,11 +925,6 @@ export class SettingsManager { showToast('toast.settings.displayDensitySet', { density: densityName }, 'success'); } - - // Special handling for metadata archive settings - if (settingKey === 'metadata_provider_priority') { - await this.updateMetadataArchiveStatus(); - } } catch (error) { showToast('toast.settings.settingSaveFailed', { message: error.message }, 'error'); @@ -951,11 +939,6 @@ export class SettingsManager { enableMetadataArchiveCheckbox.checked = state.global.settings.enable_metadata_archive_db || false; } - const metadataProviderPrioritySelect = document.getElementById('metadataProviderPriority'); - if (metadataProviderPrioritySelect) { - metadataProviderPrioritySelect.value = state.global.settings.metadata_provider_priority || 'archive_db'; - } - // Load status await this.updateMetadataArchiveStatus(); } catch (error) { @@ -987,12 +970,6 @@ export class SettingsManager { ${status.isEnabled ? translate('common.status.enabled') : translate('common.status.disabled')} -
- ${translate('settings.metadataArchive.currentPriority')}: - - ${status.priority === 'archive_db' ? translate('settings.metadataArchive.priorityArchiveDb') : translate('settings.metadataArchive.priorityCivitaiApi')} - -
`; // Update button states diff --git a/templates/components/modals/settings_modal.html b/templates/components/modals/settings_modal.html index d150c81d..3b1b3a8a 100644 --- a/templates/components/modals/settings_modal.html +++ b/templates/components/modals/settings_modal.html @@ -419,23 +419,6 @@ -
-
-
- -
-
- -
-
-
- {{ t('settings.metadataArchive.providerPriorityHelp') }} -
-
-
From ba1ac587214325802c77392b899e41c07e31b5b7 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 10 Sep 2025 16:18:04 +0800 Subject: [PATCH 09/13] feat(metadata): trigger metadata provider update when enabling metadata archive database --- py/routes/misc_routes.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/py/routes/misc_routes.py b/py/routes/misc_routes.py index 591cc15a..5a53fea9 100644 --- a/py/routes/misc_routes.py +++ b/py/routes/misc_routes.py @@ -188,10 +188,13 @@ class MiscRoutes: old_path = settings.get('example_images_path') if old_path != value: logger.info(f"Example images path changed to {value} - server restart required") - + # Save to settings settings.set(key, value) + if key == 'enable_metadata_archive_db': + await update_metadata_providers() + return web.json_response({'success': True}) except Exception as e: logger.error(f"Error updating settings: {e}", exc_info=True) From 4ee5b7481cc2a12379a2937daa926973f5585189 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 10 Sep 2025 18:49:35 +0800 Subject: [PATCH 10/13] fix(downloader): set socket read timeout to 5 minutes for improved stability during large downloads --- py/services/downloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/services/downloader.py b/py/services/downloader.py index 9efba53b..dc38c0d1 100644 --- a/py/services/downloader.py +++ b/py/services/downloader.py @@ -97,7 +97,7 @@ class Downloader: timeout = aiohttp.ClientTimeout( total=None, # No total timeout for large downloads connect=60, # Connection timeout - sock_read=None # No socket read timeout + sock_read=300 # 5 minute socket read timeout ) self._session = aiohttp.ClientSession( From 3e5cb223f3bac621b2bbff67b8499c3105362d92 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 10 Sep 2025 20:09:05 +0800 Subject: [PATCH 11/13] refactor(metadata): remove outdated metadata provider summary documentation --- METADATA_PROVIDER_REFACTOR_SUMMARY.md | 119 -------------------------- 1 file changed, 119 deletions(-) delete mode 100644 METADATA_PROVIDER_REFACTOR_SUMMARY.md diff --git a/METADATA_PROVIDER_REFACTOR_SUMMARY.md b/METADATA_PROVIDER_REFACTOR_SUMMARY.md deleted file mode 100644 index 8e28a60d..00000000 --- a/METADATA_PROVIDER_REFACTOR_SUMMARY.md +++ /dev/null @@ -1,119 +0,0 @@ -# Metadata Provider Refactor Summary - -## Overview -This refactor improves the metadata provider initialization logic and replaces direct Civitai client usage with the unified FallbackMetadataProvider system throughout the codebase. - -## Key Changes - -### 1. Enhanced Metadata Service (`py/services/metadata_service.py`) - -#### Improved `initialize_metadata_providers()`: -- Added provider clearing for proper reinitialization -- Enhanced error handling and validation -- Better logging for debugging -- Improved provider ordering logic based on priority settings -- More robust database path validation - -#### Enhanced `update_metadata_provider_priority()`: -- More robust error handling -- Proper reinitalization of all providers -- Better logging for setting changes - -#### New helper function: -- Added `get_default_metadata_provider()` for easier access to the default provider - -### 2. Updated Recipe Parsers -All recipe parsers now use the unified metadata provider instead of direct civitai_client: - -#### Files Updated: -- `py/recipes/parsers/civitai_image.py` -- `py/recipes/parsers/comfy.py` -- `py/recipes/parsers/automatic.py` -- `py/recipes/parsers/recipe_format.py` -- `py/recipes/parsers/meta_format.py` - -#### Changes Made: -- Added import for `get_default_metadata_provider` -- Replaced `civitai_client.get_model_by_hash()` with `metadata_provider.get_model_by_hash()` -- Replaced `civitai_client.get_model_version_info()` with `metadata_provider.get_model_version_info()` -- Updated method signatures to indicate civitai_client parameter is deprecated - -### 3. Download Manager Updates (`py/services/download_manager.py`) - -#### Metadata Operations: -- Replaced direct civitai_client usage with metadata_provider for: - - `get_model_version()` calls for version info - -#### Download Operations: -- Replaced `civitai_client.download_file()` with direct `downloader.download_file()` calls -- Replaced `civitai_client.download_preview_image()` with `downloader.download_to_memory()` for images -- Added proper authentication flags (`use_auth=True` for model files, `use_auth=False` for preview images) - -### 4. Recipe Scanner Updates (`py/services/recipe_scanner.py`) -- Added import for `get_default_metadata_provider` -- Replaced `civitai_client.get_model_version_info()` with `metadata_provider.get_model_version_info()` - -### 5. Utility Functions Updates (`py/utils/routes_common.py`) -- Added import for `get_downloader` -- Replaced preview image downloads with direct downloader usage -- Improved image optimization logic to work with in-memory downloads -- Better error handling for download and image processing operations - -## Benefits - -### 1. Unified Metadata Access -- All metadata requests now go through the fallback provider system -- Automatic failover between SQLite archive database and Civitai API -- Consistent metadata access patterns across all components - -### 2. Improved Download Performance -- Direct use of the optimized downloader service -- Better connection pooling and retry logic -- Proper authentication handling -- Support for resumable downloads - -### 3. Better Configuration Management -- Settings changes now properly update provider priority -- Clear separation between metadata and download operations -- Improved error handling and logging - -### 4. Enhanced Reliability -- Fallback mechanisms ensure metadata is always available when possible -- Better error handling and recovery -- Consistent behavior across all parsers and services - -## Usage - -### Settings Changes -When users change metadata provider settings: -1. The `update_metadata_provider_priority()` function is automatically called -2. All providers are reinitialized with the new settings -3. The fallback provider is updated with the correct priority order - -### Metadata Access -All components now use: -```python -from ...services.metadata_service import get_default_metadata_provider - -metadata_provider = await get_default_metadata_provider() -result = await metadata_provider.get_model_by_hash(hash_value) -``` - -### Downloads -All downloads now use the unified downloader: -```python -from ...services.downloader import get_downloader - -downloader = await get_downloader() -success, result = await downloader.download_file(url, path, use_auth=True) -``` - -## Compatibility -- All existing APIs and interfaces remain unchanged -- Backward compatibility maintained for existing workflows -- No changes required for external integrations - -## Testing -- All updated files pass syntax validation -- Existing functionality preserved -- Enhanced error handling and logging for better debugging From 62f06302f0fb91cc9c96241dbd28ce8c1572e0de Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 10 Sep 2025 20:29:26 +0800 Subject: [PATCH 12/13] refactor(routes): replace ModelMetadataProviderManager with get_default_metadata_provider in checkpoint, embedding, and lora routes --- py/routes/checkpoint_routes.py | 7 +++---- py/routes/embedding_routes.py | 7 +++---- py/routes/lora_routes.py | 16 ++++++++-------- 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/py/routes/checkpoint_routes.py b/py/routes/checkpoint_routes.py index b93700cf..a0f6a027 100644 --- a/py/routes/checkpoint_routes.py +++ b/py/routes/checkpoint_routes.py @@ -4,7 +4,7 @@ from aiohttp import web from .base_model_routes import BaseModelRoutes from ..services.checkpoint_service import CheckpointService from ..services.service_registry import ServiceRegistry -from ..services.model_metadata_provider import ModelMetadataProviderManager +from ..services.metadata_service import get_default_metadata_provider from ..config import config logger = logging.getLogger(__name__) @@ -16,14 +16,12 @@ class CheckpointRoutes(BaseModelRoutes): """Initialize Checkpoint routes with Checkpoint service""" # Service will be initialized later via setup_routes self.service = None - self.metadata_provider = None self.template_name = "checkpoints.html" async def initialize_services(self): """Initialize services from ServiceRegistry""" checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() self.service = CheckpointService(checkpoint_scanner) - self.metadata_provider = await ModelMetadataProviderManager.get_instance() # Initialize parent with the service super().__init__(self.service) @@ -67,7 +65,8 @@ class CheckpointRoutes(BaseModelRoutes): """Get available versions for a Civitai checkpoint model with local availability info""" try: model_id = request.match_info['model_id'] - response = await self.metadata_provider.get_model_versions(model_id) + metadata_provider = await get_default_metadata_provider() + response = await metadata_provider.get_model_versions(model_id) if not response or not response.get('modelVersions'): return web.Response(status=404, text="Model not found") diff --git a/py/routes/embedding_routes.py b/py/routes/embedding_routes.py index 65a66824..ab028666 100644 --- a/py/routes/embedding_routes.py +++ b/py/routes/embedding_routes.py @@ -4,7 +4,7 @@ from aiohttp import web from .base_model_routes import BaseModelRoutes from ..services.embedding_service import EmbeddingService from ..services.service_registry import ServiceRegistry -from ..services.model_metadata_provider import ModelMetadataProviderManager +from ..services.metadata_service import get_default_metadata_provider logger = logging.getLogger(__name__) @@ -15,14 +15,12 @@ class EmbeddingRoutes(BaseModelRoutes): """Initialize Embedding routes with Embedding service""" # Service will be initialized later via setup_routes self.service = None - self.metadata_provider = None self.template_name = "embeddings.html" async def initialize_services(self): """Initialize services from ServiceRegistry""" embedding_scanner = await ServiceRegistry.get_embedding_scanner() self.service = EmbeddingService(embedding_scanner) - self.metadata_provider = await ModelMetadataProviderManager.get_instance() # Initialize parent with the service super().__init__(self.service) @@ -62,7 +60,8 @@ class EmbeddingRoutes(BaseModelRoutes): """Get available versions for a Civitai embedding model with local availability info""" try: model_id = request.match_info['model_id'] - response = await self.metadata_provider.get_model_versions(model_id) + metadata_provider = await get_default_metadata_provider() + response = await metadata_provider.get_model_versions(model_id) if not response or not response.get('modelVersions'): return web.Response(status=404, text="Model not found") diff --git a/py/routes/lora_routes.py b/py/routes/lora_routes.py index 4c1c0467..4e261004 100644 --- a/py/routes/lora_routes.py +++ b/py/routes/lora_routes.py @@ -7,8 +7,7 @@ from server import PromptServer # type: ignore from .base_model_routes import BaseModelRoutes from ..services.lora_service import LoraService from ..services.service_registry import ServiceRegistry -from ..services.model_metadata_provider import ModelMetadataProviderManager -from ..utils.routes_common import ModelRouteUtils +from ..services.metadata_service import get_default_metadata_provider from ..utils.utils import get_lora_info logger = logging.getLogger(__name__) @@ -20,14 +19,12 @@ class LoraRoutes(BaseModelRoutes): """Initialize LoRA routes with LoRA service""" # Service will be initialized later via setup_routes self.service = None - self.metadata_provider = None self.template_name = "loras.html" async def initialize_services(self): """Initialize services from ServiceRegistry""" lora_scanner = await ServiceRegistry.get_lora_scanner() self.service = LoraService(lora_scanner) - self.metadata_provider = await ModelMetadataProviderManager.get_instance() # Initialize parent with the service super().__init__(self.service) @@ -218,7 +215,8 @@ class LoraRoutes(BaseModelRoutes): """Get available versions for a Civitai LoRA model with local availability info""" try: model_id = request.match_info['model_id'] - response = await self.metadata_provider.get_model_versions(model_id) + metadata_provider = await get_default_metadata_provider() + response = await metadata_provider.get_model_versions(model_id) if not response or not response.get('modelVersions'): return web.Response(status=404, text="Model not found") @@ -263,8 +261,9 @@ class LoraRoutes(BaseModelRoutes): model_version_id = request.match_info.get('modelVersionId') # Get model details from metadata provider - model, error_msg = await self.metadata_provider.get_model_version_info(model_version_id) - + metadata_provider = await get_default_metadata_provider() + model, error_msg = await metadata_provider.get_model_version_info(model_version_id) + if not model: # Log warning for failed model retrieval logger.warning(f"Failed to fetch model version {model_version_id}: {error_msg}") @@ -289,7 +288,8 @@ class LoraRoutes(BaseModelRoutes): """Get CivitAI model details by hash""" try: hash = request.match_info.get('hash') - model = await self.metadata_provider.get_model_by_hash(hash) + metadata_provider = await get_default_metadata_provider() + model = await metadata_provider.get_model_by_hash(hash) return web.json_response(model) except Exception as e: logger.error(f"Error fetching model details by hash: {e}") From 1fc8b45b685fd324cae4b82b4b128900ef83a33c Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 10 Sep 2025 20:33:45 +0800 Subject: [PATCH 13/13] feat(dependencies): add GitPython and aiosqlite to project dependencies --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index fc66e57d..202fb961 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,8 @@ dependencies = [ "olefile", # for getting rid of warning message "toml", "natsort", - "GitPython" + "GitPython", + "aiosqlite" ] [project.urls]