From 933e2fc01d14f047f9aaa7e024627410bbfd2ee2 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 17 Sep 2025 15:47:30 +0800 Subject: [PATCH] feat(routes): integrate CivitAI model version retrieval for various model types --- py/routes/base_model_routes.py | 107 +++++++++++++++++++++++++++++++-- py/routes/checkpoint_routes.py | 59 +++--------------- py/routes/embedding_routes.py | 59 +++--------------- py/routes/lora_routes.py | 102 +++---------------------------- 4 files changed, 126 insertions(+), 201 deletions(-) diff --git a/py/routes/base_model_routes.py b/py/routes/base_model_routes.py index aea644e3..499d4617 100644 --- a/py/routes/base_model_routes.py +++ b/py/routes/base_model_routes.py @@ -14,6 +14,7 @@ from ..services.settings_manager import settings from ..services.server_i18n import server_i18n from ..services.model_file_service import ModelFileService, ModelMoveService from ..services.websocket_progress_callback import WebSocketProgressCallback +from ..services.metadata_service import get_default_metadata_provider from ..config import config logger = logging.getLogger(__name__) @@ -84,14 +85,17 @@ class BaseModelRoutes(ABC): # Autocomplete route app.router.add_get(f'/api/{prefix}/relative-paths', self.get_relative_paths) + # Common CivitAI integration + app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions) + app.router.add_get(f'/api/{prefix}/civitai/model/version/{{modelVersionId}}', self.get_civitai_model_by_version) + app.router.add_get(f'/api/{prefix}/civitai/model/hash/{{hash}}', self.get_civitai_model_by_hash) + # Common Download management app.router.add_post(f'/api/download-model', self.download_model) app.router.add_get(f'/api/download-model-get', self.download_model_get) app.router.add_get(f'/api/cancel-download-get', self.cancel_download_get) app.router.add_get(f'/api/download-progress/{{download_id}}', self.get_download_progress) - # app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions) - # Add generic page route app.router.add_get(f'/{prefix}', self.handle_models_page) @@ -704,10 +708,101 @@ class BaseModelRoutes(ABC): async def get_civitai_versions(self, request: web.Request) -> web.Response: """Get available versions for a Civitai model with local availability info""" - # This will be implemented by subclasses as they need CivitAI client access - return web.json_response({ - "error": "Not implemented in base class" - }, status=501) + try: + model_id = request.match_info['model_id'] + metadata_provider = await get_default_metadata_provider() + response = await metadata_provider.get_model_versions(model_id) + if not response or not response.get('modelVersions'): + return web.Response(status=404, text="Model not found") + + versions = response.get('modelVersions', []) + model_type = response.get('type', '') + + # Check model type - allow subclasses to override validation + if not self._validate_civitai_model_type(model_type): + return web.json_response({ + 'error': f"Model type mismatch. Expected {self._get_expected_model_types()}, got {model_type}" + }, status=400) + + # Check local availability for each version + for version in versions: + # Find the model file (type="Model" and primary=true) in the files list + model_file = self._find_model_file(version.get('files', [])) + + if model_file: + sha256 = model_file.get('hashes', {}).get('SHA256') + if sha256: + # Set existsLocally and localPath at the version level + version['existsLocally'] = self.service.has_hash(sha256) + if version['existsLocally']: + version['localPath'] = self.service.get_path_by_hash(sha256) + + # Also set the model file size at the version level for easier access + version['modelSizeKB'] = model_file.get('sizeKB') + else: + # No model file found in this version + version['existsLocally'] = False + + return web.json_response(versions) + except Exception as e: + logger.error(f"Error fetching {self.model_type} model versions: {e}") + return web.Response(status=500, text=str(e)) + + async def get_civitai_model_by_version(self, request: web.Request) -> web.Response: + """Get CivitAI model details by model version ID""" + try: + model_version_id = request.match_info.get('modelVersionId') + + # Get model details from metadata provider + metadata_provider = await get_default_metadata_provider() + model, error_msg = await metadata_provider.get_model_version_info(model_version_id) + + if not model: + # Log warning for failed model retrieval + logger.warning(f"Failed to fetch model version {model_version_id}: {error_msg}") + + # Determine status code based on error message + status_code = 404 if error_msg and "not found" in error_msg.lower() else 500 + + return web.json_response({ + "success": False, + "error": error_msg or "Failed to fetch model information" + }, status=status_code) + + return web.json_response(model) + except Exception as e: + logger.error(f"Error fetching model details: {e}") + return web.json_response({ + "success": False, + "error": str(e) + }, status=500) + + async def get_civitai_model_by_hash(self, request: web.Request) -> web.Response: + """Get CivitAI model details by hash""" + try: + hash = request.match_info.get('hash') + metadata_provider = await get_default_metadata_provider() + model = await metadata_provider.get_model_by_hash(hash) + return web.json_response(model) + except Exception as e: + logger.error(f"Error fetching model details by hash: {e}") + return web.json_response({ + "success": False, + "error": str(e) + }, status=500) + + def _validate_civitai_model_type(self, model_type: str) -> bool: + """Validate CivitAI model type - to be overridden by subclasses""" + return True # Default: accept all types + + def _get_expected_model_types(self) -> str: + """Get expected model types string for error messages - to be overridden by subclasses""" + return "any model type" + + def _find_model_file(self, files: list) -> dict: + """Find the appropriate model file from the files list - can be overridden by subclasses""" + # Find the primary model file (type="Model" and primary=true) in the files list + return next((file for file in files if file.get('type') == 'Model' and file.get('primary') == True), None) # Common model move handlers async def move_model(self, request: web.Request) -> web.Response: diff --git a/py/routes/checkpoint_routes.py b/py/routes/checkpoint_routes.py index a0f6a027..712eaafc 100644 --- a/py/routes/checkpoint_routes.py +++ b/py/routes/checkpoint_routes.py @@ -36,9 +36,6 @@ class CheckpointRoutes(BaseModelRoutes): def setup_specific_routes(self, app: web.Application, prefix: str): """Setup Checkpoint-specific routes""" - # Checkpoint-specific CivitAI integration - app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions_checkpoint) - # Checkpoint info by name app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_checkpoint_info) @@ -46,6 +43,14 @@ class CheckpointRoutes(BaseModelRoutes): app.router.add_get(f'/api/{prefix}/checkpoints_roots', self.get_checkpoints_roots) app.router.add_get(f'/api/{prefix}/unet_roots', self.get_unet_roots) + def _validate_civitai_model_type(self, model_type: str) -> bool: + """Validate CivitAI model type for Checkpoint""" + return model_type.lower() == 'checkpoint' + + def _get_expected_model_types(self) -> str: + """Get expected model types string for error messages""" + return "Checkpoint" + async def get_checkpoint_info(self, request: web.Request) -> web.Response: """Get detailed information for a specific checkpoint by name""" try: @@ -61,54 +66,6 @@ class CheckpointRoutes(BaseModelRoutes): logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True) return web.json_response({"error": str(e)}, status=500) - async def get_civitai_versions_checkpoint(self, request: web.Request) -> web.Response: - """Get available versions for a Civitai checkpoint model with local availability info""" - try: - model_id = request.match_info['model_id'] - metadata_provider = await get_default_metadata_provider() - response = await metadata_provider.get_model_versions(model_id) - if not response or not response.get('modelVersions'): - return web.Response(status=404, text="Model not found") - - versions = response.get('modelVersions', []) - model_type = response.get('type', '') - - # Check model type - should be Checkpoint - if model_type.lower() != 'checkpoint': - return web.json_response({ - 'error': f"Model type mismatch. Expected Checkpoint, got {model_type}" - }, status=400) - - # Check local availability for each version - for version in versions: - # Find the primary model file (type="Model" and primary=true) in the files list - model_file = next((file for file in version.get('files', []) - if file.get('type') == 'Model' and file.get('primary') == True), None) - - # If no primary file found, try to find any model file - if not model_file: - model_file = next((file for file in version.get('files', []) - if file.get('type') == 'Model'), None) - - if model_file: - sha256 = model_file.get('hashes', {}).get('SHA256') - if sha256: - # Set existsLocally and localPath at the version level - version['existsLocally'] = self.service.has_hash(sha256) - if version['existsLocally']: - version['localPath'] = self.service.get_path_by_hash(sha256) - - # Also set the model file size at the version level for easier access - version['modelSizeKB'] = model_file.get('sizeKB') - else: - # No model file found in this version - version['existsLocally'] = False - - return web.json_response(versions) - except Exception as e: - logger.error(f"Error fetching checkpoint model versions: {e}") - return web.Response(status=500, text=str(e)) - async def get_checkpoints_roots(self, request: web.Request) -> web.Response: """Return the list of checkpoint roots from config""" try: diff --git a/py/routes/embedding_routes.py b/py/routes/embedding_routes.py index ab028666..70b5b26b 100644 --- a/py/routes/embedding_routes.py +++ b/py/routes/embedding_routes.py @@ -35,12 +35,17 @@ class EmbeddingRoutes(BaseModelRoutes): def setup_specific_routes(self, app: web.Application, prefix: str): """Setup Embedding-specific routes""" - # Embedding-specific CivitAI integration - app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions_embedding) - # Embedding info by name app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_embedding_info) + def _validate_civitai_model_type(self, model_type: str) -> bool: + """Validate CivitAI model type for Embedding""" + return model_type.lower() in ['textualinversion', 'embedding'] + + def _get_expected_model_types(self) -> str: + """Get expected model types string for error messages""" + return "TextualInversion/Embedding" + async def get_embedding_info(self, request: web.Request) -> web.Response: """Get detailed information for a specific embedding by name""" try: @@ -55,51 +60,3 @@ class EmbeddingRoutes(BaseModelRoutes): except Exception as e: logger.error(f"Error in get_embedding_info: {e}", exc_info=True) return web.json_response({"error": str(e)}, status=500) - - async def get_civitai_versions_embedding(self, request: web.Request) -> web.Response: - """Get available versions for a Civitai embedding model with local availability info""" - try: - model_id = request.match_info['model_id'] - metadata_provider = await get_default_metadata_provider() - response = await metadata_provider.get_model_versions(model_id) - if not response or not response.get('modelVersions'): - return web.Response(status=404, text="Model not found") - - versions = response.get('modelVersions', []) - model_type = response.get('type', '') - - # Check model type - should be TextualInversion (Embedding) - if model_type.lower() not in ['textualinversion', 'embedding']: - return web.json_response({ - 'error': f"Model type mismatch. Expected TextualInversion/Embedding, got {model_type}" - }, status=400) - - # Check local availability for each version - for version in versions: - # Find the primary model file (type="Model" and primary=true) in the files list - model_file = next((file for file in version.get('files', []) - if file.get('type') == 'Model' and file.get('primary') == True), None) - - # If no primary file found, try to find any model file - if not model_file: - model_file = next((file for file in version.get('files', []) - if file.get('type') == 'Model'), None) - - if model_file: - sha256 = model_file.get('hashes', {}).get('SHA256') - if sha256: - # Set existsLocally and localPath at the version level - version['existsLocally'] = self.service.has_hash(sha256) - if version['existsLocally']: - version['localPath'] = self.service.get_path_by_hash(sha256) - - # Also set the model file size at the version level for easier access - version['modelSizeKB'] = model_file.get('sizeKB') - else: - # No model file found in this version - version['existsLocally'] = False - - return web.json_response(versions) - except Exception as e: - logger.error(f"Error fetching embedding model versions: {e}") - return web.Response(status=500, text=str(e)) diff --git a/py/routes/lora_routes.py b/py/routes/lora_routes.py index 4e261004..d70a2801 100644 --- a/py/routes/lora_routes.py +++ b/py/routes/lora_routes.py @@ -44,11 +44,6 @@ class LoraRoutes(BaseModelRoutes): app.router.add_get(f'/api/{prefix}/get-trigger-words', self.get_lora_trigger_words) app.router.add_get(f'/api/{prefix}/usage-tips-by-path', self.get_lora_usage_tips_by_path) - # CivitAI integration with LoRA-specific validation - app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions_lora) - app.router.add_get(f'/api/{prefix}/civitai/model/version/{{modelVersionId}}', self.get_civitai_model_by_version) - app.router.add_get(f'/api/{prefix}/civitai/model/hash/{{hash}}', self.get_civitai_model_by_hash) - # ComfyUI integration app.router.add_post(f'/api/{prefix}/get_trigger_words', self.get_trigger_words) @@ -76,6 +71,15 @@ class LoraRoutes(BaseModelRoutes): return params + def _validate_civitai_model_type(self, model_type: str) -> bool: + """Validate CivitAI model type for LoRA""" + from ..utils.constants import VALID_LORA_TYPES + return model_type.lower() in VALID_LORA_TYPES + + def _get_expected_model_types(self) -> str: + """Get expected model types string for error messages""" + return "LORA, LoCon, or DORA" + # LoRA-specific route handlers async def get_letter_counts(self, request: web.Request) -> web.Response: """Get count of LoRAs for each letter of the alphabet""" @@ -210,94 +214,6 @@ class LoraRoutes(BaseModelRoutes): 'error': str(e) }, status=500) - # CivitAI integration methods - async def get_civitai_versions_lora(self, request: web.Request) -> web.Response: - """Get available versions for a Civitai LoRA model with local availability info""" - try: - model_id = request.match_info['model_id'] - metadata_provider = await get_default_metadata_provider() - response = await metadata_provider.get_model_versions(model_id) - if not response or not response.get('modelVersions'): - return web.Response(status=404, text="Model not found") - - versions = response.get('modelVersions', []) - model_type = response.get('type', '') - - # Check model type - should be LORA, LoCon, or DORA - from ..utils.constants import VALID_LORA_TYPES - if model_type.lower() not in VALID_LORA_TYPES: - return web.json_response({ - 'error': f"Model type mismatch. Expected LORA or LoCon, got {model_type}" - }, status=400) - - # Check local availability for each version - for version in versions: - # Find the model file (type="Model") in the files list - model_file = next((file for file in version.get('files', []) - if file.get('type') == 'Model'), None) - - if model_file: - sha256 = model_file.get('hashes', {}).get('SHA256') - if sha256: - # Set existsLocally and localPath at the version level - version['existsLocally'] = self.service.has_hash(sha256) - if version['existsLocally']: - version['localPath'] = self.service.get_path_by_hash(sha256) - - # Also set the model file size at the version level for easier access - version['modelSizeKB'] = model_file.get('sizeKB') - else: - # No model file found in this version - version['existsLocally'] = False - - return web.json_response(versions) - except Exception as e: - logger.error(f"Error fetching LoRA model versions: {e}") - return web.Response(status=500, text=str(e)) - - async def get_civitai_model_by_version(self, request: web.Request) -> web.Response: - """Get CivitAI model details by model version ID""" - try: - model_version_id = request.match_info.get('modelVersionId') - - # Get model details from metadata provider - metadata_provider = await get_default_metadata_provider() - model, error_msg = await metadata_provider.get_model_version_info(model_version_id) - - if not model: - # Log warning for failed model retrieval - logger.warning(f"Failed to fetch model version {model_version_id}: {error_msg}") - - # Determine status code based on error message - status_code = 404 if error_msg and "not found" in error_msg.lower() else 500 - - return web.json_response({ - "success": False, - "error": error_msg or "Failed to fetch model information" - }, status=status_code) - - return web.json_response(model) - except Exception as e: - logger.error(f"Error fetching model details: {e}") - return web.json_response({ - "success": False, - "error": str(e) - }, status=500) - - async def get_civitai_model_by_hash(self, request: web.Request) -> web.Response: - """Get CivitAI model details by hash""" - try: - hash = request.match_info.get('hash') - metadata_provider = await get_default_metadata_provider() - model = await metadata_provider.get_model_by_hash(hash) - return web.json_response(model) - except Exception as e: - logger.error(f"Error fetching model details by hash: {e}") - return web.json_response({ - "success": False, - "error": str(e) - }, status=500) - async def get_trigger_words(self, request: web.Request) -> web.Response: """Get trigger words for specified LoRA models""" try: