diff --git a/py/nodes/lora_cycler.py b/py/nodes/lora_cycler.py index 8fdf5a84..3b9e786d 100644 --- a/py/nodes/lora_cycler.py +++ b/py/nodes/lora_cycler.py @@ -8,6 +8,7 @@ and tracks the cycle progress which persists across workflow save/load. import logging import os + from ..utils.utils import get_lora_info logger = logging.getLogger(__name__) @@ -54,6 +55,9 @@ class LoraCyclerLM: current_index = cycler_config.get("current_index", 1) # 1-based model_strength = float(cycler_config.get("model_strength", 1.0)) clip_strength = float(cycler_config.get("clip_strength", 1.0)) + use_same_clip_strength = cycler_config.get("use_same_clip_strength", True) + use_preset_strength = cycler_config.get("use_preset_strength", False) + preset_strength_scale = float(cycler_config.get("preset_strength_scale", 1.0)) sort_by = "filename" # Include "no lora" option @@ -131,6 +135,39 @@ class LoraCyclerLM: else: # Normalize path separators lora_path = lora_path.replace("/", os.sep) + + if use_preset_strength: + lora_metadata = await lora_service.get_lora_metadata_by_filename( + current_lora["file_name"] + ) + if lora_metadata: + recommended_strength = ( + lora_service.get_recommended_strength_from_lora_data( + lora_metadata + ) + ) + if recommended_strength is not None: + model_strength = round( + recommended_strength * preset_strength_scale, 2 + ) + + if use_same_clip_strength: + clip_strength = model_strength + else: + recommended_clip_strength = ( + lora_service.get_recommended_clip_strength_from_lora_data( + lora_metadata + ) + ) + if recommended_clip_strength is not None: + clip_strength = round( + recommended_clip_strength * preset_strength_scale, 2 + ) + elif use_same_clip_strength: + clip_strength = model_strength + elif use_same_clip_strength: + clip_strength = model_strength + lora_stack = [(lora_path, model_strength, clip_strength)] # Calculate next index (wrap to 1 if at end) diff --git a/py/services/lora_service.py b/py/services/lora_service.py index 83dd2506..dc3407af 100644 --- a/py/services/lora_service.py +++ b/py/services/lora_service.py @@ -1,5 +1,6 @@ -import os import logging +import json +import os from typing import Dict, List, Optional from .base_model_service import BaseModelService @@ -278,6 +279,42 @@ class LoraService(BaseModelService): return None + @staticmethod + def get_recommended_strength_from_lora_data(lora_data: Dict) -> Optional[float]: + """Parse usage_tips JSON and extract recommended model strength.""" + try: + usage_tips = lora_data.get("usage_tips", "") + if not usage_tips: + return None + tips_data = json.loads(usage_tips) + return tips_data.get("strength") + except (json.JSONDecodeError, TypeError, AttributeError): + return None + + @staticmethod + def get_recommended_clip_strength_from_lora_data( + lora_data: Dict, + ) -> Optional[float]: + """Parse usage_tips JSON and extract recommended clip strength.""" + try: + usage_tips = lora_data.get("usage_tips", "") + if not usage_tips: + return None + tips_data = json.loads(usage_tips) + return tips_data.get("clipStrength") + except (json.JSONDecodeError, TypeError, AttributeError): + return None + + async def get_lora_metadata_by_filename(self, filename: str) -> Optional[Dict]: + """Return cached raw metadata for a LoRA matching the given filename.""" + cache = await self.scanner.get_cached_data(force_refresh=False) + + for lora in cache.raw_data if cache else []: + if lora.get("file_name") == filename: + return lora + + return None + def find_duplicate_hashes(self) -> Dict: """Find LoRAs with duplicate SHA256 hashes""" return self.scanner._hash_index.get_duplicate_hashes() @@ -328,34 +365,10 @@ class LoraService(BaseModelService): List of LoRA dicts with randomized strengths """ import random - import json - # Use a local Random instance to avoid affecting global random state # This ensures each execution with a different seed produces different results rng = random.Random(seed) - def get_recommended_strength(lora_data: Dict) -> Optional[float]: - """Parse usage_tips JSON and extract recommended strength""" - try: - usage_tips = lora_data.get("usage_tips", "") - if not usage_tips: - return None - tips_data = json.loads(usage_tips) - return tips_data.get("strength") - except (json.JSONDecodeError, TypeError, AttributeError): - return None - - def get_recommended_clip_strength(lora_data: Dict) -> Optional[float]: - """Parse usage_tips JSON and extract recommended clip strength""" - try: - usage_tips = lora_data.get("usage_tips", "") - if not usage_tips: - return None - tips_data = json.loads(usage_tips) - return tips_data.get("clipStrength") - except (json.JSONDecodeError, TypeError, AttributeError): - return None - if locked_loras is None: locked_loras = [] @@ -403,7 +416,9 @@ class LoraService(BaseModelService): result_loras = [] for lora in selected: if use_recommended_strength: - recommended_strength = get_recommended_strength(lora) + recommended_strength = self.get_recommended_strength_from_lora_data( + lora + ) if recommended_strength is not None: scale = rng.uniform( recommended_strength_scale_min, recommended_strength_scale_max @@ -421,7 +436,9 @@ class LoraService(BaseModelService): if use_same_clip_strength: clip_str = model_str elif use_recommended_strength: - recommended_clip_strength = get_recommended_clip_strength(lora) + recommended_clip_strength = ( + self.get_recommended_clip_strength_from_lora_data(lora) + ) if recommended_clip_strength is not None: scale = rng.uniform( recommended_strength_scale_min, recommended_strength_scale_max diff --git a/tests/nodes/test_lora_cycler.py b/tests/nodes/test_lora_cycler.py new file mode 100644 index 00000000..0bef21e8 --- /dev/null +++ b/tests/nodes/test_lora_cycler.py @@ -0,0 +1,109 @@ +"""Tests for preset strength behavior in LoraCyclerLM.""" + +from unittest.mock import AsyncMock + +import pytest + +from py.nodes.lora_cycler import LoraCyclerLM +from py.services import service_registry + + +@pytest.fixture +def cycler_node(): + return LoraCyclerLM() + + +@pytest.fixture +def cycler_config(): + return { + "current_index": 1, + "model_strength": 0.8, + "clip_strength": 0.6, + "use_same_clip_strength": False, + "use_preset_strength": True, + "preset_strength_scale": 1.5, + "include_no_lora": False, + } + + +@pytest.mark.asyncio +async def test_cycler_uses_scaled_preset_strength_when_available( + cycler_node, cycler_config, mock_scanner, monkeypatch +): + monkeypatch.setattr( + service_registry.ServiceRegistry, + "get_lora_scanner", + AsyncMock(return_value=mock_scanner), + ) + + mock_scanner._cache.raw_data = [ + { + "file_name": "preset_lora.safetensors", + "file_path": "/models/loras/preset_lora.safetensors", + "folder": "", + "usage_tips": '{"strength": 0.7, "clipStrength": 0.5}', + } + ] + + result = await cycler_node.cycle(cycler_config) + + assert result["result"][0] == [ + ("/models/loras/preset_lora.safetensors", 1.05, 0.75) + ] + + +@pytest.mark.asyncio +async def test_cycler_falls_back_to_manual_strength_when_preset_missing( + cycler_node, cycler_config, mock_scanner, monkeypatch +): + monkeypatch.setattr( + service_registry.ServiceRegistry, + "get_lora_scanner", + AsyncMock(return_value=mock_scanner), + ) + + mock_scanner._cache.raw_data = [ + { + "file_name": "manual_lora.safetensors", + "file_path": "/models/loras/manual_lora.safetensors", + "folder": "", + "usage_tips": "", + } + ] + + result = await cycler_node.cycle(cycler_config) + + assert result["result"][0] == [ + ("/models/loras/manual_lora.safetensors", 0.8, 0.6) + ] + + +@pytest.mark.asyncio +async def test_cycler_syncs_clip_to_model_when_same_clip_strength_enabled( + cycler_node, cycler_config, mock_scanner, monkeypatch +): + monkeypatch.setattr( + service_registry.ServiceRegistry, + "get_lora_scanner", + AsyncMock(return_value=mock_scanner), + ) + + mock_scanner._cache.raw_data = [ + { + "file_name": "preset_lora.safetensors", + "file_path": "/models/loras/preset_lora.safetensors", + "folder": "", + "usage_tips": '{"strength": 0.7, "clipStrength": 0.3}', + } + ] + + result = await cycler_node.cycle( + { + **cycler_config, + "use_same_clip_strength": True, + } + ) + + assert result["result"][0] == [ + ("/models/loras/preset_lora.safetensors", 1.05, 1.05) + ] diff --git a/vue-widgets/src/components/LoraCyclerWidget.vue b/vue-widgets/src/components/LoraCyclerWidget.vue index 244f6c4b..20af817b 100644 --- a/vue-widgets/src/components/LoraCyclerWidget.vue +++ b/vue-widgets/src/components/LoraCyclerWidget.vue @@ -8,6 +8,8 @@ :model-strength="state.modelStrength.value" :clip-strength="state.clipStrength.value" :use-custom-clip-range="state.useCustomClipRange.value" + :use-preset-strength="state.usePresetStrength.value" + :preset-strength-scale="state.presetStrengthScale.value" :is-clip-strength-disabled="state.isClipStrengthDisabled.value" :is-loading="state.isLoading.value" :repeat-count="state.repeatCount.value" @@ -22,6 +24,8 @@ @update:model-strength="state.modelStrength.value = $event" @update:clip-strength="state.clipStrength.value = $event" @update:use-custom-clip-range="handleUseCustomClipRangeChange" + @update:use-preset-strength="state.usePresetStrength.value = $event" + @update:preset-strength-scale="state.presetStrengthScale.value = $event" @update:repeat-count="handleRepeatCountChange" @update:include-no-lora="handleIncludeNoLoraChange" @toggle-pause="handleTogglePause" diff --git a/vue-widgets/src/components/lora-cycler/LoraCyclerSettingsView.vue b/vue-widgets/src/components/lora-cycler/LoraCyclerSettingsView.vue index f4fb84a3..d867aba7 100644 --- a/vue-widgets/src/components/lora-cycler/LoraCyclerSettingsView.vue +++ b/vue-widgets/src/components/lora-cycler/LoraCyclerSettingsView.vue @@ -131,6 +131,38 @@ + +
{{ subtitle }}
\n{{ subtitle }}
\n