refactor(metadata): update model fetching methods to return error messages alongside results

This commit is contained in:
Will Miao
2025-09-19 16:36:34 +08:00
parent fc6f1bf95b
commit 1610048974
7 changed files with 66 additions and 175 deletions

View File

@@ -55,7 +55,7 @@ class RecipeMetadataParser(ABC):
# Unpack the tuple to get the actual data # Unpack the tuple to get the actual data
civitai_info, error_msg = civitai_info_tuple if isinstance(civitai_info_tuple, tuple) else (civitai_info_tuple, None) civitai_info, error_msg = civitai_info_tuple if isinstance(civitai_info_tuple, tuple) else (civitai_info_tuple, None)
if not civitai_info or civitai_info.get("error") == "Model not found": if not civitai_info or error_msg == "Model not found":
# Model not found or deleted # Model not found or deleted
lora_entry['isDeleted'] = True lora_entry['isDeleted'] = True
lora_entry['thumbnailUrl'] = '/loras_static/images/no-preview.png' lora_entry['thumbnailUrl'] = '/loras_static/images/no-preview.png'

View File

@@ -91,7 +91,7 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
result["base_model"] = metadata["baseModel"] result["base_model"] = metadata["baseModel"]
elif "Model hash" in metadata and metadata_provider: elif "Model hash" in metadata and metadata_provider:
model_hash = metadata["Model hash"] model_hash = metadata["Model hash"]
model_info = await metadata_provider.get_model_by_hash(model_hash) model_info, error = await metadata_provider.get_model_by_hash(model_hash)
if model_info: if model_info:
result["base_model"] = model_info.get("baseModel", "") result["base_model"] = model_info.get("baseModel", "")
elif "Model" in metadata and isinstance(metadata.get("resources"), list): elif "Model" in metadata and isinstance(metadata.get("resources"), list):
@@ -100,7 +100,7 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
if resource.get("type") == "model" and resource.get("name") == metadata.get("Model"): if resource.get("type") == "model" and resource.get("name") == metadata.get("Model"):
# This is likely the checkpoint model # This is likely the checkpoint model
if metadata_provider and resource.get("hash"): if metadata_provider and resource.get("hash"):
model_info = await metadata_provider.get_model_by_hash(resource.get("hash")) model_info, error = await metadata_provider.get_model_by_hash(resource.get("hash"))
if model_info: if model_info:
result["base_model"] = model_info.get("baseModel", "") result["base_model"] = model_info.get("baseModel", "")
@@ -201,11 +201,7 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
if version_id and metadata_provider: if version_id and metadata_provider:
try: try:
# Use get_model_version_info instead of get_model_version # Use get_model_version_info instead of get_model_version
civitai_info, error = await metadata_provider.get_model_version_info(version_id) civitai_info = await metadata_provider.get_model_version_info(version_id)
if error:
logger.warning(f"Error getting model version info: {error}")
continue
populated_entry = await self.populate_lora_from_civitai( populated_entry = await self.populate_lora_from_civitai(
lora_entry, lora_entry,
@@ -267,26 +263,23 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
if version_id and metadata_provider: if version_id and metadata_provider:
try: try:
# Use get_model_version_info with the version ID # Use get_model_version_info with the version ID
civitai_info, error = await metadata_provider.get_model_version_info(version_id) civitai_info = await metadata_provider.get_model_version_info(version_id)
if error: populated_entry = await self.populate_lora_from_civitai(
logger.warning(f"Error getting model version info: {error}") lora_entry,
else: civitai_info,
populated_entry = await self.populate_lora_from_civitai( recipe_scanner,
lora_entry, base_model_counts
civitai_info, )
recipe_scanner,
base_model_counts
)
if populated_entry is None: if populated_entry is None:
continue # Skip invalid LoRA types continue # Skip invalid LoRA types
lora_entry = populated_entry lora_entry = populated_entry
# Track this LoRA for deduplication # Track this LoRA for deduplication
if version_id: if version_id:
added_loras[version_id] = len(result["loras"]) added_loras[version_id] = len(result["loras"])
except Exception as e: except Exception as e:
logger.error(f"Error fetching Civitai info for model ID {version_id}: {e}") logger.error(f"Error fetching Civitai info for model ID {version_id}: {e}")

View File

@@ -624,7 +624,7 @@ class BaseModelRoutes(ABC):
success = 0 success = 0
needs_resort = False needs_resort = False
# Prepare models to process, only those without CivitAI data or missing tags, description, or creator # Prepare models to process, only those without CivitAI data
enable_metadata_archive_db = settings.get('enable_metadata_archive_db', False) enable_metadata_archive_db = settings.get('enable_metadata_archive_db', False)
to_process = [ to_process = [
model for model in cache.raw_data model for model in cache.raw_data
@@ -633,9 +633,6 @@ class BaseModelRoutes(ABC):
and ( and (
not model.get('civitai') not model.get('civitai')
or not model['civitai'].get('id') or not model['civitai'].get('id')
# or not model.get('tags') # Skipping tag cause it could be empty legitimately
# or not model.get('modelDescription')
# or not (model.get('civitai') and model['civitai'].get('creator'))
) )
and ( and (
(enable_metadata_archive_db) (enable_metadata_archive_db)
@@ -782,7 +779,13 @@ class BaseModelRoutes(ABC):
try: try:
hash = request.match_info.get('hash') hash = request.match_info.get('hash')
metadata_provider = await get_default_metadata_provider() metadata_provider = await get_default_metadata_provider()
model = await metadata_provider.get_model_by_hash(hash) model, error = await metadata_provider.get_model_by_hash(hash)
if error:
logger.warning(f"Error getting model by hash: {error}")
return web.json_response({
"success": False,
"error": error
}, status=404)
return web.json_response(model) return web.json_response(model)
except Exception as e: except Exception as e:
logger.error(f"Error fetching model details by hash: {e}") logger.error(f"Error fetching model details by hash: {e}")

View File

@@ -7,7 +7,7 @@ import shutil
import tempfile import tempfile
from aiohttp import web from aiohttp import web
from typing import Dict, List from typing import Dict, List
from ..services.downloader import get_downloader, Downloader from ..services.downloader import get_downloader
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -265,7 +265,7 @@ class UpdateRoutes:
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/commits/main" github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/commits/main"
try: try:
downloader = await Downloader.get_instance() downloader = await get_downloader()
success, data = await downloader.make_request('GET', github_url, custom_headers={'Accept': 'application/vnd.github+json'}) success, data = await downloader.make_request('GET', github_url, custom_headers={'Accept': 'application/vnd.github+json'})
if not success: if not success:
@@ -431,7 +431,7 @@ class UpdateRoutes:
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest" github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest"
try: try:
downloader = await Downloader.get_instance() downloader = await get_downloader()
success, data = await downloader.make_request('GET', github_url, custom_headers={'Accept': 'application/vnd.github+json'}) success, data = await downloader.make_request('GET', github_url, custom_headers={'Accept': 'application/vnd.github+json'})
if not success: if not success:

View File

@@ -1,4 +1,3 @@
from datetime import datetime
import os import os
import logging import logging
import asyncio import asyncio
@@ -59,17 +58,17 @@ class CivitaiClient:
return success, result return success, result
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]: async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
try: try:
downloader = await get_downloader() downloader = await get_downloader()
success, version = await downloader.make_request( success, result = await downloader.make_request(
'GET', 'GET',
f"{self.base_url}/model-versions/by-hash/{model_hash}", f"{self.base_url}/model-versions/by-hash/{model_hash}",
use_auth=True use_auth=True
) )
if success: if success:
# Get model ID from version data # Get model ID from version data
model_id = version.get('modelId') model_id = result.get('modelId')
if model_id: if model_id:
# Fetch additional model metadata # Fetch additional model metadata
success_model, data = await downloader.make_request( success_model, data = await downloader.make_request(
@@ -79,17 +78,24 @@ class CivitaiClient:
) )
if success_model: if success_model:
# Enrich version_info with model data # Enrich version_info with model data
version['model']['description'] = data.get("description") result['model']['description'] = data.get("description")
version['model']['tags'] = data.get("tags", []) result['model']['tags'] = data.get("tags", [])
# Add creator from model data # Add creator from model data
version['creator'] = data.get("creator") result['creator'] = data.get("creator")
return version return result, None
return None
# Handle specific error cases
if "not found" in str(result):
return None, "Model not found"
# Other error cases
logger.error(f"Failed to fetch model info for {model_hash[:10]}: {result}")
return None, str(result)
except Exception as e: except Exception as e:
logger.error(f"API Error: {str(e)}") logger.error(f"API Error: {str(e)}")
return None return None, str(e)
async def download_preview_image(self, image_url: str, save_path: str): async def download_preview_image(self, image_url: str, save_path: str):
try: try:
@@ -246,8 +252,8 @@ class CivitaiClient:
return result, None return result, None
# Handle specific error cases # Handle specific error cases
if "404" in str(result): if "not found" in str(result):
error_msg = f"Model not found (status 404)" error_msg = f"Model not found"
logger.warning(f"Model version not found: {version_id} - {error_msg}") logger.warning(f"Model version not found: {version_id} - {error_msg}")
return None, error_msg return None, error_msg
@@ -259,59 +265,6 @@ class CivitaiClient:
logger.error(error_msg) logger.error(error_msg)
return None, error_msg return None, error_msg
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
"""Fetch model metadata (description, tags, and creator info) from Civitai API
Args:
model_id: The Civitai model ID
Returns:
Tuple[Optional[Dict], int]: A tuple containing:
- A dictionary with model metadata or None if not found
- The HTTP status code from the request (0 for exceptions)
"""
try:
downloader = await get_downloader()
url = f"{self.base_url}/models/{model_id}"
success, result = await downloader.make_request(
'GET',
url,
use_auth=True
)
if not success:
# Try to extract status code from error message
status_code = 0
if "404" in str(result):
status_code = 404
elif "401" in str(result):
status_code = 401
elif "403" in str(result):
status_code = 403
logger.warning(f"Failed to fetch model metadata: {result}")
return None, status_code
# Extract relevant metadata
metadata = {
"description": result.get("description") or "No model description available",
"tags": result.get("tags", []),
"creator": {
"username": result.get("creator", {}).get("username"),
"image": result.get("creator", {}).get("image")
}
}
if metadata["description"] or metadata["tags"] or metadata["creator"]["username"]:
return metadata, 200
else:
logger.warning(f"No metadata found for model {model_id}")
return None, 200
except Exception as e:
logger.error(f"Error fetching model metadata: {e}", exc_info=True)
return None, 0
async def get_image_info(self, image_id: str) -> Optional[Dict]: async def get_image_info(self, image_id: str) -> Optional[Dict]:
"""Fetch image information from Civitai API """Fetch image information from Civitai API

View File

@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
import json import json
import aiosqlite import aiosqlite
import logging import logging
import aiohttp
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from typing import Optional, Dict, Tuple from typing import Optional, Dict, Tuple
from .downloader import get_downloader from .downloader import get_downloader
@@ -13,7 +12,7 @@ class ModelMetadataProvider(ABC):
"""Base abstract class for all model metadata providers""" """Base abstract class for all model metadata providers"""
@abstractmethod @abstractmethod
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]: async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Find model by hash value""" """Find model by hash value"""
pass pass
@@ -32,18 +31,13 @@ class ModelMetadataProvider(ABC):
"""Fetch model version metadata""" """Fetch model version metadata"""
pass pass
@abstractmethod
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
"""Fetch model metadata (description, tags, and creator info)"""
pass
class CivitaiModelMetadataProvider(ModelMetadataProvider): class CivitaiModelMetadataProvider(ModelMetadataProvider):
"""Provider that uses Civitai API for metadata""" """Provider that uses Civitai API for metadata"""
def __init__(self, civitai_client): def __init__(self, civitai_client):
self.client = civitai_client self.client = civitai_client
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]: async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
return await self.client.get_model_by_hash(model_hash) return await self.client.get_model_by_hash(model_hash)
async def get_model_versions(self, model_id: str) -> Optional[Dict]: async def get_model_versions(self, model_id: str) -> Optional[Dict]:
@@ -55,15 +49,12 @@ class CivitaiModelMetadataProvider(ModelMetadataProvider):
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]: async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
return await self.client.get_model_version_info(version_id) return await self.client.get_model_version_info(version_id)
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
return await self.client.get_model_metadata(model_id)
class CivArchiveModelMetadataProvider(ModelMetadataProvider): class CivArchiveModelMetadataProvider(ModelMetadataProvider):
"""Provider that uses CivArchive HTML page parsing for metadata""" """Provider that uses CivArchive HTML page parsing for metadata"""
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]: async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Not supported by CivArchive provider""" """Not supported by CivArchive provider"""
return None return None, "CivArchive provider does not support hash lookup"
async def get_model_versions(self, model_id: str) -> Optional[Dict]: async def get_model_versions(self, model_id: str) -> Optional[Dict]:
"""Not supported by CivArchive provider""" """Not supported by CivArchive provider"""
@@ -175,17 +166,13 @@ class CivArchiveModelMetadataProvider(ModelMetadataProvider):
"""Not supported by CivArchive provider - requires both model_id and version_id""" """Not supported by CivArchive provider - requires both model_id and version_id"""
return None, "CivArchive provider requires both model_id and version_id" return None, "CivArchive provider requires both model_id and version_id"
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
"""Not supported by CivArchive provider"""
return None, 404
class SQLiteModelMetadataProvider(ModelMetadataProvider): class SQLiteModelMetadataProvider(ModelMetadataProvider):
"""Provider that uses SQLite database for metadata""" """Provider that uses SQLite database for metadata"""
def __init__(self, db_path: str): def __init__(self, db_path: str):
self.db_path = db_path self.db_path = db_path
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]: async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Find model by hash value from SQLite database""" """Find model by hash value from SQLite database"""
async with aiosqlite.connect(self.db_path) as db: async with aiosqlite.connect(self.db_path) as db:
# Look up in model_files table to get model_id and version_id # Look up in model_files table to get model_id and version_id
@@ -200,14 +187,15 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
file_row = await cursor.fetchone() file_row = await cursor.fetchone()
if not file_row: if not file_row:
return None return None, "Model not found"
# Get version details # Get version details
model_id = file_row['model_id'] model_id = file_row['model_id']
version_id = file_row['version_id'] version_id = file_row['version_id']
# Build response in the same format as Civitai API # Build response in the same format as Civitai API
return await self._get_version_with_model_data(db, model_id, version_id) result = await self._get_version_with_model_data(db, model_id, version_id)
return result, None if result else "Error retrieving model data"
async def get_model_versions(self, model_id: str) -> Optional[Dict]: async def get_model_versions(self, model_id: str) -> Optional[Dict]:
"""Get all versions of a model from SQLite database""" """Get all versions of a model from SQLite database"""
@@ -324,37 +312,6 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
version_data = await self._get_version_with_model_data(db, model_id, version_id) version_data = await self._get_version_with_model_data(db, model_id, version_id)
return version_data, None return version_data, None
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
"""Fetch model metadata from SQLite database"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
# Get model details
model_query = "SELECT name, type, data, username FROM models WHERE id = ?"
cursor = await db.execute(model_query, (model_id,))
model_row = await cursor.fetchone()
if not model_row:
return None, 404
# Parse data JSON
try:
model_data = json.loads(model_row['data'])
# Extract relevant metadata
metadata = {
"description": model_data.get("description", "No model description available"),
"tags": model_data.get("tags", []),
"creator": {
"username": model_row['username'] or model_data.get("creator", {}).get("username"),
"image": model_data.get("creator", {}).get("image")
}
}
return metadata, 200
except json.JSONDecodeError:
return None, 500
async def _get_version_with_model_data(self, db, model_id, version_id) -> Optional[Dict]: async def _get_version_with_model_data(self, db, model_id, version_id) -> Optional[Dict]:
"""Helper to build version data with model information""" """Helper to build version data with model information"""
# Get version details # Get version details
@@ -409,15 +366,16 @@ class FallbackMetadataProvider(ModelMetadataProvider):
def __init__(self, providers: list): def __init__(self, providers: list):
self.providers = providers self.providers = providers
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]: async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
for provider in self.providers: for provider in self.providers:
try: try:
result = await provider.get_model_by_hash(model_hash) result, error = await provider.get_model_by_hash(model_hash)
if result: if result:
return result return result, error
except Exception: except Exception as e:
logger.debug(f"Provider failed for get_model_by_hash: {e}")
continue continue
return None return None, "Model not found"
async def get_model_versions(self, model_id: str) -> Optional[Dict]: async def get_model_versions(self, model_id: str) -> Optional[Dict]:
for provider in self.providers: for provider in self.providers:
@@ -452,17 +410,6 @@ class FallbackMetadataProvider(ModelMetadataProvider):
continue continue
return None, "No provider could retrieve the data" return None, "No provider could retrieve the data"
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
for provider in self.providers:
try:
result, status = await provider.get_model_metadata(model_id)
if result:
return result, status
except Exception as e:
logger.debug(f"Provider failed for get_model_metadata: {e}")
continue
return None, 404
class ModelMetadataProviderManager: class ModelMetadataProviderManager:
"""Manager for selecting and using model metadata providers""" """Manager for selecting and using model metadata providers"""
@@ -485,7 +432,7 @@ class ModelMetadataProviderManager:
if is_default or self.default_provider is None: if is_default or self.default_provider is None:
self.default_provider = name self.default_provider = name
async def get_model_by_hash(self, model_hash: str, provider_name: str = None) -> Optional[Dict]: async def get_model_by_hash(self, model_hash: str, provider_name: str = None) -> Tuple[Optional[Dict], Optional[str]]:
"""Find model by hash using specified or default provider""" """Find model by hash using specified or default provider"""
provider = self._get_provider(provider_name) provider = self._get_provider(provider_name)
return await provider.get_model_by_hash(model_hash) return await provider.get_model_by_hash(model_hash)
@@ -505,11 +452,6 @@ class ModelMetadataProviderManager:
provider = self._get_provider(provider_name) provider = self._get_provider(provider_name)
return await provider.get_model_version_info(version_id) return await provider.get_model_version_info(version_id)
async def get_model_metadata(self, model_id: str, provider_name: str = None) -> Tuple[Optional[Dict], int]:
"""Fetch model metadata using specified or default provider"""
provider = self._get_provider(provider_name)
return await provider.get_model_metadata(model_id)
def _get_provider(self, provider_name: str = None) -> ModelMetadataProvider: def _get_provider(self, provider_name: str = None) -> ModelMetadataProvider:
"""Get provider by name or default provider""" """Get provider by name or default provider"""
if provider_name and provider_name in self.providers: if provider_name and provider_name in self.providers:

View File

@@ -215,7 +215,7 @@ class ModelRouteUtils:
else: else:
metadata_provider = await get_default_metadata_provider() metadata_provider = await get_default_metadata_provider()
civitai_metadata = await metadata_provider.get_model_by_hash(sha256) civitai_metadata, error = await metadata_provider.get_model_by_hash(sha256)
if not civitai_metadata: if not civitai_metadata:
# Mark as not from CivitAI if not found # Mark as not from CivitAI if not found
local_metadata['from_civitai'] = False local_metadata['from_civitai'] = False
@@ -387,10 +387,10 @@ class ModelRouteUtils:
metadata_provider = await get_default_metadata_provider() metadata_provider = await get_default_metadata_provider()
# Fetch and update metadata # Fetch and update metadata
civitai_metadata = await metadata_provider.get_model_by_hash(local_metadata["sha256"]) civitai_metadata, error = await metadata_provider.get_model_by_hash(local_metadata["sha256"])
if not civitai_metadata: if not civitai_metadata:
await ModelRouteUtils.handle_not_found_on_civitai(metadata_path, local_metadata) await ModelRouteUtils.handle_not_found_on_civitai(metadata_path, local_metadata)
return web.json_response({"success": False, "error": "Not found on CivitAI"}, status=404) return web.json_response({"success": False, "error": error}, status=404)
await ModelRouteUtils.update_model_metadata(metadata_path, local_metadata, civitai_metadata, metadata_provider) await ModelRouteUtils.update_model_metadata(metadata_path, local_metadata, civitai_metadata, metadata_provider)