From d7cb546c5f10d406def9069c924983df2f780b6d Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 2 Jul 2025 18:25:42 +0800 Subject: [PATCH] refactor: Simplify model download handling by consolidating download logic and updating parameter usage --- py/routes/api_routes.py | 63 +----------------- py/routes/checkpoints_routes.py | 66 +------------------ py/services/civitai_client.py | 2 +- py/services/download_manager.py | 38 ++--------- py/utils/routes_common.py | 12 ++-- .../js/managers/CheckpointDownloadManager.js | 15 ++--- static/js/managers/DownloadManager.js | 15 ++--- 7 files changed, 25 insertions(+), 186 deletions(-) diff --git a/py/routes/api_routes.py b/py/routes/api_routes.py index 26a15ea1..37562ad2 100644 --- a/py/routes/api_routes.py +++ b/py/routes/api_routes.py @@ -437,68 +437,7 @@ class ApiRoutes: }, status=500) async def download_lora(self, request: web.Request) -> web.Response: - async with self._download_lock: - try: - if self.download_manager is None: - self.download_manager = await ServiceRegistry.get_download_manager() - - data = await request.json() - - # Create progress callback - async def progress_callback(progress): - await ws_manager.broadcast({ - 'status': 'progress', - 'progress': progress - }) - - # Check which identifier is provided - download_url = data.get('download_url') - model_hash = data.get('model_hash') - model_version_id = data.get('model_version_id') - - # Validate that at least one identifier is provided - if not any([download_url, model_hash, model_version_id]): - return web.Response( - status=400, - text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'" - ) - - result = await self.download_manager.download_from_civitai( - download_url=download_url, - model_hash=model_hash, - model_version_id=model_version_id, - save_dir=data.get('lora_root'), - relative_path=data.get('relative_path'), - progress_callback=progress_callback - ) - - if not result.get('success', False): - error_message = result.get('error', 'Unknown error') - - # Return 401 for early access errors - if 'early access' in error_message.lower(): - logger.warning(f"Early access download failed: {error_message}") - return web.Response( - status=401, # Use 401 status code to match Civitai's response - text=error_message - ) - - return web.Response(status=500, text=error_message) - - return web.json_response(result) - except Exception as e: - error_message = str(e) - - # Check if this might be an early access error - if '401' in error_message: - logger.warning(f"Early access error (401): {error_message}") - return web.Response( - status=401, - text="Early Access Restriction: This LoRA requires purchase. Please buy early access on Civitai.com." - ) - - logger.error(f"Error downloading LoRA: {error_message}") - return web.Response(status=500, text=error_message) + return await ModelRouteUtils.handle_download_model(request, self.download_manager, model_type="lora") async def move_model(self, request: web.Request) -> web.Response: diff --git a/py/routes/checkpoints_routes.py b/py/routes/checkpoints_routes.py index 83cdf611..321f82d6 100644 --- a/py/routes/checkpoints_routes.py +++ b/py/routes/checkpoints_routes.py @@ -544,71 +544,7 @@ class CheckpointsRoutes: async def download_checkpoint(self, request: web.Request) -> web.Response: """Handle checkpoint download request""" - async with self._download_lock: - # Get the download manager from service registry if not already initialized - if self.download_manager is None: - self.download_manager = await ServiceRegistry.get_download_manager() - - try: - data = await request.json() - - # Create progress callback that uses checkpoint-specific WebSocket - async def progress_callback(progress): - await ws_manager.broadcast_checkpoint_progress({ - 'status': 'progress', - 'progress': progress - }) - - # Check which identifier is provided - download_url = data.get('download_url') - model_hash = data.get('model_hash') - model_version_id = data.get('model_version_id') - - # Validate that at least one identifier is provided - if not any([download_url, model_hash, model_version_id]): - return web.Response( - status=400, - text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'" - ) - - result = await self.download_manager.download_from_civitai( - download_url=download_url, - model_hash=model_hash, - model_version_id=model_version_id, - save_dir=data.get('checkpoint_root'), - relative_path=data.get('relative_path', ''), - progress_callback=progress_callback, - model_type="checkpoint" - ) - - if not result.get('success', False): - error_message = result.get('error', 'Unknown error') - - # Return 401 for early access errors - if 'early access' in error_message.lower(): - logger.warning(f"Early access download failed: {error_message}") - return web.Response( - status=401, - text=f"Early Access Restriction: {error_message}" - ) - - return web.Response(status=500, text=error_message) - - return web.json_response(result) - - except Exception as e: - error_message = str(e) - - # Check if this might be an early access error - if '401' in error_message: - logger.warning(f"Early access error (401): {error_message}") - return web.Response( - status=401, - text="Early Access Restriction: This model requires purchase. Please ensure you have purchased early access and are logged in to Civitai." - ) - - logger.error(f"Error downloading checkpoint: {error_message}") - return web.Response(status=500, text=error_message) + return await ModelRouteUtils.handle_download_model(request, self.download_manager, model_type="checkpoint") async def get_checkpoint_roots(self, request): """Return the checkpoint root directories""" diff --git a/py/services/civitai_client.py b/py/services/civitai_client.py index bec4f5d4..929f13d1 100644 --- a/py/services/civitai_client.py +++ b/py/services/civitai_client.py @@ -267,7 +267,7 @@ class CivitaiClient: # Replace index with modelId if 'index' in result: del result['index'] - result['modelId'] = model_id + result['modelId'] = int(model_id) # Add model field with metadata from top level result['model'] = { diff --git a/py/services/download_manager.py b/py/services/download_manager.py index 2a0ae048..85b8bf52 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -48,16 +48,15 @@ class DownloadManager: """Get the checkpoint scanner from registry""" return await ServiceRegistry.get_checkpoint_scanner() - async def download_from_civitai(self, download_url: str = None, model_hash: str = None, + async def download_from_civitai(self, model_id: str = None, model_version_id: str = None, save_dir: str = None, relative_path: str = '', progress_callback=None, model_type: str = "lora") -> Dict: """Download model from Civitai Args: - download_url: Direct download URL for the model - model_hash: SHA256 hash of the model - model_version_id: Civitai model version ID + model_id: Civitai model ID + model_version_id: Civitai model version ID (optional, if not provided, will download the latest version) save_dir: Directory to save the model to relative_path: Relative path within save_dir progress_callback: Callback function for progress updates @@ -77,25 +76,10 @@ class DownloadManager: civitai_client = await self._get_civitai_client() # Get version info based on the provided identifier - version_info = None - error_msg = None - - if model_hash: - # Get model by hash - version_info = await civitai_client.get_model_by_hash(model_hash) - elif model_version_id: - # Use model version ID directly - version_info, error_msg = await civitai_client.get_model_version_info(model_version_id) - elif download_url: - # Extract version ID from download URL - version_id = download_url.split('/')[-1] - version_info, error_msg = await civitai_client.get_model_version_info(version_id) - + version_info = await civitai_client.get_model_version(model_id, model_version_id) if not version_info: - if error_msg and "model not found" in error_msg.lower(): - return {'success': False, 'error': f'Model not found on Civitai: {error_msg}'} - return {'success': False, 'error': error_msg or 'Failed to fetch model metadata'} + return {'success': False, 'error': 'Failed to fetch model metadata'} # Check if this is an early access model if version_info.get('earlyAccessEndsAt'): @@ -137,18 +121,6 @@ class DownloadManager: metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path) logger.info(f"Creating LoraMetadata for {file_name}") - # 5.1 Get and update model tags, description and creator info - model_id = version_info.get('modelId') - if model_id: - model_metadata, _ = await civitai_client.get_model_metadata(str(model_id)) - if model_metadata: - if model_metadata.get("tags"): - metadata.tags = model_metadata.get("tags", []) - if model_metadata.get("description"): - metadata.modelDescription = model_metadata.get("description", "") - if model_metadata.get("creator"): - metadata.civitai["creator"] = model_metadata.get("creator") - # 6. Start download process result = await self._execute_download( download_url=file_info.get('downloadUrl', ''), diff --git a/py/utils/routes_common.py b/py/utils/routes_common.py index 9a533859..992ad812 100644 --- a/py/utils/routes_common.py +++ b/py/utils/routes_common.py @@ -587,15 +587,14 @@ class ModelRouteUtils: }) # Check which identifier is provided - download_url = data.get('download_url') - model_hash = data.get('model_hash') + model_id = data.get('model_id') model_version_id = data.get('model_version_id') - # Validate that at least one identifier is provided - if not any([download_url, model_hash, model_version_id]): + # Only model_id is required, model_version_id is optional + if not model_id: return web.Response( status=400, - text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'" + text="Missing required parameter: Please provide 'model_id'" ) # Use the correct root directory based on model type @@ -603,8 +602,7 @@ class ModelRouteUtils: save_dir = data.get(root_key) result = await download_manager.download_from_civitai( - download_url=download_url, - model_hash=model_hash, + model_id=model_id, model_version_id=model_version_id, save_dir=save_dir, relative_path=data.get('relative_path', ''), diff --git a/static/js/managers/CheckpointDownloadManager.js b/static/js/managers/CheckpointDownloadManager.js index 37f4081f..76ead14e 100644 --- a/static/js/managers/CheckpointDownloadManager.js +++ b/static/js/managers/CheckpointDownloadManager.js @@ -61,6 +61,7 @@ export class CheckpointDownloadManager { this.currentVersion = null; this.versions = []; this.modelInfo = null; + this.modelId = null; this.modelVersionId = null; // Clear selected folder and remove selection from UI @@ -79,12 +80,12 @@ export class CheckpointDownloadManager { try { this.loadingManager.showSimpleLoading('Fetching model versions...'); - const modelId = this.extractModelId(url); - if (!modelId) { + this.modelId = this.extractModelId(url); + if (!this.modelId) { throw new Error('Invalid Civitai URL format'); } - const response = await fetch(`/api/checkpoints/civitai/versions/${modelId}`); + const response = await fetch(`/api/checkpoints/civitai/versions/${this.modelId}`); if (!response.ok) { const errorData = await response.json().catch(() => ({})); if (errorData && errorData.error && errorData.error.includes('Model type mismatch')) { @@ -296,11 +297,6 @@ export class CheckpointDownloadManager { } try { - const downloadUrl = this.currentVersion.downloadUrl; - if (!downloadUrl) { - throw new Error('No download URL available'); - } - // Show enhanced loading with progress details const updateProgress = this.loadingManager.showDownloadProgress(1); updateProgress(0, 0, this.currentVersion.name); @@ -338,7 +334,8 @@ export class CheckpointDownloadManager { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ - download_url: downloadUrl, + model_id: this.modelId, + model_version_id: this.currentVersion.id, checkpoint_root: checkpointRoot, relative_path: targetFolder }) diff --git a/static/js/managers/DownloadManager.js b/static/js/managers/DownloadManager.js index 24fa7528..17f58f3d 100644 --- a/static/js/managers/DownloadManager.js +++ b/static/js/managers/DownloadManager.js @@ -63,6 +63,7 @@ export class DownloadManager { this.currentVersion = null; this.versions = []; this.modelInfo = null; + this.modelId = null; this.modelVersionId = null; // Clear selected folder and remove selection from UI @@ -81,12 +82,12 @@ export class DownloadManager { try { this.loadingManager.showSimpleLoading('Fetching model versions...'); - const modelId = this.extractModelId(url); - if (!modelId) { + this.modelId = this.extractModelId(url); + if (!this.modelId) { throw new Error('Invalid Civitai URL format'); } - const response = await fetch(`/api/civitai/versions/${modelId}`); + const response = await fetch(`/api/civitai/versions/${this.modelId}`); if (!response.ok) { const errorData = await response.json().catch(() => ({})); if (errorData && errorData.error && errorData.error.includes('Model type mismatch')) { @@ -306,11 +307,6 @@ export class DownloadManager { } try { - const downloadUrl = this.currentVersion.downloadUrl; - if (!downloadUrl) { - throw new Error('No download URL available'); - } - // Show enhanced loading with progress details const updateProgress = this.loadingManager.showDownloadProgress(1); updateProgress(0, 0, this.currentVersion.name); @@ -348,7 +344,8 @@ export class DownloadManager { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ - download_url: downloadUrl, + model_id: this.modelId, + model_version_id: this.currentVersion.id, lora_root: loraRoot, relative_path: targetFolder })