From a2b81ea099cbe4f5cc7ce881b13891dcaa2feba7 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 23 Jul 2025 14:39:02 +0800 Subject: [PATCH 01/18] refactor: Implement base model routes and services for LoRA and Checkpoint - Added BaseModelRoutes class to handle common routes and logic for model types. - Created CheckpointRoutes class inheriting from BaseModelRoutes for checkpoint-specific routes. - Implemented CheckpointService class for handling checkpoint-related data and operations. - Developed LoraService class for managing LoRA-specific functionalities. - Introduced ModelServiceFactory to manage service and route registrations for different model types. - Established methods for fetching, filtering, and formatting model data across services. - Integrated CivitAI metadata handling within model routes and services. - Added pagination and filtering capabilities for model data retrieval. --- py/lora_manager.py | 36 +- py/routes/api_routes.py | 1173 +------------------------- py/routes/base_model_routes.py | 431 ++++++++++ py/routes/checkpoint_routes.py | 170 ++++ py/routes/lora_routes.py | 763 ++++++++++++++--- py/routes/misc_routes.py | 5 +- py/routes/recipe_routes.py | 49 +- py/services/base_model_service.py | 248 ++++++ py/services/checkpoint_scanner.py | 33 +- py/services/checkpoint_service.py | 51 ++ py/services/lora_service.py | 172 ++++ py/services/model_service_factory.py | 137 +++ py/services/service_registry.py | 236 ++++-- py/utils/routes_common.py | 53 ++ standalone.py | 13 +- 15 files changed, 2185 insertions(+), 1385 deletions(-) create mode 100644 py/routes/base_model_routes.py create mode 100644 py/routes/checkpoint_routes.py create mode 100644 py/services/base_model_service.py create mode 100644 py/services/checkpoint_service.py create mode 100644 py/services/lora_service.py create mode 100644 py/services/model_service_factory.py diff --git a/py/lora_manager.py b/py/lora_manager.py index ed2ff8cb..dc54d8f4 100644 --- a/py/lora_manager.py +++ b/py/lora_manager.py @@ -6,10 +6,8 @@ from pathlib import Path from server import PromptServer # type: ignore from .config import config -from .routes.lora_routes import LoraRoutes -from .routes.api_routes import ApiRoutes +from .services.model_service_factory import ModelServiceFactory, register_default_model_types from .routes.recipe_routes import RecipeRoutes -from .routes.checkpoints_routes import CheckpointsRoutes from .routes.stats_routes import StatsRoutes from .routes.update_routes import UpdateRoutes from .routes.misc_routes import MiscRoutes @@ -17,6 +15,7 @@ from .routes.example_images_routes import ExampleImagesRoutes from .services.service_registry import ServiceRegistry from .services.settings_manager import settings from .utils.example_images_migration import ExampleImagesMigration +from .services.websocket_manager import ws_manager logger = logging.getLogger(__name__) @@ -28,7 +27,7 @@ class LoraManager: @classmethod def add_routes(cls): - """Initialize and register all routes""" + """Initialize and register all routes using the new refactored architecture""" app = PromptServer.instance.app # Configure aiohttp access logger to be less verbose @@ -110,27 +109,32 @@ class LoraManager: # Add static route for plugin assets app.router.add_static('/loras_static', config.static_path) - # Setup feature routes - lora_routes = LoraRoutes() - checkpoints_routes = CheckpointsRoutes() - stats_routes = StatsRoutes() + # Register default model types with the factory + register_default_model_types() - # Initialize routes - lora_routes.setup_routes(app) - checkpoints_routes.setup_routes(app) - stats_routes.setup_routes(app) # Add statistics routes - ApiRoutes.setup_routes(app) + # Setup all model routes using the factory + ModelServiceFactory.setup_all_routes(app) + + # Setup non-model-specific routes + stats_routes = StatsRoutes() + stats_routes.setup_routes(app) RecipeRoutes.setup_routes(app) UpdateRoutes.setup_routes(app) - MiscRoutes.setup_routes(app) # Register miscellaneous routes - ExampleImagesRoutes.setup_routes(app) # Register example images routes + MiscRoutes.setup_routes(app) + ExampleImagesRoutes.setup_routes(app) + + # Setup WebSocket routes that are shared across all model types + app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection) + app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection) + app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection) # Schedule service initialization app.on_startup.append(lambda app: cls._initialize_services()) # Add cleanup app.on_shutdown.append(cls._cleanup) - app.on_shutdown.append(ApiRoutes.cleanup) + + logger.info(f"LoRA Manager: Set up routes for {len(ModelServiceFactory.get_registered_types())} model types: {', '.join(ModelServiceFactory.get_registered_types())}") @classmethod async def _initialize_services(cls): diff --git a/py/routes/api_routes.py b/py/routes/api_routes.py index f5457935..5c9c93fe 100644 --- a/py/routes/api_routes.py +++ b/py/routes/api_routes.py @@ -1,1184 +1,37 @@ -import os -import json import logging from aiohttp import web -from typing import Dict -from server import PromptServer # type: ignore -from ..utils.routes_common import ModelRouteUtils -from ..utils.utils import get_lora_info - -from ..config import config from ..services.websocket_manager import ws_manager -import asyncio from .update_routes import UpdateRoutes -from ..utils.constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH, VALID_LORA_TYPES -from ..utils.exif_utils import ExifUtils -from ..utils.metadata_manager import MetadataManager -from ..services.service_registry import ServiceRegistry +from .lora_routes import LoraRoutes logger = logging.getLogger(__name__) class ApiRoutes: - """API route handlers for LoRA management""" + """Legacy API route handlers for backward compatibility""" def __init__(self): - self.scanner = None # Will be initialized in setup_routes - self.civitai_client = None # Will be initialized in setup_routes - self.download_manager = None # Will be initialized in setup_routes - self._download_lock = asyncio.Lock() - - async def initialize_services(self): - """Initialize services from ServiceRegistry""" - self.scanner = await ServiceRegistry.get_lora_scanner() - self.civitai_client = await ServiceRegistry.get_civitai_client() - self.download_manager = await ServiceRegistry.get_download_manager() + # Initialize the new LoRA routes + self.lora_routes = LoraRoutes() @classmethod def setup_routes(cls, app: web.Application): - """Register API routes""" + """Register API routes using the new refactored architecture""" routes = cls() - # Schedule service initialization on app startup - app.on_startup.append(lambda _: routes.initialize_services()) + # Setup the refactored LoRA routes + routes.lora_routes.setup_routes(app) - app.router.add_post('/api/delete_model', routes.delete_model) - app.router.add_post('/api/loras/exclude', routes.exclude_model) # Add new exclude endpoint - app.router.add_post('/api/fetch-civitai', routes.fetch_civitai) - app.router.add_post('/api/relink-civitai', routes.relink_civitai) # Add new relink endpoint - app.router.add_post('/api/replace_preview', routes.replace_preview) - app.router.add_get('/api/loras', routes.get_loras) - app.router.add_post('/api/fetch-all-civitai', routes.fetch_all_civitai) + # Setup WebSocket routes that are still shared app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection) - app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection) # Add new WebSocket route for download progress - app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection) # Add new WebSocket route - app.router.add_get('/api/lora-roots', routes.get_lora_roots) - app.router.add_get('/api/folders', routes.get_folders) - app.router.add_get('/api/civitai/versions/{model_id}', routes.get_civitai_versions) - app.router.add_get('/api/civitai/model/version/{modelVersionId}', routes.get_civitai_model_by_version) - app.router.add_get('/api/civitai/model/hash/{hash}', routes.get_civitai_model_by_hash) - app.router.add_post('/api/download-model', routes.download_model) - app.router.add_get('/api/download-model-get', routes.download_model_get) # Add new GET endpoint - app.router.add_get('/api/cancel-download-get', routes.cancel_download_get) - app.router.add_get('/api/download-progress/{download_id}', routes.get_download_progress) # Add new endpoint for download progress - app.router.add_post('/api/move_model', routes.move_model) - app.router.add_get('/api/lora-model-description', routes.get_lora_model_description) # Add new route - app.router.add_post('/api/loras/save-metadata', routes.save_metadata) - app.router.add_get('/api/lora-preview-url', routes.get_lora_preview_url) # Add new route - app.router.add_post('/api/move_models_bulk', routes.move_models_bulk) - app.router.add_get('/api/loras/top-tags', routes.get_top_tags) # Add new route for top tags - app.router.add_get('/api/loras/base-models', routes.get_base_models) # Add new route for base models - app.router.add_get('/api/lora-civitai-url', routes.get_lora_civitai_url) # Add new route for Civitai URL - app.router.add_post('/api/loras/rename', routes.rename_lora) # Add new route for renaming LoRA files - app.router.add_get('/api/loras/scan', routes.scan_loras) # Add new route for scanning LoRA files + app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection) + app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection) - # Add the new trigger words route - app.router.add_post('/loramanager/get_trigger_words', routes.get_trigger_words) - - # Add new endpoint for letter counts - app.router.add_get('/api/loras/letter-counts', routes.get_letter_counts) - - # Add new endpoints for copying lora data - app.router.add_get('/api/loras/get-notes', routes.get_lora_notes) - app.router.add_get('/api/loras/get-trigger-words', routes.get_lora_trigger_words) - - # Add update check routes + # Setup update routes that are not model-specific UpdateRoutes.setup_routes(app) - # Add new endpoints for finding duplicates - app.router.add_get('/api/loras/find-duplicates', routes.find_duplicate_loras) - app.router.add_get('/api/loras/find-filename-conflicts', routes.find_filename_conflicts) - - # Add new endpoint for bulk deleting loras - app.router.add_post('/api/loras/bulk-delete', routes.bulk_delete_loras) - - # Add new endpoint for verifying duplicates - app.router.add_post('/api/loras/verify-duplicates', routes.verify_duplicates) - - async def delete_model(self, request: web.Request) -> web.Response: - """Handle model deletion request""" - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - return await ModelRouteUtils.handle_delete_model(request, self.scanner) - - async def exclude_model(self, request: web.Request) -> web.Response: - """Handle model exclusion request""" - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - return await ModelRouteUtils.handle_exclude_model(request, self.scanner) - - async def fetch_civitai(self, request: web.Request) -> web.Response: - """Handle CivitAI metadata fetch request""" - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - response = await ModelRouteUtils.handle_fetch_civitai(request, self.scanner) - - # If successful, format the metadata before returning - if response.status == 200: - data = json.loads(response.body.decode('utf-8')) - if data.get("success") and data.get("metadata"): - formatted_metadata = self._format_lora_response(data["metadata"]) - return web.json_response({ - "success": True, - "metadata": formatted_metadata - }) - - # Otherwise, return the original response - return response - - async def replace_preview(self, request: web.Request) -> web.Response: - """Handle preview image replacement request""" - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - return await ModelRouteUtils.handle_replace_preview(request, self.scanner) - - async def scan_loras(self, request: web.Request) -> web.Response: - """Force a rescan of LoRA files""" - try: - # Get full_rebuild parameter from query string, default to false - full_rebuild = request.query.get('full_rebuild', 'false').lower() == 'true' - - await self.scanner.get_cached_data(force_refresh=True, rebuild_cache=full_rebuild) - return web.json_response({"status": "success", "message": "LoRA scan completed"}) - except Exception as e: - logger.error(f"Error in scan_loras: {e}", exc_info=True) - return web.json_response({"error": str(e)}, status=500) - - async def get_loras(self, request: web.Request) -> web.Response: - """Handle paginated LoRA data request""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - # Parse query parameters - page = int(request.query.get('page', '1')) - page_size = int(request.query.get('page_size', '20')) - sort_by = request.query.get('sort_by', 'name') - folder = request.query.get('folder', None) - search = request.query.get('search', None) - fuzzy_search = request.query.get('fuzzy', 'false').lower() == 'true' - - # Parse search options - search_options = { - 'filename': request.query.get('search_filename', 'true').lower() == 'true', - 'modelname': request.query.get('search_modelname', 'true').lower() == 'true', - 'tags': request.query.get('search_tags', 'false').lower() == 'true', - 'recursive': request.query.get('recursive', 'false').lower() == 'true' - } - - # Get filter parameters - base_models = request.query.get('base_models', None) - tags = request.query.get('tags', None) - favorites_only = request.query.get('favorites_only', 'false').lower() == 'true' # New parameter - - # New parameter for alphabet filtering - first_letter = request.query.get('first_letter', None) - - # New parameters for recipe filtering - lora_hash = request.query.get('lora_hash', None) - lora_hashes = request.query.get('lora_hashes', None) - - # Parse filter parameters - filters = {} - if base_models: - filters['base_model'] = base_models.split(',') - if tags: - filters['tags'] = tags.split(',') - - # Add lora hash filtering options - hash_filters = {} - if lora_hash: - hash_filters['single_hash'] = lora_hash.lower() - elif lora_hashes: - hash_filters['multiple_hashes'] = [h.lower() for h in lora_hashes.split(',')] - - # Get file data - data = await self.scanner.get_paginated_data( - page, - page_size, - sort_by=sort_by, - folder=folder, - search=search, - fuzzy_search=fuzzy_search, - base_models=filters.get('base_model', None), - tags=filters.get('tags', None), - search_options=search_options, - hash_filters=hash_filters, - favorites_only=favorites_only, # Pass favorites_only parameter - first_letter=first_letter # Pass the new first_letter parameter - ) - - # Get all available folders from cache - cache = await self.scanner.get_cached_data() - - # Convert output to match expected format - result = { - 'items': [self._format_lora_response(lora) for lora in data['items']], - 'folders': cache.folders, - 'total': data['total'], - 'page': data['page'], - 'page_size': data['page_size'], - 'total_pages': data['total_pages'] - } - - return web.json_response(result) - - except Exception as e: - logger.error(f"Error retrieving loras: {e}", exc_info=True) - return web.json_response({"error": str(e)}, status=500) - - def _format_lora_response(self, lora: Dict) -> Dict: - """Format LoRA data for API response""" - return { - "model_name": lora["model_name"], - "file_name": lora["file_name"], - "preview_url": config.get_preview_static_url(lora["preview_url"]), - "preview_nsfw_level": lora.get("preview_nsfw_level", 0), - "base_model": lora["base_model"], - "folder": lora["folder"], - "sha256": lora["sha256"], - "file_path": lora["file_path"].replace(os.sep, "/"), - "file_size": lora["size"], - "modified": lora["modified"], - "tags": lora["tags"], - "modelDescription": lora["modelDescription"], - "from_civitai": lora.get("from_civitai", True), - "usage_tips": lora.get("usage_tips", ""), - "notes": lora.get("notes", ""), - "favorite": lora.get("favorite", False), # Include favorite status in response - "civitai": ModelRouteUtils.filter_civitai_data(lora.get("civitai", {})) - } - - async def fetch_all_civitai(self, request: web.Request) -> web.Response: - """Fetch CivitAI metadata for all loras in the background""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - cache = await self.scanner.get_cached_data() - total = len(cache.raw_data) - processed = 0 - success = 0 - needs_resort = False - - # Prepare loras to process - to_process = [ - lora for lora in cache.raw_data - if lora.get('sha256') and (not lora.get('civitai') or 'id' not in lora.get('civitai')) and lora.get('from_civitai', True) # TODO: for lora not from CivitAI but added traineWords - ] - total_to_process = len(to_process) - - # Send initial progress - await ws_manager.broadcast({ - 'status': 'started', - 'total': total_to_process, - 'processed': 0, - 'success': 0 - }) - - for lora in to_process: - try: - original_name = lora.get('model_name') - if await ModelRouteUtils.fetch_and_update_model( - sha256=lora['sha256'], - file_path=lora['file_path'], - model_data=lora, - update_cache_func=self.scanner.update_single_model_cache - ): - success += 1 - if original_name != lora.get('model_name'): - needs_resort = True - - processed += 1 - - # Send progress update - await ws_manager.broadcast({ - 'status': 'processing', - 'total': total_to_process, - 'processed': processed, - 'success': success, - 'current_name': lora.get('model_name', 'Unknown') - }) - - except Exception as e: - logger.error(f"Error fetching CivitAI data for {lora['file_path']}: {e}") - - if needs_resort: - await cache.resort(name_only=True) - - # Send completion message - await ws_manager.broadcast({ - 'status': 'completed', - 'total': total_to_process, - 'processed': processed, - 'success': success - }) - - return web.json_response({ - "success": True, - "message": f"Successfully updated {success} of {processed} processed loras (total: {total})" - }) - - except Exception as e: - # Send error message - await ws_manager.broadcast({ - 'status': 'error', - 'error': str(e) - }) - logger.error(f"Error in fetch_all_civitai: {e}") - return web.Response(text=str(e), status=500) - - async def get_lora_roots(self, request: web.Request) -> web.Response: - """Get all configured LoRA root directories""" - return web.json_response({ - 'roots': config.loras_roots - }) - - async def get_folders(self, request: web.Request) -> web.Response: - """Get all folders in the cache""" - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - cache = await self.scanner.get_cached_data() - return web.json_response({ - 'folders': cache.folders - }) - - async def get_civitai_versions(self, request: web.Request) -> web.Response: - """Get available versions for a Civitai model with local availability info""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - if self.civitai_client is None: - self.civitai_client = await ServiceRegistry.get_civitai_client() - - model_id = request.match_info['model_id'] - response = await self.civitai_client.get_model_versions(model_id) - if not response or not response.get('modelVersions'): - return web.Response(status=404, text="Model not found") - - versions = response.get('modelVersions', []) - model_type = response.get('type', '') - - # Check model type - should be LORA, LoCon, or DORA - if model_type.lower() not in VALID_LORA_TYPES: - return web.json_response({ - 'error': f"Model type mismatch. Expected LORA or LoCon, got {model_type}" - }, status=400) - - # Check local availability for each version - for version in versions: - # Find the model file (type="Model") in the files list - model_file = next((file for file in version.get('files', []) - if file.get('type') == 'Model'), None) - - if model_file: - sha256 = model_file.get('hashes', {}).get('SHA256') - if sha256: - # Set existsLocally and localPath at the version level - version['existsLocally'] = self.scanner.has_hash(sha256) - if version['existsLocally']: - version['localPath'] = self.scanner.get_path_by_hash(sha256) - - # Also set the model file size at the version level for easier access - version['modelSizeKB'] = model_file.get('sizeKB') - else: - # No model file found in this version - version['existsLocally'] = False - - return web.json_response(versions) - except Exception as e: - logger.error(f"Error fetching model versions: {e}") - return web.Response(status=500, text=str(e)) - - async def get_civitai_model_by_version(self, request: web.Request) -> web.Response: - """Get CivitAI model details by model version ID""" - try: - if self.civitai_client is None: - self.civitai_client = await ServiceRegistry.get_civitai_client() - - model_version_id = request.match_info.get('modelVersionId') - - # Get model details from Civitai API - model, error_msg = await self.civitai_client.get_model_version_info(model_version_id) - - if not model: - # Log warning for failed model retrieval - logger.warning(f"Failed to fetch model version {model_version_id}: {error_msg}") - - # Determine status code based on error message - status_code = 404 if error_msg and "not found" in error_msg.lower() else 500 - - return web.json_response({ - "success": False, - "error": error_msg or "Failed to fetch model information" - }, status=status_code) - - return web.json_response(model) - except Exception as e: - logger.error(f"Error fetching model details: {e}") - return web.json_response({ - "success": False, - "error": str(e) - }, status=500) - - async def get_civitai_model_by_hash(self, request: web.Request) -> web.Response: - """Get CivitAI model details by hash""" - try: - if self.civitai_client is None: - self.civitai_client = await ServiceRegistry.get_civitai_client() - - hash = request.match_info.get('hash') - model = await self.civitai_client.get_model_by_hash(hash) - return web.json_response(model) - except Exception as e: - logger.error(f"Error fetching model details by hash: {e}") - return web.json_response({ - "success": False, - "error": str(e) - }, status=500) - - async def download_model(self, request: web.Request) -> web.Response: - return await ModelRouteUtils.handle_download_model(request, self.download_manager) - - async def download_model_get(self, request: web.Request) -> web.Response: - """Handle model download request via GET method - - Converts GET parameters to POST format and calls the existing download handler - - Args: - request: The aiohttp request with query parameters - - Returns: - web.Response: The HTTP response - """ - try: - # Extract query parameters - model_id = request.query.get('model_id') - if not model_id: - return web.Response( - status=400, - text="Missing required parameter: Please provide 'model_id'" - ) - - # Get optional parameters - model_version_id = request.query.get('model_version_id') - download_id = request.query.get('download_id') - use_default_paths = request.query.get('use_default_paths', 'false').lower() == 'true' - - # Create a data dictionary that mimics what would be received from a POST request - data = { - 'model_id': model_id - } - - # Add optional parameters only if they are provided - if model_version_id: - data['model_version_id'] = model_version_id - - if download_id: - data['download_id'] = download_id - - data['use_default_paths'] = use_default_paths - - # Create a mock request object with the data - # Fix: Create a proper Future object and set its result - future = asyncio.get_event_loop().create_future() - future.set_result(data) - - mock_request = type('MockRequest', (), { - 'json': lambda self=None: future - })() - - # Call the existing download handler - if self.download_manager is None: - self.download_manager = await ServiceRegistry.get_download_manager() - - return await ModelRouteUtils.handle_download_model(mock_request, self.download_manager) - - except Exception as e: - error_message = str(e) - logger.error(f"Error downloading model via GET: {error_message}", exc_info=True) - return web.Response(status=500, text=error_message) - - async def cancel_download_get(self, request: web.Request) -> web.Response: - """Handle GET request for cancelling a download by download_id""" - try: - download_id = request.query.get('download_id') - if not download_id: - return web.json_response({ - 'success': False, - 'error': 'Download ID is required' - }, status=400) - if self.download_manager is None: - self.download_manager = await ServiceRegistry.get_download_manager() - # Create a mock request with match_info for compatibility - mock_request = type('MockRequest', (), { - 'match_info': {'download_id': download_id} - })() - return await ModelRouteUtils.handle_cancel_download(mock_request, self.download_manager) - except Exception as e: - logger.error(f"Error cancelling download via GET: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def get_download_progress(self, request: web.Request) -> web.Response: - """Handle request for download progress by download_id""" - try: - # Get download_id from URL path - download_id = request.match_info.get('download_id') - if not download_id: - return web.json_response({ - 'success': False, - 'error': 'Download ID is required' - }, status=400) - - # Get progress information from websocket manager - progress_data = ws_manager.get_download_progress(download_id) - - if progress_data is None: - return web.json_response({ - 'success': False, - 'error': 'Download ID not found' - }, status=404) - - return web.json_response({ - 'success': True, - 'progress': progress_data.get('progress', 0) - }) - except Exception as e: - logger.error(f"Error getting download progress: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def move_model(self, request: web.Request) -> web.Response: - """Handle model move request""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - data = await request.json() - file_path = data.get('file_path') # full path of the model file, e.g. /path/to/model.safetensors - target_path = data.get('target_path') # folder path to move the model to, e.g. /path/to/target_folder - - if not file_path or not target_path: - return web.Response(text='File path and target path are required', status=400) - - # Check if source and destination are the same - source_dir = os.path.dirname(file_path) - if os.path.normpath(source_dir) == os.path.normpath(target_path): - logger.info(f"Source and target directories are the same: {source_dir}") - return web.json_response({'success': True, 'message': 'Source and target directories are the same'}) - - # Check if target file already exists - file_name = os.path.basename(file_path) - target_file_path = os.path.join(target_path, file_name).replace(os.sep, '/') - - if os.path.exists(target_file_path): - return web.json_response({ - 'success': False, - 'error': f"Target file already exists: {target_file_path}" - }, status=409) # 409 Conflict - - # Call scanner to handle the move operation - success = await self.scanner.move_model(file_path, target_path) - - if success: - return web.json_response({'success': True}) - else: - return web.Response(text='Failed to move model', status=500) - - except Exception as e: - logger.error(f"Error moving model: {e}", exc_info=True) - return web.Response(text=str(e), status=500) - @classmethod async def cleanup(cls): """Add cleanup method for application shutdown""" - # Now we don't need to store an instance, as services are managed by ServiceRegistry - civitai_client = await ServiceRegistry.get_civitai_client() - if civitai_client: - await civitai_client.close() - - async def save_metadata(self, request: web.Request) -> web.Response: - """Handle saving metadata updates""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - data = await request.json() - file_path = data.get('file_path') - if not file_path: - return web.Response(text='File path is required', status=400) - - # Remove file path from data to avoid saving it - metadata_updates = {k: v for k, v in data.items() if k != 'file_path'} - - # Get metadata file path - metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' - - # Load existing metadata - metadata = await ModelRouteUtils.load_local_metadata(metadata_path) - - # Handle nested updates (for civitai.trainedWords) - for key, value in metadata_updates.items(): - if isinstance(value, dict) and key in metadata and isinstance(metadata[key], dict): - # Deep update for nested dictionaries - for nested_key, nested_value in value.items(): - metadata[key][nested_key] = nested_value - else: - # Regular update for top-level keys - metadata[key] = value - - # Save updated metadata - await MetadataManager.save_metadata(file_path, metadata) - - # Update cache - await self.scanner.update_single_model_cache(file_path, file_path, metadata) - - # If model_name was updated, resort the cache - if 'model_name' in metadata_updates: - cache = await self.scanner.get_cached_data() - await cache.resort(name_only=True) - - return web.json_response({'success': True}) - - except Exception as e: - logger.error(f"Error saving metadata: {e}", exc_info=True) - return web.Response(text=str(e), status=500) - - async def get_lora_preview_url(self, request: web.Request) -> web.Response: - """Get the static preview URL for a LoRA file""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - # Get lora file name from query parameters - lora_name = request.query.get('name') - if not lora_name: - return web.Response(text='Lora file name is required', status=400) - - # Get cache data - cache = await self.scanner.get_cached_data() - - # Search for the lora in cache data - for lora in cache.raw_data: - file_name = lora['file_name'] - if file_name == lora_name: - if preview_url := lora.get('preview_url'): - # Convert preview path to static URL - static_url = config.get_preview_static_url(preview_url) - if static_url: - return web.json_response({ - 'success': True, - 'preview_url': static_url - }) - break - - # If no preview URL found - return web.json_response({ - 'success': False, - 'error': 'No preview URL found for the specified lora' - }, status=404) - - except Exception as e: - logger.error(f"Error getting lora preview URL: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def get_lora_civitai_url(self, request: web.Request) -> web.Response: - """Get the Civitai URL for a LoRA file""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - # Get lora file name from query parameters - lora_name = request.query.get('name') - if not lora_name: - return web.Response(text='Lora file name is required', status=400) - - # Get cache data - cache = await self.scanner.get_cached_data() - - # Search for the lora in cache data - for lora in cache.raw_data: - file_name = lora['file_name'] - if file_name == lora_name: - civitai_data = lora.get('civitai', {}) - model_id = civitai_data.get('modelId') - version_id = civitai_data.get('id') - - if model_id: - civitai_url = f"https://civitai.com/models/{model_id}" - if version_id: - civitai_url += f"?modelVersionId={version_id}" - - return web.json_response({ - 'success': True, - 'civitai_url': civitai_url, - 'model_id': model_id, - 'version_id': version_id - }) - break - - # If no Civitai data found - return web.json_response({ - 'success': False, - 'error': 'No Civitai data found for the specified lora' - }, status=404) - - except Exception as e: - logger.error(f"Error getting lora Civitai URL: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def move_models_bulk(self, request: web.Request) -> web.Response: - """Handle bulk model move request""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - data = await request.json() - file_paths = data.get('file_paths', []) # list of full paths of the model files, e.g. ["/path/to/model1.safetensors", "/path/to/model2.safetensors"] - target_path = data.get('target_path') # folder path to move the models to, e.g. "/path/to/target_folder" - - if not file_paths or not target_path: - return web.Response(text='File paths and target path are required', status=400) - - results = [] - for file_path in file_paths: - # Check if source and destination are the same - source_dir = os.path.dirname(file_path) - if os.path.normpath(source_dir) == os.path.normpath(target_path): - results.append({ - "path": file_path, - "success": True, - "message": "Source and target directories are the same" - }) - continue - - # Check if target file already exists - file_name = os.path.basename(file_path) - target_file_path = os.path.join(target_path, file_name).replace(os.sep, '/') - - if os.path.exists(target_file_path): - results.append({ - "path": file_path, - "success": False, - "message": f"Target file already exists: {target_file_path}" - }) - continue - - # Try to move the model - success = await self.scanner.move_model(file_path, target_path) - results.append({ - "path": file_path, - "success": success, - "message": "Success" if success else "Failed to move model" - }) - - # Count successes and failures - success_count = sum(1 for r in results if r["success"]) - failure_count = len(results) - success_count - - return web.json_response({ - 'success': True, - 'message': f'Moved {success_count} of {len(file_paths)} models', - 'results': results, - 'success_count': success_count, - 'failure_count': failure_count - }) - - except Exception as e: - logger.error(f"Error moving models in bulk: {e}", exc_info=True) - return web.Response(text=str(e), status=500) - - async def get_lora_model_description(self, request: web.Request) -> web.Response: - """Get model description for a Lora model""" - try: - if self.civitai_client is None: - self.civitai_client = await ServiceRegistry.get_civitai_client() - - # Get parameters - model_id = request.query.get('model_id') - file_path = request.query.get('file_path') - - if not model_id: - return web.json_response({ - 'success': False, - 'error': 'Model ID is required' - }, status=400) - - # Check if we already have the description stored in metadata - description = None - tags = [] - creator = {} - if file_path: - metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' - metadata = await ModelRouteUtils.load_local_metadata(metadata_path) - description = metadata.get('modelDescription') - tags = metadata.get('tags', []) - creator = metadata.get('creator', {}) - - # If description is not in metadata, fetch from CivitAI - if not description: - logger.info(f"Fetching model metadata for model ID: {model_id}") - model_metadata, _ = await self.civitai_client.get_model_metadata(model_id) - - if (model_metadata): - description = model_metadata.get('description') - tags = model_metadata.get('tags', []) - creator = model_metadata.get('creator', {}) - - # Save the metadata to file if we have a file path and got metadata - if file_path: - try: - metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' - metadata = await ModelRouteUtils.load_local_metadata(metadata_path) - - metadata['modelDescription'] = description - metadata['tags'] = tags - # Ensure the civitai dict exists - if 'civitai' not in metadata: - metadata['civitai'] = {} - # Store creator in the civitai nested structure - metadata['civitai']['creator'] = creator - - await MetadataManager.save_metadata(file_path, metadata, True) - except Exception as e: - logger.error(f"Error saving model metadata: {e}") - - return web.json_response({ - 'success': True, - 'description': description or "
No model description available.
", - 'tags': tags, - 'creator': creator - }) - - except Exception as e: - logger.error(f"Error getting model metadata: {e}") - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def get_top_tags(self, request: web.Request) -> web.Response: - """Handle request for top tags sorted by frequency""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - # Parse query parameters - limit = int(request.query.get('limit', '20')) - - # Validate limit - if limit < 1 or limit > 100: - limit = 20 # Default to a reasonable limit - - # Get top tags - top_tags = await self.scanner.get_top_tags(limit) - - return web.json_response({ - 'success': True, - 'tags': top_tags - }) - - except Exception as e: - logger.error(f"Error getting top tags: {str(e)}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': 'Internal server error' - }, status=500) - - async def get_base_models(self, request: web.Request) -> web.Response: - """Get base models used in loras""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - # Parse query parameters - limit = int(request.query.get('limit', '20')) - - # Validate limit - if limit < 1 or limit > 100: - limit = 20 # Default to a reasonable limit - - # Get base models - base_models = await self.scanner.get_base_models(limit) - - return web.json_response({ - 'success': True, - 'base_models': base_models - }) - except Exception as e: - logger.error(f"Error retrieving base models: {e}") - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def rename_lora(self, request: web.Request) -> web.Response: - """Handle renaming a LoRA file and its associated files""" - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - return await ModelRouteUtils.handle_rename_model(request, self.scanner) - - async def get_trigger_words(self, request: web.Request) -> web.Response: - """Get trigger words for specified LoRA models""" - try: - json_data = await request.json() - lora_names = json_data.get("lora_names", []) - node_ids = json_data.get("node_ids", []) - - all_trigger_words = [] - for lora_name in lora_names: - _, trigger_words = get_lora_info(lora_name) - all_trigger_words.extend(trigger_words) - - # Format the trigger words - trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else "" - - # Send update to all connected trigger word toggle nodes - for node_id in node_ids: - PromptServer.instance.send_sync("trigger_word_update", { - "id": node_id, - "message": trigger_words_text - }) - - return web.json_response({"success": True}) - - except Exception as e: - logger.error(f"Error getting trigger words: {e}") - return web.json_response({ - "success": False, - "error": str(e) - }, status=500) - - async def get_letter_counts(self, request: web.Request) -> web.Response: - """Get count of loras for each letter of the alphabet""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - # Get letter counts - letter_counts = await self.scanner.get_letter_counts() - - return web.json_response({ - 'success': True, - 'letter_counts': letter_counts - }) - except Exception as e: - logger.error(f"Error getting letter counts: {e}") - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def get_lora_notes(self, request: web.Request) -> web.Response: - """Get notes for a specific LoRA file""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - # Get lora file name from query parameters - lora_name = request.query.get('name') - if not lora_name: - return web.Response(text='Lora file name is required', status=400) - - # Get cache data - cache = await self.scanner.get_cached_data() - - # Search for the lora in cache data - for lora in cache.raw_data: - file_name = lora['file_name'] - if file_name == lora_name: - notes = lora.get('notes', '') - - return web.json_response({ - 'success': True, - 'notes': notes - }) - - # If lora not found - return web.json_response({ - 'success': False, - 'error': 'LoRA not found in cache' - }, status=404) - - except Exception as e: - logger.error(f"Error getting lora notes: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def get_lora_trigger_words(self, request: web.Request) -> web.Response: - """Get trigger words for a specific LoRA file""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - # Get lora file name from query parameters - lora_name = request.query.get('name') - if not lora_name: - return web.Response(text='Lora file name is required', status=400) - - # Get cache data - cache = await self.scanner.get_cached_data() - - # Search for the lora in cache data - for lora in cache.raw_data: - file_name = lora['file_name'] - if file_name == lora_name: - # Get trigger words from civitai data - civitai_data = lora.get('civitai', {}) - trigger_words = civitai_data.get('trainedWords', []) - - return web.json_response({ - 'success': True, - 'trigger_words': trigger_words - }) - - # If lora not found - return web.json_response({ - 'success': False, - 'error': 'LoRA not found in cache' - }, status=404) - - except Exception as e: - logger.error(f"Error getting lora trigger words: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def find_duplicate_loras(self, request: web.Request) -> web.Response: - """Find loras with duplicate SHA256 hashes""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - # Get duplicate hashes from hash index - duplicates = self.scanner._hash_index.get_duplicate_hashes() - - # Format the response - result = [] - cache = await self.scanner.get_cached_data() - - for sha256, paths in duplicates.items(): - group = { - "hash": sha256, - "models": [] - } - # Find matching models for each duplicate path - for path in paths: - model = next((m for m in cache.raw_data if m['file_path'] == path), None) - if model: - group["models"].append(self._format_lora_response(model)) - - # Add the primary model too - primary_path = self.scanner._hash_index.get_path(sha256) - if primary_path and primary_path not in paths: - primary_model = next((m for m in cache.raw_data if m['file_path'] == primary_path), None) - if primary_model: - group["models"].insert(0, self._format_lora_response(primary_model)) - - if len(group["models"]) > 1: # Only include if we found multiple models - result.append(group) - - return web.json_response({ - "success": True, - "duplicates": result, - "count": len(result) - }) - except Exception as e: - logger.error(f"Error finding duplicate loras: {e}", exc_info=True) - return web.json_response({ - "success": False, - "error": str(e) - }, status=500) - - async def find_filename_conflicts(self, request: web.Request) -> web.Response: - """Find loras with conflicting filenames""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - # Get duplicate filenames from hash index - duplicates = self.scanner._hash_index.get_duplicate_filenames() - - # Format the response - result = [] - cache = await self.scanner.get_cached_data() - - for filename, paths in duplicates.items(): - group = { - "filename": filename, - "models": [] - } - # Find matching models for each path - for path in paths: - model = next((m for m in cache.raw_data if m['file_path'] == path), None) - if model: - group["models"].append(self._format_lora_response(model)) - - # Find the model from the main index too - hash_val = self.scanner._hash_index.get_hash_by_filename(filename) - if hash_val: - main_path = self.scanner._hash_index.get_path(hash_val) - if main_path and main_path not in paths: - main_model = next((m for m in cache.raw_data if m['file_path'] == main_path), None) - if main_model: - group["models"].insert(0, self._format_lora_response(main_model)) - - if group["models"]: # Only include if we found models - result.append(group) - - return web.json_response({ - "success": True, - "conflicts": result, - "count": len(result) - }) - except Exception as e: - logger.error(f"Error finding filename conflicts: {e}", exc_info=True) - return web.json_response({ - "success": False, - "error": str(e) - }, status=500) - - async def bulk_delete_loras(self, request: web.Request) -> web.Response: - """Handle bulk deletion of lora models""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - return await ModelRouteUtils.handle_bulk_delete_models(request, self.scanner) - - except Exception as e: - logger.error(f"Error in bulk delete loras: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def relink_civitai(self, request: web.Request) -> web.Response: - """Handle CivitAI metadata re-linking request by model version ID for LoRAs""" - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - return await ModelRouteUtils.handle_relink_civitai(request, self.scanner) - - async def verify_duplicates(self, request: web.Request) -> web.Response: - """Handle verification of duplicate lora hashes""" - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - return await ModelRouteUtils.handle_verify_duplicates(request, self.scanner) + # Cleanup is now handled by ServiceRegistry and individual services + pass diff --git a/py/routes/base_model_routes.py b/py/routes/base_model_routes.py new file mode 100644 index 00000000..7a6e4aac --- /dev/null +++ b/py/routes/base_model_routes.py @@ -0,0 +1,431 @@ +from abc import ABC, abstractmethod +import json +import logging +from aiohttp import web +from typing import Dict + +from ..utils.routes_common import ModelRouteUtils +from ..services.websocket_manager import ws_manager + +logger = logging.getLogger(__name__) + +class BaseModelRoutes(ABC): + """Base route controller for all model types""" + + def __init__(self, service): + """Initialize the route controller + + Args: + service: Model service instance (LoraService, CheckpointService, etc.) + """ + self.service = service + self.model_type = service.model_type + + def setup_routes(self, app: web.Application, prefix: str): + """Setup common routes for the model type + + Args: + app: aiohttp application + prefix: URL prefix (e.g., 'loras', 'checkpoints') + """ + # Common model management routes + app.router.add_get(f'/api/{prefix}', self.get_models) + app.router.add_post(f'/api/{prefix}/delete', self.delete_model) + app.router.add_post(f'/api/{prefix}/exclude', self.exclude_model) + app.router.add_post(f'/api/{prefix}/fetch-civitai', self.fetch_civitai) + app.router.add_post(f'/api/{prefix}/relink-civitai', self.relink_civitai) + app.router.add_post(f'/api/{prefix}/replace-preview', self.replace_preview) + app.router.add_post(f'/api/{prefix}/save-metadata', self.save_metadata) + app.router.add_post(f'/api/{prefix}/rename', self.rename_model) + app.router.add_post(f'/api/{prefix}/bulk-delete', self.bulk_delete_models) + app.router.add_post(f'/api/{prefix}/verify-duplicates', self.verify_duplicates) + + # Common query routes + app.router.add_get(f'/api/{prefix}/top-tags', self.get_top_tags) + app.router.add_get(f'/api/{prefix}/base-models', self.get_base_models) + app.router.add_get(f'/api/{prefix}/scan', self.scan_models) + app.router.add_get(f'/api/{prefix}/roots', self.get_model_roots) + app.router.add_get(f'/api/{prefix}/find-duplicates', self.find_duplicate_models) + app.router.add_get(f'/api/{prefix}/find-filename-conflicts', self.find_filename_conflicts) + + # CivitAI integration routes + app.router.add_post(f'/api/{prefix}/fetch-all-civitai', self.fetch_all_civitai) + app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions) + + # Setup model-specific routes + self.setup_specific_routes(app, prefix) + + @abstractmethod + def setup_specific_routes(self, app: web.Application, prefix: str): + """Setup model-specific routes - to be implemented by subclasses""" + pass + + async def get_models(self, request: web.Request) -> web.Response: + """Get paginated model data""" + try: + # Parse common query parameters + params = self._parse_common_params(request) + + # Get data from service + result = await self.service.get_paginated_data(**params) + + # Format response items + formatted_result = { + 'items': [await self.service.format_response(item) for item in result['items']], + 'total': result['total'], + 'page': result['page'], + 'page_size': result['page_size'], + 'total_pages': result['total_pages'] + } + + return web.json_response(formatted_result) + + except Exception as e: + logger.error(f"Error in get_{self.model_type}s: {e}", exc_info=True) + return web.json_response({"error": str(e)}, status=500) + + def _parse_common_params(self, request: web.Request) -> Dict: + """Parse common query parameters""" + # Parse basic pagination and sorting + page = int(request.query.get('page', '1')) + page_size = min(int(request.query.get('page_size', '20')), 100) + sort_by = request.query.get('sort_by', 'name') + folder = request.query.get('folder', None) + search = request.query.get('search', None) + fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true' + + # Parse filter arrays + base_models = request.query.getall('base_model', []) + tags = request.query.getall('tag', []) + favorites_only = request.query.get('favorites_only', 'false').lower() == 'true' + + # Parse search options + search_options = { + 'filename': request.query.get('search_filename', 'true').lower() == 'true', + 'modelname': request.query.get('search_modelname', 'true').lower() == 'true', + 'tags': request.query.get('search_tags', 'false').lower() == 'true', + 'recursive': request.query.get('recursive', 'false').lower() == 'true', + } + + # Parse hash filters if provided + hash_filters = {} + if 'hash' in request.query: + hash_filters['single_hash'] = request.query['hash'] + elif 'hashes' in request.query: + try: + hash_list = json.loads(request.query['hashes']) + if isinstance(hash_list, list): + hash_filters['multiple_hashes'] = hash_list + except (json.JSONDecodeError, TypeError): + pass + + return { + 'page': page, + 'page_size': page_size, + 'sort_by': sort_by, + 'folder': folder, + 'search': search, + 'fuzzy_search': fuzzy_search, + 'base_models': base_models, + 'tags': tags, + 'search_options': search_options, + 'hash_filters': hash_filters, + 'favorites_only': favorites_only, + # Add model-specific parameters + **self._parse_specific_params(request) + } + + def _parse_specific_params(self, request: web.Request) -> Dict: + """Parse model-specific parameters - to be overridden by subclasses""" + return {} + + # Common route handlers + async def delete_model(self, request: web.Request) -> web.Response: + """Handle model deletion request""" + return await ModelRouteUtils.handle_delete_model(request, self.service.scanner) + + async def exclude_model(self, request: web.Request) -> web.Response: + """Handle model exclusion request""" + return await ModelRouteUtils.handle_exclude_model(request, self.service.scanner) + + async def fetch_civitai(self, request: web.Request) -> web.Response: + """Handle CivitAI metadata fetch request""" + response = await ModelRouteUtils.handle_fetch_civitai(request, self.service.scanner) + + # If successful, format the metadata before returning + if response.status == 200: + data = json.loads(response.body.decode('utf-8')) + if data.get("success") and data.get("metadata"): + formatted_metadata = await self.service.format_response(data["metadata"]) + return web.json_response({ + "success": True, + "metadata": formatted_metadata + }) + + return response + + async def relink_civitai(self, request: web.Request) -> web.Response: + """Handle CivitAI metadata re-linking request""" + return await ModelRouteUtils.handle_relink_civitai(request, self.service.scanner) + + async def replace_preview(self, request: web.Request) -> web.Response: + """Handle preview image replacement""" + return await ModelRouteUtils.handle_replace_preview(request, self.service.scanner) + + async def save_metadata(self, request: web.Request) -> web.Response: + """Handle saving metadata updates""" + return await ModelRouteUtils.handle_save_metadata(request, self.service.scanner) + + async def rename_model(self, request: web.Request) -> web.Response: + """Handle renaming a model file and its associated files""" + return await ModelRouteUtils.handle_rename_model(request, self.service.scanner) + + async def bulk_delete_models(self, request: web.Request) -> web.Response: + """Handle bulk deletion of models""" + return await ModelRouteUtils.handle_bulk_delete_models(request, self.service.scanner) + + async def verify_duplicates(self, request: web.Request) -> web.Response: + """Handle verification of duplicate model hashes""" + return await ModelRouteUtils.handle_verify_duplicates(request, self.service.scanner) + + async def get_top_tags(self, request: web.Request) -> web.Response: + """Handle request for top tags sorted by frequency""" + try: + limit = int(request.query.get('limit', '20')) + if limit < 1 or limit > 100: + limit = 20 + + top_tags = await self.service.get_top_tags(limit) + + return web.json_response({ + 'success': True, + 'tags': top_tags + }) + + except Exception as e: + logger.error(f"Error getting top tags: {str(e)}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': 'Internal server error' + }, status=500) + + async def get_base_models(self, request: web.Request) -> web.Response: + """Get base models used in models""" + try: + limit = int(request.query.get('limit', '20')) + if limit < 1 or limit > 100: + limit = 20 + + base_models = await self.service.get_base_models(limit) + + return web.json_response({ + 'success': True, + 'base_models': base_models + }) + except Exception as e: + logger.error(f"Error retrieving base models: {e}") + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + async def scan_models(self, request: web.Request) -> web.Response: + """Force a rescan of model files""" + try: + full_rebuild = request.query.get('full_rebuild', 'false').lower() == 'true' + + await self.service.scan_models(force_refresh=True, rebuild_cache=full_rebuild) + return web.json_response({ + "status": "success", + "message": f"{self.model_type.capitalize()} scan completed" + }) + except Exception as e: + logger.error(f"Error in scan_{self.model_type}s: {e}", exc_info=True) + return web.json_response({"error": str(e)}, status=500) + + async def get_model_roots(self, request: web.Request) -> web.Response: + """Return the model root directories""" + try: + roots = self.service.get_model_roots() + return web.json_response({ + "success": True, + "roots": roots + }) + except Exception as e: + logger.error(f"Error getting {self.model_type} roots: {e}", exc_info=True) + return web.json_response({ + "success": False, + "error": str(e) + }, status=500) + + async def find_duplicate_models(self, request: web.Request) -> web.Response: + """Find models with duplicate SHA256 hashes""" + try: + # Get duplicate hashes from service + duplicates = self.service.find_duplicate_hashes() + + # Format the response + result = [] + cache = await self.service.scanner.get_cached_data() + + for sha256, paths in duplicates.items(): + group = { + "hash": sha256, + "models": [] + } + # Find matching models for each path + for path in paths: + model = next((m for m in cache.raw_data if m['file_path'] == path), None) + if model: + group["models"].append(await self.service.format_response(model)) + + # Add the primary model too + primary_path = self.service.get_path_by_hash(sha256) + if primary_path and primary_path not in paths: + primary_model = next((m for m in cache.raw_data if m['file_path'] == primary_path), None) + if primary_model: + group["models"].insert(0, await self.service.format_response(primary_model)) + + if len(group["models"]) > 1: # Only include if we found multiple models + result.append(group) + + return web.json_response({ + "success": True, + "duplicates": result, + "count": len(result) + }) + except Exception as e: + logger.error(f"Error finding duplicate {self.model_type}s: {e}", exc_info=True) + return web.json_response({ + "success": False, + "error": str(e) + }, status=500) + + async def find_filename_conflicts(self, request: web.Request) -> web.Response: + """Find models with conflicting filenames""" + try: + # Get duplicate filenames from service + duplicates = self.service.find_duplicate_filenames() + + # Format the response + result = [] + cache = await self.service.scanner.get_cached_data() + + for filename, paths in duplicates.items(): + group = { + "filename": filename, + "models": [] + } + # Find matching models for each path + for path in paths: + model = next((m for m in cache.raw_data if m['file_path'] == path), None) + if model: + group["models"].append(await self.service.format_response(model)) + + # Find the model from the main index too + hash_val = self.service.scanner._hash_index.get_hash_by_filename(filename) + if hash_val: + main_path = self.service.get_path_by_hash(hash_val) + if main_path and main_path not in paths: + main_model = next((m for m in cache.raw_data if m['file_path'] == main_path), None) + if main_model: + group["models"].insert(0, await self.service.format_response(main_model)) + + if group["models"]: + result.append(group) + + return web.json_response({ + "success": True, + "conflicts": result, + "count": len(result) + }) + except Exception as e: + logger.error(f"Error finding filename conflicts for {self.model_type}s: {e}", exc_info=True) + return web.json_response({ + "success": False, + "error": str(e) + }, status=500) + + async def fetch_all_civitai(self, request: web.Request) -> web.Response: + """Fetch CivitAI metadata for all models in the background""" + try: + cache = await self.service.scanner.get_cached_data() + total = len(cache.raw_data) + processed = 0 + success = 0 + needs_resort = False + + # Prepare models to process + to_process = [ + model for model in cache.raw_data + if model.get('sha256') and (not model.get('civitai') or 'id' not in model.get('civitai')) and model.get('from_civitai', True) + ] + total_to_process = len(to_process) + + # Send initial progress + await ws_manager.broadcast({ + 'status': 'started', + 'total': total_to_process, + 'processed': 0, + 'success': 0 + }) + + # Process each model + for model in to_process: + try: + original_name = model.get('model_name') + if await ModelRouteUtils.fetch_and_update_model( + sha256=model['sha256'], + file_path=model['file_path'], + model_data=model, + update_cache_func=self.service.scanner.update_single_model_cache + ): + success += 1 + if original_name != model.get('model_name'): + needs_resort = True + + processed += 1 + + # Send progress update + await ws_manager.broadcast({ + 'status': 'processing', + 'total': total_to_process, + 'processed': processed, + 'success': success, + 'current_name': model.get('model_name', 'Unknown') + }) + + except Exception as e: + logger.error(f"Error fetching CivitAI data for {model['file_path']}: {e}") + + if needs_resort: + await cache.resort(name_only=True) + + # Send completion message + await ws_manager.broadcast({ + 'status': 'completed', + 'total': total_to_process, + 'processed': processed, + 'success': success + }) + + return web.json_response({ + "success": True, + "message": f"Successfully updated {success} of {processed} processed {self.model_type}s (total: {total})" + }) + + except Exception as e: + # Send error message + await ws_manager.broadcast({ + 'status': 'error', + 'error': str(e) + }) + logger.error(f"Error in fetch_all_civitai for {self.model_type}s: {e}") + return web.Response(text=str(e), status=500) + + async def get_civitai_versions(self, request: web.Request) -> web.Response: + """Get available versions for a Civitai model with local availability info""" + # This will be implemented by subclasses as they need CivitAI client access + return web.json_response({ + "error": "Not implemented in base class" + }, status=501) \ No newline at end of file diff --git a/py/routes/checkpoint_routes.py b/py/routes/checkpoint_routes.py new file mode 100644 index 00000000..5fdb670e --- /dev/null +++ b/py/routes/checkpoint_routes.py @@ -0,0 +1,170 @@ +import jinja2 +import logging +from aiohttp import web + +from .base_model_routes import BaseModelRoutes +from ..services.checkpoint_service import CheckpointService +from ..services.service_registry import ServiceRegistry +from ..config import config +from ..services.settings_manager import settings + +logger = logging.getLogger(__name__) + +class CheckpointRoutes(BaseModelRoutes): + """Checkpoint-specific route controller""" + + def __init__(self): + """Initialize Checkpoint routes with Checkpoint service""" + # Service will be initialized later via setup_routes + self.service = None + self.civitai_client = None + self.template_env = jinja2.Environment( + loader=jinja2.FileSystemLoader(config.templates_path), + autoescape=True + ) + + async def initialize_services(self): + """Initialize services from ServiceRegistry""" + checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() + self.service = CheckpointService(checkpoint_scanner) + self.civitai_client = await ServiceRegistry.get_civitai_client() + + # Initialize parent with the service + super().__init__(self.service) + + def setup_routes(self, app: web.Application): + """Setup Checkpoint routes""" + # Schedule service initialization on app startup + app.on_startup.append(lambda _: self.initialize_services()) + + # Setup common routes with 'checkpoints' prefix + super().setup_routes(app, 'checkpoints') + + def setup_specific_routes(self, app: web.Application, prefix: str): + """Setup Checkpoint-specific routes""" + # Checkpoint page route + app.router.add_get('/checkpoints', self.handle_checkpoints_page) + + # Checkpoint-specific CivitAI integration + app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions_checkpoint) + + # Checkpoint info by name + app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_checkpoint_info) + + async def handle_checkpoints_page(self, request: web.Request) -> web.Response: + """Handle GET /checkpoints request""" + try: + # Check if the CheckpointScanner is initializing + # It's initializing if the cache object doesn't exist yet, + # OR if the scanner explicitly says it's initializing (background task running). + is_initializing = ( + self.service.scanner._cache is None or + (hasattr(self.service.scanner, '_is_initializing') and self.service.scanner._is_initializing) + ) + + if is_initializing: + # If still initializing, return loading page + template = self.template_env.get_template('checkpoints.html') + rendered = template.render( + folders=[], # Empty folder list + is_initializing=True, # New flag + settings=settings, # Pass settings to template + request=request # Pass the request object to the template + ) + + logger.info("Checkpoints page is initializing, returning loading page") + else: + # Normal flow - get initialized cache data + try: + cache = await self.service.scanner.get_cached_data(force_refresh=False) + template = self.template_env.get_template('checkpoints.html') + rendered = template.render( + folders=cache.folders, + is_initializing=False, + settings=settings, # Pass settings to template + request=request # Pass the request object to the template + ) + except Exception as cache_error: + logger.error(f"Error loading checkpoints cache data: {cache_error}") + # If getting cache fails, also show initialization page + template = self.template_env.get_template('checkpoints.html') + rendered = template.render( + folders=[], + is_initializing=True, + settings=settings, + request=request + ) + logger.info("Checkpoints cache error, returning initialization page") + + return web.Response( + text=rendered, + content_type='text/html' + ) + except Exception as e: + logger.error(f"Error handling checkpoints request: {e}", exc_info=True) + return web.Response( + text="Error loading checkpoints page", + status=500 + ) + + async def get_checkpoint_info(self, request: web.Request) -> web.Response: + """Get detailed information for a specific checkpoint by name""" + try: + name = request.match_info.get('name', '') + checkpoint_info = await self.service.get_model_info_by_name(name) + + if checkpoint_info: + return web.json_response(checkpoint_info) + else: + return web.json_response({"error": "Checkpoint not found"}, status=404) + + except Exception as e: + logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True) + return web.json_response({"error": str(e)}, status=500) + + async def get_civitai_versions_checkpoint(self, request: web.Request) -> web.Response: + """Get available versions for a Civitai checkpoint model with local availability info""" + try: + model_id = request.match_info['model_id'] + response = await self.civitai_client.get_model_versions(model_id) + if not response or not response.get('modelVersions'): + return web.Response(status=404, text="Model not found") + + versions = response.get('modelVersions', []) + model_type = response.get('type', '') + + # Check model type - should be Checkpoint + if model_type.lower() != 'checkpoint': + return web.json_response({ + 'error': f"Model type mismatch. Expected Checkpoint, got {model_type}" + }, status=400) + + # Check local availability for each version + for version in versions: + # Find the primary model file (type="Model" and primary=true) in the files list + model_file = next((file for file in version.get('files', []) + if file.get('type') == 'Model' and file.get('primary') == True), None) + + # If no primary file found, try to find any model file + if not model_file: + model_file = next((file for file in version.get('files', []) + if file.get('type') == 'Model'), None) + + if model_file: + sha256 = model_file.get('hashes', {}).get('SHA256') + if sha256: + # Set existsLocally and localPath at the version level + version['existsLocally'] = self.service.has_hash(sha256) + if version['existsLocally']: + version['localPath'] = self.service.get_path_by_hash(sha256) + + # Also set the model file size at the version level for easier access + version['modelSizeKB'] = model_file.get('sizeKB') + else: + # No model file found in this version + version['existsLocally'] = False + + return web.json_response(versions) + except Exception as e: + logger.error(f"Error fetching checkpoint model versions: {e}") + return web.Response(status=500, text=str(e)) \ No newline at end of file diff --git a/py/routes/lora_routes.py b/py/routes/lora_routes.py index 1d00b66e..5b0674ba 100644 --- a/py/routes/lora_routes.py +++ b/py/routes/lora_routes.py @@ -1,76 +1,126 @@ +import jinja2 +import asyncio +import logging import os from aiohttp import web -import jinja2 from typing import Dict -import logging -from ..config import config +from server import PromptServer # type: ignore + +from .base_model_routes import BaseModelRoutes +from ..services.lora_service import LoraService +from ..services.service_registry import ServiceRegistry from ..services.settings_manager import settings -from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import +from ..config import config +from ..utils.routes_common import ModelRouteUtils +from ..utils.utils import get_lora_info logger = logging.getLogger(__name__) -logging.getLogger('asyncio').setLevel(logging.CRITICAL) -class LoraRoutes: - """Route handlers for LoRA management endpoints""" +class LoraRoutes(BaseModelRoutes): + """LoRA-specific route controller""" def __init__(self): - # Initialize service references as None, will be set during async init - self.scanner = None - self.recipe_scanner = None + """Initialize LoRA routes with LoRA service""" + # Service will be initialized later via setup_routes + self.service = None + self.civitai_client = None + self.download_manager = None + self._download_lock = asyncio.Lock() self.template_env = jinja2.Environment( loader=jinja2.FileSystemLoader(config.templates_path), autoescape=True ) - - async def init_services(self): - """Initialize services from ServiceRegistry""" - self.scanner = await ServiceRegistry.get_lora_scanner() - self.recipe_scanner = await ServiceRegistry.get_recipe_scanner() - def format_lora_data(self, lora: Dict) -> Dict: - """Format LoRA data for template rendering""" - return { - "model_name": lora["model_name"], - "file_name": lora["file_name"], - "preview_url": config.get_preview_static_url(lora["preview_url"]), - "preview_nsfw_level": lora.get("preview_nsfw_level", 0), - "base_model": lora["base_model"], - "folder": lora["folder"], - "sha256": lora["sha256"], - "file_path": lora["file_path"].replace(os.sep, "/"), - "size": lora["size"], - "tags": lora["tags"], - "modelDescription": lora["modelDescription"], - "usage_tips": lora["usage_tips"], - "notes": lora["notes"], - "modified": lora["modified"], - "from_civitai": lora.get("from_civitai", True), - "civitai": self._filter_civitai_data(lora.get("civitai", {})) - } - - def _filter_civitai_data(self, data: Dict) -> Dict: - """Filter relevant fields from CivitAI data""" - if not data: - return {} - - fields = [ - "id", "modelId", "name", "createdAt", "updatedAt", - "publishedAt", "trainedWords", "baseModel", "description", - "model", "images" - ] - return {k: data[k] for k in fields if k in data} + async def initialize_services(self): + """Initialize services from ServiceRegistry""" + lora_scanner = await ServiceRegistry.get_lora_scanner() + self.service = LoraService(lora_scanner) + self.civitai_client = await ServiceRegistry.get_civitai_client() + self.download_manager = await ServiceRegistry.get_download_manager() + + # Initialize parent with the service + super().__init__(self.service) + + def setup_routes(self, app: web.Application): + """Setup LoRA routes""" + # Schedule service initialization on app startup + app.on_startup.append(lambda _: self.initialize_services()) + + # Setup common routes with 'loras' prefix + super().setup_routes(app, 'loras') + + def setup_specific_routes(self, app: web.Application, prefix: str): + """Setup LoRA-specific routes""" + # Lora page route + app.router.add_get('/loras', self.handle_loras_page) + # LoRA-specific query routes + app.router.add_get(f'/api/{prefix}/letter-counts', self.get_letter_counts) + app.router.add_get(f'/api/{prefix}/get-notes', self.get_lora_notes) + app.router.add_get(f'/api/{prefix}/get-trigger-words', self.get_lora_trigger_words) + app.router.add_get(f'/api/lora-preview-url', self.get_lora_preview_url) + app.router.add_get(f'/api/lora-civitai-url', self.get_lora_civitai_url) + app.router.add_get(f'/api/lora-model-description', self.get_lora_model_description) + app.router.add_get(f'/api/folders', self.get_folders) + app.router.add_get(f'/api/lora-roots', self.get_lora_roots) + + # LoRA-specific management routes + app.router.add_post(f'/api/move_model', self.move_model) + app.router.add_post(f'/api/move_models_bulk', self.move_models_bulk) + + # CivitAI integration with LoRA-specific validation + app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions_lora) + app.router.add_get(f'/api/civitai/model/version/{{modelVersionId}}', self.get_civitai_model_by_version) + app.router.add_get(f'/api/civitai/model/hash/{{hash}}', self.get_civitai_model_by_hash) + + # Download management + app.router.add_post(f'/api/download-model', self.download_model) + app.router.add_get(f'/api/download-model-get', self.download_model_get) + app.router.add_get(f'/api/cancel-download-get', self.cancel_download_get) + app.router.add_get(f'/api/download-progress/{{download_id}}', self.get_download_progress) + + # ComfyUI integration + app.router.add_post(f'/loramanager/get_trigger_words', self.get_trigger_words) + + # Legacy API compatibility + app.router.add_post(f'/api/delete_model', self.delete_model) + app.router.add_post(f'/api/fetch-civitai', self.fetch_civitai) + app.router.add_post(f'/api/relink-civitai', self.relink_civitai) + app.router.add_post(f'/api/replace_preview', self.replace_preview) + app.router.add_post(f'/api/fetch-all-civitai', self.fetch_all_civitai) + + def _parse_specific_params(self, request: web.Request) -> Dict: + """Parse LoRA-specific parameters""" + params = {} + + # LoRA-specific parameters + if 'first_letter' in request.query: + params['first_letter'] = request.query.get('first_letter') + + # Handle fuzzy search parameter name variation + if request.query.get('fuzzy') == 'true': + params['fuzzy_search'] = True + + # Handle additional filter parameters for LoRAs + if 'lora_hash' in request.query: + if not params.get('hash_filters'): + params['hash_filters'] = {} + params['hash_filters']['single_hash'] = request.query['lora_hash'].lower() + elif 'lora_hashes' in request.query: + if not params.get('hash_filters'): + params['hash_filters'] = {} + params['hash_filters']['multiple_hashes'] = [h.lower() for h in request.query['lora_hashes'].split(',')] + + return params + async def handle_loras_page(self, request: web.Request) -> web.Response: """Handle GET /loras request""" try: - # Ensure services are initialized - await self.init_services() - # Check if the LoraScanner is initializing # It's initializing if the cache object doesn't exist yet, # OR if the scanner explicitly says it's initializing (background task running). is_initializing = ( - self.scanner._cache is None or self.scanner.is_initializing() + self.service.scanner._cache is None or self.service.scanner.is_initializing() ) if is_initializing: @@ -87,7 +137,7 @@ class LoraRoutes: else: # Normal flow - get data from initialized cache try: - cache = await self.scanner.get_cached_data(force_refresh=False) + cache = await self.service.scanner.get_cached_data(force_refresh=False) template = self.template_env.get_template('loras.html') rendered = template.render( folders=cache.folders, @@ -117,72 +167,561 @@ class LoraRoutes: text="Error loading loras page", status=500 ) - - async def handle_recipes_page(self, request: web.Request) -> web.Response: - """Handle GET /loras/recipes request""" + + # LoRA-specific route handlers + async def get_letter_counts(self, request: web.Request) -> web.Response: + """Get count of LoRAs for each letter of the alphabet""" try: - # Ensure services are initialized - await self.init_services() + letter_counts = await self.service.get_letter_counts() + return web.json_response({ + 'success': True, + 'letter_counts': letter_counts + }) + except Exception as e: + logger.error(f"Error getting letter counts: {e}") + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + async def get_lora_notes(self, request: web.Request) -> web.Response: + """Get notes for a specific LoRA file""" + try: + lora_name = request.query.get('name') + if not lora_name: + return web.Response(text='Lora file name is required', status=400) - # Skip initialization check and directly try to get cached data - try: - # Recipe scanner will initialize cache if needed - await self.recipe_scanner.get_cached_data(force_refresh=False) - template = self.template_env.get_template('recipes.html') - rendered = template.render( - recipes=[], # Frontend will load recipes via API - is_initializing=False, - settings=settings, - request=request - ) - except Exception as cache_error: - logger.error(f"Error loading recipe cache data: {cache_error}") - # Still keep error handling - show initializing page on error - template = self.template_env.get_template('recipes.html') - rendered = template.render( - is_initializing=True, - settings=settings, - request=request - ) - logger.info("Recipe cache error, returning initialization page") + notes = await self.service.get_lora_notes(lora_name) + if notes is not None: + return web.json_response({ + 'success': True, + 'notes': notes + }) + else: + return web.json_response({ + 'success': False, + 'error': 'LoRA not found in cache' + }, status=404) + + except Exception as e: + logger.error(f"Error getting lora notes: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + async def get_lora_trigger_words(self, request: web.Request) -> web.Response: + """Get trigger words for a specific LoRA file""" + try: + lora_name = request.query.get('name') + if not lora_name: + return web.Response(text='Lora file name is required', status=400) - return web.Response( - text=rendered, - content_type='text/html' - ) + trigger_words = await self.service.get_lora_trigger_words(lora_name) + return web.json_response({ + 'success': True, + 'trigger_words': trigger_words + }) except Exception as e: - logger.error(f"Error handling recipes request: {e}", exc_info=True) - return web.Response( - text="Error loading recipes page", - status=500 - ) - - def _format_recipe_file_url(self, file_path: str) -> str: - """Format file path for recipe image as a URL - same as in recipe_routes""" + logger.error(f"Error getting lora trigger words: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + async def get_lora_preview_url(self, request: web.Request) -> web.Response: + """Get the static preview URL for a LoRA file""" try: - # Return the file URL directly for the first lora root's preview - recipes_dir = os.path.join(config.loras_roots[0], "recipes").replace(os.sep, '/') - if file_path.replace(os.sep, '/').startswith(recipes_dir): - relative_path = os.path.relpath(file_path, config.loras_roots[0]).replace(os.sep, '/') - return f"/loras_static/root1/preview/{relative_path}" + lora_name = request.query.get('name') + if not lora_name: + return web.Response(text='Lora file name is required', status=400) - # If not in recipes dir, try to create a valid URL from the file path + preview_url = await self.service.get_lora_preview_url(lora_name) + if preview_url: + return web.json_response({ + 'success': True, + 'preview_url': preview_url + }) + else: + return web.json_response({ + 'success': False, + 'error': 'No preview URL found for the specified lora' + }, status=404) + + except Exception as e: + logger.error(f"Error getting lora preview URL: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + async def get_lora_civitai_url(self, request: web.Request) -> web.Response: + """Get the Civitai URL for a LoRA file""" + try: + lora_name = request.query.get('name') + if not lora_name: + return web.Response(text='Lora file name is required', status=400) + + result = await self.service.get_lora_civitai_url(lora_name) + if result['civitai_url']: + return web.json_response({ + 'success': True, + **result + }) + else: + return web.json_response({ + 'success': False, + 'error': 'No Civitai data found for the specified lora' + }, status=404) + + except Exception as e: + logger.error(f"Error getting lora Civitai URL: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + async def get_folders(self, request: web.Request) -> web.Response: + """Get all folders in the cache""" + try: + cache = await self.service.scanner.get_cached_data() + return web.json_response({ + 'folders': cache.folders + }) + except Exception as e: + logger.error(f"Error getting folders: {e}") + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + async def get_lora_roots(self, request: web.Request) -> web.Response: + """Get all configured LoRA root directories""" + try: + return web.json_response({ + 'roots': self.service.get_model_roots() + }) + except Exception as e: + logger.error(f"Error getting LoRA roots: {e}") + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + # Override get_models to add LoRA-specific response data + async def get_models(self, request: web.Request) -> web.Response: + """Get paginated LoRA data with LoRA-specific fields""" + try: + # Parse common query parameters + params = self._parse_common_params(request) + + # Get data from service + result = await self.service.get_paginated_data(**params) + + # Get all available folders from cache for LoRA-specific response + cache = await self.service.scanner.get_cached_data() + + # Format response items with LoRA-specific structure + formatted_result = { + 'items': [await self.service.format_response(item) for item in result['items']], + 'folders': cache.folders, # LoRA-specific: include folders in response + 'total': result['total'], + 'page': result['page'], + 'page_size': result['page_size'], + 'total_pages': result['total_pages'] + } + + return web.json_response(formatted_result) + + except Exception as e: + logger.error(f"Error in get_loras: {e}", exc_info=True) + return web.json_response({"error": str(e)}, status=500) + + # CivitAI integration methods + async def get_civitai_versions_lora(self, request: web.Request) -> web.Response: + """Get available versions for a Civitai LoRA model with local availability info""" + try: + model_id = request.match_info['model_id'] + response = await self.civitai_client.get_model_versions(model_id) + if not response or not response.get('modelVersions'): + return web.Response(status=404, text="Model not found") + + versions = response.get('modelVersions', []) + model_type = response.get('type', '') + + # Check model type - should be LORA, LoCon, or DORA + from ..utils.constants import VALID_LORA_TYPES + if model_type.lower() not in VALID_LORA_TYPES: + return web.json_response({ + 'error': f"Model type mismatch. Expected LORA or LoCon, got {model_type}" + }, status=400) + + # Check local availability for each version + for version in versions: + # Find the model file (type="Model") in the files list + model_file = next((file for file in version.get('files', []) + if file.get('type') == 'Model'), None) + + if model_file: + sha256 = model_file.get('hashes', {}).get('SHA256') + if sha256: + # Set existsLocally and localPath at the version level + version['existsLocally'] = self.service.has_hash(sha256) + if version['existsLocally']: + version['localPath'] = self.service.get_path_by_hash(sha256) + + # Also set the model file size at the version level for easier access + version['modelSizeKB'] = model_file.get('sizeKB') + else: + # No model file found in this version + version['existsLocally'] = False + + return web.json_response(versions) + except Exception as e: + logger.error(f"Error fetching LoRA model versions: {e}") + return web.Response(status=500, text=str(e)) + + async def get_civitai_model_by_version(self, request: web.Request) -> web.Response: + """Get CivitAI model details by model version ID""" + try: + model_version_id = request.match_info.get('modelVersionId') + + # Get model details from Civitai API + model, error_msg = await self.civitai_client.get_model_version_info(model_version_id) + + if not model: + # Log warning for failed model retrieval + logger.warning(f"Failed to fetch model version {model_version_id}: {error_msg}") + + # Determine status code based on error message + status_code = 404 if error_msg and "not found" in error_msg.lower() else 500 + + return web.json_response({ + "success": False, + "error": error_msg or "Failed to fetch model information" + }, status=status_code) + + return web.json_response(model) + except Exception as e: + logger.error(f"Error fetching model details: {e}") + return web.json_response({ + "success": False, + "error": str(e) + }, status=500) + + async def get_civitai_model_by_hash(self, request: web.Request) -> web.Response: + """Get CivitAI model details by hash""" + try: + hash = request.match_info.get('hash') + model = await self.civitai_client.get_model_by_hash(hash) + return web.json_response(model) + except Exception as e: + logger.error(f"Error fetching model details by hash: {e}") + return web.json_response({ + "success": False, + "error": str(e) + }, status=500) + + # Download management methods + async def download_model(self, request: web.Request) -> web.Response: + """Handle model download request""" + return await ModelRouteUtils.handle_download_model(request, self.download_manager) + + async def download_model_get(self, request: web.Request) -> web.Response: + """Handle model download request via GET method""" + try: + # Extract query parameters + model_id = request.query.get('model_id') + if not model_id: + return web.Response( + status=400, + text="Missing required parameter: Please provide 'model_id'" + ) + + # Get optional parameters + model_version_id = request.query.get('model_version_id') + download_id = request.query.get('download_id') + use_default_paths = request.query.get('use_default_paths', 'false').lower() == 'true' + + # Create a data dictionary that mimics what would be received from a POST request + data = { + 'model_id': model_id + } + + # Add optional parameters only if they are provided + if model_version_id: + data['model_version_id'] = model_version_id + + if download_id: + data['download_id'] = download_id + + data['use_default_paths'] = use_default_paths + + # Create a mock request object with the data + future = asyncio.get_event_loop().create_future() + future.set_result(data) + + mock_request = type('MockRequest', (), { + 'json': lambda self=None: future + })() + + # Call the existing download handler + return await ModelRouteUtils.handle_download_model(mock_request, self.download_manager) + + except Exception as e: + error_message = str(e) + logger.error(f"Error downloading model via GET: {error_message}", exc_info=True) + return web.Response(status=500, text=error_message) + + async def cancel_download_get(self, request: web.Request) -> web.Response: + """Handle GET request for cancelling a download by download_id""" + try: + download_id = request.query.get('download_id') + if not download_id: + return web.json_response({ + 'success': False, + 'error': 'Download ID is required' + }, status=400) + + # Create a mock request with match_info for compatibility + mock_request = type('MockRequest', (), { + 'match_info': {'download_id': download_id} + })() + return await ModelRouteUtils.handle_cancel_download(mock_request, self.download_manager) + except Exception as e: + logger.error(f"Error cancelling download via GET: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + async def get_download_progress(self, request: web.Request) -> web.Response: + """Handle request for download progress by download_id""" + try: + # Get download_id from URL path + download_id = request.match_info.get('download_id') + if not download_id: + return web.json_response({ + 'success': False, + 'error': 'Download ID is required' + }, status=400) + + # Get progress information from websocket manager + from ..services.websocket_manager import ws_manager + progress_data = ws_manager.get_download_progress(download_id) + + if progress_data is None: + return web.json_response({ + 'success': False, + 'error': 'Download ID not found' + }, status=404) + + return web.json_response({ + 'success': True, + 'progress': progress_data.get('progress', 0) + }) + except Exception as e: + logger.error(f"Error getting download progress: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + # Model management methods + async def move_model(self, request: web.Request) -> web.Response: + """Handle model move request""" + try: + data = await request.json() + file_path = data.get('file_path') # full path of the model file + target_path = data.get('target_path') # folder path to move the model to + + if not file_path or not target_path: + return web.Response(text='File path and target path are required', status=400) + + # Check if source and destination are the same + import os + source_dir = os.path.dirname(file_path) + if os.path.normpath(source_dir) == os.path.normpath(target_path): + logger.info(f"Source and target directories are the same: {source_dir}") + return web.json_response({'success': True, 'message': 'Source and target directories are the same'}) + + # Check if target file already exists file_name = os.path.basename(file_path) - return f"/loras_static/root1/preview/recipes/{file_name}" - except Exception as e: - logger.error(f"Error formatting recipe file URL: {e}", exc_info=True) - return '/loras_static/images/no-preview.png' # Return default image on error + target_file_path = os.path.join(target_path, file_name).replace(os.sep, '/') + + if os.path.exists(target_file_path): + return web.json_response({ + 'success': False, + 'error': f"Target file already exists: {target_file_path}" + }, status=409) # 409 Conflict - def setup_routes(self, app: web.Application): - """Register routes with the application""" - # Add an app startup handler to initialize services - app.on_startup.append(self._on_startup) - - # Register routes - app.router.add_get('/loras', self.handle_loras_page) - app.router.add_get('/loras/recipes', self.handle_recipes_page) - - async def _on_startup(self, app): - """Initialize services when the app starts""" - await self.init_services() + # Call scanner to handle the move operation + success = await self.service.scanner.move_model(file_path, target_path) + + if success: + return web.json_response({'success': True}) + else: + return web.Response(text='Failed to move model', status=500) + + except Exception as e: + logger.error(f"Error moving model: {e}", exc_info=True) + return web.Response(text=str(e), status=500) + + async def move_models_bulk(self, request: web.Request) -> web.Response: + """Handle bulk model move request""" + try: + data = await request.json() + file_paths = data.get('file_paths', []) # list of full paths of the model files + target_path = data.get('target_path') # folder path to move the models to + + if not file_paths or not target_path: + return web.Response(text='File paths and target path are required', status=400) + + results = [] + import os + for file_path in file_paths: + # Check if source and destination are the same + source_dir = os.path.dirname(file_path) + if os.path.normpath(source_dir) == os.path.normpath(target_path): + results.append({ + "path": file_path, + "success": True, + "message": "Source and target directories are the same" + }) + continue + + # Check if target file already exists + file_name = os.path.basename(file_path) + target_file_path = os.path.join(target_path, file_name).replace(os.sep, '/') + + if os.path.exists(target_file_path): + results.append({ + "path": file_path, + "success": False, + "message": f"Target file already exists: {target_file_path}" + }) + continue + + # Try to move the model + success = await self.service.scanner.move_model(file_path, target_path) + results.append({ + "path": file_path, + "success": success, + "message": "Success" if success else "Failed to move model" + }) + + # Count successes and failures + success_count = sum(1 for r in results if r["success"]) + failure_count = len(results) - success_count + + return web.json_response({ + 'success': True, + 'message': f'Moved {success_count} of {len(file_paths)} models', + 'results': results, + 'success_count': success_count, + 'failure_count': failure_count + }) + + except Exception as e: + logger.error(f"Error moving models in bulk: {e}", exc_info=True) + return web.Response(text=str(e), status=500) + + async def get_lora_model_description(self, request: web.Request) -> web.Response: + """Get model description for a Lora model""" + try: + # Get parameters + model_id = request.query.get('model_id') + file_path = request.query.get('file_path') + + if not model_id: + return web.json_response({ + 'success': False, + 'error': 'Model ID is required' + }, status=400) + + # Check if we already have the description stored in metadata + description = None + tags = [] + creator = {} + if file_path: + import os + from ..utils.metadata_manager import MetadataManager + metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' + metadata = await ModelRouteUtils.load_local_metadata(metadata_path) + description = metadata.get('modelDescription') + tags = metadata.get('tags', []) + creator = metadata.get('creator', {}) + + # If description is not in metadata, fetch from CivitAI + if not description: + logger.info(f"Fetching model metadata for model ID: {model_id}") + model_metadata, _ = await self.civitai_client.get_model_metadata(model_id) + + if model_metadata: + description = model_metadata.get('description') + tags = model_metadata.get('tags', []) + creator = model_metadata.get('creator', {}) + + # Save the metadata to file if we have a file path and got metadata + if file_path: + try: + metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' + metadata = await ModelRouteUtils.load_local_metadata(metadata_path) + + metadata['modelDescription'] = description + metadata['tags'] = tags + # Ensure the civitai dict exists + if 'civitai' not in metadata: + metadata['civitai'] = {} + # Store creator in the civitai nested structure + metadata['civitai']['creator'] = creator + + await MetadataManager.save_metadata(file_path, metadata, True) + except Exception as e: + logger.error(f"Error saving model metadata: {e}") + + return web.json_response({ + 'success': True, + 'description': description or "No model description available.
", + 'tags': tags, + 'creator': creator + }) + + except Exception as e: + logger.error(f"Error getting model metadata: {e}") + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + async def get_trigger_words(self, request: web.Request) -> web.Response: + """Get trigger words for specified LoRA models""" + try: + json_data = await request.json() + lora_names = json_data.get("lora_names", []) + node_ids = json_data.get("node_ids", []) + + all_trigger_words = [] + for lora_name in lora_names: + _, trigger_words = get_lora_info(lora_name) + all_trigger_words.extend(trigger_words) + + # Format the trigger words + trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else "" + + # Send update to all connected trigger word toggle nodes + for node_id in node_ids: + PromptServer.instance.send_sync("trigger_word_update", { + "id": node_id, + "message": trigger_words_text + }) + + return web.json_response({"success": True}) + + except Exception as e: + logger.error(f"Error getting trigger words: {e}") + return web.json_response({ + "success": False, + "error": str(e) + }, status=500) diff --git a/py/routes/misc_routes.py b/py/routes/misc_routes.py index 019aa82e..596e0323 100644 --- a/py/routes/misc_routes.py +++ b/py/routes/misc_routes.py @@ -633,9 +633,8 @@ class MiscRoutes: }, status=400) # Get both lora and checkpoint scanners - registry = ServiceRegistry.get_instance() - lora_scanner = await registry.get_lora_scanner() - checkpoint_scanner = await registry.get_checkpoint_scanner() + lora_scanner = await ServiceRegistry.get_lora_scanner() + checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() # If modelVersionId is provided, check for specific version if model_version_id_str: diff --git a/py/routes/recipe_routes.py b/py/routes/recipe_routes.py index 393a5ae3..d181ed65 100644 --- a/py/routes/recipe_routes.py +++ b/py/routes/recipe_routes.py @@ -1,6 +1,7 @@ import os import time import base64 +import jinja2 import numpy as np from PIL import Image import io @@ -15,6 +16,7 @@ from ..utils.exif_utils import ExifUtils from ..recipes import RecipeParserFactory from ..utils.constants import CARD_PREVIEW_WIDTH +from ..services.settings_manager import settings from ..config import config # Check if running in standalone mode @@ -39,7 +41,10 @@ class RecipeRoutes: # Initialize service references as None, will be set during async init self.recipe_scanner = None self.civitai_client = None - # Remove WorkflowParser instance + self.template_env = jinja2.Environment( + loader=jinja2.FileSystemLoader(config.templates_path), + autoescape=True + ) # Pre-warm the cache self._init_cache_task = None @@ -53,6 +58,8 @@ class RecipeRoutes: def setup_routes(cls, app: web.Application): """Register API routes""" routes = cls() + app.router.add_get('/loras/recipes', routes.handle_recipes_page) + app.router.add_get('/api/recipes', routes.get_recipes) app.router.add_get('/api/recipe/{recipe_id}', routes.get_recipe_detail) app.router.add_post('/api/recipes/analyze-image', routes.analyze_recipe_image) @@ -114,6 +121,46 @@ class RecipeRoutes: await self.recipe_scanner.get_cached_data(force_refresh=True) except Exception as e: logger.error(f"Error pre-warming recipe cache: {e}", exc_info=True) + + async def handle_recipes_page(self, request: web.Request) -> web.Response: + """Handle GET /loras/recipes request""" + try: + # Ensure services are initialized + await self.init_services() + + # Skip initialization check and directly try to get cached data + try: + # Recipe scanner will initialize cache if needed + await self.recipe_scanner.get_cached_data(force_refresh=False) + template = self.template_env.get_template('recipes.html') + rendered = template.render( + recipes=[], # Frontend will load recipes via API + is_initializing=False, + settings=settings, + request=request + ) + except Exception as cache_error: + logger.error(f"Error loading recipe cache data: {cache_error}") + # Still keep error handling - show initializing page on error + template = self.template_env.get_template('recipes.html') + rendered = template.render( + is_initializing=True, + settings=settings, + request=request + ) + logger.info("Recipe cache error, returning initialization page") + + return web.Response( + text=rendered, + content_type='text/html' + ) + + except Exception as e: + logger.error(f"Error handling recipes request: {e}", exc_info=True) + return web.Response( + text="Error loading recipes page", + status=500 + ) async def get_recipes(self, request: web.Request) -> web.Response: """API endpoint for getting paginated recipes""" diff --git a/py/services/base_model_service.py b/py/services/base_model_service.py new file mode 100644 index 00000000..7ecc994b --- /dev/null +++ b/py/services/base_model_service.py @@ -0,0 +1,248 @@ +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Type, Set +import logging + +from ..utils.models import BaseModelMetadata +from ..utils.constants import NSFW_LEVELS +from .settings_manager import settings +from ..utils.utils import fuzzy_match + +logger = logging.getLogger(__name__) + +class BaseModelService(ABC): + """Base service class for all model types""" + + def __init__(self, model_type: str, scanner, metadata_class: Type[BaseModelMetadata]): + """Initialize the service + + Args: + model_type: Type of model (lora, checkpoint, etc.) + scanner: Model scanner instance + metadata_class: Metadata class for this model type + """ + self.model_type = model_type + self.scanner = scanner + self.metadata_class = metadata_class + + async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name', + folder: str = None, search: str = None, fuzzy_search: bool = False, + base_models: list = None, tags: list = None, + search_options: dict = None, hash_filters: dict = None, + favorites_only: bool = False, **kwargs) -> Dict: + """Get paginated and filtered model data + + Args: + page: Page number (1-based) + page_size: Number of items per page + sort_by: Sort criteria ('name' or 'date') + folder: Folder filter + search: Search term + fuzzy_search: Whether to use fuzzy search + base_models: List of base models to filter by + tags: List of tags to filter by + search_options: Search options dict + hash_filters: Hash filtering options + favorites_only: Filter for favorites only + **kwargs: Additional model-specific filters + + Returns: + Dict containing paginated results + """ + cache = await self.scanner.get_cached_data() + + # Get default search options if not provided + if search_options is None: + search_options = { + 'filename': True, + 'modelname': True, + 'tags': False, + 'recursive': False, + } + + # Get the base data set + filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name + + # Apply hash filtering if provided (highest priority) + if hash_filters: + filtered_data = await self._apply_hash_filters(filtered_data, hash_filters) + + # Jump to pagination for hash filters + return self._paginate(filtered_data, page, page_size) + + # Apply common filters + filtered_data = await self._apply_common_filters( + filtered_data, folder, base_models, tags, favorites_only, search_options + ) + + # Apply search filtering + if search: + filtered_data = await self._apply_search_filters( + filtered_data, search, fuzzy_search, search_options + ) + + # Apply model-specific filters + filtered_data = await self._apply_specific_filters(filtered_data, **kwargs) + + return self._paginate(filtered_data, page, page_size) + + async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]: + """Apply hash-based filtering""" + single_hash = hash_filters.get('single_hash') + multiple_hashes = hash_filters.get('multiple_hashes') + + if single_hash: + # Filter by single hash + single_hash = single_hash.lower() + return [ + item for item in data + if item.get('sha256', '').lower() == single_hash + ] + elif multiple_hashes: + # Filter by multiple hashes + hash_set = set(hash.lower() for hash in multiple_hashes) + return [ + item for item in data + if item.get('sha256', '').lower() in hash_set + ] + + return data + + async def _apply_common_filters(self, data: List[Dict], folder: str = None, + base_models: list = None, tags: list = None, + favorites_only: bool = False, search_options: dict = None) -> List[Dict]: + """Apply common filters that work across all model types""" + # Apply SFW filtering if enabled in settings + if settings.get('show_only_sfw', False): + data = [ + item for item in data + if not item.get('preview_nsfw_level') or item.get('preview_nsfw_level') < NSFW_LEVELS['R'] + ] + + # Apply favorites filtering if enabled + if favorites_only: + data = [ + item for item in data + if item.get('favorite', False) is True + ] + + # Apply folder filtering + if folder is not None: + if search_options and search_options.get('recursive', False): + # Recursive folder filtering - include all subfolders + data = [ + item for item in data + if item['folder'].startswith(folder) + ] + else: + # Exact folder filtering + data = [ + item for item in data + if item['folder'] == folder + ] + + # Apply base model filtering + if base_models and len(base_models) > 0: + data = [ + item for item in data + if item.get('base_model') in base_models + ] + + # Apply tag filtering + if tags and len(tags) > 0: + data = [ + item for item in data + if any(tag in item.get('tags', []) for tag in tags) + ] + + return data + + async def _apply_search_filters(self, data: List[Dict], search: str, + fuzzy_search: bool, search_options: dict) -> List[Dict]: + """Apply search filtering""" + search_results = [] + + for item in data: + # Search by file name + if search_options.get('filename', True): + if fuzzy_search: + if fuzzy_match(item.get('file_name', ''), search): + search_results.append(item) + continue + elif search.lower() in item.get('file_name', '').lower(): + search_results.append(item) + continue + + # Search by model name + if search_options.get('modelname', True): + if fuzzy_search: + if fuzzy_match(item.get('model_name', ''), search): + search_results.append(item) + continue + elif search.lower() in item.get('model_name', '').lower(): + search_results.append(item) + continue + + # Search by tags + if search_options.get('tags', False) and 'tags' in item: + if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower()) + for tag in item['tags']): + search_results.append(item) + continue + + return search_results + + async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]: + """Apply model-specific filters - to be overridden by subclasses if needed""" + return data + + def _paginate(self, data: List[Dict], page: int, page_size: int) -> Dict: + """Apply pagination to filtered data""" + total_items = len(data) + start_idx = (page - 1) * page_size + end_idx = min(start_idx + page_size, total_items) + + return { + 'items': data[start_idx:end_idx], + 'total': total_items, + 'page': page, + 'page_size': page_size, + 'total_pages': (total_items + page_size - 1) // page_size + } + + @abstractmethod + async def format_response(self, model_data: Dict) -> Dict: + """Format model data for API response - must be implemented by subclasses""" + pass + + # Common service methods that delegate to scanner + async def get_top_tags(self, limit: int = 20) -> List[Dict]: + """Get top tags sorted by frequency""" + return await self.scanner.get_top_tags(limit) + + async def get_base_models(self, limit: int = 20) -> List[Dict]: + """Get base models sorted by frequency""" + return await self.scanner.get_base_models(limit) + + def has_hash(self, sha256: str) -> bool: + """Check if a model with given hash exists""" + return self.scanner.has_hash(sha256) + + def get_path_by_hash(self, sha256: str) -> Optional[str]: + """Get file path for a model by its hash""" + return self.scanner.get_path_by_hash(sha256) + + def get_hash_by_path(self, file_path: str) -> Optional[str]: + """Get hash for a model by its file path""" + return self.scanner.get_hash_by_path(file_path) + + async def scan_models(self, force_refresh: bool = False, rebuild_cache: bool = False): + """Trigger model scanning""" + return await self.scanner.get_cached_data(force_refresh=force_refresh, rebuild_cache=rebuild_cache) + + async def get_model_info_by_name(self, name: str): + """Get model information by name""" + return await self.scanner.get_model_info_by_name(name) + + def get_model_roots(self) -> List[str]: + """Get model root directories""" + return self.scanner.get_model_roots() \ No newline at end of file diff --git a/py/services/checkpoint_scanner.py b/py/services/checkpoint_scanner.py index 26733ab5..32da3dbf 100644 --- a/py/services/checkpoint_scanner.py +++ b/py/services/checkpoint_scanner.py @@ -1,14 +1,12 @@ import os import logging import asyncio -from typing import List, Dict, Optional, Set -import folder_paths # type: ignore +from typing import List, Dict from ..utils.models import CheckpointMetadata from ..config import config from .model_scanner import ModelScanner from .model_hash_index import ModelHashIndex -from .service_registry import ServiceRegistry logger = logging.getLogger(__name__) @@ -109,4 +107,31 @@ class CheckpointScanner(ModelScanner): if result: checkpoints.append(result) except Exception as e: - logger.error(f"Error processing {file_path}: {e}") \ No newline at end of file + logger.error(f"Error processing {file_path}: {e}") + + # Checkpoint-specific hash index functionality + def has_checkpoint_hash(self, sha256: str) -> bool: + """Check if a checkpoint with given hash exists""" + return self.has_hash(sha256) + + def get_checkpoint_path_by_hash(self, sha256: str) -> str: + """Get file path for a checkpoint by its hash""" + return self.get_path_by_hash(sha256) + + def get_checkpoint_hash_by_path(self, file_path: str) -> str: + """Get hash for a checkpoint by its file path""" + return self.get_hash_by_path(file_path) + + async def get_checkpoint_info_by_name(self, name): + """Get checkpoint information by name""" + try: + cache = await self.get_cached_data() + + for checkpoint in cache.raw_data: + if checkpoint.get("file_name") == name: + return checkpoint + + return None + except Exception as e: + logger.error(f"Error getting checkpoint info by name: {e}", exc_info=True) + return None \ No newline at end of file diff --git a/py/services/checkpoint_service.py b/py/services/checkpoint_service.py new file mode 100644 index 00000000..cfc4b118 --- /dev/null +++ b/py/services/checkpoint_service.py @@ -0,0 +1,51 @@ +import os +import logging +from typing import Dict, List, Optional + +from .base_model_service import BaseModelService +from ..utils.models import CheckpointMetadata +from ..config import config +from ..utils.routes_common import ModelRouteUtils + +logger = logging.getLogger(__name__) + +class CheckpointService(BaseModelService): + """Checkpoint-specific service implementation""" + + def __init__(self, scanner): + """Initialize Checkpoint service + + Args: + scanner: Checkpoint scanner instance + """ + super().__init__("checkpoint", scanner, CheckpointMetadata) + + async def format_response(self, checkpoint_data: Dict) -> Dict: + """Format Checkpoint data for API response""" + return { + "model_name": checkpoint_data["model_name"], + "file_name": checkpoint_data["file_name"], + "preview_url": config.get_preview_static_url(checkpoint_data.get("preview_url", "")), + "preview_nsfw_level": checkpoint_data.get("preview_nsfw_level", 0), + "base_model": checkpoint_data.get("base_model", ""), + "folder": checkpoint_data["folder"], + "sha256": checkpoint_data.get("sha256", ""), + "file_path": checkpoint_data["file_path"].replace(os.sep, "/"), + "file_size": checkpoint_data.get("size", 0), + "modified": checkpoint_data.get("modified", ""), + "tags": checkpoint_data.get("tags", []), + "modelDescription": checkpoint_data.get("modelDescription", ""), + "from_civitai": checkpoint_data.get("from_civitai", True), + "notes": checkpoint_data.get("notes", ""), + "model_type": checkpoint_data.get("model_type", "checkpoint"), + "favorite": checkpoint_data.get("favorite", False), + "civitai": ModelRouteUtils.filter_civitai_data(checkpoint_data.get("civitai", {})) + } + + def find_duplicate_hashes(self) -> Dict: + """Find Checkpoints with duplicate SHA256 hashes""" + return self.scanner._hash_index.get_duplicate_hashes() + + def find_duplicate_filenames(self) -> Dict: + """Find Checkpoints with conflicting filenames""" + return self.scanner._hash_index.get_duplicate_filenames() \ No newline at end of file diff --git a/py/services/lora_service.py b/py/services/lora_service.py new file mode 100644 index 00000000..bcfa84c5 --- /dev/null +++ b/py/services/lora_service.py @@ -0,0 +1,172 @@ +import os +import logging +from typing import Dict, List, Optional + +from .base_model_service import BaseModelService +from ..utils.models import LoraMetadata +from ..config import config +from ..utils.routes_common import ModelRouteUtils + +logger = logging.getLogger(__name__) + +class LoraService(BaseModelService): + """LoRA-specific service implementation""" + + def __init__(self, scanner): + """Initialize LoRA service + + Args: + scanner: LoRA scanner instance + """ + super().__init__("lora", scanner, LoraMetadata) + + async def format_response(self, lora_data: Dict) -> Dict: + """Format LoRA data for API response""" + return { + "model_name": lora_data["model_name"], + "file_name": lora_data["file_name"], + "preview_url": config.get_preview_static_url(lora_data.get("preview_url", "")), + "preview_nsfw_level": lora_data.get("preview_nsfw_level", 0), + "base_model": lora_data.get("base_model", ""), + "folder": lora_data["folder"], + "sha256": lora_data.get("sha256", ""), + "file_path": lora_data["file_path"].replace(os.sep, "/"), + "file_size": lora_data.get("size", 0), + "modified": lora_data.get("modified", ""), + "tags": lora_data.get("tags", []), + "modelDescription": lora_data.get("modelDescription", ""), + "from_civitai": lora_data.get("from_civitai", True), + "usage_tips": lora_data.get("usage_tips", ""), + "notes": lora_data.get("notes", ""), + "favorite": lora_data.get("favorite", False), + "civitai": ModelRouteUtils.filter_civitai_data(lora_data.get("civitai", {})) + } + + async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]: + """Apply LoRA-specific filters""" + # Handle first_letter filter for LoRAs + first_letter = kwargs.get('first_letter') + if first_letter: + data = self._filter_by_first_letter(data, first_letter) + + return data + + def _filter_by_first_letter(self, data: List[Dict], letter: str) -> List[Dict]: + """Filter LoRAs by first letter""" + if letter == '#': + # Filter for non-alphabetic characters + return [ + item for item in data + if not item.get('model_name', '')[0].isalpha() + ] + elif letter == 'CJK': + # Filter for CJK characters + return [ + item for item in data + if item.get('model_name', '') and self._is_cjk_character(item['model_name'][0]) + ] + else: + # Filter for specific letter + return [ + item for item in data + if item.get('model_name', '').lower().startswith(letter.lower()) + ] + + def _is_cjk_character(self, char: str) -> bool: + """Check if character is CJK (Chinese, Japanese, Korean)""" + cjk_ranges = [ + (0x4E00, 0x9FFF), # CJK Unified Ideographs + (0x3400, 0x4DBF), # CJK Extension A + (0x20000, 0x2A6DF), # CJK Extension B + (0x2A700, 0x2B73F), # CJK Extension C + (0x2B740, 0x2B81F), # CJK Extension D + (0x3040, 0x309F), # Hiragana + (0x30A0, 0x30FF), # Katakana + (0xAC00, 0xD7AF), # Hangul Syllables + ] + + char_code = ord(char) + return any(start <= char_code <= end for start, end in cjk_ranges) + + # LoRA-specific methods + async def get_letter_counts(self) -> Dict[str, int]: + """Get count of LoRAs for each letter of the alphabet""" + cache = await self.scanner.get_cached_data() + letter_counts = {} + + for lora in cache.raw_data: + model_name = lora.get('model_name', '') + if model_name: + first_char = model_name[0].upper() + if first_char.isalpha(): + letter_counts[first_char] = letter_counts.get(first_char, 0) + 1 + elif self._is_cjk_character(first_char): + letter_counts['CJK'] = letter_counts.get('CJK', 0) + 1 + else: + letter_counts['#'] = letter_counts.get('#', 0) + 1 + + return letter_counts + + async def get_lora_notes(self, lora_name: str) -> Optional[str]: + """Get notes for a specific LoRA file""" + cache = await self.scanner.get_cached_data() + + for lora in cache.raw_data: + if lora['file_name'] == lora_name: + return lora.get('notes', '') + + return None + + async def get_lora_trigger_words(self, lora_name: str) -> List[str]: + """Get trigger words for a specific LoRA file""" + cache = await self.scanner.get_cached_data() + + for lora in cache.raw_data: + if lora['file_name'] == lora_name: + civitai_data = lora.get('civitai', {}) + return civitai_data.get('trainedWords', []) + + return [] + + async def get_lora_preview_url(self, lora_name: str) -> Optional[str]: + """Get the static preview URL for a LoRA file""" + cache = await self.scanner.get_cached_data() + + for lora in cache.raw_data: + if lora['file_name'] == lora_name: + preview_url = lora.get('preview_url') + if preview_url: + return config.get_preview_static_url(preview_url) + + return None + + async def get_lora_civitai_url(self, lora_name: str) -> Dict[str, Optional[str]]: + """Get the Civitai URL for a LoRA file""" + cache = await self.scanner.get_cached_data() + + for lora in cache.raw_data: + if lora['file_name'] == lora_name: + civitai_data = lora.get('civitai', {}) + model_id = civitai_data.get('modelId') + version_id = civitai_data.get('id') + + if model_id: + civitai_url = f"https://civitai.com/models/{model_id}" + if version_id: + civitai_url += f"?modelVersionId={version_id}" + + return { + 'civitai_url': civitai_url, + 'model_id': str(model_id), + 'version_id': str(version_id) if version_id else None + } + + return {'civitai_url': None, 'model_id': None, 'version_id': None} + + def find_duplicate_hashes(self) -> Dict: + """Find LoRAs with duplicate SHA256 hashes""" + return self.scanner._hash_index.get_duplicate_hashes() + + def find_duplicate_filenames(self) -> Dict: + """Find LoRAs with conflicting filenames""" + return self.scanner._hash_index.get_duplicate_filenames() \ No newline at end of file diff --git a/py/services/model_service_factory.py b/py/services/model_service_factory.py new file mode 100644 index 00000000..6cc8a3a3 --- /dev/null +++ b/py/services/model_service_factory.py @@ -0,0 +1,137 @@ +from typing import Dict, Type, Any +import logging + +logger = logging.getLogger(__name__) + +class ModelServiceFactory: + """Factory for managing model services and routes""" + + _services: Dict[str, Type] = {} + _routes: Dict[str, Type] = {} + _initialized_services: Dict[str, Any] = {} + _initialized_routes: Dict[str, Any] = {} + + @classmethod + def register_model_type(cls, model_type: str, service_class: Type, route_class: Type): + """Register a new model type with its service and route classes + + Args: + model_type: The model type identifier (e.g., 'lora', 'checkpoint') + service_class: The service class for this model type + route_class: The route class for this model type + """ + cls._services[model_type] = service_class + cls._routes[model_type] = route_class + logger.info(f"Registered model type '{model_type}' with service {service_class.__name__} and routes {route_class.__name__}") + + @classmethod + def get_service_class(cls, model_type: str) -> Type: + """Get service class for a model type + + Args: + model_type: The model type identifier + + Returns: + The service class for the model type + + Raises: + ValueError: If model type is not registered + """ + if model_type not in cls._services: + raise ValueError(f"Unknown model type: {model_type}") + return cls._services[model_type] + + @classmethod + def get_route_class(cls, model_type: str) -> Type: + """Get route class for a model type + + Args: + model_type: The model type identifier + + Returns: + The route class for the model type + + Raises: + ValueError: If model type is not registered + """ + if model_type not in cls._routes: + raise ValueError(f"Unknown model type: {model_type}") + return cls._routes[model_type] + + @classmethod + def get_route_instance(cls, model_type: str): + """Get or create route instance for a model type + + Args: + model_type: The model type identifier + + Returns: + The route instance for the model type + """ + if model_type not in cls._initialized_routes: + route_class = cls.get_route_class(model_type) + cls._initialized_routes[model_type] = route_class() + return cls._initialized_routes[model_type] + + @classmethod + def setup_all_routes(cls, app): + """Setup routes for all registered model types + + Args: + app: The aiohttp application instance + """ + logger.info(f"Setting up routes for {len(cls._services)} registered model types") + + for model_type in cls._services.keys(): + try: + routes_instance = cls.get_route_instance(model_type) + routes_instance.setup_routes(app) + logger.info(f"Successfully set up routes for {model_type}") + except Exception as e: + logger.error(f"Failed to setup routes for {model_type}: {e}", exc_info=True) + + @classmethod + def get_registered_types(cls) -> list: + """Get list of all registered model types + + Returns: + List of registered model type identifiers + """ + return list(cls._services.keys()) + + @classmethod + def is_registered(cls, model_type: str) -> bool: + """Check if a model type is registered + + Args: + model_type: The model type identifier + + Returns: + True if the model type is registered, False otherwise + """ + return model_type in cls._services + + @classmethod + def clear_registrations(cls): + """Clear all registrations - mainly for testing purposes""" + cls._services.clear() + cls._routes.clear() + cls._initialized_services.clear() + cls._initialized_routes.clear() + logger.info("Cleared all model type registrations") + + +def register_default_model_types(): + """Register the default model types (LoRA and Checkpoint)""" + from ..services.lora_service import LoraService + from ..services.checkpoint_service import CheckpointService + from ..routes.lora_routes import LoraRoutes + from ..routes.checkpoint_routes import CheckpointRoutes + + # Register LoRA model type + ModelServiceFactory.register_model_type('lora', LoraService, LoraRoutes) + + # Register Checkpoint model type + ModelServiceFactory.register_model_type('checkpoint', CheckpointService, CheckpointRoutes) + + logger.info("Registered default model types: lora, checkpoint") \ No newline at end of file diff --git a/py/services/service_registry.py b/py/services/service_registry.py index 15c00a3f..9589b984 100644 --- a/py/services/service_registry.py +++ b/py/services/service_registry.py @@ -7,106 +7,176 @@ logger = logging.getLogger(__name__) T = TypeVar('T') # Define a type variable for service types class ServiceRegistry: - """Centralized registry for service singletons""" + """Central registry for managing singleton services""" - _instance = None _services: Dict[str, Any] = {} - _lock = asyncio.Lock() + _locks: Dict[str, asyncio.Lock] = {} @classmethod - def get_instance(cls): - """Get singleton instance of the registry""" - if cls._instance is None: - cls._instance = cls() - return cls._instance + async def register_service(cls, name: str, service: Any) -> None: + """Register a service instance with the registry + + Args: + name: Service name identifier + service: Service instance to register + """ + cls._services[name] = service + logger.debug(f"Registered service: {name}") @classmethod - async def register_service(cls, service_name: str, service_instance: Any) -> None: - """Register a service instance with the registry""" - registry = cls.get_instance() - async with cls._lock: - registry._services[service_name] = service_instance - logger.debug(f"Registered service: {service_name}") + async def get_service(cls, name: str) -> Optional[Any]: + """Get a service instance by name + + Args: + name: Service name identifier + + Returns: + Service instance or None if not found + """ + return cls._services.get(name) @classmethod - async def get_service(cls, service_name: str) -> Any: - """Get a service instance by name""" - registry = cls.get_instance() - async with cls._lock: - if service_name not in registry._services: - logger.debug(f"Service {service_name} not found in registry") - return None - return registry._services[service_name] + def _get_lock(cls, name: str) -> asyncio.Lock: + """Get or create a lock for a service + + Args: + name: Service name identifier + + Returns: + AsyncIO lock for the service + """ + if name not in cls._locks: + cls._locks[name] = asyncio.Lock() + return cls._locks[name] - @classmethod - def get_service_sync(cls, service_name: str) -> Any: - """Get a service instance by name (synchronous version)""" - registry = cls.get_instance() - if service_name not in registry._services: - logger.debug(f"Service {service_name} not found in registry") - return None - return registry._services[service_name] - - # Convenience methods for common services @classmethod async def get_lora_scanner(cls): - """Get the LoraScanner instance""" - from .lora_scanner import LoraScanner - scanner = await cls.get_service("lora_scanner") - if scanner is None: - scanner = await LoraScanner.get_instance() - await cls.register_service("lora_scanner", scanner) - return scanner + """Get or create LoRA scanner instance""" + service_name = "lora_scanner" + if service_name in cls._services: + return cls._services[service_name] + + async with cls._get_lock(service_name): + # Double-check after acquiring lock + if service_name in cls._services: + return cls._services[service_name] + + # Import here to avoid circular imports + from .lora_scanner import LoraScanner + + scanner = await LoraScanner.get_instance() + cls._services[service_name] = scanner + logger.info(f"Created and registered {service_name}") + return scanner + @classmethod async def get_checkpoint_scanner(cls): - """Get the CheckpointScanner instance""" - from .checkpoint_scanner import CheckpointScanner - scanner = await cls.get_service("checkpoint_scanner") - if scanner is None: + """Get or create Checkpoint scanner instance""" + service_name = "checkpoint_scanner" + + if service_name in cls._services: + return cls._services[service_name] + + async with cls._get_lock(service_name): + # Double-check after acquiring lock + if service_name in cls._services: + return cls._services[service_name] + + # Import here to avoid circular imports + from .checkpoint_scanner import CheckpointScanner + scanner = await CheckpointScanner.get_instance() - await cls.register_service("checkpoint_scanner", scanner) - return scanner - - @classmethod - async def get_civitai_client(cls): - """Get the CivitaiClient instance""" - from .civitai_client import CivitaiClient - client = await cls.get_service("civitai_client") - if client is None: - client = await CivitaiClient.get_instance() - await cls.register_service("civitai_client", client) - return client - - @classmethod - async def get_download_manager(cls): - """Get the DownloadManager instance""" - from .download_manager import DownloadManager - manager = await cls.get_service("download_manager") - if manager is None: - manager = await DownloadManager.get_instance() - await cls.register_service("download_manager", manager) - return manager - + cls._services[service_name] = scanner + logger.info(f"Created and registered {service_name}") + return scanner + @classmethod async def get_recipe_scanner(cls): - """Get the RecipeScanner instance""" - from .recipe_scanner import RecipeScanner - scanner = await cls.get_service("recipe_scanner") - if scanner is None: - lora_scanner = await cls.get_lora_scanner() - scanner = RecipeScanner(lora_scanner) - await cls.register_service("recipe_scanner", scanner) - return scanner - + """Get or create Recipe scanner instance""" + service_name = "recipe_scanner" + + if service_name in cls._services: + return cls._services[service_name] + + async with cls._get_lock(service_name): + # Double-check after acquiring lock + if service_name in cls._services: + return cls._services[service_name] + + # Import here to avoid circular imports + from .recipe_scanner import RecipeScanner + + scanner = await RecipeScanner.get_instance() + cls._services[service_name] = scanner + logger.info(f"Created and registered {service_name}") + return scanner + + @classmethod + async def get_civitai_client(cls): + """Get or create CivitAI client instance""" + service_name = "civitai_client" + + if service_name in cls._services: + return cls._services[service_name] + + async with cls._get_lock(service_name): + # Double-check after acquiring lock + if service_name in cls._services: + return cls._services[service_name] + + # Import here to avoid circular imports + from .civitai_client import CivitaiClient + + client = await CivitaiClient.get_instance() + cls._services[service_name] = client + logger.info(f"Created and registered {service_name}") + return client + + @classmethod + async def get_download_manager(cls): + """Get or create Download manager instance""" + service_name = "download_manager" + + if service_name in cls._services: + return cls._services[service_name] + + async with cls._get_lock(service_name): + # Double-check after acquiring lock + if service_name in cls._services: + return cls._services[service_name] + + # Import here to avoid circular imports + from .download_manager import DownloadManager + + manager = DownloadManager() + cls._services[service_name] = manager + logger.info(f"Created and registered {service_name}") + return manager + @classmethod async def get_websocket_manager(cls): - """Get the WebSocketManager instance""" - from .websocket_manager import ws_manager - manager = await cls.get_service("websocket_manager") - if manager is None: - # ws_manager is already a global instance in websocket_manager.py + """Get or create WebSocket manager instance""" + service_name = "websocket_manager" + + if service_name in cls._services: + return cls._services[service_name] + + async with cls._get_lock(service_name): + # Double-check after acquiring lock + if service_name in cls._services: + return cls._services[service_name] + + # Import here to avoid circular imports from .websocket_manager import ws_manager - await cls.register_service("websocket_manager", ws_manager) - manager = ws_manager - return manager \ No newline at end of file + + cls._services[service_name] = ws_manager + logger.info(f"Registered {service_name}") + return ws_manager + + @classmethod + def clear_services(cls): + """Clear all registered services - mainly for testing""" + cls._services.clear() + cls._locks.clear() + logger.info("Cleared all registered services") \ No newline at end of file diff --git a/py/utils/routes_common.py b/py/utils/routes_common.py index e151d216..8e5df544 100644 --- a/py/utils/routes_common.py +++ b/py/utils/routes_common.py @@ -1047,3 +1047,56 @@ class ModelRouteUtils: 'success': False, 'error': str(e) }, status=500) + + @staticmethod + async def handle_save_metadata(request: web.Request, scanner) -> web.Response: + """Handle saving metadata updates + + Args: + request: The aiohttp request + scanner: The model scanner instance + + Returns: + web.Response: The HTTP response + """ + try: + data = await request.json() + file_path = data.get('file_path') + if not file_path: + return web.Response(text='File path is required', status=400) + + # Remove file path from data to avoid saving it + metadata_updates = {k: v for k, v in data.items() if k != 'file_path'} + + # Get metadata file path + metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' + + # Load existing metadata + metadata = await ModelRouteUtils.load_local_metadata(metadata_path) + + # Handle nested updates (for civitai.trainedWords) + for key, value in metadata_updates.items(): + if isinstance(value, dict) and key in metadata and isinstance(metadata[key], dict): + # Deep update for nested dictionaries + for nested_key, nested_value in value.items(): + metadata[key][nested_key] = nested_value + else: + # Regular update for top-level keys + metadata[key] = value + + # Save updated metadata + await MetadataManager.save_metadata(file_path, metadata) + + # Update cache + await scanner.update_single_model_cache(file_path, file_path, metadata) + + # If model_name was updated, resort the cache + if 'model_name' in metadata_updates: + cache = await scanner.get_cached_data() + await cache.resort(name_only=True) + + return web.json_response({'success': True}) + + except Exception as e: + logger.error(f"Error saving metadata: {e}", exc_info=True) + return web.Response(text=str(e), status=500) diff --git a/standalone.py b/standalone.py index 0120ab53..0e491c14 100644 --- a/standalone.py +++ b/standalone.py @@ -314,22 +314,23 @@ class StandaloneLoraManager(LoraManager): app.router.add_static('/loras_static', config.static_path) # Setup feature routes - from py.routes.lora_routes import LoraRoutes + from py.services.model_service_factory import ModelServiceFactory, register_default_model_types from py.routes.api_routes import ApiRoutes from py.routes.recipe_routes import RecipeRoutes - from py.routes.checkpoints_routes import CheckpointsRoutes from py.routes.update_routes import UpdateRoutes from py.routes.misc_routes import MiscRoutes from py.routes.example_images_routes import ExampleImagesRoutes from py.routes.stats_routes import StatsRoutes - lora_routes = LoraRoutes() - checkpoints_routes = CheckpointsRoutes() + + register_default_model_types() + + # Setup all model routes using the factory + ModelServiceFactory.setup_all_routes(app) + stats_routes = StatsRoutes() # Initialize routes - lora_routes.setup_routes(app) - checkpoints_routes.setup_routes(app) stats_routes.setup_routes(app) ApiRoutes.setup_routes(app) RecipeRoutes.setup_routes(app) From c2e00b240e138d2a06f18f1393db7b6c6473376d Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 23 Jul 2025 15:30:39 +0800 Subject: [PATCH 02/18] feat: Enhance model routes with generic page handling and template integration --- py/routes/base_model_routes.py | 63 +++++++++++++++++++++++++++++++ py/routes/checkpoint_routes.py | 69 +--------------------------------- py/routes/lora_routes.py | 69 +--------------------------------- 3 files changed, 67 insertions(+), 134 deletions(-) diff --git a/py/routes/base_model_routes.py b/py/routes/base_model_routes.py index 7a6e4aac..8f5bc3f6 100644 --- a/py/routes/base_model_routes.py +++ b/py/routes/base_model_routes.py @@ -4,8 +4,12 @@ import logging from aiohttp import web from typing import Dict +import jinja2 + from ..utils.routes_common import ModelRouteUtils from ..services.websocket_manager import ws_manager +from ..services.settings_manager import settings +from ..config import config logger = logging.getLogger(__name__) @@ -20,6 +24,10 @@ class BaseModelRoutes(ABC): """ self.service = service self.model_type = service.model_type + self.template_env = jinja2.Environment( + loader=jinja2.FileSystemLoader(config.templates_path), + autoescape=True + ) def setup_routes(self, app: web.Application, prefix: str): """Setup common routes for the model type @@ -52,6 +60,9 @@ class BaseModelRoutes(ABC): app.router.add_post(f'/api/{prefix}/fetch-all-civitai', self.fetch_all_civitai) app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions) + # Add generic page route + app.router.add_get(f'/{prefix}', self.handle_models_page) + # Setup model-specific routes self.setup_specific_routes(app, prefix) @@ -60,6 +71,58 @@ class BaseModelRoutes(ABC): """Setup model-specific routes - to be implemented by subclasses""" pass + async def handle_models_page(self, request: web.Request) -> web.Response: + """ + Generic handler for model pages (e.g., /loras, /checkpoints). + Subclasses should set self.template_env and template_name. + """ + try: + # Check if the scanner is initializing + is_initializing = ( + self.service.scanner._cache is None or + (hasattr(self.service.scanner, 'is_initializing') and callable(self.service.scanner.is_initializing) and self.service.scanner.is_initializing()) or + (hasattr(self.service.scanner, '_is_initializing') and self.service.scanner._is_initializing) + ) + + template_name = getattr(self, "template_name", None) + if not self.template_env or not template_name: + return web.Response(text="Template environment or template name not set", status=500) + + if is_initializing: + rendered = self.template_env.get_template(template_name).render( + folders=[], + is_initializing=True, + settings=settings, + request=request + ) + else: + try: + cache = await self.service.scanner.get_cached_data(force_refresh=False) + rendered = self.template_env.get_template(template_name).render( + folders=getattr(cache, "folders", []), + is_initializing=False, + settings=settings, + request=request + ) + except Exception as cache_error: + logger.error(f"Error loading cache data: {cache_error}") + rendered = self.template_env.get_template(template_name).render( + folders=[], + is_initializing=True, + settings=settings, + request=request + ) + return web.Response( + text=rendered, + content_type='text/html' + ) + except Exception as e: + logger.error(f"Error handling models page: {e}", exc_info=True) + return web.Response( + text="Error loading models page", + status=500 + ) + async def get_models(self, request: web.Request) -> web.Response: """Get paginated model data""" try: diff --git a/py/routes/checkpoint_routes.py b/py/routes/checkpoint_routes.py index 5fdb670e..6ba550a6 100644 --- a/py/routes/checkpoint_routes.py +++ b/py/routes/checkpoint_routes.py @@ -1,12 +1,9 @@ -import jinja2 import logging from aiohttp import web from .base_model_routes import BaseModelRoutes from ..services.checkpoint_service import CheckpointService from ..services.service_registry import ServiceRegistry -from ..config import config -from ..services.settings_manager import settings logger = logging.getLogger(__name__) @@ -18,10 +15,7 @@ class CheckpointRoutes(BaseModelRoutes): # Service will be initialized later via setup_routes self.service = None self.civitai_client = None - self.template_env = jinja2.Environment( - loader=jinja2.FileSystemLoader(config.templates_path), - autoescape=True - ) + self.template_name = "checkpoints.html" async def initialize_services(self): """Initialize services from ServiceRegistry""" @@ -37,76 +31,17 @@ class CheckpointRoutes(BaseModelRoutes): # Schedule service initialization on app startup app.on_startup.append(lambda _: self.initialize_services()) - # Setup common routes with 'checkpoints' prefix + # Setup common routes with 'checkpoints' prefix (includes page route) super().setup_routes(app, 'checkpoints') def setup_specific_routes(self, app: web.Application, prefix: str): """Setup Checkpoint-specific routes""" - # Checkpoint page route - app.router.add_get('/checkpoints', self.handle_checkpoints_page) - # Checkpoint-specific CivitAI integration app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions_checkpoint) # Checkpoint info by name app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_checkpoint_info) - async def handle_checkpoints_page(self, request: web.Request) -> web.Response: - """Handle GET /checkpoints request""" - try: - # Check if the CheckpointScanner is initializing - # It's initializing if the cache object doesn't exist yet, - # OR if the scanner explicitly says it's initializing (background task running). - is_initializing = ( - self.service.scanner._cache is None or - (hasattr(self.service.scanner, '_is_initializing') and self.service.scanner._is_initializing) - ) - - if is_initializing: - # If still initializing, return loading page - template = self.template_env.get_template('checkpoints.html') - rendered = template.render( - folders=[], # Empty folder list - is_initializing=True, # New flag - settings=settings, # Pass settings to template - request=request # Pass the request object to the template - ) - - logger.info("Checkpoints page is initializing, returning loading page") - else: - # Normal flow - get initialized cache data - try: - cache = await self.service.scanner.get_cached_data(force_refresh=False) - template = self.template_env.get_template('checkpoints.html') - rendered = template.render( - folders=cache.folders, - is_initializing=False, - settings=settings, # Pass settings to template - request=request # Pass the request object to the template - ) - except Exception as cache_error: - logger.error(f"Error loading checkpoints cache data: {cache_error}") - # If getting cache fails, also show initialization page - template = self.template_env.get_template('checkpoints.html') - rendered = template.render( - folders=[], - is_initializing=True, - settings=settings, - request=request - ) - logger.info("Checkpoints cache error, returning initialization page") - - return web.Response( - text=rendered, - content_type='text/html' - ) - except Exception as e: - logger.error(f"Error handling checkpoints request: {e}", exc_info=True) - return web.Response( - text="Error loading checkpoints page", - status=500 - ) - async def get_checkpoint_info(self, request: web.Request) -> web.Response: """Get detailed information for a specific checkpoint by name""" try: diff --git a/py/routes/lora_routes.py b/py/routes/lora_routes.py index 5b0674ba..0a9d7dff 100644 --- a/py/routes/lora_routes.py +++ b/py/routes/lora_routes.py @@ -1,7 +1,5 @@ -import jinja2 import asyncio import logging -import os from aiohttp import web from typing import Dict from server import PromptServer # type: ignore @@ -9,8 +7,6 @@ from server import PromptServer # type: ignore from .base_model_routes import BaseModelRoutes from ..services.lora_service import LoraService from ..services.service_registry import ServiceRegistry -from ..services.settings_manager import settings -from ..config import config from ..utils.routes_common import ModelRouteUtils from ..utils.utils import get_lora_info @@ -26,10 +22,7 @@ class LoraRoutes(BaseModelRoutes): self.civitai_client = None self.download_manager = None self._download_lock = asyncio.Lock() - self.template_env = jinja2.Environment( - loader=jinja2.FileSystemLoader(config.templates_path), - autoescape=True - ) + self.template_name = "loras.html" async def initialize_services(self): """Initialize services from ServiceRegistry""" @@ -46,14 +39,11 @@ class LoraRoutes(BaseModelRoutes): # Schedule service initialization on app startup app.on_startup.append(lambda _: self.initialize_services()) - # Setup common routes with 'loras' prefix + # Setup common routes with 'loras' prefix (includes page route) super().setup_routes(app, 'loras') def setup_specific_routes(self, app: web.Application, prefix: str): """Setup LoRA-specific routes""" - # Lora page route - app.router.add_get('/loras', self.handle_loras_page) - # LoRA-specific query routes app.router.add_get(f'/api/{prefix}/letter-counts', self.get_letter_counts) app.router.add_get(f'/api/{prefix}/get-notes', self.get_lora_notes) @@ -113,61 +103,6 @@ class LoraRoutes(BaseModelRoutes): return params - async def handle_loras_page(self, request: web.Request) -> web.Response: - """Handle GET /loras request""" - try: - # Check if the LoraScanner is initializing - # It's initializing if the cache object doesn't exist yet, - # OR if the scanner explicitly says it's initializing (background task running). - is_initializing = ( - self.service.scanner._cache is None or self.service.scanner.is_initializing() - ) - - if is_initializing: - # If still initializing, return loading page - template = self.template_env.get_template('loras.html') - rendered = template.render( - folders=[], - is_initializing=True, - settings=settings, - request=request - ) - - logger.info("Loras page is initializing, returning loading page") - else: - # Normal flow - get data from initialized cache - try: - cache = await self.service.scanner.get_cached_data(force_refresh=False) - template = self.template_env.get_template('loras.html') - rendered = template.render( - folders=cache.folders, - is_initializing=False, - settings=settings, - request=request - ) - except Exception as cache_error: - logger.error(f"Error loading cache data: {cache_error}") - template = self.template_env.get_template('loras.html') - rendered = template.render( - folders=[], - is_initializing=True, - settings=settings, - request=request - ) - logger.info("Cache error, returning initialization page") - - return web.Response( - text=rendered, - content_type='text/html' - ) - - except Exception as e: - logger.error(f"Error handling loras request: {e}", exc_info=True) - return web.Response( - text="Error loading loras page", - status=500 - ) - # LoRA-specific route handlers async def get_letter_counts(self, request: web.Request) -> web.Response: """Get count of LoRAs for each letter of the alphabet""" From ea9370443d2de82138258bc8b5bed9fbef84bf40 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 23 Jul 2025 16:11:02 +0800 Subject: [PATCH 03/18] refactor: Implement download management routes and update API endpoints for LoRA --- py/routes/base_model_routes.py | 113 ++++++++++++++++ py/routes/lora_routes.py | 121 ------------------ py/utils/routes_common.py | 11 +- static/js/api/baseModelApi.js | 6 +- static/js/api/loraApi.js | 2 +- .../ContextMenu/ModelContextMenuMixin.js | 2 +- 6 files changed, 124 insertions(+), 131 deletions(-) diff --git a/py/routes/base_model_routes.py b/py/routes/base_model_routes.py index 8f5bc3f6..fc512d0c 100644 --- a/py/routes/base_model_routes.py +++ b/py/routes/base_model_routes.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +import asyncio import json import logging from aiohttp import web @@ -7,6 +8,7 @@ from typing import Dict import jinja2 from ..utils.routes_common import ModelRouteUtils +from ..services.service_registry import ServiceRegistry from ..services.websocket_manager import ws_manager from ..services.settings_manager import settings from ..config import config @@ -55,6 +57,12 @@ class BaseModelRoutes(ABC): app.router.add_get(f'/api/{prefix}/roots', self.get_model_roots) app.router.add_get(f'/api/{prefix}/find-duplicates', self.find_duplicate_models) app.router.add_get(f'/api/{prefix}/find-filename-conflicts', self.find_filename_conflicts) + + # Common Download management + app.router.add_post(f'/api/download-model', self.download_model) + app.router.add_get(f'/api/download-model-get', self.download_model_get) + app.router.add_get(f'/api/cancel-download-get', self.cancel_download_get) + app.router.add_get(f'/api/download-progress/{{download_id}}', self.get_download_progress) # CivitAI integration routes app.router.add_post(f'/api/{prefix}/fetch-all-civitai', self.fetch_all_civitai) @@ -408,6 +416,111 @@ class BaseModelRoutes(ABC): "success": False, "error": str(e) }, status=500) + + # Download management methods + async def download_model(self, request: web.Request) -> web.Response: + """Handle model download request""" + return await ModelRouteUtils.handle_download_model(request) + + async def download_model_get(self, request: web.Request) -> web.Response: + """Handle model download request via GET method""" + try: + # Extract query parameters + model_id = request.query.get('model_id') + if not model_id: + return web.Response( + status=400, + text="Missing required parameter: Please provide 'model_id'" + ) + + # Get optional parameters + model_version_id = request.query.get('model_version_id') + download_id = request.query.get('download_id') + use_default_paths = request.query.get('use_default_paths', 'false').lower() == 'true' + + # Create a data dictionary that mimics what would be received from a POST request + data = { + 'model_id': model_id + } + + # Add optional parameters only if they are provided + if model_version_id: + data['model_version_id'] = model_version_id + + if download_id: + data['download_id'] = download_id + + data['use_default_paths'] = use_default_paths + + # Create a mock request object with the data + future = asyncio.get_event_loop().create_future() + future.set_result(data) + + mock_request = type('MockRequest', (), { + 'json': lambda self=None: future + })() + + # Call the existing download handler + return await ModelRouteUtils.handle_download_model(mock_request) + + except Exception as e: + error_message = str(e) + logger.error(f"Error downloading model via GET: {error_message}", exc_info=True) + return web.Response(status=500, text=error_message) + + async def cancel_download_get(self, request: web.Request) -> web.Response: + """Handle GET request for cancelling a download by download_id""" + try: + download_id = request.query.get('download_id') + if not download_id: + return web.json_response({ + 'success': False, + 'error': 'Download ID is required' + }, status=400) + + # Create a mock request with match_info for compatibility + mock_request = type('MockRequest', (), { + 'match_info': {'download_id': download_id} + })() + return await ModelRouteUtils.handle_cancel_download(mock_request) + except Exception as e: + logger.error(f"Error cancelling download via GET: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + async def get_download_progress(self, request: web.Request) -> web.Response: + """Handle request for download progress by download_id""" + try: + # Get download_id from URL path + download_id = request.match_info.get('download_id') + if not download_id: + return web.json_response({ + 'success': False, + 'error': 'Download ID is required' + }, status=400) + + # Get progress information from websocket manager + from ..services.websocket_manager import ws_manager + progress_data = ws_manager.get_download_progress(download_id) + + if progress_data is None: + return web.json_response({ + 'success': False, + 'error': 'Download ID not found' + }, status=404) + + return web.json_response({ + 'success': True, + 'progress': progress_data.get('progress', 0) + }) + except Exception as e: + logger.error(f"Error getting download progress: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) async def fetch_all_civitai(self, request: web.Request) -> web.Response: """Fetch CivitAI metadata for all models in the background""" diff --git a/py/routes/lora_routes.py b/py/routes/lora_routes.py index 0a9d7dff..fde5586b 100644 --- a/py/routes/lora_routes.py +++ b/py/routes/lora_routes.py @@ -20,8 +20,6 @@ class LoraRoutes(BaseModelRoutes): # Service will be initialized later via setup_routes self.service = None self.civitai_client = None - self.download_manager = None - self._download_lock = asyncio.Lock() self.template_name = "loras.html" async def initialize_services(self): @@ -29,7 +27,6 @@ class LoraRoutes(BaseModelRoutes): lora_scanner = await ServiceRegistry.get_lora_scanner() self.service = LoraService(lora_scanner) self.civitai_client = await ServiceRegistry.get_civitai_client() - self.download_manager = await ServiceRegistry.get_download_manager() # Initialize parent with the service super().__init__(self.service) @@ -63,21 +60,8 @@ class LoraRoutes(BaseModelRoutes): app.router.add_get(f'/api/civitai/model/version/{{modelVersionId}}', self.get_civitai_model_by_version) app.router.add_get(f'/api/civitai/model/hash/{{hash}}', self.get_civitai_model_by_hash) - # Download management - app.router.add_post(f'/api/download-model', self.download_model) - app.router.add_get(f'/api/download-model-get', self.download_model_get) - app.router.add_get(f'/api/cancel-download-get', self.cancel_download_get) - app.router.add_get(f'/api/download-progress/{{download_id}}', self.get_download_progress) - # ComfyUI integration app.router.add_post(f'/loramanager/get_trigger_words', self.get_trigger_words) - - # Legacy API compatibility - app.router.add_post(f'/api/delete_model', self.delete_model) - app.router.add_post(f'/api/fetch-civitai', self.fetch_civitai) - app.router.add_post(f'/api/relink-civitai', self.relink_civitai) - app.router.add_post(f'/api/replace_preview', self.replace_preview) - app.router.add_post(f'/api/fetch-all-civitai', self.fetch_all_civitai) def _parse_specific_params(self, request: web.Request) -> Dict: """Parse LoRA-specific parameters""" @@ -358,111 +342,6 @@ class LoraRoutes(BaseModelRoutes): "error": str(e) }, status=500) - # Download management methods - async def download_model(self, request: web.Request) -> web.Response: - """Handle model download request""" - return await ModelRouteUtils.handle_download_model(request, self.download_manager) - - async def download_model_get(self, request: web.Request) -> web.Response: - """Handle model download request via GET method""" - try: - # Extract query parameters - model_id = request.query.get('model_id') - if not model_id: - return web.Response( - status=400, - text="Missing required parameter: Please provide 'model_id'" - ) - - # Get optional parameters - model_version_id = request.query.get('model_version_id') - download_id = request.query.get('download_id') - use_default_paths = request.query.get('use_default_paths', 'false').lower() == 'true' - - # Create a data dictionary that mimics what would be received from a POST request - data = { - 'model_id': model_id - } - - # Add optional parameters only if they are provided - if model_version_id: - data['model_version_id'] = model_version_id - - if download_id: - data['download_id'] = download_id - - data['use_default_paths'] = use_default_paths - - # Create a mock request object with the data - future = asyncio.get_event_loop().create_future() - future.set_result(data) - - mock_request = type('MockRequest', (), { - 'json': lambda self=None: future - })() - - # Call the existing download handler - return await ModelRouteUtils.handle_download_model(mock_request, self.download_manager) - - except Exception as e: - error_message = str(e) - logger.error(f"Error downloading model via GET: {error_message}", exc_info=True) - return web.Response(status=500, text=error_message) - - async def cancel_download_get(self, request: web.Request) -> web.Response: - """Handle GET request for cancelling a download by download_id""" - try: - download_id = request.query.get('download_id') - if not download_id: - return web.json_response({ - 'success': False, - 'error': 'Download ID is required' - }, status=400) - - # Create a mock request with match_info for compatibility - mock_request = type('MockRequest', (), { - 'match_info': {'download_id': download_id} - })() - return await ModelRouteUtils.handle_cancel_download(mock_request, self.download_manager) - except Exception as e: - logger.error(f"Error cancelling download via GET: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def get_download_progress(self, request: web.Request) -> web.Response: - """Handle request for download progress by download_id""" - try: - # Get download_id from URL path - download_id = request.match_info.get('download_id') - if not download_id: - return web.json_response({ - 'success': False, - 'error': 'Download ID is required' - }, status=400) - - # Get progress information from websocket manager - from ..services.websocket_manager import ws_manager - progress_data = ws_manager.get_download_progress(download_id) - - if progress_data is None: - return web.json_response({ - 'success': False, - 'error': 'Download ID not found' - }, status=404) - - return web.json_response({ - 'success': True, - 'progress': progress_data.get('progress', 0) - }) - except Exception as e: - logger.error(f"Error getting download progress: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - # Model management methods async def move_model(self, request: web.Request) -> web.Response: """Handle model move request""" diff --git a/py/utils/routes_common.py b/py/utils/routes_common.py index 8e5df544..2704b20f 100644 --- a/py/utils/routes_common.py +++ b/py/utils/routes_common.py @@ -566,9 +566,10 @@ class ModelRouteUtils: return web.Response(text=str(e), status=500) @staticmethod - async def handle_download_model(request: web.Request, download_manager: DownloadManager) -> web.Response: + async def handle_download_model(request: web.Request) -> web.Response: """Handle model download request""" try: + download_manager = await ServiceRegistry.get_download_manager() data = await request.json() # Get or generate a download ID @@ -663,17 +664,17 @@ class ModelRouteUtils: }, status=500) @staticmethod - async def handle_cancel_download(request: web.Request, download_manager: DownloadManager) -> web.Response: + async def handle_cancel_download(request: web.Request) -> web.Response: """Handle cancellation of a download task Args: request: The aiohttp request - download_manager: The download manager instance Returns: web.Response: The HTTP response """ try: + download_manager = await ServiceRegistry.get_download_manager() download_id = request.match_info.get('download_id') if not download_id: return web.json_response({ @@ -701,17 +702,17 @@ class ModelRouteUtils: }, status=500) @staticmethod - async def handle_list_downloads(request: web.Request, download_manager: DownloadManager) -> web.Response: + async def handle_list_downloads(request: web.Request) -> web.Response: """Get list of active downloads Args: request: The aiohttp request - download_manager: The download manager instance Returns: web.Response: The HTTP response with list of downloads """ try: + download_manager = await ServiceRegistry.get_download_manager() result = await download_manager.get_active_downloads() return web.json_response(result) except Exception as e: diff --git a/static/js/api/baseModelApi.js b/static/js/api/baseModelApi.js index 43abbedf..e15898a5 100644 --- a/static/js/api/baseModelApi.js +++ b/static/js/api/baseModelApi.js @@ -277,7 +277,7 @@ export async function deleteModel(filePath, modelType = 'lora') { const endpoint = modelType === 'checkpoint' ? '/api/checkpoints/delete' - : '/api/delete_model'; + : '/api/loras/delete'; const response = await fetch(endpoint, { method: 'POST', @@ -454,7 +454,7 @@ export async function refreshSingleModelMetadata(filePath, modelType = 'lora') { const endpoint = modelType === 'checkpoint' ? '/api/checkpoints/fetch-civitai' - : '/api/fetch-civitai'; + : '/api/loras/fetch-civitai'; const response = await fetch(endpoint, { method: 'POST', @@ -557,7 +557,7 @@ export async function uploadPreview(filePath, file, modelType = 'lora', nsfwLeve // Set endpoint based on model type const endpoint = modelType === 'checkpoint' ? '/api/checkpoints/replace-preview' - : '/api/replace_preview'; + : '/api/loras/replace_preview'; const response = await fetch(endpoint, { method: 'POST', diff --git a/static/js/api/loraApi.js b/static/js/api/loraApi.js index 3c6f0c86..9d5dd1bc 100644 --- a/static/js/api/loraApi.js +++ b/static/js/api/loraApi.js @@ -89,7 +89,7 @@ export async function fetchLorasPage(page = 1, pageSize = 100) { export async function fetchCivitai() { return fetchCivitaiMetadata({ modelType: 'lora', - fetchEndpoint: '/api/fetch-all-civitai', + fetchEndpoint: '/api/loras/fetch-all-civitai', resetAndReloadFunction: resetAndReload }); } diff --git a/static/js/components/ContextMenu/ModelContextMenuMixin.js b/static/js/components/ContextMenu/ModelContextMenuMixin.js index 7c0bc0c3..cd58bc0f 100644 --- a/static/js/components/ContextMenu/ModelContextMenuMixin.js +++ b/static/js/components/ContextMenu/ModelContextMenuMixin.js @@ -125,7 +125,7 @@ export const ModelContextMenuMixin = { const endpoint = this.modelType === 'checkpoint' ? '/api/checkpoints/relink-civitai' : - '/api/relink-civitai'; + '/api/loras/relink-civitai'; const response = await fetch(endpoint, { method: 'POST', From a9a7f4c8ecf2c68371d59ee6e100d078571a3996 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 23 Jul 2025 16:30:00 +0800 Subject: [PATCH 04/18] refactor: Remove legacy API route handlers from standalone manager --- py/routes/api_routes.py | 37 ------------------------------------- standalone.py | 3 --- 2 files changed, 40 deletions(-) delete mode 100644 py/routes/api_routes.py diff --git a/py/routes/api_routes.py b/py/routes/api_routes.py deleted file mode 100644 index 5c9c93fe..00000000 --- a/py/routes/api_routes.py +++ /dev/null @@ -1,37 +0,0 @@ -import logging -from aiohttp import web - -from ..services.websocket_manager import ws_manager -from .update_routes import UpdateRoutes -from .lora_routes import LoraRoutes - -logger = logging.getLogger(__name__) - -class ApiRoutes: - """Legacy API route handlers for backward compatibility""" - - def __init__(self): - # Initialize the new LoRA routes - self.lora_routes = LoraRoutes() - - @classmethod - def setup_routes(cls, app: web.Application): - """Register API routes using the new refactored architecture""" - routes = cls() - - # Setup the refactored LoRA routes - routes.lora_routes.setup_routes(app) - - # Setup WebSocket routes that are still shared - app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection) - app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection) - app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection) - - # Setup update routes that are not model-specific - UpdateRoutes.setup_routes(app) - - @classmethod - async def cleanup(cls): - """Add cleanup method for application shutdown""" - # Cleanup is now handled by ServiceRegistry and individual services - pass diff --git a/standalone.py b/standalone.py index 0e491c14..a809aaad 100644 --- a/standalone.py +++ b/standalone.py @@ -315,7 +315,6 @@ class StandaloneLoraManager(LoraManager): # Setup feature routes from py.services.model_service_factory import ModelServiceFactory, register_default_model_types - from py.routes.api_routes import ApiRoutes from py.routes.recipe_routes import RecipeRoutes from py.routes.update_routes import UpdateRoutes from py.routes.misc_routes import MiscRoutes @@ -332,7 +331,6 @@ class StandaloneLoraManager(LoraManager): # Initialize routes stats_routes.setup_routes(app) - ApiRoutes.setup_routes(app) RecipeRoutes.setup_routes(app) UpdateRoutes.setup_routes(app) MiscRoutes.setup_routes(app) @@ -343,7 +341,6 @@ class StandaloneLoraManager(LoraManager): # Add cleanup app.on_shutdown.append(cls._cleanup) - app.on_shutdown.append(ApiRoutes.cleanup) def parse_args(): """Parse command line arguments""" From 2c6c9542dd12a9ab23e4efac34a637b6e91e675c Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 23 Jul 2025 16:59:16 +0800 Subject: [PATCH 05/18] refactor: Change logging level from info to debug for service registration --- py/services/service_registry.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/py/services/service_registry.py b/py/services/service_registry.py index 9589b984..6cefb4d4 100644 --- a/py/services/service_registry.py +++ b/py/services/service_registry.py @@ -67,7 +67,7 @@ class ServiceRegistry: scanner = await LoraScanner.get_instance() cls._services[service_name] = scanner - logger.info(f"Created and registered {service_name}") + logger.debug(f"Created and registered {service_name}") return scanner @classmethod @@ -88,7 +88,7 @@ class ServiceRegistry: scanner = await CheckpointScanner.get_instance() cls._services[service_name] = scanner - logger.info(f"Created and registered {service_name}") + logger.debug(f"Created and registered {service_name}") return scanner @classmethod @@ -109,7 +109,7 @@ class ServiceRegistry: scanner = await RecipeScanner.get_instance() cls._services[service_name] = scanner - logger.info(f"Created and registered {service_name}") + logger.debug(f"Created and registered {service_name}") return scanner @classmethod @@ -130,7 +130,7 @@ class ServiceRegistry: client = await CivitaiClient.get_instance() cls._services[service_name] = client - logger.info(f"Created and registered {service_name}") + logger.debug(f"Created and registered {service_name}") return client @classmethod @@ -151,7 +151,7 @@ class ServiceRegistry: manager = DownloadManager() cls._services[service_name] = manager - logger.info(f"Created and registered {service_name}") + logger.debug(f"Created and registered {service_name}") return manager @classmethod @@ -171,7 +171,7 @@ class ServiceRegistry: from .websocket_manager import ws_manager cls._services[service_name] = ws_manager - logger.info(f"Registered {service_name}") + logger.debug(f"Registered {service_name}") return ws_manager @classmethod From a834fc4b3063eaf3426f0daed6e4cd34c1ad152a Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 23 Jul 2025 17:26:06 +0800 Subject: [PATCH 06/18] feat: Update API routes for LoRA management and enhance folder handling --- py/routes/base_model_routes.py | 17 +++++++++++- py/routes/checkpoint_routes.py | 2 +- py/routes/lora_routes.py | 31 +--------------------- static/js/managers/DownloadManager.js | 6 ++--- static/js/managers/MoveManager.js | 6 ++--- static/js/managers/SettingsManager.js | 2 +- static/js/managers/import/FolderBrowser.js | 4 +-- 7 files changed, 27 insertions(+), 41 deletions(-) diff --git a/py/routes/base_model_routes.py b/py/routes/base_model_routes.py index fc512d0c..d438d69e 100644 --- a/py/routes/base_model_routes.py +++ b/py/routes/base_model_routes.py @@ -55,6 +55,7 @@ class BaseModelRoutes(ABC): app.router.add_get(f'/api/{prefix}/base-models', self.get_base_models) app.router.add_get(f'/api/{prefix}/scan', self.scan_models) app.router.add_get(f'/api/{prefix}/roots', self.get_model_roots) + app.router.add_get(f'/api/{prefix}/folders', self.get_folders) app.router.add_get(f'/api/{prefix}/find-duplicates', self.find_duplicate_models) app.router.add_get(f'/api/{prefix}/find-filename-conflicts', self.find_filename_conflicts) @@ -66,7 +67,7 @@ class BaseModelRoutes(ABC): # CivitAI integration routes app.router.add_post(f'/api/{prefix}/fetch-all-civitai', self.fetch_all_civitai) - app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions) + # app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions) # Add generic page route app.router.add_get(f'/{prefix}', self.handle_models_page) @@ -328,6 +329,20 @@ class BaseModelRoutes(ABC): "success": False, "error": str(e) }, status=500) + + async def get_folders(self, request: web.Request) -> web.Response: + """Get all folders in the cache""" + try: + cache = await self.service.scanner.get_cached_data() + return web.json_response({ + 'folders': cache.folders + }) + except Exception as e: + logger.error(f"Error getting folders: {e}") + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) async def find_duplicate_models(self, request: web.Request) -> web.Response: """Find models with duplicate SHA256 hashes""" diff --git a/py/routes/checkpoint_routes.py b/py/routes/checkpoint_routes.py index 6ba550a6..4f27115e 100644 --- a/py/routes/checkpoint_routes.py +++ b/py/routes/checkpoint_routes.py @@ -37,7 +37,7 @@ class CheckpointRoutes(BaseModelRoutes): def setup_specific_routes(self, app: web.Application, prefix: str): """Setup Checkpoint-specific routes""" # Checkpoint-specific CivitAI integration - app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions_checkpoint) + app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions_checkpoint) # Checkpoint info by name app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_checkpoint_info) diff --git a/py/routes/lora_routes.py b/py/routes/lora_routes.py index fde5586b..ac7dc4ed 100644 --- a/py/routes/lora_routes.py +++ b/py/routes/lora_routes.py @@ -48,15 +48,13 @@ class LoraRoutes(BaseModelRoutes): app.router.add_get(f'/api/lora-preview-url', self.get_lora_preview_url) app.router.add_get(f'/api/lora-civitai-url', self.get_lora_civitai_url) app.router.add_get(f'/api/lora-model-description', self.get_lora_model_description) - app.router.add_get(f'/api/folders', self.get_folders) - app.router.add_get(f'/api/lora-roots', self.get_lora_roots) # LoRA-specific management routes app.router.add_post(f'/api/move_model', self.move_model) app.router.add_post(f'/api/move_models_bulk', self.move_models_bulk) # CivitAI integration with LoRA-specific validation - app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions_lora) + app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions_lora) app.router.add_get(f'/api/civitai/model/version/{{modelVersionId}}', self.get_civitai_model_by_version) app.router.add_get(f'/api/civitai/model/hash/{{hash}}', self.get_civitai_model_by_hash) @@ -201,33 +199,6 @@ class LoraRoutes(BaseModelRoutes): 'error': str(e) }, status=500) - async def get_folders(self, request: web.Request) -> web.Response: - """Get all folders in the cache""" - try: - cache = await self.service.scanner.get_cached_data() - return web.json_response({ - 'folders': cache.folders - }) - except Exception as e: - logger.error(f"Error getting folders: {e}") - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def get_lora_roots(self, request: web.Request) -> web.Response: - """Get all configured LoRA root directories""" - try: - return web.json_response({ - 'roots': self.service.get_model_roots() - }) - except Exception as e: - logger.error(f"Error getting LoRA roots: {e}") - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - # Override get_models to add LoRA-specific response data async def get_models(self, request: web.Request) -> web.Response: """Get paginated LoRA data with LoRA-specific fields""" diff --git a/static/js/managers/DownloadManager.js b/static/js/managers/DownloadManager.js index 7584bf0f..b3d0433e 100644 --- a/static/js/managers/DownloadManager.js +++ b/static/js/managers/DownloadManager.js @@ -87,7 +87,7 @@ export class DownloadManager { throw new Error('Invalid Civitai URL format'); } - const response = await fetch(`/api/civitai/versions/${this.modelId}`); + const response = await fetch(`/api/loras/civitai/versions/${this.modelId}`); if (!response.ok) { const errorData = await response.json().catch(() => ({})); if (errorData && errorData.error && errorData.error.includes('Model type mismatch')) { @@ -254,7 +254,7 @@ export class DownloadManager { try { // Fetch LoRA roots - const rootsResponse = await fetch('/api/lora-roots'); + const rootsResponse = await fetch('/api/loras/roots'); if (!rootsResponse.ok) { throw new Error('Failed to fetch LoRA roots'); } @@ -272,7 +272,7 @@ export class DownloadManager { } // Fetch folders dynamically - const foldersResponse = await fetch('/api/folders'); + const foldersResponse = await fetch('/api/loras/folders'); if (!foldersResponse.ok) { throw new Error('Failed to fetch folders'); } diff --git a/static/js/managers/MoveManager.js b/static/js/managers/MoveManager.js index 45e6a011..e72c8274 100644 --- a/static/js/managers/MoveManager.js +++ b/static/js/managers/MoveManager.js @@ -74,7 +74,7 @@ class MoveManager { try { // Fetch LoRA roots - const rootsResponse = await fetch('/api/lora-roots'); + const rootsResponse = await fetch('/api/loras/roots'); if (!rootsResponse.ok) { throw new Error('Failed to fetch LoRA roots'); } @@ -96,7 +96,7 @@ class MoveManager { } // Fetch folders dynamically - const foldersResponse = await fetch('/api/folders'); + const foldersResponse = await fetch('/api/loras/folders'); if (!foldersResponse.ok) { throw new Error('Failed to fetch folders'); } @@ -190,7 +190,7 @@ class MoveManager { // Refresh folder tags after successful move try { - const foldersResponse = await fetch('/api/folders'); + const foldersResponse = await fetch('/api/loras/folders'); if (foldersResponse.ok) { const foldersData = await foldersResponse.json(); updateFolderTags(foldersData.folders); diff --git a/static/js/managers/SettingsManager.js b/static/js/managers/SettingsManager.js index 828cb461..c0f0f662 100644 --- a/static/js/managers/SettingsManager.js +++ b/static/js/managers/SettingsManager.js @@ -161,7 +161,7 @@ export class SettingsManager { if (!defaultLoraRootSelect) return; // Fetch lora roots - const response = await fetch('/api/lora-roots'); + const response = await fetch('/api/loras/roots'); if (!response.ok) { throw new Error('Failed to fetch LoRA roots'); } diff --git a/static/js/managers/import/FolderBrowser.js b/static/js/managers/import/FolderBrowser.js index 33504ee8..edd4d6a4 100644 --- a/static/js/managers/import/FolderBrowser.js +++ b/static/js/managers/import/FolderBrowser.js @@ -99,7 +99,7 @@ export class FolderBrowser { } // Fetch LoRA roots - const rootsResponse = await fetch('/api/lora-roots'); + const rootsResponse = await fetch('/api/loras/roots'); if (!rootsResponse.ok) { throw new Error(`Failed to fetch LoRA roots: ${rootsResponse.status}`); } @@ -119,7 +119,7 @@ export class FolderBrowser { } // Fetch folders - const foldersResponse = await fetch('/api/folders'); + const foldersResponse = await fetch('/api/loras/folders'); if (!foldersResponse.ok) { throw new Error(`Failed to fetch folders: ${foldersResponse.status}`); } From 298a95432da41e98a3975db9708338d656e3ddf4 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 23 Jul 2025 18:02:38 +0800 Subject: [PATCH 07/18] feat: Integrate WebSocket routes for download progress tracking in standalone manager --- py/routes/base_model_routes.py | 3 --- standalone.py | 6 ++++++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/py/routes/base_model_routes.py b/py/routes/base_model_routes.py index d438d69e..8cee15f3 100644 --- a/py/routes/base_model_routes.py +++ b/py/routes/base_model_routes.py @@ -8,7 +8,6 @@ from typing import Dict import jinja2 from ..utils.routes_common import ModelRouteUtils -from ..services.service_registry import ServiceRegistry from ..services.websocket_manager import ws_manager from ..services.settings_manager import settings from ..config import config @@ -516,8 +515,6 @@ class BaseModelRoutes(ABC): 'error': 'Download ID is required' }, status=400) - # Get progress information from websocket manager - from ..services.websocket_manager import ws_manager progress_data = ws_manager.get_download_progress(download_id) if progress_data is None: diff --git a/standalone.py b/standalone.py index a809aaad..2ad50ade 100644 --- a/standalone.py +++ b/standalone.py @@ -320,6 +320,7 @@ class StandaloneLoraManager(LoraManager): from py.routes.misc_routes import MiscRoutes from py.routes.example_images_routes import ExampleImagesRoutes from py.routes.stats_routes import StatsRoutes + from py.services.websocket_manager import ws_manager register_default_model_types() @@ -335,6 +336,11 @@ class StandaloneLoraManager(LoraManager): UpdateRoutes.setup_routes(app) MiscRoutes.setup_routes(app) ExampleImagesRoutes.setup_routes(app) + + # Setup WebSocket routes that are shared across all model types + app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection) + app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection) + app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection) # Schedule service initialization app.on_startup.append(lambda app: cls._initialize_services()) From 804808da4a7e83c1f9de65f4accee7bc14b4cd7e Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 23 Jul 2025 22:09:42 +0800 Subject: [PATCH 08/18] refactor: Update logging configuration to use asyncio logger and remove aiohttp access logger references --- py/lora_manager.py | 7 ++----- standalone.py | 12 +++--------- static/js/utils/VirtualScroller.js | 17 ----------------- 3 files changed, 5 insertions(+), 31 deletions(-) diff --git a/py/lora_manager.py b/py/lora_manager.py index dc54d8f4..1cb2f712 100644 --- a/py/lora_manager.py +++ b/py/lora_manager.py @@ -30,8 +30,8 @@ class LoraManager: """Initialize and register all routes using the new refactored architecture""" app = PromptServer.instance.app - # Configure aiohttp access logger to be less verbose - logging.getLogger('aiohttp.access').setLevel(logging.WARNING) + # Configure aiohttp logger to be less verbose + logging.getLogger("asyncio").setLevel(logging.WARNING) added_targets = set() # Track already added target paths @@ -140,9 +140,6 @@ class LoraManager: async def _initialize_services(cls): """Initialize all services using the ServiceRegistry""" try: - # Ensure aiohttp access logger is configured with reduced verbosity - logging.getLogger('aiohttp.access').setLevel(logging.WARNING) - # Initialize CivitaiClient first to ensure it's ready for other services await ServiceRegistry.get_civitai_client() diff --git a/standalone.py b/standalone.py index 2ad50ade..6c7af9ab 100644 --- a/standalone.py +++ b/standalone.py @@ -103,9 +103,6 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger("lora-manager-standalone") -# Configure aiohttp access logger to be less verbose -logging.getLogger('aiohttp.access').setLevel(logging.WARNING) - # Now we can import the global config from our local modules from py.config import config @@ -124,7 +121,7 @@ class StandaloneServer: async def _configure_access_logger(self, app): """Configure access logger to reduce verbosity""" - logging.getLogger('aiohttp.access').setLevel(logging.WARNING) + logging.getLogger("asyncio").setLevel(logging.WARNING) # If using aiohttp>=3.8.0, configure access logger through app directly if hasattr(app, 'access_logger'): @@ -219,9 +216,6 @@ class StandaloneLoraManager(LoraManager): # Store app in a global-like location for compatibility sys.modules['server'].PromptServer.instance = server_instance - # Configure aiohttp access logger to be less verbose - logging.getLogger('aiohttp.access').setLevel(logging.WARNING) - added_targets = set() # Track already added target paths # Add static routes for each lora root @@ -371,8 +365,8 @@ async def main(): # Set log level logging.getLogger().setLevel(getattr(logging, args.log_level)) - # Explicitly configure aiohttp access logger regardless of selected log level - logging.getLogger('aiohttp.access').setLevel(logging.WARNING) + # Explicitly configure asyncio logger regardless of selected log level + logging.getLogger("asyncio").setLevel(logging.WARNING) # Create the server instance server = StandaloneServer() diff --git a/static/js/utils/VirtualScroller.js b/static/js/utils/VirtualScroller.js index bbdc73b6..0d614211 100644 --- a/static/js/utils/VirtualScroller.js +++ b/static/js/utils/VirtualScroller.js @@ -164,23 +164,6 @@ export class VirtualScroller { // Calculate the left offset to center the grid within the content area this.leftOffset = Math.max(0, (availableContentWidth - actualGridWidth) / 2); - - // Log layout info - console.log('Virtual Scroll Layout:', { - containerWidth, - availableContentWidth, - actualGridWidth, - columnsCount: this.columnsCount, - itemWidth: this.itemWidth, - itemHeight: this.itemHeight, - leftOffset: this.leftOffset, - paddingLeft, - paddingRight, - displayDensity, - maxColumns, - baseCardWidth, - rowGap: this.rowGap - }); // Update grid element max-width to match available width this.gridElement.style.maxWidth = `${actualGridWidth}px`; From 4d38add2910dea7423328470de75fd347f5f00e8 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 23 Jul 2025 22:23:48 +0800 Subject: [PATCH 09/18] Revert "refactor: Update logging configuration to use asyncio logger and remove aiohttp access logger references" This reverts commit 804808da4a7e83c1f9de65f4accee7bc14b4cd7e. --- py/lora_manager.py | 7 +++++-- standalone.py | 12 +++++++++--- static/js/utils/VirtualScroller.js | 17 +++++++++++++++++ 3 files changed, 31 insertions(+), 5 deletions(-) diff --git a/py/lora_manager.py b/py/lora_manager.py index 1cb2f712..dc54d8f4 100644 --- a/py/lora_manager.py +++ b/py/lora_manager.py @@ -30,8 +30,8 @@ class LoraManager: """Initialize and register all routes using the new refactored architecture""" app = PromptServer.instance.app - # Configure aiohttp logger to be less verbose - logging.getLogger("asyncio").setLevel(logging.WARNING) + # Configure aiohttp access logger to be less verbose + logging.getLogger('aiohttp.access').setLevel(logging.WARNING) added_targets = set() # Track already added target paths @@ -140,6 +140,9 @@ class LoraManager: async def _initialize_services(cls): """Initialize all services using the ServiceRegistry""" try: + # Ensure aiohttp access logger is configured with reduced verbosity + logging.getLogger('aiohttp.access').setLevel(logging.WARNING) + # Initialize CivitaiClient first to ensure it's ready for other services await ServiceRegistry.get_civitai_client() diff --git a/standalone.py b/standalone.py index 6c7af9ab..2ad50ade 100644 --- a/standalone.py +++ b/standalone.py @@ -103,6 +103,9 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger("lora-manager-standalone") +# Configure aiohttp access logger to be less verbose +logging.getLogger('aiohttp.access').setLevel(logging.WARNING) + # Now we can import the global config from our local modules from py.config import config @@ -121,7 +124,7 @@ class StandaloneServer: async def _configure_access_logger(self, app): """Configure access logger to reduce verbosity""" - logging.getLogger("asyncio").setLevel(logging.WARNING) + logging.getLogger('aiohttp.access').setLevel(logging.WARNING) # If using aiohttp>=3.8.0, configure access logger through app directly if hasattr(app, 'access_logger'): @@ -216,6 +219,9 @@ class StandaloneLoraManager(LoraManager): # Store app in a global-like location for compatibility sys.modules['server'].PromptServer.instance = server_instance + # Configure aiohttp access logger to be less verbose + logging.getLogger('aiohttp.access').setLevel(logging.WARNING) + added_targets = set() # Track already added target paths # Add static routes for each lora root @@ -365,8 +371,8 @@ async def main(): # Set log level logging.getLogger().setLevel(getattr(logging, args.log_level)) - # Explicitly configure asyncio logger regardless of selected log level - logging.getLogger("asyncio").setLevel(logging.WARNING) + # Explicitly configure aiohttp access logger regardless of selected log level + logging.getLogger('aiohttp.access').setLevel(logging.WARNING) # Create the server instance server = StandaloneServer() diff --git a/static/js/utils/VirtualScroller.js b/static/js/utils/VirtualScroller.js index 0d614211..bbdc73b6 100644 --- a/static/js/utils/VirtualScroller.js +++ b/static/js/utils/VirtualScroller.js @@ -164,6 +164,23 @@ export class VirtualScroller { // Calculate the left offset to center the grid within the content area this.leftOffset = Math.max(0, (availableContentWidth - actualGridWidth) / 2); + + // Log layout info + console.log('Virtual Scroll Layout:', { + containerWidth, + availableContentWidth, + actualGridWidth, + columnsCount: this.columnsCount, + itemWidth: this.itemWidth, + itemHeight: this.itemHeight, + leftOffset: this.leftOffset, + paddingLeft, + paddingRight, + displayDensity, + maxColumns, + baseCardWidth, + rowGap: this.rowGap + }); // Update grid element max-width to match available width this.gridElement.style.maxWidth = `${actualGridWidth}px`; From 5288021e4f60e6205142537a973a1621bc9c6afa Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 23 Jul 2025 22:55:42 +0800 Subject: [PATCH 10/18] refactor: Simplify filtering methods and enhance CJK character handling in LoraService --- py/services/lora_scanner.py | 291 ------------------------------------ py/services/lora_service.py | 116 +++++++++----- 2 files changed, 78 insertions(+), 329 deletions(-) diff --git a/py/services/lora_scanner.py b/py/services/lora_scanner.py index 59ff58b5..fd2694db 100644 --- a/py/services/lora_scanner.py +++ b/py/services/lora_scanner.py @@ -7,9 +7,6 @@ from ..utils.models import LoraMetadata from ..config import config from .model_scanner import ModelScanner from .model_hash_index import ModelHashIndex # Changed from LoraHashIndex to ModelHashIndex -from .settings_manager import settings -from ..utils.constants import NSFW_LEVELS -from ..utils.utils import fuzzy_match import sys logger = logging.getLogger(__name__) @@ -114,260 +111,6 @@ class LoraScanner(ModelScanner): loras.append(result) except Exception as e: logger.error(f"Error processing {file_path}: {e}") - - async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name', - folder: str = None, search: str = None, fuzzy_search: bool = False, - base_models: list = None, tags: list = None, - search_options: dict = None, hash_filters: dict = None, - favorites_only: bool = False, first_letter: str = None) -> Dict: - """Get paginated and filtered lora data - - Args: - page: Current page number (1-based) - page_size: Number of items per page - sort_by: Sort method ('name' or 'date') - folder: Filter by folder path - search: Search term - fuzzy_search: Use fuzzy matching for search - base_models: List of base models to filter by - tags: List of tags to filter by - search_options: Dictionary with search options (filename, modelname, tags, recursive) - hash_filters: Dictionary with hash filtering options (single_hash or multiple_hashes) - favorites_only: Filter for favorite models only - first_letter: Filter by first letter of model name - """ - cache = await self.get_cached_data() - - # Get default search options if not provided - if search_options is None: - search_options = { - 'filename': True, - 'modelname': True, - 'tags': False, - 'recursive': False, - } - - # Get the base data set - filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name - - # Apply hash filtering if provided (highest priority) - if hash_filters: - single_hash = hash_filters.get('single_hash') - multiple_hashes = hash_filters.get('multiple_hashes') - - if single_hash: - # Filter by single hash - single_hash = single_hash.lower() # Ensure lowercase for matching - filtered_data = [ - lora for lora in filtered_data - if lora.get('sha256', '').lower() == single_hash - ] - elif multiple_hashes: - # Filter by multiple hashes - hash_set = set(hash.lower() for hash in multiple_hashes) # Convert to set for faster lookup - filtered_data = [ - lora for lora in filtered_data - if lora.get('sha256', '').lower() in hash_set - ] - - - # Jump to pagination - total_items = len(filtered_data) - start_idx = (page - 1) * page_size - end_idx = min(start_idx + page_size, total_items) - - result = { - 'items': filtered_data[start_idx:end_idx], - 'total': total_items, - 'page': page, - 'page_size': page_size, - 'total_pages': (total_items + page_size - 1) // page_size - } - - return result - - # Apply SFW filtering if enabled - if settings.get('show_only_sfw', False): - filtered_data = [ - lora for lora in filtered_data - if not lora.get('preview_nsfw_level') or lora.get('preview_nsfw_level') < NSFW_LEVELS['R'] - ] - - # Apply favorites filtering if enabled - if favorites_only: - filtered_data = [ - lora for lora in filtered_data - if lora.get('favorite', False) is True - ] - - # Apply first letter filtering - if first_letter: - filtered_data = self._filter_by_first_letter(filtered_data, first_letter) - - # Apply folder filtering - if folder is not None: - if search_options.get('recursive', False): - # Recursive folder filtering - include all subfolders - filtered_data = [ - lora for lora in filtered_data - if lora['folder'].startswith(folder) - ] - else: - # Exact folder filtering - filtered_data = [ - lora for lora in filtered_data - if lora['folder'] == folder - ] - - # Apply base model filtering - if base_models and len(base_models) > 0: - filtered_data = [ - lora for lora in filtered_data - if lora.get('base_model') in base_models - ] - - # Apply tag filtering - if tags and len(tags) > 0: - filtered_data = [ - lora for lora in filtered_data - if any(tag in lora.get('tags', []) for tag in tags) - ] - - # Apply search filtering - if search: - search_results = [] - search_opts = search_options or {} - - for lora in filtered_data: - # Search by file name - if search_opts.get('filename', True): - if fuzzy_match(lora.get('file_name', ''), search): - search_results.append(lora) - continue - - # Search by model name - if search_opts.get('modelname', True): - if fuzzy_match(lora.get('model_name', ''), search): - search_results.append(lora) - continue - - # Search by tags - if search_opts.get('tags', False) and 'tags' in lora: - if any(fuzzy_match(tag, search) for tag in lora['tags']): - search_results.append(lora) - continue - - filtered_data = search_results - - # Calculate pagination - total_items = len(filtered_data) - start_idx = (page - 1) * page_size - end_idx = min(start_idx + page_size, total_items) - - result = { - 'items': filtered_data[start_idx:end_idx], - 'total': total_items, - 'page': page, - 'page_size': page_size, - 'total_pages': (total_items + page_size - 1) // page_size - } - - return result - - def _filter_by_first_letter(self, data, letter): - """Filter data by first letter of model name - - Special handling: - - '#': Numbers (0-9) - - '@': Special characters (not alphanumeric) - - '漢': CJK characters - """ - filtered_data = [] - - for lora in data: - model_name = lora.get('model_name', '') - if not model_name: - continue - - first_char = model_name[0].upper() - - if letter == '#' and first_char.isdigit(): - filtered_data.append(lora) - elif letter == '@' and not first_char.isalnum(): - # Special characters (not alphanumeric) - filtered_data.append(lora) - elif letter == '漢' and self._is_cjk_character(first_char): - # CJK characters - filtered_data.append(lora) - elif letter.upper() == first_char: - # Regular alphabet matching - filtered_data.append(lora) - - return filtered_data - - def _is_cjk_character(self, char): - """Check if character is a CJK character""" - # Define Unicode ranges for CJK characters - cjk_ranges = [ - (0x4E00, 0x9FFF), # CJK Unified Ideographs - (0x3400, 0x4DBF), # CJK Unified Ideographs Extension A - (0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B - (0x2A700, 0x2B73F), # CJK Unified Ideographs Extension C - (0x2B740, 0x2B81F), # CJK Unified Ideographs Extension D - (0x2B820, 0x2CEAF), # CJK Unified Ideographs Extension E - (0x2CEB0, 0x2EBEF), # CJK Unified Ideographs Extension F - (0x30000, 0x3134F), # CJK Unified Ideographs Extension G - (0xF900, 0xFAFF), # CJK Compatibility Ideographs - (0x3300, 0x33FF), # CJK Compatibility - (0x3200, 0x32FF), # Enclosed CJK Letters and Months - (0x3100, 0x312F), # Bopomofo - (0x31A0, 0x31BF), # Bopomofo Extended - (0x3040, 0x309F), # Hiragana - (0x30A0, 0x30FF), # Katakana - (0x31F0, 0x31FF), # Katakana Phonetic Extensions - (0xAC00, 0xD7AF), # Hangul Syllables - (0x1100, 0x11FF), # Hangul Jamo - (0xA960, 0xA97F), # Hangul Jamo Extended-A - (0xD7B0, 0xD7FF), # Hangul Jamo Extended-B - ] - - code_point = ord(char) - return any(start <= code_point <= end for start, end in cjk_ranges) - - async def get_letter_counts(self): - """Get count of models for each letter of the alphabet""" - cache = await self.get_cached_data() - data = cache.sorted_by_name - - # Define letter categories - letters = { - '#': 0, # Numbers - 'A': 0, 'B': 0, 'C': 0, 'D': 0, 'E': 0, 'F': 0, 'G': 0, 'H': 0, - 'I': 0, 'J': 0, 'K': 0, 'L': 0, 'M': 0, 'N': 0, 'O': 0, 'P': 0, - 'Q': 0, 'R': 0, 'S': 0, 'T': 0, 'U': 0, 'V': 0, 'W': 0, 'X': 0, - 'Y': 0, 'Z': 0, - '@': 0, # Special characters - '漢': 0 # CJK characters - } - - # Count models for each letter - for lora in data: - model_name = lora.get('model_name', '') - if not model_name: - continue - - first_char = model_name[0].upper() - - if first_char.isdigit(): - letters['#'] += 1 - elif first_char in letters: - letters[first_char] += 1 - elif self._is_cjk_character(first_char): - letters['漢'] += 1 - elif not first_char.isalnum(): - letters['@'] += 1 - - return letters # Lora-specific hash index functionality def has_lora_hash(self, sha256: str) -> bool: @@ -382,40 +125,6 @@ class LoraScanner(ModelScanner): """Get hash for a LoRA by its file path""" return self.get_hash_by_path(file_path) - async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]: - """Get top tags sorted by count""" - # Make sure cache is initialized - await self.get_cached_data() - - # Sort tags by count in descending order - sorted_tags = sorted( - [{"tag": tag, "count": count} for tag, count in self._tags_count.items()], - key=lambda x: x['count'], - reverse=True - ) - - # Return limited number - return sorted_tags[:limit] - - async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]: - """Get base models used in loras sorted by frequency""" - # Make sure cache is initialized - cache = await self.get_cached_data() - - # Count base model occurrences - base_model_counts = {} - for lora in cache.raw_data: - if 'base_model' in lora and lora['base_model']: - base_model = lora['base_model'] - base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1 - - # Sort base models by count - sorted_models = [{'name': model, 'count': count} for model, count in base_model_counts.items()] - sorted_models.sort(key=lambda x: x['count'], reverse=True) - - # Return limited number - return sorted_models[:limit] - async def diagnose_hash_index(self): """Diagnostic method to verify hash index functionality""" print("\n\n*** DIAGNOSING LORA HASH INDEX ***\n\n", file=sys.stderr) diff --git a/py/services/lora_service.py b/py/services/lora_service.py index bcfa84c5..1623f571 100644 --- a/py/services/lora_service.py +++ b/py/services/lora_service.py @@ -52,60 +52,100 @@ class LoraService(BaseModelService): return data def _filter_by_first_letter(self, data: List[Dict], letter: str) -> List[Dict]: - """Filter LoRAs by first letter""" - if letter == '#': - # Filter for non-alphabetic characters - return [ - item for item in data - if not item.get('model_name', '')[0].isalpha() - ] - elif letter == 'CJK': - # Filter for CJK characters - return [ - item for item in data - if item.get('model_name', '') and self._is_cjk_character(item['model_name'][0]) - ] - else: - # Filter for specific letter - return [ - item for item in data - if item.get('model_name', '').lower().startswith(letter.lower()) - ] + """Filter data by first letter of model name + + Special handling: + - '#': Numbers (0-9) + - '@': Special characters (not alphanumeric) + - '漢': CJK characters + """ + filtered_data = [] + + for lora in data: + model_name = lora.get('model_name', '') + if not model_name: + continue + + first_char = model_name[0].upper() + + if letter == '#' and first_char.isdigit(): + filtered_data.append(lora) + elif letter == '@' and not first_char.isalnum(): + # Special characters (not alphanumeric) + filtered_data.append(lora) + elif letter == '漢' and self._is_cjk_character(first_char): + # CJK characters + filtered_data.append(lora) + elif letter.upper() == first_char: + # Regular alphabet matching + filtered_data.append(lora) + + return filtered_data def _is_cjk_character(self, char: str) -> bool: - """Check if character is CJK (Chinese, Japanese, Korean)""" + """Check if character is a CJK character""" + # Define Unicode ranges for CJK characters cjk_ranges = [ (0x4E00, 0x9FFF), # CJK Unified Ideographs - (0x3400, 0x4DBF), # CJK Extension A - (0x20000, 0x2A6DF), # CJK Extension B - (0x2A700, 0x2B73F), # CJK Extension C - (0x2B740, 0x2B81F), # CJK Extension D + (0x3400, 0x4DBF), # CJK Unified Ideographs Extension A + (0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B + (0x2A700, 0x2B73F), # CJK Unified Ideographs Extension C + (0x2B740, 0x2B81F), # CJK Unified Ideographs Extension D + (0x2B820, 0x2CEAF), # CJK Unified Ideographs Extension E + (0x2CEB0, 0x2EBEF), # CJK Unified Ideographs Extension F + (0x30000, 0x3134F), # CJK Unified Ideographs Extension G + (0xF900, 0xFAFF), # CJK Compatibility Ideographs + (0x3300, 0x33FF), # CJK Compatibility + (0x3200, 0x32FF), # Enclosed CJK Letters and Months + (0x3100, 0x312F), # Bopomofo + (0x31A0, 0x31BF), # Bopomofo Extended (0x3040, 0x309F), # Hiragana (0x30A0, 0x30FF), # Katakana + (0x31F0, 0x31FF), # Katakana Phonetic Extensions (0xAC00, 0xD7AF), # Hangul Syllables + (0x1100, 0x11FF), # Hangul Jamo + (0xA960, 0xA97F), # Hangul Jamo Extended-A + (0xD7B0, 0xD7FF), # Hangul Jamo Extended-B ] - char_code = ord(char) - return any(start <= char_code <= end for start, end in cjk_ranges) + code_point = ord(char) + return any(start <= code_point <= end for start, end in cjk_ranges) # LoRA-specific methods async def get_letter_counts(self) -> Dict[str, int]: """Get count of LoRAs for each letter of the alphabet""" cache = await self.scanner.get_cached_data() - letter_counts = {} + data = cache.sorted_by_name - for lora in cache.raw_data: + # Define letter categories + letters = { + '#': 0, # Numbers + 'A': 0, 'B': 0, 'C': 0, 'D': 0, 'E': 0, 'F': 0, 'G': 0, 'H': 0, + 'I': 0, 'J': 0, 'K': 0, 'L': 0, 'M': 0, 'N': 0, 'O': 0, 'P': 0, + 'Q': 0, 'R': 0, 'S': 0, 'T': 0, 'U': 0, 'V': 0, 'W': 0, 'X': 0, + 'Y': 0, 'Z': 0, + '@': 0, # Special characters + '漢': 0 # CJK characters + } + + # Count models for each letter + for lora in data: model_name = lora.get('model_name', '') - if model_name: - first_char = model_name[0].upper() - if first_char.isalpha(): - letter_counts[first_char] = letter_counts.get(first_char, 0) + 1 - elif self._is_cjk_character(first_char): - letter_counts['CJK'] = letter_counts.get('CJK', 0) + 1 - else: - letter_counts['#'] = letter_counts.get('#', 0) + 1 - - return letter_counts + if not model_name: + continue + + first_char = model_name[0].upper() + + if first_char.isdigit(): + letters['#'] += 1 + elif first_char in letters: + letters[first_char] += 1 + elif self._is_cjk_character(first_char): + letters['漢'] += 1 + elif not first_char.isalnum(): + letters['@'] += 1 + + return letters async def get_lora_notes(self, lora_name: str) -> Optional[str]: """Get notes for a specific LoRA file""" From 68d00ce28933bc94a99b8ccb31ea23496049bf45 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 23 Jul 2025 22:58:40 +0800 Subject: [PATCH 11/18] refactor: Adjust logging configuration to reduce verbosity for asyncio logger --- py/lora_manager.py | 2 ++ standalone.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/py/lora_manager.py b/py/lora_manager.py index dc54d8f4..f2fa3f4b 100644 --- a/py/lora_manager.py +++ b/py/lora_manager.py @@ -32,6 +32,7 @@ class LoraManager: # Configure aiohttp access logger to be less verbose logging.getLogger('aiohttp.access').setLevel(logging.WARNING) + logging.getLogger("asyncio").setLevel(logging.WARNING) added_targets = set() # Track already added target paths @@ -142,6 +143,7 @@ class LoraManager: try: # Ensure aiohttp access logger is configured with reduced verbosity logging.getLogger('aiohttp.access').setLevel(logging.WARNING) + logging.getLogger("asyncio").setLevel(logging.WARNING) # Initialize CivitaiClient first to ensure it's ready for other services await ServiceRegistry.get_civitai_client() diff --git a/standalone.py b/standalone.py index 2ad50ade..9421e00b 100644 --- a/standalone.py +++ b/standalone.py @@ -105,6 +105,7 @@ logger = logging.getLogger("lora-manager-standalone") # Configure aiohttp access logger to be less verbose logging.getLogger('aiohttp.access').setLevel(logging.WARNING) +logging.getLogger("asyncio").setLevel(logging.WARNING) # Now we can import the global config from our local modules from py.config import config @@ -125,6 +126,7 @@ class StandaloneServer: async def _configure_access_logger(self, app): """Configure access logger to reduce verbosity""" logging.getLogger('aiohttp.access').setLevel(logging.WARNING) + logging.getLogger("asyncio").setLevel(logging.WARNING) # If using aiohttp>=3.8.0, configure access logger through app directly if hasattr(app, 'access_logger'): @@ -221,6 +223,7 @@ class StandaloneLoraManager(LoraManager): # Configure aiohttp access logger to be less verbose logging.getLogger('aiohttp.access').setLevel(logging.WARNING) + logging.getLogger("asyncio").setLevel(logging.WARNING) added_targets = set() # Track already added target paths @@ -373,6 +376,7 @@ async def main(): # Explicitly configure aiohttp access logger regardless of selected log level logging.getLogger('aiohttp.access').setLevel(logging.WARNING) + logging.getLogger("asyncio").setLevel(logging.WARNING) # Create the server instance server = StandaloneServer() From bf9aa9356bdd180f02e398dac6d18c809b15f636 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 23 Jul 2025 23:27:18 +0800 Subject: [PATCH 12/18] refactor: Update model retrieval methods in RecipeRoutes and streamline CheckpointScanner and LoraScanner initialization --- py/routes/recipe_routes.py | 6 +- py/services/checkpoint_scanner.py | 118 +++-------------------------- py/services/lora_scanner.py | 122 +++--------------------------- py/services/model_scanner.py | 92 +++++++++++++++++++++- 4 files changed, 113 insertions(+), 225 deletions(-) diff --git a/py/routes/recipe_routes.py b/py/routes/recipe_routes.py index d181ed65..07fcb287 100644 --- a/py/routes/recipe_routes.py +++ b/py/routes/recipe_routes.py @@ -1148,7 +1148,7 @@ class RecipeRoutes: for lora_name, lora_strength in lora_matches: try: # Get lora info from scanner - lora_info = await self.recipe_scanner._lora_scanner.get_lora_info_by_name(lora_name) + lora_info = await self.recipe_scanner._lora_scanner.get_model_info_by_name(lora_name) # Create lora entry lora_entry = { @@ -1167,7 +1167,7 @@ class RecipeRoutes: # Get base model from lora scanner for the available loras base_model_counts = {} for lora in loras_data: - lora_info = await self.recipe_scanner._lora_scanner.get_lora_info_by_name(lora.get("file_name", "")) + lora_info = await self.recipe_scanner._lora_scanner.get_model_info_by_name(lora.get("file_name", "")) if lora_info and "base_model" in lora_info: base_model = lora_info["base_model"] base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1 @@ -1365,7 +1365,7 @@ class RecipeRoutes: return web.json_response({"error": "Recipe not found"}, status=404) # Find target LoRA by name - target_lora = await lora_scanner.get_lora_info_by_name(target_name) + target_lora = await lora_scanner.get_model_info_by_name(target_name) if not target_lora: return web.json_response({"error": f"Local LoRA not found with name: {target_name}"}, status=404) diff --git a/py/services/checkpoint_scanner.py b/py/services/checkpoint_scanner.py index 32da3dbf..95569d4f 100644 --- a/py/services/checkpoint_scanner.py +++ b/py/services/checkpoint_scanner.py @@ -1,7 +1,5 @@ -import os import logging -import asyncio -from typing import List, Dict +from typing import List from ..utils.models import CheckpointMetadata from ..config import config @@ -13,101 +11,19 @@ logger = logging.getLogger(__name__) class CheckpointScanner(ModelScanner): """Service for scanning and managing checkpoint files""" - _instance = None - _lock = asyncio.Lock() - - def __new__(cls): - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - def __init__(self): - if not hasattr(self, '_initialized'): - # Define supported file extensions - file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'} - super().__init__( - model_type="checkpoint", - model_class=CheckpointMetadata, - file_extensions=file_extensions, - hash_index=ModelHashIndex() - ) - self._initialized = True + # Define supported file extensions + file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'} + super().__init__( + model_type="checkpoint", + model_class=CheckpointMetadata, + file_extensions=file_extensions, + hash_index=ModelHashIndex() + ) - @classmethod - async def get_instance(cls): - """Get singleton instance with async support""" - async with cls._lock: - if cls._instance is None: - cls._instance = cls() - return cls._instance - def get_model_roots(self) -> List[str]: """Get checkpoint root directories""" return config.base_models_roots - - async def scan_all_models(self) -> List[Dict]: - """Scan all checkpoint directories and return metadata""" - all_checkpoints = [] - - # Create scan tasks for each directory - scan_tasks = [] - for root in self.get_model_roots(): - task = asyncio.create_task(self._scan_directory(root)) - scan_tasks.append(task) - - # Wait for all tasks to complete - for task in scan_tasks: - try: - checkpoints = await task - all_checkpoints.extend(checkpoints) - except Exception as e: - logger.error(f"Error scanning checkpoint directory: {e}") - - return all_checkpoints - - async def _scan_directory(self, root_path: str) -> List[Dict]: - """Scan a directory for checkpoint files""" - checkpoints = [] - original_root = root_path - - async def scan_recursive(path: str, visited_paths: set): - try: - real_path = os.path.realpath(path) - if real_path in visited_paths: - logger.debug(f"Skipping already visited path: {path}") - return - visited_paths.add(real_path) - - with os.scandir(path) as it: - entries = list(it) - for entry in entries: - try: - if entry.is_file(follow_symlinks=True): - # Check if file has supported extension - ext = os.path.splitext(entry.name)[1].lower() - if ext in self.file_extensions: - file_path = entry.path.replace(os.sep, "/") - await self._process_single_file(file_path, original_root, checkpoints) - await asyncio.sleep(0) - elif entry.is_dir(follow_symlinks=True): - # For directories, continue scanning with original path - await scan_recursive(entry.path, visited_paths) - except Exception as e: - logger.error(f"Error processing entry {entry.path}: {e}") - except Exception as e: - logger.error(f"Error scanning {path}: {e}") - - await scan_recursive(root_path, set()) - return checkpoints - - async def _process_single_file(self, file_path: str, root_path: str, checkpoints: list): - """Process a single checkpoint file and add to results""" - try: - result = await self._process_model_file(file_path, root_path) - if result: - checkpoints.append(result) - except Exception as e: - logger.error(f"Error processing {file_path}: {e}") # Checkpoint-specific hash index functionality def has_checkpoint_hash(self, sha256: str) -> bool: @@ -120,18 +36,4 @@ class CheckpointScanner(ModelScanner): def get_checkpoint_hash_by_path(self, file_path: str) -> str: """Get hash for a checkpoint by its file path""" - return self.get_hash_by_path(file_path) - - async def get_checkpoint_info_by_name(self, name): - """Get checkpoint information by name""" - try: - cache = await self.get_cached_data() - - for checkpoint in cache.raw_data: - if checkpoint.get("file_name") == name: - return checkpoint - - return None - except Exception as e: - logger.error(f"Error getting checkpoint info by name: {e}", exc_info=True) - return None \ No newline at end of file + return self.get_hash_by_path(file_path) \ No newline at end of file diff --git a/py/services/lora_scanner.py b/py/services/lora_scanner.py index fd2694db..f4066dbe 100644 --- a/py/services/lora_scanner.py +++ b/py/services/lora_scanner.py @@ -1,7 +1,5 @@ -import os import logging -import asyncio -from typing import List, Dict, Optional +from typing import List, Optional from ..utils.models import LoraMetadata from ..config import config @@ -14,103 +12,21 @@ logger = logging.getLogger(__name__) class LoraScanner(ModelScanner): """Service for scanning and managing LoRA files""" - _instance = None - _lock = asyncio.Lock() - - def __new__(cls): - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - def __init__(self): - # Ensure initialization happens only once - if not hasattr(self, '_initialized'): - # Define supported file extensions - file_extensions = {'.safetensors'} - - # Initialize parent class with ModelHashIndex - super().__init__( - model_type="lora", - model_class=LoraMetadata, - file_extensions=file_extensions, - hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex - ) - self._initialized = True - - @classmethod - async def get_instance(cls): - """Get singleton instance with async support""" - async with cls._lock: - if cls._instance is None: - cls._instance = cls() - return cls._instance + # Define supported file extensions + file_extensions = {'.safetensors'} + + # Initialize parent class with ModelHashIndex + super().__init__( + model_type="lora", + model_class=LoraMetadata, + file_extensions=file_extensions, + hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex + ) def get_model_roots(self) -> List[str]: """Get lora root directories""" return config.loras_roots - - async def scan_all_models(self) -> List[Dict]: - """Scan all LoRA directories and return metadata""" - all_loras = [] - - # Create scan tasks for each directory - scan_tasks = [] - for lora_root in self.get_model_roots(): - task = asyncio.create_task(self._scan_directory(lora_root)) - scan_tasks.append(task) - - # Wait for all tasks to complete - for task in scan_tasks: - try: - loras = await task - all_loras.extend(loras) - except Exception as e: - logger.error(f"Error scanning directory: {e}") - - return all_loras - - async def _scan_directory(self, root_path: str) -> List[Dict]: - """Scan a single directory for LoRA files""" - loras = [] - original_root = root_path # Save original root path - - async def scan_recursive(path: str, visited_paths: set): - """Recursively scan directory, avoiding circular symlinks""" - try: - real_path = os.path.realpath(path) - if real_path in visited_paths: - logger.debug(f"Skipping already visited path: {path}") - return - visited_paths.add(real_path) - - with os.scandir(path) as it: - entries = list(it) - for entry in entries: - try: - if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions): - # Use original path instead of real path - file_path = entry.path.replace(os.sep, "/") - await self._process_single_file(file_path, original_root, loras) - await asyncio.sleep(0) - elif entry.is_dir(follow_symlinks=True): - # For directories, continue scanning with original path - await scan_recursive(entry.path, visited_paths) - except Exception as e: - logger.error(f"Error processing entry {entry.path}: {e}") - except Exception as e: - logger.error(f"Error scanning {path}: {e}") - - await scan_recursive(root_path, set()) - return loras - - async def _process_single_file(self, file_path: str, root_path: str, loras: list): - """Process a single file and add to results list""" - try: - result = await self._process_model_file(file_path, root_path) - if result: - loras.append(result) - except Exception as e: - logger.error(f"Error processing {file_path}: {e}") # Lora-specific hash index functionality def has_lora_hash(self, sha256: str) -> bool: @@ -160,19 +76,3 @@ class LoraScanner(ModelScanner): test_hash_result = self._hash_index.get_hash(test_path) print(f"Test reverse lookup: {test_path} -> {test_hash_result[:8]}...\n\n", file=sys.stderr) - async def get_lora_info_by_name(self, name): - """Get LoRA information by name""" - try: - # Get cached data - cache = await self.get_cached_data() - - # Find the LoRA by name - for lora in cache.raw_data: - if lora.get("file_name") == name: - return lora - - return None - except Exception as e: - logger.error(f"Error getting LoRA info by name: {e}", exc_info=True) - return None - diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index fdd9c020..a31bff42 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -29,7 +29,30 @@ CACHE_VERSION = 3 class ModelScanner: """Base service for scanning and managing model files""" - _lock = asyncio.Lock() + _instances = {} # Dictionary to store instances by class + _locks = {} # Dictionary to store locks by class + + def __new__(cls, *args, **kwargs): + """Implement singleton pattern for each subclass""" + if cls not in cls._instances: + cls._instances[cls] = super().__new__(cls) + return cls._instances[cls] + + @classmethod + def _get_lock(cls): + """Get or create a lock for this class""" + if cls not in cls._locks: + cls._locks[cls] = asyncio.Lock() + return cls._locks[cls] + + @classmethod + async def get_instance(cls): + """Get singleton instance with async support""" + lock = cls._get_lock() + async with lock: + if cls not in cls._instances: + cls._instances[cls] = cls() + return cls._instances[cls] def __init__(self, model_type: str, model_class: Type[BaseModelMetadata], file_extensions: Set[str], hash_index: Optional[ModelHashIndex] = None): """Initialize the scanner @@ -40,6 +63,10 @@ class ModelScanner: file_extensions: Set of supported file extensions including the dot (e.g. {'.safetensors'}) hash_index: Hash index instance (optional) """ + # Ensure initialization happens only once per instance + if hasattr(self, '_initialized'): + return + self.model_type = model_type self.model_class = model_class self.file_extensions = file_extensions @@ -50,6 +77,7 @@ class ModelScanner: self._excluded_models = [] # List to track excluded models self._dirs_last_modified = {} # Track directory modification times self._use_cache_files = False # Flag to control cache file usage, default to disabled + self._initialized = True # Clear cache files if disabled if not self._use_cache_files: @@ -744,10 +772,68 @@ class ModelScanner: finally: self._is_initializing = False # Unset flag - # These methods should be implemented in child classes async def scan_all_models(self) -> List[Dict]: """Scan all model directories and return metadata""" - raise NotImplementedError("Subclasses must implement scan_all_models") + all_models = [] + + # Create scan tasks for each directory + scan_tasks = [] + for model_root in self.get_model_roots(): + task = asyncio.create_task(self._scan_directory(model_root)) + scan_tasks.append(task) + + # Wait for all tasks to complete + for task in scan_tasks: + try: + models = await task + all_models.extend(models) + except Exception as e: + logger.error(f"Error scanning directory: {e}") + + return all_models + + async def _scan_directory(self, root_path: str) -> List[Dict]: + """Scan a single directory for model files""" + models = [] + original_root = root_path # Save original root path + + async def scan_recursive(path: str, visited_paths: set): + """Recursively scan directory, avoiding circular symlinks""" + try: + real_path = os.path.realpath(path) + if real_path in visited_paths: + logger.debug(f"Skipping already visited path: {path}") + return + visited_paths.add(real_path) + + with os.scandir(path) as it: + entries = list(it) + for entry in entries: + try: + if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions): + # Use original path instead of real path + file_path = entry.path.replace(os.sep, "/") + await self._process_single_file(file_path, original_root, models) + await asyncio.sleep(0) + elif entry.is_dir(follow_symlinks=True): + # For directories, continue scanning with original path + await scan_recursive(entry.path, visited_paths) + except Exception as e: + logger.error(f"Error processing entry {entry.path}: {e}") + except Exception as e: + logger.error(f"Error scanning {path}: {e}") + + await scan_recursive(root_path, set()) + return models + + async def _process_single_file(self, file_path: str, root_path: str, models: list): + """Process a single file and add to results list""" + try: + result = await self._process_model_file(file_path, root_path) + if result: + models.append(result) + except Exception as e: + logger.error(f"Error processing {file_path}: {e}") def is_initializing(self) -> bool: """Check if the scanner is currently initializing""" From cf9fd2d5c20d3ba6a5afb523ee0e6fd10c286966 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Thu, 24 Jul 2025 06:25:33 +0800 Subject: [PATCH 13/18] refactor: Rename LoraScanner methods for consistency and remove deprecated checkpoint methods --- py/routes/recipe_routes.py | 6 +++--- py/services/checkpoint_scanner.py | 15 +-------------- py/services/lora_scanner.py | 15 +-------------- py/services/recipe_scanner.py | 14 +++++++------- 4 files changed, 12 insertions(+), 38 deletions(-) diff --git a/py/routes/recipe_routes.py b/py/routes/recipe_routes.py index 07fcb287..cf1c67c2 100644 --- a/py/routes/recipe_routes.py +++ b/py/routes/recipe_routes.py @@ -1257,7 +1257,7 @@ class RecipeRoutes: if lora.get("isDeleted", False): continue - if not self.recipe_scanner._lora_scanner.has_lora_hash(lora.get("hash", "")): + if not self.recipe_scanner._lora_scanner.has_hash(lora.get("hash", "")): continue # Get the strength @@ -1477,9 +1477,9 @@ class RecipeRoutes: if 'loras' in recipe: for lora in recipe['loras']: if 'hash' in lora and lora['hash']: - lora['inLibrary'] = self.recipe_scanner._lora_scanner.has_lora_hash(lora['hash'].lower()) + lora['inLibrary'] = self.recipe_scanner._lora_scanner.has_hash(lora['hash'].lower()) lora['preview_url'] = self.recipe_scanner._lora_scanner.get_preview_url_by_hash(lora['hash'].lower()) - lora['localPath'] = self.recipe_scanner._lora_scanner.get_lora_path_by_hash(lora['hash'].lower()) + lora['localPath'] = self.recipe_scanner._lora_scanner.get_path_by_hash(lora['hash'].lower()) # Ensure file_url is set (needed by frontend) if 'file_path' in recipe: diff --git a/py/services/checkpoint_scanner.py b/py/services/checkpoint_scanner.py index 95569d4f..d4696631 100644 --- a/py/services/checkpoint_scanner.py +++ b/py/services/checkpoint_scanner.py @@ -23,17 +23,4 @@ class CheckpointScanner(ModelScanner): def get_model_roots(self) -> List[str]: """Get checkpoint root directories""" - return config.base_models_roots - - # Checkpoint-specific hash index functionality - def has_checkpoint_hash(self, sha256: str) -> bool: - """Check if a checkpoint with given hash exists""" - return self.has_hash(sha256) - - def get_checkpoint_path_by_hash(self, sha256: str) -> str: - """Get file path for a checkpoint by its hash""" - return self.get_path_by_hash(sha256) - - def get_checkpoint_hash_by_path(self, file_path: str) -> str: - """Get hash for a checkpoint by its file path""" - return self.get_hash_by_path(file_path) \ No newline at end of file + return config.base_models_roots \ No newline at end of file diff --git a/py/services/lora_scanner.py b/py/services/lora_scanner.py index f4066dbe..6feff477 100644 --- a/py/services/lora_scanner.py +++ b/py/services/lora_scanner.py @@ -1,5 +1,5 @@ import logging -from typing import List, Optional +from typing import List from ..utils.models import LoraMetadata from ..config import config @@ -28,19 +28,6 @@ class LoraScanner(ModelScanner): """Get lora root directories""" return config.loras_roots - # Lora-specific hash index functionality - def has_lora_hash(self, sha256: str) -> bool: - """Check if a LoRA with given hash exists""" - return self.has_hash(sha256) - - def get_lora_path_by_hash(self, sha256: str) -> Optional[str]: - """Get file path for a LoRA by its hash""" - return self.get_path_by_hash(sha256) - - def get_lora_hash_by_path(self, file_path: str) -> Optional[str]: - """Get hash for a LoRA by its file path""" - return self.get_hash_by_path(file_path) - async def diagnose_hash_index(self): """Diagnostic method to verify hash index functionality""" print("\n\n*** DIAGNOSING LORA HASH INDEX ***\n\n", file=sys.stderr) diff --git a/py/services/recipe_scanner.py b/py/services/recipe_scanner.py index 6d8491ae..89bbef14 100644 --- a/py/services/recipe_scanner.py +++ b/py/services/recipe_scanner.py @@ -393,8 +393,8 @@ class RecipeScanner: if 'hash' in lora and (not lora.get('file_name') or not lora['file_name']): hash_value = lora['hash'] - if self._lora_scanner.has_lora_hash(hash_value): - lora_path = self._lora_scanner.get_lora_path_by_hash(hash_value) + if self._lora_scanner.has_hash(hash_value): + lora_path = self._lora_scanner.get_path_by_hash(hash_value) if lora_path: file_name = os.path.splitext(os.path.basename(lora_path))[0] lora['file_name'] = file_name @@ -465,7 +465,7 @@ class RecipeScanner: # Count occurrences of each base model for lora in loras: if 'hash' in lora: - lora_path = self._lora_scanner.get_lora_path_by_hash(lora['hash']) + lora_path = self._lora_scanner.get_path_by_hash(lora['hash']) if lora_path: base_model = await self._get_base_model_for_lora(lora_path) if base_model: @@ -603,9 +603,9 @@ class RecipeScanner: if 'loras' in item: for lora in item['loras']: if 'hash' in lora and lora['hash']: - lora['inLibrary'] = self._lora_scanner.has_lora_hash(lora['hash'].lower()) + lora['inLibrary'] = self._lora_scanner.has_hash(lora['hash'].lower()) lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora['hash'].lower()) - lora['localPath'] = self._lora_scanner.get_lora_path_by_hash(lora['hash'].lower()) + lora['localPath'] = self._lora_scanner.get_path_by_hash(lora['hash'].lower()) result = { 'items': paginated_items, @@ -655,9 +655,9 @@ class RecipeScanner: for lora in formatted_recipe['loras']: if 'hash' in lora and lora['hash']: lora_hash = lora['hash'].lower() - lora['inLibrary'] = self._lora_scanner.has_lora_hash(lora_hash) + lora['inLibrary'] = self._lora_scanner.has_hash(lora_hash) lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora_hash) - lora['localPath'] = self._lora_scanner.get_lora_path_by_hash(lora_hash) + lora['localPath'] = self._lora_scanner.get_path_by_hash(lora_hash) return formatted_recipe From e8ccdabe6c313dda6295cf243b48987027bfb948 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Thu, 24 Jul 2025 09:26:15 +0800 Subject: [PATCH 14/18] refactor: Enhance sorting functionality and UI for model selection, including legacy format conversion --- py/services/base_model_service.py | 19 ++++- py/services/lora_service.py | 2 +- py/services/model_cache.py | 82 +++++++++++++++---- py/services/model_scanner.py | 10 --- static/css/layout.css | 25 ++++++ static/js/components/controls/PageControls.js | 31 ++++++- templates/components/controls.html | 14 +++- 7 files changed, 146 insertions(+), 37 deletions(-) diff --git a/py/services/base_model_service.py b/py/services/base_model_service.py index 7ecc994b..72242daa 100644 --- a/py/services/base_model_service.py +++ b/py/services/base_model_service.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Type, Set +from typing import Dict, List, Optional, Type import logging from ..utils.models import BaseModelMetadata @@ -34,7 +34,7 @@ class BaseModelService(ABC): Args: page: Page number (1-based) page_size: Number of items per page - sort_by: Sort criteria ('name' or 'date') + sort_by: Sort criteria, e.g. 'name', 'name:asc', 'name:desc', 'date', 'date:asc', 'date:desc' folder: Folder filter search: Search term fuzzy_search: Whether to use fuzzy search @@ -50,6 +50,17 @@ class BaseModelService(ABC): """ cache = await self.scanner.get_cached_data() + # Parse sort_by into sort_key and order + if ':' in sort_by: + sort_key, order = sort_by.split(':', 1) + sort_key = sort_key.strip() + order = order.strip().lower() + if order not in ('asc', 'desc'): + order = 'asc' + else: + sort_key = sort_by.strip() + order = 'asc' + # Get default search options if not provided if search_options is None: search_options = { @@ -59,8 +70,8 @@ class BaseModelService(ABC): 'recursive': False, } - # Get the base data set - filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name + # Get the base data set using new sort logic + filtered_data = await cache.get_sorted_data(sort_key, order) # Apply hash filtering if provided (highest priority) if hash_filters: diff --git a/py/services/lora_service.py b/py/services/lora_service.py index 1623f571..7649f75b 100644 --- a/py/services/lora_service.py +++ b/py/services/lora_service.py @@ -115,7 +115,7 @@ class LoraService(BaseModelService): async def get_letter_counts(self) -> Dict[str, int]: """Get count of LoRAs for each letter of the alphabet""" cache = await self.scanner.get_cached_data() - data = cache.sorted_by_name + data = cache.raw_data # Define letter categories letters = { diff --git a/py/services/model_cache.py b/py/services/model_cache.py index 8494531e..f67b2444 100644 --- a/py/services/model_cache.py +++ b/py/services/model_cache.py @@ -1,37 +1,85 @@ import asyncio -from typing import List, Dict +from typing import List, Dict, Tuple from dataclasses import dataclass from operator import itemgetter from natsort import natsorted +# Supported sort modes: (sort_key, order) +# order: 'asc' for ascending, 'desc' for descending +SUPPORTED_SORT_MODES = [ + ('name', 'asc'), + ('name', 'desc'), + ('date', 'asc'), + ('date', 'desc'), + ('size', 'asc'), + ('size', 'desc'), +] + @dataclass class ModelCache: - """Cache structure for model data""" + """Cache structure for model data with extensible sorting""" raw_data: List[Dict] - sorted_by_name: List[Dict] - sorted_by_date: List[Dict] folders: List[str] def __post_init__(self): self._lock = asyncio.Lock() + # Cache for last sort: (sort_key, order) -> sorted list + self._last_sort: Tuple[str, str] = (None, None) + self._last_sorted_data: List[Dict] = [] + # Default sort on init + asyncio.create_task(self.resort()) - async def resort(self, name_only: bool = False): - """Resort all cached data views""" + async def resort(self): + """Resort cached data according to last sort mode if set""" async with self._lock: - self.sorted_by_name = natsorted( - self.raw_data, - key=lambda x: x['model_name'].lower() # Case-insensitive sort - ) - if not name_only: - self.sorted_by_date = sorted( - self.raw_data, - key=itemgetter('modified'), - reverse=True - ) - # Update folder list + if self._last_sort != (None, None): + sort_key, order = self._last_sort + sorted_data = self._sort_data(self.raw_data, sort_key, order) + self._last_sorted_data = sorted_data + # Update folder list + # else: do nothing + all_folders = set(l['folder'] for l in self.raw_data) self.folders = sorted(list(all_folders), key=lambda x: x.lower()) + def _sort_data(self, data: List[Dict], sort_key: str, order: str) -> List[Dict]: + """Sort data by sort_key and order""" + reverse = (order == 'desc') + if sort_key == 'name': + # Natural sort by model_name, case-insensitive + return natsorted( + data, + key=lambda x: x['model_name'].lower(), + reverse=reverse + ) + elif sort_key == 'date': + # Sort by modified timestamp + return sorted( + data, + key=itemgetter('modified'), + reverse=reverse + ) + elif sort_key == 'size': + # Sort by file size + return sorted( + data, + key=itemgetter('size'), + reverse=reverse + ) + else: + # Fallback: no sort + return list(data) + + async def get_sorted_data(self, sort_key: str = 'name', order: str = 'asc') -> List[Dict]: + """Get sorted data by sort_key and order, using cache if possible""" + async with self._lock: + if (sort_key, order) == self._last_sort: + return self._last_sorted_data + sorted_data = self._sort_data(self.raw_data, sort_key, order) + self._last_sort = (sort_key, order) + self._last_sorted_data = sorted_data + return sorted_data + async def update_preview_url(self, file_path: str, preview_url: str, preview_nsfw_level: int) -> bool: """Update preview_url for a specific model in all cached data diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index a31bff42..0d41fe11 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -246,8 +246,6 @@ class ModelScanner: # Load data into memory self._cache = ModelCache( raw_data=cache_data["raw_data"], - sorted_by_name=[], - sorted_by_date=[], folders=[] ) @@ -280,8 +278,6 @@ class ModelScanner: if self._cache is None: self._cache = ModelCache( raw_data=[], - sorted_by_name=[], - sorted_by_date=[], folders=[] ) @@ -544,8 +540,6 @@ class ModelScanner: if self._cache is None and not force_refresh: return ModelCache( raw_data=[], - sorted_by_name=[], - sorted_by_date=[], folders=[] ) @@ -605,8 +599,6 @@ class ModelScanner: # Update cache self._cache = ModelCache( raw_data=raw_data, - sorted_by_name=[], - sorted_by_date=[], folders=[] ) @@ -620,8 +612,6 @@ class ModelScanner: if self._cache is None: self._cache = ModelCache( raw_data=[], - sorted_by_name=[], - sorted_by_date=[], folders=[] ) finally: diff --git a/static/css/layout.css b/static/css/layout.css index b948a5f2..182e50b5 100644 --- a/static/css/layout.css +++ b/static/css/layout.css @@ -182,6 +182,31 @@ box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05); } +/* Style for optgroups */ +.control-group select optgroup { + font-weight: 600; + font-style: normal; + color: var(--text-color); + background-color: var(--card-bg); +} + +.control-group select option { + padding: 4px 8px; + background-color: var(--card-bg); + color: var(--text-color); +} + +/* Dark theme optgroup styling */ +[data-theme="dark"] .control-group select optgroup { + background-color: var(--card-bg); + color: var(--text-color); +} + +[data-theme="dark"] .control-group select option { + background-color: var(--card-bg); + color: var(--text-color); +} + .control-group select:hover { border-color: var(--lora-accent); background-color: var(--bg-color); diff --git a/static/js/components/controls/PageControls.js b/static/js/components/controls/PageControls.js index 42ad0062..53a2d8ef 100644 --- a/static/js/components/controls/PageControls.js +++ b/static/js/components/controls/PageControls.js @@ -1,5 +1,5 @@ // PageControls.js - Manages controls for both LoRAs and Checkpoints pages -import { state, getCurrentPageState, setCurrentPageType } from '../../state/index.js'; +import { getCurrentPageState, setCurrentPageType } from '../../state/index.js'; import { getStorageItem, setStorageItem, getSessionItem, setSessionItem } from '../../utils/storageHelpers.js'; import { showToast } from '../../utils/uiHelpers.js'; @@ -41,6 +41,9 @@ export class PageControls { this.pageState.isLoading = false; this.pageState.hasMore = true; + // Set default sort based on page type + this.pageState.sortBy = this.pageType === 'loras' ? 'name:asc' : 'name:asc'; + // Load sort preference this.loadSortPreference(); } @@ -326,14 +329,36 @@ export class PageControls { loadSortPreference() { const savedSort = getStorageItem(`${this.pageType}_sort`); if (savedSort) { - this.pageState.sortBy = savedSort; + // Handle legacy format conversion + const convertedSort = this.convertLegacySortFormat(savedSort); + this.pageState.sortBy = convertedSort; const sortSelect = document.getElementById('sortSelect'); if (sortSelect) { - sortSelect.value = savedSort; + sortSelect.value = convertedSort; } } } + /** + * Convert legacy sort format to new format + * @param {string} sortValue - The sort value to convert + * @returns {string} - Converted sort value + */ + convertLegacySortFormat(sortValue) { + // Convert old format to new format with direction + switch (sortValue) { + case 'name': + return 'name:asc'; + case 'date': + return 'date:desc'; // Newest first is more intuitive default + case 'size': + return 'size:desc'; // Largest first is more intuitive default + default: + // If it's already in new format or unknown, return as is + return sortValue.includes(':') ? sortValue : 'name:asc'; + } + } + /** * Save sort preference to storage * @param {string} sortValue - The sort value to save diff --git a/templates/components/controls.html b/templates/components/controls.html index 65b53350..3c36b1bd 100644 --- a/templates/components/controls.html +++ b/templates/components/controls.html @@ -11,8 +11,18 @@ - - View on GitHub - + +