import asyncio import logging from aiohttp import web from typing import Dict from server import PromptServer # type: ignore from .base_model_routes import BaseModelRoutes from .model_route_registrar import ModelRouteRegistrar from ..services.lora_service import LoraService from ..services.service_registry import ServiceRegistry from ..utils.utils import get_lora_info logger = logging.getLogger(__name__) class LoraRoutes(BaseModelRoutes): """LoRA-specific route controller""" def __init__(self): """Initialize LoRA routes with LoRA service""" super().__init__() self.template_name = "loras.html" async def initialize_services(self): """Initialize services from ServiceRegistry""" lora_scanner = await ServiceRegistry.get_lora_scanner() update_service = await ServiceRegistry.get_model_update_service() self.service = LoraService(lora_scanner, update_service=update_service) self.set_model_update_service(update_service) # Attach service dependencies self.attach_service(self.service) def setup_routes(self, app: web.Application): """Setup LoRA routes""" # Schedule service initialization on app startup app.on_startup.append(lambda _: self.initialize_services()) # Setup common routes with 'loras' prefix (includes page route) super().setup_routes(app, "loras") def setup_specific_routes(self, registrar: ModelRouteRegistrar, prefix: str): """Setup LoRA-specific routes""" # LoRA-specific query routes registrar.add_prefixed_route( "GET", "/api/lm/{prefix}/letter-counts", prefix, self.get_letter_counts ) registrar.add_prefixed_route( "GET", "/api/lm/{prefix}/get-trigger-words", prefix, self.get_lora_trigger_words, ) registrar.add_prefixed_route( "GET", "/api/lm/{prefix}/usage-tips-by-path", prefix, self.get_lora_usage_tips_by_path, ) # Randomizer routes registrar.add_prefixed_route( "POST", "/api/lm/{prefix}/random-sample", prefix, self.get_random_loras ) # Cycler routes registrar.add_prefixed_route( "POST", "/api/lm/{prefix}/cycler-list", prefix, self.get_cycler_list ) # ComfyUI integration registrar.add_prefixed_route( "POST", "/api/lm/{prefix}/get_trigger_words", prefix, self.get_trigger_words ) def _parse_specific_params(self, request: web.Request) -> Dict: """Parse LoRA-specific parameters""" params = {} # LoRA-specific parameters if "first_letter" in request.query: params["first_letter"] = request.query.get("first_letter") # Handle fuzzy search parameter name variation if request.query.get("fuzzy") == "true": params["fuzzy_search"] = True # Handle additional filter parameters for LoRAs if "lora_hash" in request.query: if not params.get("hash_filters"): params["hash_filters"] = {} params["hash_filters"]["single_hash"] = request.query["lora_hash"].lower() elif "lora_hashes" in request.query: if not params.get("hash_filters"): params["hash_filters"] = {} params["hash_filters"]["multiple_hashes"] = [ h.lower() for h in request.query["lora_hashes"].split(",") ] return params def _validate_civitai_model_type(self, model_type: str) -> bool: """Validate CivitAI model type for LoRA""" from ..utils.constants import VALID_LORA_TYPES return model_type.lower() in VALID_LORA_TYPES def _get_expected_model_types(self) -> str: """Get expected model types string for error messages""" return "LORA, LoCon, or DORA" # LoRA-specific route handlers async def get_letter_counts(self, request: web.Request) -> web.Response: """Get count of LoRAs for each letter of the alphabet""" try: letter_counts = await self.service.get_letter_counts() return web.json_response({"success": True, "letter_counts": letter_counts}) except Exception as e: logger.error(f"Error getting letter counts: {e}") return web.json_response({"success": False, "error": str(e)}, status=500) async def get_lora_notes(self, request: web.Request) -> web.Response: """Get notes for a specific LoRA file""" try: lora_name = request.query.get("name") if not lora_name: return web.Response(text="Lora file name is required", status=400) notes = await self.service.get_lora_notes(lora_name) if notes is not None: return web.json_response({"success": True, "notes": notes}) else: return web.json_response( {"success": False, "error": "LoRA not found in cache"}, status=404 ) except Exception as e: logger.error(f"Error getting lora notes: {e}", exc_info=True) return web.json_response({"success": False, "error": str(e)}, status=500) async def get_lora_trigger_words(self, request: web.Request) -> web.Response: """Get trigger words for a specific LoRA file""" try: lora_name = request.query.get("name") if not lora_name: return web.Response(text="Lora file name is required", status=400) trigger_words = await self.service.get_lora_trigger_words(lora_name) return web.json_response({"success": True, "trigger_words": trigger_words}) except Exception as e: logger.error(f"Error getting lora trigger words: {e}", exc_info=True) return web.json_response({"success": False, "error": str(e)}, status=500) async def get_lora_usage_tips_by_path(self, request: web.Request) -> web.Response: """Get usage tips for a LoRA by its relative path""" try: relative_path = request.query.get("relative_path") if not relative_path: return web.Response(text="Relative path is required", status=400) usage_tips = await self.service.get_lora_usage_tips_by_relative_path( relative_path ) return web.json_response({"success": True, "usage_tips": usage_tips or ""}) except Exception as e: logger.error(f"Error getting lora usage tips by path: {e}", exc_info=True) return web.json_response({"success": False, "error": str(e)}, status=500) async def get_lora_preview_url(self, request: web.Request) -> web.Response: """Get the static preview URL for a LoRA file""" try: lora_name = request.query.get("name") if not lora_name: return web.Response(text="Lora file name is required", status=400) preview_url = await self.service.get_lora_preview_url(lora_name) if preview_url: return web.json_response({"success": True, "preview_url": preview_url}) else: return web.json_response( { "success": False, "error": "No preview URL found for the specified lora", }, status=404, ) except Exception as e: logger.error(f"Error getting lora preview URL: {e}", exc_info=True) return web.json_response({"success": False, "error": str(e)}, status=500) async def get_lora_civitai_url(self, request: web.Request) -> web.Response: """Get the Civitai URL for a LoRA file""" try: lora_name = request.query.get("name") if not lora_name: return web.Response(text="Lora file name is required", status=400) result = await self.service.get_lora_civitai_url(lora_name) if result["civitai_url"]: return web.json_response({"success": True, **result}) else: return web.json_response( { "success": False, "error": "No Civitai data found for the specified lora", }, status=404, ) except Exception as e: logger.error(f"Error getting lora Civitai URL: {e}", exc_info=True) return web.json_response({"success": False, "error": str(e)}, status=500) async def get_random_loras(self, request: web.Request) -> web.Response: """Get random LoRAs based on filters and strength ranges""" try: json_data = await request.json() # Parse parameters count = json_data.get("count", 5) count_min = json_data.get("count_min") count_max = json_data.get("count_max") model_strength_min = float(json_data.get("model_strength_min", 0.0)) model_strength_max = float(json_data.get("model_strength_max", 1.0)) use_same_clip_strength = json_data.get("use_same_clip_strength", True) clip_strength_min = float(json_data.get("clip_strength_min", 0.0)) clip_strength_max = float(json_data.get("clip_strength_max", 1.0)) locked_loras = json_data.get("locked_loras", []) pool_config = json_data.get("pool_config") use_recommended_strength = json_data.get("use_recommended_strength", False) recommended_strength_scale_min = float( json_data.get("recommended_strength_scale_min", 0.5) ) recommended_strength_scale_max = float( json_data.get("recommended_strength_scale_max", 1.0) ) # Determine target count if count_min is not None and count_max is not None: import random target_count = random.randint(count_min, count_max) else: target_count = count # Validate parameters if target_count < 1 or target_count > 100: return web.json_response( {"success": False, "error": "Count must be between 1 and 100"}, status=400, ) if model_strength_min < -10 or model_strength_max > 10: return web.json_response( { "success": False, "error": "Model strength must be between -10 and 10", }, status=400, ) # Get random LoRAs from service result_loras = await self.service.get_random_loras( count=target_count, model_strength_min=model_strength_min, model_strength_max=model_strength_max, use_same_clip_strength=use_same_clip_strength, clip_strength_min=clip_strength_min, clip_strength_max=clip_strength_max, locked_loras=locked_loras, pool_config=pool_config, use_recommended_strength=use_recommended_strength, recommended_strength_scale_min=recommended_strength_scale_min, recommended_strength_scale_max=recommended_strength_scale_max, ) return web.json_response( {"success": True, "loras": result_loras, "count": len(result_loras)} ) except ValueError as e: logger.error(f"Invalid parameter for random LoRAs: {e}") return web.json_response({"success": False, "error": str(e)}, status=400) except Exception as e: logger.error(f"Error getting random LoRAs: {e}", exc_info=True) return web.json_response({"success": False, "error": str(e)}, status=500) async def get_cycler_list(self, request: web.Request) -> web.Response: """Get filtered and sorted LoRA list for cycler widget""" try: json_data = await request.json() # Parse parameters pool_config = json_data.get("pool_config") sort_by = json_data.get("sort_by", "filename") # Get cycler list from service lora_list = await self.service.get_cycler_list( pool_config=pool_config, sort_by=sort_by ) return web.json_response( {"success": True, "loras": lora_list, "count": len(lora_list)} ) except Exception as e: logger.error(f"Error getting cycler list: {e}", exc_info=True) return web.json_response({"success": False, "error": str(e)}, status=500) async def get_trigger_words(self, request: web.Request) -> web.Response: """Get trigger words for specified LoRA models""" try: json_data = await request.json() lora_names = json_data.get("lora_names", []) node_ids = json_data.get("node_ids", []) all_trigger_words = [] for lora_name in lora_names: _, trigger_words = get_lora_info(lora_name) all_trigger_words.extend(trigger_words) # Format the trigger words trigger_words_text = ( ",, ".join(all_trigger_words) if all_trigger_words else "" ) # Send update to all connected trigger word toggle nodes for entry in node_ids: node_identifier = entry graph_identifier = None if isinstance(entry, dict): node_identifier = entry.get("node_id") graph_identifier = entry.get("graph_id") try: parsed_node_id = int(node_identifier) except (TypeError, ValueError): parsed_node_id = node_identifier payload = {"id": parsed_node_id, "message": trigger_words_text} if graph_identifier is not None: payload["graph_id"] = str(graph_identifier) PromptServer.instance.send_sync("trigger_word_update", payload) return web.json_response({"success": True}) except Exception as e: logger.error(f"Error getting trigger words: {e}") return web.json_response({"success": False, "error": str(e)}, status=500)