feat(misc): add civitai user model lookup

This commit is contained in:
pixelpaws
2025-10-09 11:49:41 +08:00
parent d997eaa429
commit 8cf762ffd3
6 changed files with 372 additions and 7 deletions

View File

@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
import json
import logging
from typing import Optional, Dict, Tuple, Any
from typing import Optional, Dict, Tuple, Any, List
from .downloader import get_downloader
try:
@@ -61,6 +61,11 @@ class ModelMetadataProvider(ABC):
"""Fetch model version metadata"""
pass
@abstractmethod
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
"""Fetch models owned by the specified user"""
pass
class CivitaiModelMetadataProvider(ModelMetadataProvider):
"""Provider that uses Civitai API for metadata"""
@@ -79,6 +84,9 @@ class CivitaiModelMetadataProvider(ModelMetadataProvider):
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
return await self.client.get_model_version_info(version_id)
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
return await self.client.get_user_models(username)
class CivArchiveModelMetadataProvider(ModelMetadataProvider):
"""Provider that uses CivArchive HTML page parsing for metadata"""
@@ -197,6 +205,10 @@ class CivArchiveModelMetadataProvider(ModelMetadataProvider):
"""Not supported by CivArchive provider - requires both model_id and version_id"""
return None, "CivArchive provider requires both model_id and version_id"
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
"""Not supported by CivArchive provider"""
return None
class SQLiteModelMetadataProvider(ModelMetadataProvider):
"""Provider that uses SQLite database for metadata"""
@@ -329,20 +341,24 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
"""Fetch model version metadata from SQLite database"""
async with self._aiosqlite.connect(self.db_path) as db:
db.row_factory = self._aiosqlite.Row
# Get version details
version_query = "SELECT model_id FROM model_versions WHERE id = ?"
cursor = await db.execute(version_query, (version_id,))
version_row = await cursor.fetchone()
if not version_row:
return None, "Model version not found"
model_id = version_row['model_id']
# Build complete version data with model info
version_data = await self._get_version_with_model_data(db, model_id, version_id)
return version_data, None
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
"""Listing models by username is not supported for archive database"""
return None
async def _get_version_with_model_data(self, db, model_id, version_id) -> Optional[Dict]:
"""Helper to build version data with model information"""
@@ -481,6 +497,17 @@ class FallbackMetadataProvider(ModelMetadataProvider):
continue
return None, "No provider could retrieve the data"
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
for provider in self.providers:
try:
result = await provider.get_user_models(username)
if result is not None:
return result
except Exception as e:
logger.debug(f"Provider failed for get_user_models: {e}")
continue
return None
class ModelMetadataProviderManager:
"""Manager for selecting and using model metadata providers"""
@@ -522,6 +549,11 @@ class ModelMetadataProviderManager:
"""Fetch model version info using specified or default provider"""
provider = self._get_provider(provider_name)
return await provider.get_model_version_info(version_id)
async def get_user_models(self, username: str, provider_name: str = None) -> Optional[List[Dict]]:
"""Fetch models owned by the specified user"""
provider = self._get_provider(provider_name)
return await provider.get_user_models(username)
def _get_provider(self, provider_name: str = None) -> ModelMetadataProvider:
"""Get provider by name or default provider"""