diff --git a/path_mappings.yaml.example b/path_mappings.yaml.example deleted file mode 100644 index 60f434a8..00000000 --- a/path_mappings.yaml.example +++ /dev/null @@ -1,71 +0,0 @@ -# Path mappings configuration for ComfyUI-Lora-Manager -# This file allows you to customize how base models and Civitai tags map to directories when downloading models - -# Base model mappings -# Format: "Original Base Model": "Custom Directory Name" -# -# Example: If you change "Flux.1 D": "flux" -# Then models with base model "Flux.1 D" will be stored in a directory named "flux" -# So the final path would be: /flux//model_file.safetensors -base_models: - "SD 1.4": "SD 1.4" - "SD 1.5": "SD 1.5" - "SD 1.5 LCM": "SD 1.5 LCM" - "SD 1.5 Hyper": "SD 1.5 Hyper" - "SD 2.0": "SD 2.0" - "SD 2.1": "SD 2.1" - "SDXL 1.0": "SDXL 1.0" - "SD 3": "SD 3" - "SD 3.5": "SD 3.5" - "SD 3.5 Medium": "SD 3.5 Medium" - "SD 3.5 Large": "SD 3.5 Large" - "SD 3.5 Large Turbo": "SD 3.5 Large Turbo" - "Pony": "Pony" - "Flux.1 S": "Flux.1 S" - "Flux.1 D": "Flux.1 D" - "Flux.1 Kontext": "Flux.1 Kontext" - "AuraFlow": "AuraFlow" - "SDXL Lightning": "SDXL Lightning" - "SDXL Hyper": "SDXL Hyper" - "Stable Cascade": "Stable Cascade" - "SVD": "SVD" - "PixArt a": "PixArt a" - "PixArt E": "PixArt E" - "Hunyuan 1": "Hunyuan 1" - "Hunyuan Video": "Hunyuan Video" - "Lumina": "Lumina" - "Kolors": "Kolors" - "Illustrious": "Illustrious" - "Mochi": "Mochi" - "LTXV": "LTXV" - "CogVideoX": "CogVideoX" - "NoobAI": "NoobAI" - "Wan Video": "Wan Video" - "Wan Video 1.3B t2v": "Wan Video 1.3B t2v" - "Wan Video 14B t2v": "Wan Video 14B t2v" - "Wan Video 14B i2v 480p": "Wan Video 14B i2v 480p" - "Wan Video 14B i2v 720p": "Wan Video 14B i2v 720p" - "HiDream": "HiDream" - "Other": "Other" - -# Civitai model tag mappings -# Format: "Original Tag": "Custom Directory Name" -# -# Example: If you change "character": "characters" -# Then models with tag "character" will be stored in a directory named "characters" -# So the final path would be: //characters/model_file.safetensors -model_tags: - "character": "character" - "style": "style" - "concept": "concept" - "clothing": "clothing" - "base model": "base model" - "poses": "poses" - "background": "background" - "tool": "tool" - "vehicle": "vehicle" - "buildings": "buildings" - "objects": "objects" - "assets": "assets" - "animal": "animal" - "action": "action" diff --git a/py/lora_manager.py b/py/lora_manager.py index ed2ff8cb..e211d75e 100644 --- a/py/lora_manager.py +++ b/py/lora_manager.py @@ -6,10 +6,8 @@ from pathlib import Path from server import PromptServer # type: ignore from .config import config -from .routes.lora_routes import LoraRoutes -from .routes.api_routes import ApiRoutes +from .services.model_service_factory import ModelServiceFactory, register_default_model_types from .routes.recipe_routes import RecipeRoutes -from .routes.checkpoints_routes import CheckpointsRoutes from .routes.stats_routes import StatsRoutes from .routes.update_routes import UpdateRoutes from .routes.misc_routes import MiscRoutes @@ -17,6 +15,7 @@ from .routes.example_images_routes import ExampleImagesRoutes from .services.service_registry import ServiceRegistry from .services.settings_manager import settings from .utils.example_images_migration import ExampleImagesMigration +from .services.websocket_manager import ws_manager logger = logging.getLogger(__name__) @@ -28,12 +27,28 @@ class LoraManager: @classmethod def add_routes(cls): - """Initialize and register all routes""" + """Initialize and register all routes using the new refactored architecture""" app = PromptServer.instance.app # Configure aiohttp access logger to be less verbose logging.getLogger('aiohttp.access').setLevel(logging.WARNING) + # Add specific suppression for connection reset errors + class ConnectionResetFilter(logging.Filter): + def filter(self, record): + # Filter out connection reset errors that are not critical + if "ConnectionResetError" in str(record.getMessage()): + return False + if "_call_connection_lost" in str(record.getMessage()): + return False + if "WinError 10054" in str(record.getMessage()): + return False + return True + + # Apply the filter to asyncio logger + asyncio_logger = logging.getLogger("asyncio") + asyncio_logger.addFilter(ConnectionResetFilter()) + added_targets = set() # Track already added target paths # Add static route for example images if the path exists in settings @@ -110,35 +125,37 @@ class LoraManager: # Add static route for plugin assets app.router.add_static('/loras_static', config.static_path) - # Setup feature routes - lora_routes = LoraRoutes() - checkpoints_routes = CheckpointsRoutes() - stats_routes = StatsRoutes() + # Register default model types with the factory + register_default_model_types() - # Initialize routes - lora_routes.setup_routes(app) - checkpoints_routes.setup_routes(app) - stats_routes.setup_routes(app) # Add statistics routes - ApiRoutes.setup_routes(app) + # Setup all model routes using the factory + ModelServiceFactory.setup_all_routes(app) + + # Setup non-model-specific routes + stats_routes = StatsRoutes() + stats_routes.setup_routes(app) RecipeRoutes.setup_routes(app) UpdateRoutes.setup_routes(app) - MiscRoutes.setup_routes(app) # Register miscellaneous routes - ExampleImagesRoutes.setup_routes(app) # Register example images routes + MiscRoutes.setup_routes(app) + ExampleImagesRoutes.setup_routes(app) + + # Setup WebSocket routes that are shared across all model types + app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection) + app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection) + app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection) # Schedule service initialization app.on_startup.append(lambda app: cls._initialize_services()) # Add cleanup app.on_shutdown.append(cls._cleanup) - app.on_shutdown.append(ApiRoutes.cleanup) + + logger.info(f"LoRA Manager: Set up routes for {len(ModelServiceFactory.get_registered_types())} model types: {', '.join(ModelServiceFactory.get_registered_types())}") @classmethod async def _initialize_services(cls): """Initialize all services using the ServiceRegistry""" try: - # Ensure aiohttp access logger is configured with reduced verbosity - logging.getLogger('aiohttp.access').setLevel(logging.WARNING) - # Initialize CivitaiClient first to ensure it's ready for other services await ServiceRegistry.get_civitai_client() diff --git a/py/routes/api_routes.py b/py/routes/api_routes.py deleted file mode 100644 index f5457935..00000000 --- a/py/routes/api_routes.py +++ /dev/null @@ -1,1184 +0,0 @@ -import os -import json -import logging -from aiohttp import web -from typing import Dict -from server import PromptServer # type: ignore - -from ..utils.routes_common import ModelRouteUtils -from ..utils.utils import get_lora_info - -from ..config import config -from ..services.websocket_manager import ws_manager -import asyncio -from .update_routes import UpdateRoutes -from ..utils.constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH, VALID_LORA_TYPES -from ..utils.exif_utils import ExifUtils -from ..utils.metadata_manager import MetadataManager -from ..services.service_registry import ServiceRegistry - -logger = logging.getLogger(__name__) - -class ApiRoutes: - """API route handlers for LoRA management""" - - def __init__(self): - self.scanner = None # Will be initialized in setup_routes - self.civitai_client = None # Will be initialized in setup_routes - self.download_manager = None # Will be initialized in setup_routes - self._download_lock = asyncio.Lock() - - async def initialize_services(self): - """Initialize services from ServiceRegistry""" - self.scanner = await ServiceRegistry.get_lora_scanner() - self.civitai_client = await ServiceRegistry.get_civitai_client() - self.download_manager = await ServiceRegistry.get_download_manager() - - @classmethod - def setup_routes(cls, app: web.Application): - """Register API routes""" - routes = cls() - - # Schedule service initialization on app startup - app.on_startup.append(lambda _: routes.initialize_services()) - - app.router.add_post('/api/delete_model', routes.delete_model) - app.router.add_post('/api/loras/exclude', routes.exclude_model) # Add new exclude endpoint - app.router.add_post('/api/fetch-civitai', routes.fetch_civitai) - app.router.add_post('/api/relink-civitai', routes.relink_civitai) # Add new relink endpoint - app.router.add_post('/api/replace_preview', routes.replace_preview) - app.router.add_get('/api/loras', routes.get_loras) - app.router.add_post('/api/fetch-all-civitai', routes.fetch_all_civitai) - app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection) - app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection) # Add new WebSocket route for download progress - app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection) # Add new WebSocket route - app.router.add_get('/api/lora-roots', routes.get_lora_roots) - app.router.add_get('/api/folders', routes.get_folders) - app.router.add_get('/api/civitai/versions/{model_id}', routes.get_civitai_versions) - app.router.add_get('/api/civitai/model/version/{modelVersionId}', routes.get_civitai_model_by_version) - app.router.add_get('/api/civitai/model/hash/{hash}', routes.get_civitai_model_by_hash) - app.router.add_post('/api/download-model', routes.download_model) - app.router.add_get('/api/download-model-get', routes.download_model_get) # Add new GET endpoint - app.router.add_get('/api/cancel-download-get', routes.cancel_download_get) - app.router.add_get('/api/download-progress/{download_id}', routes.get_download_progress) # Add new endpoint for download progress - app.router.add_post('/api/move_model', routes.move_model) - app.router.add_get('/api/lora-model-description', routes.get_lora_model_description) # Add new route - app.router.add_post('/api/loras/save-metadata', routes.save_metadata) - app.router.add_get('/api/lora-preview-url', routes.get_lora_preview_url) # Add new route - app.router.add_post('/api/move_models_bulk', routes.move_models_bulk) - app.router.add_get('/api/loras/top-tags', routes.get_top_tags) # Add new route for top tags - app.router.add_get('/api/loras/base-models', routes.get_base_models) # Add new route for base models - app.router.add_get('/api/lora-civitai-url', routes.get_lora_civitai_url) # Add new route for Civitai URL - app.router.add_post('/api/loras/rename', routes.rename_lora) # Add new route for renaming LoRA files - app.router.add_get('/api/loras/scan', routes.scan_loras) # Add new route for scanning LoRA files - - # Add the new trigger words route - app.router.add_post('/loramanager/get_trigger_words', routes.get_trigger_words) - - # Add new endpoint for letter counts - app.router.add_get('/api/loras/letter-counts', routes.get_letter_counts) - - # Add new endpoints for copying lora data - app.router.add_get('/api/loras/get-notes', routes.get_lora_notes) - app.router.add_get('/api/loras/get-trigger-words', routes.get_lora_trigger_words) - - # Add update check routes - UpdateRoutes.setup_routes(app) - - # Add new endpoints for finding duplicates - app.router.add_get('/api/loras/find-duplicates', routes.find_duplicate_loras) - app.router.add_get('/api/loras/find-filename-conflicts', routes.find_filename_conflicts) - - # Add new endpoint for bulk deleting loras - app.router.add_post('/api/loras/bulk-delete', routes.bulk_delete_loras) - - # Add new endpoint for verifying duplicates - app.router.add_post('/api/loras/verify-duplicates', routes.verify_duplicates) - - async def delete_model(self, request: web.Request) -> web.Response: - """Handle model deletion request""" - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - return await ModelRouteUtils.handle_delete_model(request, self.scanner) - - async def exclude_model(self, request: web.Request) -> web.Response: - """Handle model exclusion request""" - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - return await ModelRouteUtils.handle_exclude_model(request, self.scanner) - - async def fetch_civitai(self, request: web.Request) -> web.Response: - """Handle CivitAI metadata fetch request""" - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - response = await ModelRouteUtils.handle_fetch_civitai(request, self.scanner) - - # If successful, format the metadata before returning - if response.status == 200: - data = json.loads(response.body.decode('utf-8')) - if data.get("success") and data.get("metadata"): - formatted_metadata = self._format_lora_response(data["metadata"]) - return web.json_response({ - "success": True, - "metadata": formatted_metadata - }) - - # Otherwise, return the original response - return response - - async def replace_preview(self, request: web.Request) -> web.Response: - """Handle preview image replacement request""" - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - return await ModelRouteUtils.handle_replace_preview(request, self.scanner) - - async def scan_loras(self, request: web.Request) -> web.Response: - """Force a rescan of LoRA files""" - try: - # Get full_rebuild parameter from query string, default to false - full_rebuild = request.query.get('full_rebuild', 'false').lower() == 'true' - - await self.scanner.get_cached_data(force_refresh=True, rebuild_cache=full_rebuild) - return web.json_response({"status": "success", "message": "LoRA scan completed"}) - except Exception as e: - logger.error(f"Error in scan_loras: {e}", exc_info=True) - return web.json_response({"error": str(e)}, status=500) - - async def get_loras(self, request: web.Request) -> web.Response: - """Handle paginated LoRA data request""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - # Parse query parameters - page = int(request.query.get('page', '1')) - page_size = int(request.query.get('page_size', '20')) - sort_by = request.query.get('sort_by', 'name') - folder = request.query.get('folder', None) - search = request.query.get('search', None) - fuzzy_search = request.query.get('fuzzy', 'false').lower() == 'true' - - # Parse search options - search_options = { - 'filename': request.query.get('search_filename', 'true').lower() == 'true', - 'modelname': request.query.get('search_modelname', 'true').lower() == 'true', - 'tags': request.query.get('search_tags', 'false').lower() == 'true', - 'recursive': request.query.get('recursive', 'false').lower() == 'true' - } - - # Get filter parameters - base_models = request.query.get('base_models', None) - tags = request.query.get('tags', None) - favorites_only = request.query.get('favorites_only', 'false').lower() == 'true' # New parameter - - # New parameter for alphabet filtering - first_letter = request.query.get('first_letter', None) - - # New parameters for recipe filtering - lora_hash = request.query.get('lora_hash', None) - lora_hashes = request.query.get('lora_hashes', None) - - # Parse filter parameters - filters = {} - if base_models: - filters['base_model'] = base_models.split(',') - if tags: - filters['tags'] = tags.split(',') - - # Add lora hash filtering options - hash_filters = {} - if lora_hash: - hash_filters['single_hash'] = lora_hash.lower() - elif lora_hashes: - hash_filters['multiple_hashes'] = [h.lower() for h in lora_hashes.split(',')] - - # Get file data - data = await self.scanner.get_paginated_data( - page, - page_size, - sort_by=sort_by, - folder=folder, - search=search, - fuzzy_search=fuzzy_search, - base_models=filters.get('base_model', None), - tags=filters.get('tags', None), - search_options=search_options, - hash_filters=hash_filters, - favorites_only=favorites_only, # Pass favorites_only parameter - first_letter=first_letter # Pass the new first_letter parameter - ) - - # Get all available folders from cache - cache = await self.scanner.get_cached_data() - - # Convert output to match expected format - result = { - 'items': [self._format_lora_response(lora) for lora in data['items']], - 'folders': cache.folders, - 'total': data['total'], - 'page': data['page'], - 'page_size': data['page_size'], - 'total_pages': data['total_pages'] - } - - return web.json_response(result) - - except Exception as e: - logger.error(f"Error retrieving loras: {e}", exc_info=True) - return web.json_response({"error": str(e)}, status=500) - - def _format_lora_response(self, lora: Dict) -> Dict: - """Format LoRA data for API response""" - return { - "model_name": lora["model_name"], - "file_name": lora["file_name"], - "preview_url": config.get_preview_static_url(lora["preview_url"]), - "preview_nsfw_level": lora.get("preview_nsfw_level", 0), - "base_model": lora["base_model"], - "folder": lora["folder"], - "sha256": lora["sha256"], - "file_path": lora["file_path"].replace(os.sep, "/"), - "file_size": lora["size"], - "modified": lora["modified"], - "tags": lora["tags"], - "modelDescription": lora["modelDescription"], - "from_civitai": lora.get("from_civitai", True), - "usage_tips": lora.get("usage_tips", ""), - "notes": lora.get("notes", ""), - "favorite": lora.get("favorite", False), # Include favorite status in response - "civitai": ModelRouteUtils.filter_civitai_data(lora.get("civitai", {})) - } - - async def fetch_all_civitai(self, request: web.Request) -> web.Response: - """Fetch CivitAI metadata for all loras in the background""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - cache = await self.scanner.get_cached_data() - total = len(cache.raw_data) - processed = 0 - success = 0 - needs_resort = False - - # Prepare loras to process - to_process = [ - lora for lora in cache.raw_data - if lora.get('sha256') and (not lora.get('civitai') or 'id' not in lora.get('civitai')) and lora.get('from_civitai', True) # TODO: for lora not from CivitAI but added traineWords - ] - total_to_process = len(to_process) - - # Send initial progress - await ws_manager.broadcast({ - 'status': 'started', - 'total': total_to_process, - 'processed': 0, - 'success': 0 - }) - - for lora in to_process: - try: - original_name = lora.get('model_name') - if await ModelRouteUtils.fetch_and_update_model( - sha256=lora['sha256'], - file_path=lora['file_path'], - model_data=lora, - update_cache_func=self.scanner.update_single_model_cache - ): - success += 1 - if original_name != lora.get('model_name'): - needs_resort = True - - processed += 1 - - # Send progress update - await ws_manager.broadcast({ - 'status': 'processing', - 'total': total_to_process, - 'processed': processed, - 'success': success, - 'current_name': lora.get('model_name', 'Unknown') - }) - - except Exception as e: - logger.error(f"Error fetching CivitAI data for {lora['file_path']}: {e}") - - if needs_resort: - await cache.resort(name_only=True) - - # Send completion message - await ws_manager.broadcast({ - 'status': 'completed', - 'total': total_to_process, - 'processed': processed, - 'success': success - }) - - return web.json_response({ - "success": True, - "message": f"Successfully updated {success} of {processed} processed loras (total: {total})" - }) - - except Exception as e: - # Send error message - await ws_manager.broadcast({ - 'status': 'error', - 'error': str(e) - }) - logger.error(f"Error in fetch_all_civitai: {e}") - return web.Response(text=str(e), status=500) - - async def get_lora_roots(self, request: web.Request) -> web.Response: - """Get all configured LoRA root directories""" - return web.json_response({ - 'roots': config.loras_roots - }) - - async def get_folders(self, request: web.Request) -> web.Response: - """Get all folders in the cache""" - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - cache = await self.scanner.get_cached_data() - return web.json_response({ - 'folders': cache.folders - }) - - async def get_civitai_versions(self, request: web.Request) -> web.Response: - """Get available versions for a Civitai model with local availability info""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - if self.civitai_client is None: - self.civitai_client = await ServiceRegistry.get_civitai_client() - - model_id = request.match_info['model_id'] - response = await self.civitai_client.get_model_versions(model_id) - if not response or not response.get('modelVersions'): - return web.Response(status=404, text="Model not found") - - versions = response.get('modelVersions', []) - model_type = response.get('type', '') - - # Check model type - should be LORA, LoCon, or DORA - if model_type.lower() not in VALID_LORA_TYPES: - return web.json_response({ - 'error': f"Model type mismatch. Expected LORA or LoCon, got {model_type}" - }, status=400) - - # Check local availability for each version - for version in versions: - # Find the model file (type="Model") in the files list - model_file = next((file for file in version.get('files', []) - if file.get('type') == 'Model'), None) - - if model_file: - sha256 = model_file.get('hashes', {}).get('SHA256') - if sha256: - # Set existsLocally and localPath at the version level - version['existsLocally'] = self.scanner.has_hash(sha256) - if version['existsLocally']: - version['localPath'] = self.scanner.get_path_by_hash(sha256) - - # Also set the model file size at the version level for easier access - version['modelSizeKB'] = model_file.get('sizeKB') - else: - # No model file found in this version - version['existsLocally'] = False - - return web.json_response(versions) - except Exception as e: - logger.error(f"Error fetching model versions: {e}") - return web.Response(status=500, text=str(e)) - - async def get_civitai_model_by_version(self, request: web.Request) -> web.Response: - """Get CivitAI model details by model version ID""" - try: - if self.civitai_client is None: - self.civitai_client = await ServiceRegistry.get_civitai_client() - - model_version_id = request.match_info.get('modelVersionId') - - # Get model details from Civitai API - model, error_msg = await self.civitai_client.get_model_version_info(model_version_id) - - if not model: - # Log warning for failed model retrieval - logger.warning(f"Failed to fetch model version {model_version_id}: {error_msg}") - - # Determine status code based on error message - status_code = 404 if error_msg and "not found" in error_msg.lower() else 500 - - return web.json_response({ - "success": False, - "error": error_msg or "Failed to fetch model information" - }, status=status_code) - - return web.json_response(model) - except Exception as e: - logger.error(f"Error fetching model details: {e}") - return web.json_response({ - "success": False, - "error": str(e) - }, status=500) - - async def get_civitai_model_by_hash(self, request: web.Request) -> web.Response: - """Get CivitAI model details by hash""" - try: - if self.civitai_client is None: - self.civitai_client = await ServiceRegistry.get_civitai_client() - - hash = request.match_info.get('hash') - model = await self.civitai_client.get_model_by_hash(hash) - return web.json_response(model) - except Exception as e: - logger.error(f"Error fetching model details by hash: {e}") - return web.json_response({ - "success": False, - "error": str(e) - }, status=500) - - async def download_model(self, request: web.Request) -> web.Response: - return await ModelRouteUtils.handle_download_model(request, self.download_manager) - - async def download_model_get(self, request: web.Request) -> web.Response: - """Handle model download request via GET method - - Converts GET parameters to POST format and calls the existing download handler - - Args: - request: The aiohttp request with query parameters - - Returns: - web.Response: The HTTP response - """ - try: - # Extract query parameters - model_id = request.query.get('model_id') - if not model_id: - return web.Response( - status=400, - text="Missing required parameter: Please provide 'model_id'" - ) - - # Get optional parameters - model_version_id = request.query.get('model_version_id') - download_id = request.query.get('download_id') - use_default_paths = request.query.get('use_default_paths', 'false').lower() == 'true' - - # Create a data dictionary that mimics what would be received from a POST request - data = { - 'model_id': model_id - } - - # Add optional parameters only if they are provided - if model_version_id: - data['model_version_id'] = model_version_id - - if download_id: - data['download_id'] = download_id - - data['use_default_paths'] = use_default_paths - - # Create a mock request object with the data - # Fix: Create a proper Future object and set its result - future = asyncio.get_event_loop().create_future() - future.set_result(data) - - mock_request = type('MockRequest', (), { - 'json': lambda self=None: future - })() - - # Call the existing download handler - if self.download_manager is None: - self.download_manager = await ServiceRegistry.get_download_manager() - - return await ModelRouteUtils.handle_download_model(mock_request, self.download_manager) - - except Exception as e: - error_message = str(e) - logger.error(f"Error downloading model via GET: {error_message}", exc_info=True) - return web.Response(status=500, text=error_message) - - async def cancel_download_get(self, request: web.Request) -> web.Response: - """Handle GET request for cancelling a download by download_id""" - try: - download_id = request.query.get('download_id') - if not download_id: - return web.json_response({ - 'success': False, - 'error': 'Download ID is required' - }, status=400) - if self.download_manager is None: - self.download_manager = await ServiceRegistry.get_download_manager() - # Create a mock request with match_info for compatibility - mock_request = type('MockRequest', (), { - 'match_info': {'download_id': download_id} - })() - return await ModelRouteUtils.handle_cancel_download(mock_request, self.download_manager) - except Exception as e: - logger.error(f"Error cancelling download via GET: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def get_download_progress(self, request: web.Request) -> web.Response: - """Handle request for download progress by download_id""" - try: - # Get download_id from URL path - download_id = request.match_info.get('download_id') - if not download_id: - return web.json_response({ - 'success': False, - 'error': 'Download ID is required' - }, status=400) - - # Get progress information from websocket manager - progress_data = ws_manager.get_download_progress(download_id) - - if progress_data is None: - return web.json_response({ - 'success': False, - 'error': 'Download ID not found' - }, status=404) - - return web.json_response({ - 'success': True, - 'progress': progress_data.get('progress', 0) - }) - except Exception as e: - logger.error(f"Error getting download progress: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def move_model(self, request: web.Request) -> web.Response: - """Handle model move request""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - data = await request.json() - file_path = data.get('file_path') # full path of the model file, e.g. /path/to/model.safetensors - target_path = data.get('target_path') # folder path to move the model to, e.g. /path/to/target_folder - - if not file_path or not target_path: - return web.Response(text='File path and target path are required', status=400) - - # Check if source and destination are the same - source_dir = os.path.dirname(file_path) - if os.path.normpath(source_dir) == os.path.normpath(target_path): - logger.info(f"Source and target directories are the same: {source_dir}") - return web.json_response({'success': True, 'message': 'Source and target directories are the same'}) - - # Check if target file already exists - file_name = os.path.basename(file_path) - target_file_path = os.path.join(target_path, file_name).replace(os.sep, '/') - - if os.path.exists(target_file_path): - return web.json_response({ - 'success': False, - 'error': f"Target file already exists: {target_file_path}" - }, status=409) # 409 Conflict - - # Call scanner to handle the move operation - success = await self.scanner.move_model(file_path, target_path) - - if success: - return web.json_response({'success': True}) - else: - return web.Response(text='Failed to move model', status=500) - - except Exception as e: - logger.error(f"Error moving model: {e}", exc_info=True) - return web.Response(text=str(e), status=500) - - @classmethod - async def cleanup(cls): - """Add cleanup method for application shutdown""" - # Now we don't need to store an instance, as services are managed by ServiceRegistry - civitai_client = await ServiceRegistry.get_civitai_client() - if civitai_client: - await civitai_client.close() - - async def save_metadata(self, request: web.Request) -> web.Response: - """Handle saving metadata updates""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - data = await request.json() - file_path = data.get('file_path') - if not file_path: - return web.Response(text='File path is required', status=400) - - # Remove file path from data to avoid saving it - metadata_updates = {k: v for k, v in data.items() if k != 'file_path'} - - # Get metadata file path - metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' - - # Load existing metadata - metadata = await ModelRouteUtils.load_local_metadata(metadata_path) - - # Handle nested updates (for civitai.trainedWords) - for key, value in metadata_updates.items(): - if isinstance(value, dict) and key in metadata and isinstance(metadata[key], dict): - # Deep update for nested dictionaries - for nested_key, nested_value in value.items(): - metadata[key][nested_key] = nested_value - else: - # Regular update for top-level keys - metadata[key] = value - - # Save updated metadata - await MetadataManager.save_metadata(file_path, metadata) - - # Update cache - await self.scanner.update_single_model_cache(file_path, file_path, metadata) - - # If model_name was updated, resort the cache - if 'model_name' in metadata_updates: - cache = await self.scanner.get_cached_data() - await cache.resort(name_only=True) - - return web.json_response({'success': True}) - - except Exception as e: - logger.error(f"Error saving metadata: {e}", exc_info=True) - return web.Response(text=str(e), status=500) - - async def get_lora_preview_url(self, request: web.Request) -> web.Response: - """Get the static preview URL for a LoRA file""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - # Get lora file name from query parameters - lora_name = request.query.get('name') - if not lora_name: - return web.Response(text='Lora file name is required', status=400) - - # Get cache data - cache = await self.scanner.get_cached_data() - - # Search for the lora in cache data - for lora in cache.raw_data: - file_name = lora['file_name'] - if file_name == lora_name: - if preview_url := lora.get('preview_url'): - # Convert preview path to static URL - static_url = config.get_preview_static_url(preview_url) - if static_url: - return web.json_response({ - 'success': True, - 'preview_url': static_url - }) - break - - # If no preview URL found - return web.json_response({ - 'success': False, - 'error': 'No preview URL found for the specified lora' - }, status=404) - - except Exception as e: - logger.error(f"Error getting lora preview URL: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def get_lora_civitai_url(self, request: web.Request) -> web.Response: - """Get the Civitai URL for a LoRA file""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - # Get lora file name from query parameters - lora_name = request.query.get('name') - if not lora_name: - return web.Response(text='Lora file name is required', status=400) - - # Get cache data - cache = await self.scanner.get_cached_data() - - # Search for the lora in cache data - for lora in cache.raw_data: - file_name = lora['file_name'] - if file_name == lora_name: - civitai_data = lora.get('civitai', {}) - model_id = civitai_data.get('modelId') - version_id = civitai_data.get('id') - - if model_id: - civitai_url = f"https://civitai.com/models/{model_id}" - if version_id: - civitai_url += f"?modelVersionId={version_id}" - - return web.json_response({ - 'success': True, - 'civitai_url': civitai_url, - 'model_id': model_id, - 'version_id': version_id - }) - break - - # If no Civitai data found - return web.json_response({ - 'success': False, - 'error': 'No Civitai data found for the specified lora' - }, status=404) - - except Exception as e: - logger.error(f"Error getting lora Civitai URL: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def move_models_bulk(self, request: web.Request) -> web.Response: - """Handle bulk model move request""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - data = await request.json() - file_paths = data.get('file_paths', []) # list of full paths of the model files, e.g. ["/path/to/model1.safetensors", "/path/to/model2.safetensors"] - target_path = data.get('target_path') # folder path to move the models to, e.g. "/path/to/target_folder" - - if not file_paths or not target_path: - return web.Response(text='File paths and target path are required', status=400) - - results = [] - for file_path in file_paths: - # Check if source and destination are the same - source_dir = os.path.dirname(file_path) - if os.path.normpath(source_dir) == os.path.normpath(target_path): - results.append({ - "path": file_path, - "success": True, - "message": "Source and target directories are the same" - }) - continue - - # Check if target file already exists - file_name = os.path.basename(file_path) - target_file_path = os.path.join(target_path, file_name).replace(os.sep, '/') - - if os.path.exists(target_file_path): - results.append({ - "path": file_path, - "success": False, - "message": f"Target file already exists: {target_file_path}" - }) - continue - - # Try to move the model - success = await self.scanner.move_model(file_path, target_path) - results.append({ - "path": file_path, - "success": success, - "message": "Success" if success else "Failed to move model" - }) - - # Count successes and failures - success_count = sum(1 for r in results if r["success"]) - failure_count = len(results) - success_count - - return web.json_response({ - 'success': True, - 'message': f'Moved {success_count} of {len(file_paths)} models', - 'results': results, - 'success_count': success_count, - 'failure_count': failure_count - }) - - except Exception as e: - logger.error(f"Error moving models in bulk: {e}", exc_info=True) - return web.Response(text=str(e), status=500) - - async def get_lora_model_description(self, request: web.Request) -> web.Response: - """Get model description for a Lora model""" - try: - if self.civitai_client is None: - self.civitai_client = await ServiceRegistry.get_civitai_client() - - # Get parameters - model_id = request.query.get('model_id') - file_path = request.query.get('file_path') - - if not model_id: - return web.json_response({ - 'success': False, - 'error': 'Model ID is required' - }, status=400) - - # Check if we already have the description stored in metadata - description = None - tags = [] - creator = {} - if file_path: - metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' - metadata = await ModelRouteUtils.load_local_metadata(metadata_path) - description = metadata.get('modelDescription') - tags = metadata.get('tags', []) - creator = metadata.get('creator', {}) - - # If description is not in metadata, fetch from CivitAI - if not description: - logger.info(f"Fetching model metadata for model ID: {model_id}") - model_metadata, _ = await self.civitai_client.get_model_metadata(model_id) - - if (model_metadata): - description = model_metadata.get('description') - tags = model_metadata.get('tags', []) - creator = model_metadata.get('creator', {}) - - # Save the metadata to file if we have a file path and got metadata - if file_path: - try: - metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' - metadata = await ModelRouteUtils.load_local_metadata(metadata_path) - - metadata['modelDescription'] = description - metadata['tags'] = tags - # Ensure the civitai dict exists - if 'civitai' not in metadata: - metadata['civitai'] = {} - # Store creator in the civitai nested structure - metadata['civitai']['creator'] = creator - - await MetadataManager.save_metadata(file_path, metadata, True) - except Exception as e: - logger.error(f"Error saving model metadata: {e}") - - return web.json_response({ - 'success': True, - 'description': description or "

