From a2b81ea099cbe4f5cc7ce881b13891dcaa2feba7 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 23 Jul 2025 14:39:02 +0800 Subject: [PATCH] refactor: Implement base model routes and services for LoRA and Checkpoint - Added BaseModelRoutes class to handle common routes and logic for model types. - Created CheckpointRoutes class inheriting from BaseModelRoutes for checkpoint-specific routes. - Implemented CheckpointService class for handling checkpoint-related data and operations. - Developed LoraService class for managing LoRA-specific functionalities. - Introduced ModelServiceFactory to manage service and route registrations for different model types. - Established methods for fetching, filtering, and formatting model data across services. - Integrated CivitAI metadata handling within model routes and services. - Added pagination and filtering capabilities for model data retrieval. --- py/lora_manager.py | 36 +- py/routes/api_routes.py | 1173 +------------------------- py/routes/base_model_routes.py | 431 ++++++++++ py/routes/checkpoint_routes.py | 170 ++++ py/routes/lora_routes.py | 763 ++++++++++++++--- py/routes/misc_routes.py | 5 +- py/routes/recipe_routes.py | 49 +- py/services/base_model_service.py | 248 ++++++ py/services/checkpoint_scanner.py | 33 +- py/services/checkpoint_service.py | 51 ++ py/services/lora_service.py | 172 ++++ py/services/model_service_factory.py | 137 +++ py/services/service_registry.py | 236 ++++-- py/utils/routes_common.py | 53 ++ standalone.py | 13 +- 15 files changed, 2185 insertions(+), 1385 deletions(-) create mode 100644 py/routes/base_model_routes.py create mode 100644 py/routes/checkpoint_routes.py create mode 100644 py/services/base_model_service.py create mode 100644 py/services/checkpoint_service.py create mode 100644 py/services/lora_service.py create mode 100644 py/services/model_service_factory.py diff --git a/py/lora_manager.py b/py/lora_manager.py index ed2ff8cb..dc54d8f4 100644 --- a/py/lora_manager.py +++ b/py/lora_manager.py @@ -6,10 +6,8 @@ from pathlib import Path from server import PromptServer # type: ignore from .config import config -from .routes.lora_routes import LoraRoutes -from .routes.api_routes import ApiRoutes +from .services.model_service_factory import ModelServiceFactory, register_default_model_types from .routes.recipe_routes import RecipeRoutes -from .routes.checkpoints_routes import CheckpointsRoutes from .routes.stats_routes import StatsRoutes from .routes.update_routes import UpdateRoutes from .routes.misc_routes import MiscRoutes @@ -17,6 +15,7 @@ from .routes.example_images_routes import ExampleImagesRoutes from .services.service_registry import ServiceRegistry from .services.settings_manager import settings from .utils.example_images_migration import ExampleImagesMigration +from .services.websocket_manager import ws_manager logger = logging.getLogger(__name__) @@ -28,7 +27,7 @@ class LoraManager: @classmethod def add_routes(cls): - """Initialize and register all routes""" + """Initialize and register all routes using the new refactored architecture""" app = PromptServer.instance.app # Configure aiohttp access logger to be less verbose @@ -110,27 +109,32 @@ class LoraManager: # Add static route for plugin assets app.router.add_static('/loras_static', config.static_path) - # Setup feature routes - lora_routes = LoraRoutes() - checkpoints_routes = CheckpointsRoutes() - stats_routes = StatsRoutes() + # Register default model types with the factory + register_default_model_types() - # Initialize routes - lora_routes.setup_routes(app) - checkpoints_routes.setup_routes(app) - stats_routes.setup_routes(app) # Add statistics routes - ApiRoutes.setup_routes(app) + # Setup all model routes using the factory + ModelServiceFactory.setup_all_routes(app) + + # Setup non-model-specific routes + stats_routes = StatsRoutes() + stats_routes.setup_routes(app) RecipeRoutes.setup_routes(app) UpdateRoutes.setup_routes(app) - MiscRoutes.setup_routes(app) # Register miscellaneous routes - ExampleImagesRoutes.setup_routes(app) # Register example images routes + MiscRoutes.setup_routes(app) + ExampleImagesRoutes.setup_routes(app) + + # Setup WebSocket routes that are shared across all model types + app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection) + app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection) + app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection) # Schedule service initialization app.on_startup.append(lambda app: cls._initialize_services()) # Add cleanup app.on_shutdown.append(cls._cleanup) - app.on_shutdown.append(ApiRoutes.cleanup) + + logger.info(f"LoRA Manager: Set up routes for {len(ModelServiceFactory.get_registered_types())} model types: {', '.join(ModelServiceFactory.get_registered_types())}") @classmethod async def _initialize_services(cls): diff --git a/py/routes/api_routes.py b/py/routes/api_routes.py index f5457935..5c9c93fe 100644 --- a/py/routes/api_routes.py +++ b/py/routes/api_routes.py @@ -1,1184 +1,37 @@ -import os -import json import logging from aiohttp import web -from typing import Dict -from server import PromptServer # type: ignore -from ..utils.routes_common import ModelRouteUtils -from ..utils.utils import get_lora_info - -from ..config import config from ..services.websocket_manager import ws_manager -import asyncio from .update_routes import UpdateRoutes -from ..utils.constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH, VALID_LORA_TYPES -from ..utils.exif_utils import ExifUtils -from ..utils.metadata_manager import MetadataManager -from ..services.service_registry import ServiceRegistry +from .lora_routes import LoraRoutes logger = logging.getLogger(__name__) class ApiRoutes: - """API route handlers for LoRA management""" + """Legacy API route handlers for backward compatibility""" def __init__(self): - self.scanner = None # Will be initialized in setup_routes - self.civitai_client = None # Will be initialized in setup_routes - self.download_manager = None # Will be initialized in setup_routes - self._download_lock = asyncio.Lock() - - async def initialize_services(self): - """Initialize services from ServiceRegistry""" - self.scanner = await ServiceRegistry.get_lora_scanner() - self.civitai_client = await ServiceRegistry.get_civitai_client() - self.download_manager = await ServiceRegistry.get_download_manager() + # Initialize the new LoRA routes + self.lora_routes = LoraRoutes() @classmethod def setup_routes(cls, app: web.Application): - """Register API routes""" + """Register API routes using the new refactored architecture""" routes = cls() - # Schedule service initialization on app startup - app.on_startup.append(lambda _: routes.initialize_services()) + # Setup the refactored LoRA routes + routes.lora_routes.setup_routes(app) - app.router.add_post('/api/delete_model', routes.delete_model) - app.router.add_post('/api/loras/exclude', routes.exclude_model) # Add new exclude endpoint - app.router.add_post('/api/fetch-civitai', routes.fetch_civitai) - app.router.add_post('/api/relink-civitai', routes.relink_civitai) # Add new relink endpoint - app.router.add_post('/api/replace_preview', routes.replace_preview) - app.router.add_get('/api/loras', routes.get_loras) - app.router.add_post('/api/fetch-all-civitai', routes.fetch_all_civitai) + # Setup WebSocket routes that are still shared app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection) - app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection) # Add new WebSocket route for download progress - app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection) # Add new WebSocket route - app.router.add_get('/api/lora-roots', routes.get_lora_roots) - app.router.add_get('/api/folders', routes.get_folders) - app.router.add_get('/api/civitai/versions/{model_id}', routes.get_civitai_versions) - app.router.add_get('/api/civitai/model/version/{modelVersionId}', routes.get_civitai_model_by_version) - app.router.add_get('/api/civitai/model/hash/{hash}', routes.get_civitai_model_by_hash) - app.router.add_post('/api/download-model', routes.download_model) - app.router.add_get('/api/download-model-get', routes.download_model_get) # Add new GET endpoint - app.router.add_get('/api/cancel-download-get', routes.cancel_download_get) - app.router.add_get('/api/download-progress/{download_id}', routes.get_download_progress) # Add new endpoint for download progress - app.router.add_post('/api/move_model', routes.move_model) - app.router.add_get('/api/lora-model-description', routes.get_lora_model_description) # Add new route - app.router.add_post('/api/loras/save-metadata', routes.save_metadata) - app.router.add_get('/api/lora-preview-url', routes.get_lora_preview_url) # Add new route - app.router.add_post('/api/move_models_bulk', routes.move_models_bulk) - app.router.add_get('/api/loras/top-tags', routes.get_top_tags) # Add new route for top tags - app.router.add_get('/api/loras/base-models', routes.get_base_models) # Add new route for base models - app.router.add_get('/api/lora-civitai-url', routes.get_lora_civitai_url) # Add new route for Civitai URL - app.router.add_post('/api/loras/rename', routes.rename_lora) # Add new route for renaming LoRA files - app.router.add_get('/api/loras/scan', routes.scan_loras) # Add new route for scanning LoRA files + app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection) + app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection) - # Add the new trigger words route - app.router.add_post('/loramanager/get_trigger_words', routes.get_trigger_words) - - # Add new endpoint for letter counts - app.router.add_get('/api/loras/letter-counts', routes.get_letter_counts) - - # Add new endpoints for copying lora data - app.router.add_get('/api/loras/get-notes', routes.get_lora_notes) - app.router.add_get('/api/loras/get-trigger-words', routes.get_lora_trigger_words) - - # Add update check routes + # Setup update routes that are not model-specific UpdateRoutes.setup_routes(app) - # Add new endpoints for finding duplicates - app.router.add_get('/api/loras/find-duplicates', routes.find_duplicate_loras) - app.router.add_get('/api/loras/find-filename-conflicts', routes.find_filename_conflicts) - - # Add new endpoint for bulk deleting loras - app.router.add_post('/api/loras/bulk-delete', routes.bulk_delete_loras) - - # Add new endpoint for verifying duplicates - app.router.add_post('/api/loras/verify-duplicates', routes.verify_duplicates) - - async def delete_model(self, request: web.Request) -> web.Response: - """Handle model deletion request""" - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - return await ModelRouteUtils.handle_delete_model(request, self.scanner) - - async def exclude_model(self, request: web.Request) -> web.Response: - """Handle model exclusion request""" - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - return await ModelRouteUtils.handle_exclude_model(request, self.scanner) - - async def fetch_civitai(self, request: web.Request) -> web.Response: - """Handle CivitAI metadata fetch request""" - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - response = await ModelRouteUtils.handle_fetch_civitai(request, self.scanner) - - # If successful, format the metadata before returning - if response.status == 200: - data = json.loads(response.body.decode('utf-8')) - if data.get("success") and data.get("metadata"): - formatted_metadata = self._format_lora_response(data["metadata"]) - return web.json_response({ - "success": True, - "metadata": formatted_metadata - }) - - # Otherwise, return the original response - return response - - async def replace_preview(self, request: web.Request) -> web.Response: - """Handle preview image replacement request""" - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - return await ModelRouteUtils.handle_replace_preview(request, self.scanner) - - async def scan_loras(self, request: web.Request) -> web.Response: - """Force a rescan of LoRA files""" - try: - # Get full_rebuild parameter from query string, default to false - full_rebuild = request.query.get('full_rebuild', 'false').lower() == 'true' - - await self.scanner.get_cached_data(force_refresh=True, rebuild_cache=full_rebuild) - return web.json_response({"status": "success", "message": "LoRA scan completed"}) - except Exception as e: - logger.error(f"Error in scan_loras: {e}", exc_info=True) - return web.json_response({"error": str(e)}, status=500) - - async def get_loras(self, request: web.Request) -> web.Response: - """Handle paginated LoRA data request""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - # Parse query parameters - page = int(request.query.get('page', '1')) - page_size = int(request.query.get('page_size', '20')) - sort_by = request.query.get('sort_by', 'name') - folder = request.query.get('folder', None) - search = request.query.get('search', None) - fuzzy_search = request.query.get('fuzzy', 'false').lower() == 'true' - - # Parse search options - search_options = { - 'filename': request.query.get('search_filename', 'true').lower() == 'true', - 'modelname': request.query.get('search_modelname', 'true').lower() == 'true', - 'tags': request.query.get('search_tags', 'false').lower() == 'true', - 'recursive': request.query.get('recursive', 'false').lower() == 'true' - } - - # Get filter parameters - base_models = request.query.get('base_models', None) - tags = request.query.get('tags', None) - favorites_only = request.query.get('favorites_only', 'false').lower() == 'true' # New parameter - - # New parameter for alphabet filtering - first_letter = request.query.get('first_letter', None) - - # New parameters for recipe filtering - lora_hash = request.query.get('lora_hash', None) - lora_hashes = request.query.get('lora_hashes', None) - - # Parse filter parameters - filters = {} - if base_models: - filters['base_model'] = base_models.split(',') - if tags: - filters['tags'] = tags.split(',') - - # Add lora hash filtering options - hash_filters = {} - if lora_hash: - hash_filters['single_hash'] = lora_hash.lower() - elif lora_hashes: - hash_filters['multiple_hashes'] = [h.lower() for h in lora_hashes.split(',')] - - # Get file data - data = await self.scanner.get_paginated_data( - page, - page_size, - sort_by=sort_by, - folder=folder, - search=search, - fuzzy_search=fuzzy_search, - base_models=filters.get('base_model', None), - tags=filters.get('tags', None), - search_options=search_options, - hash_filters=hash_filters, - favorites_only=favorites_only, # Pass favorites_only parameter - first_letter=first_letter # Pass the new first_letter parameter - ) - - # Get all available folders from cache - cache = await self.scanner.get_cached_data() - - # Convert output to match expected format - result = { - 'items': [self._format_lora_response(lora) for lora in data['items']], - 'folders': cache.folders, - 'total': data['total'], - 'page': data['page'], - 'page_size': data['page_size'], - 'total_pages': data['total_pages'] - } - - return web.json_response(result) - - except Exception as e: - logger.error(f"Error retrieving loras: {e}", exc_info=True) - return web.json_response({"error": str(e)}, status=500) - - def _format_lora_response(self, lora: Dict) -> Dict: - """Format LoRA data for API response""" - return { - "model_name": lora["model_name"], - "file_name": lora["file_name"], - "preview_url": config.get_preview_static_url(lora["preview_url"]), - "preview_nsfw_level": lora.get("preview_nsfw_level", 0), - "base_model": lora["base_model"], - "folder": lora["folder"], - "sha256": lora["sha256"], - "file_path": lora["file_path"].replace(os.sep, "/"), - "file_size": lora["size"], - "modified": lora["modified"], - "tags": lora["tags"], - "modelDescription": lora["modelDescription"], - "from_civitai": lora.get("from_civitai", True), - "usage_tips": lora.get("usage_tips", ""), - "notes": lora.get("notes", ""), - "favorite": lora.get("favorite", False), # Include favorite status in response - "civitai": ModelRouteUtils.filter_civitai_data(lora.get("civitai", {})) - } - - async def fetch_all_civitai(self, request: web.Request) -> web.Response: - """Fetch CivitAI metadata for all loras in the background""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - cache = await self.scanner.get_cached_data() - total = len(cache.raw_data) - processed = 0 - success = 0 - needs_resort = False - - # Prepare loras to process - to_process = [ - lora for lora in cache.raw_data - if lora.get('sha256') and (not lora.get('civitai') or 'id' not in lora.get('civitai')) and lora.get('from_civitai', True) # TODO: for lora not from CivitAI but added traineWords - ] - total_to_process = len(to_process) - - # Send initial progress - await ws_manager.broadcast({ - 'status': 'started', - 'total': total_to_process, - 'processed': 0, - 'success': 0 - }) - - for lora in to_process: - try: - original_name = lora.get('model_name') - if await ModelRouteUtils.fetch_and_update_model( - sha256=lora['sha256'], - file_path=lora['file_path'], - model_data=lora, - update_cache_func=self.scanner.update_single_model_cache - ): - success += 1 - if original_name != lora.get('model_name'): - needs_resort = True - - processed += 1 - - # Send progress update - await ws_manager.broadcast({ - 'status': 'processing', - 'total': total_to_process, - 'processed': processed, - 'success': success, - 'current_name': lora.get('model_name', 'Unknown') - }) - - except Exception as e: - logger.error(f"Error fetching CivitAI data for {lora['file_path']}: {e}") - - if needs_resort: - await cache.resort(name_only=True) - - # Send completion message - await ws_manager.broadcast({ - 'status': 'completed', - 'total': total_to_process, - 'processed': processed, - 'success': success - }) - - return web.json_response({ - "success": True, - "message": f"Successfully updated {success} of {processed} processed loras (total: {total})" - }) - - except Exception as e: - # Send error message - await ws_manager.broadcast({ - 'status': 'error', - 'error': str(e) - }) - logger.error(f"Error in fetch_all_civitai: {e}") - return web.Response(text=str(e), status=500) - - async def get_lora_roots(self, request: web.Request) -> web.Response: - """Get all configured LoRA root directories""" - return web.json_response({ - 'roots': config.loras_roots - }) - - async def get_folders(self, request: web.Request) -> web.Response: - """Get all folders in the cache""" - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - cache = await self.scanner.get_cached_data() - return web.json_response({ - 'folders': cache.folders - }) - - async def get_civitai_versions(self, request: web.Request) -> web.Response: - """Get available versions for a Civitai model with local availability info""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - if self.civitai_client is None: - self.civitai_client = await ServiceRegistry.get_civitai_client() - - model_id = request.match_info['model_id'] - response = await self.civitai_client.get_model_versions(model_id) - if not response or not response.get('modelVersions'): - return web.Response(status=404, text="Model not found") - - versions = response.get('modelVersions', []) - model_type = response.get('type', '') - - # Check model type - should be LORA, LoCon, or DORA - if model_type.lower() not in VALID_LORA_TYPES: - return web.json_response({ - 'error': f"Model type mismatch. Expected LORA or LoCon, got {model_type}" - }, status=400) - - # Check local availability for each version - for version in versions: - # Find the model file (type="Model") in the files list - model_file = next((file for file in version.get('files', []) - if file.get('type') == 'Model'), None) - - if model_file: - sha256 = model_file.get('hashes', {}).get('SHA256') - if sha256: - # Set existsLocally and localPath at the version level - version['existsLocally'] = self.scanner.has_hash(sha256) - if version['existsLocally']: - version['localPath'] = self.scanner.get_path_by_hash(sha256) - - # Also set the model file size at the version level for easier access - version['modelSizeKB'] = model_file.get('sizeKB') - else: - # No model file found in this version - version['existsLocally'] = False - - return web.json_response(versions) - except Exception as e: - logger.error(f"Error fetching model versions: {e}") - return web.Response(status=500, text=str(e)) - - async def get_civitai_model_by_version(self, request: web.Request) -> web.Response: - """Get CivitAI model details by model version ID""" - try: - if self.civitai_client is None: - self.civitai_client = await ServiceRegistry.get_civitai_client() - - model_version_id = request.match_info.get('modelVersionId') - - # Get model details from Civitai API - model, error_msg = await self.civitai_client.get_model_version_info(model_version_id) - - if not model: - # Log warning for failed model retrieval - logger.warning(f"Failed to fetch model version {model_version_id}: {error_msg}") - - # Determine status code based on error message - status_code = 404 if error_msg and "not found" in error_msg.lower() else 500 - - return web.json_response({ - "success": False, - "error": error_msg or "Failed to fetch model information" - }, status=status_code) - - return web.json_response(model) - except Exception as e: - logger.error(f"Error fetching model details: {e}") - return web.json_response({ - "success": False, - "error": str(e) - }, status=500) - - async def get_civitai_model_by_hash(self, request: web.Request) -> web.Response: - """Get CivitAI model details by hash""" - try: - if self.civitai_client is None: - self.civitai_client = await ServiceRegistry.get_civitai_client() - - hash = request.match_info.get('hash') - model = await self.civitai_client.get_model_by_hash(hash) - return web.json_response(model) - except Exception as e: - logger.error(f"Error fetching model details by hash: {e}") - return web.json_response({ - "success": False, - "error": str(e) - }, status=500) - - async def download_model(self, request: web.Request) -> web.Response: - return await ModelRouteUtils.handle_download_model(request, self.download_manager) - - async def download_model_get(self, request: web.Request) -> web.Response: - """Handle model download request via GET method - - Converts GET parameters to POST format and calls the existing download handler - - Args: - request: The aiohttp request with query parameters - - Returns: - web.Response: The HTTP response - """ - try: - # Extract query parameters - model_id = request.query.get('model_id') - if not model_id: - return web.Response( - status=400, - text="Missing required parameter: Please provide 'model_id'" - ) - - # Get optional parameters - model_version_id = request.query.get('model_version_id') - download_id = request.query.get('download_id') - use_default_paths = request.query.get('use_default_paths', 'false').lower() == 'true' - - # Create a data dictionary that mimics what would be received from a POST request - data = { - 'model_id': model_id - } - - # Add optional parameters only if they are provided - if model_version_id: - data['model_version_id'] = model_version_id - - if download_id: - data['download_id'] = download_id - - data['use_default_paths'] = use_default_paths - - # Create a mock request object with the data - # Fix: Create a proper Future object and set its result - future = asyncio.get_event_loop().create_future() - future.set_result(data) - - mock_request = type('MockRequest', (), { - 'json': lambda self=None: future - })() - - # Call the existing download handler - if self.download_manager is None: - self.download_manager = await ServiceRegistry.get_download_manager() - - return await ModelRouteUtils.handle_download_model(mock_request, self.download_manager) - - except Exception as e: - error_message = str(e) - logger.error(f"Error downloading model via GET: {error_message}", exc_info=True) - return web.Response(status=500, text=error_message) - - async def cancel_download_get(self, request: web.Request) -> web.Response: - """Handle GET request for cancelling a download by download_id""" - try: - download_id = request.query.get('download_id') - if not download_id: - return web.json_response({ - 'success': False, - 'error': 'Download ID is required' - }, status=400) - if self.download_manager is None: - self.download_manager = await ServiceRegistry.get_download_manager() - # Create a mock request with match_info for compatibility - mock_request = type('MockRequest', (), { - 'match_info': {'download_id': download_id} - })() - return await ModelRouteUtils.handle_cancel_download(mock_request, self.download_manager) - except Exception as e: - logger.error(f"Error cancelling download via GET: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def get_download_progress(self, request: web.Request) -> web.Response: - """Handle request for download progress by download_id""" - try: - # Get download_id from URL path - download_id = request.match_info.get('download_id') - if not download_id: - return web.json_response({ - 'success': False, - 'error': 'Download ID is required' - }, status=400) - - # Get progress information from websocket manager - progress_data = ws_manager.get_download_progress(download_id) - - if progress_data is None: - return web.json_response({ - 'success': False, - 'error': 'Download ID not found' - }, status=404) - - return web.json_response({ - 'success': True, - 'progress': progress_data.get('progress', 0) - }) - except Exception as e: - logger.error(f"Error getting download progress: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def move_model(self, request: web.Request) -> web.Response: - """Handle model move request""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - data = await request.json() - file_path = data.get('file_path') # full path of the model file, e.g. /path/to/model.safetensors - target_path = data.get('target_path') # folder path to move the model to, e.g. /path/to/target_folder - - if not file_path or not target_path: - return web.Response(text='File path and target path are required', status=400) - - # Check if source and destination are the same - source_dir = os.path.dirname(file_path) - if os.path.normpath(source_dir) == os.path.normpath(target_path): - logger.info(f"Source and target directories are the same: {source_dir}") - return web.json_response({'success': True, 'message': 'Source and target directories are the same'}) - - # Check if target file already exists - file_name = os.path.basename(file_path) - target_file_path = os.path.join(target_path, file_name).replace(os.sep, '/') - - if os.path.exists(target_file_path): - return web.json_response({ - 'success': False, - 'error': f"Target file already exists: {target_file_path}" - }, status=409) # 409 Conflict - - # Call scanner to handle the move operation - success = await self.scanner.move_model(file_path, target_path) - - if success: - return web.json_response({'success': True}) - else: - return web.Response(text='Failed to move model', status=500) - - except Exception as e: - logger.error(f"Error moving model: {e}", exc_info=True) - return web.Response(text=str(e), status=500) - @classmethod async def cleanup(cls): """Add cleanup method for application shutdown""" - # Now we don't need to store an instance, as services are managed by ServiceRegistry - civitai_client = await ServiceRegistry.get_civitai_client() - if civitai_client: - await civitai_client.close() - - async def save_metadata(self, request: web.Request) -> web.Response: - """Handle saving metadata updates""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - data = await request.json() - file_path = data.get('file_path') - if not file_path: - return web.Response(text='File path is required', status=400) - - # Remove file path from data to avoid saving it - metadata_updates = {k: v for k, v in data.items() if k != 'file_path'} - - # Get metadata file path - metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' - - # Load existing metadata - metadata = await ModelRouteUtils.load_local_metadata(metadata_path) - - # Handle nested updates (for civitai.trainedWords) - for key, value in metadata_updates.items(): - if isinstance(value, dict) and key in metadata and isinstance(metadata[key], dict): - # Deep update for nested dictionaries - for nested_key, nested_value in value.items(): - metadata[key][nested_key] = nested_value - else: - # Regular update for top-level keys - metadata[key] = value - - # Save updated metadata - await MetadataManager.save_metadata(file_path, metadata) - - # Update cache - await self.scanner.update_single_model_cache(file_path, file_path, metadata) - - # If model_name was updated, resort the cache - if 'model_name' in metadata_updates: - cache = await self.scanner.get_cached_data() - await cache.resort(name_only=True) - - return web.json_response({'success': True}) - - except Exception as e: - logger.error(f"Error saving metadata: {e}", exc_info=True) - return web.Response(text=str(e), status=500) - - async def get_lora_preview_url(self, request: web.Request) -> web.Response: - """Get the static preview URL for a LoRA file""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - # Get lora file name from query parameters - lora_name = request.query.get('name') - if not lora_name: - return web.Response(text='Lora file name is required', status=400) - - # Get cache data - cache = await self.scanner.get_cached_data() - - # Search for the lora in cache data - for lora in cache.raw_data: - file_name = lora['file_name'] - if file_name == lora_name: - if preview_url := lora.get('preview_url'): - # Convert preview path to static URL - static_url = config.get_preview_static_url(preview_url) - if static_url: - return web.json_response({ - 'success': True, - 'preview_url': static_url - }) - break - - # If no preview URL found - return web.json_response({ - 'success': False, - 'error': 'No preview URL found for the specified lora' - }, status=404) - - except Exception as e: - logger.error(f"Error getting lora preview URL: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def get_lora_civitai_url(self, request: web.Request) -> web.Response: - """Get the Civitai URL for a LoRA file""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - # Get lora file name from query parameters - lora_name = request.query.get('name') - if not lora_name: - return web.Response(text='Lora file name is required', status=400) - - # Get cache data - cache = await self.scanner.get_cached_data() - - # Search for the lora in cache data - for lora in cache.raw_data: - file_name = lora['file_name'] - if file_name == lora_name: - civitai_data = lora.get('civitai', {}) - model_id = civitai_data.get('modelId') - version_id = civitai_data.get('id') - - if model_id: - civitai_url = f"https://civitai.com/models/{model_id}" - if version_id: - civitai_url += f"?modelVersionId={version_id}" - - return web.json_response({ - 'success': True, - 'civitai_url': civitai_url, - 'model_id': model_id, - 'version_id': version_id - }) - break - - # If no Civitai data found - return web.json_response({ - 'success': False, - 'error': 'No Civitai data found for the specified lora' - }, status=404) - - except Exception as e: - logger.error(f"Error getting lora Civitai URL: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def move_models_bulk(self, request: web.Request) -> web.Response: - """Handle bulk model move request""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - data = await request.json() - file_paths = data.get('file_paths', []) # list of full paths of the model files, e.g. ["/path/to/model1.safetensors", "/path/to/model2.safetensors"] - target_path = data.get('target_path') # folder path to move the models to, e.g. "/path/to/target_folder" - - if not file_paths or not target_path: - return web.Response(text='File paths and target path are required', status=400) - - results = [] - for file_path in file_paths: - # Check if source and destination are the same - source_dir = os.path.dirname(file_path) - if os.path.normpath(source_dir) == os.path.normpath(target_path): - results.append({ - "path": file_path, - "success": True, - "message": "Source and target directories are the same" - }) - continue - - # Check if target file already exists - file_name = os.path.basename(file_path) - target_file_path = os.path.join(target_path, file_name).replace(os.sep, '/') - - if os.path.exists(target_file_path): - results.append({ - "path": file_path, - "success": False, - "message": f"Target file already exists: {target_file_path}" - }) - continue - - # Try to move the model - success = await self.scanner.move_model(file_path, target_path) - results.append({ - "path": file_path, - "success": success, - "message": "Success" if success else "Failed to move model" - }) - - # Count successes and failures - success_count = sum(1 for r in results if r["success"]) - failure_count = len(results) - success_count - - return web.json_response({ - 'success': True, - 'message': f'Moved {success_count} of {len(file_paths)} models', - 'results': results, - 'success_count': success_count, - 'failure_count': failure_count - }) - - except Exception as e: - logger.error(f"Error moving models in bulk: {e}", exc_info=True) - return web.Response(text=str(e), status=500) - - async def get_lora_model_description(self, request: web.Request) -> web.Response: - """Get model description for a Lora model""" - try: - if self.civitai_client is None: - self.civitai_client = await ServiceRegistry.get_civitai_client() - - # Get parameters - model_id = request.query.get('model_id') - file_path = request.query.get('file_path') - - if not model_id: - return web.json_response({ - 'success': False, - 'error': 'Model ID is required' - }, status=400) - - # Check if we already have the description stored in metadata - description = None - tags = [] - creator = {} - if file_path: - metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' - metadata = await ModelRouteUtils.load_local_metadata(metadata_path) - description = metadata.get('modelDescription') - tags = metadata.get('tags', []) - creator = metadata.get('creator', {}) - - # If description is not in metadata, fetch from CivitAI - if not description: - logger.info(f"Fetching model metadata for model ID: {model_id}") - model_metadata, _ = await self.civitai_client.get_model_metadata(model_id) - - if (model_metadata): - description = model_metadata.get('description') - tags = model_metadata.get('tags', []) - creator = model_metadata.get('creator', {}) - - # Save the metadata to file if we have a file path and got metadata - if file_path: - try: - metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' - metadata = await ModelRouteUtils.load_local_metadata(metadata_path) - - metadata['modelDescription'] = description - metadata['tags'] = tags - # Ensure the civitai dict exists - if 'civitai' not in metadata: - metadata['civitai'] = {} - # Store creator in the civitai nested structure - metadata['civitai']['creator'] = creator - - await MetadataManager.save_metadata(file_path, metadata, True) - except Exception as e: - logger.error(f"Error saving model metadata: {e}") - - return web.json_response({ - 'success': True, - 'description': description or "

