mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-23 14:12:11 -03:00
feat: add LoraDemoNode and LoraRandomizerNode with documentation
- Import and register two new nodes: LoraDemoNode and LoraRandomizerNode - Update import exception handling for better readability with multi-line formatting - Add comprehensive documentation file `docs/custom-node-ui-output.md` for UI output usage in custom nodes - Ensure proper node registration in NODE_CLASS_MAPPINGS for ComfyUI integration - Maintain backward compatibility with existing node structure and import fallbacks
This commit is contained in:
95
py/nodes/lora_demo.py
Normal file
95
py/nodes/lora_demo.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
Lora Demo Node - Demonstrates LORAS custom widget type usage.
|
||||
|
||||
This node accepts LORAS widget input and outputs a summary string.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import random
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoraDemoNode:
|
||||
"""Demo node that uses LORAS custom widget type."""
|
||||
|
||||
NAME = "Lora Demo (LoraManager)"
|
||||
CATEGORY = "Lora Manager/demo"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"loras": ("LORAS", {}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
RETURN_NAMES = ("summary",)
|
||||
|
||||
FUNCTION = "process"
|
||||
OUTPUT_NODE = False
|
||||
|
||||
async def process(self, loras):
|
||||
"""
|
||||
Process LoRAs input and return summary + UI data for widget.
|
||||
|
||||
Args:
|
||||
loras: List of LoRA dictionaries with structure:
|
||||
[{'name': str, 'strength': float, 'clipStrength': float, 'active': bool, ...}]
|
||||
|
||||
Returns:
|
||||
Dictionary with 'result' (for workflow) and 'ui' (for frontend display)
|
||||
"""
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
|
||||
# Get lora scanner to access available loras
|
||||
scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
# Get available loras from cache
|
||||
available_loras = []
|
||||
try:
|
||||
cache_data = await scanner.get_cached_data(force_refresh=False)
|
||||
if cache_data and hasattr(cache_data, "raw_data"):
|
||||
available_loras = cache_data.raw_data
|
||||
except Exception as e:
|
||||
logger.warning(f"[LoraDemoNode] Failed to get lora cache: {e}")
|
||||
|
||||
# Randomly select 3-5 loras
|
||||
num_to_select = random.randint(3, 5)
|
||||
if len(available_loras) < num_to_select:
|
||||
num_to_select = len(available_loras)
|
||||
|
||||
selected_loras = (
|
||||
random.sample(available_loras, num_to_select) if num_to_select > 0 else []
|
||||
)
|
||||
|
||||
# Generate random loras data for widget
|
||||
widget_loras = []
|
||||
for lora in selected_loras:
|
||||
strength = round(random.uniform(0.1, 1.0), 2)
|
||||
widget_loras.append(
|
||||
{
|
||||
"name": lora.get("file_name", "Unknown"),
|
||||
"strength": strength,
|
||||
"clipStrength": strength,
|
||||
"active": True,
|
||||
"expanded": False,
|
||||
}
|
||||
)
|
||||
|
||||
# Create summary string
|
||||
active_names = [l["name"] for l in widget_loras]
|
||||
summary = f"Randomized {len(active_names)} LoRAs: {', '.join(active_names)}"
|
||||
|
||||
logger.info(f"[LoraDemoNode] {summary}")
|
||||
|
||||
# Return format: result for workflow + ui for frontend
|
||||
return {"result": (summary,), "ui": {"loras": widget_loras}}
|
||||
|
||||
|
||||
# Node class mappings for ComfyUI
|
||||
NODE_CLASS_MAPPINGS = {"LoraDemoNode": LoraDemoNode}
|
||||
|
||||
# Display name mappings
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {"LoraDemoNode": "LoRA Demo"}
|
||||
@@ -31,28 +31,28 @@ class LoraPoolNode:
|
||||
"hidden": {
|
||||
# Hidden input to pass through unique node ID for frontend
|
||||
"unique_id": "UNIQUE_ID",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LORA_POOL_CONFIG",)
|
||||
RETURN_NAMES = ("pool_config",)
|
||||
RETURN_TYPES = ("POOL_CONFIG",)
|
||||
RETURN_NAMES = ("POOL_CONFIG",)
|
||||
|
||||
FUNCTION = "process"
|
||||
OUTPUT_NODE = False
|
||||
|
||||
def process(self, pool_config, unique_id=None):
|
||||
"""
|
||||
Pass through the pool configuration.
|
||||
Pass through the pool configuration filters.
|
||||
|
||||
The config is generated entirely by the frontend widget.
|
||||
This function validates and passes through the configuration.
|
||||
This function validates and returns only the filters field.
|
||||
|
||||
Args:
|
||||
pool_config: Dict containing filter criteria from widget
|
||||
unique_id: Node's unique ID (hidden)
|
||||
|
||||
Returns:
|
||||
Tuple containing the validated pool_config
|
||||
Tuple containing the filters dict from pool_config
|
||||
"""
|
||||
# Validate required structure
|
||||
if not isinstance(pool_config, dict):
|
||||
@@ -63,10 +63,13 @@ class LoraPoolNode:
|
||||
if "version" not in pool_config:
|
||||
pool_config["version"] = 1
|
||||
|
||||
# Log for debugging
|
||||
logger.debug(f"[LoraPoolNode] Processing config: {pool_config}")
|
||||
# Extract filters field
|
||||
filters = pool_config.get("filters", self._default_config()["filters"])
|
||||
|
||||
return (pool_config,)
|
||||
# Log for debugging
|
||||
logger.debug(f"[LoraPoolNode] Processing filters: {filters}")
|
||||
|
||||
return (filters,)
|
||||
|
||||
@staticmethod
|
||||
def _default_config():
|
||||
@@ -76,23 +79,16 @@ class LoraPoolNode:
|
||||
"filters": {
|
||||
"baseModels": [],
|
||||
"tags": {"include": [], "exclude": []},
|
||||
"folder": {"path": None, "recursive": True},
|
||||
"folders": {"include": [], "exclude": []},
|
||||
"favoritesOnly": False,
|
||||
"license": {
|
||||
"noCreditRequired": None,
|
||||
"allowSellingGeneratedContent": None
|
||||
}
|
||||
"license": {"noCreditRequired": False, "allowSelling": False},
|
||||
},
|
||||
"preview": {"matchCount": 0, "lastUpdated": 0}
|
||||
"preview": {"matchCount": 0, "lastUpdated": 0},
|
||||
}
|
||||
|
||||
|
||||
# Node class mappings for ComfyUI
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LoraPoolNode": LoraPoolNode
|
||||
}
|
||||
NODE_CLASS_MAPPINGS = {"LoraPoolNode": LoraPoolNode}
|
||||
|
||||
# Display name mappings
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"LoraPoolNode": "LoRA Pool (Filter)"
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {"LoraPoolNode": "LoRA Pool (Filter)"}
|
||||
|
||||
297
py/nodes/lora_randomizer.py
Normal file
297
py/nodes/lora_randomizer.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""
|
||||
Lora Randomizer Node - Randomly selects LoRAs from a pool with configurable settings.
|
||||
|
||||
This node accepts optional pool_config input to filter available LoRAs, and outputs
|
||||
a LORA_STACK with randomly selected LoRAs. Supports both frontend roll (fixed selection)
|
||||
and backend roll (randomizes each execution).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import random
|
||||
import os
|
||||
from ..utils.utils import get_lora_info
|
||||
from .utils import extract_lora_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoraRandomizerNode:
|
||||
"""Node that randomly selects LoRAs from a pool"""
|
||||
|
||||
NAME = "Lora Randomizer (LoraManager)"
|
||||
CATEGORY = "Lora Manager/randomizer"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"randomizer_config": ("RANDOMIZER_CONFIG", {}),
|
||||
"loras": ("LORAS", {}),
|
||||
},
|
||||
"optional": {
|
||||
"pool_config": ("POOL_CONFIG", {}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LORA_STACK",)
|
||||
RETURN_NAMES = ("lora_stack",)
|
||||
|
||||
FUNCTION = "randomize"
|
||||
OUTPUT_NODE = False
|
||||
|
||||
async def randomize(self, randomizer_config, loras, pool_config=None):
|
||||
"""
|
||||
Randomize LoRAs based on configuration and pool filters.
|
||||
|
||||
Args:
|
||||
randomizer_config: Dict with randomizer settings (count, strength ranges, roll mode)
|
||||
loras: List of LoRA dicts from LORAS widget (includes locked state)
|
||||
pool_config: Optional config from LoRA Pool node for filtering
|
||||
|
||||
Returns:
|
||||
Dictionary with 'result' (LORA_STACK tuple) and 'ui' (for widget display)
|
||||
"""
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
|
||||
# Get lora scanner to access available loras
|
||||
scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
# Parse randomizer settings
|
||||
count_mode = randomizer_config.get("count_mode", "range")
|
||||
count_fixed = randomizer_config.get("count_fixed", 5)
|
||||
count_min = randomizer_config.get("count_min", 3)
|
||||
count_max = randomizer_config.get("count_max", 7)
|
||||
model_strength_min = randomizer_config.get("model_strength_min", 0.0)
|
||||
model_strength_max = randomizer_config.get("model_strength_max", 1.0)
|
||||
use_same_clip_strength = randomizer_config.get("use_same_clip_strength", True)
|
||||
clip_strength_min = randomizer_config.get("clip_strength_min", 0.0)
|
||||
clip_strength_max = randomizer_config.get("clip_strength_max", 1.0)
|
||||
roll_mode = randomizer_config.get("roll_mode", "frontend")
|
||||
|
||||
# Determine target count
|
||||
if count_mode == "fixed":
|
||||
target_count = count_fixed
|
||||
else:
|
||||
target_count = random.randint(count_min, count_max)
|
||||
|
||||
logger.info(
|
||||
f"[LoraRandomizerNode] Target count: {target_count}, Roll mode: {roll_mode}"
|
||||
)
|
||||
|
||||
# Extract locked LoRAs from input
|
||||
locked_loras = [lora for lora in loras if lora.get("locked", False)]
|
||||
locked_count = len(locked_loras)
|
||||
|
||||
logger.info(f"[LoraRandomizerNode] Locked LoRAs: {locked_count}")
|
||||
|
||||
# Get available loras from cache
|
||||
try:
|
||||
cache_data = await scanner.get_cached_data(force_refresh=False)
|
||||
if cache_data and hasattr(cache_data, "raw_data"):
|
||||
available_loras = cache_data.raw_data
|
||||
else:
|
||||
available_loras = []
|
||||
except Exception as e:
|
||||
logger.warning(f"[LoraRandomizerNode] Failed to get lora cache: {e}")
|
||||
available_loras = []
|
||||
|
||||
# Apply pool filters if provided
|
||||
if pool_config:
|
||||
available_loras = await self._apply_pool_filters(
|
||||
available_loras, pool_config, scanner
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[LoraRandomizerNode] Available LoRAs after filtering: {len(available_loras)}"
|
||||
)
|
||||
|
||||
# Calculate how many new LoRAs to select
|
||||
# In frontend mode, if loras already has data, preserve unlocked ones if roll_mode requires
|
||||
# For simplicity in backend mode, we regenerate all unlocked slots
|
||||
slots_needed = target_count - locked_count
|
||||
|
||||
if slots_needed < 0:
|
||||
slots_needed = 0
|
||||
# Too many locked, trim to target
|
||||
locked_loras = locked_loras[:target_count]
|
||||
locked_count = len(locked_loras)
|
||||
|
||||
# Filter out locked LoRAs from available pool
|
||||
locked_names = {lora["name"] for lora in locked_loras}
|
||||
available_pool = [
|
||||
l for l in available_loras if l["file_name"] not in locked_names
|
||||
]
|
||||
|
||||
# Ensure we don't try to select more than available
|
||||
if slots_needed > len(available_pool):
|
||||
slots_needed = len(available_pool)
|
||||
|
||||
logger.info(
|
||||
f"[LoraRandomizerNode] Selecting {slots_needed} new LoRAs from {len(available_pool)} available"
|
||||
)
|
||||
|
||||
# Random sample
|
||||
selected = []
|
||||
if slots_needed > 0:
|
||||
selected = random.sample(available_pool, slots_needed)
|
||||
|
||||
# Generate random strengths for selected LoRAs
|
||||
result_loras = []
|
||||
for lora in selected:
|
||||
model_str = round(random.uniform(model_strength_min, model_strength_max), 2)
|
||||
|
||||
if use_same_clip_strength:
|
||||
clip_str = model_str
|
||||
else:
|
||||
clip_str = round(
|
||||
random.uniform(clip_strength_min, clip_strength_max), 2
|
||||
)
|
||||
|
||||
result_loras.append(
|
||||
{
|
||||
"name": lora["file_name"],
|
||||
"strength": model_str,
|
||||
"clipStrength": clip_str,
|
||||
"active": True,
|
||||
"expanded": abs(model_str - clip_str) > 0.001,
|
||||
"locked": False,
|
||||
}
|
||||
)
|
||||
|
||||
# Merge with locked LoRAs
|
||||
result_loras.extend(locked_loras)
|
||||
|
||||
logger.info(f"[LoraRandomizerNode] Final LoRA count: {len(result_loras)}")
|
||||
|
||||
# Build LORA_STACK output
|
||||
lora_stack = []
|
||||
for lora in result_loras:
|
||||
if not lora.get("active", False):
|
||||
continue
|
||||
|
||||
# Get file path
|
||||
lora_path, trigger_words = get_lora_info(lora["name"])
|
||||
if not lora_path:
|
||||
logger.warning(
|
||||
f"[LoraRandomizerNode] Could not find path for LoRA: {lora['name']}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Normalize path separators
|
||||
lora_path = lora_path.replace("/", os.sep)
|
||||
|
||||
# Extract strengths
|
||||
model_strength = lora.get("strength", 1.0)
|
||||
clip_strength = lora.get("clipStrength", model_strength)
|
||||
|
||||
lora_stack.append((lora_path, model_strength, clip_strength))
|
||||
|
||||
# Return format: result for workflow + ui for frontend display
|
||||
return {"result": (lora_stack,), "ui": {"loras": result_loras}}
|
||||
|
||||
async def _apply_pool_filters(self, available_loras, pool_config, scanner):
|
||||
"""
|
||||
Apply pool_config filters to available LoRAs.
|
||||
|
||||
Args:
|
||||
available_loras: List of all LoRA dicts
|
||||
pool_config: Dict with filter settings from LoRA Pool node
|
||||
scanner: Scanner instance for accessing filter utilities
|
||||
|
||||
Returns:
|
||||
Filtered list of LoRA dicts
|
||||
"""
|
||||
from ..services.lora_service import LoraService
|
||||
from ..services.model_query import FilterCriteria
|
||||
|
||||
# Create lora service instance for filtering
|
||||
lora_service = LoraService(scanner)
|
||||
|
||||
# Extract filter parameters from pool_config
|
||||
selected_base_models = pool_config.get("baseModels", [])
|
||||
tags_dict = pool_config.get("tags", {})
|
||||
include_tags = tags_dict.get("include", [])
|
||||
exclude_tags = tags_dict.get("exclude", [])
|
||||
folders_dict = pool_config.get("folders", {})
|
||||
include_folders = folders_dict.get("include", [])
|
||||
exclude_folders = folders_dict.get("exclude", [])
|
||||
license_dict = pool_config.get("license", {})
|
||||
no_credit_required = license_dict.get("noCreditRequired", False)
|
||||
allow_selling = license_dict.get("allowSelling", False)
|
||||
|
||||
# Build tag filters dict
|
||||
tag_filters = {}
|
||||
for tag in include_tags:
|
||||
tag_filters[tag] = "include"
|
||||
for tag in exclude_tags:
|
||||
tag_filters[tag] = "exclude"
|
||||
|
||||
# Build folder filter
|
||||
# LoRA Pool uses include/exclude folders, we need to apply this logic
|
||||
# For now, we'll filter based on folder path matching
|
||||
if include_folders or exclude_folders:
|
||||
filtered = []
|
||||
for lora in available_loras:
|
||||
folder = lora.get("folder", "")
|
||||
|
||||
# Check exclude folders first
|
||||
excluded = False
|
||||
for exclude_folder in exclude_folders:
|
||||
if folder.startswith(exclude_folder):
|
||||
excluded = True
|
||||
break
|
||||
|
||||
if excluded:
|
||||
continue
|
||||
|
||||
# Check include folders
|
||||
if include_folders:
|
||||
included = False
|
||||
for include_folder in include_folders:
|
||||
if folder.startswith(include_folder):
|
||||
included = True
|
||||
break
|
||||
if not included:
|
||||
continue
|
||||
|
||||
filtered.append(lora)
|
||||
|
||||
available_loras = filtered
|
||||
|
||||
# Apply base model filter
|
||||
if selected_base_models:
|
||||
available_loras = [
|
||||
lora
|
||||
for lora in available_loras
|
||||
if lora.get("base_model") in selected_base_models
|
||||
]
|
||||
|
||||
# Apply tag filters
|
||||
if tag_filters:
|
||||
criteria = FilterCriteria(tags=tag_filters)
|
||||
available_loras = lora_service.filter_set.apply(available_loras, criteria)
|
||||
|
||||
# Apply license filters
|
||||
if no_credit_required:
|
||||
available_loras = [
|
||||
lora
|
||||
for lora in available_loras
|
||||
if not lora.get("civitai", {}).get("allowNoCredit", True)
|
||||
]
|
||||
|
||||
if allow_selling:
|
||||
available_loras = [
|
||||
lora
|
||||
for lora in available_loras
|
||||
if lora.get("civitai", {}).get("allowCommercialUse", ["None"])[0]
|
||||
!= "None"
|
||||
]
|
||||
|
||||
return available_loras
|
||||
|
||||
|
||||
# Node class mappings for ComfyUI
|
||||
NODE_CLASS_MAPPINGS = {"LoraRandomizerNode": LoraRandomizerNode}
|
||||
|
||||
# Display name mappings
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {"LoraRandomizerNode": "LoRA Randomizer"}
|
||||
@@ -45,6 +45,9 @@ class LoraRoutes(BaseModelRoutes):
|
||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/get-trigger-words', prefix, self.get_lora_trigger_words)
|
||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/usage-tips-by-path', prefix, self.get_lora_usage_tips_by_path)
|
||||
|
||||
# Randomizer routes
|
||||
registrar.add_prefixed_route('POST', '/api/lm/{prefix}/random-sample', prefix, self.get_random_loras)
|
||||
|
||||
# ComfyUI integration
|
||||
registrar.add_prefixed_route('POST', '/api/lm/{prefix}/get_trigger_words', prefix, self.get_trigger_words)
|
||||
|
||||
@@ -215,6 +218,74 @@ class LoraRoutes(BaseModelRoutes):
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_random_loras(self, request: web.Request) -> web.Response:
|
||||
"""Get random LoRAs based on filters and strength ranges"""
|
||||
try:
|
||||
json_data = await request.json()
|
||||
|
||||
# Parse parameters
|
||||
count = json_data.get('count', 5)
|
||||
count_min = json_data.get('count_min')
|
||||
count_max = json_data.get('count_max')
|
||||
model_strength_min = float(json_data.get('model_strength_min', 0.0))
|
||||
model_strength_max = float(json_data.get('model_strength_max', 1.0))
|
||||
use_same_clip_strength = json_data.get('use_same_clip_strength', True)
|
||||
clip_strength_min = float(json_data.get('clip_strength_min', 0.0))
|
||||
clip_strength_max = float(json_data.get('clip_strength_max', 1.0))
|
||||
locked_loras = json_data.get('locked_loras', [])
|
||||
pool_config = json_data.get('pool_config')
|
||||
|
||||
# Determine target count
|
||||
if count_min is not None and count_max is not None:
|
||||
import random
|
||||
target_count = random.randint(count_min, count_max)
|
||||
else:
|
||||
target_count = count
|
||||
|
||||
# Validate parameters
|
||||
if target_count < 1 or target_count > 100:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Count must be between 1 and 100'
|
||||
}, status=400)
|
||||
|
||||
if model_strength_min < 0 or model_strength_max > 10:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Model strength must be between 0 and 10'
|
||||
}, status=400)
|
||||
|
||||
# Get random LoRAs from service
|
||||
result_loras = await self.service.get_random_loras(
|
||||
count=target_count,
|
||||
model_strength_min=model_strength_min,
|
||||
model_strength_max=model_strength_max,
|
||||
use_same_clip_strength=use_same_clip_strength,
|
||||
clip_strength_min=clip_strength_min,
|
||||
clip_strength_max=clip_strength_max,
|
||||
locked_loras=locked_loras,
|
||||
pool_config=pool_config
|
||||
)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'loras': result_loras,
|
||||
'count': len(result_loras)
|
||||
})
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid parameter for random LoRAs: {e}")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=400)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting random LoRAs: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_trigger_words(self, request: web.Request) -> web.Response:
|
||||
"""Get trigger words for specified LoRA models"""
|
||||
try:
|
||||
|
||||
@@ -178,7 +178,188 @@ class LoraService(BaseModelService):
|
||||
def find_duplicate_hashes(self) -> Dict:
|
||||
"""Find LoRAs with duplicate SHA256 hashes"""
|
||||
return self.scanner._hash_index.get_duplicate_hashes()
|
||||
|
||||
|
||||
def find_duplicate_filenames(self) -> Dict:
|
||||
"""Find LoRAs with conflicting filenames"""
|
||||
return self.scanner._hash_index.get_duplicate_filenames()
|
||||
|
||||
async def get_random_loras(
|
||||
self,
|
||||
count: int,
|
||||
model_strength_min: float = 0.0,
|
||||
model_strength_max: float = 1.0,
|
||||
use_same_clip_strength: bool = True,
|
||||
clip_strength_min: float = 0.0,
|
||||
clip_strength_max: float = 1.0,
|
||||
locked_loras: Optional[List[Dict]] = None,
|
||||
pool_config: Optional[Dict] = None
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Get random LoRAs with specified strength ranges.
|
||||
|
||||
Args:
|
||||
count: Number of LoRAs to select
|
||||
model_strength_min: Minimum model strength
|
||||
model_strength_max: Maximum model strength
|
||||
use_same_clip_strength: Whether to use same strength for clip
|
||||
clip_strength_min: Minimum clip strength
|
||||
clip_strength_max: Maximum clip strength
|
||||
locked_loras: List of locked LoRA dicts to preserve
|
||||
pool_config: Optional pool config for filtering
|
||||
|
||||
Returns:
|
||||
List of LoRA dicts with randomized strengths
|
||||
"""
|
||||
import random
|
||||
|
||||
if locked_loras is None:
|
||||
locked_loras = []
|
||||
|
||||
# Get available loras from cache
|
||||
cache = await self.scanner.get_cached_data(force_refresh=False)
|
||||
available_loras = cache.raw_data if cache else []
|
||||
|
||||
# Apply pool filters if provided
|
||||
if pool_config:
|
||||
available_loras = await self._apply_pool_filters(
|
||||
available_loras, pool_config
|
||||
)
|
||||
|
||||
# Calculate slots needed (total - locked)
|
||||
locked_count = len(locked_loras)
|
||||
slots_needed = count - locked_count
|
||||
|
||||
if slots_needed < 0:
|
||||
slots_needed = 0
|
||||
# Too many locked, trim to target
|
||||
locked_loras = locked_loras[:count]
|
||||
|
||||
# Filter out locked LoRAs from available pool
|
||||
locked_names = {lora['name'] for lora in locked_loras}
|
||||
available_pool = [
|
||||
l for l in available_loras
|
||||
if l['model_name'] not in locked_names
|
||||
]
|
||||
|
||||
# Ensure we don't try to select more than available
|
||||
if slots_needed > len(available_pool):
|
||||
slots_needed = len(available_pool)
|
||||
|
||||
# Random sample
|
||||
selected = []
|
||||
if slots_needed > 0:
|
||||
selected = random.sample(available_pool, slots_needed)
|
||||
|
||||
# Generate random strengths for selected LoRAs
|
||||
result_loras = []
|
||||
for lora in selected:
|
||||
model_str = round(
|
||||
random.uniform(model_strength_min, model_strength_max), 2
|
||||
)
|
||||
|
||||
if use_same_clip_strength:
|
||||
clip_str = model_str
|
||||
else:
|
||||
clip_str = round(
|
||||
random.uniform(clip_strength_min, clip_strength_max), 2
|
||||
)
|
||||
|
||||
result_loras.append({
|
||||
'name': lora['model_name'],
|
||||
'strength': model_str,
|
||||
'clipStrength': clip_str,
|
||||
'active': True,
|
||||
'expanded': abs(model_str - clip_str) > 0.001,
|
||||
'locked': False
|
||||
})
|
||||
|
||||
# Merge with locked LoRAs
|
||||
result_loras.extend(locked_loras)
|
||||
|
||||
return result_loras
|
||||
|
||||
async def _apply_pool_filters(self, available_loras: List[Dict], pool_config: Dict) -> List[Dict]:
|
||||
"""
|
||||
Apply pool_config filters to available LoRAs.
|
||||
|
||||
Args:
|
||||
available_loras: List of all LoRA dicts
|
||||
pool_config: Dict with filter settings from LoRA Pool node
|
||||
|
||||
Returns:
|
||||
Filtered list of LoRA dicts
|
||||
"""
|
||||
from .model_query import FilterCriteria
|
||||
|
||||
# Extract filter parameters from pool_config
|
||||
selected_base_models = pool_config.get('selected_base_models', [])
|
||||
include_tags = pool_config.get('include_tags', [])
|
||||
exclude_tags = pool_config.get('exclude_tags', [])
|
||||
include_folders = pool_config.get('include_folders', [])
|
||||
exclude_folders = pool_config.get('exclude_folders', [])
|
||||
no_credit_required = pool_config.get('no_credit_required', False)
|
||||
allow_selling = pool_config.get('allow_selling', False)
|
||||
|
||||
# Build tag filters dict
|
||||
tag_filters = {}
|
||||
for tag in include_tags:
|
||||
tag_filters[tag] = 'include'
|
||||
for tag in exclude_tags:
|
||||
tag_filters[tag] = 'exclude'
|
||||
|
||||
# Build folder filter
|
||||
if include_folders or exclude_folders:
|
||||
filtered = []
|
||||
for lora in available_loras:
|
||||
folder = lora.get('folder', '')
|
||||
|
||||
# Check exclude folders first
|
||||
excluded = False
|
||||
for exclude_folder in exclude_folders:
|
||||
if folder.startswith(exclude_folder):
|
||||
excluded = True
|
||||
break
|
||||
|
||||
if excluded:
|
||||
continue
|
||||
|
||||
# Check include folders
|
||||
if include_folders:
|
||||
included = False
|
||||
for include_folder in include_folders:
|
||||
if folder.startswith(include_folder):
|
||||
included = True
|
||||
break
|
||||
if not included:
|
||||
continue
|
||||
|
||||
filtered.append(lora)
|
||||
|
||||
available_loras = filtered
|
||||
|
||||
# Apply base model filter
|
||||
if selected_base_models:
|
||||
available_loras = [
|
||||
lora for lora in available_loras
|
||||
if lora.get('base_model') in selected_base_models
|
||||
]
|
||||
|
||||
# Apply tag filters
|
||||
if tag_filters:
|
||||
criteria = FilterCriteria(tags=tag_filters)
|
||||
available_loras = self.filter_set.apply(available_loras, criteria)
|
||||
|
||||
# Apply license filters
|
||||
if no_credit_required:
|
||||
available_loras = [
|
||||
lora for lora in available_loras
|
||||
if not lora.get('civitai', {}).get('allowNoCredit', True)
|
||||
]
|
||||
|
||||
if allow_selling:
|
||||
available_loras = [
|
||||
lora for lora in available_loras
|
||||
if lora.get('civitai', {}).get('allowCommercialUse', ['None'])[0] != 'None'
|
||||
]
|
||||
|
||||
return available_loras
|
||||
|
||||
Reference in New Issue
Block a user