refactor(metadata): update model fetching methods to return error messages alongside results

This commit is contained in:
Will Miao
2025-09-19 16:36:34 +08:00
parent fc6f1bf95b
commit 1610048974
7 changed files with 66 additions and 175 deletions

View File

@@ -55,7 +55,7 @@ class RecipeMetadataParser(ABC):
# Unpack the tuple to get the actual data
civitai_info, error_msg = civitai_info_tuple if isinstance(civitai_info_tuple, tuple) else (civitai_info_tuple, None)
if not civitai_info or civitai_info.get("error") == "Model not found":
if not civitai_info or error_msg == "Model not found":
# Model not found or deleted
lora_entry['isDeleted'] = True
lora_entry['thumbnailUrl'] = '/loras_static/images/no-preview.png'

View File

@@ -91,7 +91,7 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
result["base_model"] = metadata["baseModel"]
elif "Model hash" in metadata and metadata_provider:
model_hash = metadata["Model hash"]
model_info = await metadata_provider.get_model_by_hash(model_hash)
model_info, error = await metadata_provider.get_model_by_hash(model_hash)
if model_info:
result["base_model"] = model_info.get("baseModel", "")
elif "Model" in metadata and isinstance(metadata.get("resources"), list):
@@ -100,7 +100,7 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
if resource.get("type") == "model" and resource.get("name") == metadata.get("Model"):
# This is likely the checkpoint model
if metadata_provider and resource.get("hash"):
model_info = await metadata_provider.get_model_by_hash(resource.get("hash"))
model_info, error = await metadata_provider.get_model_by_hash(resource.get("hash"))
if model_info:
result["base_model"] = model_info.get("baseModel", "")
@@ -201,11 +201,7 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
if version_id and metadata_provider:
try:
# Use get_model_version_info instead of get_model_version
civitai_info, error = await metadata_provider.get_model_version_info(version_id)
if error:
logger.warning(f"Error getting model version info: {error}")
continue
civitai_info = await metadata_provider.get_model_version_info(version_id)
populated_entry = await self.populate_lora_from_civitai(
lora_entry,
@@ -267,26 +263,23 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
if version_id and metadata_provider:
try:
# Use get_model_version_info with the version ID
civitai_info, error = await metadata_provider.get_model_version_info(version_id)
civitai_info = await metadata_provider.get_model_version_info(version_id)
if error:
logger.warning(f"Error getting model version info: {error}")
else:
populated_entry = await self.populate_lora_from_civitai(
lora_entry,
civitai_info,
recipe_scanner,
base_model_counts
)
populated_entry = await self.populate_lora_from_civitai(
lora_entry,
civitai_info,
recipe_scanner,
base_model_counts
)
if populated_entry is None:
continue # Skip invalid LoRA types
if populated_entry is None:
continue # Skip invalid LoRA types
lora_entry = populated_entry
# Track this LoRA for deduplication
if version_id:
added_loras[version_id] = len(result["loras"])
lora_entry = populated_entry
# Track this LoRA for deduplication
if version_id:
added_loras[version_id] = len(result["loras"])
except Exception as e:
logger.error(f"Error fetching Civitai info for model ID {version_id}: {e}")

View File

@@ -624,7 +624,7 @@ class BaseModelRoutes(ABC):
success = 0
needs_resort = False
# Prepare models to process, only those without CivitAI data or missing tags, description, or creator
# Prepare models to process, only those without CivitAI data
enable_metadata_archive_db = settings.get('enable_metadata_archive_db', False)
to_process = [
model for model in cache.raw_data
@@ -633,9 +633,6 @@ class BaseModelRoutes(ABC):
and (
not model.get('civitai')
or not model['civitai'].get('id')
# or not model.get('tags') # Skipping tag cause it could be empty legitimately
# or not model.get('modelDescription')
# or not (model.get('civitai') and model['civitai'].get('creator'))
)
and (
(enable_metadata_archive_db)
@@ -782,7 +779,13 @@ class BaseModelRoutes(ABC):
try:
hash = request.match_info.get('hash')
metadata_provider = await get_default_metadata_provider()
model = await metadata_provider.get_model_by_hash(hash)
model, error = await metadata_provider.get_model_by_hash(hash)
if error:
logger.warning(f"Error getting model by hash: {error}")
return web.json_response({
"success": False,
"error": error
}, status=404)
return web.json_response(model)
except Exception as e:
logger.error(f"Error fetching model details by hash: {e}")

View File

@@ -7,7 +7,7 @@ import shutil
import tempfile
from aiohttp import web
from typing import Dict, List
from ..services.downloader import get_downloader, Downloader
from ..services.downloader import get_downloader
logger = logging.getLogger(__name__)
@@ -265,7 +265,7 @@ class UpdateRoutes:
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/commits/main"
try:
downloader = await Downloader.get_instance()
downloader = await get_downloader()
success, data = await downloader.make_request('GET', github_url, custom_headers={'Accept': 'application/vnd.github+json'})
if not success:
@@ -431,7 +431,7 @@ class UpdateRoutes:
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest"
try:
downloader = await Downloader.get_instance()
downloader = await get_downloader()
success, data = await downloader.make_request('GET', github_url, custom_headers={'Accept': 'application/vnd.github+json'})
if not success:

View File

@@ -1,4 +1,3 @@
from datetime import datetime
import os
import logging
import asyncio
@@ -59,17 +58,17 @@ class CivitaiClient:
return success, result
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
try:
downloader = await get_downloader()
success, version = await downloader.make_request(
success, result = await downloader.make_request(
'GET',
f"{self.base_url}/model-versions/by-hash/{model_hash}",
use_auth=True
)
if success:
# Get model ID from version data
model_id = version.get('modelId')
model_id = result.get('modelId')
if model_id:
# Fetch additional model metadata
success_model, data = await downloader.make_request(
@@ -79,17 +78,24 @@ class CivitaiClient:
)
if success_model:
# Enrich version_info with model data
version['model']['description'] = data.get("description")
version['model']['tags'] = data.get("tags", [])
result['model']['description'] = data.get("description")
result['model']['tags'] = data.get("tags", [])
# Add creator from model data
version['creator'] = data.get("creator")
result['creator'] = data.get("creator")
return version
return None
return result, None
# Handle specific error cases
if "not found" in str(result):
return None, "Model not found"
# Other error cases
logger.error(f"Failed to fetch model info for {model_hash[:10]}: {result}")
return None, str(result)
except Exception as e:
logger.error(f"API Error: {str(e)}")
return None
return None, str(e)
async def download_preview_image(self, image_url: str, save_path: str):
try:
@@ -246,8 +252,8 @@ class CivitaiClient:
return result, None
# Handle specific error cases
if "404" in str(result):
error_msg = f"Model not found (status 404)"
if "not found" in str(result):
error_msg = f"Model not found"
logger.warning(f"Model version not found: {version_id} - {error_msg}")
return None, error_msg
@@ -259,59 +265,6 @@ class CivitaiClient:
logger.error(error_msg)
return None, error_msg
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
"""Fetch model metadata (description, tags, and creator info) from Civitai API
Args:
model_id: The Civitai model ID
Returns:
Tuple[Optional[Dict], int]: A tuple containing:
- A dictionary with model metadata or None if not found
- The HTTP status code from the request (0 for exceptions)
"""
try:
downloader = await get_downloader()
url = f"{self.base_url}/models/{model_id}"
success, result = await downloader.make_request(
'GET',
url,
use_auth=True
)
if not success:
# Try to extract status code from error message
status_code = 0
if "404" in str(result):
status_code = 404
elif "401" in str(result):
status_code = 401
elif "403" in str(result):
status_code = 403
logger.warning(f"Failed to fetch model metadata: {result}")
return None, status_code
# Extract relevant metadata
metadata = {
"description": result.get("description") or "No model description available",
"tags": result.get("tags", []),
"creator": {
"username": result.get("creator", {}).get("username"),
"image": result.get("creator", {}).get("image")
}
}
if metadata["description"] or metadata["tags"] or metadata["creator"]["username"]:
return metadata, 200
else:
logger.warning(f"No metadata found for model {model_id}")
return None, 200
except Exception as e:
logger.error(f"Error fetching model metadata: {e}", exc_info=True)
return None, 0
async def get_image_info(self, image_id: str) -> Optional[Dict]:
"""Fetch image information from Civitai API

View File

@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
import json
import aiosqlite
import logging
import aiohttp
from bs4 import BeautifulSoup
from typing import Optional, Dict, Tuple
from .downloader import get_downloader
@@ -13,7 +12,7 @@ class ModelMetadataProvider(ABC):
"""Base abstract class for all model metadata providers"""
@abstractmethod
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Find model by hash value"""
pass
@@ -31,11 +30,6 @@ class ModelMetadataProvider(ABC):
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Fetch model version metadata"""
pass
@abstractmethod
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
"""Fetch model metadata (description, tags, and creator info)"""
pass
class CivitaiModelMetadataProvider(ModelMetadataProvider):
"""Provider that uses Civitai API for metadata"""
@@ -43,7 +37,7 @@ class CivitaiModelMetadataProvider(ModelMetadataProvider):
def __init__(self, civitai_client):
self.client = civitai_client
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
return await self.client.get_model_by_hash(model_hash)
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
@@ -54,16 +48,13 @@ class CivitaiModelMetadataProvider(ModelMetadataProvider):
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
return await self.client.get_model_version_info(version_id)
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
return await self.client.get_model_metadata(model_id)
class CivArchiveModelMetadataProvider(ModelMetadataProvider):
"""Provider that uses CivArchive HTML page parsing for metadata"""
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Not supported by CivArchive provider"""
return None
return None, "CivArchive provider does not support hash lookup"
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
"""Not supported by CivArchive provider"""
@@ -174,10 +165,6 @@ class CivArchiveModelMetadataProvider(ModelMetadataProvider):
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Not supported by CivArchive provider - requires both model_id and version_id"""
return None, "CivArchive provider requires both model_id and version_id"
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
"""Not supported by CivArchive provider"""
return None, 404
class SQLiteModelMetadataProvider(ModelMetadataProvider):
"""Provider that uses SQLite database for metadata"""
@@ -185,7 +172,7 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
def __init__(self, db_path: str):
self.db_path = db_path
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Find model by hash value from SQLite database"""
async with aiosqlite.connect(self.db_path) as db:
# Look up in model_files table to get model_id and version_id
@@ -200,14 +187,15 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
file_row = await cursor.fetchone()
if not file_row:
return None
return None, "Model not found"
# Get version details
model_id = file_row['model_id']
version_id = file_row['version_id']
# Build response in the same format as Civitai API
return await self._get_version_with_model_data(db, model_id, version_id)
result = await self._get_version_with_model_data(db, model_id, version_id)
return result, None if result else "Error retrieving model data"
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
"""Get all versions of a model from SQLite database"""
@@ -324,37 +312,6 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
version_data = await self._get_version_with_model_data(db, model_id, version_id)
return version_data, None
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
"""Fetch model metadata from SQLite database"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
# Get model details
model_query = "SELECT name, type, data, username FROM models WHERE id = ?"
cursor = await db.execute(model_query, (model_id,))
model_row = await cursor.fetchone()
if not model_row:
return None, 404
# Parse data JSON
try:
model_data = json.loads(model_row['data'])
# Extract relevant metadata
metadata = {
"description": model_data.get("description", "No model description available"),
"tags": model_data.get("tags", []),
"creator": {
"username": model_row['username'] or model_data.get("creator", {}).get("username"),
"image": model_data.get("creator", {}).get("image")
}
}
return metadata, 200
except json.JSONDecodeError:
return None, 500
async def _get_version_with_model_data(self, db, model_id, version_id) -> Optional[Dict]:
"""Helper to build version data with model information"""
# Get version details
@@ -409,15 +366,16 @@ class FallbackMetadataProvider(ModelMetadataProvider):
def __init__(self, providers: list):
self.providers = providers
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
for provider in self.providers:
try:
result = await provider.get_model_by_hash(model_hash)
result, error = await provider.get_model_by_hash(model_hash)
if result:
return result
except Exception:
return result, error
except Exception as e:
logger.debug(f"Provider failed for get_model_by_hash: {e}")
continue
return None
return None, "Model not found"
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
for provider in self.providers:
@@ -452,17 +410,6 @@ class FallbackMetadataProvider(ModelMetadataProvider):
continue
return None, "No provider could retrieve the data"
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
for provider in self.providers:
try:
result, status = await provider.get_model_metadata(model_id)
if result:
return result, status
except Exception as e:
logger.debug(f"Provider failed for get_model_metadata: {e}")
continue
return None, 404
class ModelMetadataProviderManager:
"""Manager for selecting and using model metadata providers"""
@@ -485,7 +432,7 @@ class ModelMetadataProviderManager:
if is_default or self.default_provider is None:
self.default_provider = name
async def get_model_by_hash(self, model_hash: str, provider_name: str = None) -> Optional[Dict]:
async def get_model_by_hash(self, model_hash: str, provider_name: str = None) -> Tuple[Optional[Dict], Optional[str]]:
"""Find model by hash using specified or default provider"""
provider = self._get_provider(provider_name)
return await provider.get_model_by_hash(model_hash)
@@ -505,11 +452,6 @@ class ModelMetadataProviderManager:
provider = self._get_provider(provider_name)
return await provider.get_model_version_info(version_id)
async def get_model_metadata(self, model_id: str, provider_name: str = None) -> Tuple[Optional[Dict], int]:
"""Fetch model metadata using specified or default provider"""
provider = self._get_provider(provider_name)
return await provider.get_model_metadata(model_id)
def _get_provider(self, provider_name: str = None) -> ModelMetadataProvider:
"""Get provider by name or default provider"""
if provider_name and provider_name in self.providers:

View File

@@ -215,7 +215,7 @@ class ModelRouteUtils:
else:
metadata_provider = await get_default_metadata_provider()
civitai_metadata = await metadata_provider.get_model_by_hash(sha256)
civitai_metadata, error = await metadata_provider.get_model_by_hash(sha256)
if not civitai_metadata:
# Mark as not from CivitAI if not found
local_metadata['from_civitai'] = False
@@ -387,10 +387,10 @@ class ModelRouteUtils:
metadata_provider = await get_default_metadata_provider()
# Fetch and update metadata
civitai_metadata = await metadata_provider.get_model_by_hash(local_metadata["sha256"])
civitai_metadata, error = await metadata_provider.get_model_by_hash(local_metadata["sha256"])
if not civitai_metadata:
await ModelRouteUtils.handle_not_found_on_civitai(metadata_path, local_metadata)
return web.json_response({"success": False, "error": "Not found on CivitAI"}, status=404)
return web.json_response({"success": False, "error": error}, status=404)
await ModelRouteUtils.update_model_metadata(metadata_path, local_metadata, civitai_metadata, metadata_provider)