No model description available.

", - 'tags': tags, - 'creator': creator - }) - - except Exception as e: - logger.error(f"Error getting model metadata: {e}") - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def get_top_tags(self, request: web.Request) -> web.Response: - """Handle request for top tags sorted by frequency""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - # Parse query parameters - limit = int(request.query.get('limit', '20')) - - # Validate limit - if limit < 1 or limit > 100: - limit = 20 # Default to a reasonable limit - - # Get top tags - top_tags = await self.scanner.get_top_tags(limit) - - return web.json_response({ - 'success': True, - 'tags': top_tags - }) - - except Exception as e: - logger.error(f"Error getting top tags: {str(e)}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': 'Internal server error' - }, status=500) - - async def get_base_models(self, request: web.Request) -> web.Response: - """Get base models used in loras""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - # Parse query parameters - limit = int(request.query.get('limit', '20')) - - # Validate limit - if limit < 1 or limit > 100: - limit = 20 # Default to a reasonable limit - - # Get base models - base_models = await self.scanner.get_base_models(limit) - - return web.json_response({ - 'success': True, - 'base_models': base_models - }) - except Exception as e: - logger.error(f"Error retrieving base models: {e}") - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def rename_lora(self, request: web.Request) -> web.Response: - """Handle renaming a LoRA file and its associated files""" - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - return await ModelRouteUtils.handle_rename_model(request, self.scanner) - - async def get_trigger_words(self, request: web.Request) -> web.Response: - """Get trigger words for specified LoRA models""" - try: - json_data = await request.json() - lora_names = json_data.get("lora_names", []) - node_ids = json_data.get("node_ids", []) - - all_trigger_words = [] - for lora_name in lora_names: - _, trigger_words = get_lora_info(lora_name) - all_trigger_words.extend(trigger_words) - - # Format the trigger words - trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else "" - - # Send update to all connected trigger word toggle nodes - for node_id in node_ids: - PromptServer.instance.send_sync("trigger_word_update", { - "id": node_id, - "message": trigger_words_text - }) - - return web.json_response({"success": True}) - - except Exception as e: - logger.error(f"Error getting trigger words: {e}") - return web.json_response({ - "success": False, - "error": str(e) - }, status=500) - - async def get_letter_counts(self, request: web.Request) -> web.Response: - """Get count of loras for each letter of the alphabet""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - # Get letter counts - letter_counts = await self.scanner.get_letter_counts() - - return web.json_response({ - 'success': True, - 'letter_counts': letter_counts - }) - except Exception as e: - logger.error(f"Error getting letter counts: {e}") - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def get_lora_notes(self, request: web.Request) -> web.Response: - """Get notes for a specific LoRA file""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - # Get lora file name from query parameters - lora_name = request.query.get('name') - if not lora_name: - return web.Response(text='Lora file name is required', status=400) - - # Get cache data - cache = await self.scanner.get_cached_data() - - # Search for the lora in cache data - for lora in cache.raw_data: - file_name = lora['file_name'] - if file_name == lora_name: - notes = lora.get('notes', '') - - return web.json_response({ - 'success': True, - 'notes': notes - }) - - # If lora not found - return web.json_response({ - 'success': False, - 'error': 'LoRA not found in cache' - }, status=404) - - except Exception as e: - logger.error(f"Error getting lora notes: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def get_lora_trigger_words(self, request: web.Request) -> web.Response: - """Get trigger words for a specific LoRA file""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - # Get lora file name from query parameters - lora_name = request.query.get('name') - if not lora_name: - return web.Response(text='Lora file name is required', status=400) - - # Get cache data - cache = await self.scanner.get_cached_data() - - # Search for the lora in cache data - for lora in cache.raw_data: - file_name = lora['file_name'] - if file_name == lora_name: - # Get trigger words from civitai data - civitai_data = lora.get('civitai', {}) - trigger_words = civitai_data.get('trainedWords', []) - - return web.json_response({ - 'success': True, - 'trigger_words': trigger_words - }) - - # If lora not found - return web.json_response({ - 'success': False, - 'error': 'LoRA not found in cache' - }, status=404) - - except Exception as e: - logger.error(f"Error getting lora trigger words: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def find_duplicate_loras(self, request: web.Request) -> web.Response: - """Find loras with duplicate SHA256 hashes""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - # Get duplicate hashes from hash index - duplicates = self.scanner._hash_index.get_duplicate_hashes() - - # Format the response - result = [] - cache = await self.scanner.get_cached_data() - - for sha256, paths in duplicates.items(): - group = { - "hash": sha256, - "models": [] - } - # Find matching models for each duplicate path - for path in paths: - model = next((m for m in cache.raw_data if m['file_path'] == path), None) - if model: - group["models"].append(self._format_lora_response(model)) - - # Add the primary model too - primary_path = self.scanner._hash_index.get_path(sha256) - if primary_path and primary_path not in paths: - primary_model = next((m for m in cache.raw_data if m['file_path'] == primary_path), None) - if primary_model: - group["models"].insert(0, self._format_lora_response(primary_model)) - - if len(group["models"]) > 1: # Only include if we found multiple models - result.append(group) - - return web.json_response({ - "success": True, - "duplicates": result, - "count": len(result) - }) - except Exception as e: - logger.error(f"Error finding duplicate loras: {e}", exc_info=True) - return web.json_response({ - "success": False, - "error": str(e) - }, status=500) - - async def find_filename_conflicts(self, request: web.Request) -> web.Response: - """Find loras with conflicting filenames""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - # Get duplicate filenames from hash index - duplicates = self.scanner._hash_index.get_duplicate_filenames() - - # Format the response - result = [] - cache = await self.scanner.get_cached_data() - - for filename, paths in duplicates.items(): - group = { - "filename": filename, - "models": [] - } - # Find matching models for each path - for path in paths: - model = next((m for m in cache.raw_data if m['file_path'] == path), None) - if model: - group["models"].append(self._format_lora_response(model)) - - # Find the model from the main index too - hash_val = self.scanner._hash_index.get_hash_by_filename(filename) - if hash_val: - main_path = self.scanner._hash_index.get_path(hash_val) - if main_path and main_path not in paths: - main_model = next((m for m in cache.raw_data if m['file_path'] == main_path), None) - if main_model: - group["models"].insert(0, self._format_lora_response(main_model)) - - if group["models"]: # Only include if we found models - result.append(group) - - return web.json_response({ - "success": True, - "conflicts": result, - "count": len(result) - }) - except Exception as e: - logger.error(f"Error finding filename conflicts: {e}", exc_info=True) - return web.json_response({ - "success": False, - "error": str(e) - }, status=500) - - async def bulk_delete_loras(self, request: web.Request) -> web.Response: - """Handle bulk deletion of lora models""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - return await ModelRouteUtils.handle_bulk_delete_models(request, self.scanner) - - except Exception as e: - logger.error(f"Error in bulk delete loras: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def relink_civitai(self, request: web.Request) -> web.Response: - """Handle CivitAI metadata re-linking request by model version ID for LoRAs""" - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - return await ModelRouteUtils.handle_relink_civitai(request, self.scanner) - - async def verify_duplicates(self, request: web.Request) -> web.Response: - """Handle verification of duplicate lora hashes""" - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - return await ModelRouteUtils.handle_verify_duplicates(request, self.scanner) + # Cleanup is now handled by ServiceRegistry and individual services + pass diff --git a/py/routes/base_model_routes.py b/py/routes/base_model_routes.py new file mode 100644 index 00000000..7a6e4aac --- /dev/null +++ b/py/routes/base_model_routes.py @@ -0,0 +1,431 @@ +from abc import ABC, abstractmethod +import json +import logging +from aiohttp import web +from typing import Dict + +from ..utils.routes_common import ModelRouteUtils +from ..services.websocket_manager import ws_manager + +logger = logging.getLogger(__name__) + +class BaseModelRoutes(ABC): + """Base route controller for all model types""" + + def __init__(self, service): + """Initialize the route controller + + Args: + service: Model service instance (LoraService, CheckpointService, etc.) + """ + self.service = service + self.model_type = service.model_type + + def setup_routes(self, app: web.Application, prefix: str): + """Setup common routes for the model type + + Args: + app: aiohttp application + prefix: URL prefix (e.g., 'loras', 'checkpoints') + """ + # Common model management routes + app.router.add_get(f'/api/{prefix}', self.get_models) + app.router.add_post(f'/api/{prefix}/delete', self.delete_model) + app.router.add_post(f'/api/{prefix}/exclude', self.exclude_model) + app.router.add_post(f'/api/{prefix}/fetch-civitai', self.fetch_civitai) + app.router.add_post(f'/api/{prefix}/relink-civitai', self.relink_civitai) + app.router.add_post(f'/api/{prefix}/replace-preview', self.replace_preview) + app.router.add_post(f'/api/{prefix}/save-metadata', self.save_metadata) + app.router.add_post(f'/api/{prefix}/rename', self.rename_model) + app.router.add_post(f'/api/{prefix}/bulk-delete', self.bulk_delete_models) + app.router.add_post(f'/api/{prefix}/verify-duplicates', self.verify_duplicates) + + # Common query routes + app.router.add_get(f'/api/{prefix}/top-tags', self.get_top_tags) + app.router.add_get(f'/api/{prefix}/base-models', self.get_base_models) + app.router.add_get(f'/api/{prefix}/scan', self.scan_models) + app.router.add_get(f'/api/{prefix}/roots', self.get_model_roots) + app.router.add_get(f'/api/{prefix}/find-duplicates', self.find_duplicate_models) + app.router.add_get(f'/api/{prefix}/find-filename-conflicts', self.find_filename_conflicts) + + # CivitAI integration routes + app.router.add_post(f'/api/{prefix}/fetch-all-civitai', self.fetch_all_civitai) + app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions) + + # Setup model-specific routes + self.setup_specific_routes(app, prefix) + + @abstractmethod + def setup_specific_routes(self, app: web.Application, prefix: str): + """Setup model-specific routes - to be implemented by subclasses""" + pass + + async def get_models(self, request: web.Request) -> web.Response: + """Get paginated model data""" + try: + # Parse common query parameters + params = self._parse_common_params(request) + + # Get data from service + result = await self.service.get_paginated_data(**params) + + # Format response items + formatted_result = { + 'items': [await self.service.format_response(item) for item in result['items']], + 'total': result['total'], + 'page': result['page'], + 'page_size': result['page_size'], + 'total_pages': result['total_pages'] + } + + return web.json_response(formatted_result) + + except Exception as e: + logger.error(f"Error in get_{self.model_type}s: {e}", exc_info=True) + return web.json_response({"error": str(e)}, status=500) + + def _parse_common_params(self, request: web.Request) -> Dict: + """Parse common query parameters""" + # Parse basic pagination and sorting + page = int(request.query.get('page', '1')) + page_size = min(int(request.query.get('page_size', '20')), 100) + sort_by = request.query.get('sort_by', 'name') + folder = request.query.get('folder', None) + search = request.query.get('search', None) + fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true' + + # Parse filter arrays + base_models = request.query.getall('base_model', []) + tags = request.query.getall('tag', []) + favorites_only = request.query.get('favorites_only', 'false').lower() == 'true' + + # Parse search options + search_options = { + 'filename': request.query.get('search_filename', 'true').lower() == 'true', + 'modelname': request.query.get('search_modelname', 'true').lower() == 'true', + 'tags': request.query.get('search_tags', 'false').lower() == 'true', + 'recursive': request.query.get('recursive', 'false').lower() == 'true', + } + + # Parse hash filters if provided + hash_filters = {} + if 'hash' in request.query: + hash_filters['single_hash'] = request.query['hash'] + elif 'hashes' in request.query: + try: + hash_list = json.loads(request.query['hashes']) + if isinstance(hash_list, list): + hash_filters['multiple_hashes'] = hash_list + except (json.JSONDecodeError, TypeError): + pass + + return { + 'page': page, + 'page_size': page_size, + 'sort_by': sort_by, + 'folder': folder, + 'search': search, + 'fuzzy_search': fuzzy_search, + 'base_models': base_models, + 'tags': tags, + 'search_options': search_options, + 'hash_filters': hash_filters, + 'favorites_only': favorites_only, + # Add model-specific parameters + **self._parse_specific_params(request) + } + + def _parse_specific_params(self, request: web.Request) -> Dict: + """Parse model-specific parameters - to be overridden by subclasses""" + return {} + + # Common route handlers + async def delete_model(self, request: web.Request) -> web.Response: + """Handle model deletion request""" + return await ModelRouteUtils.handle_delete_model(request, self.service.scanner) + + async def exclude_model(self, request: web.Request) -> web.Response: + """Handle model exclusion request""" + return await ModelRouteUtils.handle_exclude_model(request, self.service.scanner) + + async def fetch_civitai(self, request: web.Request) -> web.Response: + """Handle CivitAI metadata fetch request""" + response = await ModelRouteUtils.handle_fetch_civitai(request, self.service.scanner) + + # If successful, format the metadata before returning + if response.status == 200: + data = json.loads(response.body.decode('utf-8')) + if data.get("success") and data.get("metadata"): + formatted_metadata = await self.service.format_response(data["metadata"]) + return web.json_response({ + "success": True, + "metadata": formatted_metadata + }) + + return response + + async def relink_civitai(self, request: web.Request) -> web.Response: + """Handle CivitAI metadata re-linking request""" + return await ModelRouteUtils.handle_relink_civitai(request, self.service.scanner) + + async def replace_preview(self, request: web.Request) -> web.Response: + """Handle preview image replacement""" + return await ModelRouteUtils.handle_replace_preview(request, self.service.scanner) + + async def save_metadata(self, request: web.Request) -> web.Response: + """Handle saving metadata updates""" + return await ModelRouteUtils.handle_save_metadata(request, self.service.scanner) + + async def rename_model(self, request: web.Request) -> web.Response: + """Handle renaming a model file and its associated files""" + return await ModelRouteUtils.handle_rename_model(request, self.service.scanner) + + async def bulk_delete_models(self, request: web.Request) -> web.Response: + """Handle bulk deletion of models""" + return await ModelRouteUtils.handle_bulk_delete_models(request, self.service.scanner) + + async def verify_duplicates(self, request: web.Request) -> web.Response: + """Handle verification of duplicate model hashes""" + return await ModelRouteUtils.handle_verify_duplicates(request, self.service.scanner) + + async def get_top_tags(self, request: web.Request) -> web.Response: + """Handle request for top tags sorted by frequency""" + try: + limit = int(request.query.get('limit', '20')) + if limit < 1 or limit > 100: + limit = 20 + + top_tags = await self.service.get_top_tags(limit) + + return web.json_response({ + 'success': True, + 'tags': top_tags + }) + + except Exception as e: + logger.error(f"Error getting top tags: {str(e)}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': 'Internal server error' + }, status=500) + + async def get_base_models(self, request: web.Request) -> web.Response: + """Get base models used in models""" + try: + limit = int(request.query.get('limit', '20')) + if limit < 1 or limit > 100: + limit = 20 + + base_models = await self.service.get_base_models(limit) + + return web.json_response({ + 'success': True, + 'base_models': base_models + }) + except Exception as e: + logger.error(f"Error retrieving base models: {e}") + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + async def scan_models(self, request: web.Request) -> web.Response: + """Force a rescan of model files""" + try: + full_rebuild = request.query.get('full_rebuild', 'false').lower() == 'true' + + await self.service.scan_models(force_refresh=True, rebuild_cache=full_rebuild) + return web.json_response({ + "status": "success", + "message": f"{self.model_type.capitalize()} scan completed" + }) + except Exception as e: + logger.error(f"Error in scan_{self.model_type}s: {e}", exc_info=True) + return web.json_response({"error": str(e)}, status=500) + + async def get_model_roots(self, request: web.Request) -> web.Response: + """Return the model root directories""" + try: + roots = self.service.get_model_roots() + return web.json_response({ + "success": True, + "roots": roots + }) + except Exception as e: + logger.error(f"Error getting {self.model_type} roots: {e}", exc_info=True) + return web.json_response({ + "success": False, + "error": str(e) + }, status=500) + + async def find_duplicate_models(self, request: web.Request) -> web.Response: + """Find models with duplicate SHA256 hashes""" + try: + # Get duplicate hashes from service + duplicates = self.service.find_duplicate_hashes() + + # Format the response + result = [] + cache = await self.service.scanner.get_cached_data() + + for sha256, paths in duplicates.items(): + group = { + "hash": sha256, + "models": [] + } + # Find matching models for each path + for path in paths: + model = next((m for m in cache.raw_data if m['file_path'] == path), None) + if model: + group["models"].append(await self.service.format_response(model)) + + # Add the primary model too + primary_path = self.service.get_path_by_hash(sha256) + if primary_path and primary_path not in paths: + primary_model = next((m for m in cache.raw_data if m['file_path'] == primary_path), None) + if primary_model: + group["models"].insert(0, await self.service.format_response(primary_model)) + + if len(group["models"]) > 1: # Only include if we found multiple models + result.append(group) + + return web.json_response({ + "success": True, + "duplicates": result, + "count": len(result) + }) + except Exception as e: + logger.error(f"Error finding duplicate {self.model_type}s: {e}", exc_info=True) + return web.json_response({ + "success": False, + "error": str(e) + }, status=500) + + async def find_filename_conflicts(self, request: web.Request) -> web.Response: + """Find models with conflicting filenames""" + try: + # Get duplicate filenames from service + duplicates = self.service.find_duplicate_filenames() + + # Format the response + result = [] + cache = await self.service.scanner.get_cached_data() + + for filename, paths in duplicates.items(): + group = { + "filename": filename, + "models": [] + } + # Find matching models for each path + for path in paths: + model = next((m for m in cache.raw_data if m['file_path'] == path), None) + if model: + group["models"].append(await self.service.format_response(model)) + + # Find the model from the main index too + hash_val = self.service.scanner._hash_index.get_hash_by_filename(filename) + if hash_val: + main_path = self.service.get_path_by_hash(hash_val) + if main_path and main_path not in paths: + main_model = next((m for m in cache.raw_data if m['file_path'] == main_path), None) + if main_model: + group["models"].insert(0, await self.service.format_response(main_model)) + + if group["models"]: + result.append(group) + + return web.json_response({ + "success": True, + "conflicts": result, + "count": len(result) + }) + except Exception as e: + logger.error(f"Error finding filename conflicts for {self.model_type}s: {e}", exc_info=True) + return web.json_response({ + "success": False, + "error": str(e) + }, status=500) + + async def fetch_all_civitai(self, request: web.Request) -> web.Response: + """Fetch CivitAI metadata for all models in the background""" + try: + cache = await self.service.scanner.get_cached_data() + total = len(cache.raw_data) + processed = 0 + success = 0 + needs_resort = False + + # Prepare models to process + to_process = [ + model for model in cache.raw_data + if model.get('sha256') and (not model.get('civitai') or 'id' not in model.get('civitai')) and model.get('from_civitai', True) + ] + total_to_process = len(to_process) + + # Send initial progress + await ws_manager.broadcast({ + 'status': 'started', + 'total': total_to_process, + 'processed': 0, + 'success': 0 + }) + + # Process each model + for model in to_process: + try: + original_name = model.get('model_name') + if await ModelRouteUtils.fetch_and_update_model( + sha256=model['sha256'], + file_path=model['file_path'], + model_data=model, + update_cache_func=self.service.scanner.update_single_model_cache + ): + success += 1 + if original_name != model.get('model_name'): + needs_resort = True + + processed += 1 + + # Send progress update + await ws_manager.broadcast({ + 'status': 'processing', + 'total': total_to_process, + 'processed': processed, + 'success': success, + 'current_name': model.get('model_name', 'Unknown') + }) + + except Exception as e: + logger.error(f"Error fetching CivitAI data for {model['file_path']}: {e}") + + if needs_resort: + await cache.resort(name_only=True) + + # Send completion message + await ws_manager.broadcast({ + 'status': 'completed', + 'total': total_to_process, + 'processed': processed, + 'success': success + }) + + return web.json_response({ + "success": True, + "message": f"Successfully updated {success} of {processed} processed {self.model_type}s (total: {total})" + }) + + except Exception as e: + # Send error message + await ws_manager.broadcast({ + 'status': 'error', + 'error': str(e) + }) + logger.error(f"Error in fetch_all_civitai for {self.model_type}s: {e}") + return web.Response(text=str(e), status=500) + + async def get_civitai_versions(self, request: web.Request) -> web.Response: + """Get available versions for a Civitai model with local availability info""" + # This will be implemented by subclasses as they need CivitAI client access + return web.json_response({ + "error": "Not implemented in base class" + }, status=501) \ No newline at end of file diff --git a/py/routes/checkpoint_routes.py b/py/routes/checkpoint_routes.py new file mode 100644 index 00000000..5fdb670e --- /dev/null +++ b/py/routes/checkpoint_routes.py @@ -0,0 +1,170 @@ +import jinja2 +import logging +from aiohttp import web + +from .base_model_routes import BaseModelRoutes +from ..services.checkpoint_service import CheckpointService +from ..services.service_registry import ServiceRegistry +from ..config import config +from ..services.settings_manager import settings + +logger = logging.getLogger(__name__) + +class CheckpointRoutes(BaseModelRoutes): + """Checkpoint-specific route controller""" + + def __init__(self): + """Initialize Checkpoint routes with Checkpoint service""" + # Service will be initialized later via setup_routes + self.service = None + self.civitai_client = None + self.template_env = jinja2.Environment( + loader=jinja2.FileSystemLoader(config.templates_path), + autoescape=True + ) + + async def initialize_services(self): + """Initialize services from ServiceRegistry""" + checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() + self.service = CheckpointService(checkpoint_scanner) + self.civitai_client = await ServiceRegistry.get_civitai_client() + + # Initialize parent with the service + super().__init__(self.service) + + def setup_routes(self, app: web.Application): + """Setup Checkpoint routes""" + # Schedule service initialization on app startup + app.on_startup.append(lambda _: self.initialize_services()) + + # Setup common routes with 'checkpoints' prefix + super().setup_routes(app, 'checkpoints') + + def setup_specific_routes(self, app: web.Application, prefix: str): + """Setup Checkpoint-specific routes""" + # Checkpoint page route + app.router.add_get('/checkpoints', self.handle_checkpoints_page) + + # Checkpoint-specific CivitAI integration + app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions_checkpoint) + + # Checkpoint info by name + app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_checkpoint_info) + + async def handle_checkpoints_page(self, request: web.Request) -> web.Response: + """Handle GET /checkpoints request""" + try: + # Check if the CheckpointScanner is initializing + # It's initializing if the cache object doesn't exist yet, + # OR if the scanner explicitly says it's initializing (background task running). + is_initializing = ( + self.service.scanner._cache is None or + (hasattr(self.service.scanner, '_is_initializing') and self.service.scanner._is_initializing) + ) + + if is_initializing: + # If still initializing, return loading page + template = self.template_env.get_template('checkpoints.html') + rendered = template.render( + folders=[], # Empty folder list + is_initializing=True, # New flag + settings=settings, # Pass settings to template + request=request # Pass the request object to the template + ) + + logger.info("Checkpoints page is initializing, returning loading page") + else: + # Normal flow - get initialized cache data + try: + cache = await self.service.scanner.get_cached_data(force_refresh=False) + template = self.template_env.get_template('checkpoints.html') + rendered = template.render( + folders=cache.folders, + is_initializing=False, + settings=settings, # Pass settings to template + request=request # Pass the request object to the template + ) + except Exception as cache_error: + logger.error(f"Error loading checkpoints cache data: {cache_error}") + # If getting cache fails, also show initialization page + template = self.template_env.get_template('checkpoints.html') + rendered = template.render( + folders=[], + is_initializing=True, + settings=settings, + request=request + ) + logger.info("Checkpoints cache error, returning initialization page") + + return web.Response( + text=rendered, + content_type='text/html' + ) + except Exception as e: + logger.error(f"Error handling checkpoints request: {e}", exc_info=True) + return web.Response( + text="Error loading checkpoints page", + status=500 + ) + + async def get_checkpoint_info(self, request: web.Request) -> web.Response: + """Get detailed information for a specific checkpoint by name""" + try: + name = request.match_info.get('name', '') + checkpoint_info = await self.service.get_model_info_by_name(name) + + if checkpoint_info: + return web.json_response(checkpoint_info) + else: + return web.json_response({"error": "Checkpoint not found"}, status=404) + + except Exception as e: + logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True) + return web.json_response({"error": str(e)}, status=500) + + async def get_civitai_versions_checkpoint(self, request: web.Request) -> web.Response: + """Get available versions for a Civitai checkpoint model with local availability info""" + try: + model_id = request.match_info['model_id'] + response = await self.civitai_client.get_model_versions(model_id) + if not response or not response.get('modelVersions'): + return web.Response(status=404, text="Model not found") + + versions = response.get('modelVersions', []) + model_type = response.get('type', '') + + # Check model type - should be Checkpoint + if model_type.lower() != 'checkpoint': + return web.json_response({ + 'error': f"Model type mismatch. Expected Checkpoint, got {model_type}" + }, status=400) + + # Check local availability for each version + for version in versions: + # Find the primary model file (type="Model" and primary=true) in the files list + model_file = next((file for file in version.get('files', []) + if file.get('type') == 'Model' and file.get('primary') == True), None) + + # If no primary file found, try to find any model file + if not model_file: + model_file = next((file for file in version.get('files', []) + if file.get('type') == 'Model'), None) + + if model_file: + sha256 = model_file.get('hashes', {}).get('SHA256') + if sha256: + # Set existsLocally and localPath at the version level + version['existsLocally'] = self.service.has_hash(sha256) + if version['existsLocally']: + version['localPath'] = self.service.get_path_by_hash(sha256) + + # Also set the model file size at the version level for easier access + version['modelSizeKB'] = model_file.get('sizeKB') + else: + # No model file found in this version + version['existsLocally'] = False + + return web.json_response(versions) + except Exception as e: + logger.error(f"Error fetching checkpoint model versions: {e}") + return web.Response(status=500, text=str(e)) \ No newline at end of file diff --git a/py/routes/lora_routes.py b/py/routes/lora_routes.py index 1d00b66e..5b0674ba 100644 --- a/py/routes/lora_routes.py +++ b/py/routes/lora_routes.py @@ -1,76 +1,126 @@ +import jinja2 +import asyncio +import logging import os from aiohttp import web -import jinja2 from typing import Dict -import logging -from ..config import config +from server import PromptServer # type: ignore + +from .base_model_routes import BaseModelRoutes +from ..services.lora_service import LoraService +from ..services.service_registry import ServiceRegistry from ..services.settings_manager import settings -from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import +from ..config import config +from ..utils.routes_common import ModelRouteUtils +from ..utils.utils import get_lora_info logger = logging.getLogger(__name__) -logging.getLogger('asyncio').setLevel(logging.CRITICAL) -class LoraRoutes: - """Route handlers for LoRA management endpoints""" +class LoraRoutes(BaseModelRoutes): + """LoRA-specific route controller""" def __init__(self): - # Initialize service references as None, will be set during async init - self.scanner = None - self.recipe_scanner = None + """Initialize LoRA routes with LoRA service""" + # Service will be initialized later via setup_routes + self.service = None + self.civitai_client = None + self.download_manager = None + self._download_lock = asyncio.Lock() self.template_env = jinja2.Environment( loader=jinja2.FileSystemLoader(config.templates_path), autoescape=True ) - - async def init_services(self): - """Initialize services from ServiceRegistry""" - self.scanner = await ServiceRegistry.get_lora_scanner() - self.recipe_scanner = await ServiceRegistry.get_recipe_scanner() - def format_lora_data(self, lora: Dict) -> Dict: - """Format LoRA data for template rendering""" - return { - "model_name": lora["model_name"], - "file_name": lora["file_name"], - "preview_url": config.get_preview_static_url(lora["preview_url"]), - "preview_nsfw_level": lora.get("preview_nsfw_level", 0), - "base_model": lora["base_model"], - "folder": lora["folder"], - "sha256": lora["sha256"], - "file_path": lora["file_path"].replace(os.sep, "/"), - "size": lora["size"], - "tags": lora["tags"], - "modelDescription": lora["modelDescription"], - "usage_tips": lora["usage_tips"], - "notes": lora["notes"], - "modified": lora["modified"], - "from_civitai": lora.get("from_civitai", True), - "civitai": self._filter_civitai_data(lora.get("civitai", {})) - } - - def _filter_civitai_data(self, data: Dict) -> Dict: - """Filter relevant fields from CivitAI data""" - if not data: - return {} - - fields = [ - "id", "modelId", "name", "createdAt", "updatedAt", - "publishedAt", "trainedWords", "baseModel", "description", - "model", "images" - ] - return {k: data[k] for k in fields if k in data} + async def initialize_services(self): + """Initialize services from ServiceRegistry""" + lora_scanner = await ServiceRegistry.get_lora_scanner() + self.service = LoraService(lora_scanner) + self.civitai_client = await ServiceRegistry.get_civitai_client() + self.download_manager = await ServiceRegistry.get_download_manager() + + # Initialize parent with the service + super().__init__(self.service) + + def setup_routes(self, app: web.Application): + """Setup LoRA routes""" + # Schedule service initialization on app startup + app.on_startup.append(lambda _: self.initialize_services()) + + # Setup common routes with 'loras' prefix + super().setup_routes(app, 'loras') + + def setup_specific_routes(self, app: web.Application, prefix: str): + """Setup LoRA-specific routes""" + # Lora page route + app.router.add_get('/loras', self.handle_loras_page) + # LoRA-specific query routes + app.router.add_get(f'/api/{prefix}/letter-counts', self.get_letter_counts) + app.router.add_get(f'/api/{prefix}/get-notes', self.get_lora_notes) + app.router.add_get(f'/api/{prefix}/get-trigger-words', self.get_lora_trigger_words) + app.router.add_get(f'/api/lora-preview-url', self.get_lora_preview_url) + app.router.add_get(f'/api/lora-civitai-url', self.get_lora_civitai_url) + app.router.add_get(f'/api/lora-model-description', self.get_lora_model_description) + app.router.add_get(f'/api/folders', self.get_folders) + app.router.add_get(f'/api/lora-roots', self.get_lora_roots) + + # LoRA-specific management routes + app.router.add_post(f'/api/move_model', self.move_model) + app.router.add_post(f'/api/move_models_bulk', self.move_models_bulk) + + # CivitAI integration with LoRA-specific validation + app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions_lora) + app.router.add_get(f'/api/civitai/model/version/{{modelVersionId}}', self.get_civitai_model_by_version) + app.router.add_get(f'/api/civitai/model/hash/{{hash}}', self.get_civitai_model_by_hash) + + # Download management + app.router.add_post(f'/api/download-model', self.download_model) + app.router.add_get(f'/api/download-model-get', self.download_model_get) + app.router.add_get(f'/api/cancel-download-get', self.cancel_download_get) + app.router.add_get(f'/api/download-progress/{{download_id}}', self.get_download_progress) + + # ComfyUI integration + app.router.add_post(f'/loramanager/get_trigger_words', self.get_trigger_words) + + # Legacy API compatibility + app.router.add_post(f'/api/delete_model', self.delete_model) + app.router.add_post(f'/api/fetch-civitai', self.fetch_civitai) + app.router.add_post(f'/api/relink-civitai', self.relink_civitai) + app.router.add_post(f'/api/replace_preview', self.replace_preview) + app.router.add_post(f'/api/fetch-all-civitai', self.fetch_all_civitai) + + def _parse_specific_params(self, request: web.Request) -> Dict: + """Parse LoRA-specific parameters""" + params = {} + + # LoRA-specific parameters + if 'first_letter' in request.query: + params['first_letter'] = request.query.get('first_letter') + + # Handle fuzzy search parameter name variation + if request.query.get('fuzzy') == 'true': + params['fuzzy_search'] = True + + # Handle additional filter parameters for LoRAs + if 'lora_hash' in request.query: + if not params.get('hash_filters'): + params['hash_filters'] = {} + params['hash_filters']['single_hash'] = request.query['lora_hash'].lower() + elif 'lora_hashes' in request.query: + if not params.get('hash_filters'): + params['hash_filters'] = {} + params['hash_filters']['multiple_hashes'] = [h.lower() for h in request.query['lora_hashes'].split(',')] + + return params + async def handle_loras_page(self, request: web.Request) -> web.Response: """Handle GET /loras request""" try: - # Ensure services are initialized - await self.init_services() - # Check if the LoraScanner is initializing # It's initializing if the cache object doesn't exist yet, # OR if the scanner explicitly says it's initializing (background task running). is_initializing = ( - self.scanner._cache is None or self.scanner.is_initializing() + self.service.scanner._cache is None or self.service.scanner.is_initializing() ) if is_initializing: @@ -87,7 +137,7 @@ class LoraRoutes: else: # Normal flow - get data from initialized cache try: - cache = await self.scanner.get_cached_data(force_refresh=False) + cache = await self.service.scanner.get_cached_data(force_refresh=False) template = self.template_env.get_template('loras.html') rendered = template.render( folders=cache.folders, @@ -117,72 +167,561 @@ class LoraRoutes: text="Error loading loras page", status=500 ) - - async def handle_recipes_page(self, request: web.Request) -> web.Response: - """Handle GET /loras/recipes request""" + + # LoRA-specific route handlers + async def get_letter_counts(self, request: web.Request) -> web.Response: + """Get count of LoRAs for each letter of the alphabet""" try: - # Ensure services are initialized - await self.init_services() + letter_counts = await self.service.get_letter_counts() + return web.json_response({ + 'success': True, + 'letter_counts': letter_counts + }) + except Exception as e: + logger.error(f"Error getting letter counts: {e}") + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + async def get_lora_notes(self, request: web.Request) -> web.Response: + """Get notes for a specific LoRA file""" + try: + lora_name = request.query.get('name') + if not lora_name: + return web.Response(text='Lora file name is required', status=400) - # Skip initialization check and directly try to get cached data - try: - # Recipe scanner will initialize cache if needed - await self.recipe_scanner.get_cached_data(force_refresh=False) - template = self.template_env.get_template('recipes.html') - rendered = template.render( - recipes=[], # Frontend will load recipes via API - is_initializing=False, - settings=settings, - request=request - ) - except Exception as cache_error: - logger.error(f"Error loading recipe cache data: {cache_error}") - # Still keep error handling - show initializing page on error - template = self.template_env.get_template('recipes.html') - rendered = template.render( - is_initializing=True, - settings=settings, - request=request - ) - logger.info("Recipe cache error, returning initialization page") + notes = await self.service.get_lora_notes(lora_name) + if notes is not None: + return web.json_response({ + 'success': True, + 'notes': notes + }) + else: + return web.json_response({ + 'success': False, + 'error': 'LoRA not found in cache' + }, status=404) + + except Exception as e: + logger.error(f"Error getting lora notes: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + async def get_lora_trigger_words(self, request: web.Request) -> web.Response: + """Get trigger words for a specific LoRA file""" + try: + lora_name = request.query.get('name') + if not lora_name: + return web.Response(text='Lora file name is required', status=400) - return web.Response( - text=rendered, - content_type='text/html' - ) + trigger_words = await self.service.get_lora_trigger_words(lora_name) + return web.json_response({ + 'success': True, + 'trigger_words': trigger_words + }) except Exception as e: - logger.error(f"Error handling recipes request: {e}", exc_info=True) - return web.Response( - text="Error loading recipes page", - status=500 - ) - - def _format_recipe_file_url(self, file_path: str) -> str: - """Format file path for recipe image as a URL - same as in recipe_routes""" + logger.error(f"Error getting lora trigger words: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + async def get_lora_preview_url(self, request: web.Request) -> web.Response: + """Get the static preview URL for a LoRA file""" try: - # Return the file URL directly for the first lora root's preview - recipes_dir = os.path.join(config.loras_roots[0], "recipes").replace(os.sep, '/') - if file_path.replace(os.sep, '/').startswith(recipes_dir): - relative_path = os.path.relpath(file_path, config.loras_roots[0]).replace(os.sep, '/') - return f"/loras_static/root1/preview/{relative_path}" + lora_name = request.query.get('name') + if not lora_name: + return web.Response(text='Lora file name is required', status=400) - # If not in recipes dir, try to create a valid URL from the file path + preview_url = await self.service.get_lora_preview_url(lora_name) + if preview_url: + return web.json_response({ + 'success': True, + 'preview_url': preview_url + }) + else: + return web.json_response({ + 'success': False, + 'error': 'No preview URL found for the specified lora' + }, status=404) + + except Exception as e: + logger.error(f"Error getting lora preview URL: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + async def get_lora_civitai_url(self, request: web.Request) -> web.Response: + """Get the Civitai URL for a LoRA file""" + try: + lora_name = request.query.get('name') + if not lora_name: + return web.Response(text='Lora file name is required', status=400) + + result = await self.service.get_lora_civitai_url(lora_name) + if result['civitai_url']: + return web.json_response({ + 'success': True, + **result + }) + else: + return web.json_response({ + 'success': False, + 'error': 'No Civitai data found for the specified lora' + }, status=404) + + except Exception as e: + logger.error(f"Error getting lora Civitai URL: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + async def get_folders(self, request: web.Request) -> web.Response: + """Get all folders in the cache""" + try: + cache = await self.service.scanner.get_cached_data() + return web.json_response({ + 'folders': cache.folders + }) + except Exception as e: + logger.error(f"Error getting folders: {e}") + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + async def get_lora_roots(self, request: web.Request) -> web.Response: + """Get all configured LoRA root directories""" + try: + return web.json_response({ + 'roots': self.service.get_model_roots() + }) + except Exception as e: + logger.error(f"Error getting LoRA roots: {e}") + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + # Override get_models to add LoRA-specific response data + async def get_models(self, request: web.Request) -> web.Response: + """Get paginated LoRA data with LoRA-specific fields""" + try: + # Parse common query parameters + params = self._parse_common_params(request) + + # Get data from service + result = await self.service.get_paginated_data(**params) + + # Get all available folders from cache for LoRA-specific response + cache = await self.service.scanner.get_cached_data() + + # Format response items with LoRA-specific structure + formatted_result = { + 'items': [await self.service.format_response(item) for item in result['items']], + 'folders': cache.folders, # LoRA-specific: include folders in response + 'total': result['total'], + 'page': result['page'], + 'page_size': result['page_size'], + 'total_pages': result['total_pages'] + } + + return web.json_response(formatted_result) + + except Exception as e: + logger.error(f"Error in get_loras: {e}", exc_info=True) + return web.json_response({"error": str(e)}, status=500) + + # CivitAI integration methods + async def get_civitai_versions_lora(self, request: web.Request) -> web.Response: + """Get available versions for a Civitai LoRA model with local availability info""" + try: + model_id = request.match_info['model_id'] + response = await self.civitai_client.get_model_versions(model_id) + if not response or not response.get('modelVersions'): + return web.Response(status=404, text="Model not found") + + versions = response.get('modelVersions', []) + model_type = response.get('type', '') + + # Check model type - should be LORA, LoCon, or DORA + from ..utils.constants import VALID_LORA_TYPES + if model_type.lower() not in VALID_LORA_TYPES: + return web.json_response({ + 'error': f"Model type mismatch. Expected LORA or LoCon, got {model_type}" + }, status=400) + + # Check local availability for each version + for version in versions: + # Find the model file (type="Model") in the files list + model_file = next((file for file in version.get('files', []) + if file.get('type') == 'Model'), None) + + if model_file: + sha256 = model_file.get('hashes', {}).get('SHA256') + if sha256: + # Set existsLocally and localPath at the version level + version['existsLocally'] = self.service.has_hash(sha256) + if version['existsLocally']: + version['localPath'] = self.service.get_path_by_hash(sha256) + + # Also set the model file size at the version level for easier access + version['modelSizeKB'] = model_file.get('sizeKB') + else: + # No model file found in this version + version['existsLocally'] = False + + return web.json_response(versions) + except Exception as e: + logger.error(f"Error fetching LoRA model versions: {e}") + return web.Response(status=500, text=str(e)) + + async def get_civitai_model_by_version(self, request: web.Request) -> web.Response: + """Get CivitAI model details by model version ID""" + try: + model_version_id = request.match_info.get('modelVersionId') + + # Get model details from Civitai API + model, error_msg = await self.civitai_client.get_model_version_info(model_version_id) + + if not model: + # Log warning for failed model retrieval + logger.warning(f"Failed to fetch model version {model_version_id}: {error_msg}") + + # Determine status code based on error message + status_code = 404 if error_msg and "not found" in error_msg.lower() else 500 + + return web.json_response({ + "success": False, + "error": error_msg or "Failed to fetch model information" + }, status=status_code) + + return web.json_response(model) + except Exception as e: + logger.error(f"Error fetching model details: {e}") + return web.json_response({ + "success": False, + "error": str(e) + }, status=500) + + async def get_civitai_model_by_hash(self, request: web.Request) -> web.Response: + """Get CivitAI model details by hash""" + try: + hash = request.match_info.get('hash') + model = await self.civitai_client.get_model_by_hash(hash) + return web.json_response(model) + except Exception as e: + logger.error(f"Error fetching model details by hash: {e}") + return web.json_response({ + "success": False, + "error": str(e) + }, status=500) + + # Download management methods + async def download_model(self, request: web.Request) -> web.Response: + """Handle model download request""" + return await ModelRouteUtils.handle_download_model(request, self.download_manager) + + async def download_model_get(self, request: web.Request) -> web.Response: + """Handle model download request via GET method""" + try: + # Extract query parameters + model_id = request.query.get('model_id') + if not model_id: + return web.Response( + status=400, + text="Missing required parameter: Please provide 'model_id'" + ) + + # Get optional parameters + model_version_id = request.query.get('model_version_id') + download_id = request.query.get('download_id') + use_default_paths = request.query.get('use_default_paths', 'false').lower() == 'true' + + # Create a data dictionary that mimics what would be received from a POST request + data = { + 'model_id': model_id + } + + # Add optional parameters only if they are provided + if model_version_id: + data['model_version_id'] = model_version_id + + if download_id: + data['download_id'] = download_id + + data['use_default_paths'] = use_default_paths + + # Create a mock request object with the data + future = asyncio.get_event_loop().create_future() + future.set_result(data) + + mock_request = type('MockRequest', (), { + 'json': lambda self=None: future + })() + + # Call the existing download handler + return await ModelRouteUtils.handle_download_model(mock_request, self.download_manager) + + except Exception as e: + error_message = str(e) + logger.error(f"Error downloading model via GET: {error_message}", exc_info=True) + return web.Response(status=500, text=error_message) + + async def cancel_download_get(self, request: web.Request) -> web.Response: + """Handle GET request for cancelling a download by download_id""" + try: + download_id = request.query.get('download_id') + if not download_id: + return web.json_response({ + 'success': False, + 'error': 'Download ID is required' + }, status=400) + + # Create a mock request with match_info for compatibility + mock_request = type('MockRequest', (), { + 'match_info': {'download_id': download_id} + })() + return await ModelRouteUtils.handle_cancel_download(mock_request, self.download_manager) + except Exception as e: + logger.error(f"Error cancelling download via GET: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + async def get_download_progress(self, request: web.Request) -> web.Response: + """Handle request for download progress by download_id""" + try: + # Get download_id from URL path + download_id = request.match_info.get('download_id') + if not download_id: + return web.json_response({ + 'success': False, + 'error': 'Download ID is required' + }, status=400) + + # Get progress information from websocket manager + from ..services.websocket_manager import ws_manager + progress_data = ws_manager.get_download_progress(download_id) + + if progress_data is None: + return web.json_response({ + 'success': False, + 'error': 'Download ID not found' + }, status=404) + + return web.json_response({ + 'success': True, + 'progress': progress_data.get('progress', 0) + }) + except Exception as e: + logger.error(f"Error getting download progress: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + # Model management methods + async def move_model(self, request: web.Request) -> web.Response: + """Handle model move request""" + try: + data = await request.json() + file_path = data.get('file_path') # full path of the model file + target_path = data.get('target_path') # folder path to move the model to + + if not file_path or not target_path: + return web.Response(text='File path and target path are required', status=400) + + # Check if source and destination are the same + import os + source_dir = os.path.dirname(file_path) + if os.path.normpath(source_dir) == os.path.normpath(target_path): + logger.info(f"Source and target directories are the same: {source_dir}") + return web.json_response({'success': True, 'message': 'Source and target directories are the same'}) + + # Check if target file already exists file_name = os.path.basename(file_path) - return f"/loras_static/root1/preview/recipes/{file_name}" - except Exception as e: - logger.error(f"Error formatting recipe file URL: {e}", exc_info=True) - return '/loras_static/images/no-preview.png' # Return default image on error + target_file_path = os.path.join(target_path, file_name).replace(os.sep, '/') + + if os.path.exists(target_file_path): + return web.json_response({ + 'success': False, + 'error': f"Target file already exists: {target_file_path}" + }, status=409) # 409 Conflict - def setup_routes(self, app: web.Application): - """Register routes with the application""" - # Add an app startup handler to initialize services - app.on_startup.append(self._on_startup) - - # Register routes - app.router.add_get('/loras', self.handle_loras_page) - app.router.add_get('/loras/recipes', self.handle_recipes_page) - - async def _on_startup(self, app): - """Initialize services when the app starts""" - await self.init_services() + # Call scanner to handle the move operation + success = await self.service.scanner.move_model(file_path, target_path) + + if success: + return web.json_response({'success': True}) + else: + return web.Response(text='Failed to move model', status=500) + + except Exception as e: + logger.error(f"Error moving model: {e}", exc_info=True) + return web.Response(text=str(e), status=500) + + async def move_models_bulk(self, request: web.Request) -> web.Response: + """Handle bulk model move request""" + try: + data = await request.json() + file_paths = data.get('file_paths', []) # list of full paths of the model files + target_path = data.get('target_path') # folder path to move the models to + + if not file_paths or not target_path: + return web.Response(text='File paths and target path are required', status=400) + + results = [] + import os + for file_path in file_paths: + # Check if source and destination are the same + source_dir = os.path.dirname(file_path) + if os.path.normpath(source_dir) == os.path.normpath(target_path): + results.append({ + "path": file_path, + "success": True, + "message": "Source and target directories are the same" + }) + continue + + # Check if target file already exists + file_name = os.path.basename(file_path) + target_file_path = os.path.join(target_path, file_name).replace(os.sep, '/') + + if os.path.exists(target_file_path): + results.append({ + "path": file_path, + "success": False, + "message": f"Target file already exists: {target_file_path}" + }) + continue + + # Try to move the model + success = await self.service.scanner.move_model(file_path, target_path) + results.append({ + "path": file_path, + "success": success, + "message": "Success" if success else "Failed to move model" + }) + + # Count successes and failures + success_count = sum(1 for r in results if r["success"]) + failure_count = len(results) - success_count + + return web.json_response({ + 'success': True, + 'message': f'Moved {success_count} of {len(file_paths)} models', + 'results': results, + 'success_count': success_count, + 'failure_count': failure_count + }) + + except Exception as e: + logger.error(f"Error moving models in bulk: {e}", exc_info=True) + return web.Response(text=str(e), status=500) + + async def get_lora_model_description(self, request: web.Request) -> web.Response: + """Get model description for a Lora model""" + try: + # Get parameters + model_id = request.query.get('model_id') + file_path = request.query.get('file_path') + + if not model_id: + return web.json_response({ + 'success': False, + 'error': 'Model ID is required' + }, status=400) + + # Check if we already have the description stored in metadata + description = None + tags = [] + creator = {} + if file_path: + import os + from ..utils.metadata_manager import MetadataManager + metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' + metadata = await ModelRouteUtils.load_local_metadata(metadata_path) + description = metadata.get('modelDescription') + tags = metadata.get('tags', []) + creator = metadata.get('creator', {}) + + # If description is not in metadata, fetch from CivitAI + if not description: + logger.info(f"Fetching model metadata for model ID: {model_id}") + model_metadata, _ = await self.civitai_client.get_model_metadata(model_id) + + if model_metadata: + description = model_metadata.get('description') + tags = model_metadata.get('tags', []) + creator = model_metadata.get('creator', {}) + + # Save the metadata to file if we have a file path and got metadata + if file_path: + try: + metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' + metadata = await ModelRouteUtils.load_local_metadata(metadata_path) + + metadata['modelDescription'] = description + metadata['tags'] = tags + # Ensure the civitai dict exists + if 'civitai' not in metadata: + metadata['civitai'] = {} + # Store creator in the civitai nested structure + metadata['civitai']['creator'] = creator + + await MetadataManager.save_metadata(file_path, metadata, True) + except Exception as e: + logger.error(f"Error saving model metadata: {e}") + + return web.json_response({ + 'success': True, + 'description': description or "

