diff --git a/py/lora_manager.py b/py/lora_manager.py index df082d1b..f6598bca 100644 --- a/py/lora_manager.py +++ b/py/lora_manager.py @@ -190,6 +190,9 @@ class LoraManager: # Register DownloadManager with ServiceRegistry await ServiceRegistry.get_download_manager() + + from .services.metadata_service import initialize_metadata_providers + await initialize_metadata_providers() # Initialize WebSocket manager await ServiceRegistry.get_websocket_manager() diff --git a/py/services/civitai_client.py b/py/services/civitai_client.py index ff870f6d..64a6f2f5 100644 --- a/py/services/civitai_client.py +++ b/py/services/civitai_client.py @@ -3,9 +3,8 @@ import aiohttp import os import logging import asyncio -from email.parser import Parser from typing import Optional, Dict, Tuple, List -from urllib.parse import unquote +from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager logger = logging.getLogger(__name__) @@ -19,6 +18,11 @@ class CivitaiClient: async with cls._lock: if cls._instance is None: cls._instance = cls() + + # Register this client as a metadata provider + provider_manager = await ModelMetadataProviderManager.get_instance() + provider_manager.register_provider('civitai', CivitaiModelMetadataProvider(cls._instance), True) + return cls._instance def __init__(self): @@ -69,24 +73,6 @@ class CivitaiClient: return await self.session - def _parse_content_disposition(self, header: str) -> str: - """Parse filename from content-disposition header""" - if not header: - return None - - # Handle quoted filenames - if 'filename="' in header: - start = header.index('filename="') + 10 - end = header.index('"', start) - return unquote(header[start:end]) - - # Fallback to original parsing - disposition = Parser().parsestr(f'Content-Disposition: {header}') - filename = disposition.get_param('filename') - if filename: - return unquote(filename) - return None - def _get_request_headers(self) -> dict: """Get request headers with optional API key""" headers = { @@ -101,7 +87,7 @@ class CivitaiClient: return headers - async def _download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]: + async def download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]: """Download file with resumable downloads and retry mechanism Args: @@ -129,7 +115,6 @@ class CivitaiClient: logger.info(f"Resuming download from offset {resume_offset} bytes") total_size = 0 - filename = default_filename while retry_count <= max_retries: try: diff --git a/py/services/download_manager.py b/py/services/download_manager.py index 8f491fa8..08295985 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -487,7 +487,7 @@ class DownloadManager: await progress_callback(3) # 3% progress after preview download # Download model file with progress tracking - success, result = await civitai_client._download_file( + success, result = await civitai_client.download_file( download_url, save_dir, os.path.basename(save_path), diff --git a/py/services/metadata_service.py b/py/services/metadata_service.py new file mode 100644 index 00000000..beb8912d --- /dev/null +++ b/py/services/metadata_service.py @@ -0,0 +1,33 @@ +import os +import logging +from .model_metadata_provider import ModelMetadataProviderManager, SQLiteModelMetadataProvider + +logger = logging.getLogger(__name__) + +async def initialize_metadata_providers(): + """Initialize and configure all metadata providers""" + provider_manager = await ModelMetadataProviderManager.get_instance() + + # Use hardcoded SQLite DB path if not set in settings + db_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))), + 'civitai', 'civitai.sqlite' + ) + if db_path and os.path.exists(db_path): + try: + sqlite_provider = SQLiteModelMetadataProvider(db_path) + provider_manager.register_provider('sqlite', sqlite_provider) + logger.info(f"SQLite metadata provider registered with database: {db_path}") + except Exception as e: + logger.error(f"Failed to initialize SQLite metadata provider: {e}") + + return provider_manager + +async def get_metadata_provider(provider_name: str = None): + """Get a specific metadata provider or default provider""" + provider_manager = await ModelMetadataProviderManager.get_instance() + + if provider_name: + return provider_manager._get_provider(provider_name) + + return provider_manager._get_provider() diff --git a/py/services/model_metadata_provider.py b/py/services/model_metadata_provider.py new file mode 100644 index 00000000..95dfc1da --- /dev/null +++ b/py/services/model_metadata_provider.py @@ -0,0 +1,389 @@ +from abc import ABC, abstractmethod +import json +import aiosqlite +import logging +from typing import Optional, Dict, List, Tuple, Any + +logger = logging.getLogger(__name__) + +class ModelMetadataProvider(ABC): + """Base abstract class for all model metadata providers""" + + @abstractmethod + async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]: + """Find model by hash value""" + pass + + @abstractmethod + async def get_model_versions(self, model_id: str) -> Optional[Dict]: + """Get all versions of a model with their details""" + pass + + @abstractmethod + async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: + """Get specific model version with additional metadata""" + pass + + @abstractmethod + async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]: + """Fetch model version metadata""" + pass + + @abstractmethod + async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]: + """Fetch model metadata (description, tags, and creator info)""" + pass + +class CivitaiModelMetadataProvider(ModelMetadataProvider): + """Provider that uses Civitai API for metadata""" + + def __init__(self, civitai_client): + self.client = civitai_client + + async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]: + return await self.client.get_model_by_hash(model_hash) + + async def get_model_versions(self, model_id: str) -> Optional[Dict]: + return await self.client.get_model_versions(model_id) + + async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: + return await self.client.get_model_version(model_id, version_id) + + async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]: + return await self.client.get_model_version_info(version_id) + + async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]: + return await self.client.get_model_metadata(model_id) + +class SQLiteModelMetadataProvider(ModelMetadataProvider): + """Provider that uses SQLite database for metadata""" + + def __init__(self, db_path: str): + self.db_path = db_path + + async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]: + """Find model by hash value from SQLite database""" + async with aiosqlite.connect(self.db_path) as db: + # Look up in model_files table to get model_id and version_id + query = """ + SELECT model_id, version_id + FROM model_files + WHERE sha256 = ? + LIMIT 1 + """ + db.row_factory = aiosqlite.Row + cursor = await db.execute(query, (model_hash.upper(),)) + file_row = await cursor.fetchone() + + if not file_row: + return None + + # Get version details + model_id = file_row['model_id'] + version_id = file_row['version_id'] + + # Build response in the same format as Civitai API + return await self._get_version_with_model_data(db, model_id, version_id) + + async def get_model_versions(self, model_id: str) -> Optional[Dict]: + """Get all versions of a model from SQLite database""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + + # First check if model exists + model_query = "SELECT * FROM models WHERE id = ?" + cursor = await db.execute(model_query, (model_id,)) + model_row = await cursor.fetchone() + + if not model_row: + return None + + model_data = json.loads(model_row['data']) + model_type = model_row['type'] + + # Get all versions for this model + versions_query = """ + SELECT id, name, base_model, data, position, published_at + FROM model_versions + WHERE model_id = ? + ORDER BY position ASC + """ + cursor = await db.execute(versions_query, (model_id,)) + version_rows = await cursor.fetchall() + + if not version_rows: + return {'modelVersions': [], 'type': model_type} + + # Format versions similar to Civitai API + model_versions = [] + for row in version_rows: + version_data = json.loads(row['data']) + # Add fields from the row to ensure we have the basic fields + version_entry = { + 'id': row['id'], + 'modelId': int(model_id), + 'name': row['name'], + 'baseModel': row['base_model'], + 'model': { + 'name': model_row['name'], + 'type': model_type, + } + } + # Update with any additional data + version_entry.update(version_data) + model_versions.append(version_entry) + + return { + 'modelVersions': model_versions, + 'type': model_type + } + + async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: + """Get specific model version with additional metadata from SQLite database""" + if not model_id and not version_id: + return None + + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + + # Case 1: Only version_id is provided + if model_id is None and version_id is not None: + # First get the version info to extract model_id + version_query = "SELECT model_id FROM model_versions WHERE id = ?" + cursor = await db.execute(version_query, (version_id,)) + version_row = await cursor.fetchone() + + if not version_row: + return None + + model_id = version_row['model_id'] + + # Case 2: model_id is provided but version_id is not + elif model_id is not None and version_id is None: + # Find the latest version + version_query = """ + SELECT id FROM model_versions + WHERE model_id = ? + ORDER BY position ASC + LIMIT 1 + """ + cursor = await db.execute(version_query, (model_id,)) + version_row = await cursor.fetchone() + + if not version_row: + return None + + version_id = version_row['id'] + + # Now we have both model_id and version_id, get the full data + return await self._get_version_with_model_data(db, model_id, version_id) + + async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]: + """Fetch model version metadata from SQLite database""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + + # Get version details + version_query = "SELECT model_id FROM model_versions WHERE id = ?" + cursor = await db.execute(version_query, (version_id,)) + version_row = await cursor.fetchone() + + if not version_row: + return None, "Model version not found" + + model_id = version_row['model_id'] + + # Build complete version data with model info + version_data = await self._get_version_with_model_data(db, model_id, version_id) + return version_data, None + + async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]: + """Fetch model metadata from SQLite database""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + + # Get model details + model_query = "SELECT name, type, data, username FROM models WHERE id = ?" + cursor = await db.execute(model_query, (model_id,)) + model_row = await cursor.fetchone() + + if not model_row: + return None, 404 + + # Parse data JSON + try: + model_data = json.loads(model_row['data']) + + # Extract relevant metadata + metadata = { + "description": model_data.get("description", "No model description available"), + "tags": model_data.get("tags", []), + "creator": { + "username": model_row['username'] or model_data.get("creator", {}).get("username"), + "image": model_data.get("creator", {}).get("image") + } + } + + return metadata, 200 + except json.JSONDecodeError: + return None, 500 + + async def _get_version_with_model_data(self, db, model_id, version_id) -> Optional[Dict]: + """Helper to build version data with model information""" + # Get version details + version_query = "SELECT name, base_model, data FROM model_versions WHERE id = ? AND model_id = ?" + cursor = await db.execute(version_query, (version_id, model_id)) + version_row = await cursor.fetchone() + + if not version_row: + return None + + # Get model details + model_query = "SELECT name, type, data, username FROM models WHERE id = ?" + cursor = await db.execute(model_query, (model_id,)) + model_row = await cursor.fetchone() + + if not model_row: + return None + + # Parse JSON data + try: + version_data = json.loads(version_row['data']) + model_data = json.loads(model_row['data']) + + # Build response + result = { + "id": int(version_id), + "modelId": int(model_id), + "name": version_row['name'], + "baseModel": version_row['base_model'], + "model": { + "name": model_row['name'], + "description": model_data.get("description"), + "type": model_row['type'], + "tags": model_data.get("tags", []) + }, + "creator": { + "username": model_row['username'] or model_data.get("creator", {}).get("username"), + "image": model_data.get("creator", {}).get("image") + } + } + + # Add any additional fields from version data + result.update(version_data) + + return result + except json.JSONDecodeError: + return None + +class FallbackMetadataProvider(ModelMetadataProvider): + """Try providers in order, return first successful result.""" + def __init__(self, providers: list): + self.providers = providers + + async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]: + for provider in self.providers: + try: + result = await provider.get_model_by_hash(model_hash) + if result: + return result + except Exception: + continue + return None + + async def get_model_versions(self, model_id: str) -> Optional[Dict]: + for provider in self.providers: + try: + result = await provider.get_model_versions(model_id) + if result: + return result + except Exception: + continue + return None + + async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: + for provider in self.providers: + try: + result = await provider.get_model_version(model_id, version_id) + if result: + return result + except Exception: + continue + return None + + async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]: + for provider in self.providers: + try: + result, err = await provider.get_model_version_info(version_id) + if result: + return result, err + except Exception: + continue + return None, "Not found in any provider" + + async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]: + for provider in self.providers: + try: + result, code = await provider.get_model_metadata(model_id) + if result: + return result, code + except Exception: + continue + return None, 404 + +class ModelMetadataProviderManager: + """Manager for selecting and using model metadata providers""" + + _instance = None + + @classmethod + async def get_instance(cls): + """Get singleton instance of ModelMetadataProviderManager""" + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def __init__(self): + self.providers = {} + self.default_provider = None + + def register_provider(self, name: str, provider: ModelMetadataProvider, is_default: bool = False): + """Register a metadata provider""" + self.providers[name] = provider + if is_default or self.default_provider is None: + self.default_provider = name + + async def get_model_by_hash(self, model_hash: str, provider_name: str = None) -> Optional[Dict]: + """Find model by hash using specified or default provider""" + provider = self._get_provider(provider_name) + return await provider.get_model_by_hash(model_hash) + + async def get_model_versions(self, model_id: str, provider_name: str = None) -> Optional[Dict]: + """Get model versions using specified or default provider""" + provider = self._get_provider(provider_name) + return await provider.get_model_versions(model_id) + + async def get_model_version(self, model_id: int = None, version_id: int = None, provider_name: str = None) -> Optional[Dict]: + """Get specific model version using specified or default provider""" + provider = self._get_provider(provider_name) + return await provider.get_model_version(model_id, version_id) + + async def get_model_version_info(self, version_id: str, provider_name: str = None) -> Tuple[Optional[Dict], Optional[str]]: + """Fetch model version info using specified or default provider""" + provider = self._get_provider(provider_name) + return await provider.get_model_version_info(version_id) + + async def get_model_metadata(self, model_id: str, provider_name: str = None) -> Tuple[Optional[Dict], int]: + """Fetch model metadata using specified or default provider""" + provider = self._get_provider(provider_name) + return await provider.get_model_metadata(model_id) + + def _get_provider(self, provider_name: str = None) -> ModelMetadataProvider: + """Get provider by name or default provider""" + if provider_name and provider_name in self.providers: + return self.providers[provider_name] + + if self.default_provider is None: + raise ValueError("No default provider set and no valid provider specified") + + return self.providers[self.default_provider] diff --git a/py/utils/routes_common.py b/py/utils/routes_common.py index 8f9cfe7f..778725a1 100644 --- a/py/utils/routes_common.py +++ b/py/utils/routes_common.py @@ -13,6 +13,7 @@ from ..utils.exif_utils import ExifUtils from ..utils.metadata_manager import MetadataManager from ..services.download_manager import DownloadManager from ..services.websocket_manager import ws_manager +from ..services.metadata_service import get_metadata_provider logger = logging.getLogger(__name__) diff --git a/refs/civitai.sql b/refs/civitai.sql new file mode 100644 index 00000000..99814e19 --- /dev/null +++ b/refs/civitai.sql @@ -0,0 +1,38 @@ +CREATE TABLE models ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + type TEXT NOT NULL, + username TEXT, + data TEXT NOT NULL, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL + ) STRICT; +CREATE TABLE model_versions ( + id INTEGER PRIMARY KEY, + model_id INTEGER NOT NULL, + position INTEGER NOT NULL, + name TEXT NOT NULL, + base_model TEXT NOT NULL, + published_at INTEGER, + data TEXT NOT NULL, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL + ) STRICT; +CREATE INDEX model_versions_model_id_idx ON model_versions (model_id); +CREATE TABLE model_files ( + id INTEGER PRIMARY KEY, + model_id INTEGER NOT NULL, + version_id INTEGER NOT NULL, + type TEXT NOT NULL, + sha256 TEXT, + data TEXT NOT NULL, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL + ) STRICT; +CREATE INDEX model_files_model_id_idx ON model_files (model_id); +CREATE INDEX model_files_version_id_idx ON model_files (version_id); +CREATE TABLE archived_model_files ( + file_id INTEGER PRIMARY KEY, + model_id INTEGER NOT NULL, + version_id INTEGER NOT NULL + ) STRICT; \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index bb65770d..9e280c64 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ toml numpy natsort GitPython +aiosqlite