diff --git a/.gitignore b/.gitignore index a5af59a1..708ef925 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,9 @@ coverage/ .coverage model_cache/ +# agent +.opencode/ + # Vue widgets development cache (but keep build output) vue-widgets/node_modules/ vue-widgets/.vite/ diff --git a/py/routes/lora_routes.py b/py/routes/lora_routes.py index 35e16c8b..ecc30a32 100644 --- a/py/routes/lora_routes.py +++ b/py/routes/lora_routes.py @@ -12,14 +12,15 @@ from ..utils.utils import get_lora_info logger = logging.getLogger(__name__) + class LoraRoutes(BaseModelRoutes): """LoRA-specific route controller""" - + def __init__(self): """Initialize LoRA routes with LoRA service""" super().__init__() self.template_name = "loras.html" - + async def initialize_services(self): """Initialize services from ServiceRegistry""" lora_scanner = await ServiceRegistry.get_lora_scanner() @@ -29,231 +30,225 @@ class LoraRoutes(BaseModelRoutes): # Attach service dependencies self.attach_service(self.service) - + def setup_routes(self, app: web.Application): """Setup LoRA routes""" # Schedule service initialization on app startup app.on_startup.append(lambda _: self.initialize_services()) # Setup common routes with 'loras' prefix (includes page route) - super().setup_routes(app, 'loras') + super().setup_routes(app, "loras") def setup_specific_routes(self, registrar: ModelRouteRegistrar, prefix: str): """Setup LoRA-specific routes""" # LoRA-specific query routes - registrar.add_prefixed_route('GET', '/api/lm/{prefix}/letter-counts', prefix, self.get_letter_counts) - registrar.add_prefixed_route('GET', '/api/lm/{prefix}/get-trigger-words', prefix, self.get_lora_trigger_words) - registrar.add_prefixed_route('GET', '/api/lm/{prefix}/usage-tips-by-path', prefix, self.get_lora_usage_tips_by_path) + registrar.add_prefixed_route( + "GET", "/api/lm/{prefix}/letter-counts", prefix, self.get_letter_counts + ) + registrar.add_prefixed_route( + "GET", + "/api/lm/{prefix}/get-trigger-words", + prefix, + self.get_lora_trigger_words, + ) + registrar.add_prefixed_route( + "GET", + "/api/lm/{prefix}/usage-tips-by-path", + prefix, + self.get_lora_usage_tips_by_path, + ) # Randomizer routes - registrar.add_prefixed_route('POST', '/api/lm/{prefix}/random-sample', prefix, self.get_random_loras) + registrar.add_prefixed_route( + "POST", "/api/lm/{prefix}/random-sample", prefix, self.get_random_loras + ) # ComfyUI integration - registrar.add_prefixed_route('POST', '/api/lm/{prefix}/get_trigger_words', prefix, self.get_trigger_words) - + registrar.add_prefixed_route( + "POST", "/api/lm/{prefix}/get_trigger_words", prefix, self.get_trigger_words + ) + def _parse_specific_params(self, request: web.Request) -> Dict: """Parse LoRA-specific parameters""" params = {} - + # LoRA-specific parameters - if 'first_letter' in request.query: - params['first_letter'] = request.query.get('first_letter') - + if "first_letter" in request.query: + params["first_letter"] = request.query.get("first_letter") + # Handle fuzzy search parameter name variation - if request.query.get('fuzzy') == 'true': - params['fuzzy_search'] = True - + if request.query.get("fuzzy") == "true": + params["fuzzy_search"] = True + # Handle additional filter parameters for LoRAs - if 'lora_hash' in request.query: - if not params.get('hash_filters'): - params['hash_filters'] = {} - params['hash_filters']['single_hash'] = request.query['lora_hash'].lower() - elif 'lora_hashes' in request.query: - if not params.get('hash_filters'): - params['hash_filters'] = {} - params['hash_filters']['multiple_hashes'] = [h.lower() for h in request.query['lora_hashes'].split(',')] - + if "lora_hash" in request.query: + if not params.get("hash_filters"): + params["hash_filters"] = {} + params["hash_filters"]["single_hash"] = request.query["lora_hash"].lower() + elif "lora_hashes" in request.query: + if not params.get("hash_filters"): + params["hash_filters"] = {} + params["hash_filters"]["multiple_hashes"] = [ + h.lower() for h in request.query["lora_hashes"].split(",") + ] + return params - + def _validate_civitai_model_type(self, model_type: str) -> bool: """Validate CivitAI model type for LoRA""" from ..utils.constants import VALID_LORA_TYPES + return model_type.lower() in VALID_LORA_TYPES - + def _get_expected_model_types(self) -> str: """Get expected model types string for error messages""" return "LORA, LoCon, or DORA" - + # LoRA-specific route handlers async def get_letter_counts(self, request: web.Request) -> web.Response: """Get count of LoRAs for each letter of the alphabet""" try: letter_counts = await self.service.get_letter_counts() - return web.json_response({ - 'success': True, - 'letter_counts': letter_counts - }) + return web.json_response({"success": True, "letter_counts": letter_counts}) except Exception as e: logger.error(f"Error getting letter counts: {e}") - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - + return web.json_response({"success": False, "error": str(e)}, status=500) + async def get_lora_notes(self, request: web.Request) -> web.Response: """Get notes for a specific LoRA file""" try: - lora_name = request.query.get('name') + lora_name = request.query.get("name") if not lora_name: - return web.Response(text='Lora file name is required', status=400) - + return web.Response(text="Lora file name is required", status=400) + notes = await self.service.get_lora_notes(lora_name) if notes is not None: - return web.json_response({ - 'success': True, - 'notes': notes - }) + return web.json_response({"success": True, "notes": notes}) else: - return web.json_response({ - 'success': False, - 'error': 'LoRA not found in cache' - }, status=404) - + return web.json_response( + {"success": False, "error": "LoRA not found in cache"}, status=404 + ) + except Exception as e: logger.error(f"Error getting lora notes: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - + return web.json_response({"success": False, "error": str(e)}, status=500) + async def get_lora_trigger_words(self, request: web.Request) -> web.Response: """Get trigger words for a specific LoRA file""" try: - lora_name = request.query.get('name') + lora_name = request.query.get("name") if not lora_name: - return web.Response(text='Lora file name is required', status=400) - + return web.Response(text="Lora file name is required", status=400) + trigger_words = await self.service.get_lora_trigger_words(lora_name) - return web.json_response({ - 'success': True, - 'trigger_words': trigger_words - }) - + return web.json_response({"success": True, "trigger_words": trigger_words}) + except Exception as e: logger.error(f"Error getting lora trigger words: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - + return web.json_response({"success": False, "error": str(e)}, status=500) + async def get_lora_usage_tips_by_path(self, request: web.Request) -> web.Response: """Get usage tips for a LoRA by its relative path""" try: - relative_path = request.query.get('relative_path') + relative_path = request.query.get("relative_path") if not relative_path: - return web.Response(text='Relative path is required', status=400) - - usage_tips = await self.service.get_lora_usage_tips_by_relative_path(relative_path) - return web.json_response({ - 'success': True, - 'usage_tips': usage_tips or '' - }) - + return web.Response(text="Relative path is required", status=400) + + usage_tips = await self.service.get_lora_usage_tips_by_relative_path( + relative_path + ) + return web.json_response({"success": True, "usage_tips": usage_tips or ""}) + except Exception as e: logger.error(f"Error getting lora usage tips by path: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - + return web.json_response({"success": False, "error": str(e)}, status=500) + async def get_lora_preview_url(self, request: web.Request) -> web.Response: """Get the static preview URL for a LoRA file""" try: - lora_name = request.query.get('name') + lora_name = request.query.get("name") if not lora_name: - return web.Response(text='Lora file name is required', status=400) - + return web.Response(text="Lora file name is required", status=400) + preview_url = await self.service.get_lora_preview_url(lora_name) if preview_url: - return web.json_response({ - 'success': True, - 'preview_url': preview_url - }) + return web.json_response({"success": True, "preview_url": preview_url}) else: - return web.json_response({ - 'success': False, - 'error': 'No preview URL found for the specified lora' - }, status=404) - + return web.json_response( + { + "success": False, + "error": "No preview URL found for the specified lora", + }, + status=404, + ) + except Exception as e: logger.error(f"Error getting lora preview URL: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - + return web.json_response({"success": False, "error": str(e)}, status=500) + async def get_lora_civitai_url(self, request: web.Request) -> web.Response: """Get the Civitai URL for a LoRA file""" try: - lora_name = request.query.get('name') + lora_name = request.query.get("name") if not lora_name: - return web.Response(text='Lora file name is required', status=400) - + return web.Response(text="Lora file name is required", status=400) + result = await self.service.get_lora_civitai_url(lora_name) - if result['civitai_url']: - return web.json_response({ - 'success': True, - **result - }) + if result["civitai_url"]: + return web.json_response({"success": True, **result}) else: - return web.json_response({ - 'success': False, - 'error': 'No Civitai data found for the specified lora' - }, status=404) - + return web.json_response( + { + "success": False, + "error": "No Civitai data found for the specified lora", + }, + status=404, + ) + except Exception as e: logger.error(f"Error getting lora Civitai URL: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - + return web.json_response({"success": False, "error": str(e)}, status=500) + async def get_random_loras(self, request: web.Request) -> web.Response: """Get random LoRAs based on filters and strength ranges""" try: json_data = await request.json() # Parse parameters - count = json_data.get('count', 5) - count_min = json_data.get('count_min') - count_max = json_data.get('count_max') - model_strength_min = float(json_data.get('model_strength_min', 0.0)) - model_strength_max = float(json_data.get('model_strength_max', 1.0)) - use_same_clip_strength = json_data.get('use_same_clip_strength', True) - clip_strength_min = float(json_data.get('clip_strength_min', 0.0)) - clip_strength_max = float(json_data.get('clip_strength_max', 1.0)) - locked_loras = json_data.get('locked_loras', []) - pool_config = json_data.get('pool_config') + count = json_data.get("count", 5) + count_min = json_data.get("count_min") + count_max = json_data.get("count_max") + model_strength_min = float(json_data.get("model_strength_min", 0.0)) + model_strength_max = float(json_data.get("model_strength_max", 1.0)) + use_same_clip_strength = json_data.get("use_same_clip_strength", True) + clip_strength_min = float(json_data.get("clip_strength_min", 0.0)) + clip_strength_max = float(json_data.get("clip_strength_max", 1.0)) + locked_loras = json_data.get("locked_loras", []) + pool_config = json_data.get("pool_config") # Determine target count if count_min is not None and count_max is not None: import random + target_count = random.randint(count_min, count_max) else: target_count = count # Validate parameters if target_count < 1 or target_count > 100: - return web.json_response({ - 'success': False, - 'error': 'Count must be between 1 and 100' - }, status=400) + return web.json_response( + {"success": False, "error": "Count must be between 1 and 100"}, + status=400, + ) - if model_strength_min < 0 or model_strength_max > 10: - return web.json_response({ - 'success': False, - 'error': 'Model strength must be between 0 and 10' - }, status=400) + if model_strength_min < -10 or model_strength_max > 10: + return web.json_response( + { + "success": False, + "error": "Model strength must be between -10 and 10", + }, + status=400, + ) # Get random LoRAs from service result_loras = await self.service.get_random_loras( @@ -264,27 +259,19 @@ class LoraRoutes(BaseModelRoutes): clip_strength_min=clip_strength_min, clip_strength_max=clip_strength_max, locked_loras=locked_loras, - pool_config=pool_config + pool_config=pool_config, ) - return web.json_response({ - 'success': True, - 'loras': result_loras, - 'count': len(result_loras) - }) + return web.json_response( + {"success": True, "loras": result_loras, "count": len(result_loras)} + ) except ValueError as e: logger.error(f"Invalid parameter for random LoRAs: {e}") - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=400) + return web.json_response({"success": False, "error": str(e)}, status=400) except Exception as e: logger.error(f"Error getting random LoRAs: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) + return web.json_response({"success": False, "error": str(e)}, status=500) async def get_trigger_words(self, request: web.Request) -> web.Response: """Get trigger words for specified LoRA models""" @@ -292,15 +279,17 @@ class LoraRoutes(BaseModelRoutes): json_data = await request.json() lora_names = json_data.get("lora_names", []) node_ids = json_data.get("node_ids", []) - + all_trigger_words = [] for lora_name in lora_names: _, trigger_words = get_lora_info(lora_name) all_trigger_words.extend(trigger_words) - + # Format the trigger words - trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else "" - + trigger_words_text = ( + ",, ".join(all_trigger_words) if all_trigger_words else "" + ) + # Send update to all connected trigger word toggle nodes for entry in node_ids: node_identifier = entry @@ -314,21 +303,15 @@ class LoraRoutes(BaseModelRoutes): except (TypeError, ValueError): parsed_node_id = node_identifier - payload = { - "id": parsed_node_id, - "message": trigger_words_text - } + payload = {"id": parsed_node_id, "message": trigger_words_text} if graph_identifier is not None: payload["graph_id"] = str(graph_identifier) PromptServer.instance.send_sync("trigger_word_update", payload) - + return web.json_response({"success": True}) except Exception as e: logger.error(f"Error getting trigger words: {e}") - return web.json_response({ - "success": False, - "error": str(e) - }, status=500) + return web.json_response({"success": False, "error": str(e)}, status=500) diff --git a/tests/routes/test_randomizer_endpoints.py b/tests/routes/test_randomizer_endpoints.py index b8283cf3..4541a849 100644 --- a/tests/routes/test_randomizer_endpoints.py +++ b/tests/routes/test_randomizer_endpoints.py @@ -37,155 +37,167 @@ async def test_get_random_loras_success(routes): """Test successful random LoRA generation""" routes.service.random_loras = [ { - 'name': 'test_lora_1', - 'strength': 0.8, - 'clipStrength': 0.8, - 'active': True, - 'expanded': False, - 'locked': False + "name": "test_lora_1", + "strength": 0.8, + "clipStrength": 0.8, + "active": True, + "expanded": False, + "locked": False, }, { - 'name': 'test_lora_2', - 'strength': 0.6, - 'clipStrength': 0.6, - 'active': True, - 'expanded': False, - 'locked': False - } + "name": "test_lora_2", + "strength": 0.6, + "clipStrength": 0.6, + "active": True, + "expanded": False, + "locked": False, + }, ] - request = DummyRequest(json_data={ - 'count': 5, - 'model_strength_min': 0.5, - 'model_strength_max': 1.0, - 'use_same_clip_strength': True, - 'locked_loras': [] - }) + request = DummyRequest( + json_data={ + "count": 5, + "model_strength_min": 0.5, + "model_strength_max": 1.0, + "use_same_clip_strength": True, + "locked_loras": [], + } + ) response = await routes.get_random_loras(request) payload = json.loads(response.text) assert response.status == 200 - assert payload['success'] is True - assert 'loras' in payload - assert payload['count'] == 2 + assert payload["success"] is True + assert "loras" in payload + assert payload["count"] == 2 async def test_get_random_loras_with_range(routes): """Test random LoRAs with count range""" routes.service.random_loras = [ { - 'name': 'test_lora_1', - 'strength': 0.8, - 'clipStrength': 0.8, - 'active': True, - 'expanded': False, - 'locked': False + "name": "test_lora_1", + "strength": 0.8, + "clipStrength": 0.8, + "active": True, + "expanded": False, + "locked": False, } ] - request = DummyRequest(json_data={ - 'count_min': 3, - 'count_max': 7, - 'model_strength_min': 0.0, - 'model_strength_max': 1.0, - 'use_same_clip_strength': True - }) + request = DummyRequest( + json_data={ + "count_min": 3, + "count_max": 7, + "model_strength_min": 0.0, + "model_strength_max": 1.0, + "use_same_clip_strength": True, + } + ) response = await routes.get_random_loras(request) payload = json.loads(response.text) assert response.status == 200 - assert payload['success'] is True + assert payload["success"] is True async def test_get_random_loras_invalid_count(routes): """Test invalid count parameter""" - request = DummyRequest(json_data={ - 'count': 150, # Over limit - 'model_strength_min': 0.0, - 'model_strength_max': 1.0 - }) + request = DummyRequest( + json_data={ + "count": 150, # Over limit + "model_strength_min": 0.0, + "model_strength_max": 1.0, + } + ) response = await routes.get_random_loras(request) payload = json.loads(response.text) assert response.status == 400 - assert payload['success'] is False - assert 'Count must be between 1 and 100' in payload['error'] + assert payload["success"] is False + assert "Count must be between 1 and 100" in payload["error"] async def test_get_random_loras_invalid_strength(routes): """Test invalid strength range""" - request = DummyRequest(json_data={ - 'count': 5, - 'model_strength_min': -0.5, # Invalid - 'model_strength_max': 1.0 - }) + request = DummyRequest( + json_data={ + "count": 5, + "model_strength_min": -11, # Invalid (below -10) + "model_strength_max": 1.0, + } + ) response = await routes.get_random_loras(request) payload = json.loads(response.text) assert response.status == 400 - assert payload['success'] is False + assert payload["success"] is False + assert "Model strength must be between -10 and 10" in payload["error"] async def test_get_random_loras_with_locked(routes): """Test random LoRAs with locked items""" routes.service.random_loras = [ { - 'name': 'new_lora', - 'strength': 0.7, - 'clipStrength': 0.7, - 'active': True, - 'expanded': False, - 'locked': False + "name": "new_lora", + "strength": 0.7, + "clipStrength": 0.7, + "active": True, + "expanded": False, + "locked": False, }, { - 'name': 'locked_lora', - 'strength': 0.9, - 'clipStrength': 0.9, - 'active': True, - 'expanded': False, - 'locked': True - } + "name": "locked_lora", + "strength": 0.9, + "clipStrength": 0.9, + "active": True, + "expanded": False, + "locked": True, + }, ] - request = DummyRequest(json_data={ - 'count': 5, - 'model_strength_min': 0.5, - 'model_strength_max': 1.0, - 'use_same_clip_strength': True, - 'locked_loras': [ - { - 'name': 'locked_lora', - 'strength': 0.9, - 'clipStrength': 0.9, - 'active': True, - 'expanded': False, - 'locked': True - } - ] - }) + request = DummyRequest( + json_data={ + "count": 5, + "model_strength_min": 0.5, + "model_strength_max": 1.0, + "use_same_clip_strength": True, + "locked_loras": [ + { + "name": "locked_lora", + "strength": 0.9, + "clipStrength": 0.9, + "active": True, + "expanded": False, + "locked": True, + } + ], + } + ) response = await routes.get_random_loras(request) payload = json.loads(response.text) assert response.status == 200 - assert payload['success'] is True + assert payload["success"] is True async def test_get_random_loras_error(routes, monkeypatch): """Test error handling""" + async def failing(*_args, **_kwargs): raise RuntimeError("Service error") routes.service.get_random_loras = failing - request = DummyRequest(json_data={'count': 5}) + request = DummyRequest(json_data={"count": 5}) response = await routes.get_random_loras(request) payload = json.loads(response.text) assert response.status == 500 - assert payload['success'] is False - assert 'error' in payload + assert payload["success"] is False + assert "error" in payload diff --git a/vue-widgets/src/components/LoraPoolWidget.vue b/vue-widgets/src/components/LoraPoolWidget.vue index e531861b..12321ee9 100644 --- a/vue-widgets/src/components/LoraPoolWidget.vue +++ b/vue-widgets/src/components/LoraPoolWidget.vue @@ -104,7 +104,7 @@ onMounted(async () => { // Handle external value updates (e.g., loading workflow, paste) props.widget.onSetValue = (v) => { - state.restoreFromConfig(v) + state.restoreFromConfig(v as LoraPoolConfig | LegacyLoraPoolConfig) state.refreshPreview() } diff --git a/vue-widgets/src/components/lora-randomizer/LoraRandomizerSettingsView.vue b/vue-widgets/src/components/lora-randomizer/LoraRandomizerSettingsView.vue index b9afe89b..fd6138bd 100644 --- a/vue-widgets/src/components/lora-randomizer/LoraRandomizerSettingsView.vue +++ b/vue-widgets/src/components/lora-randomizer/LoraRandomizerSettingsView.vue @@ -7,8 +7,8 @@
-
-
-
-
- - -
-
- - -
+
+
- -
-