diff --git a/py/nodes/lora_randomizer.py b/py/nodes/lora_randomizer.py index abb0793d..32f3a2c3 100644 --- a/py/nodes/lora_randomizer.py +++ b/py/nodes/lora_randomizer.py @@ -152,6 +152,15 @@ class LoraRandomizerNode: use_same_clip_strength = randomizer_config.get("use_same_clip_strength", True) clip_strength_min = randomizer_config.get("clip_strength_min", 0.0) clip_strength_max = randomizer_config.get("clip_strength_max", 1.0) + use_recommended_strength = randomizer_config.get( + "use_recommended_strength", False + ) + recommended_strength_scale_min = randomizer_config.get( + "recommended_strength_scale_min", 0.5 + ) + recommended_strength_scale_max = randomizer_config.get( + "recommended_strength_scale_max", 1.0 + ) # Extract locked LoRAs from input locked_loras = [lora for lora in input_loras if lora.get("locked", False)] @@ -170,6 +179,9 @@ class LoraRandomizerNode: count_mode=count_mode, count_min=count_min, count_max=count_max, + use_recommended_strength=use_recommended_strength, + recommended_strength_scale_min=recommended_strength_scale_min, + recommended_strength_scale_max=recommended_strength_scale_max, ) return result_loras diff --git a/py/routes/lora_routes.py b/py/routes/lora_routes.py index ecc30a32..cbeab57a 100644 --- a/py/routes/lora_routes.py +++ b/py/routes/lora_routes.py @@ -225,6 +225,13 @@ class LoraRoutes(BaseModelRoutes): clip_strength_max = float(json_data.get("clip_strength_max", 1.0)) locked_loras = json_data.get("locked_loras", []) pool_config = json_data.get("pool_config") + use_recommended_strength = json_data.get("use_recommended_strength", False) + recommended_strength_scale_min = float( + json_data.get("recommended_strength_scale_min", 0.5) + ) + recommended_strength_scale_max = float( + json_data.get("recommended_strength_scale_max", 1.0) + ) # Determine target count if count_min is not None and count_max is not None: @@ -260,6 +267,9 @@ class LoraRoutes(BaseModelRoutes): clip_strength_max=clip_strength_max, locked_loras=locked_loras, pool_config=pool_config, + use_recommended_strength=use_recommended_strength, + recommended_strength_scale_min=recommended_strength_scale_min, + recommended_strength_scale_max=recommended_strength_scale_max, ) return web.json_response( diff --git a/py/services/lora_service.py b/py/services/lora_service.py index 3c900423..f55b1fdf 100644 --- a/py/services/lora_service.py +++ b/py/services/lora_service.py @@ -228,6 +228,9 @@ class LoraService(BaseModelService): count_mode: str = "fixed", count_min: int = 3, count_max: int = 7, + use_recommended_strength: bool = False, + recommended_strength_scale_min: float = 0.5, + recommended_strength_scale_max: float = 1.0, ) -> List[Dict]: """ Get random LoRAs with specified strength ranges. @@ -244,11 +247,37 @@ class LoraService(BaseModelService): count_mode: How to determine count ('fixed' or 'range') count_min: Minimum count for range mode count_max: Maximum count for range mode + use_recommended_strength: Whether to use recommended strength from usage_tips + recommended_strength_scale_min: Minimum scale factor for recommended strength + recommended_strength_scale_max: Maximum scale factor for recommended strength Returns: List of LoRA dicts with randomized strengths """ import random + import json + + def get_recommended_strength(lora_data: Dict) -> Optional[float]: + """Parse usage_tips JSON and extract recommended strength""" + try: + usage_tips = lora_data.get("usage_tips", "") + if not usage_tips: + return None + tips_data = json.loads(usage_tips) + return tips_data.get("strength") + except (json.JSONDecodeError, TypeError, AttributeError): + return None + + def get_recommended_clip_strength(lora_data: Dict) -> Optional[float]: + """Parse usage_tips JSON and extract recommended clip strength""" + try: + usage_tips = lora_data.get("usage_tips", "") + if not usage_tips: + return None + tips_data = json.loads(usage_tips) + return tips_data.get("clipStrength") + except (json.JSONDecodeError, TypeError, AttributeError): + return None if locked_loras is None: locked_loras = [] @@ -296,10 +325,35 @@ class LoraService(BaseModelService): # Generate random strengths for selected LoRAs result_loras = [] for lora in selected: - model_str = round(random.uniform(model_strength_min, model_strength_max), 2) + if use_recommended_strength: + recommended_strength = get_recommended_strength(lora) + if recommended_strength is not None: + scale = random.uniform( + recommended_strength_scale_min, recommended_strength_scale_max + ) + model_str = round(recommended_strength * scale, 2) + else: + model_str = round( + random.uniform(model_strength_min, model_strength_max), 2 + ) + else: + model_str = round( + random.uniform(model_strength_min, model_strength_max), 2 + ) if use_same_clip_strength: clip_str = model_str + elif use_recommended_strength: + recommended_clip_strength = get_recommended_clip_strength(lora) + if recommended_clip_strength is not None: + scale = random.uniform( + recommended_strength_scale_min, recommended_strength_scale_max + ) + clip_str = round(recommended_clip_strength * scale, 2) + else: + clip_str = round( + random.uniform(clip_strength_min, clip_strength_max), 2 + ) else: clip_str = round( random.uniform(clip_strength_min, clip_strength_max), 2 diff --git a/tests/nodes/test_lora_randomizer.py b/tests/nodes/test_lora_randomizer.py index d22c58a8..fe8ff3e7 100644 --- a/tests/nodes/test_lora_randomizer.py +++ b/tests/nodes/test_lora_randomizer.py @@ -51,6 +51,9 @@ def randomizer_config_fixed(): "clip_strength_min": 0.5, "clip_strength_max": 1.0, "roll_mode": "fixed", + "use_recommended_strength": False, + "recommended_strength_scale_min": 0.5, + "recommended_strength_scale_max": 1.0, } @@ -68,6 +71,9 @@ def randomizer_config_always(): "clip_strength_min": 0.5, "clip_strength_max": 1.0, "roll_mode": "always", + "use_recommended_strength": False, + "recommended_strength_scale_min": 0.5, + "recommended_strength_scale_max": 1.0, } @@ -319,3 +325,81 @@ async def test_execution_stack_always_from_input_loras_not_ui_loras( assert execution_stack[0][2] == 0.8 assert execution_stack[1][1] == 0.6 assert execution_stack[1][2] == 0.6 + + +@pytest.fixture +def randomizer_config_with_recommended_strength(): + """Randomizer config with recommended strength enabled""" + return { + "count_mode": "fixed", + "count_fixed": 3, + "count_min": 2, + "count_max": 5, + "model_strength_min": 0.5, + "model_strength_max": 1.0, + "use_same_clip_strength": True, + "clip_strength_min": 0.5, + "clip_strength_max": 1.0, + "roll_mode": "always", + "use_recommended_strength": True, + "recommended_strength_scale_min": 0.6, + "recommended_strength_scale_max": 0.8, + } + + +@pytest.mark.asyncio +async def test_recommended_strength_config_passed_to_service( + randomizer_node, + sample_loras, + randomizer_config_with_recommended_strength, + mock_scanner, + monkeypatch, +): + """Test that recommended strength config is passed to service when enabled""" + from py.services.lora_service import LoraService + from unittest.mock import AsyncMock, patch + + # Mock LoraService.get_random_loras to verify parameters + mock_get_random_loras = AsyncMock( + return_value=[ + { + "name": "new_lora.safetensors", + "strength": 0.7, + "clipStrength": 0.7, + "active": True, + "expanded": False, + "locked": False, + } + ] + ) + + with patch.object(LoraService, "__init__", return_value=None): + with patch.object(LoraService, "get_random_loras", mock_get_random_loras): + monkeypatch.setattr( + service_registry.ServiceRegistry, + "get_lora_scanner", + AsyncMock(return_value=mock_scanner), + ) + + mock_scanner._cache.raw_data = [ + { + "file_name": "new_lora.safetensors", + "file_path": "/path/to/new_lora.safetensors", + "folder": "", + } + ] + + result = await randomizer_node.randomize( + randomizer_config_with_recommended_strength, + sample_loras, + pool_config=None, + ) + + # Verify service was called + assert mock_get_random_loras.called + + # Verify recommended strength parameters were passed + call_kwargs = mock_get_random_loras.call_args[1] + assert call_kwargs["use_recommended_strength"] is True + assert call_kwargs["recommended_strength_scale_min"] == 0.6 + assert call_kwargs["recommended_strength_scale_max"] == 0.8 diff --git a/tests/routes/test_randomizer_endpoints.py b/tests/routes/test_randomizer_endpoints.py index 4541a849..48374698 100644 --- a/tests/routes/test_randomizer_endpoints.py +++ b/tests/routes/test_randomizer_endpoints.py @@ -21,8 +21,10 @@ class StubLoraService: def __init__(self): self.random_loras = [] + self.last_get_random_loras_kwargs = {} async def get_random_loras(self, **kwargs): + self.last_get_random_loras_kwargs = kwargs return self.random_loras @@ -201,3 +203,56 @@ async def test_get_random_loras_error(routes, monkeypatch): assert response.status == 500 assert payload["success"] is False assert "error" in payload + + +async def test_get_random_loras_with_recommended_strength_enabled(routes): + """Test random LoRAs with recommended strength feature enabled""" + request = DummyRequest( + json_data={ + "count": 5, + "model_strength_min": 0.5, + "model_strength_max": 1.0, + "use_same_clip_strength": True, + "use_recommended_strength": True, + "recommended_strength_scale_min": 0.6, + "recommended_strength_scale_max": 0.8, + "locked_loras": [], + } + ) + + response = await routes.get_random_loras(request) + payload = json.loads(response.text) + + assert response.status == 200 + assert payload["success"] is True + + # Verify parameters were passed to service + kwargs = routes.service.last_get_random_loras_kwargs + assert kwargs["use_recommended_strength"] is True + assert kwargs["recommended_strength_scale_min"] == 0.6 + assert kwargs["recommended_strength_scale_max"] == 0.8 + + +async def test_get_random_loras_with_recommended_strength_disabled(routes): + """Test random LoRAs with recommended strength feature disabled (default)""" + request = DummyRequest( + json_data={ + "count": 5, + "model_strength_min": 0.5, + "model_strength_max": 1.0, + "use_same_clip_strength": True, + "locked_loras": [], + } + ) + + response = await routes.get_random_loras(request) + payload = json.loads(response.text) + + assert response.status == 200 + assert payload["success"] is True + + # Verify default parameters were passed to service + kwargs = routes.service.last_get_random_loras_kwargs + assert kwargs["use_recommended_strength"] is False + assert kwargs["recommended_strength_scale_min"] == 0.5 + assert kwargs["recommended_strength_scale_max"] == 1.0 diff --git a/vue-widgets/src/components/LoraRandomizerWidget.vue b/vue-widgets/src/components/LoraRandomizerWidget.vue index 5fb5beaf..6e83fbbc 100644 --- a/vue-widgets/src/components/LoraRandomizerWidget.vue +++ b/vue-widgets/src/components/LoraRandomizerWidget.vue @@ -16,6 +16,9 @@ :last-used="state.lastUsed.value" :current-loras="currentLoras" :can-reuse-last="canReuseLast" + :use-recommended-strength="state.useRecommendedStrength.value" + :recommended-strength-scale-min="state.recommendedStrengthScaleMin.value" + :recommended-strength-scale-max="state.recommendedStrengthScaleMax.value" @update:count-mode="state.countMode.value = $event" @update:count-fixed="state.countFixed.value = $event" @update:count-min="state.countMin.value = $event" @@ -26,6 +29,9 @@ @update:clip-strength-min="state.clipStrengthMin.value = $event" @update:clip-strength-max="state.clipStrengthMax.value = $event" @update:roll-mode="state.rollMode.value = $event" + @update:use-recommended-strength="state.useRecommendedStrength.value = $event" + @update:recommended-strength-scale-min="state.recommendedStrengthScaleMin.value = $event" + @update:recommended-strength-scale-max="state.recommendedStrengthScaleMax.value = $event" @generate-fixed="handleGenerateFixed" @always-randomize="handleAlwaysRandomize" @reuse-last="handleReuseLast" diff --git a/vue-widgets/src/components/lora-randomizer/LoraRandomizerSettingsView.vue b/vue-widgets/src/components/lora-randomizer/LoraRandomizerSettingsView.vue index 1a316b1b..433a584b 100644 --- a/vue-widgets/src/components/lora-randomizer/LoraRandomizerSettingsView.vue +++ b/vue-widgets/src/components/lora-randomizer/LoraRandomizerSettingsView.vue @@ -73,6 +73,40 @@ + +
{{ subtitle }}
\n{{ subtitle }}
\n