diff --git a/py/routes/base_model_routes.py b/py/routes/base_model_routes.py index 6c3f11b4..458a5e87 100644 --- a/py/routes/base_model_routes.py +++ b/py/routes/base_model_routes.py @@ -1,1095 +1,200 @@ -from abc import ABC, abstractmethod -import asyncio -import os -import json +from __future__ import annotations + import logging -from aiohttp import web -from typing import Dict +from abc import ABC, abstractmethod +from typing import Callable, Dict, Mapping import jinja2 +from aiohttp import web -from ..utils.routes_common import ModelRouteUtils -from ..services.websocket_manager import ws_manager -from ..services.settings_manager import settings -from ..services.server_i18n import server_i18n -from ..services.model_file_service import ModelFileService, ModelMoveService -from ..services.websocket_progress_callback import WebSocketProgressCallback -from ..services.metadata_service import get_default_metadata_provider from ..config import config +from ..services.metadata_service import get_default_metadata_provider +from ..services.model_file_service import ModelFileService, ModelMoveService +from ..services.settings_manager import settings as default_settings +from ..services.websocket_manager import ws_manager as default_ws_manager +from ..services.websocket_progress_callback import WebSocketProgressCallback +from ..services.server_i18n import server_i18n as default_server_i18n +from ..utils.routes_common import ModelRouteUtils +from .model_route_registrar import COMMON_ROUTE_DEFINITIONS, ModelRouteRegistrar +from .handlers.model_handlers import ( + ModelAutoOrganizeHandler, + ModelCivitaiHandler, + ModelDownloadHandler, + ModelHandlerSet, + ModelListingHandler, + ModelManagementHandler, + ModelMoveHandler, + ModelPageView, + ModelQueryHandler, +) logger = logging.getLogger(__name__) + class BaseModelRoutes(ABC): - """Base route controller for all model types""" - - def __init__(self, service): - """Initialize the route controller - - Args: - service: Model service instance (LoraService, CheckpointService, etc.) - """ - self.service = service - self.model_type = service.model_type + """Base route controller for all model types.""" + + template_name: str | None = None + + def __init__( + self, + service=None, + *, + settings_service=default_settings, + ws_manager=default_ws_manager, + server_i18n=default_server_i18n, + metadata_provider_factory=get_default_metadata_provider, + ) -> None: + self.service = None + self.model_type = "" + self._settings = settings_service + self._ws_manager = ws_manager + self._server_i18n = server_i18n + self._metadata_provider_factory = metadata_provider_factory + self.template_env = jinja2.Environment( loader=jinja2.FileSystemLoader(config.templates_path), - autoescape=True + autoescape=True, ) - - # Initialize file services with dependency injection + + self.model_file_service: ModelFileService | None = None + self.model_move_service: ModelMoveService | None = None + self.websocket_progress_callback = WebSocketProgressCallback() + + self._handler_set: ModelHandlerSet | None = None + self._handler_mapping: Dict[str, Callable[[web.Request], web.StreamResponse]] | None = None + + if service is not None: + self.attach_service(service) + + def attach_service(self, service) -> None: + """Attach a model service and rebuild handler dependencies.""" + self.service = service + self.model_type = service.model_type self.model_file_service = ModelFileService(service.scanner, service.model_type) self.model_move_service = ModelMoveService(service.scanner) - self.websocket_progress_callback = WebSocketProgressCallback() - - def setup_routes(self, app: web.Application, prefix: str): - """Setup common routes for the model type - - Args: - app: aiohttp application - prefix: URL prefix (e.g., 'loras', 'checkpoints') - """ - # Common model management routes - app.router.add_get(f'/api/lm/{prefix}/list', self.get_models) - app.router.add_post(f'/api/lm/{prefix}/delete', self.delete_model) - app.router.add_post(f'/api/lm/{prefix}/exclude', self.exclude_model) - app.router.add_post(f'/api/lm/{prefix}/fetch-civitai', self.fetch_civitai) - app.router.add_post(f'/api/lm/{prefix}/fetch-all-civitai', self.fetch_all_civitai) - app.router.add_post(f'/api/lm/{prefix}/relink-civitai', self.relink_civitai) - app.router.add_post(f'/api/lm/{prefix}/replace-preview', self.replace_preview) - app.router.add_post(f'/api/lm/{prefix}/save-metadata', self.save_metadata) - app.router.add_post(f'/api/lm/{prefix}/add-tags', self.add_tags) - app.router.add_post(f'/api/lm/{prefix}/rename', self.rename_model) - app.router.add_post(f'/api/lm/{prefix}/bulk-delete', self.bulk_delete_models) - app.router.add_post(f'/api/lm/{prefix}/verify-duplicates', self.verify_duplicates) - app.router.add_post(f'/api/lm/{prefix}/move_model', self.move_model) - app.router.add_post(f'/api/lm/{prefix}/move_models_bulk', self.move_models_bulk) - app.router.add_get(f'/api/lm/{prefix}/auto-organize', self.auto_organize_models) - app.router.add_post(f'/api/lm/{prefix}/auto-organize', self.auto_organize_models) - app.router.add_get(f'/api/lm/{prefix}/auto-organize-progress', self.get_auto_organize_progress) - - # Common query routes - app.router.add_get(f'/api/lm/{prefix}/top-tags', self.get_top_tags) - app.router.add_get(f'/api/lm/{prefix}/base-models', self.get_base_models) - app.router.add_get(f'/api/lm/{prefix}/scan', self.scan_models) - app.router.add_get(f'/api/lm/{prefix}/roots', self.get_model_roots) - app.router.add_get(f'/api/lm/{prefix}/folders', self.get_folders) - app.router.add_get(f'/api/lm/{prefix}/folder-tree', self.get_folder_tree) - app.router.add_get(f'/api/lm/{prefix}/unified-folder-tree', self.get_unified_folder_tree) - app.router.add_get(f'/api/lm/{prefix}/find-duplicates', self.find_duplicate_models) - app.router.add_get(f'/api/lm/{prefix}/find-filename-conflicts', self.find_filename_conflicts) - app.router.add_get(f'/api/lm/{prefix}/get-notes', self.get_model_notes) - app.router.add_get(f'/api/lm/{prefix}/preview-url', self.get_model_preview_url) - app.router.add_get(f'/api/lm/{prefix}/civitai-url', self.get_model_civitai_url) - app.router.add_get(f'/api/lm/{prefix}/metadata', self.get_model_metadata) - app.router.add_get(f'/api/lm/{prefix}/model-description', self.get_model_description) - - # Autocomplete route - app.router.add_get(f'/api/lm/{prefix}/relative-paths', self.get_relative_paths) + self._handler_set = None + self._handler_mapping = None - # Common CivitAI integration - app.router.add_get(f'/api/lm/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions) - app.router.add_get(f'/api/lm/{prefix}/civitai/model/version/{{modelVersionId}}', self.get_civitai_model_by_version) - app.router.add_get(f'/api/lm/{prefix}/civitai/model/hash/{{hash}}', self.get_civitai_model_by_hash) + def _ensure_handler_mapping(self) -> Mapping[str, Callable[[web.Request], web.StreamResponse]]: + if self._handler_mapping is None: + handler_set = self._create_handler_set() + self._handler_set = handler_set + self._handler_mapping = handler_set.to_route_mapping() + return self._handler_mapping + + def _create_handler_set(self) -> ModelHandlerSet: + service = self._ensure_service() + page_view = ModelPageView( + template_env=self.template_env, + template_name=self.template_name or "", + service=service, + settings_service=self._settings, + server_i18n=self._server_i18n, + logger=logger, + ) + listing = ModelListingHandler( + service=service, + parse_specific_params=self._parse_specific_params, + logger=logger, + ) + management = ModelManagementHandler(service=service, logger=logger) + query = ModelQueryHandler(service=service, logger=logger) + download = ModelDownloadHandler(ws_manager=self._ws_manager, logger=logger) + civitai = ModelCivitaiHandler( + service=service, + settings_service=self._settings, + ws_manager=self._ws_manager, + logger=logger, + metadata_provider_factory=self._metadata_provider_factory, + validate_model_type=self._validate_civitai_model_type, + expected_model_types=self._get_expected_model_types, + find_model_file=self._find_model_file, + ) + move = ModelMoveHandler(move_service=self._ensure_move_service(), logger=logger) + auto_organize = ModelAutoOrganizeHandler( + file_service=self._ensure_file_service(), + progress_callback=self.websocket_progress_callback, + ws_manager=self._ws_manager, + logger=logger, + ) + return ModelHandlerSet( + page_view=page_view, + listing=listing, + management=management, + query=query, + download=download, + civitai=civitai, + move=move, + auto_organize=auto_organize, + ) + + @property + def route_handlers(self) -> Mapping[str, Callable[[web.Request], web.StreamResponse]]: + return self._ensure_handler_mapping() + + def setup_routes(self, app: web.Application, prefix: str) -> None: + registrar = ModelRouteRegistrar(app) + handler_lookup = { + definition.handler_name: self._make_handler_proxy(definition.handler_name) + for definition in COMMON_ROUTE_DEFINITIONS + } + registrar.register_common_routes(prefix, handler_lookup) + self.setup_specific_routes(registrar, prefix) - # Common Download management - app.router.add_post(f'/api/lm/download-model', self.download_model) - app.router.add_get(f'/api/lm/download-model-get', self.download_model_get) - app.router.add_get(f'/api/lm/cancel-download-get', self.cancel_download_get) - app.router.add_get(f'/api/lm/download-progress/{{download_id}}', self.get_download_progress) - - # Add generic page route - app.router.add_get(f'/{prefix}', self.handle_models_page) - - # Setup model-specific routes - self.setup_specific_routes(app, prefix) - @abstractmethod - def setup_specific_routes(self, app: web.Application, prefix: str): - """Setup model-specific routes - to be implemented by subclasses""" - pass - - async def handle_models_page(self, request: web.Request) -> web.Response: - """ - Generic handler for model pages (e.g., /loras, /checkpoints). - Subclasses should set self.template_env and template_name. - """ - try: - # Check if the scanner is initializing - is_initializing = ( - self.service.scanner._cache is None or - (hasattr(self.service.scanner, 'is_initializing') and callable(self.service.scanner.is_initializing) and self.service.scanner.is_initializing()) or - (hasattr(self.service.scanner, '_is_initializing') and self.service.scanner._is_initializing) - ) + def setup_specific_routes(self, registrar: ModelRouteRegistrar, prefix: str) -> None: + """Setup model-specific routes.""" + raise NotImplementedError - template_name = getattr(self, "template_name", None) - if not self.template_env or not template_name: - return web.Response(text="Template environment or template name not set", status=500) - - # Get user's language setting - user_language = settings.get('language', 'en') - - # Set server-side i18n locale - server_i18n.set_locale(user_language) - - # Add i18n filter to the template environment if not already added - if not hasattr(self.template_env, '_i18n_filter_added'): - self.template_env.filters['t'] = server_i18n.create_template_filter() - self.template_env._i18n_filter_added = True - - # Prepare template context - template_context = { - 'is_initializing': is_initializing, - 'settings': settings, - 'request': request, - 'folders': [], - 't': server_i18n.get_translation, - } - - if not is_initializing: - try: - cache = await self.service.scanner.get_cached_data(force_refresh=False) - template_context['folders'] = getattr(cache, "folders", []) - except Exception as cache_error: - logger.error(f"Error loading cache data: {cache_error}") - template_context['is_initializing'] = True - - rendered = self.template_env.get_template(template_name).render(**template_context) - - return web.Response( - text=rendered, - content_type='text/html' - ) - except Exception as e: - logger.error(f"Error handling models page: {e}", exc_info=True) - return web.Response( - text="Error loading models page", - status=500 - ) - - async def get_models(self, request: web.Request) -> web.Response: - """Get paginated model data""" - try: - # Parse common query parameters - params = self._parse_common_params(request) - - # Get data from service - result = await self.service.get_paginated_data(**params) - - # Format response items - formatted_result = { - 'items': [await self.service.format_response(item) for item in result['items']], - 'total': result['total'], - 'page': result['page'], - 'page_size': result['page_size'], - 'total_pages': result['total_pages'] - } - - return web.json_response(formatted_result) - - except Exception as e: - logger.error(f"Error in get_{self.model_type}s: {e}", exc_info=True) - return web.json_response({"error": str(e)}, status=500) - - def _parse_common_params(self, request: web.Request) -> Dict: - """Parse common query parameters""" - # Parse basic pagination and sorting - page = int(request.query.get('page', '1')) - page_size = min(int(request.query.get('page_size', '20')), 100) - sort_by = request.query.get('sort_by', 'name') - folder = request.query.get('folder', None) - search = request.query.get('search', None) - fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true' - - # Parse filter arrays - base_models = request.query.getall('base_model', []) - tags = request.query.getall('tag', []) - favorites_only = request.query.get('favorites_only', 'false').lower() == 'true' - - # Parse search options - search_options = { - 'filename': request.query.get('search_filename', 'true').lower() == 'true', - 'modelname': request.query.get('search_modelname', 'true').lower() == 'true', - 'tags': request.query.get('search_tags', 'false').lower() == 'true', - 'creator': request.query.get('search_creator', 'false').lower() == 'true', - 'recursive': request.query.get('recursive', 'true').lower() == 'true', - } - - # Parse hash filters if provided - hash_filters = {} - if 'hash' in request.query: - hash_filters['single_hash'] = request.query['hash'] - elif 'hashes' in request.query: - try: - hash_list = json.loads(request.query['hashes']) - if isinstance(hash_list, list): - hash_filters['multiple_hashes'] = hash_list - except (json.JSONDecodeError, TypeError): - pass - - return { - 'page': page, - 'page_size': page_size, - 'sort_by': sort_by, - 'folder': folder, - 'search': search, - 'fuzzy_search': fuzzy_search, - 'base_models': base_models, - 'tags': tags, - 'search_options': search_options, - 'hash_filters': hash_filters, - 'favorites_only': favorites_only, - # Add model-specific parameters - **self._parse_specific_params(request) - } - def _parse_specific_params(self, request: web.Request) -> Dict: - """Parse model-specific parameters - to be overridden by subclasses""" + """Parse model-specific parameters - to be overridden by subclasses.""" return {} - - # Common route handlers - async def delete_model(self, request: web.Request) -> web.Response: - """Handle model deletion request""" - return await ModelRouteUtils.handle_delete_model(request, self.service.scanner) - - async def exclude_model(self, request: web.Request) -> web.Response: - """Handle model exclusion request""" - return await ModelRouteUtils.handle_exclude_model(request, self.service.scanner) - - async def fetch_civitai(self, request: web.Request) -> web.Response: - """Handle CivitAI metadata fetch request - force refresh model metadata""" - try: - data = await request.json() - file_path = data.get('file_path') - if not file_path: - return web.json_response({"success": False, "error": "File path is required"}, status=400) - # Get model data from cache - cache = await self.service.scanner.get_cached_data() - model_data = next((item for item in cache.raw_data if item['file_path'] == file_path), None) - - if not model_data: - return web.json_response({"success": False, "error": "Model not found in cache"}, status=404) - - # Check if model has SHA256 hash - if not model_data.get('sha256'): - return web.json_response({"success": False, "error": "No SHA256 hash found"}, status=400) - - # Use fetch_and_update_model to get and update metadata - success, error = await ModelRouteUtils.fetch_and_update_model( - sha256=model_data['sha256'], - file_path=file_path, - model_data=model_data, - update_cache_func=self.service.scanner.update_single_model_cache - ) - - if not success: - return web.json_response({"success": False, "error": error}) - - # Format the updated metadata for response - formatted_metadata = await self.service.format_response(model_data) - return web.json_response({ - "success": True, - "metadata": formatted_metadata - }) - - except Exception as e: - logger.error(f"Error fetching from CivitAI: {e}", exc_info=True) - return web.json_response({"success": False, "error": str(e)}, status=500) - - async def relink_civitai(self, request: web.Request) -> web.Response: - """Handle CivitAI metadata re-linking request""" - return await ModelRouteUtils.handle_relink_civitai(request, self.service.scanner) - - async def replace_preview(self, request: web.Request) -> web.Response: - """Handle preview image replacement""" - return await ModelRouteUtils.handle_replace_preview(request, self.service.scanner) - - async def save_metadata(self, request: web.Request) -> web.Response: - """Handle saving metadata updates""" - return await ModelRouteUtils.handle_save_metadata(request, self.service.scanner) - - async def add_tags(self, request: web.Request) -> web.Response: - """Handle adding tags to model metadata""" - return await ModelRouteUtils.handle_add_tags(request, self.service.scanner) - - async def rename_model(self, request: web.Request) -> web.Response: - """Handle renaming a model file and its associated files""" - return await ModelRouteUtils.handle_rename_model(request, self.service.scanner) - - async def bulk_delete_models(self, request: web.Request) -> web.Response: - """Handle bulk deletion of models""" - return await ModelRouteUtils.handle_bulk_delete_models(request, self.service.scanner) - - async def verify_duplicates(self, request: web.Request) -> web.Response: - """Handle verification of duplicate model hashes""" - return await ModelRouteUtils.handle_verify_duplicates(request, self.service.scanner) - - async def get_top_tags(self, request: web.Request) -> web.Response: - """Handle request for top tags sorted by frequency""" - try: - limit = int(request.query.get('limit', '20')) - if limit < 1 or limit > 100: - limit = 20 - - top_tags = await self.service.get_top_tags(limit) - - return web.json_response({ - 'success': True, - 'tags': top_tags - }) - - except Exception as e: - logger.error(f"Error getting top tags: {str(e)}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': 'Internal server error' - }, status=500) - - async def get_base_models(self, request: web.Request) -> web.Response: - """Get base models used in models""" - try: - limit = int(request.query.get('limit', '20')) - if limit < 1 or limit > 100: - limit = 20 - - base_models = await self.service.get_base_models(limit) - - return web.json_response({ - 'success': True, - 'base_models': base_models - }) - except Exception as e: - logger.error(f"Error retrieving base models: {e}") - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def scan_models(self, request: web.Request) -> web.Response: - """Force a rescan of model files""" - try: - full_rebuild = request.query.get('full_rebuild', 'false').lower() == 'true' - - await self.service.scan_models(force_refresh=True, rebuild_cache=full_rebuild) - return web.json_response({ - "status": "success", - "message": f"{self.model_type.capitalize()} scan completed" - }) - except Exception as e: - logger.error(f"Error in scan_{self.model_type}s: {e}", exc_info=True) - return web.json_response({"error": str(e)}, status=500) - - async def get_model_roots(self, request: web.Request) -> web.Response: - """Return the model root directories""" - try: - roots = self.service.get_model_roots() - return web.json_response({ - "success": True, - "roots": roots - }) - except Exception as e: - logger.error(f"Error getting {self.model_type} roots: {e}", exc_info=True) - return web.json_response({ - "success": False, - "error": str(e) - }, status=500) - - async def get_folders(self, request: web.Request) -> web.Response: - """Get all folders in the cache""" - try: - cache = await self.service.scanner.get_cached_data() - return web.json_response({ - 'folders': cache.folders - }) - except Exception as e: - logger.error(f"Error getting folders: {e}") - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def get_folder_tree(self, request: web.Request) -> web.Response: - """Get hierarchical folder tree structure for download modal""" - try: - model_root = request.query.get('model_root') - if not model_root: - return web.json_response({ - 'success': False, - 'error': 'model_root parameter is required' - }, status=400) - - folder_tree = await self.service.get_folder_tree(model_root) - return web.json_response({ - 'success': True, - 'tree': folder_tree - }) - except Exception as e: - logger.error(f"Error getting folder tree: {e}") - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def get_unified_folder_tree(self, request: web.Request) -> web.Response: - """Get unified folder tree across all model roots""" - try: - unified_tree = await self.service.get_unified_folder_tree() - return web.json_response({ - 'success': True, - 'tree': unified_tree - }) - except Exception as e: - logger.error(f"Error getting unified folder tree: {e}") - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def find_duplicate_models(self, request: web.Request) -> web.Response: - """Find models with duplicate SHA256 hashes""" - try: - # Get duplicate hashes from service - duplicates = self.service.find_duplicate_hashes() - - # Format the response - result = [] - cache = await self.service.scanner.get_cached_data() - - for sha256, paths in duplicates.items(): - group = { - "hash": sha256, - "models": [] - } - # Find matching models for each path - for path in paths: - model = next((m for m in cache.raw_data if m['file_path'] == path), None) - if model: - group["models"].append(await self.service.format_response(model)) - - # Add the primary model too - primary_path = self.service.get_path_by_hash(sha256) - if primary_path and primary_path not in paths: - primary_model = next((m for m in cache.raw_data if m['file_path'] == primary_path), None) - if primary_model: - group["models"].insert(0, await self.service.format_response(primary_model)) - - if len(group["models"]) > 1: # Only include if we found multiple models - result.append(group) - - return web.json_response({ - "success": True, - "duplicates": result, - "count": len(result) - }) - except Exception as e: - logger.error(f"Error finding duplicate {self.model_type}s: {e}", exc_info=True) - return web.json_response({ - "success": False, - "error": str(e) - }, status=500) - - async def find_filename_conflicts(self, request: web.Request) -> web.Response: - """Find models with conflicting filenames""" - try: - # Get duplicate filenames from service - duplicates = self.service.find_duplicate_filenames() - - # Format the response - result = [] - cache = await self.service.scanner.get_cached_data() - - for filename, paths in duplicates.items(): - group = { - "filename": filename, - "models": [] - } - # Find matching models for each path - for path in paths: - model = next((m for m in cache.raw_data if m['file_path'] == path), None) - if model: - group["models"].append(await self.service.format_response(model)) - - # Find the model from the main index too - hash_val = self.service.scanner.get_hash_by_filename(filename) - if hash_val: - main_path = self.service.get_path_by_hash(hash_val) - if main_path and main_path not in paths: - main_model = next((m for m in cache.raw_data if m['file_path'] == main_path), None) - if main_model: - group["models"].insert(0, await self.service.format_response(main_model)) - - if group["models"]: - result.append(group) - - return web.json_response({ - "success": True, - "conflicts": result, - "count": len(result) - }) - except Exception as e: - logger.error(f"Error finding filename conflicts for {self.model_type}s: {e}", exc_info=True) - return web.json_response({ - "success": False, - "error": str(e) - }, status=500) - - # Download management methods - async def download_model(self, request: web.Request) -> web.Response: - """Handle model download request""" - return await ModelRouteUtils.handle_download_model(request) - - async def download_model_get(self, request: web.Request) -> web.Response: - """Handle model download request via GET method""" - try: - # Extract query parameters - model_id = request.query.get('model_id') - if not model_id: - return web.Response( - status=400, - text="Missing required parameter: Please provide 'model_id'" - ) - - # Get optional parameters - model_version_id = request.query.get('model_version_id') - download_id = request.query.get('download_id') - use_default_paths = request.query.get('use_default_paths', 'false').lower() == 'true' - source = request.query.get('source') # Optional source parameter - - # Create a data dictionary that mimics what would be received from a POST request - data = { - 'model_id': model_id - } - - # Add optional parameters only if they are provided - if model_version_id: - data['model_version_id'] = model_version_id - - if download_id: - data['download_id'] = download_id - - data['use_default_paths'] = use_default_paths - - # Add source parameter if provided - if source: - data['source'] = source - - # Create a mock request object with the data - future = asyncio.get_event_loop().create_future() - future.set_result(data) - - mock_request = type('MockRequest', (), { - 'json': lambda self=None: future - })() - - # Call the existing download handler - return await ModelRouteUtils.handle_download_model(mock_request) - - except Exception as e: - error_message = str(e) - logger.error(f"Error downloading model via GET: {error_message}", exc_info=True) - return web.Response(status=500, text=error_message) - - async def cancel_download_get(self, request: web.Request) -> web.Response: - """Handle GET request for cancelling a download by download_id""" - try: - download_id = request.query.get('download_id') - if not download_id: - return web.json_response({ - 'success': False, - 'error': 'Download ID is required' - }, status=400) - - # Create a mock request with match_info for compatibility - mock_request = type('MockRequest', (), { - 'match_info': {'download_id': download_id} - })() - return await ModelRouteUtils.handle_cancel_download(mock_request) - except Exception as e: - logger.error(f"Error cancelling download via GET: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def get_download_progress(self, request: web.Request) -> web.Response: - """Handle request for download progress by download_id""" - try: - # Get download_id from URL path - download_id = request.match_info.get('download_id') - if not download_id: - return web.json_response({ - 'success': False, - 'error': 'Download ID is required' - }, status=400) - - progress_data = ws_manager.get_download_progress(download_id) - - if progress_data is None: - return web.json_response({ - 'success': False, - 'error': 'Download ID not found' - }, status=404) - - return web.json_response({ - 'success': True, - 'progress': progress_data.get('progress', 0) - }) - except Exception as e: - logger.error(f"Error getting download progress: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def fetch_all_civitai(self, request: web.Request) -> web.Response: - """Fetch CivitAI metadata for all models in the background""" - try: - cache = await self.service.scanner.get_cached_data() - total = len(cache.raw_data) - processed = 0 - success = 0 - needs_resort = False - - # Prepare models to process, only those without CivitAI data - enable_metadata_archive_db = settings.get('enable_metadata_archive_db', False) - # Filter models that need CivitAI metadata update - to_process = [ - model for model in cache.raw_data - if model.get('sha256') - and ( - not model.get('civitai') or not model['civitai'].get('id') - ) - and ( - (enable_metadata_archive_db and not model.get('db_checked', False)) - or (not enable_metadata_archive_db and model.get('from_civitai') is True) - ) - ] - total_to_process = len(to_process) - - # Send initial progress - await ws_manager.broadcast({ - 'status': 'started', - 'total': total_to_process, - 'processed': 0, - 'success': 0 - }) - - # Process each model - for model in to_process: - try: - original_name = model.get('model_name') - result, error = await ModelRouteUtils.fetch_and_update_model( - sha256=model['sha256'], - file_path=model['file_path'], - model_data=model, - update_cache_func=self.service.scanner.update_single_model_cache - ) - if result: - success += 1 - if original_name != model.get('model_name'): - needs_resort = True - - processed += 1 - - # Send progress update - await ws_manager.broadcast({ - 'status': 'processing', - 'total': total_to_process, - 'processed': processed, - 'success': success, - 'current_name': model.get('model_name', 'Unknown') - }) - - except Exception as e: - logger.error(f"Error fetching CivitAI data for {model['file_path']}: {e}") - - if needs_resort: - await cache.resort() - - # Send completion message - await ws_manager.broadcast({ - 'status': 'completed', - 'total': total_to_process, - 'processed': processed, - 'success': success - }) - - return web.json_response({ - "success": True, - "message": f"Successfully updated {success} of {processed} processed {self.model_type}s (total: {total})" - }) - - except Exception as e: - # Send error message - await ws_manager.broadcast({ - 'status': 'error', - 'error': str(e) - }) - logger.error(f"Error in fetch_all_civitai for {self.model_type}s: {e}") - return web.Response(text=str(e), status=500) - - async def get_civitai_versions(self, request: web.Request) -> web.Response: - """Get available versions for a Civitai model with local availability info""" - try: - model_id = request.match_info['model_id'] - metadata_provider = await get_default_metadata_provider() - response = await metadata_provider.get_model_versions(model_id) - if not response or not response.get('modelVersions'): - return web.Response(status=404, text="Model not found") - - versions = response.get('modelVersions', []) - model_type = response.get('type', '') - - # Check model type - allow subclasses to override validation - if not self._validate_civitai_model_type(model_type): - return web.json_response({ - 'error': f"Model type mismatch. Expected {self._get_expected_model_types()}, got {model_type}" - }, status=400) - - # Check local availability for each version - for version in versions: - # Find the model file (type="Model" and primary=true) in the files list - model_file = self._find_model_file(version.get('files', [])) - - if model_file: - sha256 = model_file.get('hashes', {}).get('SHA256') - if sha256: - # Set existsLocally and localPath at the version level - version['existsLocally'] = self.service.has_hash(sha256) - if version['existsLocally']: - version['localPath'] = self.service.get_path_by_hash(sha256) - - # Also set the model file size at the version level for easier access - version['modelSizeKB'] = model_file.get('sizeKB') - else: - # No model file found in this version - version['existsLocally'] = False - - return web.json_response(versions) - except Exception as e: - logger.error(f"Error fetching {self.model_type} model versions: {e}") - return web.Response(status=500, text=str(e)) - - async def get_civitai_model_by_version(self, request: web.Request) -> web.Response: - """Get CivitAI model details by model version ID""" - try: - model_version_id = request.match_info.get('modelVersionId') - - # Get model details from metadata provider - metadata_provider = await get_default_metadata_provider() - model, error_msg = await metadata_provider.get_model_version_info(model_version_id) - - if not model: - # Log warning for failed model retrieval - logger.warning(f"Failed to fetch model version {model_version_id}: {error_msg}") - - # Determine status code based on error message - status_code = 404 if error_msg and "not found" in error_msg.lower() else 500 - - return web.json_response({ - "success": False, - "error": error_msg or "Failed to fetch model information" - }, status=status_code) - - return web.json_response(model) - except Exception as e: - logger.error(f"Error fetching model details: {e}") - return web.json_response({ - "success": False, - "error": str(e) - }, status=500) - - async def get_civitai_model_by_hash(self, request: web.Request) -> web.Response: - """Get CivitAI model details by hash""" - try: - hash = request.match_info.get('hash') - metadata_provider = await get_default_metadata_provider() - model, error = await metadata_provider.get_model_by_hash(hash) - if error: - logger.warning(f"Error getting model by hash: {error}") - return web.json_response({ - "success": False, - "error": error - }, status=404) - return web.json_response(model) - except Exception as e: - logger.error(f"Error fetching model details by hash: {e}") - return web.json_response({ - "success": False, - "error": str(e) - }, status=500) - def _validate_civitai_model_type(self, model_type: str) -> bool: - """Validate CivitAI model type - to be overridden by subclasses""" - return True # Default: accept all types - - def _get_expected_model_types(self) -> str: - """Get expected model types string for error messages - to be overridden by subclasses""" - return "any model type" - - def _find_model_file(self, files: list) -> dict: - """Find the appropriate model file from the files list - can be overridden by subclasses""" - # Find the primary model file (type="Model" and primary=true) in the files list - return next((file for file in files if file.get('type') == 'Model' and file.get('primary') == True), None) - - # Common model move handlers - async def move_model(self, request: web.Request) -> web.Response: - """Handle model move request""" - try: - data = await request.json() - file_path = data.get('file_path') - target_path = data.get('target_path') - - if not file_path or not target_path: - return web.Response(text='File path and target path are required', status=400) - - result = await self.model_move_service.move_model(file_path, target_path) - - if result['success']: - return web.json_response(result) - else: - return web.json_response(result, status=500) - - except Exception as e: - logger.error(f"Error moving model: {e}", exc_info=True) - return web.Response(text=str(e), status=500) + """Validate CivitAI model type - to be overridden by subclasses.""" + return True + + def _get_expected_model_types(self) -> str: + """Get expected model types string for error messages - to be overridden by subclasses.""" + return "any model type" + + def _find_model_file(self, files): + """Find the appropriate model file from the files list - can be overridden by subclasses.""" + return next((file for file in files if file.get("type") == "Model" and file.get("primary") is True), None) + + def get_handler(self, name: str) -> Callable[[web.Request], web.StreamResponse]: + """Expose handlers for subclasses or tests.""" + return self._ensure_handler_mapping()[name] + + @property + def utils(self) -> ModelRouteUtils: # pragma: no cover - compatibility shim + return ModelRouteUtils + + def _ensure_service(self): + if self.service is None: + raise RuntimeError("Model service has not been attached") + return self.service + + def _ensure_file_service(self) -> ModelFileService: + if self.model_file_service is None: + service = self._ensure_service() + self.model_file_service = ModelFileService(service.scanner, service.model_type) + return self.model_file_service + + def _ensure_move_service(self) -> ModelMoveService: + if self.model_move_service is None: + service = self._ensure_service() + self.model_move_service = ModelMoveService(service.scanner) + return self.model_move_service + + def _make_handler_proxy(self, name: str) -> Callable[[web.Request], web.StreamResponse]: + async def proxy(request: web.Request) -> web.StreamResponse: + try: + handler = self.get_handler(name) + except RuntimeError: + return web.json_response({"success": False, "error": "Service not ready"}, status=503) + return await handler(request) + + return proxy - async def move_models_bulk(self, request: web.Request) -> web.Response: - """Handle bulk model move request""" - try: - data = await request.json() - file_paths = data.get('file_paths', []) - target_path = data.get('target_path') - - if not file_paths or not target_path: - return web.Response(text='File paths and target path are required', status=400) - - result = await self.model_move_service.move_models_bulk(file_paths, target_path) - return web.json_response(result) - - except Exception as e: - logger.error(f"Error moving models in bulk: {e}", exc_info=True) - return web.Response(text=str(e), status=500) - - async def auto_organize_models(self, request: web.Request) -> web.Response: - """Auto-organize all models or a specific set of models based on current settings""" - try: - # Check if auto-organize is already running - if ws_manager.is_auto_organize_running(): - return web.json_response({ - 'success': False, - 'error': 'Auto-organize is already running. Please wait for it to complete.' - }, status=409) - - # Acquire lock to prevent concurrent auto-organize operations - auto_organize_lock = await ws_manager.get_auto_organize_lock() - - if auto_organize_lock.locked(): - return web.json_response({ - 'success': False, - 'error': 'Auto-organize is already running. Please wait for it to complete.' - }, status=409) - - # Get specific file paths from request if this is a POST with selected models - file_paths = None - if request.method == 'POST': - try: - data = await request.json() - file_paths = data.get('file_paths') - except Exception: - pass # Continue with all models if no valid JSON - - async with auto_organize_lock: - # Use the service layer for business logic - result = await self.model_file_service.auto_organize_models( - file_paths=file_paths, - progress_callback=self.websocket_progress_callback - ) - - return web.json_response(result.to_dict()) - - except Exception as e: - logger.error(f"Error in auto_organize_models: {e}", exc_info=True) - - # Send error message via WebSocket - await ws_manager.broadcast_auto_organize_progress({ - 'type': 'auto_organize_progress', - 'status': 'error', - 'error': str(e) - }) - - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def get_auto_organize_progress(self, request: web.Request) -> web.Response: - """Get current auto-organize progress for polling""" - try: - progress_data = ws_manager.get_auto_organize_progress() - - if progress_data is None: - return web.json_response({ - 'success': False, - 'error': 'No auto-organize operation in progress' - }, status=404) - - return web.json_response({ - 'success': True, - 'progress': progress_data - }) - except Exception as e: - logger.error(f"Error getting auto-organize progress: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def get_model_notes(self, request: web.Request) -> web.Response: - """Get notes for a specific model file""" - try: - model_name = request.query.get('name') - if not model_name: - return web.Response(text=f'{self.model_type.capitalize()} file name is required', status=400) - - notes = await self.service.get_model_notes(model_name) - if notes is not None: - return web.json_response({ - 'success': True, - 'notes': notes - }) - else: - return web.json_response({ - 'success': False, - 'error': f'{self.model_type.capitalize()} not found in cache' - }, status=404) - - except Exception as e: - logger.error(f"Error getting {self.model_type} notes: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def get_model_preview_url(self, request: web.Request) -> web.Response: - """Get the static preview URL for a model file""" - try: - model_name = request.query.get('name') - if not model_name: - return web.Response(text=f'{self.model_type.capitalize()} file name is required', status=400) - - preview_url = await self.service.get_model_preview_url(model_name) - if preview_url: - return web.json_response({ - 'success': True, - 'preview_url': preview_url - }) - else: - return web.json_response({ - 'success': False, - 'error': f'No preview URL found for the specified {self.model_type}' - }, status=404) - - except Exception as e: - logger.error(f"Error getting {self.model_type} preview URL: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def get_model_civitai_url(self, request: web.Request) -> web.Response: - """Get the Civitai URL for a model file""" - try: - model_name = request.query.get('name') - if not model_name: - return web.Response(text=f'{self.model_type.capitalize()} file name is required', status=400) - - result = await self.service.get_model_civitai_url(model_name) - if result['civitai_url']: - return web.json_response({ - 'success': True, - **result - }) - else: - return web.json_response({ - 'success': False, - 'error': f'No Civitai data found for the specified {self.model_type}' - }, status=404) - - except Exception as e: - logger.error(f"Error getting {self.model_type} Civitai URL: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def get_model_metadata(self, request: web.Request) -> web.Response: - """Get filtered CivitAI metadata for a model by file path""" - try: - file_path = request.query.get('file_path') - if not file_path: - return web.Response(text='File path is required', status=400) - - metadata = await self.service.get_model_metadata(file_path) - if metadata is not None: - return web.json_response({ - 'success': True, - 'metadata': metadata - }) - else: - return web.json_response({ - 'success': False, - 'error': f'{self.model_type.capitalize()} not found or no CivitAI metadata available' - }, status=404) - - except Exception as e: - logger.error(f"Error getting {self.model_type} metadata: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def get_model_description(self, request: web.Request) -> web.Response: - """Get model description by file path""" - try: - file_path = request.query.get('file_path') - if not file_path: - return web.Response(text='File path is required', status=400) - - description = await self.service.get_model_description(file_path) - if description is not None: - return web.json_response({ - 'success': True, - 'description': description - }) - else: - return web.json_response({ - 'success': False, - 'error': f'{self.model_type.capitalize()} not found or no description available' - }, status=404) - - except Exception as e: - logger.error(f"Error getting {self.model_type} description: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def get_relative_paths(self, request: web.Request) -> web.Response: - """Get model relative file paths for autocomplete functionality""" - try: - search = request.query.get('search', '').strip() - limit = min(int(request.query.get('limit', '15')), 50) # Max 50 items - - matching_paths = await self.service.search_relative_paths(search, limit) - - return web.json_response({ - 'success': True, - 'relative_paths': matching_paths - }) - - except Exception as e: - logger.error(f"Error getting relative paths for autocomplete: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) \ No newline at end of file diff --git a/py/routes/checkpoint_routes.py b/py/routes/checkpoint_routes.py index 95c747e5..ad4c538a 100644 --- a/py/routes/checkpoint_routes.py +++ b/py/routes/checkpoint_routes.py @@ -2,9 +2,9 @@ import logging from aiohttp import web from .base_model_routes import BaseModelRoutes +from .model_route_registrar import ModelRouteRegistrar from ..services.checkpoint_service import CheckpointService from ..services.service_registry import ServiceRegistry -from ..services.metadata_service import get_default_metadata_provider from ..config import config logger = logging.getLogger(__name__) @@ -14,8 +14,7 @@ class CheckpointRoutes(BaseModelRoutes): def __init__(self): """Initialize Checkpoint routes with Checkpoint service""" - # Service will be initialized later via setup_routes - self.service = None + super().__init__() self.template_name = "checkpoints.html" async def initialize_services(self): @@ -23,8 +22,8 @@ class CheckpointRoutes(BaseModelRoutes): checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() self.service = CheckpointService(checkpoint_scanner) - # Initialize parent with the service - super().__init__(self.service) + # Attach service dependencies + self.attach_service(self.service) def setup_routes(self, app: web.Application): """Setup Checkpoint routes""" @@ -34,14 +33,14 @@ class CheckpointRoutes(BaseModelRoutes): # Setup common routes with 'checkpoints' prefix (includes page route) super().setup_routes(app, 'checkpoints') - def setup_specific_routes(self, app: web.Application, prefix: str): + def setup_specific_routes(self, registrar: ModelRouteRegistrar, prefix: str): """Setup Checkpoint-specific routes""" # Checkpoint info by name - app.router.add_get(f'/api/lm/{prefix}/info/{{name}}', self.get_checkpoint_info) - + registrar.add_prefixed_route('GET', '/api/lm/{prefix}/info/{name}', prefix, self.get_checkpoint_info) + # Checkpoint roots and Unet roots - app.router.add_get(f'/api/lm/{prefix}/checkpoints_roots', self.get_checkpoints_roots) - app.router.add_get(f'/api/lm/{prefix}/unet_roots', self.get_unet_roots) + registrar.add_prefixed_route('GET', '/api/lm/{prefix}/checkpoints_roots', prefix, self.get_checkpoints_roots) + registrar.add_prefixed_route('GET', '/api/lm/{prefix}/unet_roots', prefix, self.get_unet_roots) def _validate_civitai_model_type(self, model_type: str) -> bool: """Validate CivitAI model type for Checkpoint""" diff --git a/py/routes/embedding_routes.py b/py/routes/embedding_routes.py index 29b2f9fd..d7d361ce 100644 --- a/py/routes/embedding_routes.py +++ b/py/routes/embedding_routes.py @@ -2,9 +2,9 @@ import logging from aiohttp import web from .base_model_routes import BaseModelRoutes +from .model_route_registrar import ModelRouteRegistrar from ..services.embedding_service import EmbeddingService from ..services.service_registry import ServiceRegistry -from ..services.metadata_service import get_default_metadata_provider logger = logging.getLogger(__name__) @@ -13,8 +13,7 @@ class EmbeddingRoutes(BaseModelRoutes): def __init__(self): """Initialize Embedding routes with Embedding service""" - # Service will be initialized later via setup_routes - self.service = None + super().__init__() self.template_name = "embeddings.html" async def initialize_services(self): @@ -22,8 +21,8 @@ class EmbeddingRoutes(BaseModelRoutes): embedding_scanner = await ServiceRegistry.get_embedding_scanner() self.service = EmbeddingService(embedding_scanner) - # Initialize parent with the service - super().__init__(self.service) + # Attach service dependencies + self.attach_service(self.service) def setup_routes(self, app: web.Application): """Setup Embedding routes""" @@ -33,10 +32,10 @@ class EmbeddingRoutes(BaseModelRoutes): # Setup common routes with 'embeddings' prefix (includes page route) super().setup_routes(app, 'embeddings') - def setup_specific_routes(self, app: web.Application, prefix: str): + def setup_specific_routes(self, registrar: ModelRouteRegistrar, prefix: str): """Setup Embedding-specific routes""" # Embedding info by name - app.router.add_get(f'/api/lm/{prefix}/info/{{name}}', self.get_embedding_info) + registrar.add_prefixed_route('GET', '/api/lm/{prefix}/info/{name}', prefix, self.get_embedding_info) def _validate_civitai_model_type(self, model_type: str) -> bool: """Validate CivitAI model type for Embedding""" diff --git a/py/routes/handlers/model_handlers.py b/py/routes/handlers/model_handlers.py new file mode 100644 index 00000000..66a7123a --- /dev/null +++ b/py/routes/handlers/model_handlers.py @@ -0,0 +1,802 @@ +"""Handlers for base model routes.""" +from __future__ import annotations + +import asyncio +import json +import logging +from dataclasses import dataclass +from typing import Awaitable, Callable, Dict, Iterable, Mapping, Optional + +from aiohttp import web +import jinja2 + +from ...services.model_file_service import ModelFileService, ModelMoveService +from ...services.websocket_progress_callback import WebSocketProgressCallback +from ...services.websocket_manager import WebSocketManager +from ...services.settings_manager import SettingsManager +from ...utils.routes_common import ModelRouteUtils + + +class ModelPageView: + """Render the HTML view for model listings.""" + + def __init__( + self, + *, + template_env: jinja2.Environment, + template_name: str, + service, + settings_service: SettingsManager, + server_i18n, + logger: logging.Logger, + ) -> None: + self._template_env = template_env + self._template_name = template_name + self._service = service + self._settings = settings_service + self._server_i18n = server_i18n + self._logger = logger + + async def handle(self, request: web.Request) -> web.Response: + try: + is_initializing = ( + self._service.scanner._cache is None + or ( + hasattr(self._service.scanner, "is_initializing") + and callable(self._service.scanner.is_initializing) + and self._service.scanner.is_initializing() + ) + or ( + hasattr(self._service.scanner, "_is_initializing") + and self._service.scanner._is_initializing + ) + ) + + if not self._template_env or not self._template_name: + return web.Response( + text="Template environment or template name not set", + status=500, + ) + + user_language = self._settings.get("language", "en") + self._server_i18n.set_locale(user_language) + + if not hasattr(self._template_env, "_i18n_filter_added"): + self._template_env.filters["t"] = self._server_i18n.create_template_filter() + self._template_env._i18n_filter_added = True # type: ignore[attr-defined] + + template_context = { + "is_initializing": is_initializing, + "settings": self._settings, + "request": request, + "folders": [], + "t": self._server_i18n.get_translation, + } + + if not is_initializing: + try: + cache = await self._service.scanner.get_cached_data(force_refresh=False) + template_context["folders"] = getattr(cache, "folders", []) + except Exception as cache_error: # pragma: no cover - logging path + self._logger.error("Error loading cache data: %s", cache_error) + template_context["is_initializing"] = True + + rendered = self._template_env.get_template(self._template_name).render(**template_context) + return web.Response(text=rendered, content_type="text/html") + except Exception as exc: # pragma: no cover - logging path + self._logger.error("Error handling models page: %s", exc, exc_info=True) + return web.Response(text="Error loading models page", status=500) + + +class ModelListingHandler: + """Provide paginated model listings.""" + + def __init__( + self, + *, + service, + parse_specific_params: Callable[[web.Request], Dict], + logger: logging.Logger, + ) -> None: + self._service = service + self._parse_specific_params = parse_specific_params + self._logger = logger + + async def get_models(self, request: web.Request) -> web.Response: + try: + params = self._parse_common_params(request) + result = await self._service.get_paginated_data(**params) + formatted_result = { + "items": [await self._service.format_response(item) for item in result["items"]], + "total": result["total"], + "page": result["page"], + "page_size": result["page_size"], + "total_pages": result["total_pages"], + } + return web.json_response(formatted_result) + except Exception as exc: + self._logger.error("Error retrieving %ss: %s", self._service.model_type, exc, exc_info=True) + return web.json_response({"error": str(exc)}, status=500) + + def _parse_common_params(self, request: web.Request) -> Dict: + page = int(request.query.get("page", "1")) + page_size = min(int(request.query.get("page_size", "20")), 100) + sort_by = request.query.get("sort_by", "name") + folder = request.query.get("folder") + search = request.query.get("search") + fuzzy_search = request.query.get("fuzzy_search", "false").lower() == "true" + + base_models = request.query.getall("base_model", []) + tags = request.query.getall("tag", []) + favorites_only = request.query.get("favorites_only", "false").lower() == "true" + + search_options = { + "filename": request.query.get("search_filename", "true").lower() == "true", + "modelname": request.query.get("search_modelname", "true").lower() == "true", + "tags": request.query.get("search_tags", "false").lower() == "true", + "creator": request.query.get("search_creator", "false").lower() == "true", + "recursive": request.query.get("recursive", "true").lower() == "true", + } + + hash_filters: Dict[str, object] = {} + if "hash" in request.query: + hash_filters["single_hash"] = request.query["hash"] + elif "hashes" in request.query: + try: + hash_list = json.loads(request.query["hashes"]) + if isinstance(hash_list, list): + hash_filters["multiple_hashes"] = hash_list + except (json.JSONDecodeError, TypeError): + pass + + return { + "page": page, + "page_size": page_size, + "sort_by": sort_by, + "folder": folder, + "search": search, + "fuzzy_search": fuzzy_search, + "base_models": base_models, + "tags": tags, + "search_options": search_options, + "hash_filters": hash_filters, + "favorites_only": favorites_only, + **self._parse_specific_params(request), + } + + +class ModelManagementHandler: + """Handle mutation operations on models.""" + + def __init__(self, *, service, logger: logging.Logger) -> None: + self._service = service + self._logger = logger + + async def delete_model(self, request: web.Request) -> web.Response: + return await ModelRouteUtils.handle_delete_model(request, self._service.scanner) + + async def exclude_model(self, request: web.Request) -> web.Response: + return await ModelRouteUtils.handle_exclude_model(request, self._service.scanner) + + async def fetch_civitai(self, request: web.Request) -> web.Response: + try: + data = await request.json() + file_path = data.get("file_path") + if not file_path: + return web.json_response({"success": False, "error": "File path is required"}, status=400) + + cache = await self._service.scanner.get_cached_data() + model_data = next((item for item in cache.raw_data if item["file_path"] == file_path), None) + if not model_data: + return web.json_response({"success": False, "error": "Model not found in cache"}, status=404) + if not model_data.get("sha256"): + return web.json_response({"success": False, "error": "No SHA256 hash found"}, status=400) + + success, error = await ModelRouteUtils.fetch_and_update_model( + sha256=model_data["sha256"], + file_path=file_path, + model_data=model_data, + update_cache_func=self._service.scanner.update_single_model_cache, + ) + if not success: + return web.json_response({"success": False, "error": error}) + + formatted_metadata = await self._service.format_response(model_data) + return web.json_response({"success": True, "metadata": formatted_metadata}) + except Exception as exc: + self._logger.error("Error fetching from CivitAI: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def relink_civitai(self, request: web.Request) -> web.Response: + return await ModelRouteUtils.handle_relink_civitai(request, self._service.scanner) + + async def replace_preview(self, request: web.Request) -> web.Response: + return await ModelRouteUtils.handle_replace_preview(request, self._service.scanner) + + async def save_metadata(self, request: web.Request) -> web.Response: + return await ModelRouteUtils.handle_save_metadata(request, self._service.scanner) + + async def add_tags(self, request: web.Request) -> web.Response: + return await ModelRouteUtils.handle_add_tags(request, self._service.scanner) + + async def rename_model(self, request: web.Request) -> web.Response: + return await ModelRouteUtils.handle_rename_model(request, self._service.scanner) + + async def bulk_delete_models(self, request: web.Request) -> web.Response: + return await ModelRouteUtils.handle_bulk_delete_models(request, self._service.scanner) + + async def verify_duplicates(self, request: web.Request) -> web.Response: + return await ModelRouteUtils.handle_verify_duplicates(request, self._service.scanner) + + +class ModelQueryHandler: + """Serve read-only model queries.""" + + def __init__(self, *, service, logger: logging.Logger) -> None: + self._service = service + self._logger = logger + + async def get_top_tags(self, request: web.Request) -> web.Response: + try: + limit = int(request.query.get("limit", "20")) + if limit < 1 or limit > 100: + limit = 20 + top_tags = await self._service.get_top_tags(limit) + return web.json_response({"success": True, "tags": top_tags}) + except Exception as exc: + self._logger.error("Error getting top tags: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": "Internal server error"}, status=500) + + async def get_base_models(self, request: web.Request) -> web.Response: + try: + limit = int(request.query.get("limit", "20")) + if limit < 1 or limit > 100: + limit = 20 + base_models = await self._service.get_base_models(limit) + return web.json_response({"success": True, "base_models": base_models}) + except Exception as exc: + self._logger.error("Error retrieving base models: %s", exc) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def scan_models(self, request: web.Request) -> web.Response: + try: + full_rebuild = request.query.get("full_rebuild", "false").lower() == "true" + await self._service.scan_models(force_refresh=True, rebuild_cache=full_rebuild) + return web.json_response({"status": "success", "message": f"{self._service.model_type.capitalize()} scan completed"}) + except Exception as exc: + self._logger.error("Error scanning %ss: %s", self._service.model_type, exc, exc_info=True) + return web.json_response({"error": str(exc)}, status=500) + + async def get_model_roots(self, request: web.Request) -> web.Response: + try: + roots = self._service.get_model_roots() + return web.json_response({"success": True, "roots": roots}) + except Exception as exc: + self._logger.error("Error getting %s roots: %s", self._service.model_type, exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def get_folders(self, request: web.Request) -> web.Response: + try: + cache = await self._service.scanner.get_cached_data() + return web.json_response({"folders": cache.folders}) + except Exception as exc: + self._logger.error("Error getting folders: %s", exc) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def get_folder_tree(self, request: web.Request) -> web.Response: + try: + model_root = request.query.get("model_root") + if not model_root: + return web.json_response({"success": False, "error": "model_root parameter is required"}, status=400) + folder_tree = await self._service.get_folder_tree(model_root) + return web.json_response({"success": True, "tree": folder_tree}) + except Exception as exc: + self._logger.error("Error getting folder tree: %s", exc) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def get_unified_folder_tree(self, request: web.Request) -> web.Response: + try: + unified_tree = await self._service.get_unified_folder_tree() + return web.json_response({"success": True, "tree": unified_tree}) + except Exception as exc: + self._logger.error("Error getting unified folder tree: %s", exc) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def find_duplicate_models(self, request: web.Request) -> web.Response: + try: + duplicates = self._service.find_duplicate_hashes() + result = [] + cache = await self._service.scanner.get_cached_data() + for sha256, paths in duplicates.items(): + group = {"hash": sha256, "models": []} + for path in paths: + model = next((m for m in cache.raw_data if m["file_path"] == path), None) + if model: + group["models"].append(await self._service.format_response(model)) + primary_path = self._service.get_path_by_hash(sha256) + if primary_path and primary_path not in paths: + primary_model = next((m for m in cache.raw_data if m["file_path"] == primary_path), None) + if primary_model: + group["models"].insert(0, await self._service.format_response(primary_model)) + if len(group["models"]) > 1: + result.append(group) + return web.json_response({"success": True, "duplicates": result, "count": len(result)}) + except Exception as exc: + self._logger.error("Error finding duplicate %ss: %s", self._service.model_type, exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def find_filename_conflicts(self, request: web.Request) -> web.Response: + try: + duplicates = self._service.find_duplicate_filenames() + result = [] + cache = await self._service.scanner.get_cached_data() + for filename, paths in duplicates.items(): + group = {"filename": filename, "models": []} + for path in paths: + model = next((m for m in cache.raw_data if m["file_path"] == path), None) + if model: + group["models"].append(await self._service.format_response(model)) + hash_val = self._service.scanner.get_hash_by_filename(filename) + if hash_val: + main_path = self._service.get_path_by_hash(hash_val) + if main_path and main_path not in paths: + main_model = next((m for m in cache.raw_data if m["file_path"] == main_path), None) + if main_model: + group["models"].insert(0, await self._service.format_response(main_model)) + if group["models"]: + result.append(group) + return web.json_response({"success": True, "conflicts": result, "count": len(result)}) + except Exception as exc: + self._logger.error("Error finding filename conflicts for %ss: %s", self._service.model_type, exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def get_model_notes(self, request: web.Request) -> web.Response: + try: + model_name = request.query.get("name") + if not model_name: + return web.Response(text=f"{self._service.model_type.capitalize()} file name is required", status=400) + notes = await self._service.get_model_notes(model_name) + if notes is not None: + return web.json_response({"success": True, "notes": notes}) + return web.json_response({"success": False, "error": f"{self._service.model_type.capitalize()} not found in cache"}, status=404) + except Exception as exc: + self._logger.error("Error getting %s notes: %s", self._service.model_type, exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def get_model_preview_url(self, request: web.Request) -> web.Response: + try: + model_name = request.query.get("name") + if not model_name: + return web.Response(text=f"{self._service.model_type.capitalize()} file name is required", status=400) + preview_url = await self._service.get_model_preview_url(model_name) + if preview_url: + return web.json_response({"success": True, "preview_url": preview_url}) + return web.json_response({"success": False, "error": f"No preview URL found for the specified {self._service.model_type}"}, status=404) + except Exception as exc: + self._logger.error("Error getting %s preview URL: %s", self._service.model_type, exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def get_model_civitai_url(self, request: web.Request) -> web.Response: + try: + model_name = request.query.get("name") + if not model_name: + return web.Response(text=f"{self._service.model_type.capitalize()} file name is required", status=400) + result = await self._service.get_model_civitai_url(model_name) + if result["civitai_url"]: + return web.json_response({"success": True, **result}) + return web.json_response({"success": False, "error": f"No Civitai data found for the specified {self._service.model_type}"}, status=404) + except Exception as exc: + self._logger.error("Error getting %s Civitai URL: %s", self._service.model_type, exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def get_model_metadata(self, request: web.Request) -> web.Response: + try: + file_path = request.query.get("file_path") + if not file_path: + return web.Response(text="File path is required", status=400) + metadata = await self._service.get_model_metadata(file_path) + if metadata is not None: + return web.json_response({"success": True, "metadata": metadata}) + return web.json_response({"success": False, "error": f"{self._service.model_type.capitalize()} not found or no CivitAI metadata available"}, status=404) + except Exception as exc: + self._logger.error("Error getting %s metadata: %s", self._service.model_type, exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def get_model_description(self, request: web.Request) -> web.Response: + try: + file_path = request.query.get("file_path") + if not file_path: + return web.Response(text="File path is required", status=400) + description = await self._service.get_model_description(file_path) + if description is not None: + return web.json_response({"success": True, "description": description}) + return web.json_response({"success": False, "error": f"{self._service.model_type.capitalize()} not found or no description available"}, status=404) + except Exception as exc: + self._logger.error("Error getting %s description: %s", self._service.model_type, exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def get_relative_paths(self, request: web.Request) -> web.Response: + try: + search = request.query.get("search", "").strip() + limit = min(int(request.query.get("limit", "15")), 50) + matching_paths = await self._service.search_relative_paths(search, limit) + return web.json_response({"success": True, "relative_paths": matching_paths}) + except Exception as exc: + self._logger.error("Error getting relative paths for autocomplete: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + +class ModelDownloadHandler: + """Coordinate downloads and progress reporting.""" + + def __init__(self, *, ws_manager: WebSocketManager, logger: logging.Logger) -> None: + self._ws_manager = ws_manager + self._logger = logger + + async def download_model(self, request: web.Request) -> web.Response: + return await ModelRouteUtils.handle_download_model(request) + + async def download_model_get(self, request: web.Request) -> web.Response: + try: + model_id = request.query.get("model_id") + if not model_id: + return web.Response(status=400, text="Missing required parameter: Please provide 'model_id'") + + model_version_id = request.query.get("model_version_id") + download_id = request.query.get("download_id") + use_default_paths = request.query.get("use_default_paths", "false").lower() == "true" + source = request.query.get("source") + + data = {"model_id": model_id, "use_default_paths": use_default_paths} + if model_version_id: + data["model_version_id"] = model_version_id + if download_id: + data["download_id"] = download_id + if source: + data["source"] = source + + loop = asyncio.get_event_loop() + future = loop.create_future() + future.set_result(data) + + mock_request = type("MockRequest", (), {"json": lambda self=None: future})() + return await ModelRouteUtils.handle_download_model(mock_request) + except Exception as exc: + self._logger.error("Error downloading model via GET: %s", exc, exc_info=True) + return web.Response(status=500, text=str(exc)) + + async def cancel_download_get(self, request: web.Request) -> web.Response: + try: + download_id = request.query.get("download_id") + if not download_id: + return web.json_response({"success": False, "error": "Download ID is required"}, status=400) + mock_request = type("MockRequest", (), {"match_info": {"download_id": download_id}})() + return await ModelRouteUtils.handle_cancel_download(mock_request) + except Exception as exc: + self._logger.error("Error cancelling download via GET: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def get_download_progress(self, request: web.Request) -> web.Response: + try: + download_id = request.match_info.get("download_id") + if not download_id: + return web.json_response({"success": False, "error": "Download ID is required"}, status=400) + progress_data = self._ws_manager.get_download_progress(download_id) + if progress_data is None: + return web.json_response({"success": False, "error": "Download ID not found"}, status=404) + return web.json_response({"success": True, "progress": progress_data.get("progress", 0)}) + except Exception as exc: + self._logger.error("Error getting download progress: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + +class ModelCivitaiHandler: + """CivitAI integration endpoints.""" + + def __init__( + self, + *, + service, + settings_service: SettingsManager, + ws_manager: WebSocketManager, + logger: logging.Logger, + metadata_provider_factory: Callable[[], Awaitable], + validate_model_type: Callable[[str], bool], + expected_model_types: Callable[[], str], + find_model_file: Callable[[Iterable[Mapping[str, object]]], Optional[Mapping[str, object]]], + ) -> None: + self._service = service + self._settings = settings_service + self._ws_manager = ws_manager + self._logger = logger + self._metadata_provider_factory = metadata_provider_factory + self._validate_model_type = validate_model_type + self._expected_model_types = expected_model_types + self._find_model_file = find_model_file + + async def fetch_all_civitai(self, request: web.Request) -> web.Response: + try: + cache = await self._service.scanner.get_cached_data() + total = len(cache.raw_data) + processed = 0 + success = 0 + needs_resort = False + + enable_metadata_archive_db = self._settings.get("enable_metadata_archive_db", False) + to_process = [ + model + for model in cache.raw_data + if model.get("sha256") + and (not model.get("civitai") or not model["civitai"].get("id")) + and ( + (enable_metadata_archive_db and not model.get("db_checked", False)) + or (not enable_metadata_archive_db and model.get("from_civitai") is True) + ) + ] + total_to_process = len(to_process) + + await self._ws_manager.broadcast({ + "status": "started", + "total": total_to_process, + "processed": 0, + "success": 0, + }) + + for model in to_process: + try: + original_name = model.get("model_name") + result, error = await ModelRouteUtils.fetch_and_update_model( + sha256=model["sha256"], + file_path=model["file_path"], + model_data=model, + update_cache_func=self._service.scanner.update_single_model_cache, + ) + if result: + success += 1 + if original_name != model.get("model_name"): + needs_resort = True + processed += 1 + await self._ws_manager.broadcast({ + "status": "processing", + "total": total_to_process, + "processed": processed, + "success": success, + "current_name": model.get("model_name", "Unknown"), + }) + except Exception as exc: # pragma: no cover - logging path + self._logger.error("Error fetching CivitAI data for %s: %s", model["file_path"], exc) + + if needs_resort: + await cache.resort() + + await self._ws_manager.broadcast({ + "status": "completed", + "total": total_to_process, + "processed": processed, + "success": success, + }) + + return web.json_response({ + "success": True, + "message": f"Successfully updated {success} of {processed} processed {self._service.model_type}s (total: {total})", + }) + except Exception as exc: + await self._ws_manager.broadcast({"status": "error", "error": str(exc)}) + self._logger.error("Error in fetch_all_civitai for %ss: %s", self._service.model_type, exc) + return web.Response(text=str(exc), status=500) + + async def get_civitai_versions(self, request: web.Request) -> web.Response: + try: + model_id = request.match_info["model_id"] + metadata_provider = await self._metadata_provider_factory() + response = await metadata_provider.get_model_versions(model_id) + if not response or not response.get("modelVersions"): + return web.Response(status=404, text="Model not found") + + versions = response.get("modelVersions", []) + model_type = response.get("type", "") + if not self._validate_model_type(model_type): + return web.json_response( + {"error": f"Model type mismatch. Expected {self._expected_model_types()}, got {model_type}"}, + status=400, + ) + + for version in versions: + model_file = self._find_model_file(version.get("files", [])) if isinstance(version.get("files"), Iterable) else None + if model_file: + hashes = model_file.get("hashes", {}) if isinstance(model_file, Mapping) else {} + sha256 = hashes.get("SHA256") if isinstance(hashes, Mapping) else None + if sha256: + version["existsLocally"] = self._service.has_hash(sha256) + if version["existsLocally"]: + version["localPath"] = self._service.get_path_by_hash(sha256) + version["modelSizeKB"] = model_file.get("sizeKB") if isinstance(model_file, Mapping) else None + else: + version["existsLocally"] = False + return web.json_response(versions) + except Exception as exc: + self._logger.error("Error fetching %s model versions: %s", self._service.model_type, exc) + return web.Response(status=500, text=str(exc)) + + async def get_civitai_model_by_version(self, request: web.Request) -> web.Response: + try: + model_version_id = request.match_info.get("modelVersionId") + metadata_provider = await self._metadata_provider_factory() + model, error_msg = await metadata_provider.get_model_version_info(model_version_id) + if not model: + self._logger.warning("Failed to fetch model version %s: %s", model_version_id, error_msg) + status_code = 404 if error_msg and "not found" in error_msg.lower() else 500 + return web.json_response({"success": False, "error": error_msg or "Failed to fetch model information"}, status=status_code) + return web.json_response(model) + except Exception as exc: + self._logger.error("Error fetching model details: %s", exc) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def get_civitai_model_by_hash(self, request: web.Request) -> web.Response: + try: + hash_value = request.match_info.get("hash") + metadata_provider = await self._metadata_provider_factory() + model, error = await metadata_provider.get_model_by_hash(hash_value) + if error: + self._logger.warning("Error getting model by hash: %s", error) + return web.json_response({"success": False, "error": error}, status=404) + return web.json_response(model) + except Exception as exc: + self._logger.error("Error fetching model details by hash: %s", exc) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + +class ModelMoveHandler: + """Move model files between folders.""" + + def __init__(self, *, move_service: ModelMoveService, logger: logging.Logger) -> None: + self._move_service = move_service + self._logger = logger + + async def move_model(self, request: web.Request) -> web.Response: + try: + data = await request.json() + file_path = data.get("file_path") + target_path = data.get("target_path") + if not file_path or not target_path: + return web.Response(text="File path and target path are required", status=400) + result = await self._move_service.move_model(file_path, target_path) + status = 200 if result.get("success") else 500 + return web.json_response(result, status=status) + except Exception as exc: + self._logger.error("Error moving model: %s", exc, exc_info=True) + return web.Response(text=str(exc), status=500) + + async def move_models_bulk(self, request: web.Request) -> web.Response: + try: + data = await request.json() + file_paths = data.get("file_paths", []) + target_path = data.get("target_path") + if not file_paths or not target_path: + return web.Response(text="File paths and target path are required", status=400) + result = await self._move_service.move_models_bulk(file_paths, target_path) + return web.json_response(result) + except Exception as exc: + self._logger.error("Error moving models in bulk: %s", exc, exc_info=True) + return web.Response(text=str(exc), status=500) + + +class ModelAutoOrganizeHandler: + """Manage auto-organize operations.""" + + def __init__( + self, + *, + file_service: ModelFileService, + progress_callback: WebSocketProgressCallback, + ws_manager: WebSocketManager, + logger: logging.Logger, + ) -> None: + self._file_service = file_service + self._progress_callback = progress_callback + self._ws_manager = ws_manager + self._logger = logger + + async def auto_organize_models(self, request: web.Request) -> web.Response: + try: + if self._ws_manager.is_auto_organize_running(): + return web.json_response( + {"success": False, "error": "Auto-organize is already running. Please wait for it to complete."}, + status=409, + ) + + auto_organize_lock = await self._ws_manager.get_auto_organize_lock() + if auto_organize_lock.locked(): + return web.json_response( + {"success": False, "error": "Auto-organize is already running. Please wait for it to complete."}, + status=409, + ) + + file_paths = None + if request.method == "POST": + try: + data = await request.json() + file_paths = data.get("file_paths") + except Exception: # pragma: no cover - permissive path + pass + + async with auto_organize_lock: + result = await self._file_service.auto_organize_models( + file_paths=file_paths, + progress_callback=self._progress_callback, + ) + return web.json_response(result.to_dict()) + except Exception as exc: + self._logger.error("Error in auto_organize_models: %s", exc, exc_info=True) + await self._ws_manager.broadcast_auto_organize_progress( + {"type": "auto_organize_progress", "status": "error", "error": str(exc)} + ) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def get_auto_organize_progress(self, request: web.Request) -> web.Response: + try: + progress_data = self._ws_manager.get_auto_organize_progress() + if progress_data is None: + return web.json_response({"success": False, "error": "No auto-organize operation in progress"}, status=404) + return web.json_response({"success": True, "progress": progress_data}) + except Exception as exc: + self._logger.error("Error getting auto-organize progress: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + +@dataclass +class ModelHandlerSet: + """Aggregate concrete handlers into a flat mapping.""" + + page_view: ModelPageView + listing: ModelListingHandler + management: ModelManagementHandler + query: ModelQueryHandler + download: ModelDownloadHandler + civitai: ModelCivitaiHandler + move: ModelMoveHandler + auto_organize: ModelAutoOrganizeHandler + + def to_route_mapping(self) -> Dict[str, Callable[[web.Request], Awaitable[web.Response]]]: + return { + "handle_models_page": self.page_view.handle, + "get_models": self.listing.get_models, + "delete_model": self.management.delete_model, + "exclude_model": self.management.exclude_model, + "fetch_civitai": self.management.fetch_civitai, + "fetch_all_civitai": self.civitai.fetch_all_civitai, + "relink_civitai": self.management.relink_civitai, + "replace_preview": self.management.replace_preview, + "save_metadata": self.management.save_metadata, + "add_tags": self.management.add_tags, + "rename_model": self.management.rename_model, + "bulk_delete_models": self.management.bulk_delete_models, + "verify_duplicates": self.management.verify_duplicates, + "get_top_tags": self.query.get_top_tags, + "get_base_models": self.query.get_base_models, + "scan_models": self.query.scan_models, + "get_model_roots": self.query.get_model_roots, + "get_folders": self.query.get_folders, + "get_folder_tree": self.query.get_folder_tree, + "get_unified_folder_tree": self.query.get_unified_folder_tree, + "find_duplicate_models": self.query.find_duplicate_models, + "find_filename_conflicts": self.query.find_filename_conflicts, + "download_model": self.download.download_model, + "download_model_get": self.download.download_model_get, + "cancel_download_get": self.download.cancel_download_get, + "get_download_progress": self.download.get_download_progress, + "get_civitai_versions": self.civitai.get_civitai_versions, + "get_civitai_model_by_version": self.civitai.get_civitai_model_by_version, + "get_civitai_model_by_hash": self.civitai.get_civitai_model_by_hash, + "move_model": self.move.move_model, + "move_models_bulk": self.move.move_models_bulk, + "auto_organize_models": self.auto_organize.auto_organize_models, + "get_auto_organize_progress": self.auto_organize.get_auto_organize_progress, + "get_model_notes": self.query.get_model_notes, + "get_model_preview_url": self.query.get_model_preview_url, + "get_model_civitai_url": self.query.get_model_civitai_url, + "get_model_metadata": self.query.get_model_metadata, + "get_model_description": self.query.get_model_description, + "get_relative_paths": self.query.get_relative_paths, + } + diff --git a/py/routes/lora_routes.py b/py/routes/lora_routes.py index 0ddb41ab..ee6bc151 100644 --- a/py/routes/lora_routes.py +++ b/py/routes/lora_routes.py @@ -5,9 +5,9 @@ from typing import Dict from server import PromptServer # type: ignore from .base_model_routes import BaseModelRoutes +from .model_route_registrar import ModelRouteRegistrar from ..services.lora_service import LoraService from ..services.service_registry import ServiceRegistry -from ..services.metadata_service import get_default_metadata_provider from ..utils.utils import get_lora_info logger = logging.getLogger(__name__) @@ -17,8 +17,7 @@ class LoraRoutes(BaseModelRoutes): def __init__(self): """Initialize LoRA routes with LoRA service""" - # Service will be initialized later via setup_routes - self.service = None + super().__init__() self.template_name = "loras.html" async def initialize_services(self): @@ -26,26 +25,26 @@ class LoraRoutes(BaseModelRoutes): lora_scanner = await ServiceRegistry.get_lora_scanner() self.service = LoraService(lora_scanner) - # Initialize parent with the service - super().__init__(self.service) + # Attach service dependencies + self.attach_service(self.service) def setup_routes(self, app: web.Application): """Setup LoRA routes""" # Schedule service initialization on app startup app.on_startup.append(lambda _: self.initialize_services()) - + # Setup common routes with 'loras' prefix (includes page route) super().setup_routes(app, 'loras') - - def setup_specific_routes(self, app: web.Application, prefix: str): + + def setup_specific_routes(self, registrar: ModelRouteRegistrar, prefix: str): """Setup LoRA-specific routes""" # LoRA-specific query routes - app.router.add_get(f'/api/lm/{prefix}/letter-counts', self.get_letter_counts) - app.router.add_get(f'/api/lm/{prefix}/get-trigger-words', self.get_lora_trigger_words) - app.router.add_get(f'/api/lm/{prefix}/usage-tips-by-path', self.get_lora_usage_tips_by_path) - + registrar.add_prefixed_route('GET', '/api/lm/{prefix}/letter-counts', prefix, self.get_letter_counts) + registrar.add_prefixed_route('GET', '/api/lm/{prefix}/get-trigger-words', prefix, self.get_lora_trigger_words) + registrar.add_prefixed_route('GET', '/api/lm/{prefix}/usage-tips-by-path', prefix, self.get_lora_usage_tips_by_path) + # ComfyUI integration - app.router.add_post(f'/api/lm/{prefix}/get_trigger_words', self.get_trigger_words) + registrar.add_prefixed_route('POST', '/api/lm/{prefix}/get_trigger_words', prefix, self.get_trigger_words) def _parse_specific_params(self, request: web.Request) -> Dict: """Parse LoRA-specific parameters""" diff --git a/py/routes/model_route_registrar.py b/py/routes/model_route_registrar.py new file mode 100644 index 00000000..96f65fc5 --- /dev/null +++ b/py/routes/model_route_registrar.py @@ -0,0 +1,99 @@ +"""Route registrar for model endpoints.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, Iterable, Mapping + +from aiohttp import web + + +@dataclass(frozen=True) +class RouteDefinition: + """Declarative definition for a HTTP route.""" + + method: str + path_template: str + handler_name: str + + def build_path(self, prefix: str) -> str: + return self.path_template.replace("{prefix}", prefix) + + +COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( + RouteDefinition("GET", "/api/lm/{prefix}/list", "get_models"), + RouteDefinition("POST", "/api/lm/{prefix}/delete", "delete_model"), + RouteDefinition("POST", "/api/lm/{prefix}/exclude", "exclude_model"), + RouteDefinition("POST", "/api/lm/{prefix}/fetch-civitai", "fetch_civitai"), + RouteDefinition("POST", "/api/lm/{prefix}/fetch-all-civitai", "fetch_all_civitai"), + RouteDefinition("POST", "/api/lm/{prefix}/relink-civitai", "relink_civitai"), + RouteDefinition("POST", "/api/lm/{prefix}/replace-preview", "replace_preview"), + RouteDefinition("POST", "/api/lm/{prefix}/save-metadata", "save_metadata"), + RouteDefinition("POST", "/api/lm/{prefix}/add-tags", "add_tags"), + RouteDefinition("POST", "/api/lm/{prefix}/rename", "rename_model"), + RouteDefinition("POST", "/api/lm/{prefix}/bulk-delete", "bulk_delete_models"), + RouteDefinition("POST", "/api/lm/{prefix}/verify-duplicates", "verify_duplicates"), + RouteDefinition("POST", "/api/lm/{prefix}/move_model", "move_model"), + RouteDefinition("POST", "/api/lm/{prefix}/move_models_bulk", "move_models_bulk"), + RouteDefinition("GET", "/api/lm/{prefix}/auto-organize", "auto_organize_models"), + RouteDefinition("POST", "/api/lm/{prefix}/auto-organize", "auto_organize_models"), + RouteDefinition("GET", "/api/lm/{prefix}/auto-organize-progress", "get_auto_organize_progress"), + RouteDefinition("GET", "/api/lm/{prefix}/top-tags", "get_top_tags"), + RouteDefinition("GET", "/api/lm/{prefix}/base-models", "get_base_models"), + RouteDefinition("GET", "/api/lm/{prefix}/scan", "scan_models"), + RouteDefinition("GET", "/api/lm/{prefix}/roots", "get_model_roots"), + RouteDefinition("GET", "/api/lm/{prefix}/folders", "get_folders"), + RouteDefinition("GET", "/api/lm/{prefix}/folder-tree", "get_folder_tree"), + RouteDefinition("GET", "/api/lm/{prefix}/unified-folder-tree", "get_unified_folder_tree"), + RouteDefinition("GET", "/api/lm/{prefix}/find-duplicates", "find_duplicate_models"), + RouteDefinition("GET", "/api/lm/{prefix}/find-filename-conflicts", "find_filename_conflicts"), + RouteDefinition("GET", "/api/lm/{prefix}/get-notes", "get_model_notes"), + RouteDefinition("GET", "/api/lm/{prefix}/preview-url", "get_model_preview_url"), + RouteDefinition("GET", "/api/lm/{prefix}/civitai-url", "get_model_civitai_url"), + RouteDefinition("GET", "/api/lm/{prefix}/metadata", "get_model_metadata"), + RouteDefinition("GET", "/api/lm/{prefix}/model-description", "get_model_description"), + RouteDefinition("GET", "/api/lm/{prefix}/relative-paths", "get_relative_paths"), + RouteDefinition("GET", "/api/lm/{prefix}/civitai/versions/{model_id}", "get_civitai_versions"), + RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/version/{modelVersionId}", "get_civitai_model_by_version"), + RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/hash/{hash}", "get_civitai_model_by_hash"), + RouteDefinition("POST", "/api/lm/download-model", "download_model"), + RouteDefinition("GET", "/api/lm/download-model-get", "download_model_get"), + RouteDefinition("GET", "/api/lm/cancel-download-get", "cancel_download_get"), + RouteDefinition("GET", "/api/lm/download-progress/{download_id}", "get_download_progress"), + RouteDefinition("GET", "/{prefix}", "handle_models_page"), +) + + +class ModelRouteRegistrar: + """Bind declarative definitions to an aiohttp router.""" + + _METHOD_MAP = { + "GET": "add_get", + "POST": "add_post", + "PUT": "add_put", + "DELETE": "add_delete", + } + + def __init__(self, app: web.Application) -> None: + self._app = app + + def register_common_routes( + self, + prefix: str, + handler_lookup: Mapping[str, Callable[[web.Request], object]], + *, + definitions: Iterable[RouteDefinition] = COMMON_ROUTE_DEFINITIONS, + ) -> None: + for definition in definitions: + self._bind_route(definition.method, definition.build_path(prefix), handler_lookup[definition.handler_name]) + + def add_route(self, method: str, path: str, handler: Callable) -> None: + self._bind_route(method, path, handler) + + def add_prefixed_route(self, method: str, path_template: str, prefix: str, handler: Callable) -> None: + self._bind_route(method, path_template.replace("{prefix}", prefix), handler) + + def _bind_route(self, method: str, path: str, handler: Callable) -> None: + add_method_name = self._METHOD_MAP[method.upper()] + add_method = getattr(self._app.router, add_method_name) + add_method(path, handler) + diff --git a/tests/routes/test_base_model_routes_smoke.py b/tests/routes/test_base_model_routes_smoke.py index bccd83a9..2b9ed805 100644 --- a/tests/routes/test_base_model_routes_smoke.py +++ b/tests/routes/test_base_model_routes_smoke.py @@ -37,7 +37,7 @@ from py_local.config import config class DummyRoutes(BaseModelRoutes): template_name = "dummy.html" - def setup_specific_routes(self, app: web.Application, prefix: str) -> None: # pragma: no cover - no extra routes in smoke tests + def setup_specific_routes(self, registrar, prefix: str) -> None: # pragma: no cover - no extra routes in smoke tests return None