import asyncio import logging from aiohttp import web from typing import Dict from server import PromptServer # type: ignore from .base_model_routes import BaseModelRoutes from ..services.lora_service import LoraService from ..services.service_registry import ServiceRegistry from ..utils.routes_common import ModelRouteUtils from ..utils.utils import get_lora_info logger = logging.getLogger(__name__) class LoraRoutes(BaseModelRoutes): """LoRA-specific route controller""" def __init__(self): """Initialize LoRA routes with LoRA service""" # Service will be initialized later via setup_routes self.service = None self.civitai_client = None self.template_name = "loras.html" async def initialize_services(self): """Initialize services from ServiceRegistry""" lora_scanner = await ServiceRegistry.get_lora_scanner() self.service = LoraService(lora_scanner) self.civitai_client = await ServiceRegistry.get_civitai_client() # Initialize parent with the service super().__init__(self.service) def setup_routes(self, app: web.Application): """Setup LoRA routes""" # Schedule service initialization on app startup app.on_startup.append(lambda _: self.initialize_services()) # Setup common routes with 'loras' prefix (includes page route) super().setup_routes(app, 'loras') def setup_specific_routes(self, app: web.Application, prefix: str): """Setup LoRA-specific routes""" # LoRA-specific query routes app.router.add_get(f'/api/{prefix}/letter-counts', self.get_letter_counts) app.router.add_get(f'/api/{prefix}/get-notes', self.get_lora_notes) app.router.add_get(f'/api/{prefix}/get-trigger-words', self.get_lora_trigger_words) app.router.add_get(f'/api/lora-preview-url', self.get_lora_preview_url) app.router.add_get(f'/api/lora-civitai-url', self.get_lora_civitai_url) app.router.add_get(f'/api/lora-model-description', self.get_lora_model_description) app.router.add_get(f'/api/folders', self.get_folders) app.router.add_get(f'/api/lora-roots', self.get_lora_roots) # LoRA-specific management routes app.router.add_post(f'/api/move_model', self.move_model) app.router.add_post(f'/api/move_models_bulk', self.move_models_bulk) # CivitAI integration with LoRA-specific validation app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions_lora) app.router.add_get(f'/api/civitai/model/version/{{modelVersionId}}', self.get_civitai_model_by_version) app.router.add_get(f'/api/civitai/model/hash/{{hash}}', self.get_civitai_model_by_hash) # ComfyUI integration app.router.add_post(f'/loramanager/get_trigger_words', self.get_trigger_words) def _parse_specific_params(self, request: web.Request) -> Dict: """Parse LoRA-specific parameters""" params = {} # LoRA-specific parameters if 'first_letter' in request.query: params['first_letter'] = request.query.get('first_letter') # Handle fuzzy search parameter name variation if request.query.get('fuzzy') == 'true': params['fuzzy_search'] = True # Handle additional filter parameters for LoRAs if 'lora_hash' in request.query: if not params.get('hash_filters'): params['hash_filters'] = {} params['hash_filters']['single_hash'] = request.query['lora_hash'].lower() elif 'lora_hashes' in request.query: if not params.get('hash_filters'): params['hash_filters'] = {} params['hash_filters']['multiple_hashes'] = [h.lower() for h in request.query['lora_hashes'].split(',')] return params # LoRA-specific route handlers async def get_letter_counts(self, request: web.Request) -> web.Response: """Get count of LoRAs for each letter of the alphabet""" try: letter_counts = await self.service.get_letter_counts() return web.json_response({ 'success': True, 'letter_counts': letter_counts }) except Exception as e: logger.error(f"Error getting letter counts: {e}") return web.json_response({ 'success': False, 'error': str(e) }, status=500) async def get_lora_notes(self, request: web.Request) -> web.Response: """Get notes for a specific LoRA file""" try: lora_name = request.query.get('name') if not lora_name: return web.Response(text='Lora file name is required', status=400) notes = await self.service.get_lora_notes(lora_name) if notes is not None: return web.json_response({ 'success': True, 'notes': notes }) else: return web.json_response({ 'success': False, 'error': 'LoRA not found in cache' }, status=404) except Exception as e: logger.error(f"Error getting lora notes: {e}", exc_info=True) return web.json_response({ 'success': False, 'error': str(e) }, status=500) async def get_lora_trigger_words(self, request: web.Request) -> web.Response: """Get trigger words for a specific LoRA file""" try: lora_name = request.query.get('name') if not lora_name: return web.Response(text='Lora file name is required', status=400) trigger_words = await self.service.get_lora_trigger_words(lora_name) return web.json_response({ 'success': True, 'trigger_words': trigger_words }) except Exception as e: logger.error(f"Error getting lora trigger words: {e}", exc_info=True) return web.json_response({ 'success': False, 'error': str(e) }, status=500) async def get_lora_preview_url(self, request: web.Request) -> web.Response: """Get the static preview URL for a LoRA file""" try: lora_name = request.query.get('name') if not lora_name: return web.Response(text='Lora file name is required', status=400) preview_url = await self.service.get_lora_preview_url(lora_name) if preview_url: return web.json_response({ 'success': True, 'preview_url': preview_url }) else: return web.json_response({ 'success': False, 'error': 'No preview URL found for the specified lora' }, status=404) except Exception as e: logger.error(f"Error getting lora preview URL: {e}", exc_info=True) return web.json_response({ 'success': False, 'error': str(e) }, status=500) async def get_lora_civitai_url(self, request: web.Request) -> web.Response: """Get the Civitai URL for a LoRA file""" try: lora_name = request.query.get('name') if not lora_name: return web.Response(text='Lora file name is required', status=400) result = await self.service.get_lora_civitai_url(lora_name) if result['civitai_url']: return web.json_response({ 'success': True, **result }) else: return web.json_response({ 'success': False, 'error': 'No Civitai data found for the specified lora' }, status=404) except Exception as e: logger.error(f"Error getting lora Civitai URL: {e}", exc_info=True) return web.json_response({ 'success': False, 'error': str(e) }, status=500) async def get_folders(self, request: web.Request) -> web.Response: """Get all folders in the cache""" try: cache = await self.service.scanner.get_cached_data() return web.json_response({ 'folders': cache.folders }) except Exception as e: logger.error(f"Error getting folders: {e}") return web.json_response({ 'success': False, 'error': str(e) }, status=500) async def get_lora_roots(self, request: web.Request) -> web.Response: """Get all configured LoRA root directories""" try: return web.json_response({ 'roots': self.service.get_model_roots() }) except Exception as e: logger.error(f"Error getting LoRA roots: {e}") return web.json_response({ 'success': False, 'error': str(e) }, status=500) # Override get_models to add LoRA-specific response data async def get_models(self, request: web.Request) -> web.Response: """Get paginated LoRA data with LoRA-specific fields""" try: # Parse common query parameters params = self._parse_common_params(request) # Get data from service result = await self.service.get_paginated_data(**params) # Get all available folders from cache for LoRA-specific response cache = await self.service.scanner.get_cached_data() # Format response items with LoRA-specific structure formatted_result = { 'items': [await self.service.format_response(item) for item in result['items']], 'folders': cache.folders, # LoRA-specific: include folders in response 'total': result['total'], 'page': result['page'], 'page_size': result['page_size'], 'total_pages': result['total_pages'] } return web.json_response(formatted_result) except Exception as e: logger.error(f"Error in get_loras: {e}", exc_info=True) return web.json_response({"error": str(e)}, status=500) # CivitAI integration methods async def get_civitai_versions_lora(self, request: web.Request) -> web.Response: """Get available versions for a Civitai LoRA model with local availability info""" try: model_id = request.match_info['model_id'] response = await self.civitai_client.get_model_versions(model_id) if not response or not response.get('modelVersions'): return web.Response(status=404, text="Model not found") versions = response.get('modelVersions', []) model_type = response.get('type', '') # Check model type - should be LORA, LoCon, or DORA from ..utils.constants import VALID_LORA_TYPES if model_type.lower() not in VALID_LORA_TYPES: return web.json_response({ 'error': f"Model type mismatch. Expected LORA or LoCon, got {model_type}" }, status=400) # Check local availability for each version for version in versions: # Find the model file (type="Model") in the files list model_file = next((file for file in version.get('files', []) if file.get('type') == 'Model'), None) if model_file: sha256 = model_file.get('hashes', {}).get('SHA256') if sha256: # Set existsLocally and localPath at the version level version['existsLocally'] = self.service.has_hash(sha256) if version['existsLocally']: version['localPath'] = self.service.get_path_by_hash(sha256) # Also set the model file size at the version level for easier access version['modelSizeKB'] = model_file.get('sizeKB') else: # No model file found in this version version['existsLocally'] = False return web.json_response(versions) except Exception as e: logger.error(f"Error fetching LoRA model versions: {e}") return web.Response(status=500, text=str(e)) async def get_civitai_model_by_version(self, request: web.Request) -> web.Response: """Get CivitAI model details by model version ID""" try: model_version_id = request.match_info.get('modelVersionId') # Get model details from Civitai API model, error_msg = await self.civitai_client.get_model_version_info(model_version_id) if not model: # Log warning for failed model retrieval logger.warning(f"Failed to fetch model version {model_version_id}: {error_msg}") # Determine status code based on error message status_code = 404 if error_msg and "not found" in error_msg.lower() else 500 return web.json_response({ "success": False, "error": error_msg or "Failed to fetch model information" }, status=status_code) return web.json_response(model) except Exception as e: logger.error(f"Error fetching model details: {e}") return web.json_response({ "success": False, "error": str(e) }, status=500) async def get_civitai_model_by_hash(self, request: web.Request) -> web.Response: """Get CivitAI model details by hash""" try: hash = request.match_info.get('hash') model = await self.civitai_client.get_model_by_hash(hash) return web.json_response(model) except Exception as e: logger.error(f"Error fetching model details by hash: {e}") return web.json_response({ "success": False, "error": str(e) }, status=500) # Model management methods async def move_model(self, request: web.Request) -> web.Response: """Handle model move request""" try: data = await request.json() file_path = data.get('file_path') # full path of the model file target_path = data.get('target_path') # folder path to move the model to if not file_path or not target_path: return web.Response(text='File path and target path are required', status=400) # Check if source and destination are the same import os source_dir = os.path.dirname(file_path) if os.path.normpath(source_dir) == os.path.normpath(target_path): logger.info(f"Source and target directories are the same: {source_dir}") return web.json_response({'success': True, 'message': 'Source and target directories are the same'}) # Check if target file already exists file_name = os.path.basename(file_path) target_file_path = os.path.join(target_path, file_name).replace(os.sep, '/') if os.path.exists(target_file_path): return web.json_response({ 'success': False, 'error': f"Target file already exists: {target_file_path}" }, status=409) # 409 Conflict # Call scanner to handle the move operation success = await self.service.scanner.move_model(file_path, target_path) if success: return web.json_response({'success': True}) else: return web.Response(text='Failed to move model', status=500) except Exception as e: logger.error(f"Error moving model: {e}", exc_info=True) return web.Response(text=str(e), status=500) async def move_models_bulk(self, request: web.Request) -> web.Response: """Handle bulk model move request""" try: data = await request.json() file_paths = data.get('file_paths', []) # list of full paths of the model files target_path = data.get('target_path') # folder path to move the models to if not file_paths or not target_path: return web.Response(text='File paths and target path are required', status=400) results = [] import os for file_path in file_paths: # Check if source and destination are the same source_dir = os.path.dirname(file_path) if os.path.normpath(source_dir) == os.path.normpath(target_path): results.append({ "path": file_path, "success": True, "message": "Source and target directories are the same" }) continue # Check if target file already exists file_name = os.path.basename(file_path) target_file_path = os.path.join(target_path, file_name).replace(os.sep, '/') if os.path.exists(target_file_path): results.append({ "path": file_path, "success": False, "message": f"Target file already exists: {target_file_path}" }) continue # Try to move the model success = await self.service.scanner.move_model(file_path, target_path) results.append({ "path": file_path, "success": success, "message": "Success" if success else "Failed to move model" }) # Count successes and failures success_count = sum(1 for r in results if r["success"]) failure_count = len(results) - success_count return web.json_response({ 'success': True, 'message': f'Moved {success_count} of {len(file_paths)} models', 'results': results, 'success_count': success_count, 'failure_count': failure_count }) except Exception as e: logger.error(f"Error moving models in bulk: {e}", exc_info=True) return web.Response(text=str(e), status=500) async def get_lora_model_description(self, request: web.Request) -> web.Response: """Get model description for a Lora model""" try: # Get parameters model_id = request.query.get('model_id') file_path = request.query.get('file_path') if not model_id: return web.json_response({ 'success': False, 'error': 'Model ID is required' }, status=400) # Check if we already have the description stored in metadata description = None tags = [] creator = {} if file_path: import os from ..utils.metadata_manager import MetadataManager metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' metadata = await ModelRouteUtils.load_local_metadata(metadata_path) description = metadata.get('modelDescription') tags = metadata.get('tags', []) creator = metadata.get('creator', {}) # If description is not in metadata, fetch from CivitAI if not description: logger.info(f"Fetching model metadata for model ID: {model_id}") model_metadata, _ = await self.civitai_client.get_model_metadata(model_id) if model_metadata: description = model_metadata.get('description') tags = model_metadata.get('tags', []) creator = model_metadata.get('creator', {}) # Save the metadata to file if we have a file path and got metadata if file_path: try: metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' metadata = await ModelRouteUtils.load_local_metadata(metadata_path) metadata['modelDescription'] = description metadata['tags'] = tags # Ensure the civitai dict exists if 'civitai' not in metadata: metadata['civitai'] = {} # Store creator in the civitai nested structure metadata['civitai']['creator'] = creator await MetadataManager.save_metadata(file_path, metadata, True) except Exception as e: logger.error(f"Error saving model metadata: {e}") return web.json_response({ 'success': True, 'description': description or "
No model description available.
", 'tags': tags, 'creator': creator }) except Exception as e: logger.error(f"Error getting model metadata: {e}") return web.json_response({ 'success': False, 'error': str(e) }, status=500) async def get_trigger_words(self, request: web.Request) -> web.Response: """Get trigger words for specified LoRA models""" try: json_data = await request.json() lora_names = json_data.get("lora_names", []) node_ids = json_data.get("node_ids", []) all_trigger_words = [] for lora_name in lora_names: _, trigger_words = get_lora_info(lora_name) all_trigger_words.extend(trigger_words) # Format the trigger words trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else "" # Send update to all connected trigger word toggle nodes for node_id in node_ids: PromptServer.instance.send_sync("trigger_word_update", { "id": node_id, "message": trigger_words_text }) return web.json_response({"success": True}) except Exception as e: logger.error(f"Error getting trigger words: {e}") return web.json_response({ "success": False, "error": str(e) }, status=500)