diff --git a/py/routes/example_images_routes.py b/py/routes/example_images_routes.py index 5408808a..c26e2db3 100644 --- a/py/routes/example_images_routes.py +++ b/py/routes/example_images_routes.py @@ -1,7 +1,6 @@ import logging from ..utils.example_images_download_manager import DownloadManager from ..utils.example_images_processor import ExampleImagesProcessor -from ..utils.example_images_metadata import MetadataUpdater from ..utils.example_images_file_manager import ExampleImagesFileManager logger = logging.getLogger(__name__) diff --git a/py/routes/misc_routes.py b/py/routes/misc_routes.py index 1428d5b2..8ba4febe 100644 --- a/py/routes/misc_routes.py +++ b/py/routes/misc_routes.py @@ -10,6 +10,7 @@ from ..utils.usage_stats import UsageStats from ..utils.lora_metadata import extract_trained_words from ..config import config from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS, NODE_TYPES, DEFAULT_NODE_COLOR +from ..services.service_registry import ServiceRegistry import re logger = logging.getLogger(__name__) @@ -106,6 +107,9 @@ class MiscRoutes: # Node registry endpoints app.router.add_post('/api/register-nodes', MiscRoutes.register_nodes) app.router.add_get('/api/get-registry', MiscRoutes.get_registry) + + # Add new route for checking if a model exists in the library + app.router.add_get('/api/check-model-exists', MiscRoutes.check_model_exists) @staticmethod async def clear_cache(request): @@ -580,3 +584,142 @@ class MiscRoutes: 'error': 'Internal Error', 'message': str(e) }, status=500) + + @staticmethod + async def check_model_exists(request): + """ + Check if a model with specified modelId and optionally modelVersionId exists in the library + + Expects query parameters: + - modelId: int - Civitai model ID (required) + - modelVersionId: int - Civitai model version ID (optional) + + Returns: + - If modelVersionId is provided: JSON with a boolean 'exists' field + - If modelVersionId is not provided: JSON with a list of modelVersionIds that exist in the library + """ + try: + # Get the modelId and modelVersionId from query parameters + model_id_str = request.query.get('modelId') + model_version_id_str = request.query.get('modelVersionId') + + # Validate modelId parameter (required) + if not model_id_str: + return web.json_response({ + 'success': False, + 'error': 'Missing required parameter: modelId' + }, status=400) + + try: + # Convert modelId to integer + model_id = int(model_id_str) + except ValueError: + return web.json_response({ + 'success': False, + 'error': 'Parameter modelId must be an integer' + }, status=400) + + # Get both lora and checkpoint scanners + registry = ServiceRegistry.get_instance() + lora_scanner = await registry.get_lora_scanner() + checkpoint_scanner = await registry.get_checkpoint_scanner() + + # If modelVersionId is provided, check for specific version + if model_version_id_str: + try: + model_version_id = int(model_version_id_str) + except ValueError: + return web.json_response({ + 'success': False, + 'error': 'Parameter modelVersionId must be an integer' + }, status=400) + + # Check if the specific version exists in either scanner's cache + exists = False + model_type = None + + # First check lora cache + lora_cache = await lora_scanner.get_cached_data() + if lora_cache and lora_cache.raw_data: + for item in lora_cache.raw_data: + if (item.get('civitai') and + item['civitai'].get('modelId') == model_id and + item['civitai'].get('id') == model_version_id): + exists = True + model_type = 'lora' + break + + # If not found in lora cache, check checkpoint cache + if not exists and checkpoint_scanner: + checkpoint_cache = await checkpoint_scanner.get_cached_data() + if checkpoint_cache and checkpoint_cache.raw_data: + for item in checkpoint_cache.raw_data: + if (item.get('civitai') and + item['civitai'].get('modelId') == model_id and + item['civitai'].get('id') == model_version_id): + exists = True + model_type = 'checkpoint' + break + + return web.json_response({ + 'success': True, + 'exists': exists, + 'modelType': model_type if exists else None + }) + + # If modelVersionId is not provided, return all version IDs for the model + else: + # Lists to collect version IDs from both scanners + lora_versions = [] + checkpoint_versions = [] + + # Check lora cache + lora_cache = await lora_scanner.get_cached_data() + if lora_cache and lora_cache.raw_data: + for item in lora_cache.raw_data: + if (item.get('civitai') and + item['civitai'].get('modelId') == model_id and + item['civitai'].get('id')): + lora_versions.append({ + 'versionId': item['civitai'].get('id'), + 'name': item['civitai'].get('name'), + 'fileName': item.get('file_name', '') + }) + + # Check checkpoint cache + checkpoint_cache = await checkpoint_scanner.get_cached_data() + if checkpoint_cache and checkpoint_cache.raw_data: + for item in checkpoint_cache.raw_data: + if (item.get('civitai') and + item['civitai'].get('modelId') == model_id and + item['civitai'].get('id')): + checkpoint_versions.append({ + 'versionId': item['civitai'].get('id'), + 'name': item['civitai'].get('name'), + 'fileName': item.get('file_name', '') + }) + + # Determine model type and combine results + model_type = None + versions = [] + + if lora_versions: + model_type = 'lora' + versions = lora_versions + elif checkpoint_versions: + model_type = 'checkpoint' + versions = checkpoint_versions + + return web.json_response({ + 'success': True, + 'modelId': model_id, + 'modelType': model_type, + 'versions': versions + }) + + except Exception as e: + logger.error(f"Failed to check model existence: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500)