mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
refactor(metadata): update model fetching methods to return error messages alongside results
This commit is contained in:
@@ -55,7 +55,7 @@ class RecipeMetadataParser(ABC):
|
|||||||
# Unpack the tuple to get the actual data
|
# Unpack the tuple to get the actual data
|
||||||
civitai_info, error_msg = civitai_info_tuple if isinstance(civitai_info_tuple, tuple) else (civitai_info_tuple, None)
|
civitai_info, error_msg = civitai_info_tuple if isinstance(civitai_info_tuple, tuple) else (civitai_info_tuple, None)
|
||||||
|
|
||||||
if not civitai_info or civitai_info.get("error") == "Model not found":
|
if not civitai_info or error_msg == "Model not found":
|
||||||
# Model not found or deleted
|
# Model not found or deleted
|
||||||
lora_entry['isDeleted'] = True
|
lora_entry['isDeleted'] = True
|
||||||
lora_entry['thumbnailUrl'] = '/loras_static/images/no-preview.png'
|
lora_entry['thumbnailUrl'] = '/loras_static/images/no-preview.png'
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
|||||||
result["base_model"] = metadata["baseModel"]
|
result["base_model"] = metadata["baseModel"]
|
||||||
elif "Model hash" in metadata and metadata_provider:
|
elif "Model hash" in metadata and metadata_provider:
|
||||||
model_hash = metadata["Model hash"]
|
model_hash = metadata["Model hash"]
|
||||||
model_info = await metadata_provider.get_model_by_hash(model_hash)
|
model_info, error = await metadata_provider.get_model_by_hash(model_hash)
|
||||||
if model_info:
|
if model_info:
|
||||||
result["base_model"] = model_info.get("baseModel", "")
|
result["base_model"] = model_info.get("baseModel", "")
|
||||||
elif "Model" in metadata and isinstance(metadata.get("resources"), list):
|
elif "Model" in metadata and isinstance(metadata.get("resources"), list):
|
||||||
@@ -100,7 +100,7 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
|||||||
if resource.get("type") == "model" and resource.get("name") == metadata.get("Model"):
|
if resource.get("type") == "model" and resource.get("name") == metadata.get("Model"):
|
||||||
# This is likely the checkpoint model
|
# This is likely the checkpoint model
|
||||||
if metadata_provider and resource.get("hash"):
|
if metadata_provider and resource.get("hash"):
|
||||||
model_info = await metadata_provider.get_model_by_hash(resource.get("hash"))
|
model_info, error = await metadata_provider.get_model_by_hash(resource.get("hash"))
|
||||||
if model_info:
|
if model_info:
|
||||||
result["base_model"] = model_info.get("baseModel", "")
|
result["base_model"] = model_info.get("baseModel", "")
|
||||||
|
|
||||||
@@ -201,11 +201,7 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
|||||||
if version_id and metadata_provider:
|
if version_id and metadata_provider:
|
||||||
try:
|
try:
|
||||||
# Use get_model_version_info instead of get_model_version
|
# Use get_model_version_info instead of get_model_version
|
||||||
civitai_info, error = await metadata_provider.get_model_version_info(version_id)
|
civitai_info = await metadata_provider.get_model_version_info(version_id)
|
||||||
|
|
||||||
if error:
|
|
||||||
logger.warning(f"Error getting model version info: {error}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
populated_entry = await self.populate_lora_from_civitai(
|
populated_entry = await self.populate_lora_from_civitai(
|
||||||
lora_entry,
|
lora_entry,
|
||||||
@@ -267,26 +263,23 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
|||||||
if version_id and metadata_provider:
|
if version_id and metadata_provider:
|
||||||
try:
|
try:
|
||||||
# Use get_model_version_info with the version ID
|
# Use get_model_version_info with the version ID
|
||||||
civitai_info, error = await metadata_provider.get_model_version_info(version_id)
|
civitai_info = await metadata_provider.get_model_version_info(version_id)
|
||||||
|
|
||||||
if error:
|
populated_entry = await self.populate_lora_from_civitai(
|
||||||
logger.warning(f"Error getting model version info: {error}")
|
lora_entry,
|
||||||
else:
|
civitai_info,
|
||||||
populated_entry = await self.populate_lora_from_civitai(
|
recipe_scanner,
|
||||||
lora_entry,
|
base_model_counts
|
||||||
civitai_info,
|
)
|
||||||
recipe_scanner,
|
|
||||||
base_model_counts
|
|
||||||
)
|
|
||||||
|
|
||||||
if populated_entry is None:
|
if populated_entry is None:
|
||||||
continue # Skip invalid LoRA types
|
continue # Skip invalid LoRA types
|
||||||
|
|
||||||
lora_entry = populated_entry
|
lora_entry = populated_entry
|
||||||
|
|
||||||
# Track this LoRA for deduplication
|
# Track this LoRA for deduplication
|
||||||
if version_id:
|
if version_id:
|
||||||
added_loras[version_id] = len(result["loras"])
|
added_loras[version_id] = len(result["loras"])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching Civitai info for model ID {version_id}: {e}")
|
logger.error(f"Error fetching Civitai info for model ID {version_id}: {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -624,7 +624,7 @@ class BaseModelRoutes(ABC):
|
|||||||
success = 0
|
success = 0
|
||||||
needs_resort = False
|
needs_resort = False
|
||||||
|
|
||||||
# Prepare models to process, only those without CivitAI data or missing tags, description, or creator
|
# Prepare models to process, only those without CivitAI data
|
||||||
enable_metadata_archive_db = settings.get('enable_metadata_archive_db', False)
|
enable_metadata_archive_db = settings.get('enable_metadata_archive_db', False)
|
||||||
to_process = [
|
to_process = [
|
||||||
model for model in cache.raw_data
|
model for model in cache.raw_data
|
||||||
@@ -633,9 +633,6 @@ class BaseModelRoutes(ABC):
|
|||||||
and (
|
and (
|
||||||
not model.get('civitai')
|
not model.get('civitai')
|
||||||
or not model['civitai'].get('id')
|
or not model['civitai'].get('id')
|
||||||
# or not model.get('tags') # Skipping tag cause it could be empty legitimately
|
|
||||||
# or not model.get('modelDescription')
|
|
||||||
# or not (model.get('civitai') and model['civitai'].get('creator'))
|
|
||||||
)
|
)
|
||||||
and (
|
and (
|
||||||
(enable_metadata_archive_db)
|
(enable_metadata_archive_db)
|
||||||
@@ -782,7 +779,13 @@ class BaseModelRoutes(ABC):
|
|||||||
try:
|
try:
|
||||||
hash = request.match_info.get('hash')
|
hash = request.match_info.get('hash')
|
||||||
metadata_provider = await get_default_metadata_provider()
|
metadata_provider = await get_default_metadata_provider()
|
||||||
model = await metadata_provider.get_model_by_hash(hash)
|
model, error = await metadata_provider.get_model_by_hash(hash)
|
||||||
|
if error:
|
||||||
|
logger.warning(f"Error getting model by hash: {error}")
|
||||||
|
return web.json_response({
|
||||||
|
"success": False,
|
||||||
|
"error": error
|
||||||
|
}, status=404)
|
||||||
return web.json_response(model)
|
return web.json_response(model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching model details by hash: {e}")
|
logger.error(f"Error fetching model details by hash: {e}")
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import shutil
|
|||||||
import tempfile
|
import tempfile
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
from ..services.downloader import get_downloader, Downloader
|
from ..services.downloader import get_downloader
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -265,7 +265,7 @@ class UpdateRoutes:
|
|||||||
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/commits/main"
|
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/commits/main"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
downloader = await Downloader.get_instance()
|
downloader = await get_downloader()
|
||||||
success, data = await downloader.make_request('GET', github_url, custom_headers={'Accept': 'application/vnd.github+json'})
|
success, data = await downloader.make_request('GET', github_url, custom_headers={'Accept': 'application/vnd.github+json'})
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
@@ -431,7 +431,7 @@ class UpdateRoutes:
|
|||||||
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest"
|
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
downloader = await Downloader.get_instance()
|
downloader = await get_downloader()
|
||||||
success, data = await downloader.make_request('GET', github_url, custom_headers={'Accept': 'application/vnd.github+json'})
|
success, data = await downloader.make_request('GET', github_url, custom_headers={'Accept': 'application/vnd.github+json'})
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
from datetime import datetime
|
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -59,17 +58,17 @@ class CivitaiClient:
|
|||||||
|
|
||||||
return success, result
|
return success, result
|
||||||
|
|
||||||
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||||
try:
|
try:
|
||||||
downloader = await get_downloader()
|
downloader = await get_downloader()
|
||||||
success, version = await downloader.make_request(
|
success, result = await downloader.make_request(
|
||||||
'GET',
|
'GET',
|
||||||
f"{self.base_url}/model-versions/by-hash/{model_hash}",
|
f"{self.base_url}/model-versions/by-hash/{model_hash}",
|
||||||
use_auth=True
|
use_auth=True
|
||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
# Get model ID from version data
|
# Get model ID from version data
|
||||||
model_id = version.get('modelId')
|
model_id = result.get('modelId')
|
||||||
if model_id:
|
if model_id:
|
||||||
# Fetch additional model metadata
|
# Fetch additional model metadata
|
||||||
success_model, data = await downloader.make_request(
|
success_model, data = await downloader.make_request(
|
||||||
@@ -79,17 +78,24 @@ class CivitaiClient:
|
|||||||
)
|
)
|
||||||
if success_model:
|
if success_model:
|
||||||
# Enrich version_info with model data
|
# Enrich version_info with model data
|
||||||
version['model']['description'] = data.get("description")
|
result['model']['description'] = data.get("description")
|
||||||
version['model']['tags'] = data.get("tags", [])
|
result['model']['tags'] = data.get("tags", [])
|
||||||
|
|
||||||
# Add creator from model data
|
# Add creator from model data
|
||||||
version['creator'] = data.get("creator")
|
result['creator'] = data.get("creator")
|
||||||
|
|
||||||
return version
|
return result, None
|
||||||
return None
|
|
||||||
|
# Handle specific error cases
|
||||||
|
if "not found" in str(result):
|
||||||
|
return None, "Model not found"
|
||||||
|
|
||||||
|
# Other error cases
|
||||||
|
logger.error(f"Failed to fetch model info for {model_hash[:10]}: {result}")
|
||||||
|
return None, str(result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"API Error: {str(e)}")
|
logger.error(f"API Error: {str(e)}")
|
||||||
return None
|
return None, str(e)
|
||||||
|
|
||||||
async def download_preview_image(self, image_url: str, save_path: str):
|
async def download_preview_image(self, image_url: str, save_path: str):
|
||||||
try:
|
try:
|
||||||
@@ -246,8 +252,8 @@ class CivitaiClient:
|
|||||||
return result, None
|
return result, None
|
||||||
|
|
||||||
# Handle specific error cases
|
# Handle specific error cases
|
||||||
if "404" in str(result):
|
if "not found" in str(result):
|
||||||
error_msg = f"Model not found (status 404)"
|
error_msg = f"Model not found"
|
||||||
logger.warning(f"Model version not found: {version_id} - {error_msg}")
|
logger.warning(f"Model version not found: {version_id} - {error_msg}")
|
||||||
return None, error_msg
|
return None, error_msg
|
||||||
|
|
||||||
@@ -259,59 +265,6 @@ class CivitaiClient:
|
|||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
return None, error_msg
|
return None, error_msg
|
||||||
|
|
||||||
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
|
|
||||||
"""Fetch model metadata (description, tags, and creator info) from Civitai API
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_id: The Civitai model ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[Optional[Dict], int]: A tuple containing:
|
|
||||||
- A dictionary with model metadata or None if not found
|
|
||||||
- The HTTP status code from the request (0 for exceptions)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
downloader = await get_downloader()
|
|
||||||
url = f"{self.base_url}/models/{model_id}"
|
|
||||||
|
|
||||||
success, result = await downloader.make_request(
|
|
||||||
'GET',
|
|
||||||
url,
|
|
||||||
use_auth=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if not success:
|
|
||||||
# Try to extract status code from error message
|
|
||||||
status_code = 0
|
|
||||||
if "404" in str(result):
|
|
||||||
status_code = 404
|
|
||||||
elif "401" in str(result):
|
|
||||||
status_code = 401
|
|
||||||
elif "403" in str(result):
|
|
||||||
status_code = 403
|
|
||||||
logger.warning(f"Failed to fetch model metadata: {result}")
|
|
||||||
return None, status_code
|
|
||||||
|
|
||||||
# Extract relevant metadata
|
|
||||||
metadata = {
|
|
||||||
"description": result.get("description") or "No model description available",
|
|
||||||
"tags": result.get("tags", []),
|
|
||||||
"creator": {
|
|
||||||
"username": result.get("creator", {}).get("username"),
|
|
||||||
"image": result.get("creator", {}).get("image")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if metadata["description"] or metadata["tags"] or metadata["creator"]["username"]:
|
|
||||||
return metadata, 200
|
|
||||||
else:
|
|
||||||
logger.warning(f"No metadata found for model {model_id}")
|
|
||||||
return None, 200
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error fetching model metadata: {e}", exc_info=True)
|
|
||||||
return None, 0
|
|
||||||
|
|
||||||
async def get_image_info(self, image_id: str) -> Optional[Dict]:
|
async def get_image_info(self, image_id: str) -> Optional[Dict]:
|
||||||
"""Fetch image information from Civitai API
|
"""Fetch image information from Civitai API
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
|
|||||||
import json
|
import json
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
import logging
|
import logging
|
||||||
import aiohttp
|
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
from typing import Optional, Dict, Tuple
|
from typing import Optional, Dict, Tuple
|
||||||
from .downloader import get_downloader
|
from .downloader import get_downloader
|
||||||
@@ -13,7 +12,7 @@ class ModelMetadataProvider(ABC):
|
|||||||
"""Base abstract class for all model metadata providers"""
|
"""Base abstract class for all model metadata providers"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||||
"""Find model by hash value"""
|
"""Find model by hash value"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -32,18 +31,13 @@ class ModelMetadataProvider(ABC):
|
|||||||
"""Fetch model version metadata"""
|
"""Fetch model version metadata"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
|
|
||||||
"""Fetch model metadata (description, tags, and creator info)"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
class CivitaiModelMetadataProvider(ModelMetadataProvider):
|
class CivitaiModelMetadataProvider(ModelMetadataProvider):
|
||||||
"""Provider that uses Civitai API for metadata"""
|
"""Provider that uses Civitai API for metadata"""
|
||||||
|
|
||||||
def __init__(self, civitai_client):
|
def __init__(self, civitai_client):
|
||||||
self.client = civitai_client
|
self.client = civitai_client
|
||||||
|
|
||||||
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||||
return await self.client.get_model_by_hash(model_hash)
|
return await self.client.get_model_by_hash(model_hash)
|
||||||
|
|
||||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||||
@@ -55,15 +49,12 @@ class CivitaiModelMetadataProvider(ModelMetadataProvider):
|
|||||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||||
return await self.client.get_model_version_info(version_id)
|
return await self.client.get_model_version_info(version_id)
|
||||||
|
|
||||||
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
|
|
||||||
return await self.client.get_model_metadata(model_id)
|
|
||||||
|
|
||||||
class CivArchiveModelMetadataProvider(ModelMetadataProvider):
|
class CivArchiveModelMetadataProvider(ModelMetadataProvider):
|
||||||
"""Provider that uses CivArchive HTML page parsing for metadata"""
|
"""Provider that uses CivArchive HTML page parsing for metadata"""
|
||||||
|
|
||||||
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||||
"""Not supported by CivArchive provider"""
|
"""Not supported by CivArchive provider"""
|
||||||
return None
|
return None, "CivArchive provider does not support hash lookup"
|
||||||
|
|
||||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||||
"""Not supported by CivArchive provider"""
|
"""Not supported by CivArchive provider"""
|
||||||
@@ -175,17 +166,13 @@ class CivArchiveModelMetadataProvider(ModelMetadataProvider):
|
|||||||
"""Not supported by CivArchive provider - requires both model_id and version_id"""
|
"""Not supported by CivArchive provider - requires both model_id and version_id"""
|
||||||
return None, "CivArchive provider requires both model_id and version_id"
|
return None, "CivArchive provider requires both model_id and version_id"
|
||||||
|
|
||||||
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
|
|
||||||
"""Not supported by CivArchive provider"""
|
|
||||||
return None, 404
|
|
||||||
|
|
||||||
class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
||||||
"""Provider that uses SQLite database for metadata"""
|
"""Provider that uses SQLite database for metadata"""
|
||||||
|
|
||||||
def __init__(self, db_path: str):
|
def __init__(self, db_path: str):
|
||||||
self.db_path = db_path
|
self.db_path = db_path
|
||||||
|
|
||||||
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||||
"""Find model by hash value from SQLite database"""
|
"""Find model by hash value from SQLite database"""
|
||||||
async with aiosqlite.connect(self.db_path) as db:
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
# Look up in model_files table to get model_id and version_id
|
# Look up in model_files table to get model_id and version_id
|
||||||
@@ -200,14 +187,15 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
|||||||
file_row = await cursor.fetchone()
|
file_row = await cursor.fetchone()
|
||||||
|
|
||||||
if not file_row:
|
if not file_row:
|
||||||
return None
|
return None, "Model not found"
|
||||||
|
|
||||||
# Get version details
|
# Get version details
|
||||||
model_id = file_row['model_id']
|
model_id = file_row['model_id']
|
||||||
version_id = file_row['version_id']
|
version_id = file_row['version_id']
|
||||||
|
|
||||||
# Build response in the same format as Civitai API
|
# Build response in the same format as Civitai API
|
||||||
return await self._get_version_with_model_data(db, model_id, version_id)
|
result = await self._get_version_with_model_data(db, model_id, version_id)
|
||||||
|
return result, None if result else "Error retrieving model data"
|
||||||
|
|
||||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||||
"""Get all versions of a model from SQLite database"""
|
"""Get all versions of a model from SQLite database"""
|
||||||
@@ -324,37 +312,6 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
|||||||
version_data = await self._get_version_with_model_data(db, model_id, version_id)
|
version_data = await self._get_version_with_model_data(db, model_id, version_id)
|
||||||
return version_data, None
|
return version_data, None
|
||||||
|
|
||||||
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
|
|
||||||
"""Fetch model metadata from SQLite database"""
|
|
||||||
async with aiosqlite.connect(self.db_path) as db:
|
|
||||||
db.row_factory = aiosqlite.Row
|
|
||||||
|
|
||||||
# Get model details
|
|
||||||
model_query = "SELECT name, type, data, username FROM models WHERE id = ?"
|
|
||||||
cursor = await db.execute(model_query, (model_id,))
|
|
||||||
model_row = await cursor.fetchone()
|
|
||||||
|
|
||||||
if not model_row:
|
|
||||||
return None, 404
|
|
||||||
|
|
||||||
# Parse data JSON
|
|
||||||
try:
|
|
||||||
model_data = json.loads(model_row['data'])
|
|
||||||
|
|
||||||
# Extract relevant metadata
|
|
||||||
metadata = {
|
|
||||||
"description": model_data.get("description", "No model description available"),
|
|
||||||
"tags": model_data.get("tags", []),
|
|
||||||
"creator": {
|
|
||||||
"username": model_row['username'] or model_data.get("creator", {}).get("username"),
|
|
||||||
"image": model_data.get("creator", {}).get("image")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return metadata, 200
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
return None, 500
|
|
||||||
|
|
||||||
async def _get_version_with_model_data(self, db, model_id, version_id) -> Optional[Dict]:
|
async def _get_version_with_model_data(self, db, model_id, version_id) -> Optional[Dict]:
|
||||||
"""Helper to build version data with model information"""
|
"""Helper to build version data with model information"""
|
||||||
# Get version details
|
# Get version details
|
||||||
@@ -409,15 +366,16 @@ class FallbackMetadataProvider(ModelMetadataProvider):
|
|||||||
def __init__(self, providers: list):
|
def __init__(self, providers: list):
|
||||||
self.providers = providers
|
self.providers = providers
|
||||||
|
|
||||||
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||||
for provider in self.providers:
|
for provider in self.providers:
|
||||||
try:
|
try:
|
||||||
result = await provider.get_model_by_hash(model_hash)
|
result, error = await provider.get_model_by_hash(model_hash)
|
||||||
if result:
|
if result:
|
||||||
return result
|
return result, error
|
||||||
except Exception:
|
except Exception as e:
|
||||||
|
logger.debug(f"Provider failed for get_model_by_hash: {e}")
|
||||||
continue
|
continue
|
||||||
return None
|
return None, "Model not found"
|
||||||
|
|
||||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||||
for provider in self.providers:
|
for provider in self.providers:
|
||||||
@@ -452,17 +410,6 @@ class FallbackMetadataProvider(ModelMetadataProvider):
|
|||||||
continue
|
continue
|
||||||
return None, "No provider could retrieve the data"
|
return None, "No provider could retrieve the data"
|
||||||
|
|
||||||
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
|
|
||||||
for provider in self.providers:
|
|
||||||
try:
|
|
||||||
result, status = await provider.get_model_metadata(model_id)
|
|
||||||
if result:
|
|
||||||
return result, status
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"Provider failed for get_model_metadata: {e}")
|
|
||||||
continue
|
|
||||||
return None, 404
|
|
||||||
|
|
||||||
class ModelMetadataProviderManager:
|
class ModelMetadataProviderManager:
|
||||||
"""Manager for selecting and using model metadata providers"""
|
"""Manager for selecting and using model metadata providers"""
|
||||||
|
|
||||||
@@ -485,7 +432,7 @@ class ModelMetadataProviderManager:
|
|||||||
if is_default or self.default_provider is None:
|
if is_default or self.default_provider is None:
|
||||||
self.default_provider = name
|
self.default_provider = name
|
||||||
|
|
||||||
async def get_model_by_hash(self, model_hash: str, provider_name: str = None) -> Optional[Dict]:
|
async def get_model_by_hash(self, model_hash: str, provider_name: str = None) -> Tuple[Optional[Dict], Optional[str]]:
|
||||||
"""Find model by hash using specified or default provider"""
|
"""Find model by hash using specified or default provider"""
|
||||||
provider = self._get_provider(provider_name)
|
provider = self._get_provider(provider_name)
|
||||||
return await provider.get_model_by_hash(model_hash)
|
return await provider.get_model_by_hash(model_hash)
|
||||||
@@ -505,11 +452,6 @@ class ModelMetadataProviderManager:
|
|||||||
provider = self._get_provider(provider_name)
|
provider = self._get_provider(provider_name)
|
||||||
return await provider.get_model_version_info(version_id)
|
return await provider.get_model_version_info(version_id)
|
||||||
|
|
||||||
async def get_model_metadata(self, model_id: str, provider_name: str = None) -> Tuple[Optional[Dict], int]:
|
|
||||||
"""Fetch model metadata using specified or default provider"""
|
|
||||||
provider = self._get_provider(provider_name)
|
|
||||||
return await provider.get_model_metadata(model_id)
|
|
||||||
|
|
||||||
def _get_provider(self, provider_name: str = None) -> ModelMetadataProvider:
|
def _get_provider(self, provider_name: str = None) -> ModelMetadataProvider:
|
||||||
"""Get provider by name or default provider"""
|
"""Get provider by name or default provider"""
|
||||||
if provider_name and provider_name in self.providers:
|
if provider_name and provider_name in self.providers:
|
||||||
|
|||||||
@@ -215,7 +215,7 @@ class ModelRouteUtils:
|
|||||||
else:
|
else:
|
||||||
metadata_provider = await get_default_metadata_provider()
|
metadata_provider = await get_default_metadata_provider()
|
||||||
|
|
||||||
civitai_metadata = await metadata_provider.get_model_by_hash(sha256)
|
civitai_metadata, error = await metadata_provider.get_model_by_hash(sha256)
|
||||||
if not civitai_metadata:
|
if not civitai_metadata:
|
||||||
# Mark as not from CivitAI if not found
|
# Mark as not from CivitAI if not found
|
||||||
local_metadata['from_civitai'] = False
|
local_metadata['from_civitai'] = False
|
||||||
@@ -387,10 +387,10 @@ class ModelRouteUtils:
|
|||||||
metadata_provider = await get_default_metadata_provider()
|
metadata_provider = await get_default_metadata_provider()
|
||||||
|
|
||||||
# Fetch and update metadata
|
# Fetch and update metadata
|
||||||
civitai_metadata = await metadata_provider.get_model_by_hash(local_metadata["sha256"])
|
civitai_metadata, error = await metadata_provider.get_model_by_hash(local_metadata["sha256"])
|
||||||
if not civitai_metadata:
|
if not civitai_metadata:
|
||||||
await ModelRouteUtils.handle_not_found_on_civitai(metadata_path, local_metadata)
|
await ModelRouteUtils.handle_not_found_on_civitai(metadata_path, local_metadata)
|
||||||
return web.json_response({"success": False, "error": "Not found on CivitAI"}, status=404)
|
return web.json_response({"success": False, "error": error}, status=404)
|
||||||
|
|
||||||
await ModelRouteUtils.update_model_metadata(metadata_path, local_metadata, civitai_metadata, metadata_provider)
|
await ModelRouteUtils.update_model_metadata(metadata_path, local_metadata, civitai_metadata, metadata_provider)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user