feat(metadata): implement metadata providers and initialize metadata service

- Added ModelMetadataProvider and CivitaiModelMetadataProvider for handling model metadata.
- Introduced SQLiteModelMetadataProvider for SQLite database integration.
- Created metadata_service.py to initialize and configure metadata providers.
- Updated CivitaiClient to register as a metadata provider.
- Refactored download_manager to use the new download_file method.
- Added SQL schema for models, model_versions, and model_files.
- Updated requirements.txt to include aiosqlite.
This commit is contained in:
Will Miao
2025-09-08 10:33:59 +08:00
parent d287883671
commit 9ba3e2c204
8 changed files with 473 additions and 23 deletions

View File

@@ -190,6 +190,9 @@ class LoraManager:
# Register DownloadManager with ServiceRegistry # Register DownloadManager with ServiceRegistry
await ServiceRegistry.get_download_manager() await ServiceRegistry.get_download_manager()
from .services.metadata_service import initialize_metadata_providers
await initialize_metadata_providers()
# Initialize WebSocket manager # Initialize WebSocket manager
await ServiceRegistry.get_websocket_manager() await ServiceRegistry.get_websocket_manager()

View File

@@ -3,9 +3,8 @@ import aiohttp
import os import os
import logging import logging
import asyncio import asyncio
from email.parser import Parser
from typing import Optional, Dict, Tuple, List from typing import Optional, Dict, Tuple, List
from urllib.parse import unquote from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -19,6 +18,11 @@ class CivitaiClient:
async with cls._lock: async with cls._lock:
if cls._instance is None: if cls._instance is None:
cls._instance = cls() cls._instance = cls()
# Register this client as a metadata provider
provider_manager = await ModelMetadataProviderManager.get_instance()
provider_manager.register_provider('civitai', CivitaiModelMetadataProvider(cls._instance), True)
return cls._instance return cls._instance
def __init__(self): def __init__(self):
@@ -69,24 +73,6 @@ class CivitaiClient:
return await self.session return await self.session
def _parse_content_disposition(self, header: str) -> str:
"""Parse filename from content-disposition header"""
if not header:
return None
# Handle quoted filenames
if 'filename="' in header:
start = header.index('filename="') + 10
end = header.index('"', start)
return unquote(header[start:end])
# Fallback to original parsing
disposition = Parser().parsestr(f'Content-Disposition: {header}')
filename = disposition.get_param('filename')
if filename:
return unquote(filename)
return None
def _get_request_headers(self) -> dict: def _get_request_headers(self) -> dict:
"""Get request headers with optional API key""" """Get request headers with optional API key"""
headers = { headers = {
@@ -101,7 +87,7 @@ class CivitaiClient:
return headers return headers
async def _download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]: async def download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]:
"""Download file with resumable downloads and retry mechanism """Download file with resumable downloads and retry mechanism
Args: Args:
@@ -129,7 +115,6 @@ class CivitaiClient:
logger.info(f"Resuming download from offset {resume_offset} bytes") logger.info(f"Resuming download from offset {resume_offset} bytes")
total_size = 0 total_size = 0
filename = default_filename
while retry_count <= max_retries: while retry_count <= max_retries:
try: try:

View File

