From 79011bd257bb5c4212f76ebf3db4d602386e02eb Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 9 Jul 2025 14:21:49 +0800 Subject: [PATCH] refactor: Update model_id and model_version_id types to integers and add validation in routes --- py/services/civitai_client.py | 6 +++--- py/services/download_manager.py | 37 ++++++++++++++++++++++++++------ py/services/websocket_manager.py | 3 --- py/utils/routes_common.py | 28 +++++++++++++++++++----- 4 files changed, 57 insertions(+), 17 deletions(-) diff --git a/py/services/civitai_client.py b/py/services/civitai_client.py index 929f13d1..0299988f 100644 --- a/py/services/civitai_client.py +++ b/py/services/civitai_client.py @@ -225,7 +225,7 @@ class CivitaiClient: logger.error(f"Error fetching model versions: {e}") return None - async def get_model_version(self, model_id: str, version_id: str = "") -> Optional[Dict]: + async def get_model_version(self, model_id: int, version_id: int) -> Optional[Dict]: """Get specific model version with additional metadata Args: @@ -250,7 +250,7 @@ class CivitaiClient: if version_id: # If version_id provided, find exact match for version in model_versions: - if str(version.get('id')) == str(version_id): + if version.get('id') == version_id: matched_version = version break else: @@ -267,7 +267,7 @@ class CivitaiClient: # Replace index with modelId if 'index' in result: del result['index'] - result['modelId'] = int(model_id) + result['modelId'] = model_id # Add model field with metadata from top level result['model'] = { diff --git a/py/services/download_manager.py b/py/services/download_manager.py index f0520ac7..f9248e45 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -48,8 +48,8 @@ class DownloadManager: """Get the checkpoint scanner from registry""" return await ServiceRegistry.get_checkpoint_scanner() - async def download_from_civitai(self, model_id: str = None, - model_version_id: str = None, save_dir: str = None, + async def download_from_civitai(self, model_id: int, + model_version_id: int, save_dir: str = None, relative_path: str = '', progress_callback=None, use_default_paths: bool = False) -> Dict: """Download model from Civitai @@ -65,6 +65,20 @@ class DownloadManager: Dict with download result """ try: + # Check if model version already exists in library + if model_version_id is not None: + # Case 1: model_version_id is provided, check both scanners + lora_scanner = await self._get_lora_scanner() + checkpoint_scanner = await self._get_checkpoint_scanner() + + # Check lora scanner first + if await lora_scanner.check_model_version_exists(model_id, model_version_id): + return {'success': False, 'error': 'Model version already exists in lora library'} + + # Check checkpoint scanner + if await checkpoint_scanner.check_model_version_exists(model_id, model_version_id): + return {'success': False, 'error': 'Model version already exists in checkpoint library'} + # Get civitai client civitai_client = await self._get_civitai_client() @@ -82,10 +96,21 @@ class DownloadManager: else: return {'success': False, 'error': f'Model type "{model_type_from_info}" is not supported for download'} - scanner = model_type == 'checkpoint' and await self._get_checkpoint_scanner() or await self._get_lora_scanner() - - if scanner.check_model_version_exists(model_id, model_version_id): - return {'success': False, 'error': 'Model version already exists in library'} + # Case 2: model_version_id was None, check after getting version_info + if model_version_id is None: + version_model_id = version_info.get('modelId') + version_id = version_info.get('id') + + if model_type == 'lora': + # Check lora scanner + lora_scanner = await self._get_lora_scanner() + if await lora_scanner.check_model_version_exists(version_model_id, version_id): + return {'success': False, 'error': 'Model version already exists in lora library'} + elif model_type == 'checkpoint': + # Check checkpoint scanner + checkpoint_scanner = await self._get_checkpoint_scanner() + if await checkpoint_scanner.check_model_version_exists(version_model_id, version_id): + return {'success': False, 'error': 'Model version already exists in checkpoint library'} # Handle use_default_paths if use_default_paths: diff --git a/py/services/websocket_manager.py b/py/services/websocket_manager.py index 958f0e38..1692fe54 100644 --- a/py/services/websocket_manager.py +++ b/py/services/websocket_manager.py @@ -52,9 +52,6 @@ class WebSocketManager: if not download_id: # Generate a new download ID if not provided download_id = str(uuid4()) - logger.info(f"Created new download ID: {download_id}") - else: - logger.info(f"Using provided download ID: {download_id}") # Store the websocket with its download ID self._download_websockets[download_id] = ws diff --git a/py/utils/routes_common.py b/py/utils/routes_common.py index e13b5cff..9f962c84 100644 --- a/py/utils/routes_common.py +++ b/py/utils/routes_common.py @@ -590,9 +590,25 @@ class ModelRouteUtils: 'download_id': download_id }) - # Check which identifier is provided - model_id = data.get('model_id') - model_version_id = data.get('model_version_id') + # Check which identifier is provided and convert to int + try: + model_id = int(data.get('model_id')) + except (TypeError, ValueError): + return web.Response( + status=400, + text="Invalid model_id: Must be an integer" + ) + + # Convert model_version_id to int if provided + model_version_id = None + if data.get('model_version_id'): + try: + model_version_id = int(data.get('model_version_id')) + except (TypeError, ValueError): + return web.Response( + status=400, + text="Invalid model_version_id: Must be an integer" + ) # Only model_id is required, model_version_id is optional if not model_id: @@ -696,8 +712,10 @@ class ModelRouteUtils: try: data = await request.json() file_path = data.get('file_path') - model_id = data.get('model_id') - model_version_id = data.get('model_version_id') + model_id = int(data.get('model_id')) + model_version_id = None + if data.get('model_version_id'): + model_version_id = int(data.get('model_version_id')) if not file_path or not model_id: return web.json_response({"success": False, "error": "Both file_path and model_id are required"}, status=400)