No model description available.

", - 'tags': tags, - 'creator': creator - }) - - except Exception as e: - logger.error(f"Error getting model metadata: {e}") - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def get_top_tags(self, request: web.Request) -> web.Response: - """Handle request for top tags sorted by frequency""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - # Parse query parameters - limit = int(request.query.get('limit', '20')) - - # Validate limit - if limit < 1 or limit > 100: - limit = 20 # Default to a reasonable limit - - # Get top tags - top_tags = await self.scanner.get_top_tags(limit) - - return web.json_response({ - 'success': True, - 'tags': top_tags - }) - - except Exception as e: - logger.error(f"Error getting top tags: {str(e)}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': 'Internal server error' - }, status=500) - - async def get_base_models(self, request: web.Request) -> web.Response: - """Get base models used in loras""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - # Parse query parameters - limit = int(request.query.get('limit', '20')) - - # Validate limit - if limit < 1 or limit > 100: - limit = 20 # Default to a reasonable limit - - # Get base models - base_models = await self.scanner.get_base_models(limit) - - return web.json_response({ - 'success': True, - 'base_models': base_models - }) - except Exception as e: - logger.error(f"Error retrieving base models: {e}") - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def rename_lora(self, request: web.Request) -> web.Response: - """Handle renaming a LoRA file and its associated files""" - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - return await ModelRouteUtils.handle_rename_model(request, self.scanner) - - async def get_trigger_words(self, request: web.Request) -> web.Response: - """Get trigger words for specified LoRA models""" - try: - json_data = await request.json() - lora_names = json_data.get("lora_names", []) - node_ids = json_data.get("node_ids", []) - - all_trigger_words = [] - for lora_name in lora_names: - _, trigger_words = get_lora_info(lora_name) - all_trigger_words.extend(trigger_words) - - # Format the trigger words - trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else "" - - # Send update to all connected trigger word toggle nodes - for node_id in node_ids: - PromptServer.instance.send_sync("trigger_word_update", { - "id": node_id, - "message": trigger_words_text - }) - - return web.json_response({"success": True}) - - except Exception as e: - logger.error(f"Error getting trigger words: {e}") - return web.json_response({ - "success": False, - "error": str(e) - }, status=500) - - async def get_letter_counts(self, request: web.Request) -> web.Response: - """Get count of loras for each letter of the alphabet""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - # Get letter counts - letter_counts = await self.scanner.get_letter_counts() - - return web.json_response({ - 'success': True, - 'letter_counts': letter_counts - }) - except Exception as e: - logger.error(f"Error getting letter counts: {e}") - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def get_lora_notes(self, request: web.Request) -> web.Response: - """Get notes for a specific LoRA file""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - # Get lora file name from query parameters - lora_name = request.query.get('name') - if not lora_name: - return web.Response(text='Lora file name is required', status=400) - - # Get cache data - cache = await self.scanner.get_cached_data() - - # Search for the lora in cache data - for lora in cache.raw_data: - file_name = lora['file_name'] - if file_name == lora_name: - notes = lora.get('notes', '') - - return web.json_response({ - 'success': True, - 'notes': notes - }) - - # If lora not found - return web.json_response({ - 'success': False, - 'error': 'LoRA not found in cache' - }, status=404) - - except Exception as e: - logger.error(f"Error getting lora notes: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def get_lora_trigger_words(self, request: web.Request) -> web.Response: - """Get trigger words for a specific LoRA file""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - # Get lora file name from query parameters - lora_name = request.query.get('name') - if not lora_name: - return web.Response(text='Lora file name is required', status=400) - - # Get cache data - cache = await self.scanner.get_cached_data() - - # Search for the lora in cache data - for lora in cache.raw_data: - file_name = lora['file_name'] - if file_name == lora_name: - # Get trigger words from civitai data - civitai_data = lora.get('civitai', {}) - trigger_words = civitai_data.get('trainedWords', []) - - return web.json_response({ - 'success': True, - 'trigger_words': trigger_words - }) - - # If lora not found - return web.json_response({ - 'success': False, - 'error': 'LoRA not found in cache' - }, status=404) - - except Exception as e: - logger.error(f"Error getting lora trigger words: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def find_duplicate_loras(self, request: web.Request) -> web.Response: - """Find loras with duplicate SHA256 hashes""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - # Get duplicate hashes from hash index - duplicates = self.scanner._hash_index.get_duplicate_hashes() - - # Format the response - result = [] - cache = await self.scanner.get_cached_data() - - for sha256, paths in duplicates.items(): - group = { - "hash": sha256, - "models": [] - } - # Find matching models for each duplicate path - for path in paths: - model = next((m for m in cache.raw_data if m['file_path'] == path), None) - if model: - group["models"].append(self._format_lora_response(model)) - - # Add the primary model too - primary_path = self.scanner._hash_index.get_path(sha256) - if primary_path and primary_path not in paths: - primary_model = next((m for m in cache.raw_data if m['file_path'] == primary_path), None) - if primary_model: - group["models"].insert(0, self._format_lora_response(primary_model)) - - if len(group["models"]) > 1: # Only include if we found multiple models - result.append(group) - - return web.json_response({ - "success": True, - "duplicates": result, - "count": len(result) - }) - except Exception as e: - logger.error(f"Error finding duplicate loras: {e}", exc_info=True) - return web.json_response({ - "success": False, - "error": str(e) - }, status=500) - - async def find_filename_conflicts(self, request: web.Request) -> web.Response: - """Find loras with conflicting filenames""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - # Get duplicate filenames from hash index - duplicates = self.scanner._hash_index.get_duplicate_filenames() - - # Format the response - result = [] - cache = await self.scanner.get_cached_data() - - for filename, paths in duplicates.items(): - group = { - "filename": filename, - "models": [] - } - # Find matching models for each path - for path in paths: - model = next((m for m in cache.raw_data if m['file_path'] == path), None) - if model: - group["models"].append(self._format_lora_response(model)) - - # Find the model from the main index too - hash_val = self.scanner._hash_index.get_hash_by_filename(filename) - if hash_val: - main_path = self.scanner._hash_index.get_path(hash_val) - if main_path and main_path not in paths: - main_model = next((m for m in cache.raw_data if m['file_path'] == main_path), None) - if main_model: - group["models"].insert(0, self._format_lora_response(main_model)) - - if group["models"]: # Only include if we found models - result.append(group) - - return web.json_response({ - "success": True, - "conflicts": result, - "count": len(result) - }) - except Exception as e: - logger.error(f"Error finding filename conflicts: {e}", exc_info=True) - return web.json_response({ - "success": False, - "error": str(e) - }, status=500) - - async def bulk_delete_loras(self, request: web.Request) -> web.Response: - """Handle bulk deletion of lora models""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - - return await ModelRouteUtils.handle_bulk_delete_models(request, self.scanner) - - except Exception as e: - logger.error(f"Error in bulk delete loras: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def relink_civitai(self, request: web.Request) -> web.Response: - """Handle CivitAI metadata re-linking request by model version ID for LoRAs""" - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - return await ModelRouteUtils.handle_relink_civitai(request, self.scanner) - - async def verify_duplicates(self, request: web.Request) -> web.Response: - """Handle verification of duplicate lora hashes""" - if self.scanner is None: - self.scanner = await ServiceRegistry.get_lora_scanner() - return await ModelRouteUtils.handle_verify_duplicates(request, self.scanner) diff --git a/py/routes/base_model_routes.py b/py/routes/base_model_routes.py new file mode 100644 index 00000000..33f47193 --- /dev/null +++ b/py/routes/base_model_routes.py @@ -0,0 +1,619 @@ +from abc import ABC, abstractmethod +import asyncio +import json +import logging +from aiohttp import web +from typing import Dict + +import jinja2 + +from ..utils.routes_common import ModelRouteUtils +from ..services.websocket_manager import ws_manager +from ..services.settings_manager import settings +from ..config import config + +logger = logging.getLogger(__name__) + +class BaseModelRoutes(ABC): + """Base route controller for all model types""" + + def __init__(self, service): + """Initialize the route controller + + Args: + service: Model service instance (LoraService, CheckpointService, etc.) + """ + self.service = service + self.model_type = service.model_type + self.template_env = jinja2.Environment( + loader=jinja2.FileSystemLoader(config.templates_path), + autoescape=True + ) + + def setup_routes(self, app: web.Application, prefix: str): + """Setup common routes for the model type + + Args: + app: aiohttp application + prefix: URL prefix (e.g., 'loras', 'checkpoints') + """ + # Common model management routes + app.router.add_get(f'/api/{prefix}', self.get_models) + app.router.add_post(f'/api/{prefix}/delete', self.delete_model) + app.router.add_post(f'/api/{prefix}/exclude', self.exclude_model) + app.router.add_post(f'/api/{prefix}/fetch-civitai', self.fetch_civitai) + app.router.add_post(f'/api/{prefix}/relink-civitai', self.relink_civitai) + app.router.add_post(f'/api/{prefix}/replace-preview', self.replace_preview) + app.router.add_post(f'/api/{prefix}/save-metadata', self.save_metadata) + app.router.add_post(f'/api/{prefix}/rename', self.rename_model) + app.router.add_post(f'/api/{prefix}/bulk-delete', self.bulk_delete_models) + app.router.add_post(f'/api/{prefix}/verify-duplicates', self.verify_duplicates) + + # Common query routes + app.router.add_get(f'/api/{prefix}/top-tags', self.get_top_tags) + app.router.add_get(f'/api/{prefix}/base-models', self.get_base_models) + app.router.add_get(f'/api/{prefix}/scan', self.scan_models) + app.router.add_get(f'/api/{prefix}/roots', self.get_model_roots) + app.router.add_get(f'/api/{prefix}/folders', self.get_folders) + app.router.add_get(f'/api/{prefix}/find-duplicates', self.find_duplicate_models) + app.router.add_get(f'/api/{prefix}/find-filename-conflicts', self.find_filename_conflicts) + + # Common Download management + app.router.add_post(f'/api/download-model', self.download_model) + app.router.add_get(f'/api/download-model-get', self.download_model_get) + app.router.add_get(f'/api/cancel-download-get', self.cancel_download_get) + app.router.add_get(f'/api/download-progress/{{download_id}}', self.get_download_progress) + + # CivitAI integration routes + app.router.add_post(f'/api/{prefix}/fetch-all-civitai', self.fetch_all_civitai) + # app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions) + + # Add generic page route + app.router.add_get(f'/{prefix}', self.handle_models_page) + + # Setup model-specific routes + self.setup_specific_routes(app, prefix) + + @abstractmethod + def setup_specific_routes(self, app: web.Application, prefix: str): + """Setup model-specific routes - to be implemented by subclasses""" + pass + + async def handle_models_page(self, request: web.Request) -> web.Response: + """ + Generic handler for model pages (e.g., /loras, /checkpoints). + Subclasses should set self.template_env and template_name. + """ + try: + # Check if the scanner is initializing + is_initializing = ( + self.service.scanner._cache is None or + (hasattr(self.service.scanner, 'is_initializing') and callable(self.service.scanner.is_initializing) and self.service.scanner.is_initializing()) or + (hasattr(self.service.scanner, '_is_initializing') and self.service.scanner._is_initializing) + ) + + template_name = getattr(self, "template_name", None) + if not self.template_env or not template_name: + return web.Response(text="Template environment or template name not set", status=500) + + if is_initializing: + rendered = self.template_env.get_template(template_name).render( + folders=[], + is_initializing=True, + settings=settings, + request=request + ) + else: + try: + cache = await self.service.scanner.get_cached_data(force_refresh=False) + rendered = self.template_env.get_template(template_name).render( + folders=getattr(cache, "folders", []), + is_initializing=False, + settings=settings, + request=request + ) + except Exception as cache_error: + logger.error(f"Error loading cache data: {cache_error}") + rendered = self.template_env.get_template(template_name).render( + folders=[], + is_initializing=True, + settings=settings, + request=request + ) + return web.Response( + text=rendered, + content_type='text/html' + ) + except Exception as e: + logger.error(f"Error handling models page: {e}", exc_info=True) + return web.Response( + text="Error loading models page", + status=500 + ) + + async def get_models(self, request: web.Request) -> web.Response: + """Get paginated model data""" + try: + # Parse common query parameters + params = self._parse_common_params(request) + + # Get data from service + result = await self.service.get_paginated_data(**params) + + # Format response items + formatted_result = { + 'items': [await self.service.format_response(item) for item in result['items']], + 'total': result['total'], + 'page': result['page'], + 'page_size': result['page_size'], + 'total_pages': result['total_pages'] + } + + return web.json_response(formatted_result) + + except Exception as e: + logger.error(f"Error in get_{self.model_type}s: {e}", exc_info=True) + return web.json_response({"error": str(e)}, status=500) + + def _parse_common_params(self, request: web.Request) -> Dict: + """Parse common query parameters""" + # Parse basic pagination and sorting + page = int(request.query.get('page', '1')) + page_size = min(int(request.query.get('page_size', '20')), 100) + sort_by = request.query.get('sort_by', 'name') + folder = request.query.get('folder', None) + search = request.query.get('search', None) + fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true' + + # Parse filter arrays + base_models = request.query.getall('base_model', []) + tags = request.query.getall('tag', []) + favorites_only = request.query.get('favorites_only', 'false').lower() == 'true' + + # Parse search options + search_options = { + 'filename': request.query.get('search_filename', 'true').lower() == 'true', + 'modelname': request.query.get('search_modelname', 'true').lower() == 'true', + 'tags': request.query.get('search_tags', 'false').lower() == 'true', + 'recursive': request.query.get('recursive', 'false').lower() == 'true', + } + + # Parse hash filters if provided + hash_filters = {} + if 'hash' in request.query: + hash_filters['single_hash'] = request.query['hash'] + elif 'hashes' in request.query: + try: + hash_list = json.loads(request.query['hashes']) + if isinstance(hash_list, list): + hash_filters['multiple_hashes'] = hash_list + except (json.JSONDecodeError, TypeError): + pass + + return { + 'page': page, + 'page_size': page_size, + 'sort_by': sort_by, + 'folder': folder, + 'search': search, + 'fuzzy_search': fuzzy_search, + 'base_models': base_models, + 'tags': tags, + 'search_options': search_options, + 'hash_filters': hash_filters, + 'favorites_only': favorites_only, + # Add model-specific parameters + **self._parse_specific_params(request) + } + + def _parse_specific_params(self, request: web.Request) -> Dict: + """Parse model-specific parameters - to be overridden by subclasses""" + return {} + + # Common route handlers + async def delete_model(self, request: web.Request) -> web.Response: + """Handle model deletion request""" + return await ModelRouteUtils.handle_delete_model(request, self.service.scanner) + + async def exclude_model(self, request: web.Request) -> web.Response: + """Handle model exclusion request""" + return await ModelRouteUtils.handle_exclude_model(request, self.service.scanner) + + async def fetch_civitai(self, request: web.Request) -> web.Response: + """Handle CivitAI metadata fetch request""" + response = await ModelRouteUtils.handle_fetch_civitai(request, self.service.scanner) + + # If successful, format the metadata before returning + if response.status == 200: + data = json.loads(response.body.decode('utf-8')) + if data.get("success") and data.get("metadata"): + formatted_metadata = await self.service.format_response(data["metadata"]) + return web.json_response({ + "success": True, + "metadata": formatted_metadata + }) + + return response + + async def relink_civitai(self, request: web.Request) -> web.Response: + """Handle CivitAI metadata re-linking request""" + return await ModelRouteUtils.handle_relink_civitai(request, self.service.scanner) + + async def replace_preview(self, request: web.Request) -> web.Response: + """Handle preview image replacement""" + return await ModelRouteUtils.handle_replace_preview(request, self.service.scanner) + + async def save_metadata(self, request: web.Request) -> web.Response: + """Handle saving metadata updates""" + return await ModelRouteUtils.handle_save_metadata(request, self.service.scanner) + + async def rename_model(self, request: web.Request) -> web.Response: + """Handle renaming a model file and its associated files""" + return await ModelRouteUtils.handle_rename_model(request, self.service.scanner) + + async def bulk_delete_models(self, request: web.Request) -> web.Response: + """Handle bulk deletion of models""" + return await ModelRouteUtils.handle_bulk_delete_models(request, self.service.scanner) + + async def verify_duplicates(self, request: web.Request) -> web.Response: + """Handle verification of duplicate model hashes""" + return await ModelRouteUtils.handle_verify_duplicates(request, self.service.scanner) + + async def get_top_tags(self, request: web.Request) -> web.Response: + """Handle request for top tags sorted by frequency""" + try: + limit = int(request.query.get('limit', '20')) + if limit < 1 or limit > 100: + limit = 20 + + top_tags = await self.service.get_top_tags(limit) + + return web.json_response({ + 'success': True, + 'tags': top_tags + }) + + except Exception as e: + logger.error(f"Error getting top tags: {str(e)}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': 'Internal server error' + }, status=500) + + async def get_base_models(self, request: web.Request) -> web.Response: + """Get base models used in models""" + try: + limit = int(request.query.get('limit', '20')) + if limit < 1 or limit > 100: + limit = 20 + + base_models = await self.service.get_base_models(limit) + + return web.json_response({ + 'success': True, + 'base_models': base_models + }) + except Exception as e: + logger.error(f"Error retrieving base models: {e}") + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + async def scan_models(self, request: web.Request) -> web.Response: + """Force a rescan of model files""" + try: + full_rebuild = request.query.get('full_rebuild', 'false').lower() == 'true' + + await self.service.scan_models(force_refresh=True, rebuild_cache=full_rebuild) + return web.json_response({ + "status": "success", + "message": f"{self.model_type.capitalize()} scan completed" + }) + except Exception as e: + logger.error(f"Error in scan_{self.model_type}s: {e}", exc_info=True) + return web.json_response({"error": str(e)}, status=500) + + async def get_model_roots(self, request: web.Request) -> web.Response: + """Return the model root directories""" + try: + roots = self.service.get_model_roots() + return web.json_response({ + "success": True, + "roots": roots + }) + except Exception as e: + logger.error(f"Error getting {self.model_type} roots: {e}", exc_info=True) + return web.json_response({ + "success": False, + "error": str(e) + }, status=500) + + async def get_folders(self, request: web.Request) -> web.Response: + """Get all folders in the cache""" + try: + cache = await self.service.scanner.get_cached_data() + return web.json_response({ + 'folders': cache.folders + }) + except Exception as e: + logger.error(f"Error getting folders: {e}") + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + async def find_duplicate_models(self, request: web.Request) -> web.Response: + """Find models with duplicate SHA256 hashes""" + try: + # Get duplicate hashes from service + duplicates = self.service.find_duplicate_hashes() + + # Format the response + result = [] + cache = await self.service.scanner.get_cached_data() + + for sha256, paths in duplicates.items(): + group = { + "hash": sha256, + "models": [] + } + # Find matching models for each path + for path in paths: + model = next((m for m in cache.raw_data if m['file_path'] == path), None) + if model: + group["models"].append(await self.service.format_response(model)) + + # Add the primary model too + primary_path = self.service.get_path_by_hash(sha256) + if primary_path and primary_path not in paths: + primary_model = next((m for m in cache.raw_data if m['file_path'] == primary_path), None) + if primary_model: + group["models"].insert(0, await self.service.format_response(primary_model)) + + if len(group["models"]) > 1: # Only include if we found multiple models + result.append(group) + + return web.json_response({ + "success": True, + "duplicates": result, + "count": len(result) + }) + except Exception as e: + logger.error(f"Error finding duplicate {self.model_type}s: {e}", exc_info=True) + return web.json_response({ + "success": False, + "error": str(e) + }, status=500) + + async def find_filename_conflicts(self, request: web.Request) -> web.Response: + """Find models with conflicting filenames""" + try: + # Get duplicate filenames from service + duplicates = self.service.find_duplicate_filenames() + + # Format the response + result = [] + cache = await self.service.scanner.get_cached_data() + + for filename, paths in duplicates.items(): + group = { + "filename": filename, + "models": [] + } + # Find matching models for each path + for path in paths: + model = next((m for m in cache.raw_data if m['file_path'] == path), None) + if model: + group["models"].append(await self.service.format_response(model)) + + # Find the model from the main index too + hash_val = self.service.scanner._hash_index.get_hash_by_filename(filename) + if hash_val: + main_path = self.service.get_path_by_hash(hash_val) + if main_path and main_path not in paths: + main_model = next((m for m in cache.raw_data if m['file_path'] == main_path), None) + if main_model: + group["models"].insert(0, await self.service.format_response(main_model)) + + if group["models"]: + result.append(group) + + return web.json_response({ + "success": True, + "conflicts": result, + "count": len(result) + }) + except Exception as e: + logger.error(f"Error finding filename conflicts for {self.model_type}s: {e}", exc_info=True) + return web.json_response({ + "success": False, + "error": str(e) + }, status=500) + + # Download management methods + async def download_model(self, request: web.Request) -> web.Response: + """Handle model download request""" + return await ModelRouteUtils.handle_download_model(request) + + async def download_model_get(self, request: web.Request) -> web.Response: + """Handle model download request via GET method""" + try: + # Extract query parameters + model_id = request.query.get('model_id') + if not model_id: + return web.Response( + status=400, + text="Missing required parameter: Please provide 'model_id'" + ) + + # Get optional parameters + model_version_id = request.query.get('model_version_id') + download_id = request.query.get('download_id') + use_default_paths = request.query.get('use_default_paths', 'false').lower() == 'true' + + # Create a data dictionary that mimics what would be received from a POST request + data = { + 'model_id': model_id + } + + # Add optional parameters only if they are provided + if model_version_id: + data['model_version_id'] = model_version_id + + if download_id: + data['download_id'] = download_id + + data['use_default_paths'] = use_default_paths + + # Create a mock request object with the data + future = asyncio.get_event_loop().create_future() + future.set_result(data) + + mock_request = type('MockRequest', (), { + 'json': lambda self=None: future + })() + + # Call the existing download handler + return await ModelRouteUtils.handle_download_model(mock_request) + + except Exception as e: + error_message = str(e) + logger.error(f"Error downloading model via GET: {error_message}", exc_info=True) + return web.Response(status=500, text=error_message) + + async def cancel_download_get(self, request: web.Request) -> web.Response: + """Handle GET request for cancelling a download by download_id""" + try: + download_id = request.query.get('download_id') + if not download_id: + return web.json_response({ + 'success': False, + 'error': 'Download ID is required' + }, status=400) + + # Create a mock request with match_info for compatibility + mock_request = type('MockRequest', (), { + 'match_info': {'download_id': download_id} + })() + return await ModelRouteUtils.handle_cancel_download(mock_request) + except Exception as e: + logger.error(f"Error cancelling download via GET: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + async def get_download_progress(self, request: web.Request) -> web.Response: + """Handle request for download progress by download_id""" + try: + # Get download_id from URL path + download_id = request.match_info.get('download_id') + if not download_id: + return web.json_response({ + 'success': False, + 'error': 'Download ID is required' + }, status=400) + + progress_data = ws_manager.get_download_progress(download_id) + + if progress_data is None: + return web.json_response({ + 'success': False, + 'error': 'Download ID not found' + }, status=404) + + return web.json_response({ + 'success': True, + 'progress': progress_data.get('progress', 0) + }) + except Exception as e: + logger.error(f"Error getting download progress: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + async def fetch_all_civitai(self, request: web.Request) -> web.Response: + """Fetch CivitAI metadata for all models in the background""" + try: + cache = await self.service.scanner.get_cached_data() + total = len(cache.raw_data) + processed = 0 + success = 0 + needs_resort = False + + # Prepare models to process + to_process = [ + model for model in cache.raw_data + if model.get('sha256') and (not model.get('civitai') or 'id' not in model.get('civitai')) and model.get('from_civitai', True) + ] + total_to_process = len(to_process) + + # Send initial progress + await ws_manager.broadcast({ + 'status': 'started', + 'total': total_to_process, + 'processed': 0, + 'success': 0 + }) + + # Process each model + for model in to_process: + try: + original_name = model.get('model_name') + if await ModelRouteUtils.fetch_and_update_model( + sha256=model['sha256'], + file_path=model['file_path'], + model_data=model, + update_cache_func=self.service.scanner.update_single_model_cache + ): + success += 1 + if original_name != model.get('model_name'): + needs_resort = True + + processed += 1 + + # Send progress update + await ws_manager.broadcast({ + 'status': 'processing', + 'total': total_to_process, + 'processed': processed, + 'success': success, + 'current_name': model.get('model_name', 'Unknown') + }) + + except Exception as e: + logger.error(f"Error fetching CivitAI data for {model['file_path']}: {e}") + + if needs_resort: + await cache.resort() + + # Send completion message + await ws_manager.broadcast({ + 'status': 'completed', + 'total': total_to_process, + 'processed': processed, + 'success': success + }) + + return web.json_response({ + "success": True, + "message": f"Successfully updated {success} of {processed} processed {self.model_type}s (total: {total})" + }) + + except Exception as e: + # Send error message + await ws_manager.broadcast({ + 'status': 'error', + 'error': str(e) + }) + logger.error(f"Error in fetch_all_civitai for {self.model_type}s: {e}") + return web.Response(text=str(e), status=500) + + async def get_civitai_versions(self, request: web.Request) -> web.Response: + """Get available versions for a Civitai model with local availability info""" + # This will be implemented by subclasses as they need CivitAI client access + return web.json_response({ + "error": "Not implemented in base class" + }, status=501) \ No newline at end of file diff --git a/py/routes/checkpoint_routes.py b/py/routes/checkpoint_routes.py new file mode 100644 index 00000000..4f27115e --- /dev/null +++ b/py/routes/checkpoint_routes.py @@ -0,0 +1,105 @@ +import logging +from aiohttp import web + +from .base_model_routes import BaseModelRoutes +from ..services.checkpoint_service import CheckpointService +from ..services.service_registry import ServiceRegistry + +logger = logging.getLogger(__name__) + +class CheckpointRoutes(BaseModelRoutes): + """Checkpoint-specific route controller""" + + def __init__(self): + """Initialize Checkpoint routes with Checkpoint service""" + # Service will be initialized later via setup_routes + self.service = None + self.civitai_client = None + self.template_name = "checkpoints.html" + + async def initialize_services(self): + """Initialize services from ServiceRegistry""" + checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() + self.service = CheckpointService(checkpoint_scanner) + self.civitai_client = await ServiceRegistry.get_civitai_client() + + # Initialize parent with the service + super().__init__(self.service) + + def setup_routes(self, app: web.Application): + """Setup Checkpoint routes""" + # Schedule service initialization on app startup + app.on_startup.append(lambda _: self.initialize_services()) + + # Setup common routes with 'checkpoints' prefix (includes page route) + super().setup_routes(app, 'checkpoints') + + def setup_specific_routes(self, app: web.Application, prefix: str): + """Setup Checkpoint-specific routes""" + # Checkpoint-specific CivitAI integration + app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions_checkpoint) + + # Checkpoint info by name + app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_checkpoint_info) + + async def get_checkpoint_info(self, request: web.Request) -> web.Response: + """Get detailed information for a specific checkpoint by name""" + try: + name = request.match_info.get('name', '') + checkpoint_info = await self.service.get_model_info_by_name(name) + + if checkpoint_info: + return web.json_response(checkpoint_info) + else: + return web.json_response({"error": "Checkpoint not found"}, status=404) + + except Exception as e: + logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True) + return web.json_response({"error": str(e)}, status=500) + + async def get_civitai_versions_checkpoint(self, request: web.Request) -> web.Response: + """Get available versions for a Civitai checkpoint model with local availability info""" + try: + model_id = request.match_info['model_id'] + response = await self.civitai_client.get_model_versions(model_id) + if not response or not response.get('modelVersions'): + return web.Response(status=404, text="Model not found") + + versions = response.get('modelVersions', []) + model_type = response.get('type', '') + + # Check model type - should be Checkpoint + if model_type.lower() != 'checkpoint': + return web.json_response({ + 'error': f"Model type mismatch. Expected Checkpoint, got {model_type}" + }, status=400) + + # Check local availability for each version + for version in versions: + # Find the primary model file (type="Model" and primary=true) in the files list + model_file = next((file for file in version.get('files', []) + if file.get('type') == 'Model' and file.get('primary') == True), None) + + # If no primary file found, try to find any model file + if not model_file: + model_file = next((file for file in version.get('files', []) + if file.get('type') == 'Model'), None) + + if model_file: + sha256 = model_file.get('hashes', {}).get('SHA256') + if sha256: + # Set existsLocally and localPath at the version level + version['existsLocally'] = self.service.has_hash(sha256) + if version['existsLocally']: + version['localPath'] = self.service.get_path_by_hash(sha256) + + # Also set the model file size at the version level for easier access + version['modelSizeKB'] = model_file.get('sizeKB') + else: + # No model file found in this version + version['existsLocally'] = False + + return web.json_response(versions) + except Exception as e: + logger.error(f"Error fetching checkpoint model versions: {e}") + return web.Response(status=500, text=str(e)) \ No newline at end of file diff --git a/py/routes/checkpoints_routes.py b/py/routes/checkpoints_routes.py deleted file mode 100644 index 4edb4456..00000000 --- a/py/routes/checkpoints_routes.py +++ /dev/null @@ -1,771 +0,0 @@ -import os -import json -import jinja2 -from aiohttp import web -import logging -import asyncio - -from ..utils.routes_common import ModelRouteUtils -from ..utils.constants import NSFW_LEVELS -from ..utils.metadata_manager import MetadataManager -from ..services.websocket_manager import ws_manager -from ..services.service_registry import ServiceRegistry -from ..config import config -from ..services.settings_manager import settings -from ..utils.utils import fuzzy_match - -logger = logging.getLogger(__name__) - -class CheckpointsRoutes: - """API routes for checkpoint management""" - - def __init__(self): - self.scanner = None # Will be initialized in setup_routes - self.template_env = jinja2.Environment( - loader=jinja2.FileSystemLoader(config.templates_path), - autoescape=True - ) - self.download_manager = None # Will be initialized in setup_routes - self._download_lock = asyncio.Lock() - - async def initialize_services(self): - """Initialize services from ServiceRegistry""" - self.scanner = await ServiceRegistry.get_checkpoint_scanner() - self.download_manager = await ServiceRegistry.get_download_manager() - - def setup_routes(self, app): - """Register routes with the aiohttp app""" - # Schedule service initialization on app startup - app.on_startup.append(lambda _: self.initialize_services()) - - app.router.add_get('/checkpoints', self.handle_checkpoints_page) - app.router.add_get('/api/checkpoints', self.get_checkpoints) - app.router.add_post('/api/checkpoints/fetch-all-civitai', self.fetch_all_civitai) - app.router.add_get('/api/checkpoints/base-models', self.get_base_models) - app.router.add_get('/api/checkpoints/top-tags', self.get_top_tags) - app.router.add_get('/api/checkpoints/scan', self.scan_checkpoints) - app.router.add_get('/api/checkpoints/info/{name}', self.get_checkpoint_info) - app.router.add_get('/api/checkpoints/roots', self.get_checkpoint_roots) - app.router.add_get('/api/checkpoints/civitai/versions/{model_id}', self.get_civitai_versions) # Add new route - - # Add new routes for model management similar to LoRA routes - app.router.add_post('/api/checkpoints/delete', self.delete_model) - app.router.add_post('/api/checkpoints/exclude', self.exclude_model) # Add new exclude endpoint - app.router.add_post('/api/checkpoints/fetch-civitai', self.fetch_civitai) - app.router.add_post('/api/checkpoints/relink-civitai', self.relink_civitai) # Add new relink endpoint - app.router.add_post('/api/checkpoints/replace-preview', self.replace_preview) - app.router.add_post('/api/checkpoints/save-metadata', self.save_metadata) # Add new route - app.router.add_post('/api/checkpoints/rename', self.rename_checkpoint) # Add new rename endpoint - - # Add new routes for finding duplicates and filename conflicts - app.router.add_get('/api/checkpoints/find-duplicates', self.find_duplicate_checkpoints) - app.router.add_get('/api/checkpoints/find-filename-conflicts', self.find_filename_conflicts) - - # Add new endpoint for bulk deleting checkpoints - app.router.add_post('/api/checkpoints/bulk-delete', self.bulk_delete_checkpoints) - - # Add new endpoint for verifying duplicates - app.router.add_post('/api/checkpoints/verify-duplicates', self.verify_duplicates) - - async def get_checkpoints(self, request): - """Get paginated checkpoint data""" - try: - # Parse query parameters - page = int(request.query.get('page', '1')) - page_size = min(int(request.query.get('page_size', '20')), 100) - sort_by = request.query.get('sort_by', 'name') - folder = request.query.get('folder', None) - search = request.query.get('search', None) - fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true' - base_models = request.query.getall('base_model', []) - tags = request.query.getall('tag', []) - favorites_only = request.query.get('favorites_only', 'false').lower() == 'true' # Add favorites_only parameter - - # Process search options - search_options = { - 'filename': request.query.get('search_filename', 'true').lower() == 'true', - 'modelname': request.query.get('search_modelname', 'true').lower() == 'true', - 'tags': request.query.get('search_tags', 'false').lower() == 'true', - 'recursive': request.query.get('recursive', 'false').lower() == 'true', - } - - # Process hash filters if provided - hash_filters = {} - if 'hash' in request.query: - hash_filters['single_hash'] = request.query['hash'] - elif 'hashes' in request.query: - try: - hash_list = json.loads(request.query['hashes']) - if isinstance(hash_list, list): - hash_filters['multiple_hashes'] = hash_list - except (json.JSONDecodeError, TypeError): - pass - - # Get data from scanner - result = await self.get_paginated_data( - page=page, - page_size=page_size, - sort_by=sort_by, - folder=folder, - search=search, - fuzzy_search=fuzzy_search, - base_models=base_models, - tags=tags, - search_options=search_options, - hash_filters=hash_filters, - favorites_only=favorites_only # Pass favorites_only parameter - ) - - # Format response items - formatted_result = { - 'items': [self._format_checkpoint_response(cp) for cp in result['items']], - 'total': result['total'], - 'page': result['page'], - 'page_size': result['page_size'], - 'total_pages': result['total_pages'] - } - - # Return as JSON - return web.json_response(formatted_result) - - except Exception as e: - logger.error(f"Error in get_checkpoints: {e}", exc_info=True) - return web.json_response({"error": str(e)}, status=500) - - async def get_paginated_data(self, page, page_size, sort_by='name', - folder=None, search=None, fuzzy_search=False, - base_models=None, tags=None, - search_options=None, hash_filters=None, - favorites_only=False): # Add favorites_only parameter with default False - """Get paginated and filtered checkpoint data""" - cache = await self.scanner.get_cached_data() - - # Get default search options if not provided - if search_options is None: - search_options = { - 'filename': True, - 'modelname': True, - 'tags': False, - 'recursive': False, - } - - # Get the base data set - filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name - - # Apply hash filtering if provided (highest priority) - if hash_filters: - single_hash = hash_filters.get('single_hash') - multiple_hashes = hash_filters.get('multiple_hashes') - - if single_hash: - # Filter by single hash - single_hash = single_hash.lower() # Ensure lowercase for matching - filtered_data = [ - cp for cp in filtered_data - if cp.get('sha256', '').lower() == single_hash - ] - elif multiple_hashes: - # Filter by multiple hashes - hash_set = set(hash.lower() for hash in multiple_hashes) # Convert to set for faster lookup - filtered_data = [ - cp for cp in filtered_data - if cp.get('sha256', '').lower() in hash_set - ] - - # Jump to pagination - total_items = len(filtered_data) - start_idx = (page - 1) * page_size - end_idx = min(start_idx + page_size, total_items) - - result = { - 'items': filtered_data[start_idx:end_idx], - 'total': total_items, - 'page': page, - 'page_size': page_size, - 'total_pages': (total_items + page_size - 1) // page_size - } - - return result - - # Apply SFW filtering if enabled in settings - if settings.get('show_only_sfw', False): - filtered_data = [ - cp for cp in filtered_data - if not cp.get('preview_nsfw_level') or cp.get('preview_nsfw_level') < NSFW_LEVELS['R'] - ] - - # Apply favorites filtering if enabled - if favorites_only: - filtered_data = [ - cp for cp in filtered_data - if cp.get('favorite', False) is True - ] - - # Apply folder filtering - if folder is not None: - if search_options.get('recursive', False): - # Recursive folder filtering - include all subfolders - filtered_data = [ - cp for cp in filtered_data - if cp['folder'].startswith(folder) - ] - else: - # Exact folder filtering - filtered_data = [ - cp for cp in filtered_data - if cp['folder'] == folder - ] - - # Apply base model filtering - if base_models and len(base_models) > 0: - filtered_data = [ - cp for cp in filtered_data - if cp.get('base_model') in base_models - ] - - # Apply tag filtering - if tags and len(tags) > 0: - filtered_data = [ - cp for cp in filtered_data - if any(tag in cp.get('tags', []) for tag in tags) - ] - - # Apply search filtering - if search: - search_results = [] - - for cp in filtered_data: - # Search by file name - if search_options.get('filename', True): - if fuzzy_search: - if fuzzy_match(cp.get('file_name', ''), search): - search_results.append(cp) - continue - elif search.lower() in cp.get('file_name', '').lower(): - search_results.append(cp) - continue - - # Search by model name - if search_options.get('modelname', True): - if fuzzy_search: - if fuzzy_match(cp.get('model_name', ''), search): - search_results.append(cp) - continue - elif search.lower() in cp.get('model_name', '').lower(): - search_results.append(cp) - continue - - # Search by tags - if search_options.get('tags', False) and 'tags' in cp: - if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower()) for tag in cp['tags']): - search_results.append(cp) - continue - - filtered_data = search_results - - # Calculate pagination - total_items = len(filtered_data) - start_idx = (page - 1) * page_size - end_idx = min(start_idx + page_size, total_items) - - result = { - 'items': filtered_data[start_idx:end_idx], - 'total': total_items, - 'page': page, - 'page_size': page_size, - 'total_pages': (total_items + page_size - 1) // page_size - } - - return result - - def _format_checkpoint_response(self, checkpoint): - """Format checkpoint data for API response""" - return { - "model_name": checkpoint["model_name"], - "file_name": checkpoint["file_name"], - "preview_url": config.get_preview_static_url(checkpoint.get("preview_url", "")), - "preview_nsfw_level": checkpoint.get("preview_nsfw_level", 0), - "base_model": checkpoint.get("base_model", ""), - "folder": checkpoint["folder"], - "sha256": checkpoint.get("sha256", ""), - "file_path": checkpoint["file_path"].replace(os.sep, "/"), - "file_size": checkpoint.get("size", 0), - "modified": checkpoint.get("modified", ""), - "tags": checkpoint.get("tags", []), - "modelDescription": checkpoint.get("modelDescription", ""), - "from_civitai": checkpoint.get("from_civitai", True), - "notes": checkpoint.get("notes", ""), - "model_type": checkpoint.get("model_type", "checkpoint"), - "favorite": checkpoint.get("favorite", False), - "civitai": ModelRouteUtils.filter_civitai_data(checkpoint.get("civitai", {})) - } - - async def fetch_all_civitai(self, request: web.Request) -> web.Response: - """Fetch CivitAI metadata for all checkpoints in the background""" - try: - cache = await self.scanner.get_cached_data() - total = len(cache.raw_data) - processed = 0 - success = 0 - needs_resort = False - - # Prepare checkpoints to process - to_process = [ - cp for cp in cache.raw_data - if cp.get('sha256') and (not cp.get('civitai') or 'id' not in cp.get('civitai')) and cp.get('from_civitai', True) - ] - total_to_process = len(to_process) - - # Send initial progress - await ws_manager.broadcast({ - 'status': 'started', - 'total': total_to_process, - 'processed': 0, - 'success': 0 - }) - - # Process each checkpoint - for cp in to_process: - try: - original_name = cp.get('model_name') - if await ModelRouteUtils.fetch_and_update_model( - sha256=cp['sha256'], - file_path=cp['file_path'], - model_data=cp, - update_cache_func=self.scanner.update_single_model_cache - ): - success += 1 - if original_name != cp.get('model_name'): - needs_resort = True - - processed += 1 - - # Send progress update - await ws_manager.broadcast({ - 'status': 'processing', - 'total': total_to_process, - 'processed': processed, - 'success': success, - 'current_name': cp.get('model_name', 'Unknown') - }) - - except Exception as e: - logger.error(f"Error fetching CivitAI data for {cp['file_path']}: {e}") - - if needs_resort: - await cache.resort(name_only=True) - - # Send completion message - await ws_manager.broadcast({ - 'status': 'completed', - 'total': total_to_process, - 'processed': processed, - 'success': success - }) - - return web.json_response({ - "success": True, - "message": f"Successfully updated {success} of {processed} processed checkpoints (total: {total})" - }) - - except Exception as e: - # Send error message - await ws_manager.broadcast({ - 'status': 'error', - 'error': str(e) - }) - logger.error(f"Error in fetch_all_civitai for checkpoints: {e}") - return web.Response(text=str(e), status=500) - - async def get_top_tags(self, request: web.Request) -> web.Response: - """Handle request for top tags sorted by frequency""" - try: - # Parse query parameters - limit = int(request.query.get('limit', '20')) - - # Validate limit - if limit < 1 or limit > 100: - limit = 20 # Default to a reasonable limit - - # Get top tags - top_tags = await self.scanner.get_top_tags(limit) - - return web.json_response({ - 'success': True, - 'tags': top_tags - }) - - except Exception as e: - logger.error(f"Error getting top tags: {str(e)}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': 'Internal server error' - }, status=500) - - async def get_base_models(self, request: web.Request) -> web.Response: - """Get base models used in loras""" - try: - # Parse query parameters - limit = int(request.query.get('limit', '20')) - - # Validate limit - if limit < 1 or limit > 100: - limit = 20 # Default to a reasonable limit - - # Get base models - base_models = await self.scanner.get_base_models(limit) - - return web.json_response({ - 'success': True, - 'base_models': base_models - }) - except Exception as e: - logger.error(f"Error retrieving base models: {e}") - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def scan_checkpoints(self, request): - """Force a rescan of checkpoint files""" - try: - # Get the full_rebuild parameter and convert to bool, default to False - full_rebuild = request.query.get('full_rebuild', 'false').lower() == 'true' - - await self.scanner.get_cached_data(force_refresh=True, rebuild_cache=full_rebuild) - return web.json_response({"status": "success", "message": "Checkpoint scan completed"}) - except Exception as e: - logger.error(f"Error in scan_checkpoints: {e}", exc_info=True) - return web.json_response({"error": str(e)}, status=500) - - async def get_checkpoint_info(self, request): - """Get detailed information for a specific checkpoint by name""" - try: - name = request.match_info.get('name', '') - checkpoint_info = await self.scanner.get_model_info_by_name(name) - - if checkpoint_info: - return web.json_response(checkpoint_info) - else: - return web.json_response({"error": "Checkpoint not found"}, status=404) - - except Exception as e: - logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True) - return web.json_response({"error": str(e)}, status=500) - - async def handle_checkpoints_page(self, request: web.Request) -> web.Response: - """Handle GET /checkpoints request""" - try: - # Check if the CheckpointScanner is initializing - # It's initializing if the cache object doesn't exist yet, - # OR if the scanner explicitly says it's initializing (background task running). - is_initializing = ( - self.scanner._cache is None or - (hasattr(self.scanner, '_is_initializing') and self.scanner._is_initializing) - ) - - if is_initializing: - # If still initializing, return loading page - template = self.template_env.get_template('checkpoints.html') - rendered = template.render( - folders=[], # 空文件夹列表 - is_initializing=True, # 新增标志 - settings=settings, # Pass settings to template - request=request # Pass the request object to the template - ) - - logger.info("Checkpoints page is initializing, returning loading page") - else: - # 正常流程 - 获取已经初始化好的缓存数据 - try: - cache = await self.scanner.get_cached_data(force_refresh=False) - template = self.template_env.get_template('checkpoints.html') - rendered = template.render( - folders=cache.folders, - is_initializing=False, - settings=settings, # Pass settings to template - request=request # Pass the request object to the template - ) - except Exception as cache_error: - logger.error(f"Error loading checkpoints cache data: {cache_error}") - # 如果获取缓存失败,也显示初始化页面 - template = self.template_env.get_template('checkpoints.html') - rendered = template.render( - folders=[], - is_initializing=True, - settings=settings, - request=request - ) - logger.info("Checkpoints cache error, returning initialization page") - - return web.Response( - text=rendered, - content_type='text/html' - ) - except Exception as e: - logger.error(f"Error handling checkpoints request: {e}", exc_info=True) - return web.Response( - text="Error loading checkpoints page", - status=500 - ) - - async def delete_model(self, request: web.Request) -> web.Response: - """Handle checkpoint model deletion request""" - return await ModelRouteUtils.handle_delete_model(request, self.scanner) - - async def exclude_model(self, request: web.Request) -> web.Response: - """Handle checkpoint model exclusion request""" - return await ModelRouteUtils.handle_exclude_model(request, self.scanner) - - async def fetch_civitai(self, request: web.Request) -> web.Response: - """Handle CivitAI metadata fetch request for checkpoints""" - response = await ModelRouteUtils.handle_fetch_civitai(request, self.scanner) - - # If successful, format the metadata before returning - if response.status == 200: - data = json.loads(response.body.decode('utf-8')) - if data.get("success") and data.get("metadata"): - formatted_metadata = self._format_checkpoint_response(data["metadata"]) - return web.json_response({ - "success": True, - "metadata": formatted_metadata - }) - - # Otherwise, return the original response - return response - - async def replace_preview(self, request: web.Request) -> web.Response: - """Handle preview image replacement for checkpoints""" - return await ModelRouteUtils.handle_replace_preview(request, self.scanner) - - async def get_checkpoint_roots(self, request): - """Return the checkpoint root directories""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_checkpoint_scanner() - - roots = self.scanner.get_model_roots() - return web.json_response({ - "success": True, - "roots": roots - }) - except Exception as e: - logger.error(f"Error getting checkpoint roots: {e}", exc_info=True) - return web.json_response({ - "success": False, - "error": str(e) - }, status=500) - - async def save_metadata(self, request: web.Request) -> web.Response: - """Handle saving metadata updates for checkpoints""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_checkpoint_scanner() - - data = await request.json() - file_path = data.get('file_path') - if not file_path: - return web.Response(text='File path is required', status=400) - - # Remove file path from data to avoid saving it - metadata_updates = {k: v for k, v in data.items() if k != 'file_path'} - - # Get metadata file path - metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' - - # Load existing metadata - metadata = await ModelRouteUtils.load_local_metadata(metadata_path) - - # Update metadata - metadata.update(metadata_updates) - - # Save updated metadata - await MetadataManager.save_metadata(file_path, metadata) - - # Update cache - await self.scanner.update_single_model_cache(file_path, file_path, metadata) - - # If model_name was updated, resort the cache - if 'model_name' in metadata_updates: - cache = await self.scanner.get_cached_data() - await cache.resort(name_only=True) - - return web.json_response({'success': True}) - - except Exception as e: - logger.error(f"Error saving checkpoint metadata: {e}", exc_info=True) - return web.Response(text=str(e), status=500) - - async def get_civitai_versions(self, request: web.Request) -> web.Response: - """Get available versions for a Civitai checkpoint model with local availability info""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_checkpoint_scanner() - - # Get the civitai client from service registry - civitai_client = await ServiceRegistry.get_civitai_client() - - model_id = request.match_info['model_id'] - response = await civitai_client.get_model_versions(model_id) - if not response or not response.get('modelVersions'): - return web.Response(status=404, text="Model not found") - - versions = response.get('modelVersions', []) - model_type = response.get('type', '') - - # Check model type - should be Checkpoint - if (model_type.lower() != 'checkpoint'): - return web.json_response({ - 'error': f"Model type mismatch. Expected Checkpoint, got {model_type}" - }, status=400) - - # Check local availability for each version - for version in versions: - # Find the primary model file (type="Model" and primary=true) in the files list - model_file = next((file for file in version.get('files', []) - if file.get('type') == 'Model' and file.get('primary') == True), None) - - # If no primary file found, try to find any model file - if not model_file: - model_file = next((file for file in version.get('files', []) - if file.get('type') == 'Model'), None) - - if model_file: - sha256 = model_file.get('hashes', {}).get('SHA256') - if sha256: - # Set existsLocally and localPath at the version level - version['existsLocally'] = self.scanner.has_hash(sha256) - if version['existsLocally']: - version['localPath'] = self.scanner.get_path_by_hash(sha256) - - # Also set the model file size at the version level for easier access - version['modelSizeKB'] = model_file.get('sizeKB') - else: - # No model file found in this version - version['existsLocally'] = False - - return web.json_response(versions) - except Exception as e: - logger.error(f"Error fetching checkpoint model versions: {e}") - return web.Response(status=500, text=str(e)) - - async def find_duplicate_checkpoints(self, request: web.Request) -> web.Response: - """Find checkpoints with duplicate SHA256 hashes""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_checkpoint_scanner() - - # Get duplicate hashes from hash index - duplicates = self.scanner._hash_index.get_duplicate_hashes() - - # Format the response - result = [] - cache = await self.scanner.get_cached_data() - - for sha256, paths in duplicates.items(): - group = { - "hash": sha256, - "models": [] - } - # Find matching models for each path - for path in paths: - model = next((m for m in cache.raw_data if m['file_path'] == path), None) - if model: - group["models"].append(self._format_checkpoint_response(model)) - - # Add the primary model too - primary_path = self.scanner._hash_index.get_path(sha256) - if primary_path and primary_path not in paths: - primary_model = next((m for m in cache.raw_data if m['file_path'] == primary_path), None) - if primary_model: - group["models"].insert(0, self._format_checkpoint_response(primary_model)) - - if len(group["models"]) > 1: # Only include if we found multiple models - result.append(group) - - return web.json_response({ - "success": True, - "duplicates": result, - "count": len(result) - }) - except Exception as e: - logger.error(f"Error finding duplicate checkpoints: {e}", exc_info=True) - return web.json_response({ - "success": False, - "error": str(e) - }, status=500) - - async def find_filename_conflicts(self, request: web.Request) -> web.Response: - """Find checkpoints with conflicting filenames""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_checkpoint_scanner() - - # Get duplicate filenames from hash index - duplicates = self.scanner._hash_index.get_duplicate_filenames() - - # Format the response - result = [] - cache = await self.scanner.get_cached_data() - - for filename, paths in duplicates.items(): - group = { - "filename": filename, - "models": [] - } - # Find matching models for each path - for path in paths: - model = next((m for m in cache.raw_data if m['file_path'] == path), None) - if model: - group["models"].append(self._format_checkpoint_response(model)) - - # Find the model from the main index too - hash_val = self.scanner._hash_index.get_hash_by_filename(filename) - if hash_val: - main_path = self.scanner._hash_index.get_path(hash_val) - if main_path and main_path not in paths: - main_model = next((m for m in cache.raw_data if m['file_path'] == main_path), None) - if main_model: - group["models"].insert(0, self._format_checkpoint_response(main_model)) - - if group["models"]: - result.append(group) - - return web.json_response({ - "success": True, - "conflicts": result, - "count": len(result) - }) - except Exception as e: - logger.error(f"Error finding filename conflicts: {e}", exc_info=True) - return web.json_response({ - "success": False, - "error": str(e) - }, status=500) - - async def bulk_delete_checkpoints(self, request: web.Request) -> web.Response: - """Handle bulk deletion of checkpoint models""" - try: - if self.scanner is None: - self.scanner = await ServiceRegistry.get_checkpoint_scanner() - - return await ModelRouteUtils.handle_bulk_delete_models(request, self.scanner) - - except Exception as e: - logger.error(f"Error in bulk delete checkpoints: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def relink_civitai(self, request: web.Request) -> web.Response: - """Handle CivitAI metadata re-linking request by model version ID for checkpoints""" - return await ModelRouteUtils.handle_relink_civitai(request, self.scanner) - - async def verify_duplicates(self, request: web.Request) -> web.Response: - """Handle verification of duplicate checkpoint hashes""" - return await ModelRouteUtils.handle_verify_duplicates(request, self.scanner) - - async def rename_checkpoint(self, request: web.Request) -> web.Response: - """Handle renaming a checkpoint file and its associated files""" - return await ModelRouteUtils.handle_rename_model(request, self.scanner) diff --git a/py/routes/lora_routes.py b/py/routes/lora_routes.py index 1d00b66e..ac7dc4ed 100644 --- a/py/routes/lora_routes.py +++ b/py/routes/lora_routes.py @@ -1,188 +1,512 @@ -import os -from aiohttp import web -import jinja2 -from typing import Dict +import asyncio import logging -from ..config import config -from ..services.settings_manager import settings -from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import +from aiohttp import web +from typing import Dict +from server import PromptServer # type: ignore + +from .base_model_routes import BaseModelRoutes +from ..services.lora_service import LoraService +from ..services.service_registry import ServiceRegistry +from ..utils.routes_common import ModelRouteUtils +from ..utils.utils import get_lora_info logger = logging.getLogger(__name__) -logging.getLogger('asyncio').setLevel(logging.CRITICAL) -class LoraRoutes: - """Route handlers for LoRA management endpoints""" +class LoraRoutes(BaseModelRoutes): + """LoRA-specific route controller""" def __init__(self): - # Initialize service references as None, will be set during async init - self.scanner = None - self.recipe_scanner = None - self.template_env = jinja2.Environment( - loader=jinja2.FileSystemLoader(config.templates_path), - autoescape=True - ) - - async def init_services(self): - """Initialize services from ServiceRegistry""" - self.scanner = await ServiceRegistry.get_lora_scanner() - self.recipe_scanner = await ServiceRegistry.get_recipe_scanner() + """Initialize LoRA routes with LoRA service""" + # Service will be initialized later via setup_routes + self.service = None + self.civitai_client = None + self.template_name = "loras.html" + + async def initialize_services(self): + """Initialize services from ServiceRegistry""" + lora_scanner = await ServiceRegistry.get_lora_scanner() + self.service = LoraService(lora_scanner) + self.civitai_client = await ServiceRegistry.get_civitai_client() + + # Initialize parent with the service + super().__init__(self.service) - def format_lora_data(self, lora: Dict) -> Dict: - """Format LoRA data for template rendering""" - return { - "model_name": lora["model_name"], - "file_name": lora["file_name"], - "preview_url": config.get_preview_static_url(lora["preview_url"]), - "preview_nsfw_level": lora.get("preview_nsfw_level", 0), - "base_model": lora["base_model"], - "folder": lora["folder"], - "sha256": lora["sha256"], - "file_path": lora["file_path"].replace(os.sep, "/"), - "size": lora["size"], - "tags": lora["tags"], - "modelDescription": lora["modelDescription"], - "usage_tips": lora["usage_tips"], - "notes": lora["notes"], - "modified": lora["modified"], - "from_civitai": lora.get("from_civitai", True), - "civitai": self._filter_civitai_data(lora.get("civitai", {})) - } - - def _filter_civitai_data(self, data: Dict) -> Dict: - """Filter relevant fields from CivitAI data""" - if not data: - return {} - - fields = [ - "id", "modelId", "name", "createdAt", "updatedAt", - "publishedAt", "trainedWords", "baseModel", "description", - "model", "images" - ] - return {k: data[k] for k in fields if k in data} - - async def handle_loras_page(self, request: web.Request) -> web.Response: - """Handle GET /loras request""" - try: - # Ensure services are initialized - await self.init_services() - - # Check if the LoraScanner is initializing - # It's initializing if the cache object doesn't exist yet, - # OR if the scanner explicitly says it's initializing (background task running). - is_initializing = ( - self.scanner._cache is None or self.scanner.is_initializing() - ) - - if is_initializing: - # If still initializing, return loading page - template = self.template_env.get_template('loras.html') - rendered = template.render( - folders=[], - is_initializing=True, - settings=settings, - request=request - ) - - logger.info("Loras page is initializing, returning loading page") - else: - # Normal flow - get data from initialized cache - try: - cache = await self.scanner.get_cached_data(force_refresh=False) - template = self.template_env.get_template('loras.html') - rendered = template.render( - folders=cache.folders, - is_initializing=False, - settings=settings, - request=request - ) - except Exception as cache_error: - logger.error(f"Error loading cache data: {cache_error}") - template = self.template_env.get_template('loras.html') - rendered = template.render( - folders=[], - is_initializing=True, - settings=settings, - request=request - ) - logger.info("Cache error, returning initialization page") - - return web.Response( - text=rendered, - content_type='text/html' - ) - - except Exception as e: - logger.error(f"Error handling loras request: {e}", exc_info=True) - return web.Response( - text="Error loading loras page", - status=500 - ) - - async def handle_recipes_page(self, request: web.Request) -> web.Response: - """Handle GET /loras/recipes request""" - try: - # Ensure services are initialized - await self.init_services() - - # Skip initialization check and directly try to get cached data - try: - # Recipe scanner will initialize cache if needed - await self.recipe_scanner.get_cached_data(force_refresh=False) - template = self.template_env.get_template('recipes.html') - rendered = template.render( - recipes=[], # Frontend will load recipes via API - is_initializing=False, - settings=settings, - request=request - ) - except Exception as cache_error: - logger.error(f"Error loading recipe cache data: {cache_error}") - # Still keep error handling - show initializing page on error - template = self.template_env.get_template('recipes.html') - rendered = template.render( - is_initializing=True, - settings=settings, - request=request - ) - logger.info("Recipe cache error, returning initialization page") - - return web.Response( - text=rendered, - content_type='text/html' - ) - - except Exception as e: - logger.error(f"Error handling recipes request: {e}", exc_info=True) - return web.Response( - text="Error loading recipes page", - status=500 - ) - - def _format_recipe_file_url(self, file_path: str) -> str: - """Format file path for recipe image as a URL - same as in recipe_routes""" - try: - # Return the file URL directly for the first lora root's preview - recipes_dir = os.path.join(config.loras_roots[0], "recipes").replace(os.sep, '/') - if file_path.replace(os.sep, '/').startswith(recipes_dir): - relative_path = os.path.relpath(file_path, config.loras_roots[0]).replace(os.sep, '/') - return f"/loras_static/root1/preview/{relative_path}" - - # If not in recipes dir, try to create a valid URL from the file path - file_name = os.path.basename(file_path) - return f"/loras_static/root1/preview/recipes/{file_name}" - except Exception as e: - logger.error(f"Error formatting recipe file URL: {e}", exc_info=True) - return '/loras_static/images/no-preview.png' # Return default image on error - def setup_routes(self, app: web.Application): - """Register routes with the application""" - # Add an app startup handler to initialize services - app.on_startup.append(self._on_startup) + """Setup LoRA routes""" + # Schedule service initialization on app startup + app.on_startup.append(lambda _: self.initialize_services()) - # Register routes - app.router.add_get('/loras', self.handle_loras_page) - app.router.add_get('/loras/recipes', self.handle_recipes_page) + # Setup common routes with 'loras' prefix (includes page route) + super().setup_routes(app, 'loras') + + def setup_specific_routes(self, app: web.Application, prefix: str): + """Setup LoRA-specific routes""" + # LoRA-specific query routes + app.router.add_get(f'/api/{prefix}/letter-counts', self.get_letter_counts) + app.router.add_get(f'/api/{prefix}/get-notes', self.get_lora_notes) + app.router.add_get(f'/api/{prefix}/get-trigger-words', self.get_lora_trigger_words) + app.router.add_get(f'/api/lora-preview-url', self.get_lora_preview_url) + app.router.add_get(f'/api/lora-civitai-url', self.get_lora_civitai_url) + app.router.add_get(f'/api/lora-model-description', self.get_lora_model_description) - async def _on_startup(self, app): - """Initialize services when the app starts""" - await self.init_services() + # LoRA-specific management routes + app.router.add_post(f'/api/move_model', self.move_model) + app.router.add_post(f'/api/move_models_bulk', self.move_models_bulk) + + # CivitAI integration with LoRA-specific validation + app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions_lora) + app.router.add_get(f'/api/civitai/model/version/{{modelVersionId}}', self.get_civitai_model_by_version) + app.router.add_get(f'/api/civitai/model/hash/{{hash}}', self.get_civitai_model_by_hash) + + # ComfyUI integration + app.router.add_post(f'/loramanager/get_trigger_words', self.get_trigger_words) + + def _parse_specific_params(self, request: web.Request) -> Dict: + """Parse LoRA-specific parameters""" + params = {} + + # LoRA-specific parameters + if 'first_letter' in request.query: + params['first_letter'] = request.query.get('first_letter') + + # Handle fuzzy search parameter name variation + if request.query.get('fuzzy') == 'true': + params['fuzzy_search'] = True + + # Handle additional filter parameters for LoRAs + if 'lora_hash' in request.query: + if not params.get('hash_filters'): + params['hash_filters'] = {} + params['hash_filters']['single_hash'] = request.query['lora_hash'].lower() + elif 'lora_hashes' in request.query: + if not params.get('hash_filters'): + params['hash_filters'] = {} + params['hash_filters']['multiple_hashes'] = [h.lower() for h in request.query['lora_hashes'].split(',')] + + return params + + # LoRA-specific route handlers + async def get_letter_counts(self, request: web.Request) -> web.Response: + """Get count of LoRAs for each letter of the alphabet""" + try: + letter_counts = await self.service.get_letter_counts() + return web.json_response({ + 'success': True, + 'letter_counts': letter_counts + }) + except Exception as e: + logger.error(f"Error getting letter counts: {e}") + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + async def get_lora_notes(self, request: web.Request) -> web.Response: + """Get notes for a specific LoRA file""" + try: + lora_name = request.query.get('name') + if not lora_name: + return web.Response(text='Lora file name is required', status=400) + + notes = await self.service.get_lora_notes(lora_name) + if notes is not None: + return web.json_response({ + 'success': True, + 'notes': notes + }) + else: + return web.json_response({ + 'success': False, + 'error': 'LoRA not found in cache' + }, status=404) + + except Exception as e: + logger.error(f"Error getting lora notes: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + async def get_lora_trigger_words(self, request: web.Request) -> web.Response: + """Get trigger words for a specific LoRA file""" + try: + lora_name = request.query.get('name') + if not lora_name: + return web.Response(text='Lora file name is required', status=400) + + trigger_words = await self.service.get_lora_trigger_words(lora_name) + return web.json_response({ + 'success': True, + 'trigger_words': trigger_words + }) + + except Exception as e: + logger.error(f"Error getting lora trigger words: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + async def get_lora_preview_url(self, request: web.Request) -> web.Response: + """Get the static preview URL for a LoRA file""" + try: + lora_name = request.query.get('name') + if not lora_name: + return web.Response(text='Lora file name is required', status=400) + + preview_url = await self.service.get_lora_preview_url(lora_name) + if preview_url: + return web.json_response({ + 'success': True, + 'preview_url': preview_url + }) + else: + return web.json_response({ + 'success': False, + 'error': 'No preview URL found for the specified lora' + }, status=404) + + except Exception as e: + logger.error(f"Error getting lora preview URL: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + async def get_lora_civitai_url(self, request: web.Request) -> web.Response: + """Get the Civitai URL for a LoRA file""" + try: + lora_name = request.query.get('name') + if not lora_name: + return web.Response(text='Lora file name is required', status=400) + + result = await self.service.get_lora_civitai_url(lora_name) + if result['civitai_url']: + return web.json_response({ + 'success': True, + **result + }) + else: + return web.json_response({ + 'success': False, + 'error': 'No Civitai data found for the specified lora' + }, status=404) + + except Exception as e: + logger.error(f"Error getting lora Civitai URL: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + # Override get_models to add LoRA-specific response data + async def get_models(self, request: web.Request) -> web.Response: + """Get paginated LoRA data with LoRA-specific fields""" + try: + # Parse common query parameters + params = self._parse_common_params(request) + + # Get data from service + result = await self.service.get_paginated_data(**params) + + # Get all available folders from cache for LoRA-specific response + cache = await self.service.scanner.get_cached_data() + + # Format response items with LoRA-specific structure + formatted_result = { + 'items': [await self.service.format_response(item) for item in result['items']], + 'folders': cache.folders, # LoRA-specific: include folders in response + 'total': result['total'], + 'page': result['page'], + 'page_size': result['page_size'], + 'total_pages': result['total_pages'] + } + + return web.json_response(formatted_result) + + except Exception as e: + logger.error(f"Error in get_loras: {e}", exc_info=True) + return web.json_response({"error": str(e)}, status=500) + + # CivitAI integration methods + async def get_civitai_versions_lora(self, request: web.Request) -> web.Response: + """Get available versions for a Civitai LoRA model with local availability info""" + try: + model_id = request.match_info['model_id'] + response = await self.civitai_client.get_model_versions(model_id) + if not response or not response.get('modelVersions'): + return web.Response(status=404, text="Model not found") + + versions = response.get('modelVersions', []) + model_type = response.get('type', '') + + # Check model type - should be LORA, LoCon, or DORA + from ..utils.constants import VALID_LORA_TYPES + if model_type.lower() not in VALID_LORA_TYPES: + return web.json_response({ + 'error': f"Model type mismatch. Expected LORA or LoCon, got {model_type}" + }, status=400) + + # Check local availability for each version + for version in versions: + # Find the model file (type="Model") in the files list + model_file = next((file for file in version.get('files', []) + if file.get('type') == 'Model'), None) + + if model_file: + sha256 = model_file.get('hashes', {}).get('SHA256') + if sha256: + # Set existsLocally and localPath at the version level + version['existsLocally'] = self.service.has_hash(sha256) + if version['existsLocally']: + version['localPath'] = self.service.get_path_by_hash(sha256) + + # Also set the model file size at the version level for easier access + version['modelSizeKB'] = model_file.get('sizeKB') + else: + # No model file found in this version + version['existsLocally'] = False + + return web.json_response(versions) + except Exception as e: + logger.error(f"Error fetching LoRA model versions: {e}") + return web.Response(status=500, text=str(e)) + + async def get_civitai_model_by_version(self, request: web.Request) -> web.Response: + """Get CivitAI model details by model version ID""" + try: + model_version_id = request.match_info.get('modelVersionId') + + # Get model details from Civitai API + model, error_msg = await self.civitai_client.get_model_version_info(model_version_id) + + if not model: + # Log warning for failed model retrieval + logger.warning(f"Failed to fetch model version {model_version_id}: {error_msg}") + + # Determine status code based on error message + status_code = 404 if error_msg and "not found" in error_msg.lower() else 500 + + return web.json_response({ + "success": False, + "error": error_msg or "Failed to fetch model information" + }, status=status_code) + + return web.json_response(model) + except Exception as e: + logger.error(f"Error fetching model details: {e}") + return web.json_response({ + "success": False, + "error": str(e) + }, status=500) + + async def get_civitai_model_by_hash(self, request: web.Request) -> web.Response: + """Get CivitAI model details by hash""" + try: + hash = request.match_info.get('hash') + model = await self.civitai_client.get_model_by_hash(hash) + return web.json_response(model) + except Exception as e: + logger.error(f"Error fetching model details by hash: {e}") + return web.json_response({ + "success": False, + "error": str(e) + }, status=500) + + # Model management methods + async def move_model(self, request: web.Request) -> web.Response: + """Handle model move request""" + try: + data = await request.json() + file_path = data.get('file_path') # full path of the model file + target_path = data.get('target_path') # folder path to move the model to + + if not file_path or not target_path: + return web.Response(text='File path and target path are required', status=400) + + # Check if source and destination are the same + import os + source_dir = os.path.dirname(file_path) + if os.path.normpath(source_dir) == os.path.normpath(target_path): + logger.info(f"Source and target directories are the same: {source_dir}") + return web.json_response({'success': True, 'message': 'Source and target directories are the same'}) + + # Check if target file already exists + file_name = os.path.basename(file_path) + target_file_path = os.path.join(target_path, file_name).replace(os.sep, '/') + + if os.path.exists(target_file_path): + return web.json_response({ + 'success': False, + 'error': f"Target file already exists: {target_file_path}" + }, status=409) # 409 Conflict + + # Call scanner to handle the move operation + success = await self.service.scanner.move_model(file_path, target_path) + + if success: + return web.json_response({'success': True}) + else: + return web.Response(text='Failed to move model', status=500) + + except Exception as e: + logger.error(f"Error moving model: {e}", exc_info=True) + return web.Response(text=str(e), status=500) + + async def move_models_bulk(self, request: web.Request) -> web.Response: + """Handle bulk model move request""" + try: + data = await request.json() + file_paths = data.get('file_paths', []) # list of full paths of the model files + target_path = data.get('target_path') # folder path to move the models to + + if not file_paths or not target_path: + return web.Response(text='File paths and target path are required', status=400) + + results = [] + import os + for file_path in file_paths: + # Check if source and destination are the same + source_dir = os.path.dirname(file_path) + if os.path.normpath(source_dir) == os.path.normpath(target_path): + results.append({ + "path": file_path, + "success": True, + "message": "Source and target directories are the same" + }) + continue + + # Check if target file already exists + file_name = os.path.basename(file_path) + target_file_path = os.path.join(target_path, file_name).replace(os.sep, '/') + + if os.path.exists(target_file_path): + results.append({ + "path": file_path, + "success": False, + "message": f"Target file already exists: {target_file_path}" + }) + continue + + # Try to move the model + success = await self.service.scanner.move_model(file_path, target_path) + results.append({ + "path": file_path, + "success": success, + "message": "Success" if success else "Failed to move model" + }) + + # Count successes and failures + success_count = sum(1 for r in results if r["success"]) + failure_count = len(results) - success_count + + return web.json_response({ + 'success': True, + 'message': f'Moved {success_count} of {len(file_paths)} models', + 'results': results, + 'success_count': success_count, + 'failure_count': failure_count + }) + + except Exception as e: + logger.error(f"Error moving models in bulk: {e}", exc_info=True) + return web.Response(text=str(e), status=500) + + async def get_lora_model_description(self, request: web.Request) -> web.Response: + """Get model description for a Lora model""" + try: + # Get parameters + model_id = request.query.get('model_id') + file_path = request.query.get('file_path') + + if not model_id: + return web.json_response({ + 'success': False, + 'error': 'Model ID is required' + }, status=400) + + # Check if we already have the description stored in metadata + description = None + tags = [] + creator = {} + if file_path: + import os + from ..utils.metadata_manager import MetadataManager + metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' + metadata = await ModelRouteUtils.load_local_metadata(metadata_path) + description = metadata.get('modelDescription') + tags = metadata.get('tags', []) + creator = metadata.get('creator', {}) + + # If description is not in metadata, fetch from CivitAI + if not description: + logger.info(f"Fetching model metadata for model ID: {model_id}") + model_metadata, _ = await self.civitai_client.get_model_metadata(model_id) + + if model_metadata: + description = model_metadata.get('description') + tags = model_metadata.get('tags', []) + creator = model_metadata.get('creator', {}) + + # Save the metadata to file if we have a file path and got metadata + if file_path: + try: + metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' + metadata = await ModelRouteUtils.load_local_metadata(metadata_path) + + metadata['modelDescription'] = description + metadata['tags'] = tags + # Ensure the civitai dict exists + if 'civitai' not in metadata: + metadata['civitai'] = {} + # Store creator in the civitai nested structure + metadata['civitai']['creator'] = creator + + await MetadataManager.save_metadata(file_path, metadata, True) + except Exception as e: + logger.error(f"Error saving model metadata: {e}") + + return web.json_response({ + 'success': True, + 'description': description or "