@@ -487,7 +487,7 @@ class DownloadManager:
await progress_callback(3) # 3% progress after preview download await progress_callback(3) # 3% progress after preview download
# Download model file with progress tracking # Download model file with progress tracking
success, result = await civitai_client._download_file( success, result = await civitai_client.download_file(
download_url, download_url,
save_dir, save_dir,
os.path.basename(save_path), os.path.basename(save_path),

View File

@@ -0,0 +1,33 @@
import os
import logging
from .model_metadata_provider import ModelMetadataProviderManager, SQLiteModelMetadataProvider
logger = logging.getLogger(__name__)
async def initialize_metadata_providers():
"""Initialize and configure all metadata providers"""
provider_manager = await ModelMetadataProviderManager.get_instance()
# Use hardcoded SQLite DB path if not set in settings
db_path = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
'civitai', 'civitai.sqlite'
)
if db_path and os.path.exists(db_path):
try:
sqlite_provider = SQLiteModelMetadataProvider(db_path)
provider_manager.register_provider('sqlite', sqlite_provider)
logger.info(f"SQLite metadata provider registered with database: {db_path}")
except Exception as e:
logger.error(f"Failed to initialize SQLite metadata provider: {e}")
return provider_manager
async def get_metadata_provider(provider_name: str = None):
"""Get a specific metadata provider or default provider"""
provider_manager = await ModelMetadataProviderManager.get_instance()
if provider_name:
return provider_manager._get_provider(provider_name)
return provider_manager._get_provider()

View File

@@ -0,0 +1,389 @@
from abc import ABC, abstractmethod
import json
import aiosqlite
import logging
from typing import Optional, Dict, List, Tuple, Any
logger = logging.getLogger(__name__)
class ModelMetadataProvider(ABC):
"""Base abstract class for all model metadata providers"""
@abstractmethod
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
"""Find model by hash value"""
pass
@abstractmethod
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
"""Get all versions of a model with their details"""
pass
@abstractmethod
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
"""Get specific model version with additional metadata"""
pass
@abstractmethod
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Fetch model version metadata"""
pass
@abstractmethod
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
"""Fetch model metadata (description, tags, and creator info)"""
pass
class CivitaiModelMetadataProvider(ModelMetadataProvider):
"""Provider that uses Civitai API for metadata"""
def __init__(self, civitai_client):
self.client = civitai_client
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
return await self.client.get_model_by_hash(model_hash)
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
return await self.client.get_model_versions(model_id)
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
return await self.client.get_model_version(model_id, version_id)
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
return await self.client.get_model_version_info(version_id)
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
return await self.client.get_model_metadata(model_id)
class SQLiteModelMetadataProvider(ModelMetadataProvider):
"""Provider that uses SQLite database for metadata"""
def __init__(self, db_path: str):
self.db_path = db_path
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
"""Find model by hash value from SQLite database"""
async with aiosqlite.connect(self.db_path) as db:
# Look up in model_files table to get model_id and version_id
query = """
SELECT model_id, version_id
FROM model_files
WHERE sha256 = ?
LIMIT 1
"""
db.row_factory = aiosqlite.Row
cursor = await db.execute(query, (model_hash.upper(),))
file_row = await cursor.fetchone()
if not file_row:
return None
# Get version details
model_id = file_row['model_id']
version_id = file_row['version_id']
# Build response in the same format as Civitai API
return await self._get_version_with_model_data(db, model_id, version_id)
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
"""Get all versions of a model from SQLite database"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
# First check if model exists
model_query = "SELECT * FROM models WHERE id = ?"
cursor = await db.execute(model_query, (model_id,))
model_row = await cursor.fetchone()
if not model_row:
return None
model_data = json.loads(model_row['data'])
model_type = model_row['type']
# Get all versions for this model
versions_query = """
SELECT id, name, base_model, data, position, published_at
FROM model_versions
WHERE model_id = ?
ORDER BY position ASC
"""
cursor = await db.execute(versions_query, (model_id,))
version_rows = await cursor.fetchall()
if not version_rows:
return {'modelVersions': [], 'type': model_type}
# Format versions similar to Civitai API
model_versions = []
for row in version_rows:
version_data = json.loads(row['data'])
# Add fields from the row to ensure we have the basic fields
version_entry = {
'id': row['id'],
'modelId': int(model_id),
'name': row['name'],
'baseModel': row['base_model'],
'model': {
'name': model_row['name'],
'type': model_type,
}
}
# Update with any additional data
version_entry.update(version_data)
model_versions.append(version_entry)
return {
'modelVersions': model_versions,
'type': model_type
}
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
"""Get specific model version with additional metadata from SQLite database"""
if not model_id and not version_id:
return None
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
# Case 1: Only version_id is provided
if model_id is None and version_id is not None:
# First get the version info to extract model_id
version_query = "SELECT model_id FROM model_versions WHERE id = ?"
cursor = await db.execute(version_query, (version_id,))
version_row = await cursor.fetchone()
if not version_row:
return None
model_id = version_row['model_id']
# Case 2: model_id is provided but version_id is not
elif model_id is not None and version_id is None:
# Find the latest version
version_query = """
SELECT id FROM model_versions
WHERE model_id = ?
ORDER BY position ASC
LIMIT 1
"""
cursor = await db.execute(version_query, (model_id,))
version_row = await cursor.fetchone()
if not version_row:
return None
version_id = version_row['id']
# Now we have both model_id and version_id, get the full data
return await self._get_version_with_model_data(db, model_id, version_id)
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Fetch model version metadata from SQLite database"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
# Get version details
version_query = "SELECT model_id FROM model_versions WHERE id = ?"
cursor = await db.execute(version_query, (version_id,))
version_row = await cursor.fetchone()
if not version_row:
return None, "Model version not found"
model_id = version_row['model_id']
# Build complete version data with model info
version_data = await self._get_version_with_model_data(db, model_id, version_id)
return version_data, None
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
"""Fetch model metadata from SQLite database"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
# Get model details
model_query = "SELECT name, type, data, username FROM models WHERE id = ?"
cursor = await db.execute(model_query, (model_id,))
model_row = await cursor.fetchone()
if not model_row:
return None, 404
# Parse data JSON
try:
model_data = json.loads(model_row['data'])
# Extract relevant metadata
metadata = {
"description": model_data.get("description", "No model description available"),
"tags": model_data.get("tags", []),
"creator": {
"username": model_row['username'] or model_data.get("creator", {}).get("username"),
"image": model_data.get("creator", {}).get("image")
}
}
return metadata, 200
except json.JSONDecodeError:
return None, 500
async def _get_version_with_model_data(self, db, model_id, version_id) -> Optional[Dict]:
"""Helper to build version data with model information"""
# Get version details
version_query = "SELECT name, base_model, data FROM model_versions WHERE id = ? AND model_id = ?"
cursor = await db.execute(version_query, (version_id, model_id))
version_row = await cursor.fetchone()
if not version_row:
return None
# Get model details
model_query = "SELECT name, type, data, username FROM models WHERE id = ?"
cursor = await db.execute(model_query, (model_id,))
model_row = await cursor.fetchone()
if not model_row:
return None
# Parse JSON data
try:
version_data = json.loads(version_row['data'])
model_data = json.loads(model_row['data'])
# Build response
result = {
"id": int(version_id),
"modelId": int(model_id),
"name": version_row['name'],
"baseModel": version_row['base_model'],
"model": {
"name": model_row['name'],
"description": model_data.get("description"),
"type": model_row['type'],
"tags": model_data.get("tags", [])
},
"creator": {
"username": model_row['username'] or model_data.get("creator", {}).get("username"),
"image": model_data.get("creator", {}).get("image")
}
}
# Add any additional fields from version data
result.update(version_data)
return result
except json.JSONDecodeError:
return None
class FallbackMetadataProvider(ModelMetadataProvider):
"""Try providers in order, return first successful result."""
def __init__(self, providers: list):
self.providers = providers
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
for provider in self.providers:
try:
result = await provider.get_model_by_hash(model_hash)
if result:
return result
except Exception:
continue
return None
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
for provider in self.providers:
try:
result = await provider.get_model_versions(model_id)
if result:
return result
except Exception:
continue
return None
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
for provider in self.providers:
try:
result = await provider.get_model_version(model_id, version_id)
if result:
return result
except Exception:
continue
return None
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
for provider in self.providers:
try:
result, err = await provider.get_model_version_info(version_id)
if result:
return result, err
except Exception:
continue
return None, "Not found in any provider"
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
for provider in self.providers:
try:
result, code = await provider.get_model_metadata(model_id)
if result:
return result, code
except Exception:
continue
return None, 404
class ModelMetadataProviderManager:
"""Manager for selecting and using model metadata providers"""
_instance = None
@classmethod
async def get_instance(cls):
"""Get singleton instance of ModelMetadataProviderManager"""
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
self.providers = {}
self.default_provider = None
def register_provider(self, name: str, provider: ModelMetadataProvider, is_default: bool = False):
"""Register a metadata provider"""
self.providers[name] = provider
if is_default or self.default_provider is None:
self.default_provider = name
async def get_model_by_hash(self, model_hash: str, provider_name: str = None) -> Optional[Dict]:
"""Find model by hash using specified or default provider"""
provider = self._get_provider(provider_name)
return await provider.get_model_by_hash(model_hash)
async def get_model_versions(self, model_id: str, provider_name: str = None) -> Optional[Dict]:
"""Get model versions using specified or default provider"""
provider = self._get_provider(provider_name)
return await provider.get_model_versions(model_id)
async def get_model_version(self, model_id: int = None, version_id: int = None, provider_name: str = None) -> Optional[Dict]:
"""Get specific model version using specified or default provider"""
provider = self._get_provider(provider_name)
return await provider.get_model_version(model_id, version_id)
async def get_model_version_info(self, version_id: str, provider_name: str = None) -> Tuple[Optional[Dict], Optional[str]]:
"""Fetch model version info using specified or default provider"""
provider = self._get_provider(provider_name)
return await provider.get_model_version_info(version_id)
async def get_model_metadata(self, model_id: str, provider_name: str = None) -> Tuple[Optional[Dict], int]:
"""Fetch model metadata using specified or default provider"""
provider = self._get_provider(provider_name)
return await provider.get_model_metadata(model_id)
def _get_provider(self, provider_name: str = None) -> ModelMetadataProvider:
"""Get provider by name or default provider"""
if provider_name and provider_name in self.providers:
return self.providers[provider_name]
if self.default_provider is None:
raise ValueError("No default provider set and no valid provider specified")
return self.providers[self.default_provider]

View File

@@ -13,6 +13,7 @@ from ..utils.exif_utils import ExifUtils
from ..utils.metadata_manager import MetadataManager from ..utils.metadata_manager import MetadataManager
from ..services.download_manager import DownloadManager from ..services.download_manager import DownloadManager
from ..services.websocket_manager import ws_manager from ..services.websocket_manager import ws_manager
from ..services.metadata_service import get_metadata_provider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

38
refs/civitai.sql Normal file
View File

@@ -0,0 +1,38 @@
CREATE TABLE models (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
type TEXT NOT NULL,
username TEXT,
data TEXT NOT NULL,
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL
) STRICT;
CREATE TABLE model_versions (
id INTEGER PRIMARY KEY,
model_id INTEGER NOT NULL,
position INTEGER NOT NULL,
name TEXT NOT NULL,
base_model TEXT NOT NULL,
published_at INTEGER,
data TEXT NOT NULL,
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL
) STRICT;
CREATE INDEX model_versions_model_id_idx ON model_versions (model_id);
CREATE TABLE model_files (
id INTEGER PRIMARY KEY,
model_id INTEGER NOT NULL,
version_id INTEGER NOT NULL,
type TEXT NOT NULL,
sha256 TEXT,
data TEXT NOT NULL,
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL
) STRICT;
CREATE INDEX model_files_model_id_idx ON model_files (model_id);
CREATE INDEX model_files_version_id_idx ON model_files (version_id);
CREATE TABLE archived_model_files (
file_id INTEGER PRIMARY KEY,
model_id INTEGER NOT NULL,
version_id INTEGER NOT NULL
) STRICT;

View File

@@ -8,3 +8,4 @@ toml
numpy numpy
natsort natsort
GitPython GitPython
aiosqlite