mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 22:52:12 -03:00
feat(routes): integrate CivitAI model version retrieval for various model types
This commit is contained in:
@@ -14,6 +14,7 @@ from ..services.settings_manager import settings
|
||||
from ..services.server_i18n import server_i18n
|
||||
from ..services.model_file_service import ModelFileService, ModelMoveService
|
||||
from ..services.websocket_progress_callback import WebSocketProgressCallback
|
||||
from ..services.metadata_service import get_default_metadata_provider
|
||||
from ..config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -84,14 +85,17 @@ class BaseModelRoutes(ABC):
|
||||
# Autocomplete route
|
||||
app.router.add_get(f'/api/{prefix}/relative-paths', self.get_relative_paths)
|
||||
|
||||
# Common CivitAI integration
|
||||
app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions)
|
||||
app.router.add_get(f'/api/{prefix}/civitai/model/version/{{modelVersionId}}', self.get_civitai_model_by_version)
|
||||
app.router.add_get(f'/api/{prefix}/civitai/model/hash/{{hash}}', self.get_civitai_model_by_hash)
|
||||
|
||||
# Common Download management
|
||||
app.router.add_post(f'/api/download-model', self.download_model)
|
||||
app.router.add_get(f'/api/download-model-get', self.download_model_get)
|
||||
app.router.add_get(f'/api/cancel-download-get', self.cancel_download_get)
|
||||
app.router.add_get(f'/api/download-progress/{{download_id}}', self.get_download_progress)
|
||||
|
||||
# app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions)
|
||||
|
||||
# Add generic page route
|
||||
app.router.add_get(f'/{prefix}', self.handle_models_page)
|
||||
|
||||
@@ -704,10 +708,101 @@ class BaseModelRoutes(ABC):
|
||||
|
||||
async def get_civitai_versions(self, request: web.Request) -> web.Response:
|
||||
"""Get available versions for a Civitai model with local availability info"""
|
||||
# This will be implemented by subclasses as they need CivitAI client access
|
||||
return web.json_response({
|
||||
"error": "Not implemented in base class"
|
||||
}, status=501)
|
||||
try:
|
||||
model_id = request.match_info['model_id']
|
||||
metadata_provider = await get_default_metadata_provider()
|
||||
response = await metadata_provider.get_model_versions(model_id)
|
||||
if not response or not response.get('modelVersions'):
|
||||
return web.Response(status=404, text="Model not found")
|
||||
|
||||
versions = response.get('modelVersions', [])
|
||||
model_type = response.get('type', '')
|
||||
|
||||
# Check model type - allow subclasses to override validation
|
||||
if not self._validate_civitai_model_type(model_type):
|
||||
return web.json_response({
|
||||
'error': f"Model type mismatch. Expected {self._get_expected_model_types()}, got {model_type}"
|
||||
}, status=400)
|
||||
|
||||
# Check local availability for each version
|
||||
for version in versions:
|
||||
# Find the model file (type="Model" and primary=true) in the files list
|
||||
model_file = self._find_model_file(version.get('files', []))
|
||||
|
||||
if model_file:
|
||||
sha256 = model_file.get('hashes', {}).get('SHA256')
|
||||
if sha256:
|
||||
# Set existsLocally and localPath at the version level
|
||||
version['existsLocally'] = self.service.has_hash(sha256)
|
||||
if version['existsLocally']:
|
||||
version['localPath'] = self.service.get_path_by_hash(sha256)
|
||||
|
||||
# Also set the model file size at the version level for easier access
|
||||
version['modelSizeKB'] = model_file.get('sizeKB')
|
||||
else:
|
||||
# No model file found in this version
|
||||
version['existsLocally'] = False
|
||||
|
||||
return web.json_response(versions)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching {self.model_type} model versions: {e}")
|
||||
return web.Response(status=500, text=str(e))
|
||||
|
||||
async def get_civitai_model_by_version(self, request: web.Request) -> web.Response:
|
||||
"""Get CivitAI model details by model version ID"""
|
||||
try:
|
||||
model_version_id = request.match_info.get('modelVersionId')
|
||||
|
||||
# Get model details from metadata provider
|
||||
metadata_provider = await get_default_metadata_provider()
|
||||
model, error_msg = await metadata_provider.get_model_version_info(model_version_id)
|
||||
|
||||
if not model:
|
||||
# Log warning for failed model retrieval
|
||||
logger.warning(f"Failed to fetch model version {model_version_id}: {error_msg}")
|
||||
|
||||
# Determine status code based on error message
|
||||
status_code = 404 if error_msg and "not found" in error_msg.lower() else 500
|
||||
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": error_msg or "Failed to fetch model information"
|
||||
}, status=status_code)
|
||||
|
||||
return web.json_response(model)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model details: {e}")
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_civitai_model_by_hash(self, request: web.Request) -> web.Response:
|
||||
"""Get CivitAI model details by hash"""
|
||||
try:
|
||||
hash = request.match_info.get('hash')
|
||||
metadata_provider = await get_default_metadata_provider()
|
||||
model = await metadata_provider.get_model_by_hash(hash)
|
||||
return web.json_response(model)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model details by hash: {e}")
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
def _validate_civitai_model_type(self, model_type: str) -> bool:
|
||||
"""Validate CivitAI model type - to be overridden by subclasses"""
|
||||
return True # Default: accept all types
|
||||
|
||||
def _get_expected_model_types(self) -> str:
|
||||
"""Get expected model types string for error messages - to be overridden by subclasses"""
|
||||
return "any model type"
|
||||
|
||||
def _find_model_file(self, files: list) -> dict:
|
||||
"""Find the appropriate model file from the files list - can be overridden by subclasses"""
|
||||
# Find the primary model file (type="Model" and primary=true) in the files list
|
||||
return next((file for file in files if file.get('type') == 'Model' and file.get('primary') == True), None)
|
||||
|
||||
# Common model move handlers
|
||||
async def move_model(self, request: web.Request) -> web.Response:
|
||||
|
||||
Reference in New Issue
Block a user