No model description available.

", + 'tags': tags, + 'creator': creator + }) + + except Exception as e: + logger.error(f"Error getting model metadata: {e}") + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + async def get_trigger_words(self, request: web.Request) -> web.Response: + """Get trigger words for specified LoRA models""" + try: + json_data = await request.json() + lora_names = json_data.get("lora_names", []) + node_ids = json_data.get("node_ids", []) + + all_trigger_words = [] + for lora_name in lora_names: + _, trigger_words = get_lora_info(lora_name) + all_trigger_words.extend(trigger_words) + + # Format the trigger words + trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else "" + + # Send update to all connected trigger word toggle nodes + for node_id in node_ids: + PromptServer.instance.send_sync("trigger_word_update", { + "id": node_id, + "message": trigger_words_text + }) + + return web.json_response({"success": True}) + + except Exception as e: + logger.error(f"Error getting trigger words: {e}") + return web.json_response({ + "success": False, + "error": str(e) + }, status=500) diff --git a/py/routes/misc_routes.py b/py/routes/misc_routes.py index 019aa82e..596e0323 100644 --- a/py/routes/misc_routes.py +++ b/py/routes/misc_routes.py @@ -633,9 +633,8 @@ class MiscRoutes: }, status=400) # Get both lora and checkpoint scanners - registry = ServiceRegistry.get_instance() - lora_scanner = await registry.get_lora_scanner() - checkpoint_scanner = await registry.get_checkpoint_scanner() + lora_scanner = await ServiceRegistry.get_lora_scanner() + checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() # If modelVersionId is provided, check for specific version if model_version_id_str: diff --git a/py/routes/recipe_routes.py b/py/routes/recipe_routes.py index 393a5ae3..cf1c67c2 100644 --- a/py/routes/recipe_routes.py +++ b/py/routes/recipe_routes.py @@ -1,6 +1,7 @@ import os import time import base64 +import jinja2 import numpy as np from PIL import Image import io @@ -15,6 +16,7 @@ from ..utils.exif_utils import ExifUtils from ..recipes import RecipeParserFactory from ..utils.constants import CARD_PREVIEW_WIDTH +from ..services.settings_manager import settings from ..config import config # Check if running in standalone mode @@ -39,7 +41,10 @@ class RecipeRoutes: # Initialize service references as None, will be set during async init self.recipe_scanner = None self.civitai_client = None - # Remove WorkflowParser instance + self.template_env = jinja2.Environment( + loader=jinja2.FileSystemLoader(config.templates_path), + autoescape=True + ) # Pre-warm the cache self._init_cache_task = None @@ -53,6 +58,8 @@ class RecipeRoutes: def setup_routes(cls, app: web.Application): """Register API routes""" routes = cls() + app.router.add_get('/loras/recipes', routes.handle_recipes_page) + app.router.add_get('/api/recipes', routes.get_recipes) app.router.add_get('/api/recipe/{recipe_id}', routes.get_recipe_detail) app.router.add_post('/api/recipes/analyze-image', routes.analyze_recipe_image) @@ -114,6 +121,46 @@ class RecipeRoutes: await self.recipe_scanner.get_cached_data(force_refresh=True) except Exception as e: logger.error(f"Error pre-warming recipe cache: {e}", exc_info=True) + + async def handle_recipes_page(self, request: web.Request) -> web.Response: + """Handle GET /loras/recipes request""" + try: + # Ensure services are initialized + await self.init_services() + + # Skip initialization check and directly try to get cached data + try: + # Recipe scanner will initialize cache if needed + await self.recipe_scanner.get_cached_data(force_refresh=False) + template = self.template_env.get_template('recipes.html') + rendered = template.render( + recipes=[], # Frontend will load recipes via API + is_initializing=False, + settings=settings, + request=request + ) + except Exception as cache_error: + logger.error(f"Error loading recipe cache data: {cache_error}") + # Still keep error handling - show initializing page on error + template = self.template_env.get_template('recipes.html') + rendered = template.render( + is_initializing=True, + settings=settings, + request=request + ) + logger.info("Recipe cache error, returning initialization page") + + return web.Response( + text=rendered, + content_type='text/html' + ) + + except Exception as e: + logger.error(f"Error handling recipes request: {e}", exc_info=True) + return web.Response( + text="Error loading recipes page", + status=500 + ) async def get_recipes(self, request: web.Request) -> web.Response: """API endpoint for getting paginated recipes""" @@ -1101,7 +1148,7 @@ class RecipeRoutes: for lora_name, lora_strength in lora_matches: try: # Get lora info from scanner - lora_info = await self.recipe_scanner._lora_scanner.get_lora_info_by_name(lora_name) + lora_info = await self.recipe_scanner._lora_scanner.get_model_info_by_name(lora_name) # Create lora entry lora_entry = { @@ -1120,7 +1167,7 @@ class RecipeRoutes: # Get base model from lora scanner for the available loras base_model_counts = {} for lora in loras_data: - lora_info = await self.recipe_scanner._lora_scanner.get_lora_info_by_name(lora.get("file_name", "")) + lora_info = await self.recipe_scanner._lora_scanner.get_model_info_by_name(lora.get("file_name", "")) if lora_info and "base_model" in lora_info: base_model = lora_info["base_model"] base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1 @@ -1210,7 +1257,7 @@ class RecipeRoutes: if lora.get("isDeleted", False): continue - if not self.recipe_scanner._lora_scanner.has_lora_hash(lora.get("hash", "")): + if not self.recipe_scanner._lora_scanner.has_hash(lora.get("hash", "")): continue # Get the strength @@ -1318,7 +1365,7 @@ class RecipeRoutes: return web.json_response({"error": "Recipe not found"}, status=404) # Find target LoRA by name - target_lora = await lora_scanner.get_lora_info_by_name(target_name) + target_lora = await lora_scanner.get_model_info_by_name(target_name) if not target_lora: return web.json_response({"error": f"Local LoRA not found with name: {target_name}"}, status=404) @@ -1430,9 +1477,9 @@ class RecipeRoutes: if 'loras' in recipe: for lora in recipe['loras']: if 'hash' in lora and lora['hash']: - lora['inLibrary'] = self.recipe_scanner._lora_scanner.has_lora_hash(lora['hash'].lower()) + lora['inLibrary'] = self.recipe_scanner._lora_scanner.has_hash(lora['hash'].lower()) lora['preview_url'] = self.recipe_scanner._lora_scanner.get_preview_url_by_hash(lora['hash'].lower()) - lora['localPath'] = self.recipe_scanner._lora_scanner.get_lora_path_by_hash(lora['hash'].lower()) + lora['localPath'] = self.recipe_scanner._lora_scanner.get_path_by_hash(lora['hash'].lower()) # Ensure file_url is set (needed by frontend) if 'file_path' in recipe: diff --git a/py/routes/update_routes.py b/py/routes/update_routes.py index 3b45c0f1..0d126cb5 100644 --- a/py/routes/update_routes.py +++ b/py/routes/update_routes.py @@ -1,11 +1,13 @@ import os +import subprocess import aiohttp import logging import toml -import subprocess +import git from datetime import datetime from aiohttp import web -from typing import Dict, Any, List +from typing import Dict, List + logger = logging.getLogger(__name__) @@ -17,6 +19,7 @@ class UpdateRoutes: """Register update check routes""" app.router.add_get('/api/check-updates', UpdateRoutes.check_updates) app.router.add_get('/api/version-info', UpdateRoutes.get_version_info) + app.router.add_post('/api/perform-update', UpdateRoutes.perform_update) @staticmethod async def check_updates(request): @@ -25,6 +28,8 @@ class UpdateRoutes: Returns update status and version information """ try: + nightly = request.query.get('nightly', 'false').lower() == 'true' + # Read local version from pyproject.toml local_version = UpdateRoutes._get_local_version() @@ -32,13 +37,21 @@ class UpdateRoutes: git_info = UpdateRoutes._get_git_info() # Fetch remote version from GitHub - remote_version, changelog = await UpdateRoutes._get_remote_version() + if nightly: + remote_version, changelog = await UpdateRoutes._get_nightly_version() + else: + remote_version, changelog = await UpdateRoutes._get_remote_version() # Compare versions - update_available = UpdateRoutes._compare_versions( - local_version.replace('v', ''), - remote_version.replace('v', '') - ) + if nightly: + # For nightly, compare commit hashes + update_available = UpdateRoutes._compare_nightly_versions(git_info, remote_version) + else: + # For stable, compare semantic versions + update_available = UpdateRoutes._compare_versions( + local_version.replace('v', ''), + remote_version.replace('v', '') + ) return web.json_response({ 'success': True, @@ -46,7 +59,8 @@ class UpdateRoutes: 'latest_version': remote_version, 'update_available': update_available, 'changelog': changelog, - 'git_info': git_info + 'git_info': git_info, + 'nightly': nightly }) except Exception as e: @@ -55,7 +69,7 @@ class UpdateRoutes: 'success': False, 'error': str(e) }) - + @staticmethod async def get_version_info(request): """ @@ -84,6 +98,168 @@ class UpdateRoutes: 'error': str(e) }) + @staticmethod + async def perform_update(request): + """ + Perform Git-based update to latest release tag or main branch + """ + try: + # Parse request body + body = await request.json() if request.has_body else {} + nightly = body.get('nightly', False) + + # Get current plugin directory + current_dir = os.path.dirname(os.path.abspath(__file__)) + plugin_root = os.path.dirname(os.path.dirname(current_dir)) + + # Backup settings.json if it exists + settings_path = os.path.join(plugin_root, 'settings.json') + settings_backup = None + if os.path.exists(settings_path): + with open(settings_path, 'r', encoding='utf-8') as f: + settings_backup = f.read() + logger.info("Backed up settings.json") + + # Perform Git update + success, new_version = await UpdateRoutes._perform_git_update(plugin_root, nightly) + + # Restore settings.json if we backed it up + if settings_backup and success: + with open(settings_path, 'w', encoding='utf-8') as f: + f.write(settings_backup) + logger.info("Restored settings.json") + + if success: + return web.json_response({ + 'success': True, + 'message': f'Successfully updated to {new_version}', + 'new_version': new_version + }) + else: + return web.json_response({ + 'success': False, + 'error': 'Failed to complete Git update' + }) + + except Exception as e: + logger.error(f"Failed to perform update: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }) + + @staticmethod + async def _get_nightly_version() -> tuple[str, List[str]]: + """ + Fetch latest commit from main branch + """ + repo_owner = "willmiao" + repo_name = "ComfyUI-Lora-Manager" + + # Use GitHub API to fetch the latest commit from main branch + github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/commits/main" + + try: + async with aiohttp.ClientSession() as session: + async with session.get(github_url, headers={'Accept': 'application/vnd.github+json'}) as response: + if response.status != 200: + logger.warning(f"Failed to fetch GitHub commit: {response.status}") + return "main", [] + + data = await response.json() + commit_sha = data.get('sha', '')[:7] # Short hash + commit_message = data.get('commit', {}).get('message', '') + + # Format as "main-{short_hash}" + version = f"main-{commit_sha}" + + # Use commit message as changelog + changelog = [commit_message] if commit_message else [] + + return version, changelog + + except Exception as e: + logger.error(f"Error fetching nightly version: {e}", exc_info=True) + return "main", [] + + @staticmethod + def _compare_nightly_versions(local_git_info: Dict[str, str], remote_version: str) -> bool: + """ + Compare local commit hash with remote main branch + """ + try: + local_hash = local_git_info.get('short_hash', 'unknown') + if local_hash == 'unknown': + return True # Assume update available if we can't get local hash + + # Extract remote hash from version string (format: "main-{hash}") + if '-' in remote_version: + remote_hash = remote_version.split('-')[-1] + return local_hash != remote_hash + + return True # Default to update available + + except Exception as e: + logger.error(f"Error comparing nightly versions: {e}") + return False + + @staticmethod + async def _perform_git_update(plugin_root: str, nightly: bool = False) -> tuple[bool, str]: + """ + Perform Git-based update using GitPython + + Args: + plugin_root: Path to the plugin root directory + nightly: Whether to update to main branch or latest release + + Returns: + tuple: (success, new_version) + """ + try: + # Open the Git repository + repo = git.Repo(plugin_root) + + # Fetch latest changes + origin = repo.remotes.origin + origin.fetch() + + if nightly: + # Switch to main branch and pull latest + main_branch = 'main' + if main_branch not in [branch.name for branch in repo.branches]: + # Create local main branch if it doesn't exist + repo.create_head(main_branch, origin.refs.main) + + repo.heads[main_branch].checkout() + origin.pull(main_branch) + + # Get new commit hash + new_version = f"main-{repo.head.commit.hexsha[:7]}" + + else: + # Get latest release tag + tags = sorted(repo.tags, key=lambda t: t.commit.committed_datetime, reverse=True) + if not tags: + logger.error("No tags found in repository") + return False, "" + + latest_tag = tags[0] + + # Checkout to latest tag + repo.git.checkout(latest_tag.name) + + new_version = latest_tag.name + + logger.info(f"Successfully updated to {new_version}") + return True, new_version + + except git.exc.GitError as e: + logger.error(f"Git error during update: {e}") + return False, "" + except Exception as e: + logger.error(f"Error during Git update: {e}") + return False, "" + @staticmethod def _get_local_version() -> str: """Get local plugin version from pyproject.toml""" diff --git a/py/services/base_model_service.py b/py/services/base_model_service.py new file mode 100644 index 00000000..72242daa --- /dev/null +++ b/py/services/base_model_service.py @@ -0,0 +1,259 @@ +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Type +import logging + +from ..utils.models import BaseModelMetadata +from ..utils.constants import NSFW_LEVELS +from .settings_manager import settings +from ..utils.utils import fuzzy_match + +logger = logging.getLogger(__name__) + +class BaseModelService(ABC): + """Base service class for all model types""" + + def __init__(self, model_type: str, scanner, metadata_class: Type[BaseModelMetadata]): + """Initialize the service + + Args: + model_type: Type of model (lora, checkpoint, etc.) + scanner: Model scanner instance + metadata_class: Metadata class for this model type + """ + self.model_type = model_type + self.scanner = scanner + self.metadata_class = metadata_class + + async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name', + folder: str = None, search: str = None, fuzzy_search: bool = False, + base_models: list = None, tags: list = None, + search_options: dict = None, hash_filters: dict = None, + favorites_only: bool = False, **kwargs) -> Dict: + """Get paginated and filtered model data + + Args: + page: Page number (1-based) + page_size: Number of items per page + sort_by: Sort criteria, e.g. 'name', 'name:asc', 'name:desc', 'date', 'date:asc', 'date:desc' + folder: Folder filter + search: Search term + fuzzy_search: Whether to use fuzzy search + base_models: List of base models to filter by + tags: List of tags to filter by + search_options: Search options dict + hash_filters: Hash filtering options + favorites_only: Filter for favorites only + **kwargs: Additional model-specific filters + + Returns: + Dict containing paginated results + """ + cache = await self.scanner.get_cached_data() + + # Parse sort_by into sort_key and order + if ':' in sort_by: + sort_key, order = sort_by.split(':', 1) + sort_key = sort_key.strip() + order = order.strip().lower() + if order not in ('asc', 'desc'): + order = 'asc' + else: + sort_key = sort_by.strip() + order = 'asc' + + # Get default search options if not provided + if search_options is None: + search_options = { + 'filename': True, + 'modelname': True, + 'tags': False, + 'recursive': False, + } + + # Get the base data set using new sort logic + filtered_data = await cache.get_sorted_data(sort_key, order) + + # Apply hash filtering if provided (highest priority) + if hash_filters: + filtered_data = await self._apply_hash_filters(filtered_data, hash_filters) + + # Jump to pagination for hash filters + return self._paginate(filtered_data, page, page_size) + + # Apply common filters + filtered_data = await self._apply_common_filters( + filtered_data, folder, base_models, tags, favorites_only, search_options + ) + + # Apply search filtering + if search: + filtered_data = await self._apply_search_filters( + filtered_data, search, fuzzy_search, search_options + ) + + # Apply model-specific filters + filtered_data = await self._apply_specific_filters(filtered_data, **kwargs) + + return self._paginate(filtered_data, page, page_size) + + async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]: + """Apply hash-based filtering""" + single_hash = hash_filters.get('single_hash') + multiple_hashes = hash_filters.get('multiple_hashes') + + if single_hash: + # Filter by single hash + single_hash = single_hash.lower() + return [ + item for item in data + if item.get('sha256', '').lower() == single_hash + ] + elif multiple_hashes: + # Filter by multiple hashes + hash_set = set(hash.lower() for hash in multiple_hashes) + return [ + item for item in data + if item.get('sha256', '').lower() in hash_set + ] + + return data + + async def _apply_common_filters(self, data: List[Dict], folder: str = None, + base_models: list = None, tags: list = None, + favorites_only: bool = False, search_options: dict = None) -> List[Dict]: + """Apply common filters that work across all model types""" + # Apply SFW filtering if enabled in settings + if settings.get('show_only_sfw', False): + data = [ + item for item in data + if not item.get('preview_nsfw_level') or item.get('preview_nsfw_level') < NSFW_LEVELS['R'] + ] + + # Apply favorites filtering if enabled + if favorites_only: + data = [ + item for item in data + if item.get('favorite', False) is True + ] + + # Apply folder filtering + if folder is not None: + if search_options and search_options.get('recursive', False): + # Recursive folder filtering - include all subfolders + data = [ + item for item in data + if item['folder'].startswith(folder) + ] + else: + # Exact folder filtering + data = [ + item for item in data + if item['folder'] == folder + ] + + # Apply base model filtering + if base_models and len(base_models) > 0: + data = [ + item for item in data + if item.get('base_model') in base_models + ] + + # Apply tag filtering + if tags and len(tags) > 0: + data = [ + item for item in data + if any(tag in item.get('tags', []) for tag in tags) + ] + + return data + + async def _apply_search_filters(self, data: List[Dict], search: str, + fuzzy_search: bool, search_options: dict) -> List[Dict]: + """Apply search filtering""" + search_results = [] + + for item in data: + # Search by file name + if search_options.get('filename', True): + if fuzzy_search: + if fuzzy_match(item.get('file_name', ''), search): + search_results.append(item) + continue + elif search.lower() in item.get('file_name', '').lower(): + search_results.append(item) + continue + + # Search by model name + if search_options.get('modelname', True): + if fuzzy_search: + if fuzzy_match(item.get('model_name', ''), search): + search_results.append(item) + continue + elif search.lower() in item.get('model_name', '').lower(): + search_results.append(item) + continue + + # Search by tags + if search_options.get('tags', False) and 'tags' in item: + if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower()) + for tag in item['tags']): + search_results.append(item) + continue + + return search_results + + async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]: + """Apply model-specific filters - to be overridden by subclasses if needed""" + return data + + def _paginate(self, data: List[Dict], page: int, page_size: int) -> Dict: + """Apply pagination to filtered data""" + total_items = len(data) + start_idx = (page - 1) * page_size + end_idx = min(start_idx + page_size, total_items) + + return { + 'items': data[start_idx:end_idx], + 'total': total_items, + 'page': page, + 'page_size': page_size, + 'total_pages': (total_items + page_size - 1) // page_size + } + + @abstractmethod + async def format_response(self, model_data: Dict) -> Dict: + """Format model data for API response - must be implemented by subclasses""" + pass + + # Common service methods that delegate to scanner + async def get_top_tags(self, limit: int = 20) -> List[Dict]: + """Get top tags sorted by frequency""" + return await self.scanner.get_top_tags(limit) + + async def get_base_models(self, limit: int = 20) -> List[Dict]: + """Get base models sorted by frequency""" + return await self.scanner.get_base_models(limit) + + def has_hash(self, sha256: str) -> bool: + """Check if a model with given hash exists""" + return self.scanner.has_hash(sha256) + + def get_path_by_hash(self, sha256: str) -> Optional[str]: + """Get file path for a model by its hash""" + return self.scanner.get_path_by_hash(sha256) + + def get_hash_by_path(self, file_path: str) -> Optional[str]: + """Get hash for a model by its file path""" + return self.scanner.get_hash_by_path(file_path) + + async def scan_models(self, force_refresh: bool = False, rebuild_cache: bool = False): + """Trigger model scanning""" + return await self.scanner.get_cached_data(force_refresh=force_refresh, rebuild_cache=rebuild_cache) + + async def get_model_info_by_name(self, name: str): + """Get model information by name""" + return await self.scanner.get_model_info_by_name(name) + + def get_model_roots(self) -> List[str]: + """Get model root directories""" + return self.scanner.get_model_roots() \ No newline at end of file diff --git a/py/services/checkpoint_scanner.py b/py/services/checkpoint_scanner.py index 26733ab5..d4696631 100644 --- a/py/services/checkpoint_scanner.py +++ b/py/services/checkpoint_scanner.py @@ -1,112 +1,26 @@ -import os import logging -import asyncio -from typing import List, Dict, Optional, Set -import folder_paths # type: ignore +from typing import List from ..utils.models import CheckpointMetadata from ..config import config from .model_scanner import ModelScanner from .model_hash_index import ModelHashIndex -from .service_registry import ServiceRegistry logger = logging.getLogger(__name__) class CheckpointScanner(ModelScanner): """Service for scanning and managing checkpoint files""" - _instance = None - _lock = asyncio.Lock() - - def __new__(cls): - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - def __init__(self): - if not hasattr(self, '_initialized'): - # Define supported file extensions - file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'} - super().__init__( - model_type="checkpoint", - model_class=CheckpointMetadata, - file_extensions=file_extensions, - hash_index=ModelHashIndex() - ) - self._initialized = True + # Define supported file extensions + file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'} + super().__init__( + model_type="checkpoint", + model_class=CheckpointMetadata, + file_extensions=file_extensions, + hash_index=ModelHashIndex() + ) - @classmethod - async def get_instance(cls): - """Get singleton instance with async support""" - async with cls._lock: - if cls._instance is None: - cls._instance = cls() - return cls._instance - def get_model_roots(self) -> List[str]: """Get checkpoint root directories""" - return config.base_models_roots - - async def scan_all_models(self) -> List[Dict]: - """Scan all checkpoint directories and return metadata""" - all_checkpoints = [] - - # Create scan tasks for each directory - scan_tasks = [] - for root in self.get_model_roots(): - task = asyncio.create_task(self._scan_directory(root)) - scan_tasks.append(task) - - # Wait for all tasks to complete - for task in scan_tasks: - try: - checkpoints = await task - all_checkpoints.extend(checkpoints) - except Exception as e: - logger.error(f"Error scanning checkpoint directory: {e}") - - return all_checkpoints - - async def _scan_directory(self, root_path: str) -> List[Dict]: - """Scan a directory for checkpoint files""" - checkpoints = [] - original_root = root_path - - async def scan_recursive(path: str, visited_paths: set): - try: - real_path = os.path.realpath(path) - if real_path in visited_paths: - logger.debug(f"Skipping already visited path: {path}") - return - visited_paths.add(real_path) - - with os.scandir(path) as it: - entries = list(it) - for entry in entries: - try: - if entry.is_file(follow_symlinks=True): - # Check if file has supported extension - ext = os.path.splitext(entry.name)[1].lower() - if ext in self.file_extensions: - file_path = entry.path.replace(os.sep, "/") - await self._process_single_file(file_path, original_root, checkpoints) - await asyncio.sleep(0) - elif entry.is_dir(follow_symlinks=True): - # For directories, continue scanning with original path - await scan_recursive(entry.path, visited_paths) - except Exception as e: - logger.error(f"Error processing entry {entry.path}: {e}") - except Exception as e: - logger.error(f"Error scanning {path}: {e}") - - await scan_recursive(root_path, set()) - return checkpoints - - async def _process_single_file(self, file_path: str, root_path: str, checkpoints: list): - """Process a single checkpoint file and add to results""" - try: - result = await self._process_model_file(file_path, root_path) - if result: - checkpoints.append(result) - except Exception as e: - logger.error(f"Error processing {file_path}: {e}") \ No newline at end of file + return config.base_models_roots \ No newline at end of file diff --git a/py/services/checkpoint_service.py b/py/services/checkpoint_service.py new file mode 100644 index 00000000..cfc4b118 --- /dev/null +++ b/py/services/checkpoint_service.py @@ -0,0 +1,51 @@ +import os +import logging +from typing import Dict, List, Optional + +from .base_model_service import BaseModelService +from ..utils.models import CheckpointMetadata +from ..config import config +from ..utils.routes_common import ModelRouteUtils + +logger = logging.getLogger(__name__) + +class CheckpointService(BaseModelService): + """Checkpoint-specific service implementation""" + + def __init__(self, scanner): + """Initialize Checkpoint service + + Args: + scanner: Checkpoint scanner instance + """ + super().__init__("checkpoint", scanner, CheckpointMetadata) + + async def format_response(self, checkpoint_data: Dict) -> Dict: + """Format Checkpoint data for API response""" + return { + "model_name": checkpoint_data["model_name"], + "file_name": checkpoint_data["file_name"], + "preview_url": config.get_preview_static_url(checkpoint_data.get("preview_url", "")), + "preview_nsfw_level": checkpoint_data.get("preview_nsfw_level", 0), + "base_model": checkpoint_data.get("base_model", ""), + "folder": checkpoint_data["folder"], + "sha256": checkpoint_data.get("sha256", ""), + "file_path": checkpoint_data["file_path"].replace(os.sep, "/"), + "file_size": checkpoint_data.get("size", 0), + "modified": checkpoint_data.get("modified", ""), + "tags": checkpoint_data.get("tags", []), + "modelDescription": checkpoint_data.get("modelDescription", ""), + "from_civitai": checkpoint_data.get("from_civitai", True), + "notes": checkpoint_data.get("notes", ""), + "model_type": checkpoint_data.get("model_type", "checkpoint"), + "favorite": checkpoint_data.get("favorite", False), + "civitai": ModelRouteUtils.filter_civitai_data(checkpoint_data.get("civitai", {})) + } + + def find_duplicate_hashes(self) -> Dict: + """Find Checkpoints with duplicate SHA256 hashes""" + return self.scanner._hash_index.get_duplicate_hashes() + + def find_duplicate_filenames(self) -> Dict: + """Find Checkpoints with conflicting filenames""" + return self.scanner._hash_index.get_duplicate_filenames() \ No newline at end of file diff --git a/py/services/lora_scanner.py b/py/services/lora_scanner.py index 59ff58b5..6feff477 100644 --- a/py/services/lora_scanner.py +++ b/py/services/lora_scanner.py @@ -1,15 +1,10 @@ -import os import logging -import asyncio -from typing import List, Dict, Optional +from typing import List from ..utils.models import LoraMetadata from ..config import config from .model_scanner import ModelScanner from .model_hash_index import ModelHashIndex # Changed from LoraHashIndex to ModelHashIndex -from .settings_manager import settings -from ..utils.constants import NSFW_LEVELS -from ..utils.utils import fuzzy_match import sys logger = logging.getLogger(__name__) @@ -17,404 +12,21 @@ logger = logging.getLogger(__name__) class LoraScanner(ModelScanner): """Service for scanning and managing LoRA files""" - _instance = None - _lock = asyncio.Lock() - - def __new__(cls): - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - def __init__(self): - # Ensure initialization happens only once - if not hasattr(self, '_initialized'): - # Define supported file extensions - file_extensions = {'.safetensors'} - - # Initialize parent class with ModelHashIndex - super().__init__( - model_type="lora", - model_class=LoraMetadata, - file_extensions=file_extensions, - hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex - ) - self._initialized = True - - @classmethod - async def get_instance(cls): - """Get singleton instance with async support""" - async with cls._lock: - if cls._instance is None: - cls._instance = cls() - return cls._instance + # Define supported file extensions + file_extensions = {'.safetensors'} + + # Initialize parent class with ModelHashIndex + super().__init__( + model_type="lora", + model_class=LoraMetadata, + file_extensions=file_extensions, + hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex + ) def get_model_roots(self) -> List[str]: """Get lora root directories""" return config.loras_roots - - async def scan_all_models(self) -> List[Dict]: - """Scan all LoRA directories and return metadata""" - all_loras = [] - - # Create scan tasks for each directory - scan_tasks = [] - for lora_root in self.get_model_roots(): - task = asyncio.create_task(self._scan_directory(lora_root)) - scan_tasks.append(task) - - # Wait for all tasks to complete - for task in scan_tasks: - try: - loras = await task - all_loras.extend(loras) - except Exception as e: - logger.error(f"Error scanning directory: {e}") - - return all_loras - - async def _scan_directory(self, root_path: str) -> List[Dict]: - """Scan a single directory for LoRA files""" - loras = [] - original_root = root_path # Save original root path - - async def scan_recursive(path: str, visited_paths: set): - """Recursively scan directory, avoiding circular symlinks""" - try: - real_path = os.path.realpath(path) - if real_path in visited_paths: - logger.debug(f"Skipping already visited path: {path}") - return - visited_paths.add(real_path) - - with os.scandir(path) as it: - entries = list(it) - for entry in entries: - try: - if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions): - # Use original path instead of real path - file_path = entry.path.replace(os.sep, "/") - await self._process_single_file(file_path, original_root, loras) - await asyncio.sleep(0) - elif entry.is_dir(follow_symlinks=True): - # For directories, continue scanning with original path - await scan_recursive(entry.path, visited_paths) - except Exception as e: - logger.error(f"Error processing entry {entry.path}: {e}") - except Exception as e: - logger.error(f"Error scanning {path}: {e}") - - await scan_recursive(root_path, set()) - return loras - - async def _process_single_file(self, file_path: str, root_path: str, loras: list): - """Process a single file and add to results list""" - try: - result = await self._process_model_file(file_path, root_path) - if result: - loras.append(result) - except Exception as e: - logger.error(f"Error processing {file_path}: {e}") - - async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name', - folder: str = None, search: str = None, fuzzy_search: bool = False, - base_models: list = None, tags: list = None, - search_options: dict = None, hash_filters: dict = None, - favorites_only: bool = False, first_letter: str = None) -> Dict: - """Get paginated and filtered lora data - - Args: - page: Current page number (1-based) - page_size: Number of items per page - sort_by: Sort method ('name' or 'date') - folder: Filter by folder path - search: Search term - fuzzy_search: Use fuzzy matching for search - base_models: List of base models to filter by - tags: List of tags to filter by - search_options: Dictionary with search options (filename, modelname, tags, recursive) - hash_filters: Dictionary with hash filtering options (single_hash or multiple_hashes) - favorites_only: Filter for favorite models only - first_letter: Filter by first letter of model name - """ - cache = await self.get_cached_data() - - # Get default search options if not provided - if search_options is None: - search_options = { - 'filename': True, - 'modelname': True, - 'tags': False, - 'recursive': False, - } - - # Get the base data set - filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name - - # Apply hash filtering if provided (highest priority) - if hash_filters: - single_hash = hash_filters.get('single_hash') - multiple_hashes = hash_filters.get('multiple_hashes') - - if single_hash: - # Filter by single hash - single_hash = single_hash.lower() # Ensure lowercase for matching - filtered_data = [ - lora for lora in filtered_data - if lora.get('sha256', '').lower() == single_hash - ] - elif multiple_hashes: - # Filter by multiple hashes - hash_set = set(hash.lower() for hash in multiple_hashes) # Convert to set for faster lookup - filtered_data = [ - lora for lora in filtered_data - if lora.get('sha256', '').lower() in hash_set - ] - - - # Jump to pagination - total_items = len(filtered_data) - start_idx = (page - 1) * page_size - end_idx = min(start_idx + page_size, total_items) - - result = { - 'items': filtered_data[start_idx:end_idx], - 'total': total_items, - 'page': page, - 'page_size': page_size, - 'total_pages': (total_items + page_size - 1) // page_size - } - - return result - - # Apply SFW filtering if enabled - if settings.get('show_only_sfw', False): - filtered_data = [ - lora for lora in filtered_data - if not lora.get('preview_nsfw_level') or lora.get('preview_nsfw_level') < NSFW_LEVELS['R'] - ] - - # Apply favorites filtering if enabled - if favorites_only: - filtered_data = [ - lora for lora in filtered_data - if lora.get('favorite', False) is True - ] - - # Apply first letter filtering - if first_letter: - filtered_data = self._filter_by_first_letter(filtered_data, first_letter) - - # Apply folder filtering - if folder is not None: - if search_options.get('recursive', False): - # Recursive folder filtering - include all subfolders - filtered_data = [ - lora for lora in filtered_data - if lora['folder'].startswith(folder) - ] - else: - # Exact folder filtering - filtered_data = [ - lora for lora in filtered_data - if lora['folder'] == folder - ] - - # Apply base model filtering - if base_models and len(base_models) > 0: - filtered_data = [ - lora for lora in filtered_data - if lora.get('base_model') in base_models - ] - - # Apply tag filtering - if tags and len(tags) > 0: - filtered_data = [ - lora for lora in filtered_data - if any(tag in lora.get('tags', []) for tag in tags) - ] - - # Apply search filtering - if search: - search_results = [] - search_opts = search_options or {} - - for lora in filtered_data: - # Search by file name - if search_opts.get('filename', True): - if fuzzy_match(lora.get('file_name', ''), search): - search_results.append(lora) - continue - - # Search by model name - if search_opts.get('modelname', True): - if fuzzy_match(lora.get('model_name', ''), search): - search_results.append(lora) - continue - - # Search by tags - if search_opts.get('tags', False) and 'tags' in lora: - if any(fuzzy_match(tag, search) for tag in lora['tags']): - search_results.append(lora) - continue - - filtered_data = search_results - - # Calculate pagination - total_items = len(filtered_data) - start_idx = (page - 1) * page_size - end_idx = min(start_idx + page_size, total_items) - - result = { - 'items': filtered_data[start_idx:end_idx], - 'total': total_items, - 'page': page, - 'page_size': page_size, - 'total_pages': (total_items + page_size - 1) // page_size - } - - return result - - def _filter_by_first_letter(self, data, letter): - """Filter data by first letter of model name - - Special handling: - - '#': Numbers (0-9) - - '@': Special characters (not alphanumeric) - - '漢': CJK characters - """ - filtered_data = [] - - for lora in data: - model_name = lora.get('model_name', '') - if not model_name: - continue - - first_char = model_name[0].upper() - - if letter == '#' and first_char.isdigit(): - filtered_data.append(lora) - elif letter == '@' and not first_char.isalnum(): - # Special characters (not alphanumeric) - filtered_data.append(lora) - elif letter == '漢' and self._is_cjk_character(first_char): - # CJK characters - filtered_data.append(lora) - elif letter.upper() == first_char: - # Regular alphabet matching - filtered_data.append(lora) - - return filtered_data - - def _is_cjk_character(self, char): - """Check if character is a CJK character""" - # Define Unicode ranges for CJK characters - cjk_ranges = [ - (0x4E00, 0x9FFF), # CJK Unified Ideographs - (0x3400, 0x4DBF), # CJK Unified Ideographs Extension A - (0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B - (0x2A700, 0x2B73F), # CJK Unified Ideographs Extension C - (0x2B740, 0x2B81F), # CJK Unified Ideographs Extension D - (0x2B820, 0x2CEAF), # CJK Unified Ideographs Extension E - (0x2CEB0, 0x2EBEF), # CJK Unified Ideographs Extension F - (0x30000, 0x3134F), # CJK Unified Ideographs Extension G - (0xF900, 0xFAFF), # CJK Compatibility Ideographs - (0x3300, 0x33FF), # CJK Compatibility - (0x3200, 0x32FF), # Enclosed CJK Letters and Months - (0x3100, 0x312F), # Bopomofo - (0x31A0, 0x31BF), # Bopomofo Extended - (0x3040, 0x309F), # Hiragana - (0x30A0, 0x30FF), # Katakana - (0x31F0, 0x31FF), # Katakana Phonetic Extensions - (0xAC00, 0xD7AF), # Hangul Syllables - (0x1100, 0x11FF), # Hangul Jamo - (0xA960, 0xA97F), # Hangul Jamo Extended-A - (0xD7B0, 0xD7FF), # Hangul Jamo Extended-B - ] - - code_point = ord(char) - return any(start <= code_point <= end for start, end in cjk_ranges) - - async def get_letter_counts(self): - """Get count of models for each letter of the alphabet""" - cache = await self.get_cached_data() - data = cache.sorted_by_name - - # Define letter categories - letters = { - '#': 0, # Numbers - 'A': 0, 'B': 0, 'C': 0, 'D': 0, 'E': 0, 'F': 0, 'G': 0, 'H': 0, - 'I': 0, 'J': 0, 'K': 0, 'L': 0, 'M': 0, 'N': 0, 'O': 0, 'P': 0, - 'Q': 0, 'R': 0, 'S': 0, 'T': 0, 'U': 0, 'V': 0, 'W': 0, 'X': 0, - 'Y': 0, 'Z': 0, - '@': 0, # Special characters - '漢': 0 # CJK characters - } - - # Count models for each letter - for lora in data: - model_name = lora.get('model_name', '') - if not model_name: - continue - - first_char = model_name[0].upper() - - if first_char.isdigit(): - letters['#'] += 1 - elif first_char in letters: - letters[first_char] += 1 - elif self._is_cjk_character(first_char): - letters['漢'] += 1 - elif not first_char.isalnum(): - letters['@'] += 1 - - return letters - - # Lora-specific hash index functionality - def has_lora_hash(self, sha256: str) -> bool: - """Check if a LoRA with given hash exists""" - return self.has_hash(sha256) - - def get_lora_path_by_hash(self, sha256: str) -> Optional[str]: - """Get file path for a LoRA by its hash""" - return self.get_path_by_hash(sha256) - - def get_lora_hash_by_path(self, file_path: str) -> Optional[str]: - """Get hash for a LoRA by its file path""" - return self.get_hash_by_path(file_path) - - async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]: - """Get top tags sorted by count""" - # Make sure cache is initialized - await self.get_cached_data() - - # Sort tags by count in descending order - sorted_tags = sorted( - [{"tag": tag, "count": count} for tag, count in self._tags_count.items()], - key=lambda x: x['count'], - reverse=True - ) - - # Return limited number - return sorted_tags[:limit] - - async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]: - """Get base models used in loras sorted by frequency""" - # Make sure cache is initialized - cache = await self.get_cached_data() - - # Count base model occurrences - base_model_counts = {} - for lora in cache.raw_data: - if 'base_model' in lora and lora['base_model']: - base_model = lora['base_model'] - base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1 - - # Sort base models by count - sorted_models = [{'name': model, 'count': count} for model, count in base_model_counts.items()] - sorted_models.sort(key=lambda x: x['count'], reverse=True) - - # Return limited number - return sorted_models[:limit] async def diagnose_hash_index(self): """Diagnostic method to verify hash index functionality""" @@ -451,19 +63,3 @@ class LoraScanner(ModelScanner): test_hash_result = self._hash_index.get_hash(test_path) print(f"Test reverse lookup: {test_path} -> {test_hash_result[:8]}...\n\n", file=sys.stderr) - async def get_lora_info_by_name(self, name): - """Get LoRA information by name""" - try: - # Get cached data - cache = await self.get_cached_data() - - # Find the LoRA by name - for lora in cache.raw_data: - if lora.get("file_name") == name: - return lora - - return None - except Exception as e: - logger.error(f"Error getting LoRA info by name: {e}", exc_info=True) - return None - diff --git a/py/services/lora_service.py b/py/services/lora_service.py new file mode 100644 index 00000000..7649f75b --- /dev/null +++ b/py/services/lora_service.py @@ -0,0 +1,212 @@ +import os +import logging +from typing import Dict, List, Optional + +from .base_model_service import BaseModelService +from ..utils.models import LoraMetadata +from ..config import config +from ..utils.routes_common import ModelRouteUtils + +logger = logging.getLogger(__name__) + +class LoraService(BaseModelService): + """LoRA-specific service implementation""" + + def __init__(self, scanner): + """Initialize LoRA service + + Args: + scanner: LoRA scanner instance + """ + super().__init__("lora", scanner, LoraMetadata) + + async def format_response(self, lora_data: Dict) -> Dict: + """Format LoRA data for API response""" + return { + "model_name": lora_data["model_name"], + "file_name": lora_data["file_name"], + "preview_url": config.get_preview_static_url(lora_data.get("preview_url", "")), + "preview_nsfw_level": lora_data.get("preview_nsfw_level", 0), + "base_model": lora_data.get("base_model", ""), + "folder": lora_data["folder"], + "sha256": lora_data.get("sha256", ""), + "file_path": lora_data["file_path"].replace(os.sep, "/"), + "file_size": lora_data.get("size", 0), + "modified": lora_data.get("modified", ""), + "tags": lora_data.get("tags", []), + "modelDescription": lora_data.get("modelDescription", ""), + "from_civitai": lora_data.get("from_civitai", True), + "usage_tips": lora_data.get("usage_tips", ""), + "notes": lora_data.get("notes", ""), + "favorite": lora_data.get("favorite", False), + "civitai": ModelRouteUtils.filter_civitai_data(lora_data.get("civitai", {})) + } + + async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]: + """Apply LoRA-specific filters""" + # Handle first_letter filter for LoRAs + first_letter = kwargs.get('first_letter') + if first_letter: + data = self._filter_by_first_letter(data, first_letter) + + return data + + def _filter_by_first_letter(self, data: List[Dict], letter: str) -> List[Dict]: + """Filter data by first letter of model name + + Special handling: + - '#': Numbers (0-9) + - '@': Special characters (not alphanumeric) + - '漢': CJK characters + """ + filtered_data = [] + + for lora in data: + model_name = lora.get('model_name', '') + if not model_name: + continue + + first_char = model_name[0].upper() + + if letter == '#' and first_char.isdigit(): + filtered_data.append(lora) + elif letter == '@' and not first_char.isalnum(): + # Special characters (not alphanumeric) + filtered_data.append(lora) + elif letter == '漢' and self._is_cjk_character(first_char): + # CJK characters + filtered_data.append(lora) + elif letter.upper() == first_char: + # Regular alphabet matching + filtered_data.append(lora) + + return filtered_data + + def _is_cjk_character(self, char: str) -> bool: + """Check if character is a CJK character""" + # Define Unicode ranges for CJK characters + cjk_ranges = [ + (0x4E00, 0x9FFF), # CJK Unified Ideographs + (0x3400, 0x4DBF), # CJK Unified Ideographs Extension A + (0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B + (0x2A700, 0x2B73F), # CJK Unified Ideographs Extension C + (0x2B740, 0x2B81F), # CJK Unified Ideographs Extension D + (0x2B820, 0x2CEAF), # CJK Unified Ideographs Extension E + (0x2CEB0, 0x2EBEF), # CJK Unified Ideographs Extension F + (0x30000, 0x3134F), # CJK Unified Ideographs Extension G + (0xF900, 0xFAFF), # CJK Compatibility Ideographs + (0x3300, 0x33FF), # CJK Compatibility + (0x3200, 0x32FF), # Enclosed CJK Letters and Months + (0x3100, 0x312F), # Bopomofo + (0x31A0, 0x31BF), # Bopomofo Extended + (0x3040, 0x309F), # Hiragana + (0x30A0, 0x30FF), # Katakana + (0x31F0, 0x31FF), # Katakana Phonetic Extensions + (0xAC00, 0xD7AF), # Hangul Syllables + (0x1100, 0x11FF), # Hangul Jamo + (0xA960, 0xA97F), # Hangul Jamo Extended-A + (0xD7B0, 0xD7FF), # Hangul Jamo Extended-B + ] + + code_point = ord(char) + return any(start <= code_point <= end for start, end in cjk_ranges) + + # LoRA-specific methods + async def get_letter_counts(self) -> Dict[str, int]: + """Get count of LoRAs for each letter of the alphabet""" + cache = await self.scanner.get_cached_data() + data = cache.raw_data + + # Define letter categories + letters = { + '#': 0, # Numbers + 'A': 0, 'B': 0, 'C': 0, 'D': 0, 'E': 0, 'F': 0, 'G': 0, 'H': 0, + 'I': 0, 'J': 0, 'K': 0, 'L': 0, 'M': 0, 'N': 0, 'O': 0, 'P': 0, + 'Q': 0, 'R': 0, 'S': 0, 'T': 0, 'U': 0, 'V': 0, 'W': 0, 'X': 0, + 'Y': 0, 'Z': 0, + '@': 0, # Special characters + '漢': 0 # CJK characters + } + + # Count models for each letter + for lora in data: + model_name = lora.get('model_name', '') + if not model_name: + continue + + first_char = model_name[0].upper() + + if first_char.isdigit(): + letters['#'] += 1 + elif first_char in letters: + letters[first_char] += 1 + elif self._is_cjk_character(first_char): + letters['漢'] += 1 + elif not first_char.isalnum(): + letters['@'] += 1 + + return letters + + async def get_lora_notes(self, lora_name: str) -> Optional[str]: + """Get notes for a specific LoRA file""" + cache = await self.scanner.get_cached_data() + + for lora in cache.raw_data: + if lora['file_name'] == lora_name: + return lora.get('notes', '') + + return None + + async def get_lora_trigger_words(self, lora_name: str) -> List[str]: + """Get trigger words for a specific LoRA file""" + cache = await self.scanner.get_cached_data() + + for lora in cache.raw_data: + if lora['file_name'] == lora_name: + civitai_data = lora.get('civitai', {}) + return civitai_data.get('trainedWords', []) + + return [] + + async def get_lora_preview_url(self, lora_name: str) -> Optional[str]: + """Get the static preview URL for a LoRA file""" + cache = await self.scanner.get_cached_data() + + for lora in cache.raw_data: + if lora['file_name'] == lora_name: + preview_url = lora.get('preview_url') + if preview_url: + return config.get_preview_static_url(preview_url) + + return None + + async def get_lora_civitai_url(self, lora_name: str) -> Dict[str, Optional[str]]: + """Get the Civitai URL for a LoRA file""" + cache = await self.scanner.get_cached_data() + + for lora in cache.raw_data: + if lora['file_name'] == lora_name: + civitai_data = lora.get('civitai', {}) + model_id = civitai_data.get('modelId') + version_id = civitai_data.get('id') + + if model_id: + civitai_url = f"https://civitai.com/models/{model_id}" + if version_id: + civitai_url += f"?modelVersionId={version_id}" + + return { + 'civitai_url': civitai_url, + 'model_id': str(model_id), + 'version_id': str(version_id) if version_id else None + } + + return {'civitai_url': None, 'model_id': None, 'version_id': None} + + def find_duplicate_hashes(self) -> Dict: + """Find LoRAs with duplicate SHA256 hashes""" + return self.scanner._hash_index.get_duplicate_hashes() + + def find_duplicate_filenames(self) -> Dict: + """Find LoRAs with conflicting filenames""" + return self.scanner._hash_index.get_duplicate_filenames() \ No newline at end of file diff --git a/py/services/model_cache.py b/py/services/model_cache.py index 8494531e..f67b2444 100644 --- a/py/services/model_cache.py +++ b/py/services/model_cache.py @@ -1,37 +1,85 @@ import asyncio -from typing import List, Dict +from typing import List, Dict, Tuple from dataclasses import dataclass from operator import itemgetter from natsort import natsorted +# Supported sort modes: (sort_key, order) +# order: 'asc' for ascending, 'desc' for descending +SUPPORTED_SORT_MODES = [ + ('name', 'asc'), + ('name', 'desc'), + ('date', 'asc'), + ('date', 'desc'), + ('size', 'asc'), + ('size', 'desc'), +] + @dataclass class ModelCache: - """Cache structure for model data""" + """Cache structure for model data with extensible sorting""" raw_data: List[Dict] - sorted_by_name: List[Dict] - sorted_by_date: List[Dict] folders: List[str] def __post_init__(self): self._lock = asyncio.Lock() + # Cache for last sort: (sort_key, order) -> sorted list + self._last_sort: Tuple[str, str] = (None, None) + self._last_sorted_data: List[Dict] = [] + # Default sort on init + asyncio.create_task(self.resort()) - async def resort(self, name_only: bool = False): - """Resort all cached data views""" + async def resort(self): + """Resort cached data according to last sort mode if set""" async with self._lock: - self.sorted_by_name = natsorted( - self.raw_data, - key=lambda x: x['model_name'].lower() # Case-insensitive sort - ) - if not name_only: - self.sorted_by_date = sorted( - self.raw_data, - key=itemgetter('modified'), - reverse=True - ) - # Update folder list + if self._last_sort != (None, None): + sort_key, order = self._last_sort + sorted_data = self._sort_data(self.raw_data, sort_key, order) + self._last_sorted_data = sorted_data + # Update folder list + # else: do nothing + all_folders = set(l['folder'] for l in self.raw_data) self.folders = sorted(list(all_folders), key=lambda x: x.lower()) + def _sort_data(self, data: List[Dict], sort_key: str, order: str) -> List[Dict]: + """Sort data by sort_key and order""" + reverse = (order == 'desc') + if sort_key == 'name': + # Natural sort by model_name, case-insensitive + return natsorted( + data, + key=lambda x: x['model_name'].lower(), + reverse=reverse + ) + elif sort_key == 'date': + # Sort by modified timestamp + return sorted( + data, + key=itemgetter('modified'), + reverse=reverse + ) + elif sort_key == 'size': + # Sort by file size + return sorted( + data, + key=itemgetter('size'), + reverse=reverse + ) + else: + # Fallback: no sort + return list(data) + + async def get_sorted_data(self, sort_key: str = 'name', order: str = 'asc') -> List[Dict]: + """Get sorted data by sort_key and order, using cache if possible""" + async with self._lock: + if (sort_key, order) == self._last_sort: + return self._last_sorted_data + sorted_data = self._sort_data(self.raw_data, sort_key, order) + self._last_sort = (sort_key, order) + self._last_sorted_data = sorted_data + return sorted_data + async def update_preview_url(self, file_path: str, preview_url: str, preview_nsfw_level: int) -> bool: """Update preview_url for a specific model in all cached data diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index fdd9c020..b9bebda3 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -5,7 +5,6 @@ import asyncio import time import shutil from typing import List, Dict, Optional, Type, Set -import msgpack # Add MessagePack import for efficient serialization from ..utils.models import BaseModelMetadata from ..config import config @@ -19,17 +18,33 @@ from .websocket_manager import ws_manager logger = logging.getLogger(__name__) -# Define cache version to handle future format changes -# Version history: -# 1 - Initial version -# 2 - Added duplicate_filenames and duplicate_hashes tracking -# 3 - Added _excluded_models list to cache -CACHE_VERSION = 3 - class ModelScanner: """Base service for scanning and managing model files""" - _lock = asyncio.Lock() + _instances = {} # Dictionary to store instances by class + _locks = {} # Dictionary to store locks by class + + def __new__(cls, *args, **kwargs): + """Implement singleton pattern for each subclass""" + if cls not in cls._instances: + cls._instances[cls] = super().__new__(cls) + return cls._instances[cls] + + @classmethod + def _get_lock(cls): + """Get or create a lock for this class""" + if cls not in cls._locks: + cls._locks[cls] = asyncio.Lock() + return cls._locks[cls] + + @classmethod + async def get_instance(cls): + """Get singleton instance with async support""" + lock = cls._get_lock() + async with lock: + if cls not in cls._instances: + cls._instances[cls] = cls() + return cls._instances[cls] def __init__(self, model_type: str, model_class: Type[BaseModelMetadata], file_extensions: Set[str], hash_index: Optional[ModelHashIndex] = None): """Initialize the scanner @@ -40,6 +55,10 @@ class ModelScanner: file_extensions: Set of supported file extensions including the dot (e.g. {'.safetensors'}) hash_index: Hash index instance (optional) """ + # Ensure initialization happens only once per instance + if hasattr(self, '_initialized'): + return + self.model_type = model_type self.model_class = model_class self.file_extensions = file_extensions @@ -48,202 +67,15 @@ class ModelScanner: self._tags_count = {} # Dictionary to store tag counts self._is_initializing = False # Flag to track initialization state self._excluded_models = [] # List to track excluded models - self._dirs_last_modified = {} # Track directory modification times - self._use_cache_files = False # Flag to control cache file usage, default to disabled - - # Clear cache files if disabled - if not self._use_cache_files: - self._clear_cache_files() + self._initialized = True # Register this service asyncio.create_task(self._register_service()) - def _clear_cache_files(self): - """Clear existing cache files if they exist""" - try: - cache_path = self._get_cache_file_path() - if cache_path and os.path.exists(cache_path): - os.remove(cache_path) - logger.info(f"Cleared {self.model_type} cache file: {cache_path}") - except Exception as e: - logger.error(f"Error clearing {self.model_type} cache file: {e}") - async def _register_service(self): """Register this instance with the ServiceRegistry""" service_name = f"{self.model_type}_scanner" await ServiceRegistry.register_service(service_name, self) - - def _get_cache_file_path(self) -> Optional[str]: - """Get the path to the cache file""" - # Get the directory where this module is located - current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) - - # Create a cache directory within the project if it doesn't exist - cache_dir = os.path.join(current_dir, "cache") - os.makedirs(cache_dir, exist_ok=True) - - # Create filename based on model type - cache_filename = f"lm_{self.model_type}_cache.msgpack" - return os.path.join(cache_dir, cache_filename) - - def _prepare_for_msgpack(self, data): - """Preprocess data to accommodate MessagePack serialization limitations - - Converts integers exceeding safe range to strings - - Args: - data: Any type of data structure - - Returns: - Preprocessed data structure with large integers converted to strings - """ - if isinstance(data, dict): - return {k: self._prepare_for_msgpack(v) for k, v in data.items()} - elif isinstance(data, list): - return [self._prepare_for_msgpack(item) for item in data] - elif isinstance(data, int) and (data > 9007199254740991 or data < -9007199254740991): - # Convert integers exceeding JavaScript's safe integer range (2^53-1) to strings - return str(data) - else: - return data - - async def _save_cache_to_disk(self) -> bool: - """Save cache data to disk using MessagePack""" - if not self._use_cache_files: - logger.debug(f"Cache files disabled for {self.model_type}, skipping save") - return False - - if self._cache is None or not self._cache.raw_data: - logger.debug(f"No {self.model_type} cache data to save") - return False - - cache_path = self._get_cache_file_path() - if not cache_path: - logger.warning(f"Cannot determine {self.model_type} cache file location") - return False - - try: - # Create cache data structure - cache_data = { - "version": CACHE_VERSION, - "timestamp": time.time(), - "model_type": self.model_type, - "raw_data": self._cache.raw_data, - "hash_index": { - "hash_to_path": self._hash_index._hash_to_path, - "filename_to_hash": self._hash_index._filename_to_hash, # Fix: changed from path_to_hash to filename_to_hash - "duplicate_hashes": self._hash_index._duplicate_hashes, - "duplicate_filenames": self._hash_index._duplicate_filenames - }, - "tags_count": self._tags_count, - "dirs_last_modified": self._get_dirs_last_modified(), - "excluded_models": self._excluded_models # Add excluded_models to cache data - } - - # Preprocess data to handle large integers - processed_cache_data = self._prepare_for_msgpack(cache_data) - - # Write to temporary file first (atomic operation) - temp_path = f"{cache_path}.tmp" - with open(temp_path, 'wb') as f: - msgpack.pack(processed_cache_data, f) - - # Replace the old file with the new one - if os.path.exists(cache_path): - os.replace(temp_path, cache_path) - else: - os.rename(temp_path, cache_path) - - logger.info(f"Saved {self.model_type} cache with {len(self._cache.raw_data)} models to {cache_path}") - logger.debug(f"Hash index stats - hash_to_path: {len(self._hash_index._hash_to_path)}, filename_to_hash: {len(self._hash_index._filename_to_hash)}, duplicate_hashes: {len(self._hash_index._duplicate_hashes)}, duplicate_filenames: {len(self._hash_index._duplicate_filenames)}") - return True - except Exception as e: - logger.error(f"Error saving {self.model_type} cache to disk: {e}") - # Try to clean up temp file if it exists - if 'temp_path' in locals() and os.path.exists(temp_path): - try: - os.remove(temp_path) - except: - pass - return False - - def _get_dirs_last_modified(self) -> Dict[str, float]: - """Get last modified time for all model directories""" - dirs_info = {} - for root in self.get_model_roots(): - if os.path.exists(root): - dirs_info[root] = os.path.getmtime(root) - # Also check immediate subdirectories for changes - try: - with os.scandir(root) as it: - for entry in it: - if entry.is_dir(follow_symlinks=True): - dirs_info[entry.path] = entry.stat().st_mtime - except Exception as e: - logger.error(f"Error getting directory info for {root}: {e}") - return dirs_info - - def _is_cache_valid(self, cache_data: Dict) -> bool: - """Validate if the loaded cache is still valid""" - if not cache_data or cache_data.get("version") != CACHE_VERSION: - logger.info(f"Cache invalid - version mismatch. Got: {cache_data.get('version')}, Expected: {CACHE_VERSION}") - return False - - if cache_data.get("model_type") != self.model_type: - logger.info(f"Cache invalid - model type mismatch. Got: {cache_data.get('model_type')}, Expected: {self.model_type}") - return False - - return True - - async def _load_cache_from_disk(self) -> bool: - """Load cache data from disk using MessagePack""" - if not self._use_cache_files: - logger.info(f"Cache files disabled for {self.model_type}, skipping load") - return False - - start_time = time.time() - cache_path = self._get_cache_file_path() - if not cache_path or not os.path.exists(cache_path): - return False - - try: - with open(cache_path, 'rb') as f: - cache_data = msgpack.unpack(f) - - # Validate cache data - if not self._is_cache_valid(cache_data): - logger.info(f"{self.model_type.capitalize()} cache file found but invalid or outdated") - return False - - # Load data into memory - self._cache = ModelCache( - raw_data=cache_data["raw_data"], - sorted_by_name=[], - sorted_by_date=[], - folders=[] - ) - - # Load hash index - hash_index_data = cache_data.get("hash_index", {}) - self._hash_index._hash_to_path = hash_index_data.get("hash_to_path", {}) - self._hash_index._filename_to_hash = hash_index_data.get("filename_to_hash", {}) # Fix: changed from path_to_hash to filename_to_hash - self._hash_index._duplicate_hashes = hash_index_data.get("duplicate_hashes", {}) - self._hash_index._duplicate_filenames = hash_index_data.get("duplicate_filenames", {}) - - # Load tags count - self._tags_count = cache_data.get("tags_count", {}) - - # Load excluded models - self._excluded_models = cache_data.get("excluded_models", []) - - # Resort the cache - await self._cache.resort() - - logger.info(f"Loaded {self.model_type} cache from disk with {len(self._cache.raw_data)} models in {time.time() - start_time:.2f} seconds") - return True - except Exception as e: - logger.error(f"Error loading {self.model_type} cache from disk: {e}") - return False async def initialize_in_background(self) -> None: """Initialize cache in background using thread pool""" @@ -252,8 +84,6 @@ class ModelScanner: if self._cache is None: self._cache = ModelCache( raw_data=[], - sorted_by_name=[], - sorted_by_date=[], folders=[] ) @@ -271,21 +101,6 @@ class ModelScanner: 'scanner_type': self.model_type, 'pageType': page_type }) - - cache_loaded = await self._load_cache_from_disk() - - if cache_loaded: - # Cache loaded successfully, broadcast complete message - await ws_manager.broadcast_init_progress({ - 'stage': 'finalizing', - 'progress': 100, - 'status': 'complete', - 'details': f"Loaded {len(self._cache.raw_data)} {self.model_type} files from cache.", - 'scanner_type': self.model_type, - 'pageType': page_type - }) - self._is_initializing = False - return # If cache loading failed, proceed with full scan await ws_manager.broadcast_init_progress({ @@ -332,9 +147,6 @@ class ModelScanner: logger.info(f"{self.model_type.capitalize()} cache initialized in {time.time() - start_time:.2f} seconds. Found {len(self._cache.raw_data)} models") - # Save the cache to disk after initialization - await self._save_cache_to_disk() - # Send completion message await asyncio.sleep(0.5) # Small delay to ensure final progress message is sent await ws_manager.broadcast_init_progress({ @@ -509,40 +321,21 @@ class ModelScanner: Args: force_refresh: Whether to refresh the cache - rebuild_cache: Whether to completely rebuild the cache by reloading from disk first + rebuild_cache: Whether to completely rebuild the cache """ # If cache is not initialized, return an empty cache # Actual initialization should be done via initialize_in_background if self._cache is None and not force_refresh: return ModelCache( raw_data=[], - sorted_by_name=[], - sorted_by_date=[], folders=[] ) # If force refresh is requested, initialize the cache directly if force_refresh: - # If rebuild_cache is True, try to reload from disk before reconciliation if rebuild_cache: - logger.info(f"{self.model_type.capitalize()} Scanner: Attempting to rebuild cache from disk...") - cache_loaded = await self._load_cache_from_disk() - if cache_loaded: - logger.info(f"{self.model_type.capitalize()} Scanner: Successfully reloaded cache from disk") - else: - logger.info(f"{self.model_type.capitalize()} Scanner: Could not reload cache from disk, proceeding with complete rebuild") - # If loading from disk failed, do a complete rebuild and save to disk - await self._initialize_cache() - await self._save_cache_to_disk() - return self._cache - - if self._cache is None: - # For initial creation, do a full initialization await self._initialize_cache() - # Save the newly built cache - await self._save_cache_to_disk() else: - # For subsequent refreshes, use fast reconciliation await self._reconcile_cache() return self._cache @@ -577,8 +370,6 @@ class ModelScanner: # Update cache self._cache = ModelCache( raw_data=raw_data, - sorted_by_name=[], - sorted_by_date=[], folders=[] ) @@ -592,8 +383,6 @@ class ModelScanner: if self._cache is None: self._cache = ModelCache( raw_data=[], - sorted_by_name=[], - sorted_by_date=[], folders=[] ) finally: @@ -735,19 +524,74 @@ class ModelScanner: # Resort cache await self._cache.resort() - # Save updated cache to disk - await self._save_cache_to_disk() - logger.info(f"{self.model_type.capitalize()} Scanner: Cache reconciliation completed in {time.time() - start_time:.2f} seconds. Added {total_added}, removed {total_removed} models.") except Exception as e: logger.error(f"{self.model_type.capitalize()} Scanner: Error reconciling cache: {e}", exc_info=True) finally: self._is_initializing = False # Unset flag - # These methods should be implemented in child classes async def scan_all_models(self) -> List[Dict]: """Scan all model directories and return metadata""" - raise NotImplementedError("Subclasses must implement scan_all_models") + all_models = [] + + # Create scan tasks for each directory + scan_tasks = [] + for model_root in self.get_model_roots(): + task = asyncio.create_task(self._scan_directory(model_root)) + scan_tasks.append(task) + + # Wait for all tasks to complete + for task in scan_tasks: + try: + models = await task + all_models.extend(models) + except Exception as e: + logger.error(f"Error scanning directory: {e}") + + return all_models + + async def _scan_directory(self, root_path: str) -> List[Dict]: + """Scan a single directory for model files""" + models = [] + original_root = root_path # Save original root path + + async def scan_recursive(path: str, visited_paths: set): + """Recursively scan directory, avoiding circular symlinks""" + try: + real_path = os.path.realpath(path) + if real_path in visited_paths: + logger.debug(f"Skipping already visited path: {path}") + return + visited_paths.add(real_path) + + with os.scandir(path) as it: + entries = list(it) + for entry in entries: + try: + if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions): + # Use original path instead of real path + file_path = entry.path.replace(os.sep, "/") + await self._process_single_file(file_path, original_root, models) + await asyncio.sleep(0) + elif entry.is_dir(follow_symlinks=True): + # For directories, continue scanning with original path + await scan_recursive(entry.path, visited_paths) + except Exception as e: + logger.error(f"Error processing entry {entry.path}: {e}") + except Exception as e: + logger.error(f"Error scanning {path}: {e}") + + await scan_recursive(root_path, set()) + return models + + async def _process_single_file(self, file_path: str, root_path: str, models: list): + """Process a single file and add to results list""" + try: + result = await self._process_model_file(file_path, root_path) + if result: + models.append(result) + except Exception as e: + logger.error(f"Error processing {file_path}: {e}") def is_initializing(self) -> bool: """Check if the scanner is currently initializing""" @@ -931,7 +775,7 @@ class ModelScanner: logger.error(f"Error processing {file_path}: {e}") async def add_model_to_cache(self, metadata_dict: Dict, folder: str = '') -> bool: - """Add a model to the cache and save to disk + """Add a model to the cache Args: metadata_dict: The model metadata dictionary @@ -960,9 +804,6 @@ class ModelScanner: # Update the hash index self._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path']) - - # Save to disk - await self._save_cache_to_disk() return True except Exception as e: logger.error(f"Error adding model to cache: {e}") @@ -1102,9 +943,6 @@ class ModelScanner: await cache.resort() - # Save the updated cache - await self._save_cache_to_disk() - return True def has_hash(self, sha256: str) -> bool: @@ -1198,11 +1036,7 @@ class ModelScanner: if self._cache is None: return False - updated = await self._cache.update_preview_url(file_path, preview_url, preview_nsfw_level) - if updated: - # Save updated cache to disk - await self._save_cache_to_disk() - return updated + return await self._cache.update_preview_url(file_path, preview_url, preview_nsfw_level) async def bulk_delete_models(self, file_paths: List[str]) -> Dict: """Delete multiple models and update cache in a batch operation @@ -1334,9 +1168,6 @@ class ModelScanner: # Resort cache await self._cache.resort() - # Save updated cache to disk - await self._save_cache_to_disk() - return True except Exception as e: diff --git a/py/services/model_service_factory.py b/py/services/model_service_factory.py new file mode 100644 index 00000000..6cc8a3a3 --- /dev/null +++ b/py/services/model_service_factory.py @@ -0,0 +1,137 @@ +from typing import Dict, Type, Any +import logging + +logger = logging.getLogger(__name__) + +class ModelServiceFactory: + """Factory for managing model services and routes""" + + _services: Dict[str, Type] = {} + _routes: Dict[str, Type] = {} + _initialized_services: Dict[str, Any] = {} + _initialized_routes: Dict[str, Any] = {} + + @classmethod + def register_model_type(cls, model_type: str, service_class: Type, route_class: Type): + """Register a new model type with its service and route classes + + Args: + model_type: The model type identifier (e.g., 'lora', 'checkpoint') + service_class: The service class for this model type + route_class: The route class for this model type + """ + cls._services[model_type] = service_class + cls._routes[model_type] = route_class + logger.info(f"Registered model type '{model_type}' with service {service_class.__name__} and routes {route_class.__name__}") + + @classmethod + def get_service_class(cls, model_type: str) -> Type: + """Get service class for a model type + + Args: + model_type: The model type identifier + + Returns: + The service class for the model type + + Raises: + ValueError: If model type is not registered + """ + if model_type not in cls._services: + raise ValueError(f"Unknown model type: {model_type}") + return cls._services[model_type] + + @classmethod + def get_route_class(cls, model_type: str) -> Type: + """Get route class for a model type + + Args: + model_type: The model type identifier + + Returns: + The route class for the model type + + Raises: + ValueError: If model type is not registered + """ + if model_type not in cls._routes: + raise ValueError(f"Unknown model type: {model_type}") + return cls._routes[model_type] + + @classmethod + def get_route_instance(cls, model_type: str): + """Get or create route instance for a model type + + Args: + model_type: The model type identifier + + Returns: + The route instance for the model type + """ + if model_type not in cls._initialized_routes: + route_class = cls.get_route_class(model_type) + cls._initialized_routes[model_type] = route_class() + return cls._initialized_routes[model_type] + + @classmethod + def setup_all_routes(cls, app): + """Setup routes for all registered model types + + Args: + app: The aiohttp application instance + """ + logger.info(f"Setting up routes for {len(cls._services)} registered model types") + + for model_type in cls._services.keys(): + try: + routes_instance = cls.get_route_instance(model_type) + routes_instance.setup_routes(app) + logger.info(f"Successfully set up routes for {model_type}") + except Exception as e: + logger.error(f"Failed to setup routes for {model_type}: {e}", exc_info=True) + + @classmethod + def get_registered_types(cls) -> list: + """Get list of all registered model types + + Returns: + List of registered model type identifiers + """ + return list(cls._services.keys()) + + @classmethod + def is_registered(cls, model_type: str) -> bool: + """Check if a model type is registered + + Args: + model_type: The model type identifier + + Returns: + True if the model type is registered, False otherwise + """ + return model_type in cls._services + + @classmethod + def clear_registrations(cls): + """Clear all registrations - mainly for testing purposes""" + cls._services.clear() + cls._routes.clear() + cls._initialized_services.clear() + cls._initialized_routes.clear() + logger.info("Cleared all model type registrations") + + +def register_default_model_types(): + """Register the default model types (LoRA and Checkpoint)""" + from ..services.lora_service import LoraService + from ..services.checkpoint_service import CheckpointService + from ..routes.lora_routes import LoraRoutes + from ..routes.checkpoint_routes import CheckpointRoutes + + # Register LoRA model type + ModelServiceFactory.register_model_type('lora', LoraService, LoraRoutes) + + # Register Checkpoint model type + ModelServiceFactory.register_model_type('checkpoint', CheckpointService, CheckpointRoutes) + + logger.info("Registered default model types: lora, checkpoint") \ No newline at end of file diff --git a/py/services/recipe_scanner.py b/py/services/recipe_scanner.py index 6d8491ae..89bbef14 100644 --- a/py/services/recipe_scanner.py +++ b/py/services/recipe_scanner.py @@ -393,8 +393,8 @@ class RecipeScanner: if 'hash' in lora and (not lora.get('file_name') or not lora['file_name']): hash_value = lora['hash'] - if self._lora_scanner.has_lora_hash(hash_value): - lora_path = self._lora_scanner.get_lora_path_by_hash(hash_value) + if self._lora_scanner.has_hash(hash_value): + lora_path = self._lora_scanner.get_path_by_hash(hash_value) if lora_path: file_name = os.path.splitext(os.path.basename(lora_path))[0] lora['file_name'] = file_name @@ -465,7 +465,7 @@ class RecipeScanner: # Count occurrences of each base model for lora in loras: if 'hash' in lora: - lora_path = self._lora_scanner.get_lora_path_by_hash(lora['hash']) + lora_path = self._lora_scanner.get_path_by_hash(lora['hash']) if lora_path: base_model = await self._get_base_model_for_lora(lora_path) if base_model: @@ -603,9 +603,9 @@ class RecipeScanner: if 'loras' in item: for lora in item['loras']: if 'hash' in lora and lora['hash']: - lora['inLibrary'] = self._lora_scanner.has_lora_hash(lora['hash'].lower()) + lora['inLibrary'] = self._lora_scanner.has_hash(lora['hash'].lower()) lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora['hash'].lower()) - lora['localPath'] = self._lora_scanner.get_lora_path_by_hash(lora['hash'].lower()) + lora['localPath'] = self._lora_scanner.get_path_by_hash(lora['hash'].lower()) result = { 'items': paginated_items, @@ -655,9 +655,9 @@ class RecipeScanner: for lora in formatted_recipe['loras']: if 'hash' in lora and lora['hash']: lora_hash = lora['hash'].lower() - lora['inLibrary'] = self._lora_scanner.has_lora_hash(lora_hash) + lora['inLibrary'] = self._lora_scanner.has_hash(lora_hash) lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora_hash) - lora['localPath'] = self._lora_scanner.get_lora_path_by_hash(lora_hash) + lora['localPath'] = self._lora_scanner.get_path_by_hash(lora_hash) return formatted_recipe diff --git a/py/services/service_registry.py b/py/services/service_registry.py index 15c00a3f..6cefb4d4 100644 --- a/py/services/service_registry.py +++ b/py/services/service_registry.py @@ -7,106 +7,176 @@ logger = logging.getLogger(__name__) T = TypeVar('T') # Define a type variable for service types class ServiceRegistry: - """Centralized registry for service singletons""" + """Central registry for managing singleton services""" - _instance = None _services: Dict[str, Any] = {} - _lock = asyncio.Lock() + _locks: Dict[str, asyncio.Lock] = {} @classmethod - def get_instance(cls): - """Get singleton instance of the registry""" - if cls._instance is None: - cls._instance = cls() - return cls._instance + async def register_service(cls, name: str, service: Any) -> None: + """Register a service instance with the registry + + Args: + name: Service name identifier + service: Service instance to register + """ + cls._services[name] = service + logger.debug(f"Registered service: {name}") @classmethod - async def register_service(cls, service_name: str, service_instance: Any) -> None: - """Register a service instance with the registry""" - registry = cls.get_instance() - async with cls._lock: - registry._services[service_name] = service_instance - logger.debug(f"Registered service: {service_name}") + async def get_service(cls, name: str) -> Optional[Any]: + """Get a service instance by name + + Args: + name: Service name identifier + + Returns: + Service instance or None if not found + """ + return cls._services.get(name) @classmethod - async def get_service(cls, service_name: str) -> Any: - """Get a service instance by name""" - registry = cls.get_instance() - async with cls._lock: - if service_name not in registry._services: - logger.debug(f"Service {service_name} not found in registry") - return None - return registry._services[service_name] + def _get_lock(cls, name: str) -> asyncio.Lock: + """Get or create a lock for a service + + Args: + name: Service name identifier + + Returns: + AsyncIO lock for the service + """ + if name not in cls._locks: + cls._locks[name] = asyncio.Lock() + return cls._locks[name] - @classmethod - def get_service_sync(cls, service_name: str) -> Any: - """Get a service instance by name (synchronous version)""" - registry = cls.get_instance() - if service_name not in registry._services: - logger.debug(f"Service {service_name} not found in registry") - return None - return registry._services[service_name] - - # Convenience methods for common services @classmethod async def get_lora_scanner(cls): - """Get the LoraScanner instance""" - from .lora_scanner import LoraScanner - scanner = await cls.get_service("lora_scanner") - if scanner is None: - scanner = await LoraScanner.get_instance() - await cls.register_service("lora_scanner", scanner) - return scanner + """Get or create LoRA scanner instance""" + service_name = "lora_scanner" + if service_name in cls._services: + return cls._services[service_name] + + async with cls._get_lock(service_name): + # Double-check after acquiring lock + if service_name in cls._services: + return cls._services[service_name] + + # Import here to avoid circular imports + from .lora_scanner import LoraScanner + + scanner = await LoraScanner.get_instance() + cls._services[service_name] = scanner + logger.debug(f"Created and registered {service_name}") + return scanner + @classmethod async def get_checkpoint_scanner(cls): - """Get the CheckpointScanner instance""" - from .checkpoint_scanner import CheckpointScanner - scanner = await cls.get_service("checkpoint_scanner") - if scanner is None: + """Get or create Checkpoint scanner instance""" + service_name = "checkpoint_scanner" + + if service_name in cls._services: + return cls._services[service_name] + + async with cls._get_lock(service_name): + # Double-check after acquiring lock + if service_name in cls._services: + return cls._services[service_name] + + # Import here to avoid circular imports + from .checkpoint_scanner import CheckpointScanner + scanner = await CheckpointScanner.get_instance() - await cls.register_service("checkpoint_scanner", scanner) - return scanner - - @classmethod - async def get_civitai_client(cls): - """Get the CivitaiClient instance""" - from .civitai_client import CivitaiClient - client = await cls.get_service("civitai_client") - if client is None: - client = await CivitaiClient.get_instance() - await cls.register_service("civitai_client", client) - return client - - @classmethod - async def get_download_manager(cls): - """Get the DownloadManager instance""" - from .download_manager import DownloadManager - manager = await cls.get_service("download_manager") - if manager is None: - manager = await DownloadManager.get_instance() - await cls.register_service("download_manager", manager) - return manager - + cls._services[service_name] = scanner + logger.debug(f"Created and registered {service_name}") + return scanner + @classmethod async def get_recipe_scanner(cls): - """Get the RecipeScanner instance""" - from .recipe_scanner import RecipeScanner - scanner = await cls.get_service("recipe_scanner") - if scanner is None: - lora_scanner = await cls.get_lora_scanner() - scanner = RecipeScanner(lora_scanner) - await cls.register_service("recipe_scanner", scanner) - return scanner - + """Get or create Recipe scanner instance""" + service_name = "recipe_scanner" + + if service_name in cls._services: + return cls._services[service_name] + + async with cls._get_lock(service_name): + # Double-check after acquiring lock + if service_name in cls._services: + return cls._services[service_name] + + # Import here to avoid circular imports + from .recipe_scanner import RecipeScanner + + scanner = await RecipeScanner.get_instance() + cls._services[service_name] = scanner + logger.debug(f"Created and registered {service_name}") + return scanner + + @classmethod + async def get_civitai_client(cls): + """Get or create CivitAI client instance""" + service_name = "civitai_client" + + if service_name in cls._services: + return cls._services[service_name] + + async with cls._get_lock(service_name): + # Double-check after acquiring lock + if service_name in cls._services: + return cls._services[service_name] + + # Import here to avoid circular imports + from .civitai_client import CivitaiClient + + client = await CivitaiClient.get_instance() + cls._services[service_name] = client + logger.debug(f"Created and registered {service_name}") + return client + + @classmethod + async def get_download_manager(cls): + """Get or create Download manager instance""" + service_name = "download_manager" + + if service_name in cls._services: + return cls._services[service_name] + + async with cls._get_lock(service_name): + # Double-check after acquiring lock + if service_name in cls._services: + return cls._services[service_name] + + # Import here to avoid circular imports + from .download_manager import DownloadManager + + manager = DownloadManager() + cls._services[service_name] = manager + logger.debug(f"Created and registered {service_name}") + return manager + @classmethod async def get_websocket_manager(cls): - """Get the WebSocketManager instance""" - from .websocket_manager import ws_manager - manager = await cls.get_service("websocket_manager") - if manager is None: - # ws_manager is already a global instance in websocket_manager.py + """Get or create WebSocket manager instance""" + service_name = "websocket_manager" + + if service_name in cls._services: + return cls._services[service_name] + + async with cls._get_lock(service_name): + # Double-check after acquiring lock + if service_name in cls._services: + return cls._services[service_name] + + # Import here to avoid circular imports from .websocket_manager import ws_manager - await cls.register_service("websocket_manager", ws_manager) - manager = ws_manager - return manager \ No newline at end of file + + cls._services[service_name] = ws_manager + logger.debug(f"Registered {service_name}") + return ws_manager + + @classmethod + def clear_services(cls): + """Clear all registered services - mainly for testing""" + cls._services.clear() + cls._locks.clear() + logger.info("Cleared all registered services") \ No newline at end of file diff --git a/py/utils/routes_common.py b/py/utils/routes_common.py index e151d216..99f679cd 100644 --- a/py/utils/routes_common.py +++ b/py/utils/routes_common.py @@ -566,9 +566,10 @@ class ModelRouteUtils: return web.Response(text=str(e), status=500) @staticmethod - async def handle_download_model(request: web.Request, download_manager: DownloadManager) -> web.Response: + async def handle_download_model(request: web.Request) -> web.Response: """Handle model download request""" try: + download_manager = await ServiceRegistry.get_download_manager() data = await request.json() # Get or generate a download ID @@ -663,17 +664,17 @@ class ModelRouteUtils: }, status=500) @staticmethod - async def handle_cancel_download(request: web.Request, download_manager: DownloadManager) -> web.Response: + async def handle_cancel_download(request: web.Request) -> web.Response: """Handle cancellation of a download task Args: request: The aiohttp request - download_manager: The download manager instance Returns: web.Response: The HTTP response """ try: + download_manager = await ServiceRegistry.get_download_manager() download_id = request.match_info.get('download_id') if not download_id: return web.json_response({ @@ -701,17 +702,17 @@ class ModelRouteUtils: }, status=500) @staticmethod - async def handle_list_downloads(request: web.Request, download_manager: DownloadManager) -> web.Response: + async def handle_list_downloads(request: web.Request) -> web.Response: """Get list of active downloads Args: request: The aiohttp request - download_manager: The download manager instance Returns: web.Response: The HTTP response with list of downloads """ try: + download_manager = await ServiceRegistry.get_download_manager() result = await download_manager.get_active_downloads() return web.json_response(result) except Exception as e: @@ -1047,3 +1048,56 @@ class ModelRouteUtils: 'success': False, 'error': str(e) }, status=500) + + @staticmethod + async def handle_save_metadata(request: web.Request, scanner) -> web.Response: + """Handle saving metadata updates + + Args: + request: The aiohttp request + scanner: The model scanner instance + + Returns: + web.Response: The HTTP response + """ + try: + data = await request.json() + file_path = data.get('file_path') + if not file_path: + return web.Response(text='File path is required', status=400) + + # Remove file path from data to avoid saving it + metadata_updates = {k: v for k, v in data.items() if k != 'file_path'} + + # Get metadata file path + metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' + + # Load existing metadata + metadata = await ModelRouteUtils.load_local_metadata(metadata_path) + + # Handle nested updates (for civitai.trainedWords) + for key, value in metadata_updates.items(): + if isinstance(value, dict) and key in metadata and isinstance(metadata[key], dict): + # Deep update for nested dictionaries + for nested_key, nested_value in value.items(): + metadata[key][nested_key] = nested_value + else: + # Regular update for top-level keys + metadata[key] = value + + # Save updated metadata + await MetadataManager.save_metadata(file_path, metadata) + + # Update cache + await scanner.update_single_model_cache(file_path, file_path, metadata) + + # If model_name was updated, resort the cache + if 'model_name' in metadata_updates: + cache = await scanner.get_cached_data() + await cache.resort() + + return web.json_response({'success': True}) + + except Exception as e: + logger.error(f"Error saving metadata: {e}", exc_info=True) + return web.Response(text=str(e), status=500) diff --git a/pyproject.toml b/pyproject.toml index 2af0a264..0b82f326 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ dependencies = [ "requests", "toml", "natsort", - "msgpack" + "GitPython" ] [project.urls] diff --git a/requirements.txt b/requirements.txt index db115810..87c9f47f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,5 +9,5 @@ requests toml numpy natsort -msgpack pyyaml +GitPython diff --git a/standalone.py b/standalone.py index 0120ab53..8890e05f 100644 --- a/standalone.py +++ b/standalone.py @@ -106,6 +106,22 @@ logger = logging.getLogger("lora-manager-standalone") # Configure aiohttp access logger to be less verbose logging.getLogger('aiohttp.access').setLevel(logging.WARNING) +# Add specific suppression for connection reset errors +class ConnectionResetFilter(logging.Filter): + def filter(self, record): + # Filter out connection reset errors that are not critical + if "ConnectionResetError" in str(record.getMessage()): + return False + if "_call_connection_lost" in str(record.getMessage()): + return False + if "WinError 10054" in str(record.getMessage()): + return False + return True + +# Apply the filter to asyncio logger +asyncio_logger = logging.getLogger("asyncio") +asyncio_logger.addFilter(ConnectionResetFilter()) + # Now we can import the global config from our local modules from py.config import config @@ -118,17 +134,6 @@ class StandaloneServer: # Ensure the app's access logger is configured to reduce verbosity self.app._subapps = [] # Ensure this exists to avoid AttributeError - - # Configure access logging for the app - self.app.on_startup.append(self._configure_access_logger) - - async def _configure_access_logger(self, app): - """Configure access logger to reduce verbosity""" - logging.getLogger('aiohttp.access').setLevel(logging.WARNING) - - # If using aiohttp>=3.8.0, configure access logger through app directly - if hasattr(app, 'access_logger'): - app.access_logger.setLevel(logging.WARNING) async def setup(self): """Set up the standalone server""" @@ -218,9 +223,6 @@ class StandaloneLoraManager(LoraManager): # Store app in a global-like location for compatibility sys.modules['server'].PromptServer.instance = server_instance - - # Configure aiohttp access logger to be less verbose - logging.getLogger('aiohttp.access').setLevel(logging.WARNING) added_targets = set() # Track already added target paths @@ -314,35 +316,39 @@ class StandaloneLoraManager(LoraManager): app.router.add_static('/loras_static', config.static_path) # Setup feature routes - from py.routes.lora_routes import LoraRoutes - from py.routes.api_routes import ApiRoutes + from py.services.model_service_factory import ModelServiceFactory, register_default_model_types from py.routes.recipe_routes import RecipeRoutes - from py.routes.checkpoints_routes import CheckpointsRoutes from py.routes.update_routes import UpdateRoutes from py.routes.misc_routes import MiscRoutes from py.routes.example_images_routes import ExampleImagesRoutes from py.routes.stats_routes import StatsRoutes + from py.services.websocket_manager import ws_manager - lora_routes = LoraRoutes() - checkpoints_routes = CheckpointsRoutes() + + register_default_model_types() + + # Setup all model routes using the factory + ModelServiceFactory.setup_all_routes(app) + stats_routes = StatsRoutes() # Initialize routes - lora_routes.setup_routes(app) - checkpoints_routes.setup_routes(app) stats_routes.setup_routes(app) - ApiRoutes.setup_routes(app) RecipeRoutes.setup_routes(app) UpdateRoutes.setup_routes(app) MiscRoutes.setup_routes(app) ExampleImagesRoutes.setup_routes(app) + + # Setup WebSocket routes that are shared across all model types + app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection) + app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection) + app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection) # Schedule service initialization app.on_startup.append(lambda app: cls._initialize_services()) # Add cleanup app.on_shutdown.append(cls._cleanup) - app.on_shutdown.append(ApiRoutes.cleanup) def parse_args(): """Parse command line arguments""" @@ -367,9 +373,6 @@ async def main(): # Set log level logging.getLogger().setLevel(getattr(logging, args.log_level)) - # Explicitly configure aiohttp access logger regardless of selected log level - logging.getLogger('aiohttp.access').setLevel(logging.WARNING) - # Create the server instance server = StandaloneServer() diff --git a/static/css/base.css b/static/css/base.css index 1f983dae..99aedbb7 100644 --- a/static/css/base.css +++ b/static/css/base.css @@ -50,8 +50,8 @@ html, body { --lora-border: oklch(90% 0.02 256 / 0.15); --lora-text: oklch(95% 0.02 256); --lora-error: oklch(75% 0.32 29); - --lora-warning: oklch(var(--lora-warning-l) var(--lora-warning-c) var(--lora-warning-h)); /* Modified to be used with oklch() */ - --lora-success: oklch(var(--lora-success-l) var(--lora-success-c) var(--lora-success-h)); /* New green success color */ + --lora-warning: oklch(var(--lora-warning-l) var(--lora-warning-c) var(--lora-warning-h)); + --lora-success: oklch(var(--lora-success-l) var(--lora-success-c) var(--lora-success-h)); /* Spacing Scale */ --space-1: calc(8px * 1); diff --git a/static/css/components/header.css b/static/css/components/header.css index 13918428..0eb2237a 100644 --- a/static/css/components/header.css +++ b/static/css/components/header.css @@ -223,11 +223,6 @@ opacity: 1; } -.update-badge.hidden, -.update-badge:not(.visible) { - opacity: 0; -} - /* Mobile adjustments */ @media (max-width: 768px) { .app-title { diff --git a/static/css/components/modal.css b/static/css/components/modal.css index 500afa08..031ed917 100644 --- a/static/css/components/modal.css +++ b/static/css/components/modal.css @@ -172,6 +172,91 @@ body.modal-open { opacity: 1; } +/* Update Modal specific styles */ +.update-actions { + display: flex; + flex-direction: column; + gap: var(--space-2); + align-items: stretch; + flex-wrap: nowrap; +} + +.update-link { + color: var(--lora-accent); + text-decoration: none; + display: flex; + align-items: center; + gap: 8px; + font-size: 0.95em; +} + +.update-link:hover { + text-decoration: underline; +} + +/* Update progress styles */ +.update-progress { + background: rgba(0, 0, 0, 0.03); + border: 1px solid var(--lora-border); + border-radius: var(--border-radius-sm); + padding: var(--space-2); + margin: var(--space-2) 0; +} + +[data-theme="dark"] .update-progress { + background: rgba(255, 255, 255, 0.03); +} + +.progress-info { + display: flex; + flex-direction: column; + gap: var(--space-1); +} + +.progress-text { + font-size: 0.9em; + color: var(--text-color); + opacity: 0.8; +} + +.progress-bar { + width: 100%; + height: 8px; + background-color: rgba(0, 0, 0, 0.1); + border-radius: 4px; + overflow: hidden; +} + +[data-theme="dark"] .progress-bar { + background-color: rgba(255, 255, 255, 0.1); +} + +.progress-fill { + height: 100%; + background-color: var(--lora-accent); + width: 0%; + transition: width 0.3s ease; + border-radius: 4px; +} + +/* Update button states */ +#updateBtn { + min-width: 120px; +} + +#updateBtn.updating { + background-color: var(--lora-warning); + cursor: not-allowed; +} + +#updateBtn.success { + background-color: var(--lora-success); +} + +#updateBtn.error { + background-color: var(--lora-error); +} + /* Settings styles */ .settings-toggle { width: 36px; diff --git a/static/css/layout.css b/static/css/layout.css index b948a5f2..182e50b5 100644 --- a/static/css/layout.css +++ b/static/css/layout.css @@ -182,6 +182,31 @@ box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05); } +/* Style for optgroups */ +.control-group select optgroup { + font-weight: 600; + font-style: normal; + color: var(--text-color); + background-color: var(--card-bg); +} + +.control-group select option { + padding: 4px 8px; + background-color: var(--card-bg); + color: var(--text-color); +} + +/* Dark theme optgroup styling */ +[data-theme="dark"] .control-group select optgroup { + background-color: var(--card-bg); + color: var(--text-color); +} + +[data-theme="dark"] .control-group select option { + background-color: var(--card-bg); + color: var(--text-color); +} + .control-group select:hover { border-color: var(--lora-accent); background-color: var(--bg-color); diff --git a/static/js/api/baseModelApi.js b/static/js/api/baseModelApi.js index 43abbedf..fe915f86 100644 --- a/static/js/api/baseModelApi.js +++ b/static/js/api/baseModelApi.js @@ -54,25 +54,16 @@ export async function fetchModelsPage(options = {}) { if (pageState.filters) { // Handle tags filters if (pageState.filters.tags && pageState.filters.tags.length > 0) { - // Checkpoints API expects individual 'tag' parameters, Loras API expects comma-separated 'tags' - if (modelType === 'checkpoint') { - pageState.filters.tags.forEach(tag => { - params.append('tag', tag); - }); - } else { - params.append('tags', pageState.filters.tags.join(',')); - } + pageState.filters.tags.forEach(tag => { + params.append('tag', tag); + }); } // Handle base model filters if (pageState.filters.baseModel && pageState.filters.baseModel.length > 0) { - if (modelType === 'checkpoint') { - pageState.filters.baseModel.forEach(model => { - params.append('base_model', model); - }); - } else { - params.append('base_models', pageState.filters.baseModel.join(',')); - } + pageState.filters.baseModel.forEach(model => { + params.append('base_model', model); + }); } } @@ -277,7 +268,7 @@ export async function deleteModel(filePath, modelType = 'lora') { const endpoint = modelType === 'checkpoint' ? '/api/checkpoints/delete' - : '/api/delete_model'; + : '/api/loras/delete'; const response = await fetch(endpoint, { method: 'POST', @@ -454,7 +445,7 @@ export async function refreshSingleModelMetadata(filePath, modelType = 'lora') { const endpoint = modelType === 'checkpoint' ? '/api/checkpoints/fetch-civitai' - : '/api/fetch-civitai'; + : '/api/loras/fetch-civitai'; const response = await fetch(endpoint, { method: 'POST', @@ -557,7 +548,7 @@ export async function uploadPreview(filePath, file, modelType = 'lora', nsfwLeve // Set endpoint based on model type const endpoint = modelType === 'checkpoint' ? '/api/checkpoints/replace-preview' - : '/api/replace_preview'; + : '/api/loras/replace_preview'; const response = await fetch(endpoint, { method: 'POST', diff --git a/static/js/api/loraApi.js b/static/js/api/loraApi.js index 3c6f0c86..9d5dd1bc 100644 --- a/static/js/api/loraApi.js +++ b/static/js/api/loraApi.js @@ -89,7 +89,7 @@ export async function fetchLorasPage(page = 1, pageSize = 100) { export async function fetchCivitai() { return fetchCivitaiMetadata({ modelType: 'lora', - fetchEndpoint: '/api/fetch-all-civitai', + fetchEndpoint: '/api/loras/fetch-all-civitai', resetAndReloadFunction: resetAndReload }); } diff --git a/static/js/components/ContextMenu/ModelContextMenuMixin.js b/static/js/components/ContextMenu/ModelContextMenuMixin.js index 7c0bc0c3..cd58bc0f 100644 --- a/static/js/components/ContextMenu/ModelContextMenuMixin.js +++ b/static/js/components/ContextMenu/ModelContextMenuMixin.js @@ -125,7 +125,7 @@ export const ModelContextMenuMixin = { const endpoint = this.modelType === 'checkpoint' ? '/api/checkpoints/relink-civitai' : - '/api/relink-civitai'; + '/api/loras/relink-civitai'; const response = await fetch(endpoint, { method: 'POST', diff --git a/static/js/components/controls/PageControls.js b/static/js/components/controls/PageControls.js index 42ad0062..53a2d8ef 100644 --- a/static/js/components/controls/PageControls.js +++ b/static/js/components/controls/PageControls.js @@ -1,5 +1,5 @@ // PageControls.js - Manages controls for both LoRAs and Checkpoints pages -import { state, getCurrentPageState, setCurrentPageType } from '../../state/index.js'; +import { getCurrentPageState, setCurrentPageType } from '../../state/index.js'; import { getStorageItem, setStorageItem, getSessionItem, setSessionItem } from '../../utils/storageHelpers.js'; import { showToast } from '../../utils/uiHelpers.js'; @@ -41,6 +41,9 @@ export class PageControls { this.pageState.isLoading = false; this.pageState.hasMore = true; + // Set default sort based on page type + this.pageState.sortBy = this.pageType === 'loras' ? 'name:asc' : 'name:asc'; + // Load sort preference this.loadSortPreference(); } @@ -326,14 +329,36 @@ export class PageControls { loadSortPreference() { const savedSort = getStorageItem(`${this.pageType}_sort`); if (savedSort) { - this.pageState.sortBy = savedSort; + // Handle legacy format conversion + const convertedSort = this.convertLegacySortFormat(savedSort); + this.pageState.sortBy = convertedSort; const sortSelect = document.getElementById('sortSelect'); if (sortSelect) { - sortSelect.value = savedSort; + sortSelect.value = convertedSort; } } } + /** + * Convert legacy sort format to new format + * @param {string} sortValue - The sort value to convert + * @returns {string} - Converted sort value + */ + convertLegacySortFormat(sortValue) { + // Convert old format to new format with direction + switch (sortValue) { + case 'name': + return 'name:asc'; + case 'date': + return 'date:desc'; // Newest first is more intuitive default + case 'size': + return 'size:desc'; // Largest first is more intuitive default + default: + // If it's already in new format or unknown, return as is + return sortValue.includes(':') ? sortValue : 'name:asc'; + } + } + /** * Save sort preference to storage * @param {string} sortValue - The sort value to save diff --git a/static/js/managers/DownloadManager.js b/static/js/managers/DownloadManager.js index 7584bf0f..b3d0433e 100644 --- a/static/js/managers/DownloadManager.js +++ b/static/js/managers/DownloadManager.js @@ -87,7 +87,7 @@ export class DownloadManager { throw new Error('Invalid Civitai URL format'); } - const response = await fetch(`/api/civitai/versions/${this.modelId}`); + const response = await fetch(`/api/loras/civitai/versions/${this.modelId}`); if (!response.ok) { const errorData = await response.json().catch(() => ({})); if (errorData && errorData.error && errorData.error.includes('Model type mismatch')) { @@ -254,7 +254,7 @@ export class DownloadManager { try { // Fetch LoRA roots - const rootsResponse = await fetch('/api/lora-roots'); + const rootsResponse = await fetch('/api/loras/roots'); if (!rootsResponse.ok) { throw new Error('Failed to fetch LoRA roots'); } @@ -272,7 +272,7 @@ export class DownloadManager { } // Fetch folders dynamically - const foldersResponse = await fetch('/api/folders'); + const foldersResponse = await fetch('/api/loras/folders'); if (!foldersResponse.ok) { throw new Error('Failed to fetch folders'); } diff --git a/static/js/managers/MoveManager.js b/static/js/managers/MoveManager.js index 45e6a011..e72c8274 100644 --- a/static/js/managers/MoveManager.js +++ b/static/js/managers/MoveManager.js @@ -74,7 +74,7 @@ class MoveManager { try { // Fetch LoRA roots - const rootsResponse = await fetch('/api/lora-roots'); + const rootsResponse = await fetch('/api/loras/roots'); if (!rootsResponse.ok) { throw new Error('Failed to fetch LoRA roots'); } @@ -96,7 +96,7 @@ class MoveManager { } // Fetch folders dynamically - const foldersResponse = await fetch('/api/folders'); + const foldersResponse = await fetch('/api/loras/folders'); if (!foldersResponse.ok) { throw new Error('Failed to fetch folders'); } @@ -190,7 +190,7 @@ class MoveManager { // Refresh folder tags after successful move try { - const foldersResponse = await fetch('/api/folders'); + const foldersResponse = await fetch('/api/loras/folders'); if (foldersResponse.ok) { const foldersData = await foldersResponse.json(); updateFolderTags(foldersData.folders); diff --git a/static/js/managers/SettingsManager.js b/static/js/managers/SettingsManager.js index 828cb461..c0f0f662 100644 --- a/static/js/managers/SettingsManager.js +++ b/static/js/managers/SettingsManager.js @@ -161,7 +161,7 @@ export class SettingsManager { if (!defaultLoraRootSelect) return; // Fetch lora roots - const response = await fetch('/api/lora-roots'); + const response = await fetch('/api/loras/roots'); if (!response.ok) { throw new Error('Failed to fetch LoRA roots'); } diff --git a/static/js/managers/UpdateService.js b/static/js/managers/UpdateService.js index 52a62603..dc95f98d 100644 --- a/static/js/managers/UpdateService.js +++ b/static/js/managers/UpdateService.js @@ -3,7 +3,7 @@ import { getStorageItem, setStorageItem } from '../utils/storageHelpers.js'; export class UpdateService { constructor() { - this.updateCheckInterval = 24 * 60 * 60 * 1000; // 24 hours + this.updateCheckInterval = 60 * 60 * 1000; // 1 hour this.currentVersion = "v0.0.0"; // Initialize with default values this.latestVersion = "v0.0.0"; // Initialize with default values this.updateInfo = null; @@ -13,8 +13,10 @@ export class UpdateService { branch: "unknown", commit_date: "unknown" }; - this.updateNotificationsEnabled = getStorageItem('show_update_notifications'); + this.updateNotificationsEnabled = getStorageItem('show_update_notifications', true); this.lastCheckTime = parseInt(getStorageItem('last_update_check') || '0'); + this.isUpdating = false; + this.nightlyMode = getStorageItem('nightly_updates', false); } initialize() { @@ -28,23 +30,44 @@ export class UpdateService { this.updateBadgeVisibility(); }); } + + const updateBtn = document.getElementById('updateBtn'); + if (updateBtn) { + updateBtn.addEventListener('click', () => this.performUpdate()); + } + + // Register event listener for nightly update toggle + const nightlyCheckbox = document.getElementById('nightlyUpdateToggle'); + if (nightlyCheckbox) { + nightlyCheckbox.checked = this.nightlyMode; + nightlyCheckbox.addEventListener('change', (e) => { + this.nightlyMode = e.target.checked; + setStorageItem('nightly_updates', e.target.checked); + this.updateNightlyWarning(); + this.updateModalContent(); + // Re-check for updates when switching channels + this.manualCheckForUpdates(); + }); + this.updateNightlyWarning(); + } // Perform update check if needed this.checkForUpdates().then(() => { // Ensure badges are updated after checking this.updateBadgeVisibility(); }); - - // Set up event listener for update button - // const updateToggle = document.getElementById('updateToggleBtn'); - // if (updateToggle) { - // updateToggle.addEventListener('click', () => this.toggleUpdateModal()); - // } // Immediately update modal content with current values (even if from default) this.updateModalContent(); } + updateNightlyWarning() { + const warning = document.getElementById('nightlyWarning'); + if (warning) { + warning.style.display = this.nightlyMode ? 'flex' : 'none'; + } + } + async checkForUpdates() { // Check if we should perform an update check const now = Date.now(); @@ -59,8 +82,8 @@ export class UpdateService { } try { - // Call backend API to check for updates - const response = await fetch('/api/check-updates'); + // Call backend API to check for updates with nightly flag + const response = await fetch(`/api/check-updates?nightly=${this.nightlyMode}`); const data = await response.json(); if (data.success) { @@ -137,8 +160,8 @@ export class UpdateService { const shouldShow = this.updateNotificationsEnabled && this.updateAvailable; if (updateBadge) { - updateBadge.classList.toggle('hidden', !shouldShow); - console.log("Update badge visibility:", !shouldShow ? "hidden" : "visible"); + updateBadge.classList.toggle('visible', shouldShow); + console.log("Update badge visibility:", shouldShow ? "visible" : "hidden"); } } @@ -157,7 +180,17 @@ export class UpdateService { const newVersionEl = modal.querySelector('.new-version .version-number'); if (currentVersionEl) currentVersionEl.textContent = this.currentVersion; - if (newVersionEl) newVersionEl.textContent = this.latestVersion; + + if (newVersionEl) { + newVersionEl.textContent = this.latestVersion; + } + + // Update update button state + const updateBtn = modal.querySelector('#updateBtn'); + if (updateBtn) { + updateBtn.classList.toggle('disabled', !this.updateAvailable || this.isUpdating); + updateBtn.disabled = !this.updateAvailable || this.isUpdating; + } // Update git info const gitInfoEl = modal.querySelector('.git-info'); @@ -218,6 +251,131 @@ export class UpdateService { } } + async performUpdate() { + if (!this.updateAvailable || this.isUpdating) { + return; + } + + try { + this.isUpdating = true; + this.updateUpdateUI('updating', 'Updating...'); + this.showUpdateProgress(true); + + // Update progress + this.updateProgress(10, 'Preparing update...'); + + const response = await fetch('/api/perform-update', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + nightly: this.nightlyMode + }) + }); + + this.updateProgress(50, 'Installing update...'); + + const data = await response.json(); + + if (data.success) { + this.updateProgress(100, 'Update completed successfully!'); + this.updateUpdateUI('success', 'Updated!'); + + // Show success message and suggest restart + setTimeout(() => { + this.showUpdateCompleteMessage(data.new_version); + }, 1000); + + } else { + throw new Error(data.error || 'Update failed'); + } + + } catch (error) { + console.error('Update failed:', error); + this.updateUpdateUI('error', 'Update Failed'); + this.updateProgress(0, `Update failed: ${error.message}`); + + // Hide progress after error + setTimeout(() => { + this.showUpdateProgress(false); + }, 3000); + } finally { + this.isUpdating = false; + } + } + + updateUpdateUI(state, text) { + const updateBtn = document.getElementById('updateBtn'); + const updateBtnText = document.getElementById('updateBtnText'); + + if (updateBtn && updateBtnText) { + // Remove existing state classes + updateBtn.classList.remove('updating', 'success', 'error', 'disabled'); + + // Add new state class + if (state !== 'normal') { + updateBtn.classList.add(state); + } + + // Update button text + updateBtnText.textContent = text; + + // Update disabled state + updateBtn.disabled = (state === 'updating' || state === 'disabled'); + } + } + + showUpdateProgress(show) { + const progressContainer = document.getElementById('updateProgress'); + if (progressContainer) { + progressContainer.style.display = show ? 'block' : 'none'; + } + } + + updateProgress(percentage, text) { + const progressFill = document.getElementById('updateProgressFill'); + const progressText = document.getElementById('updateProgressText'); + + if (progressFill) { + progressFill.style.width = `${percentage}%`; + } + + if (progressText) { + progressText.textContent = text; + } + } + + showUpdateCompleteMessage(newVersion) { + const modal = document.getElementById('updateModal'); + if (!modal) return; + + // Update the modal content to show completion + const progressText = document.getElementById('updateProgressText'); + if (progressText) { + progressText.innerHTML = ` +
+ + Successfully updated to ${newVersion}! +

