mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
refactor(metadata): update model fetching methods to return error messages alongside results
This commit is contained in:
@@ -55,7 +55,7 @@ class RecipeMetadataParser(ABC):
|
||||
# Unpack the tuple to get the actual data
|
||||
civitai_info, error_msg = civitai_info_tuple if isinstance(civitai_info_tuple, tuple) else (civitai_info_tuple, None)
|
||||
|
||||
if not civitai_info or civitai_info.get("error") == "Model not found":
|
||||
if not civitai_info or error_msg == "Model not found":
|
||||
# Model not found or deleted
|
||||
lora_entry['isDeleted'] = True
|
||||
lora_entry['thumbnailUrl'] = '/loras_static/images/no-preview.png'
|
||||
|
||||
@@ -91,7 +91,7 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
||||
result["base_model"] = metadata["baseModel"]
|
||||
elif "Model hash" in metadata and metadata_provider:
|
||||
model_hash = metadata["Model hash"]
|
||||
model_info = await metadata_provider.get_model_by_hash(model_hash)
|
||||
model_info, error = await metadata_provider.get_model_by_hash(model_hash)
|
||||
if model_info:
|
||||
result["base_model"] = model_info.get("baseModel", "")
|
||||
elif "Model" in metadata and isinstance(metadata.get("resources"), list):
|
||||
@@ -100,7 +100,7 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
||||
if resource.get("type") == "model" and resource.get("name") == metadata.get("Model"):
|
||||
# This is likely the checkpoint model
|
||||
if metadata_provider and resource.get("hash"):
|
||||
model_info = await metadata_provider.get_model_by_hash(resource.get("hash"))
|
||||
model_info, error = await metadata_provider.get_model_by_hash(resource.get("hash"))
|
||||
if model_info:
|
||||
result["base_model"] = model_info.get("baseModel", "")
|
||||
|
||||
@@ -201,11 +201,7 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
||||
if version_id and metadata_provider:
|
||||
try:
|
||||
# Use get_model_version_info instead of get_model_version
|
||||
civitai_info, error = await metadata_provider.get_model_version_info(version_id)
|
||||
|
||||
if error:
|
||||
logger.warning(f"Error getting model version info: {error}")
|
||||
continue
|
||||
civitai_info = await metadata_provider.get_model_version_info(version_id)
|
||||
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
@@ -267,26 +263,23 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
||||
if version_id and metadata_provider:
|
||||
try:
|
||||
# Use get_model_version_info with the version ID
|
||||
civitai_info, error = await metadata_provider.get_model_version_info(version_id)
|
||||
civitai_info = await metadata_provider.get_model_version_info(version_id)
|
||||
|
||||
if error:
|
||||
logger.warning(f"Error getting model version info: {error}")
|
||||
else:
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
recipe_scanner,
|
||||
base_model_counts
|
||||
)
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
recipe_scanner,
|
||||
base_model_counts
|
||||
)
|
||||
|
||||
if populated_entry is None:
|
||||
continue # Skip invalid LoRA types
|
||||
|
||||
if populated_entry is None:
|
||||
continue # Skip invalid LoRA types
|
||||
|
||||
lora_entry = populated_entry
|
||||
|
||||
# Track this LoRA for deduplication
|
||||
if version_id:
|
||||
added_loras[version_id] = len(result["loras"])
|
||||
lora_entry = populated_entry
|
||||
|
||||
# Track this LoRA for deduplication
|
||||
if version_id:
|
||||
added_loras[version_id] = len(result["loras"])
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for model ID {version_id}: {e}")
|
||||
|
||||
|
||||
@@ -624,7 +624,7 @@ class BaseModelRoutes(ABC):
|
||||
success = 0
|
||||
needs_resort = False
|
||||
|
||||
# Prepare models to process, only those without CivitAI data or missing tags, description, or creator
|
||||
# Prepare models to process, only those without CivitAI data
|
||||
enable_metadata_archive_db = settings.get('enable_metadata_archive_db', False)
|
||||
to_process = [
|
||||
model for model in cache.raw_data
|
||||
@@ -633,9 +633,6 @@ class BaseModelRoutes(ABC):
|
||||
and (
|
||||
not model.get('civitai')
|
||||
or not model['civitai'].get('id')
|
||||
# or not model.get('tags') # Skipping tag cause it could be empty legitimately
|
||||
# or not model.get('modelDescription')
|
||||
# or not (model.get('civitai') and model['civitai'].get('creator'))
|
||||
)
|
||||
and (
|
||||
(enable_metadata_archive_db)
|
||||
@@ -782,7 +779,13 @@ class BaseModelRoutes(ABC):
|
||||
try:
|
||||
hash = request.match_info.get('hash')
|
||||
metadata_provider = await get_default_metadata_provider()
|
||||
model = await metadata_provider.get_model_by_hash(hash)
|
||||
model, error = await metadata_provider.get_model_by_hash(hash)
|
||||
if error:
|
||||
logger.warning(f"Error getting model by hash: {error}")
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": error
|
||||
}, status=404)
|
||||
return web.json_response(model)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model details by hash: {e}")
|
||||
|
||||
@@ -7,7 +7,7 @@ import shutil
|
||||
import tempfile
|
||||
from aiohttp import web
|
||||
from typing import Dict, List
|
||||
from ..services.downloader import get_downloader, Downloader
|
||||
from ..services.downloader import get_downloader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -265,7 +265,7 @@ class UpdateRoutes:
|
||||
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/commits/main"
|
||||
|
||||
try:
|
||||
downloader = await Downloader.get_instance()
|
||||
downloader = await get_downloader()
|
||||
success, data = await downloader.make_request('GET', github_url, custom_headers={'Accept': 'application/vnd.github+json'})
|
||||
|
||||
if not success:
|
||||
@@ -431,7 +431,7 @@ class UpdateRoutes:
|
||||
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest"
|
||||
|
||||
try:
|
||||
downloader = await Downloader.get_instance()
|
||||
downloader = await get_downloader()
|
||||
success, data = await downloader.make_request('GET', github_url, custom_headers={'Accept': 'application/vnd.github+json'})
|
||||
|
||||
if not success:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from datetime import datetime
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
@@ -59,17 +58,17 @@ class CivitaiClient:
|
||||
|
||||
return success, result
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
try:
|
||||
downloader = await get_downloader()
|
||||
success, version = await downloader.make_request(
|
||||
success, result = await downloader.make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/model-versions/by-hash/{model_hash}",
|
||||
use_auth=True
|
||||
)
|
||||
if success:
|
||||
# Get model ID from version data
|
||||
model_id = version.get('modelId')
|
||||
model_id = result.get('modelId')
|
||||
if model_id:
|
||||
# Fetch additional model metadata
|
||||
success_model, data = await downloader.make_request(
|
||||
@@ -79,17 +78,24 @@ class CivitaiClient:
|
||||
)
|
||||
if success_model:
|
||||
# Enrich version_info with model data
|
||||
version['model']['description'] = data.get("description")
|
||||
version['model']['tags'] = data.get("tags", [])
|
||||
result['model']['description'] = data.get("description")
|
||||
result['model']['tags'] = data.get("tags", [])
|
||||
|
||||
# Add creator from model data
|
||||
version['creator'] = data.get("creator")
|
||||
result['creator'] = data.get("creator")
|
||||
|
||||
return version
|
||||
return None
|
||||
return result, None
|
||||
|
||||
# Handle specific error cases
|
||||
if "not found" in str(result):
|
||||
return None, "Model not found"
|
||||
|
||||
# Other error cases
|
||||
logger.error(f"Failed to fetch model info for {model_hash[:10]}: {result}")
|
||||
return None, str(result)
|
||||
except Exception as e:
|
||||
logger.error(f"API Error: {str(e)}")
|
||||
return None
|
||||
return None, str(e)
|
||||
|
||||
async def download_preview_image(self, image_url: str, save_path: str):
|
||||
try:
|
||||
@@ -246,8 +252,8 @@ class CivitaiClient:
|
||||
return result, None
|
||||
|
||||
# Handle specific error cases
|
||||
if "404" in str(result):
|
||||
error_msg = f"Model not found (status 404)"
|
||||
if "not found" in str(result):
|
||||
error_msg = f"Model not found"
|
||||
logger.warning(f"Model version not found: {version_id} - {error_msg}")
|
||||
return None, error_msg
|
||||
|
||||
@@ -259,59 +265,6 @@ class CivitaiClient:
|
||||
logger.error(error_msg)
|
||||
return None, error_msg
|
||||
|
||||
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
|
||||
"""Fetch model metadata (description, tags, and creator info) from Civitai API
|
||||
|
||||
Args:
|
||||
model_id: The Civitai model ID
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[Dict], int]: A tuple containing:
|
||||
- A dictionary with model metadata or None if not found
|
||||
- The HTTP status code from the request (0 for exceptions)
|
||||
"""
|
||||
try:
|
||||
downloader = await get_downloader()
|
||||
url = f"{self.base_url}/models/{model_id}"
|
||||
|
||||
success, result = await downloader.make_request(
|
||||
'GET',
|
||||
url,
|
||||
use_auth=True
|
||||
)
|
||||
|
||||
if not success:
|
||||
# Try to extract status code from error message
|
||||
status_code = 0
|
||||
if "404" in str(result):
|
||||
status_code = 404
|
||||
elif "401" in str(result):
|
||||
status_code = 401
|
||||
elif "403" in str(result):
|
||||
status_code = 403
|
||||
logger.warning(f"Failed to fetch model metadata: {result}")
|
||||
return None, status_code
|
||||
|
||||
# Extract relevant metadata
|
||||
metadata = {
|
||||
"description": result.get("description") or "No model description available",
|
||||
"tags": result.get("tags", []),
|
||||
"creator": {
|
||||
"username": result.get("creator", {}).get("username"),
|
||||
"image": result.get("creator", {}).get("image")
|
||||
}
|
||||
}
|
||||
|
||||
if metadata["description"] or metadata["tags"] or metadata["creator"]["username"]:
|
||||
return metadata, 200
|
||||
else:
|
||||
logger.warning(f"No metadata found for model {model_id}")
|
||||
return None, 200
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model metadata: {e}", exc_info=True)
|
||||
return None, 0
|
||||
|
||||
async def get_image_info(self, image_id: str) -> Optional[Dict]:
|
||||
"""Fetch image information from Civitai API
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
|
||||
import json
|
||||
import aiosqlite
|
||||
import logging
|
||||
import aiohttp
|
||||
from bs4 import BeautifulSoup
|
||||
from typing import Optional, Dict, Tuple
|
||||
from .downloader import get_downloader
|
||||
@@ -13,7 +12,7 @@ class ModelMetadataProvider(ABC):
|
||||
"""Base abstract class for all model metadata providers"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Find model by hash value"""
|
||||
pass
|
||||
|
||||
@@ -31,11 +30,6 @@ class ModelMetadataProvider(ABC):
|
||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Fetch model version metadata"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
|
||||
"""Fetch model metadata (description, tags, and creator info)"""
|
||||
pass
|
||||
|
||||
class CivitaiModelMetadataProvider(ModelMetadataProvider):
|
||||
"""Provider that uses Civitai API for metadata"""
|
||||
@@ -43,7 +37,7 @@ class CivitaiModelMetadataProvider(ModelMetadataProvider):
|
||||
def __init__(self, civitai_client):
|
||||
self.client = civitai_client
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
return await self.client.get_model_by_hash(model_hash)
|
||||
|
||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||
@@ -54,16 +48,13 @@ class CivitaiModelMetadataProvider(ModelMetadataProvider):
|
||||
|
||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
return await self.client.get_model_version_info(version_id)
|
||||
|
||||
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
|
||||
return await self.client.get_model_metadata(model_id)
|
||||
|
||||
class CivArchiveModelMetadataProvider(ModelMetadataProvider):
|
||||
"""Provider that uses CivArchive HTML page parsing for metadata"""
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Not supported by CivArchive provider"""
|
||||
return None
|
||||
return None, "CivArchive provider does not support hash lookup"
|
||||
|
||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||
"""Not supported by CivArchive provider"""
|
||||
@@ -174,10 +165,6 @@ class CivArchiveModelMetadataProvider(ModelMetadataProvider):
|
||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Not supported by CivArchive provider - requires both model_id and version_id"""
|
||||
return None, "CivArchive provider requires both model_id and version_id"
|
||||
|
||||
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
|
||||
"""Not supported by CivArchive provider"""
|
||||
return None, 404
|
||||
|
||||
class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
||||
"""Provider that uses SQLite database for metadata"""
|
||||
@@ -185,7 +172,7 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Find model by hash value from SQLite database"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
# Look up in model_files table to get model_id and version_id
|
||||
@@ -200,14 +187,15 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
||||
file_row = await cursor.fetchone()
|
||||
|
||||
if not file_row:
|
||||
return None
|
||||
return None, "Model not found"
|
||||
|
||||
# Get version details
|
||||
model_id = file_row['model_id']
|
||||
version_id = file_row['version_id']
|
||||
|
||||
# Build response in the same format as Civitai API
|
||||
return await self._get_version_with_model_data(db, model_id, version_id)
|
||||
result = await self._get_version_with_model_data(db, model_id, version_id)
|
||||
return result, None if result else "Error retrieving model data"
|
||||
|
||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||
"""Get all versions of a model from SQLite database"""
|
||||
@@ -324,37 +312,6 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
||||
version_data = await self._get_version_with_model_data(db, model_id, version_id)
|
||||
return version_data, None
|
||||
|
||||
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
|
||||
"""Fetch model metadata from SQLite database"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
|
||||
# Get model details
|
||||
model_query = "SELECT name, type, data, username FROM models WHERE id = ?"
|
||||
cursor = await db.execute(model_query, (model_id,))
|
||||
model_row = await cursor.fetchone()
|
||||
|
||||
if not model_row:
|
||||
return None, 404
|
||||
|
||||
# Parse data JSON
|
||||
try:
|
||||
model_data = json.loads(model_row['data'])
|
||||
|
||||
# Extract relevant metadata
|
||||
metadata = {
|
||||
"description": model_data.get("description", "No model description available"),
|
||||
"tags": model_data.get("tags", []),
|
||||
"creator": {
|
||||
"username": model_row['username'] or model_data.get("creator", {}).get("username"),
|
||||
"image": model_data.get("creator", {}).get("image")
|
||||
}
|
||||
}
|
||||
|
||||
return metadata, 200
|
||||
except json.JSONDecodeError:
|
||||
return None, 500
|
||||
|
||||
async def _get_version_with_model_data(self, db, model_id, version_id) -> Optional[Dict]:
|
||||
"""Helper to build version data with model information"""
|
||||
# Get version details
|
||||
@@ -409,15 +366,16 @@ class FallbackMetadataProvider(ModelMetadataProvider):
|
||||
def __init__(self, providers: list):
|
||||
self.providers = providers
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
for provider in self.providers:
|
||||
try:
|
||||
result = await provider.get_model_by_hash(model_hash)
|
||||
result, error = await provider.get_model_by_hash(model_hash)
|
||||
if result:
|
||||
return result
|
||||
except Exception:
|
||||
return result, error
|
||||
except Exception as e:
|
||||
logger.debug(f"Provider failed for get_model_by_hash: {e}")
|
||||
continue
|
||||
return None
|
||||
return None, "Model not found"
|
||||
|
||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||
for provider in self.providers:
|
||||
@@ -452,17 +410,6 @@ class FallbackMetadataProvider(ModelMetadataProvider):
|
||||
continue
|
||||
return None, "No provider could retrieve the data"
|
||||
|
||||
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
|
||||
for provider in self.providers:
|
||||
try:
|
||||
result, status = await provider.get_model_metadata(model_id)
|
||||
if result:
|
||||
return result, status
|
||||
except Exception as e:
|
||||
logger.debug(f"Provider failed for get_model_metadata: {e}")
|
||||
continue
|
||||
return None, 404
|
||||
|
||||
class ModelMetadataProviderManager:
|
||||
"""Manager for selecting and using model metadata providers"""
|
||||
|
||||
@@ -485,7 +432,7 @@ class ModelMetadataProviderManager:
|
||||
if is_default or self.default_provider is None:
|
||||
self.default_provider = name
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str, provider_name: str = None) -> Optional[Dict]:
|
||||
async def get_model_by_hash(self, model_hash: str, provider_name: str = None) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Find model by hash using specified or default provider"""
|
||||
provider = self._get_provider(provider_name)
|
||||
return await provider.get_model_by_hash(model_hash)
|
||||
@@ -505,11 +452,6 @@ class ModelMetadataProviderManager:
|
||||
provider = self._get_provider(provider_name)
|
||||
return await provider.get_model_version_info(version_id)
|
||||
|
||||
async def get_model_metadata(self, model_id: str, provider_name: str = None) -> Tuple[Optional[Dict], int]:
|
||||
"""Fetch model metadata using specified or default provider"""
|
||||
provider = self._get_provider(provider_name)
|
||||
return await provider.get_model_metadata(model_id)
|
||||
|
||||
def _get_provider(self, provider_name: str = None) -> ModelMetadataProvider:
|
||||
"""Get provider by name or default provider"""
|
||||
if provider_name and provider_name in self.providers:
|
||||
|
||||
@@ -215,7 +215,7 @@ class ModelRouteUtils:
|
||||
else:
|
||||
metadata_provider = await get_default_metadata_provider()
|
||||
|
||||
civitai_metadata = await metadata_provider.get_model_by_hash(sha256)
|
||||
civitai_metadata, error = await metadata_provider.get_model_by_hash(sha256)
|
||||
if not civitai_metadata:
|
||||
# Mark as not from CivitAI if not found
|
||||
local_metadata['from_civitai'] = False
|
||||
@@ -387,10 +387,10 @@ class ModelRouteUtils:
|
||||
metadata_provider = await get_default_metadata_provider()
|
||||
|
||||
# Fetch and update metadata
|
||||
civitai_metadata = await metadata_provider.get_model_by_hash(local_metadata["sha256"])
|
||||
civitai_metadata, error = await metadata_provider.get_model_by_hash(local_metadata["sha256"])
|
||||
if not civitai_metadata:
|
||||
await ModelRouteUtils.handle_not_found_on_civitai(metadata_path, local_metadata)
|
||||
return web.json_response({"success": False, "error": "Not found on CivitAI"}, status=404)
|
||||
return web.json_response({"success": False, "error": error}, status=404)
|
||||
|
||||
await ModelRouteUtils.update_model_metadata(metadata_path, local_metadata, civitai_metadata, metadata_provider)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user