feat(lora-randomizer): refactor randomization logic and add input preprocessing

- Add `_preprocess_loras_input` method to handle different widget input formats
- Move core randomization logic to `LoraService` for better separation of concerns
- Update `_select_loras` method to use new service-based approach
- Add comprehensive test fixtures for license filtering scenarios
- Include debug print statement for pool config inspection during development

This refactor improves code organization by centralizing business logic in the service layer while maintaining backward compatibility with existing widget inputs.
This commit is contained in:
Will Miao
2026-01-13 15:47:59 +08:00
parent 1ebd2c93a0
commit 514846cd4a
8 changed files with 457 additions and 229 deletions

View File

@@ -39,6 +39,22 @@ class LoraRandomizerNode:
FUNCTION = "randomize"
OUTPUT_NODE = False
def _preprocess_loras_input(self, loras):
"""
Preprocess loras input to handle different widget formats.
Args:
loras: Input from widget, either:
- List of LoRA dicts (expected format)
- Dict with '__value__' key containing the list
Returns:
List of LoRA dicts
"""
if isinstance(loras, dict) and "__value__" in loras:
return loras["__value__"]
return loras
async def randomize(self, randomizer_config, loras, pool_config=None):
"""
Randomize LoRAs based on configuration and pool filters.
@@ -53,6 +69,8 @@ class LoraRandomizerNode:
"""
from ..services.service_registry import ServiceRegistry
loras = self._preprocess_loras_input(loras)
roll_mode = randomizer_config.get("roll_mode", "always")
logger.debug(f"[LoraRandomizerNode] roll_mode: {roll_mode}")
@@ -64,6 +82,8 @@ class LoraRandomizerNode:
scanner, randomizer_config, loras, pool_config
)
print("pool config", pool_config)
execution_stack = self._build_execution_stack_from_input(loras)
return {
@@ -120,6 +140,8 @@ class LoraRandomizerNode:
Returns:
List of LoRA dicts for UI display
"""
from ..services.lora_service import LoraService
# Parse randomizer settings
count_mode = randomizer_config.get("count_mode", "range")
count_fixed = randomizer_config.get("count_fixed", 5)
@@ -131,183 +153,23 @@ class LoraRandomizerNode:
clip_strength_min = randomizer_config.get("clip_strength_min", 0.0)
clip_strength_max = randomizer_config.get("clip_strength_max", 1.0)
# Determine target count
if count_mode == "fixed":
target_count = count_fixed
else:
target_count = random.randint(count_min, count_max)
# Extract locked LoRAs from input
locked_loras = [lora for lora in input_loras if lora.get("locked", False)]
locked_count = len(locked_loras)
# Get available loras from cache
try:
cache_data = await scanner.get_cached_data(force_refresh=False)
if cache_data and hasattr(cache_data, "raw_data"):
available_loras = cache_data.raw_data
else:
available_loras = []
except Exception as e:
logger.warning(f"[LoraRandomizerNode] Failed to get lora cache: {e}")
available_loras = []
# Apply pool filters if provided
if pool_config:
available_loras = await self._apply_pool_filters(
available_loras, pool_config, scanner
)
# Calculate how many new LoRAs to select
slots_needed = target_count - locked_count
if slots_needed < 0:
slots_needed = 0
# Too many locked, trim to target
locked_loras = locked_loras[:target_count]
locked_count = len(locked_loras)
# Filter out locked LoRAs from available pool
locked_names = {lora["name"] for lora in locked_loras}
available_pool = [
l for l in available_loras if l["file_name"] not in locked_names
]
# Ensure we don't try to select more than available
if slots_needed > len(available_pool):
slots_needed = len(available_pool)
# Random sample
selected = []
if slots_needed > 0:
selected = random.sample(available_pool, slots_needed)
# Generate random strengths for selected LoRAs
result_loras = []
for lora in selected:
model_str = round(random.uniform(model_strength_min, model_strength_max), 2)
if use_same_clip_strength:
clip_str = model_str
else:
clip_str = round(
random.uniform(clip_strength_min, clip_strength_max), 2
)
result_loras.append(
{
"name": lora["file_name"],
"strength": model_str,
"clipStrength": clip_str,
"active": True,
"expanded": abs(model_str - clip_str) > 0.001,
"locked": False,
}
)
# Merge with locked LoRAs
result_loras.extend(locked_loras)
# Use LoraService to generate random LoRAs
lora_service = LoraService(scanner)
result_loras = await lora_service.get_random_loras(
count=count_fixed,
model_strength_min=model_strength_min,
model_strength_max=model_strength_max,
use_same_clip_strength=use_same_clip_strength,
clip_strength_min=clip_strength_min,
clip_strength_max=clip_strength_max,
locked_loras=locked_loras,
pool_config=pool_config,
count_mode=count_mode,
count_min=count_min,
count_max=count_max,
)
return result_loras
async def _apply_pool_filters(self, available_loras, pool_config, scanner):
"""
Apply pool_config filters to available LoRAs.
Args:
available_loras: List of all LoRA dicts
pool_config: Dict with filter settings from LoRA Pool node
scanner: Scanner instance for accessing filter utilities
Returns:
Filtered list of LoRA dicts
"""
from ..services.lora_service import LoraService
from ..services.model_query import FilterCriteria
# Create lora service instance for filtering
lora_service = LoraService(scanner)
# Extract filter parameters from pool_config
selected_base_models = pool_config.get("baseModels", [])
tags_dict = pool_config.get("tags", {})
include_tags = tags_dict.get("include", [])
exclude_tags = tags_dict.get("exclude", [])
folders_dict = pool_config.get("folders", {})
include_folders = folders_dict.get("include", [])
exclude_folders = folders_dict.get("exclude", [])
license_dict = pool_config.get("license", {})
no_credit_required = license_dict.get("noCreditRequired", False)
allow_selling = license_dict.get("allowSelling", False)
# Build tag filters dict
tag_filters = {}
for tag in include_tags:
tag_filters[tag] = "include"
for tag in exclude_tags:
tag_filters[tag] = "exclude"
# Build folder filter
# LoRA Pool uses include/exclude folders, we need to apply this logic
# For now, we'll filter based on folder path matching
if include_folders or exclude_folders:
filtered = []
for lora in available_loras:
folder = lora.get("folder", "")
# Check exclude folders first
excluded = False
for exclude_folder in exclude_folders:
if folder.startswith(exclude_folder):
excluded = True
break
if excluded:
continue
# Check include folders
if include_folders:
included = False
for include_folder in include_folders:
if folder.startswith(include_folder):
included = True
break
if not included:
continue
filtered.append(lora)
available_loras = filtered
# Apply base model filter
if selected_base_models:
available_loras = [
lora
for lora in available_loras
if lora.get("base_model") in selected_base_models
]
# Apply tag filters
if tag_filters:
criteria = FilterCriteria(tags=tag_filters)
available_loras = lora_service.filter_set.apply(available_loras, criteria)
# Apply license filters
# Note: no_credit_required=True means filter out models where credit is NOT required
# (i.e., keep only models where credit IS required)
if no_credit_required:
available_loras = [
lora
for lora in available_loras
if not (lora.get("license_flags", 127) & (1 << 0))
]
# allow_selling=True means keep only models where selling generated content is allowed
if allow_selling:
available_loras = [
lora
for lora in available_loras
if bool(lora.get("license_flags", 127) & (1 << 1))
]
return available_loras

View File

@@ -225,12 +225,15 @@ class LoraService(BaseModelService):
clip_strength_max: float = 1.0,
locked_loras: Optional[List[Dict]] = None,
pool_config: Optional[Dict] = None,
count_mode: str = "fixed",
count_min: int = 3,
count_max: int = 7,
) -> List[Dict]:
"""
Get random LoRAs with specified strength ranges.
Args:
count: Number of LoRAs to select
count: Number of LoRAs to select (if count_mode='fixed')
model_strength_min: Minimum model strength
model_strength_max: Maximum model strength
use_same_clip_strength: Whether to use same strength for clip
@@ -238,6 +241,9 @@ class LoraService(BaseModelService):
clip_strength_max: Maximum clip strength
locked_loras: List of locked LoRA dicts to preserve
pool_config: Optional pool config for filtering
count_mode: How to determine count ('fixed' or 'range')
count_min: Minimum count for range mode
count_max: Maximum count for range mode
Returns:
List of LoRA dicts with randomized strengths
@@ -247,6 +253,12 @@ class LoraService(BaseModelService):
if locked_loras is None:
locked_loras = []
# Determine target count based on count_mode
if count_mode == "fixed":
target_count = count
else:
target_count = random.randint(count_min, count_max)
# Get available loras from cache
cache = await self.scanner.get_cached_data(force_refresh=False)
available_loras = cache.raw_data if cache else []
@@ -259,12 +271,12 @@ class LoraService(BaseModelService):
# Calculate slots needed (total - locked)
locked_count = len(locked_loras)
slots_needed = count - locked_count
slots_needed = target_count - locked_count
if slots_needed < 0:
slots_needed = 0
# Too many locked, trim to target
locked_loras = locked_loras[:count]
locked_loras = locked_loras[:target_count]
# Filter out locked LoRAs from available pool
locked_names = {lora["name"] for lora in locked_loras}
@@ -324,14 +336,19 @@ class LoraService(BaseModelService):
"""
from .model_query import FilterCriteria
# Extract filter parameters from pool_config
selected_base_models = pool_config.get("selected_base_models", [])
include_tags = pool_config.get("include_tags", [])
exclude_tags = pool_config.get("exclude_tags", [])
include_folders = pool_config.get("include_folders", [])
exclude_folders = pool_config.get("exclude_folders", [])
no_credit_required = pool_config.get("no_credit_required", False)
allow_selling = pool_config.get("allow_selling", False)
filter_section = pool_config
# Extract filter parameters
selected_base_models = filter_section.get("baseModels", [])
tags_dict = filter_section.get("tags", {})
include_tags = tags_dict.get("include", [])
exclude_tags = tags_dict.get("exclude", [])
folders_dict = filter_section.get("folders", {})
include_folders = folders_dict.get("include", [])
exclude_folders = folders_dict.get("exclude", [])
license_dict = filter_section.get("license", {})
no_credit_required = license_dict.get("noCreditRequired", False)
allow_selling = license_dict.get("allowSelling", False)
# Build tag filters dict
tag_filters = {}
@@ -384,13 +401,13 @@ class LoraService(BaseModelService):
available_loras = self.filter_set.apply(available_loras, criteria)
# Apply license filters
# Note: no_credit_required=True means filter out models where credit is NOT required
# (i.e., keep only models where credit IS required)
# no_credit_required=True means keep only models where credit is NOT required
# (i.e., allowNoCredit=True, which is bit 0 = 1 in license_flags)
if no_credit_required:
available_loras = [
lora
for lora in available_loras
if not (lora.get("license_flags", 127) & (1 << 0))
if bool(lora.get("license_flags", 127) & (1 << 0))
]
# allow_selling=True means keep only models where selling generated content is allowed