+ + Please restart ComfyUI to complete the update process. + +
+ `; + } + + // Update current version display + this.currentVersion = newVersion; + this.updateAvailable = false; + + // Refresh the modal content + setTimeout(() => { + this.updateModalContent(); + this.showUpdateProgress(false); + }, 2000); + } + // Simple markdown parser for changelog items parseMarkdown(text) { if (!text) return ''; diff --git a/static/js/managers/import/FolderBrowser.js b/static/js/managers/import/FolderBrowser.js index 33504ee8..edd4d6a4 100644 --- a/static/js/managers/import/FolderBrowser.js +++ b/static/js/managers/import/FolderBrowser.js @@ -99,7 +99,7 @@ export class FolderBrowser { } // Fetch LoRA roots - const rootsResponse = await fetch('/api/lora-roots'); + const rootsResponse = await fetch('/api/loras/roots'); if (!rootsResponse.ok) { throw new Error(`Failed to fetch LoRA roots: ${rootsResponse.status}`); } @@ -119,7 +119,7 @@ export class FolderBrowser { } // Fetch folders - const foldersResponse = await fetch('/api/folders'); + const foldersResponse = await fetch('/api/loras/folders'); if (!foldersResponse.ok) { throw new Error(`Failed to fetch folders: ${foldersResponse.status}`); } diff --git a/templates/components/controls.html b/templates/components/controls.html index 65b53350..3c36b1bd 100644 --- a/templates/components/controls.html +++ b/templates/components/controls.html @@ -11,8 +11,18 @@
- +
diff --git a/templates/components/modals.html b/templates/components/modals.html index 7caa12dd..7f152c90 100644 --- a/templates/components/modals.html +++ b/templates/components/modals.html @@ -476,9 +476,26 @@ v0.0.0
- - View on GitHub - + +
+ + View on GitHub + + +
+ + + +