From c692713ffb0ecfcc2b450ca3ed97d6da7d7c7b89 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 9 Jul 2025 10:26:03 +0800 Subject: [PATCH] refactor: Simplify model version existence checks and enhance version retrieval methods in scanners --- py/routes/misc_routes.py | 64 ++++++++------------------------- py/services/download_manager.py | 5 +++ py/services/model_scanner.py | 56 +++++++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 50 deletions(-) diff --git a/py/routes/misc_routes.py b/py/routes/misc_routes.py index 92c048ee..54146ce4 100644 --- a/py/routes/misc_routes.py +++ b/py/routes/misc_routes.py @@ -636,32 +636,18 @@ class MiscRoutes: 'error': 'Parameter modelVersionId must be an integer' }, status=400) - # Check if the specific version exists in either scanner's cache + # Check if the specific version exists in either scanner exists = False model_type = None - # First check lora cache - lora_cache = await lora_scanner.get_cached_data() - if lora_cache and lora_cache.raw_data: - for item in lora_cache.raw_data: - if (item.get('civitai') and - item['civitai'].get('modelId') == model_id and - item['civitai'].get('id') == model_version_id): - exists = True - model_type = 'lora' - break - - # If not found in lora cache, check checkpoint cache - if not exists and checkpoint_scanner: - checkpoint_cache = await checkpoint_scanner.get_cached_data() - if checkpoint_cache and checkpoint_cache.raw_data: - for item in checkpoint_cache.raw_data: - if (item.get('civitai') and - item['civitai'].get('modelId') == model_id and - item['civitai'].get('id') == model_version_id): - exists = True - model_type = 'checkpoint' - break + # Check lora scanner first + if await lora_scanner.check_model_version_exists(model_id, model_version_id): + exists = True + model_type = 'lora' + # If not found in lora, check checkpoint scanner + elif checkpoint_scanner and await checkpoint_scanner.check_model_version_exists(model_id, model_version_id): + exists = True + model_type = 'checkpoint' return web.json_response({ 'success': True, @@ -671,35 +657,13 @@ class MiscRoutes: # If modelVersionId is not provided, return all version IDs for the model else: - # Lists to collect version IDs from both scanners - lora_versions = [] + # Get versions from lora scanner first + lora_versions = await lora_scanner.get_model_versions_by_id(model_id) checkpoint_versions = [] - # Check lora cache - lora_cache = await lora_scanner.get_cached_data() - if lora_cache and lora_cache.raw_data: - for item in lora_cache.raw_data: - if (item.get('civitai') and - item['civitai'].get('modelId') == model_id and - item['civitai'].get('id')): - lora_versions.append({ - 'versionId': item['civitai'].get('id'), - 'name': item['civitai'].get('name'), - 'fileName': item.get('file_name', '') - }) - - # Check checkpoint cache - checkpoint_cache = await checkpoint_scanner.get_cached_data() - if checkpoint_cache and checkpoint_cache.raw_data: - for item in checkpoint_cache.raw_data: - if (item.get('civitai') and - item['civitai'].get('modelId') == model_id and - item['civitai'].get('id')): - checkpoint_versions.append({ - 'versionId': item['civitai'].get('id'), - 'name': item['civitai'].get('name'), - 'fileName': item.get('file_name', '') - }) + # Only check checkpoint scanner if no lora versions found + if not lora_versions: + checkpoint_versions = await checkpoint_scanner.get_model_versions_by_id(model_id) # Determine model type and combine results model_type = None diff --git a/py/services/download_manager.py b/py/services/download_manager.py index ca11e297..f0520ac7 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -82,6 +82,11 @@ class DownloadManager: else: return {'success': False, 'error': f'Model type "{model_type_from_info}" is not supported for download'} + scanner = model_type == 'checkpoint' and await self._get_checkpoint_scanner() or await self._get_lora_scanner() + + if scanner.check_model_version_exists(model_id, model_version_id): + return {'success': False, 'error': 'Model version already exists in library'} + # Handle use_default_paths if use_default_paths: # Set save_dir based on model type diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index f9008013..fdd9c020 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -1362,3 +1362,59 @@ class ModelScanner: if file_name in self._hash_index._duplicate_filenames: if len(self._hash_index._duplicate_filenames[file_name]) <= 1: del self._hash_index._duplicate_filenames[file_name] + + async def check_model_version_exists(self, model_id: int, model_version_id: int) -> bool: + """Check if a specific model version exists in the cache + + Args: + model_id: Civitai model ID + model_version_id: Civitai model version ID + + Returns: + bool: True if the model version exists, False otherwise + """ + try: + cache = await self.get_cached_data() + if not cache or not cache.raw_data: + return False + + for item in cache.raw_data: + if (item.get('civitai') and + item['civitai'].get('modelId') == model_id and + item['civitai'].get('id') == model_version_id): + return True + + return False + except Exception as e: + logger.error(f"Error checking model version existence: {e}") + return False + + async def get_model_versions_by_id(self, model_id: int) -> List[Dict]: + """Get all versions of a model by its ID + + Args: + model_id: Civitai model ID + + Returns: + List[Dict]: List of version information dictionaries + """ + try: + cache = await self.get_cached_data() + if not cache or not cache.raw_data: + return [] + + versions = [] + for item in cache.raw_data: + if (item.get('civitai') and + item['civitai'].get('modelId') == model_id and + item['civitai'].get('id')): + versions.append({ + 'versionId': item['civitai'].get('id'), + 'name': item['civitai'].get('name'), + 'fileName': item.get('file_name', '') + }) + + return versions + except Exception as e: + logger.error(f"Error getting model versions: {e}") + return []