mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 06:32:12 -03:00
feat(lora-randomizer): refactor randomization logic and add input preprocessing
- Add `_preprocess_loras_input` method to handle different widget input formats - Move core randomization logic to `LoraService` for better separation of concerns - Update `_select_loras` method to use new service-based approach - Add comprehensive test fixtures for license filtering scenarios - Include debug print statement for pool config inspection during development This refactor improves code organization by centralizing business logic in the service layer while maintaining backward compatibility with existing widget inputs.
This commit is contained in:
@@ -225,12 +225,15 @@ class LoraService(BaseModelService):
|
||||
clip_strength_max: float = 1.0,
|
||||
locked_loras: Optional[List[Dict]] = None,
|
||||
pool_config: Optional[Dict] = None,
|
||||
count_mode: str = "fixed",
|
||||
count_min: int = 3,
|
||||
count_max: int = 7,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Get random LoRAs with specified strength ranges.
|
||||
|
||||
Args:
|
||||
count: Number of LoRAs to select
|
||||
count: Number of LoRAs to select (if count_mode='fixed')
|
||||
model_strength_min: Minimum model strength
|
||||
model_strength_max: Maximum model strength
|
||||
use_same_clip_strength: Whether to use same strength for clip
|
||||
@@ -238,6 +241,9 @@ class LoraService(BaseModelService):
|
||||
clip_strength_max: Maximum clip strength
|
||||
locked_loras: List of locked LoRA dicts to preserve
|
||||
pool_config: Optional pool config for filtering
|
||||
count_mode: How to determine count ('fixed' or 'range')
|
||||
count_min: Minimum count for range mode
|
||||
count_max: Maximum count for range mode
|
||||
|
||||
Returns:
|
||||
List of LoRA dicts with randomized strengths
|
||||
@@ -247,6 +253,12 @@ class LoraService(BaseModelService):
|
||||
if locked_loras is None:
|
||||
locked_loras = []
|
||||
|
||||
# Determine target count based on count_mode
|
||||
if count_mode == "fixed":
|
||||
target_count = count
|
||||
else:
|
||||
target_count = random.randint(count_min, count_max)
|
||||
|
||||
# Get available loras from cache
|
||||
cache = await self.scanner.get_cached_data(force_refresh=False)
|
||||
available_loras = cache.raw_data if cache else []
|
||||
@@ -259,12 +271,12 @@ class LoraService(BaseModelService):
|
||||
|
||||
# Calculate slots needed (total - locked)
|
||||
locked_count = len(locked_loras)
|
||||
slots_needed = count - locked_count
|
||||
slots_needed = target_count - locked_count
|
||||
|
||||
if slots_needed < 0:
|
||||
slots_needed = 0
|
||||
# Too many locked, trim to target
|
||||
locked_loras = locked_loras[:count]
|
||||
locked_loras = locked_loras[:target_count]
|
||||
|
||||
# Filter out locked LoRAs from available pool
|
||||
locked_names = {lora["name"] for lora in locked_loras}
|
||||
@@ -324,14 +336,19 @@ class LoraService(BaseModelService):
|
||||
"""
|
||||
from .model_query import FilterCriteria
|
||||
|
||||
# Extract filter parameters from pool_config
|
||||
selected_base_models = pool_config.get("selected_base_models", [])
|
||||
include_tags = pool_config.get("include_tags", [])
|
||||
exclude_tags = pool_config.get("exclude_tags", [])
|
||||
include_folders = pool_config.get("include_folders", [])
|
||||
exclude_folders = pool_config.get("exclude_folders", [])
|
||||
no_credit_required = pool_config.get("no_credit_required", False)
|
||||
allow_selling = pool_config.get("allow_selling", False)
|
||||
filter_section = pool_config
|
||||
|
||||
# Extract filter parameters
|
||||
selected_base_models = filter_section.get("baseModels", [])
|
||||
tags_dict = filter_section.get("tags", {})
|
||||
include_tags = tags_dict.get("include", [])
|
||||
exclude_tags = tags_dict.get("exclude", [])
|
||||
folders_dict = filter_section.get("folders", {})
|
||||
include_folders = folders_dict.get("include", [])
|
||||
exclude_folders = folders_dict.get("exclude", [])
|
||||
license_dict = filter_section.get("license", {})
|
||||
no_credit_required = license_dict.get("noCreditRequired", False)
|
||||
allow_selling = license_dict.get("allowSelling", False)
|
||||
|
||||
# Build tag filters dict
|
||||
tag_filters = {}
|
||||
@@ -384,13 +401,13 @@ class LoraService(BaseModelService):
|
||||
available_loras = self.filter_set.apply(available_loras, criteria)
|
||||
|
||||
# Apply license filters
|
||||
# Note: no_credit_required=True means filter out models where credit is NOT required
|
||||
# (i.e., keep only models where credit IS required)
|
||||
# no_credit_required=True means keep only models where credit is NOT required
|
||||
# (i.e., allowNoCredit=True, which is bit 0 = 1 in license_flags)
|
||||
if no_credit_required:
|
||||
available_loras = [
|
||||
lora
|
||||
for lora in available_loras
|
||||
if not (lora.get("license_flags", 127) & (1 << 0))
|
||||
if bool(lora.get("license_flags", 127) & (1 << 0))
|
||||
]
|
||||
|
||||
# allow_selling=True means keep only models where selling generated content is allowed
|
||||
|
||||
Reference in New Issue
Block a user