From 96517cbdef6c7f30879c50fbf16ef84712915284 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Mon, 11 Aug 2025 15:31:49 +0800 Subject: [PATCH] fix: update model_id and model_version_id handling across various services for improved flexibility --- py/routes/misc_routes.py | 6 +- py/services/civitai_client.py | 95 +++++++++++++------- py/services/download_manager.py | 26 ++++-- py/services/model_scanner.py | 15 ++-- py/utils/routes_common.py | 25 +++--- static/js/api/modelApiFactory.js | 12 +-- static/js/managers/import/DownloadManager.js | 40 ++++----- 7 files changed, 126 insertions(+), 93 deletions(-) diff --git a/py/routes/misc_routes.py b/py/routes/misc_routes.py index be07d105..1ecd5a5a 100644 --- a/py/routes/misc_routes.py +++ b/py/routes/misc_routes.py @@ -654,13 +654,13 @@ class MiscRoutes: exists = False model_type = None - if await lora_scanner.check_model_version_exists(model_id, model_version_id): + if await lora_scanner.check_model_version_exists(model_version_id): exists = True model_type = 'lora' - elif checkpoint_scanner and await checkpoint_scanner.check_model_version_exists(model_id, model_version_id): + elif checkpoint_scanner and await checkpoint_scanner.check_model_version_exists(model_version_id): exists = True model_type = 'checkpoint' - elif embedding_scanner and await embedding_scanner.check_model_version_exists(model_id, model_version_id): + elif embedding_scanner and await embedding_scanner.check_model_version_exists(model_version_id): exists = True model_type = 'embedding' diff --git a/py/services/civitai_client.py b/py/services/civitai_client.py index af6f948a..13b8f96b 100644 --- a/py/services/civitai_client.py +++ b/py/services/civitai_client.py @@ -223,11 +223,11 @@ class CivitaiClient: logger.error(f"Error fetching model versions: {e}") return None - async def get_model_version(self, model_id: int, version_id: int = None) -> Optional[Dict]: + async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: """Get specific model version with additional metadata Args: - model_id: The Civitai model ID + model_id: The Civitai model ID (optional if version_id is provided) version_id: Optional specific version ID to retrieve Returns: @@ -235,37 +235,72 @@ class CivitaiClient: """ try: session = await self._ensure_fresh_session() - - # Step 1: Get model data to find version_id if not provided and get additional metadata - async with session.get(f"{self.base_url}/models/{model_id}") as response: - if response.status != 200: - return None - - data = await response.json() - model_versions = data.get('modelVersions', []) - - # Step 2: Determine the version_id to use - target_version_id = version_id - if target_version_id is None: - target_version_id = model_versions[0].get('id') - - # Step 3: Get detailed version info using the version_id headers = self._get_request_headers() - async with session.get(f"{self.base_url}/model-versions/{target_version_id}", headers=headers) as response: - if response.status != 200: - return None + + # Case 1: Only version_id is provided + if model_id is None and version_id is not None: + # First get the version info to extract model_id + async with session.get(f"{self.base_url}/model-versions/{version_id}", headers=headers) as response: + if response.status != 200: + return None + + version = await response.json() + model_id = version.get('modelId') + + if not model_id: + logger.error(f"No modelId found in version {version_id}") + return None - version = await response.json() + # Now get the model data for additional metadata + async with session.get(f"{self.base_url}/models/{model_id}") as response: + if response.status != 200: + return version # Return version without additional metadata + + model_data = await response.json() + + # Enrich version with model data + version['model']['description'] = model_data.get("description") + version['model']['tags'] = model_data.get("tags", []) + version['creator'] = model_data.get("creator") + + return version + + # Case 2: model_id is provided (with or without version_id) + elif model_id is not None: + # Step 1: Get model data to find version_id if not provided and get additional metadata + async with session.get(f"{self.base_url}/models/{model_id}") as response: + if response.status != 200: + return None + + data = await response.json() + model_versions = data.get('modelVersions', []) + + # Step 2: Determine the version_id to use + target_version_id = version_id + if target_version_id is None: + target_version_id = model_versions[0].get('id') - # Step 4: Enrich version_info with model data - # Add description and tags from model data - version['model']['description'] = data.get("description") - version['model']['tags'] = data.get("tags", []) - - # Add creator from model data - version['creator'] = data.get("creator") - - return version + # Step 3: Get detailed version info using the version_id + async with session.get(f"{self.base_url}/model-versions/{target_version_id}", headers=headers) as response: + if response.status != 200: + return None + + version = await response.json() + + # Step 4: Enrich version_info with model data + # Add description and tags from model data + version['model']['description'] = data.get("description") + version['model']['tags'] = data.get("tags", []) + + # Add creator from model data + version['creator'] = data.get("creator") + + return version + + # Case 3: Neither model_id nor version_id provided + else: + logger.error("Either model_id or version_id must be provided") + return None except Exception as e: logger.error(f"Error fetching model version: {e}") diff --git a/py/services/download_manager.py b/py/services/download_manager.py index aacdc362..55e8715d 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -54,15 +54,15 @@ class DownloadManager: """Get the checkpoint scanner from registry""" return await ServiceRegistry.get_checkpoint_scanner() - async def download_from_civitai(self, model_id: int, model_version_id: int, + async def download_from_civitai(self, model_id: int = None, model_version_id: int = None, save_dir: str = None, relative_path: str = '', progress_callback=None, use_default_paths: bool = False, download_id: str = None) -> Dict: """Download model from Civitai with task tracking and concurrency control Args: - model_id: Civitai model ID - model_version_id: Civitai model version ID + model_id: Civitai model ID (optional if model_version_id is provided) + model_version_id: Civitai model version ID (optional if model_id is provided) save_dir: Directory to save the model relative_path: Relative path within save_dir progress_callback: Callback function for progress updates @@ -72,6 +72,10 @@ class DownloadManager: Returns: Dict with download result """ + # Validate that at least one identifier is provided + if not model_id and not model_version_id: + return {'success': False, 'error': 'Either model_id or model_version_id must be provided'} + # Use provided download_id or generate new one task_id = download_id or str(uuid.uuid4()) @@ -181,14 +185,19 @@ class DownloadManager: # Check both scanners lora_scanner = await self._get_lora_scanner() checkpoint_scanner = await self._get_checkpoint_scanner() + embedding_scanner = await ServiceRegistry.get_embedding_scanner() # Check lora scanner first - if await lora_scanner.check_model_version_exists(model_id, model_version_id): + if await lora_scanner.check_model_version_exists(model_version_id): return {'success': False, 'error': 'Model version already exists in lora library'} # Check checkpoint scanner - if await checkpoint_scanner.check_model_version_exists(model_id, model_version_id): + if await checkpoint_scanner.check_model_version_exists(model_version_id): return {'success': False, 'error': 'Model version already exists in checkpoint library'} + + # Check embedding scanner + if await embedding_scanner.check_model_version_exists(model_version_id): + return {'success': False, 'error': 'Model version already exists in embedding library'} # Get civitai client civitai_client = await self._get_civitai_client() @@ -211,23 +220,22 @@ class DownloadManager: # Case 2: model_version_id was None, check after getting version_info if model_version_id is None: - version_model_id = version_info.get('modelId') version_id = version_info.get('id') if model_type == 'lora': # Check lora scanner lora_scanner = await self._get_lora_scanner() - if await lora_scanner.check_model_version_exists(version_model_id, version_id): + if await lora_scanner.check_model_version_exists(version_id): return {'success': False, 'error': 'Model version already exists in lora library'} elif model_type == 'checkpoint': # Check checkpoint scanner checkpoint_scanner = await self._get_checkpoint_scanner() - if await checkpoint_scanner.check_model_version_exists(version_model_id, version_id): + if await checkpoint_scanner.check_model_version_exists(version_id): return {'success': False, 'error': 'Model version already exists in checkpoint library'} elif model_type == 'embedding': # Embeddings are not checked in scanners, but we can still check if it exists embedding_scanner = await ServiceRegistry.get_embedding_scanner() - if await embedding_scanner.check_model_version_exists(version_model_id, version_id): + if await embedding_scanner.check_model_version_exists(version_id): return {'success': False, 'error': 'Model version already exists in embedding library'} # Handle use_default_paths diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index bf58edde..6fd16a82 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -1149,13 +1149,12 @@ class ModelScanner: if len(self._hash_index._duplicate_filenames[file_name]) <= 1: del self._hash_index._duplicate_filenames[file_name] - async def check_model_version_exists(self, model_id: int, model_version_id: int) -> bool: + async def check_model_version_exists(self, model_version_id: int) -> bool: """Check if a specific model version exists in the cache - + Args: - model_id: Civitai model ID model_version_id: Civitai model version ID - + Returns: bool: True if the model version exists, False otherwise """ @@ -1163,13 +1162,11 @@ class ModelScanner: cache = await self.get_cached_data() if not cache or not cache.raw_data: return False - + for item in cache.raw_data: - if (item.get('civitai') and - item['civitai'].get('modelId') == model_id and - item['civitai'].get('id') == model_version_id): + if item.get('civitai') and item['civitai'].get('id') == model_version_id: return True - + return False except Exception as e: logger.error(f"Error checking model version existence: {e}") diff --git a/py/utils/routes_common.py b/py/utils/routes_common.py index 40b02cb5..1b583d72 100644 --- a/py/utils/routes_common.py +++ b/py/utils/routes_common.py @@ -580,16 +580,19 @@ class ModelRouteUtils: }) # Check which identifier is provided and convert to int - try: - model_id = int(data.get('model_id')) - except (TypeError, ValueError): - return web.json_response({ - 'success': False, - 'error': "Invalid model_id: Must be an integer" - }, status=400) + model_id = None + model_version_id = None + + if data.get('model_id'): + try: + model_id = int(data.get('model_id')) + except (TypeError, ValueError): + return web.json_response({ + 'success': False, + 'error': "Invalid model_id: Must be an integer" + }, status=400) # Convert model_version_id to int if provided - model_version_id = None if data.get('model_version_id'): try: model_version_id = int(data.get('model_version_id')) @@ -599,11 +602,11 @@ class ModelRouteUtils: 'error': "Invalid model_version_id: Must be an integer" }, status=400) - # Only model_id is required, model_version_id is optional - if not model_id: + # At least one identifier is required + if not model_id and not model_version_id: return web.json_response({ 'success': False, - 'error': "Missing required parameter: Please provide 'model_id'" + 'error': "Missing required parameter: Please provide either 'model_id' or 'model_version_id'" }, status=400) use_default_paths = data.get('use_default_paths', False) diff --git a/static/js/api/modelApiFactory.js b/static/js/api/modelApiFactory.js index 5aad0d6e..51c0c055 100644 --- a/static/js/api/modelApiFactory.js +++ b/static/js/api/modelApiFactory.js @@ -17,16 +17,16 @@ export function createModelApiClient(modelType) { } } -let _singletonClient = null; +let _singletonClients = new Map(); -export function getModelApiClient() { - const currentType = state.currentPageType; +export function getModelApiClient(modelType = null) { + const targetType = modelType || state.currentPageType; - if (!_singletonClient || _singletonClient.modelType !== currentType) { - _singletonClient = createModelApiClient(currentType); + if (!_singletonClients.has(targetType)) { + _singletonClients.set(targetType, createModelApiClient(targetType)); } - return _singletonClient; + return _singletonClients.get(targetType); } export function resetAndReload(updateFolders = false) { diff --git a/static/js/managers/import/DownloadManager.js b/static/js/managers/import/DownloadManager.js index 00e5d497..59638ff7 100644 --- a/static/js/managers/import/DownloadManager.js +++ b/static/js/managers/import/DownloadManager.js @@ -1,4 +1,6 @@ import { showToast } from '../../utils/uiHelpers.js'; +import { getModelApiClient } from '../../api/modelApiFactory.js'; +import { MODEL_TYPES } from '../../api/apiConfig.js'; export class DownloadManager { constructor(importManager) { @@ -199,39 +201,27 @@ export class DownloadManager { updateProgress(0, completedDownloads, lora.name); try { + console.log(`lora:`, lora); // Download the LoRA with download ID - const response = await fetch('/api/download-model', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - model_id: lora.modelId, - model_version_id: lora.id, - model_root: loraRoot, - relative_path: targetPath.replace(loraRoot + '/', ''), - download_id: batchDownloadId - }) - }); + const response = await getModelApiClient(MODEL_TYPES.LORA).downloadModel( + lora.modelId, + lora.modelVersionId, + loraRoot, + targetPath.replace(loraRoot + '/', ''), + batchDownloadId + ); - if (!response.ok) { - const errorText = await response.text(); - console.error(`Failed to download LoRA ${lora.name}: ${errorText}`); - - // Check if this is an early access error (status 401 is the key indicator) - if (response.status === 401) { - accessFailures++; - this.importManager.loadingManager.setStatus( - `Failed to download ${lora.name}: Access restricted` - ); - } - + if (!response.success) { + console.error(`Failed to download LoRA ${lora.name}: ${response.error}`); + failedDownloads++; // Continue with next download } else { completedDownloads++; - + // Update progress to show completion of current LoRA updateProgress(100, completedDownloads, ''); - + if (completedDownloads + failedDownloads < this.importManager.downloadableLoRAs.length) { this.importManager.loadingManager.setStatus( `Completed ${completedDownloads}/${this.importManager.downloadableLoRAs.length} LoRAs. Starting next download...`