mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
396 lines
16 KiB
Python
396 lines
16 KiB
Python
from abc import ABC, abstractmethod
|
|
import json
|
|
import aiosqlite
|
|
import logging
|
|
from typing import Optional, Dict, List, Tuple, Any
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ModelMetadataProvider(ABC):
|
|
"""Base abstract class for all model metadata providers"""
|
|
|
|
@abstractmethod
|
|
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
|
"""Find model by hash value"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
|
"""Get all versions of a model with their details"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
|
"""Get specific model version with additional metadata"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
|
"""Fetch model version metadata"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
|
|
"""Fetch model metadata (description, tags, and creator info)"""
|
|
pass
|
|
|
|
class CivitaiModelMetadataProvider(ModelMetadataProvider):
|
|
"""Provider that uses Civitai API for metadata"""
|
|
|
|
def __init__(self, civitai_client):
|
|
self.client = civitai_client
|
|
|
|
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
|
return await self.client.get_model_by_hash(model_hash)
|
|
|
|
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
|
return await self.client.get_model_versions(model_id)
|
|
|
|
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
|
return await self.client.get_model_version(model_id, version_id)
|
|
|
|
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
|
return await self.client.get_model_version_info(version_id)
|
|
|
|
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
|
|
return await self.client.get_model_metadata(model_id)
|
|
|
|
class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
|
"""Provider that uses SQLite database for metadata"""
|
|
|
|
def __init__(self, db_path: str):
|
|
self.db_path = db_path
|
|
|
|
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
|
"""Find model by hash value from SQLite database"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
# Look up in model_files table to get model_id and version_id
|
|
query = """
|
|
SELECT model_id, version_id
|
|
FROM model_files
|
|
WHERE sha256 = ?
|
|
LIMIT 1
|
|
"""
|
|
db.row_factory = aiosqlite.Row
|
|
cursor = await db.execute(query, (model_hash.upper(),))
|
|
file_row = await cursor.fetchone()
|
|
|
|
if not file_row:
|
|
return None
|
|
|
|
# Get version details
|
|
model_id = file_row['model_id']
|
|
version_id = file_row['version_id']
|
|
|
|
# Build response in the same format as Civitai API
|
|
return await self._get_version_with_model_data(db, model_id, version_id)
|
|
|
|
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
|
"""Get all versions of a model from SQLite database"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
|
|
# First check if model exists
|
|
model_query = "SELECT * FROM models WHERE id = ?"
|
|
cursor = await db.execute(model_query, (model_id,))
|
|
model_row = await cursor.fetchone()
|
|
|
|
if not model_row:
|
|
return None
|
|
|
|
model_data = json.loads(model_row['data'])
|
|
model_type = model_row['type']
|
|
|
|
# Get all versions for this model
|
|
versions_query = """
|
|
SELECT id, name, base_model, data, position, published_at
|
|
FROM model_versions
|
|
WHERE model_id = ?
|
|
ORDER BY position ASC
|
|
"""
|
|
cursor = await db.execute(versions_query, (model_id,))
|
|
version_rows = await cursor.fetchall()
|
|
|
|
if not version_rows:
|
|
return {'modelVersions': [], 'type': model_type}
|
|
|
|
# Format versions similar to Civitai API
|
|
model_versions = []
|
|
for row in version_rows:
|
|
version_data = json.loads(row['data'])
|
|
# Add fields from the row to ensure we have the basic fields
|
|
version_entry = {
|
|
'id': row['id'],
|
|
'modelId': int(model_id),
|
|
'name': row['name'],
|
|
'baseModel': row['base_model'],
|
|
'model': {
|
|
'name': model_row['name'],
|
|
'type': model_type,
|
|
},
|
|
'source': 'archive_db'
|
|
}
|
|
# Update with any additional data
|
|
version_entry.update(version_data)
|
|
model_versions.append(version_entry)
|
|
|
|
return {
|
|
'modelVersions': model_versions,
|
|
'type': model_type
|
|
}
|
|
|
|
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
|
"""Get specific model version with additional metadata from SQLite database"""
|
|
if not model_id and not version_id:
|
|
return None
|
|
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
|
|
# Case 1: Only version_id is provided
|
|
if model_id is None and version_id is not None:
|
|
# First get the version info to extract model_id
|
|
version_query = "SELECT model_id FROM model_versions WHERE id = ?"
|
|
cursor = await db.execute(version_query, (version_id,))
|
|
version_row = await cursor.fetchone()
|
|
|
|
if not version_row:
|
|
return None
|
|
|
|
model_id = version_row['model_id']
|
|
|
|
# Case 2: model_id is provided but version_id is not
|
|
elif model_id is not None and version_id is None:
|
|
# Find the latest version
|
|
version_query = """
|
|
SELECT id FROM model_versions
|
|
WHERE model_id = ?
|
|
ORDER BY position ASC
|
|
LIMIT 1
|
|
"""
|
|
cursor = await db.execute(version_query, (model_id,))
|
|
version_row = await cursor.fetchone()
|
|
|
|
if not version_row:
|
|
return None
|
|
|
|
version_id = version_row['id']
|
|
|
|
# Now we have both model_id and version_id, get the full data
|
|
return await self._get_version_with_model_data(db, model_id, version_id)
|
|
|
|
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
|
"""Fetch model version metadata from SQLite database"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
|
|
# Get version details
|
|
version_query = "SELECT model_id FROM model_versions WHERE id = ?"
|
|
cursor = await db.execute(version_query, (version_id,))
|
|
version_row = await cursor.fetchone()
|
|
|
|
if not version_row:
|
|
return None, "Model version not found"
|
|
|
|
model_id = version_row['model_id']
|
|
|
|
# Build complete version data with model info
|
|
version_data = await self._get_version_with_model_data(db, model_id, version_id)
|
|
return version_data, None
|
|
|
|
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
|
|
"""Fetch model metadata from SQLite database"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
|
|
# Get model details
|
|
model_query = "SELECT name, type, data, username FROM models WHERE id = ?"
|
|
cursor = await db.execute(model_query, (model_id,))
|
|
model_row = await cursor.fetchone()
|
|
|
|
if not model_row:
|
|
return None, 404
|
|
|
|
# Parse data JSON
|
|
try:
|
|
model_data = json.loads(model_row['data'])
|
|
|
|
# Extract relevant metadata
|
|
metadata = {
|
|
"description": model_data.get("description", "No model description available"),
|
|
"tags": model_data.get("tags", []),
|
|
"creator": {
|
|
"username": model_row['username'] or model_data.get("creator", {}).get("username"),
|
|
"image": model_data.get("creator", {}).get("image")
|
|
}
|
|
}
|
|
|
|
return metadata, 200
|
|
except json.JSONDecodeError:
|
|
return None, 500
|
|
|
|
async def _get_version_with_model_data(self, db, model_id, version_id) -> Optional[Dict]:
|
|
"""Helper to build version data with model information"""
|
|
# Get version details
|
|
version_query = "SELECT name, base_model, data FROM model_versions WHERE id = ? AND model_id = ?"
|
|
cursor = await db.execute(version_query, (version_id, model_id))
|
|
version_row = await cursor.fetchone()
|
|
|
|
if not version_row:
|
|
return None
|
|
|
|
# Get model details
|
|
model_query = "SELECT name, type, data, username FROM models WHERE id = ?"
|
|
cursor = await db.execute(model_query, (model_id,))
|
|
model_row = await cursor.fetchone()
|
|
|
|
if not model_row:
|
|
return None
|
|
|
|
# Parse JSON data
|
|
try:
|
|
version_data = json.loads(version_row['data'])
|
|
model_data = json.loads(model_row['data'])
|
|
|
|
# Build response
|
|
result = {
|
|
"id": int(version_id),
|
|
"modelId": int(model_id),
|
|
"name": version_row['name'],
|
|
"baseModel": version_row['base_model'],
|
|
"model": {
|
|
"name": model_row['name'],
|
|
"description": model_data.get("description"),
|
|
"type": model_row['type'],
|
|
"tags": model_data.get("tags", [])
|
|
},
|
|
"creator": {
|
|
"username": model_row['username'] or model_data.get("creator", {}).get("username"),
|
|
"image": model_data.get("creator", {}).get("image")
|
|
},
|
|
"source": "archive_db"
|
|
}
|
|
|
|
# Add any additional fields from version data
|
|
result.update(version_data)
|
|
|
|
return result
|
|
except json.JSONDecodeError:
|
|
return None
|
|
|
|
class FallbackMetadataProvider(ModelMetadataProvider):
|
|
"""Try providers in order, return first successful result."""
|
|
def __init__(self, providers: list):
|
|
self.providers = providers
|
|
|
|
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
|
for provider in self.providers:
|
|
try:
|
|
result = await provider.get_model_by_hash(model_hash)
|
|
if result:
|
|
return result
|
|
except Exception:
|
|
continue
|
|
return None
|
|
|
|
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
|
for provider in self.providers:
|
|
try:
|
|
result = await provider.get_model_versions(model_id)
|
|
if result:
|
|
return result
|
|
except Exception as e:
|
|
logger.debug(f"Provider failed for get_model_versions: {e}")
|
|
continue
|
|
return None
|
|
|
|
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
|
for provider in self.providers:
|
|
try:
|
|
result = await provider.get_model_version(model_id, version_id)
|
|
if result:
|
|
return result
|
|
except Exception as e:
|
|
logger.debug(f"Provider failed for get_model_version: {e}")
|
|
continue
|
|
return None
|
|
|
|
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
|
for provider in self.providers:
|
|
try:
|
|
result, error = await provider.get_model_version_info(version_id)
|
|
if result:
|
|
return result, error
|
|
except Exception as e:
|
|
logger.debug(f"Provider failed for get_model_version_info: {e}")
|
|
continue
|
|
return None, "No provider could retrieve the data"
|
|
|
|
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
|
|
for provider in self.providers:
|
|
try:
|
|
result, status = await provider.get_model_metadata(model_id)
|
|
if result:
|
|
return result, status
|
|
except Exception as e:
|
|
logger.debug(f"Provider failed for get_model_metadata: {e}")
|
|
continue
|
|
return None, 404
|
|
|
|
class ModelMetadataProviderManager:
|
|
"""Manager for selecting and using model metadata providers"""
|
|
|
|
_instance = None
|
|
|
|
@classmethod
|
|
async def get_instance(cls):
|
|
"""Get singleton instance of ModelMetadataProviderManager"""
|
|
if cls._instance is None:
|
|
cls._instance = cls()
|
|
return cls._instance
|
|
|
|
def __init__(self):
|
|
self.providers = {}
|
|
self.default_provider = None
|
|
|
|
def register_provider(self, name: str, provider: ModelMetadataProvider, is_default: bool = False):
|
|
"""Register a metadata provider"""
|
|
self.providers[name] = provider
|
|
if is_default or self.default_provider is None:
|
|
self.default_provider = name
|
|
|
|
async def get_model_by_hash(self, model_hash: str, provider_name: str = None) -> Optional[Dict]:
|
|
"""Find model by hash using specified or default provider"""
|
|
provider = self._get_provider(provider_name)
|
|
return await provider.get_model_by_hash(model_hash)
|
|
|
|
async def get_model_versions(self, model_id: str, provider_name: str = None) -> Optional[Dict]:
|
|
"""Get model versions using specified or default provider"""
|
|
provider = self._get_provider(provider_name)
|
|
return await provider.get_model_versions(model_id)
|
|
|
|
async def get_model_version(self, model_id: int = None, version_id: int = None, provider_name: str = None) -> Optional[Dict]:
|
|
"""Get specific model version using specified or default provider"""
|
|
provider = self._get_provider(provider_name)
|
|
return await provider.get_model_version(model_id, version_id)
|
|
|
|
async def get_model_version_info(self, version_id: str, provider_name: str = None) -> Tuple[Optional[Dict], Optional[str]]:
|
|
"""Fetch model version info using specified or default provider"""
|
|
provider = self._get_provider(provider_name)
|
|
return await provider.get_model_version_info(version_id)
|
|
|
|
async def get_model_metadata(self, model_id: str, provider_name: str = None) -> Tuple[Optional[Dict], int]:
|
|
"""Fetch model metadata using specified or default provider"""
|
|
provider = self._get_provider(provider_name)
|
|
return await provider.get_model_metadata(model_id)
|
|
|
|
def _get_provider(self, provider_name: str = None) -> ModelMetadataProvider:
|
|
"""Get provider by name or default provider"""
|
|
if provider_name and provider_name in self.providers:
|
|
return self.providers[provider_name]
|
|
|
|
if self.default_provider is None:
|
|
raise ValueError("No default provider set and no valid provider specified")
|
|
|
|
return self.providers[self.default_provider]
|