mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 22:52:12 -03:00
refactor(metadata): update model fetching methods to return error messages alongside results
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
from datetime import datetime
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
@@ -59,17 +58,17 @@ class CivitaiClient:
|
||||
|
||||
return success, result
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
try:
|
||||
downloader = await get_downloader()
|
||||
success, version = await downloader.make_request(
|
||||
success, result = await downloader.make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/model-versions/by-hash/{model_hash}",
|
||||
use_auth=True
|
||||
)
|
||||
if success:
|
||||
# Get model ID from version data
|
||||
model_id = version.get('modelId')
|
||||
model_id = result.get('modelId')
|
||||
if model_id:
|
||||
# Fetch additional model metadata
|
||||
success_model, data = await downloader.make_request(
|
||||
@@ -79,17 +78,24 @@ class CivitaiClient:
|
||||
)
|
||||
if success_model:
|
||||
# Enrich version_info with model data
|
||||
version['model']['description'] = data.get("description")
|
||||
version['model']['tags'] = data.get("tags", [])
|
||||
result['model']['description'] = data.get("description")
|
||||
result['model']['tags'] = data.get("tags", [])
|
||||
|
||||
# Add creator from model data
|
||||
version['creator'] = data.get("creator")
|
||||
result['creator'] = data.get("creator")
|
||||
|
||||
return version
|
||||
return None
|
||||
return result, None
|
||||
|
||||
# Handle specific error cases
|
||||
if "not found" in str(result):
|
||||
return None, "Model not found"
|
||||
|
||||
# Other error cases
|
||||
logger.error(f"Failed to fetch model info for {model_hash[:10]}: {result}")
|
||||
return None, str(result)
|
||||
except Exception as e:
|
||||
logger.error(f"API Error: {str(e)}")
|
||||
return None
|
||||
return None, str(e)
|
||||
|
||||
async def download_preview_image(self, image_url: str, save_path: str):
|
||||
try:
|
||||
@@ -246,8 +252,8 @@ class CivitaiClient:
|
||||
return result, None
|
||||
|
||||
# Handle specific error cases
|
||||
if "404" in str(result):
|
||||
error_msg = f"Model not found (status 404)"
|
||||
if "not found" in str(result):
|
||||
error_msg = f"Model not found"
|
||||
logger.warning(f"Model version not found: {version_id} - {error_msg}")
|
||||
return None, error_msg
|
||||
|
||||
@@ -259,59 +265,6 @@ class CivitaiClient:
|
||||
logger.error(error_msg)
|
||||
return None, error_msg
|
||||
|
||||
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
|
||||
"""Fetch model metadata (description, tags, and creator info) from Civitai API
|
||||
|
||||
Args:
|
||||
model_id: The Civitai model ID
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[Dict], int]: A tuple containing:
|
||||
- A dictionary with model metadata or None if not found
|
||||
- The HTTP status code from the request (0 for exceptions)
|
||||
"""
|
||||
try:
|
||||
downloader = await get_downloader()
|
||||
url = f"{self.base_url}/models/{model_id}"
|
||||
|
||||
success, result = await downloader.make_request(
|
||||
'GET',
|
||||
url,
|
||||
use_auth=True
|
||||
)
|
||||
|
||||
if not success:
|
||||
# Try to extract status code from error message
|
||||
status_code = 0
|
||||
if "404" in str(result):
|
||||
status_code = 404
|
||||
elif "401" in str(result):
|
||||
status_code = 401
|
||||
elif "403" in str(result):
|
||||
status_code = 403
|
||||
logger.warning(f"Failed to fetch model metadata: {result}")
|
||||
return None, status_code
|
||||
|
||||
# Extract relevant metadata
|
||||
metadata = {
|
||||
"description": result.get("description") or "No model description available",
|
||||
"tags": result.get("tags", []),
|
||||
"creator": {
|
||||
"username": result.get("creator", {}).get("username"),
|
||||
"image": result.get("creator", {}).get("image")
|
||||
}
|
||||
}
|
||||
|
||||
if metadata["description"] or metadata["tags"] or metadata["creator"]["username"]:
|
||||
return metadata, 200
|
||||
else:
|
||||
logger.warning(f"No metadata found for model {model_id}")
|
||||
return None, 200
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model metadata: {e}", exc_info=True)
|
||||
return None, 0
|
||||
|
||||
async def get_image_info(self, image_id: str) -> Optional[Dict]:
|
||||
"""Fetch image information from Civitai API
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
|
||||
import json
|
||||
import aiosqlite
|
||||
import logging
|
||||
import aiohttp
|
||||
from bs4 import BeautifulSoup
|
||||
from typing import Optional, Dict, Tuple
|
||||
from .downloader import get_downloader
|
||||
@@ -13,7 +12,7 @@ class ModelMetadataProvider(ABC):
|
||||
"""Base abstract class for all model metadata providers"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Find model by hash value"""
|
||||
pass
|
||||
|
||||
@@ -31,11 +30,6 @@ class ModelMetadataProvider(ABC):
|
||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Fetch model version metadata"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
|
||||
"""Fetch model metadata (description, tags, and creator info)"""
|
||||
pass
|
||||
|
||||
class CivitaiModelMetadataProvider(ModelMetadataProvider):
|
||||
"""Provider that uses Civitai API for metadata"""
|
||||
@@ -43,7 +37,7 @@ class CivitaiModelMetadataProvider(ModelMetadataProvider):
|
||||
def __init__(self, civitai_client):
|
||||
self.client = civitai_client
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
return await self.client.get_model_by_hash(model_hash)
|
||||
|
||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||
@@ -54,16 +48,13 @@ class CivitaiModelMetadataProvider(ModelMetadataProvider):
|
||||
|
||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
return await self.client.get_model_version_info(version_id)
|
||||
|
||||
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
|
||||
return await self.client.get_model_metadata(model_id)
|
||||
|
||||
class CivArchiveModelMetadataProvider(ModelMetadataProvider):
|
||||
"""Provider that uses CivArchive HTML page parsing for metadata"""
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Not supported by CivArchive provider"""
|
||||
return None
|
||||
return None, "CivArchive provider does not support hash lookup"
|
||||
|
||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||
"""Not supported by CivArchive provider"""
|
||||
@@ -174,10 +165,6 @@ class CivArchiveModelMetadataProvider(ModelMetadataProvider):
|
||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Not supported by CivArchive provider - requires both model_id and version_id"""
|
||||
return None, "CivArchive provider requires both model_id and version_id"
|
||||
|
||||
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
|
||||
"""Not supported by CivArchive provider"""
|
||||
return None, 404
|
||||
|
||||
class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
||||
"""Provider that uses SQLite database for metadata"""
|
||||
@@ -185,7 +172,7 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Find model by hash value from SQLite database"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
# Look up in model_files table to get model_id and version_id
|
||||
@@ -200,14 +187,15 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
||||
file_row = await cursor.fetchone()
|
||||
|
||||
if not file_row:
|
||||
return None
|
||||
return None, "Model not found"
|
||||
|
||||
# Get version details
|
||||
model_id = file_row['model_id']
|
||||
version_id = file_row['version_id']
|
||||
|
||||
# Build response in the same format as Civitai API
|
||||
return await self._get_version_with_model_data(db, model_id, version_id)
|
||||
result = await self._get_version_with_model_data(db, model_id, version_id)
|
||||
return result, None if result else "Error retrieving model data"
|
||||
|
||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||
"""Get all versions of a model from SQLite database"""
|
||||
@@ -324,37 +312,6 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
||||
version_data = await self._get_version_with_model_data(db, model_id, version_id)
|
||||
return version_data, None
|
||||
|
||||
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
|
||||
"""Fetch model metadata from SQLite database"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
|
||||
# Get model details
|
||||
model_query = "SELECT name, type, data, username FROM models WHERE id = ?"
|
||||
cursor = await db.execute(model_query, (model_id,))
|
||||
model_row = await cursor.fetchone()
|
||||
|
||||
if not model_row:
|
||||
return None, 404
|
||||
|
||||
# Parse data JSON
|
||||
try:
|
||||
model_data = json.loads(model_row['data'])
|
||||
|
||||
# Extract relevant metadata
|
||||
metadata = {
|
||||
"description": model_data.get("description", "No model description available"),
|
||||
"tags": model_data.get("tags", []),
|
||||
"creator": {
|
||||
"username": model_row['username'] or model_data.get("creator", {}).get("username"),
|
||||
"image": model_data.get("creator", {}).get("image")
|
||||
}
|
||||
}
|
||||
|
||||
return metadata, 200
|
||||
except json.JSONDecodeError:
|
||||
return None, 500
|
||||
|
||||
async def _get_version_with_model_data(self, db, model_id, version_id) -> Optional[Dict]:
|
||||
"""Helper to build version data with model information"""
|
||||
# Get version details
|
||||
@@ -409,15 +366,16 @@ class FallbackMetadataProvider(ModelMetadataProvider):
|
||||
def __init__(self, providers: list):
|
||||
self.providers = providers
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
for provider in self.providers:
|
||||
try:
|
||||
result = await provider.get_model_by_hash(model_hash)
|
||||
result, error = await provider.get_model_by_hash(model_hash)
|
||||
if result:
|
||||
return result
|
||||
except Exception:
|
||||
return result, error
|
||||
except Exception as e:
|
||||
logger.debug(f"Provider failed for get_model_by_hash: {e}")
|
||||
continue
|
||||
return None
|
||||
return None, "Model not found"
|
||||
|
||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||
for provider in self.providers:
|
||||
@@ -452,17 +410,6 @@ class FallbackMetadataProvider(ModelMetadataProvider):
|
||||
continue
|
||||
return None, "No provider could retrieve the data"
|
||||
|
||||
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
|
||||
for provider in self.providers:
|
||||
try:
|
||||
result, status = await provider.get_model_metadata(model_id)
|
||||
if result:
|
||||
return result, status
|
||||
except Exception as e:
|
||||
logger.debug(f"Provider failed for get_model_metadata: {e}")
|
||||
continue
|
||||
return None, 404
|
||||
|
||||
class ModelMetadataProviderManager:
|
||||
"""Manager for selecting and using model metadata providers"""
|
||||
|
||||
@@ -485,7 +432,7 @@ class ModelMetadataProviderManager:
|
||||
if is_default or self.default_provider is None:
|
||||
self.default_provider = name
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str, provider_name: str = None) -> Optional[Dict]:
|
||||
async def get_model_by_hash(self, model_hash: str, provider_name: str = None) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Find model by hash using specified or default provider"""
|
||||
provider = self._get_provider(provider_name)
|
||||
return await provider.get_model_by_hash(model_hash)
|
||||
@@ -505,11 +452,6 @@ class ModelMetadataProviderManager:
|
||||
provider = self._get_provider(provider_name)
|
||||
return await provider.get_model_version_info(version_id)
|
||||
|
||||
async def get_model_metadata(self, model_id: str, provider_name: str = None) -> Tuple[Optional[Dict], int]:
|
||||
"""Fetch model metadata using specified or default provider"""
|
||||
provider = self._get_provider(provider_name)
|
||||
return await provider.get_model_metadata(model_id)
|
||||
|
||||
def _get_provider(self, provider_name: str = None) -> ModelMetadataProvider:
|
||||
"""Get provider by name or default provider"""
|
||||
if provider_name and provider_name in self.providers:
|
||||
|
||||
Reference in New Issue
Block a user