No model description available.

", + 'tags': tags, + 'creator': creator + }) + + except Exception as e: + logger.error(f"Error getting model metadata: {e}") + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + async def get_trigger_words(self, request: web.Request) -> web.Response: + """Get trigger words for specified LoRA models""" + try: + json_data = await request.json() + lora_names = json_data.get("lora_names", []) + node_ids = json_data.get("node_ids", []) + + all_trigger_words = [] + for lora_name in lora_names: + _, trigger_words = get_lora_info(lora_name) + all_trigger_words.extend(trigger_words) + + # Format the trigger words + trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else "" + + # Send update to all connected trigger word toggle nodes + for node_id in node_ids: + PromptServer.instance.send_sync("trigger_word_update", { + "id": node_id, + "message": trigger_words_text + }) + + return web.json_response({"success": True}) + + except Exception as e: + logger.error(f"Error getting trigger words: {e}") + return web.json_response({ + "success": False, + "error": str(e) + }, status=500) diff --git a/py/routes/misc_routes.py b/py/routes/misc_routes.py index 019aa82e..596e0323 100644 --- a/py/routes/misc_routes.py +++ b/py/routes/misc_routes.py @@ -633,9 +633,8 @@ class MiscRoutes: }, status=400) # Get both lora and checkpoint scanners - registry = ServiceRegistry.get_instance() - lora_scanner = await registry.get_lora_scanner() - checkpoint_scanner = await registry.get_checkpoint_scanner() + lora_scanner = await ServiceRegistry.get_lora_scanner() + checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() # If modelVersionId is provided, check for specific version if model_version_id_str: diff --git a/py/routes/recipe_routes.py b/py/routes/recipe_routes.py index 393a5ae3..d181ed65 100644 --- a/py/routes/recipe_routes.py +++ b/py/routes/recipe_routes.py @@ -1,6 +1,7 @@ import os import time import base64 +import jinja2 import numpy as np from PIL import Image import io @@ -15,6 +16,7 @@ from ..utils.exif_utils import ExifUtils from ..recipes import RecipeParserFactory from ..utils.constants import CARD_PREVIEW_WIDTH +from ..services.settings_manager import settings from ..config import config # Check if running in standalone mode @@ -39,7 +41,10 @@ class RecipeRoutes: # Initialize service references as None, will be set during async init self.recipe_scanner = None self.civitai_client = None - # Remove WorkflowParser instance + self.template_env = jinja2.Environment( + loader=jinja2.FileSystemLoader(config.templates_path), + autoescape=True + ) # Pre-warm the cache self._init_cache_task = None @@ -53,6 +58,8 @@ class RecipeRoutes: def setup_routes(cls, app: web.Application): """Register API routes""" routes = cls() + app.router.add_get('/loras/recipes', routes.handle_recipes_page) + app.router.add_get('/api/recipes', routes.get_recipes) app.router.add_get('/api/recipe/{recipe_id}', routes.get_recipe_detail) app.router.add_post('/api/recipes/analyze-image', routes.analyze_recipe_image) @@ -114,6 +121,46 @@ class RecipeRoutes: await self.recipe_scanner.get_cached_data(force_refresh=True) except Exception as e: logger.error(f"Error pre-warming recipe cache: {e}", exc_info=True) + + async def handle_recipes_page(self, request: web.Request) -> web.Response: + """Handle GET /loras/recipes request""" + try: + # Ensure services are initialized + await self.init_services() + + # Skip initialization check and directly try to get cached data + try: + # Recipe scanner will initialize cache if needed + await self.recipe_scanner.get_cached_data(force_refresh=False) + template = self.template_env.get_template('recipes.html') + rendered = template.render( + recipes=[], # Frontend will load recipes via API + is_initializing=False, + settings=settings, + request=request + ) + except Exception as cache_error: + logger.error(f"Error loading recipe cache data: {cache_error}") + # Still keep error handling - show initializing page on error + template = self.template_env.get_template('recipes.html') + rendered = template.render( + is_initializing=True, + settings=settings, + request=request + ) + logger.info("Recipe cache error, returning initialization page") + + return web.Response( + text=rendered, + content_type='text/html' + ) + + except Exception as e: + logger.error(f"Error handling recipes request: {e}", exc_info=True) + return web.Response( + text="Error loading recipes page", + status=500 + ) async def get_recipes(self, request: web.Request) -> web.Response: """API endpoint for getting paginated recipes""" diff --git a/py/services/base_model_service.py b/py/services/base_model_service.py new file mode 100644 index 00000000..7ecc994b --- /dev/null +++ b/py/services/base_model_service.py @@ -0,0 +1,248 @@ +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Type, Set +import logging + +from ..utils.models import BaseModelMetadata +from ..utils.constants import NSFW_LEVELS +from .settings_manager import settings +from ..utils.utils import fuzzy_match + +logger = logging.getLogger(__name__) + +class BaseModelService(ABC): + """Base service class for all model types""" + + def __init__(self, model_type: str, scanner, metadata_class: Type[BaseModelMetadata]): + """Initialize the service + + Args: + model_type: Type of model (lora, checkpoint, etc.) + scanner: Model scanner instance + metadata_class: Metadata class for this model type + """ + self.model_type = model_type + self.scanner = scanner + self.metadata_class = metadata_class + + async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name', + folder: str = None, search: str = None, fuzzy_search: bool = False, + base_models: list = None, tags: list = None, + search_options: dict = None, hash_filters: dict = None, + favorites_only: bool = False, **kwargs) -> Dict: + """Get paginated and filtered model data + + Args: + page: Page number (1-based) + page_size: Number of items per page + sort_by: Sort criteria ('name' or 'date') + folder: Folder filter + search: Search term + fuzzy_search: Whether to use fuzzy search + base_models: List of base models to filter by + tags: List of tags to filter by + search_options: Search options dict + hash_filters: Hash filtering options + favorites_only: Filter for favorites only + **kwargs: Additional model-specific filters + + Returns: + Dict containing paginated results + """ + cache = await self.scanner.get_cached_data() + + # Get default search options if not provided + if search_options is None: + search_options = { + 'filename': True, + 'modelname': True, + 'tags': False, + 'recursive': False, + } + + # Get the base data set + filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name + + # Apply hash filtering if provided (highest priority) + if hash_filters: + filtered_data = await self._apply_hash_filters(filtered_data, hash_filters) + + # Jump to pagination for hash filters + return self._paginate(filtered_data, page, page_size) + + # Apply common filters + filtered_data = await self._apply_common_filters( + filtered_data, folder, base_models, tags, favorites_only, search_options + ) + + # Apply search filtering + if search: + filtered_data = await self._apply_search_filters( + filtered_data, search, fuzzy_search, search_options + ) + + # Apply model-specific filters + filtered_data = await self._apply_specific_filters(filtered_data, **kwargs) + + return self._paginate(filtered_data, page, page_size) + + async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]: + """Apply hash-based filtering""" + single_hash = hash_filters.get('single_hash') + multiple_hashes = hash_filters.get('multiple_hashes') + + if single_hash: + # Filter by single hash + single_hash = single_hash.lower() + return [ + item for item in data + if item.get('sha256', '').lower() == single_hash + ] + elif multiple_hashes: + # Filter by multiple hashes + hash_set = set(hash.lower() for hash in multiple_hashes) + return [ + item for item in data + if item.get('sha256', '').lower() in hash_set + ] + + return data + + async def _apply_common_filters(self, data: List[Dict], folder: str = None, + base_models: list = None, tags: list = None, + favorites_only: bool = False, search_options: dict = None) -> List[Dict]: + """Apply common filters that work across all model types""" + # Apply SFW filtering if enabled in settings + if settings.get('show_only_sfw', False): + data = [ + item for item in data + if not item.get('preview_nsfw_level') or item.get('preview_nsfw_level') < NSFW_LEVELS['R'] + ] + + # Apply favorites filtering if enabled + if favorites_only: + data = [ + item for item in data + if item.get('favorite', False) is True + ] + + # Apply folder filtering + if folder is not None: + if search_options and search_options.get('recursive', False): + # Recursive folder filtering - include all subfolders + data = [ + item for item in data + if item['folder'].startswith(folder) + ] + else: + # Exact folder filtering + data = [ + item for item in data + if item['folder'] == folder + ] + + # Apply base model filtering + if base_models and len(base_models) > 0: + data = [ + item for item in data + if item.get('base_model') in base_models + ] + + # Apply tag filtering + if tags and len(tags) > 0: + data = [ + item for item in data + if any(tag in item.get('tags', []) for tag in tags) + ] + + return data + + async def _apply_search_filters(self, data: List[Dict], search: str, + fuzzy_search: bool, search_options: dict) -> List[Dict]: + """Apply search filtering""" + search_results = [] + + for item in data: + # Search by file name + if search_options.get('filename', True): + if fuzzy_search: + if fuzzy_match(item.get('file_name', ''), search): + search_results.append(item) + continue + elif search.lower() in item.get('file_name', '').lower(): + search_results.append(item) + continue + + # Search by model name + if search_options.get('modelname', True): + if fuzzy_search: + if fuzzy_match(item.get('model_name', ''), search): + search_results.append(item) + continue + elif search.lower() in item.get('model_name', '').lower(): + search_results.append(item) + continue + + # Search by tags + if search_options.get('tags', False) and 'tags' in item: + if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower()) + for tag in item['tags']): + search_results.append(item) + continue + + return search_results + + async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]: + """Apply model-specific filters - to be overridden by subclasses if needed""" + return data + + def _paginate(self, data: List[Dict], page: int, page_size: int) -> Dict: + """Apply pagination to filtered data""" + total_items = len(data) + start_idx = (page - 1) * page_size + end_idx = min(start_idx + page_size, total_items) + + return { + 'items': data[start_idx:end_idx], + 'total': total_items, + 'page': page, + 'page_size': page_size, + 'total_pages': (total_items + page_size - 1) // page_size + } + + @abstractmethod + async def format_response(self, model_data: Dict) -> Dict: + """Format model data for API response - must be implemented by subclasses""" + pass + + # Common service methods that delegate to scanner + async def get_top_tags(self, limit: int = 20) -> List[Dict]: + """Get top tags sorted by frequency""" + return await self.scanner.get_top_tags(limit) + + async def get_base_models(self, limit: int = 20) -> List[Dict]: + """Get base models sorted by frequency""" + return await self.scanner.get_base_models(limit) + + def has_hash(self, sha256: str) -> bool: + """Check if a model with given hash exists""" + return self.scanner.has_hash(sha256) + + def get_path_by_hash(self, sha256: str) -> Optional[str]: + """Get file path for a model by its hash""" + return self.scanner.get_path_by_hash(sha256) + + def get_hash_by_path(self, file_path: str) -> Optional[str]: + """Get hash for a model by its file path""" + return self.scanner.get_hash_by_path(file_path) + + async def scan_models(self, force_refresh: bool = False, rebuild_cache: bool = False): + """Trigger model scanning""" + return await self.scanner.get_cached_data(force_refresh=force_refresh, rebuild_cache=rebuild_cache) + + async def get_model_info_by_name(self, name: str): + """Get model information by name""" + return await self.scanner.get_model_info_by_name(name) + + def get_model_roots(self) -> List[str]: + """Get model root directories""" + return self.scanner.get_model_roots() \ No newline at end of file diff --git a/py/services/checkpoint_scanner.py b/py/services/checkpoint_scanner.py index 26733ab5..32da3dbf 100644 --- a/py/services/checkpoint_scanner.py +++ b/py/services/checkpoint_scanner.py @@ -1,14 +1,12 @@ import os import logging import asyncio -from typing import List, Dict, Optional, Set -import folder_paths # type: ignore +from typing import List, Dict from ..utils.models import CheckpointMetadata from ..config import config from .model_scanner import ModelScanner from .model_hash_index import ModelHashIndex -from .service_registry import ServiceRegistry logger = logging.getLogger(__name__) @@ -109,4 +107,31 @@ class CheckpointScanner(ModelScanner): if result: checkpoints.append(result) except Exception as e: - logger.error(f"Error processing {file_path}: {e}") \ No newline at end of file + logger.error(f"Error processing {file_path}: {e}") + + # Checkpoint-specific hash index functionality + def has_checkpoint_hash(self, sha256: str) -> bool: + """Check if a checkpoint with given hash exists""" + return self.has_hash(sha256) + + def get_checkpoint_path_by_hash(self, sha256: str) -> str: + """Get file path for a checkpoint by its hash""" + return self.get_path_by_hash(sha256) + + def get_checkpoint_hash_by_path(self, file_path: str) -> str: + """Get hash for a checkpoint by its file path""" + return self.get_hash_by_path(file_path) + + async def get_checkpoint_info_by_name(self, name): + """Get checkpoint information by name""" + try: + cache = await self.get_cached_data() + + for checkpoint in cache.raw_data: + if checkpoint.get("file_name") == name: + return checkpoint + + return None + except Exception as e: + logger.error(f"Error getting checkpoint info by name: {e}", exc_info=True) + return None \ No newline at end of file diff --git a/py/services/checkpoint_service.py b/py/services/checkpoint_service.py new file mode 100644 index 00000000..cfc4b118 --- /dev/null +++ b/py/services/checkpoint_service.py @@ -0,0 +1,51 @@ +import os +import logging +from typing import Dict, List, Optional + +from .base_model_service import BaseModelService +from ..utils.models import CheckpointMetadata +from ..config import config +from ..utils.routes_common import ModelRouteUtils + +logger = logging.getLogger(__name__) + +class CheckpointService(BaseModelService): + """Checkpoint-specific service implementation""" + + def __init__(self, scanner): + """Initialize Checkpoint service + + Args: + scanner: Checkpoint scanner instance + """ + super().__init__("checkpoint", scanner, CheckpointMetadata) + + async def format_response(self, checkpoint_data: Dict) -> Dict: + """Format Checkpoint data for API response""" + return { + "model_name": checkpoint_data["model_name"], + "file_name": checkpoint_data["file_name"], + "preview_url": config.get_preview_static_url(checkpoint_data.get("preview_url", "")), + "preview_nsfw_level": checkpoint_data.get("preview_nsfw_level", 0), + "base_model": checkpoint_data.get("base_model", ""), + "folder": checkpoint_data["folder"], + "sha256": checkpoint_data.get("sha256", ""), + "file_path": checkpoint_data["file_path"].replace(os.sep, "/"), + "file_size": checkpoint_data.get("size", 0), + "modified": checkpoint_data.get("modified", ""), + "tags": checkpoint_data.get("tags", []), + "modelDescription": checkpoint_data.get("modelDescription", ""), + "from_civitai": checkpoint_data.get("from_civitai", True), + "notes": checkpoint_data.get("notes", ""), + "model_type": checkpoint_data.get("model_type", "checkpoint"), + "favorite": checkpoint_data.get("favorite", False), + "civitai": ModelRouteUtils.filter_civitai_data(checkpoint_data.get("civitai", {})) + } + + def find_duplicate_hashes(self) -> Dict: + """Find Checkpoints with duplicate SHA256 hashes""" + return self.scanner._hash_index.get_duplicate_hashes() + + def find_duplicate_filenames(self) -> Dict: + """Find Checkpoints with conflicting filenames""" + return self.scanner._hash_index.get_duplicate_filenames() \ No newline at end of file diff --git a/py/services/lora_service.py b/py/services/lora_service.py new file mode 100644 index 00000000..bcfa84c5 --- /dev/null +++ b/py/services/lora_service.py @@ -0,0 +1,172 @@ +import os +import logging +from typing import Dict, List, Optional + +from .base_model_service import BaseModelService +from ..utils.models import LoraMetadata +from ..config import config +from ..utils.routes_common import ModelRouteUtils + +logger = logging.getLogger(__name__) + +class LoraService(BaseModelService): + """LoRA-specific service implementation""" + + def __init__(self, scanner): + """Initialize LoRA service + + Args: + scanner: LoRA scanner instance + """ + super().__init__("lora", scanner, LoraMetadata) + + async def format_response(self, lora_data: Dict) -> Dict: + """Format LoRA data for API response""" + return { + "model_name": lora_data["model_name"], + "file_name": lora_data["file_name"], + "preview_url": config.get_preview_static_url(lora_data.get("preview_url", "")), + "preview_nsfw_level": lora_data.get("preview_nsfw_level", 0), + "base_model": lora_data.get("base_model", ""), + "folder": lora_data["folder"], + "sha256": lora_data.get("sha256", ""), + "file_path": lora_data["file_path"].replace(os.sep, "/"), + "file_size": lora_data.get("size", 0), + "modified": lora_data.get("modified", ""), + "tags": lora_data.get("tags", []), + "modelDescription": lora_data.get("modelDescription", ""), + "from_civitai": lora_data.get("from_civitai", True), + "usage_tips": lora_data.get("usage_tips", ""), + "notes": lora_data.get("notes", ""), + "favorite": lora_data.get("favorite", False), + "civitai": ModelRouteUtils.filter_civitai_data(lora_data.get("civitai", {})) + } + + async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]: + """Apply LoRA-specific filters""" + # Handle first_letter filter for LoRAs + first_letter = kwargs.get('first_letter') + if first_letter: + data = self._filter_by_first_letter(data, first_letter) + + return data + + def _filter_by_first_letter(self, data: List[Dict], letter: str) -> List[Dict]: + """Filter LoRAs by first letter""" + if letter == '#': + # Filter for non-alphabetic characters + return [ + item for item in data + if not item.get('model_name', '')[0].isalpha() + ] + elif letter == 'CJK': + # Filter for CJK characters + return [ + item for item in data + if item.get('model_name', '') and self._is_cjk_character(item['model_name'][0]) + ] + else: + # Filter for specific letter + return [ + item for item in data + if item.get('model_name', '').lower().startswith(letter.lower()) + ] + + def _is_cjk_character(self, char: str) -> bool: + """Check if character is CJK (Chinese, Japanese, Korean)""" + cjk_ranges = [ + (0x4E00, 0x9FFF), # CJK Unified Ideographs + (0x3400, 0x4DBF), # CJK Extension A + (0x20000, 0x2A6DF), # CJK Extension B + (0x2A700, 0x2B73F), # CJK Extension C + (0x2B740, 0x2B81F), # CJK Extension D + (0x3040, 0x309F), # Hiragana + (0x30A0, 0x30FF), # Katakana + (0xAC00, 0xD7AF), # Hangul Syllables + ] + + char_code = ord(char) + return any(start <= char_code <= end for start, end in cjk_ranges) + + # LoRA-specific methods + async def get_letter_counts(self) -> Dict[str, int]: + """Get count of LoRAs for each letter of the alphabet""" + cache = await self.scanner.get_cached_data() + letter_counts = {} + + for lora in cache.raw_data: + model_name = lora.get('model_name', '') + if model_name: + first_char = model_name[0].upper() + if first_char.isalpha(): + letter_counts[first_char] = letter_counts.get(first_char, 0) + 1 + elif self._is_cjk_character(first_char): + letter_counts['CJK'] = letter_counts.get('CJK', 0) + 1 + else: + letter_counts['#'] = letter_counts.get('#', 0) + 1 + + return letter_counts + + async def get_lora_notes(self, lora_name: str) -> Optional[str]: + """Get notes for a specific LoRA file""" + cache = await self.scanner.get_cached_data() + + for lora in cache.raw_data: + if lora['file_name'] == lora_name: + return lora.get('notes', '') + + return None + + async def get_lora_trigger_words(self, lora_name: str) -> List[str]: + """Get trigger words for a specific LoRA file""" + cache = await self.scanner.get_cached_data() + + for lora in cache.raw_data: + if lora['file_name'] == lora_name: + civitai_data = lora.get('civitai', {}) + return civitai_data.get('trainedWords', []) + + return [] + + async def get_lora_preview_url(self, lora_name: str) -> Optional[str]: + """Get the static preview URL for a LoRA file""" + cache = await self.scanner.get_cached_data() + + for lora in cache.raw_data: + if lora['file_name'] == lora_name: + preview_url = lora.get('preview_url') + if preview_url: + return config.get_preview_static_url(preview_url) + + return None + + async def get_lora_civitai_url(self, lora_name: str) -> Dict[str, Optional[str]]: + """Get the Civitai URL for a LoRA file""" + cache = await self.scanner.get_cached_data() + + for lora in cache.raw_data: + if lora['file_name'] == lora_name: + civitai_data = lora.get('civitai', {}) + model_id = civitai_data.get('modelId') + version_id = civitai_data.get('id') + + if model_id: + civitai_url = f"https://civitai.com/models/{model_id}" + if version_id: + civitai_url += f"?modelVersionId={version_id}" + + return { + 'civitai_url': civitai_url, + 'model_id': str(model_id), + 'version_id': str(version_id) if version_id else None + } + + return {'civitai_url': None, 'model_id': None, 'version_id': None} + + def find_duplicate_hashes(self) -> Dict: + """Find LoRAs with duplicate SHA256 hashes""" + return self.scanner._hash_index.get_duplicate_hashes() + + def find_duplicate_filenames(self) -> Dict: + """Find LoRAs with conflicting filenames""" + return self.scanner._hash_index.get_duplicate_filenames() \ No newline at end of file diff --git a/py/services/model_service_factory.py b/py/services/model_service_factory.py new file mode 100644 index 00000000..6cc8a3a3 --- /dev/null +++ b/py/services/model_service_factory.py @@ -0,0 +1,137 @@ +from typing import Dict, Type, Any +import logging + +logger = logging.getLogger(__name__) + +class ModelServiceFactory: + """Factory for managing model services and routes""" + + _services: Dict[str, Type] = {} + _routes: Dict[str, Type] = {} + _initialized_services: Dict[str, Any] = {} + _initialized_routes: Dict[str, Any] = {} + + @classmethod + def register_model_type(cls, model_type: str, service_class: Type, route_class: Type): + """Register a new model type with its service and route classes + + Args: + model_type: The model type identifier (e.g., 'lora', 'checkpoint') + service_class: The service class for this model type + route_class: The route class for this model type + """ + cls._services[model_type] = service_class + cls._routes[model_type] = route_class + logger.info(f"Registered model type '{model_type}' with service {service_class.__name__} and routes {route_class.__name__}") + + @classmethod + def get_service_class(cls, model_type: str) -> Type: + """Get service class for a model type + + Args: + model_type: The model type identifier + + Returns: + The service class for the model type + + Raises: + ValueError: If model type is not registered + """ + if model_type not in cls._services: + raise ValueError(f"Unknown model type: {model_type}") + return cls._services[model_type] + + @classmethod + def get_route_class(cls, model_type: str) -> Type: + """Get route class for a model type + + Args: + model_type: The model type identifier + + Returns: + The route class for the model type + + Raises: + ValueError: If model type is not registered + """ + if model_type not in cls._routes: + raise ValueError(f"Unknown model type: {model_type}") + return cls._routes[model_type] + + @classmethod + def get_route_instance(cls, model_type: str): + """Get or create route instance for a model type + + Args: + model_type: The model type identifier + + Returns: + The route instance for the model type + """ + if model_type not in cls._initialized_routes: + route_class = cls.get_route_class(model_type) + cls._initialized_routes[model_type] = route_class() + return cls._initialized_routes[model_type] + + @classmethod + def setup_all_routes(cls, app): + """Setup routes for all registered model types + + Args: + app: The aiohttp application instance + """ + logger.info(f"Setting up routes for {len(cls._services)} registered model types") + + for model_type in cls._services.keys(): + try: + routes_instance = cls.get_route_instance(model_type) + routes_instance.setup_routes(app) + logger.info(f"Successfully set up routes for {model_type}") + except Exception as e: + logger.error(f"Failed to setup routes for {model_type}: {e}", exc_info=True) + + @classmethod + def get_registered_types(cls) -> list: + """Get list of all registered model types + + Returns: + List of registered model type identifiers + """ + return list(cls._services.keys()) + + @classmethod + def is_registered(cls, model_type: str) -> bool: + """Check if a model type is registered + + Args: + model_type: The model type identifier + + Returns: + True if the model type is registered, False otherwise + """ + return model_type in cls._services + + @classmethod + def clear_registrations(cls): + """Clear all registrations - mainly for testing purposes""" + cls._services.clear() + cls._routes.clear() + cls._initialized_services.clear() + cls._initialized_routes.clear() + logger.info("Cleared all model type registrations") + + +def register_default_model_types(): + """Register the default model types (LoRA and Checkpoint)""" + from ..services.lora_service import LoraService + from ..services.checkpoint_service import CheckpointService + from ..routes.lora_routes import LoraRoutes + from ..routes.checkpoint_routes import CheckpointRoutes + + # Register LoRA model type + ModelServiceFactory.register_model_type('lora', LoraService, LoraRoutes) + + # Register Checkpoint model type + ModelServiceFactory.register_model_type('checkpoint', CheckpointService, CheckpointRoutes) + + logger.info("Registered default model types: lora, checkpoint") \ No newline at end of file diff --git a/py/services/service_registry.py b/py/services/service_registry.py index 15c00a3f..9589b984 100644 --- a/py/services/service_registry.py +++ b/py/services/service_registry.py @@ -7,106 +7,176 @@ logger = logging.getLogger(__name__) T = TypeVar('T') # Define a type variable for service types class ServiceRegistry: - """Centralized registry for service singletons""" + """Central registry for managing singleton services""" - _instance = None _services: Dict[str, Any] = {} - _lock = asyncio.Lock() + _locks: Dict[str, asyncio.Lock] = {} @classmethod - def get_instance(cls): - """Get singleton instance of the registry""" - if cls._instance is None: - cls._instance = cls() - return cls._instance + async def register_service(cls, name: str, service: Any) -> None: + """Register a service instance with the registry + + Args: + name: Service name identifier + service: Service instance to register + """ + cls._services[name] = service + logger.debug(f"Registered service: {name}") @classmethod - async def register_service(cls, service_name: str, service_instance: Any) -> None: - """Register a service instance with the registry""" - registry = cls.get_instance() - async with cls._lock: - registry._services[service_name] = service_instance - logger.debug(f"Registered service: {service_name}") + async def get_service(cls, name: str) -> Optional[Any]: + """Get a service instance by name + + Args: + name: Service name identifier + + Returns: + Service instance or None if not found + """ + return cls._services.get(name) @classmethod - async def get_service(cls, service_name: str) -> Any: - """Get a service instance by name""" - registry = cls.get_instance() - async with cls._lock: - if service_name not in registry._services: - logger.debug(f"Service {service_name} not found in registry") - return None - return registry._services[service_name] + def _get_lock(cls, name: str) -> asyncio.Lock: + """Get or create a lock for a service + + Args: + name: Service name identifier + + Returns: + AsyncIO lock for the service + """ + if name not in cls._locks: + cls._locks[name] = asyncio.Lock() + return cls._locks[name] - @classmethod - def get_service_sync(cls, service_name: str) -> Any: - """Get a service instance by name (synchronous version)""" - registry = cls.get_instance() - if service_name not in registry._services: - logger.debug(f"Service {service_name} not found in registry") - return None - return registry._services[service_name] - - # Convenience methods for common services @classmethod async def get_lora_scanner(cls): - """Get the LoraScanner instance""" - from .lora_scanner import LoraScanner - scanner = await cls.get_service("lora_scanner") - if scanner is None: - scanner = await LoraScanner.get_instance() - await cls.register_service("lora_scanner", scanner) - return scanner + """Get or create LoRA scanner instance""" + service_name = "lora_scanner" + if service_name in cls._services: + return cls._services[service_name] + + async with cls._get_lock(service_name): + # Double-check after acquiring lock + if service_name in cls._services: + return cls._services[service_name] + + # Import here to avoid circular imports + from .lora_scanner import LoraScanner + + scanner = await LoraScanner.get_instance() + cls._services[service_name] = scanner + logger.info(f"Created and registered {service_name}") + return scanner + @classmethod async def get_checkpoint_scanner(cls): - """Get the CheckpointScanner instance""" - from .checkpoint_scanner import CheckpointScanner - scanner = await cls.get_service("checkpoint_scanner") - if scanner is None: + """Get or create Checkpoint scanner instance""" + service_name = "checkpoint_scanner" + + if service_name in cls._services: + return cls._services[service_name] + + async with cls._get_lock(service_name): + # Double-check after acquiring lock + if service_name in cls._services: + return cls._services[service_name] + + # Import here to avoid circular imports + from .checkpoint_scanner import CheckpointScanner + scanner = await CheckpointScanner.get_instance() - await cls.register_service("checkpoint_scanner", scanner) - return scanner - - @classmethod - async def get_civitai_client(cls): - """Get the CivitaiClient instance""" - from .civitai_client import CivitaiClient - client = await cls.get_service("civitai_client") - if client is None: - client = await CivitaiClient.get_instance() - await cls.register_service("civitai_client", client) - return client - - @classmethod - async def get_download_manager(cls): - """Get the DownloadManager instance""" - from .download_manager import DownloadManager - manager = await cls.get_service("download_manager") - if manager is None: - manager = await DownloadManager.get_instance() - await cls.register_service("download_manager", manager) - return manager - + cls._services[service_name] = scanner + logger.info(f"Created and registered {service_name}") + return scanner + @classmethod async def get_recipe_scanner(cls): - """Get the RecipeScanner instance""" - from .recipe_scanner import RecipeScanner - scanner = await cls.get_service("recipe_scanner") - if scanner is None: - lora_scanner = await cls.get_lora_scanner() - scanner = RecipeScanner(lora_scanner) - await cls.register_service("recipe_scanner", scanner) - return scanner - + """Get or create Recipe scanner instance""" + service_name = "recipe_scanner" + + if service_name in cls._services: + return cls._services[service_name] + + async with cls._get_lock(service_name): + # Double-check after acquiring lock + if service_name in cls._services: + return cls._services[service_name] + + # Import here to avoid circular imports + from .recipe_scanner import RecipeScanner + + scanner = await RecipeScanner.get_instance() + cls._services[service_name] = scanner + logger.info(f"Created and registered {service_name}") + return scanner + + @classmethod + async def get_civitai_client(cls): + """Get or create CivitAI client instance""" + service_name = "civitai_client" + + if service_name in cls._services: + return cls._services[service_name] + + async with cls._get_lock(service_name): + # Double-check after acquiring lock + if service_name in cls._services: + return cls._services[service_name] + + # Import here to avoid circular imports + from .civitai_client import CivitaiClient + + client = await CivitaiClient.get_instance() + cls._services[service_name] = client + logger.info(f"Created and registered {service_name}") + return client + + @classmethod + async def get_download_manager(cls): + """Get or create Download manager instance""" + service_name = "download_manager" + + if service_name in cls._services: + return cls._services[service_name] + + async with cls._get_lock(service_name): + # Double-check after acquiring lock + if service_name in cls._services: + return cls._services[service_name] + + # Import here to avoid circular imports + from .download_manager import DownloadManager + + manager = DownloadManager() + cls._services[service_name] = manager + logger.info(f"Created and registered {service_name}") + return manager + @classmethod async def get_websocket_manager(cls): - """Get the WebSocketManager instance""" - from .websocket_manager import ws_manager - manager = await cls.get_service("websocket_manager") - if manager is None: - # ws_manager is already a global instance in websocket_manager.py + """Get or create WebSocket manager instance""" + service_name = "websocket_manager" + + if service_name in cls._services: + return cls._services[service_name] + + async with cls._get_lock(service_name): + # Double-check after acquiring lock + if service_name in cls._services: + return cls._services[service_name] + + # Import here to avoid circular imports from .websocket_manager import ws_manager - await cls.register_service("websocket_manager", ws_manager) - manager = ws_manager - return manager \ No newline at end of file + + cls._services[service_name] = ws_manager + logger.info(f"Registered {service_name}") + return ws_manager + + @classmethod + def clear_services(cls): + """Clear all registered services - mainly for testing""" + cls._services.clear() + cls._locks.clear() + logger.info("Cleared all registered services") \ No newline at end of file diff --git a/py/utils/routes_common.py b/py/utils/routes_common.py index e151d216..8e5df544 100644 --- a/py/utils/routes_common.py +++ b/py/utils/routes_common.py @@ -1047,3 +1047,56 @@ class ModelRouteUtils: 'success': False, 'error': str(e) }, status=500) + + @staticmethod + async def handle_save_metadata(request: web.Request, scanner) -> web.Response: + """Handle saving metadata updates + + Args: + request: The aiohttp request + scanner: The model scanner instance + + Returns: + web.Response: The HTTP response + """ + try: + data = await request.json() + file_path = data.get('file_path') + if not file_path: + return web.Response(text='File path is required', status=400) + + # Remove file path from data to avoid saving it + metadata_updates = {k: v for k, v in data.items() if k != 'file_path'} + + # Get metadata file path + metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' + + # Load existing metadata + metadata = await ModelRouteUtils.load_local_metadata(metadata_path) + + # Handle nested updates (for civitai.trainedWords) + for key, value in metadata_updates.items(): + if isinstance(value, dict) and key in metadata and isinstance(metadata[key], dict): + # Deep update for nested dictionaries + for nested_key, nested_value in value.items(): + metadata[key][nested_key] = nested_value + else: + # Regular update for top-level keys + metadata[key] = value + + # Save updated metadata + await MetadataManager.save_metadata(file_path, metadata) + + # Update cache + await scanner.update_single_model_cache(file_path, file_path, metadata) + + # If model_name was updated, resort the cache + if 'model_name' in metadata_updates: + cache = await scanner.get_cached_data() + await cache.resort(name_only=True) + + return web.json_response({'success': True}) + + except Exception as e: + logger.error(f"Error saving metadata: {e}", exc_info=True) + return web.Response(text=str(e), status=500) diff --git a/standalone.py b/standalone.py index 0120ab53..0e491c14 100644 --- a/standalone.py +++ b/standalone.py @@ -314,22 +314,23 @@ class StandaloneLoraManager(LoraManager): app.router.add_static('/loras_static', config.static_path) # Setup feature routes - from py.routes.lora_routes import LoraRoutes + from py.services.model_service_factory import ModelServiceFactory, register_default_model_types from py.routes.api_routes import ApiRoutes from py.routes.recipe_routes import RecipeRoutes - from py.routes.checkpoints_routes import CheckpointsRoutes from py.routes.update_routes import UpdateRoutes from py.routes.misc_routes import MiscRoutes from py.routes.example_images_routes import ExampleImagesRoutes from py.routes.stats_routes import StatsRoutes - lora_routes = LoraRoutes() - checkpoints_routes = CheckpointsRoutes() + + register_default_model_types() + + # Setup all model routes using the factory + ModelServiceFactory.setup_all_routes(app) + stats_routes = StatsRoutes() # Initialize routes - lora_routes.setup_routes(app) - checkpoints_routes.setup_routes(app) stats_routes.setup_routes(app) ApiRoutes.setup_routes(app) RecipeRoutes.setup_routes(app)