mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat: add LoraDemoNode and LoraRandomizerNode with documentation
- Import and register two new nodes: LoraDemoNode and LoraRandomizerNode - Update import exception handling for better readability with multi-line formatting - Add comprehensive documentation file `docs/custom-node-ui-output.md` for UI output usage in custom nodes - Ensure proper node registration in NODE_CLASS_MAPPINGS for ComfyUI integration - Maintain backward compatibility with existing node structure and import fallbacks
This commit is contained in:
41
__init__.py
41
__init__.py
@@ -9,8 +9,12 @@ try: # pragma: no cover - import fallback for pytest collection
|
|||||||
from .py.nodes.wanvideo_lora_select import WanVideoLoraSelectLM
|
from .py.nodes.wanvideo_lora_select import WanVideoLoraSelectLM
|
||||||
from .py.nodes.wanvideo_lora_select_from_text import WanVideoLoraSelectFromText
|
from .py.nodes.wanvideo_lora_select_from_text import WanVideoLoraSelectFromText
|
||||||
from .py.nodes.lora_pool import LoraPoolNode
|
from .py.nodes.lora_pool import LoraPoolNode
|
||||||
|
from .py.nodes.lora_demo import LoraDemoNode
|
||||||
|
from .py.nodes.lora_randomizer import LoraRandomizerNode
|
||||||
from .py.metadata_collector import init as init_metadata_collector
|
from .py.metadata_collector import init as init_metadata_collector
|
||||||
except ImportError: # pragma: no cover - allows running under pytest without package install
|
except (
|
||||||
|
ImportError
|
||||||
|
): # pragma: no cover - allows running under pytest without package install
|
||||||
import importlib
|
import importlib
|
||||||
import pathlib
|
import pathlib
|
||||||
import sys
|
import sys
|
||||||
@@ -21,15 +25,27 @@ except ImportError: # pragma: no cover - allows running under pytest without pa
|
|||||||
|
|
||||||
PromptLoraManager = importlib.import_module("py.nodes.prompt").PromptLoraManager
|
PromptLoraManager = importlib.import_module("py.nodes.prompt").PromptLoraManager
|
||||||
LoraManager = importlib.import_module("py.lora_manager").LoraManager
|
LoraManager = importlib.import_module("py.lora_manager").LoraManager
|
||||||
LoraManagerLoader = importlib.import_module("py.nodes.lora_loader").LoraManagerLoader
|
LoraManagerLoader = importlib.import_module(
|
||||||
LoraManagerTextLoader = importlib.import_module("py.nodes.lora_loader").LoraManagerTextLoader
|
"py.nodes.lora_loader"
|
||||||
TriggerWordToggle = importlib.import_module("py.nodes.trigger_word_toggle").TriggerWordToggle
|
).LoraManagerLoader
|
||||||
|
LoraManagerTextLoader = importlib.import_module(
|
||||||
|
"py.nodes.lora_loader"
|
||||||
|
).LoraManagerTextLoader
|
||||||
|
TriggerWordToggle = importlib.import_module(
|
||||||
|
"py.nodes.trigger_word_toggle"
|
||||||
|
).TriggerWordToggle
|
||||||
LoraStacker = importlib.import_module("py.nodes.lora_stacker").LoraStacker
|
LoraStacker = importlib.import_module("py.nodes.lora_stacker").LoraStacker
|
||||||
SaveImageLM = importlib.import_module("py.nodes.save_image").SaveImageLM
|
SaveImageLM = importlib.import_module("py.nodes.save_image").SaveImageLM
|
||||||
DebugMetadata = importlib.import_module("py.nodes.debug_metadata").DebugMetadata
|
DebugMetadata = importlib.import_module("py.nodes.debug_metadata").DebugMetadata
|
||||||
WanVideoLoraSelectLM = importlib.import_module("py.nodes.wanvideo_lora_select").WanVideoLoraSelectLM
|
WanVideoLoraSelectLM = importlib.import_module(
|
||||||
WanVideoLoraSelectFromText = importlib.import_module("py.nodes.wanvideo_lora_select_from_text").WanVideoLoraSelectFromText
|
"py.nodes.wanvideo_lora_select"
|
||||||
|
).WanVideoLoraSelectLM
|
||||||
|
WanVideoLoraSelectFromText = importlib.import_module(
|
||||||
|
"py.nodes.wanvideo_lora_select_from_text"
|
||||||
|
).WanVideoLoraSelectFromText
|
||||||
LoraPoolNode = importlib.import_module("py.nodes.lora_pool").LoraPoolNode
|
LoraPoolNode = importlib.import_module("py.nodes.lora_pool").LoraPoolNode
|
||||||
|
LoraDemoNode = importlib.import_module("py.nodes.lora_demo").LoraDemoNode
|
||||||
|
LoraRandomizerNode = importlib.import_module("py.nodes.lora_randomizer").LoraRandomizerNode
|
||||||
init_metadata_collector = importlib.import_module("py.metadata_collector").init
|
init_metadata_collector = importlib.import_module("py.metadata_collector").init
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
@@ -42,7 +58,9 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
DebugMetadata.NAME: DebugMetadata,
|
DebugMetadata.NAME: DebugMetadata,
|
||||||
WanVideoLoraSelectLM.NAME: WanVideoLoraSelectLM,
|
WanVideoLoraSelectLM.NAME: WanVideoLoraSelectLM,
|
||||||
WanVideoLoraSelectFromText.NAME: WanVideoLoraSelectFromText,
|
WanVideoLoraSelectFromText.NAME: WanVideoLoraSelectFromText,
|
||||||
LoraPoolNode.NAME: LoraPoolNode
|
LoraPoolNode.NAME: LoraPoolNode,
|
||||||
|
LoraDemoNode.NAME: LoraDemoNode,
|
||||||
|
LoraRandomizerNode.NAME: LoraRandomizerNode,
|
||||||
}
|
}
|
||||||
|
|
||||||
WEB_DIRECTORY = "./web/comfyui"
|
WEB_DIRECTORY = "./web/comfyui"
|
||||||
@@ -50,15 +68,20 @@ WEB_DIRECTORY = "./web/comfyui"
|
|||||||
# Check and build Vue widgets if needed (development mode)
|
# Check and build Vue widgets if needed (development mode)
|
||||||
try:
|
try:
|
||||||
from .py.vue_widget_builder import check_and_build_vue_widgets
|
from .py.vue_widget_builder import check_and_build_vue_widgets
|
||||||
|
|
||||||
# Auto-build in development, warn only if fails
|
# Auto-build in development, warn only if fails
|
||||||
check_and_build_vue_widgets(auto_build=True, warn_only=True)
|
check_and_build_vue_widgets(auto_build=True, warn_only=True)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# Fallback for pytest
|
# Fallback for pytest
|
||||||
import importlib
|
import importlib
|
||||||
check_and_build_vue_widgets = importlib.import_module("py.vue_widget_builder").check_and_build_vue_widgets
|
|
||||||
|
check_and_build_vue_widgets = importlib.import_module(
|
||||||
|
"py.vue_widget_builder"
|
||||||
|
).check_and_build_vue_widgets
|
||||||
check_and_build_vue_widgets(auto_build=True, warn_only=True)
|
check_and_build_vue_widgets(auto_build=True, warn_only=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logging.warning(f"[LoRA Manager] Vue widget build check skipped: {e}")
|
logging.warning(f"[LoRA Manager] Vue widget build check skipped: {e}")
|
||||||
|
|
||||||
# Initialize metadata collector
|
# Initialize metadata collector
|
||||||
@@ -66,4 +89,4 @@ init_metadata_collector()
|
|||||||
|
|
||||||
# Register routes on import
|
# Register routes on import
|
||||||
LoraManager.add_routes()
|
LoraManager.add_routes()
|
||||||
__all__ = ['NODE_CLASS_MAPPINGS', 'WEB_DIRECTORY']
|
__all__ = ["NODE_CLASS_MAPPINGS", "WEB_DIRECTORY"]
|
||||||
|
|||||||
95
py/nodes/lora_demo.py
Normal file
95
py/nodes/lora_demo.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
"""
|
||||||
|
Lora Demo Node - Demonstrates LORAS custom widget type usage.
|
||||||
|
|
||||||
|
This node accepts LORAS widget input and outputs a summary string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LoraDemoNode:
|
||||||
|
"""Demo node that uses LORAS custom widget type."""
|
||||||
|
|
||||||
|
NAME = "Lora Demo (LoraManager)"
|
||||||
|
CATEGORY = "Lora Manager/demo"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"loras": ("LORAS", {}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("STRING",)
|
||||||
|
RETURN_NAMES = ("summary",)
|
||||||
|
|
||||||
|
FUNCTION = "process"
|
||||||
|
OUTPUT_NODE = False
|
||||||
|
|
||||||
|
async def process(self, loras):
|
||||||
|
"""
|
||||||
|
Process LoRAs input and return summary + UI data for widget.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loras: List of LoRA dictionaries with structure:
|
||||||
|
[{'name': str, 'strength': float, 'clipStrength': float, 'active': bool, ...}]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with 'result' (for workflow) and 'ui' (for frontend display)
|
||||||
|
"""
|
||||||
|
from ..services.service_registry import ServiceRegistry
|
||||||
|
|
||||||
|
# Get lora scanner to access available loras
|
||||||
|
scanner = await ServiceRegistry.get_lora_scanner()
|
||||||
|
|
||||||
|
# Get available loras from cache
|
||||||
|
available_loras = []
|
||||||
|
try:
|
||||||
|
cache_data = await scanner.get_cached_data(force_refresh=False)
|
||||||
|
if cache_data and hasattr(cache_data, "raw_data"):
|
||||||
|
available_loras = cache_data.raw_data
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[LoraDemoNode] Failed to get lora cache: {e}")
|
||||||
|
|
||||||
|
# Randomly select 3-5 loras
|
||||||
|
num_to_select = random.randint(3, 5)
|
||||||
|
if len(available_loras) < num_to_select:
|
||||||
|
num_to_select = len(available_loras)
|
||||||
|
|
||||||
|
selected_loras = (
|
||||||
|
random.sample(available_loras, num_to_select) if num_to_select > 0 else []
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate random loras data for widget
|
||||||
|
widget_loras = []
|
||||||
|
for lora in selected_loras:
|
||||||
|
strength = round(random.uniform(0.1, 1.0), 2)
|
||||||
|
widget_loras.append(
|
||||||
|
{
|
||||||
|
"name": lora.get("file_name", "Unknown"),
|
||||||
|
"strength": strength,
|
||||||
|
"clipStrength": strength,
|
||||||
|
"active": True,
|
||||||
|
"expanded": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create summary string
|
||||||
|
active_names = [l["name"] for l in widget_loras]
|
||||||
|
summary = f"Randomized {len(active_names)} LoRAs: {', '.join(active_names)}"
|
||||||
|
|
||||||
|
logger.info(f"[LoraDemoNode] {summary}")
|
||||||
|
|
||||||
|
# Return format: result for workflow + ui for frontend
|
||||||
|
return {"result": (summary,), "ui": {"loras": widget_loras}}
|
||||||
|
|
||||||
|
|
||||||
|
# Node class mappings for ComfyUI
|
||||||
|
NODE_CLASS_MAPPINGS = {"LoraDemoNode": LoraDemoNode}
|
||||||
|
|
||||||
|
# Display name mappings
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {"LoraDemoNode": "LoRA Demo"}
|
||||||
@@ -31,28 +31,28 @@ class LoraPoolNode:
|
|||||||
"hidden": {
|
"hidden": {
|
||||||
# Hidden input to pass through unique node ID for frontend
|
# Hidden input to pass through unique node ID for frontend
|
||||||
"unique_id": "UNIQUE_ID",
|
"unique_id": "UNIQUE_ID",
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("LORA_POOL_CONFIG",)
|
RETURN_TYPES = ("POOL_CONFIG",)
|
||||||
RETURN_NAMES = ("pool_config",)
|
RETURN_NAMES = ("POOL_CONFIG",)
|
||||||
|
|
||||||
FUNCTION = "process"
|
FUNCTION = "process"
|
||||||
OUTPUT_NODE = False
|
OUTPUT_NODE = False
|
||||||
|
|
||||||
def process(self, pool_config, unique_id=None):
|
def process(self, pool_config, unique_id=None):
|
||||||
"""
|
"""
|
||||||
Pass through the pool configuration.
|
Pass through the pool configuration filters.
|
||||||
|
|
||||||
The config is generated entirely by the frontend widget.
|
The config is generated entirely by the frontend widget.
|
||||||
This function validates and passes through the configuration.
|
This function validates and returns only the filters field.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pool_config: Dict containing filter criteria from widget
|
pool_config: Dict containing filter criteria from widget
|
||||||
unique_id: Node's unique ID (hidden)
|
unique_id: Node's unique ID (hidden)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple containing the validated pool_config
|
Tuple containing the filters dict from pool_config
|
||||||
"""
|
"""
|
||||||
# Validate required structure
|
# Validate required structure
|
||||||
if not isinstance(pool_config, dict):
|
if not isinstance(pool_config, dict):
|
||||||
@@ -63,10 +63,13 @@ class LoraPoolNode:
|
|||||||
if "version" not in pool_config:
|
if "version" not in pool_config:
|
||||||
pool_config["version"] = 1
|
pool_config["version"] = 1
|
||||||
|
|
||||||
# Log for debugging
|
# Extract filters field
|
||||||
logger.debug(f"[LoraPoolNode] Processing config: {pool_config}")
|
filters = pool_config.get("filters", self._default_config()["filters"])
|
||||||
|
|
||||||
return (pool_config,)
|
# Log for debugging
|
||||||
|
logger.debug(f"[LoraPoolNode] Processing filters: {filters}")
|
||||||
|
|
||||||
|
return (filters,)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _default_config():
|
def _default_config():
|
||||||
@@ -76,23 +79,16 @@ class LoraPoolNode:
|
|||||||
"filters": {
|
"filters": {
|
||||||
"baseModels": [],
|
"baseModels": [],
|
||||||
"tags": {"include": [], "exclude": []},
|
"tags": {"include": [], "exclude": []},
|
||||||
"folder": {"path": None, "recursive": True},
|
"folders": {"include": [], "exclude": []},
|
||||||
"favoritesOnly": False,
|
"favoritesOnly": False,
|
||||||
"license": {
|
"license": {"noCreditRequired": False, "allowSelling": False},
|
||||||
"noCreditRequired": None,
|
|
||||||
"allowSellingGeneratedContent": None
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"preview": {"matchCount": 0, "lastUpdated": 0}
|
"preview": {"matchCount": 0, "lastUpdated": 0},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# Node class mappings for ComfyUI
|
# Node class mappings for ComfyUI
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {"LoraPoolNode": LoraPoolNode}
|
||||||
"LoraPoolNode": LoraPoolNode
|
|
||||||
}
|
|
||||||
|
|
||||||
# Display name mappings
|
# Display name mappings
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {"LoraPoolNode": "LoRA Pool (Filter)"}
|
||||||
"LoraPoolNode": "LoRA Pool (Filter)"
|
|
||||||
}
|
|
||||||
|
|||||||
297
py/nodes/lora_randomizer.py
Normal file
297
py/nodes/lora_randomizer.py
Normal file
@@ -0,0 +1,297 @@
|
|||||||
|
"""
|
||||||
|
Lora Randomizer Node - Randomly selects LoRAs from a pool with configurable settings.
|
||||||
|
|
||||||
|
This node accepts optional pool_config input to filter available LoRAs, and outputs
|
||||||
|
a LORA_STACK with randomly selected LoRAs. Supports both frontend roll (fixed selection)
|
||||||
|
and backend roll (randomizes each execution).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
import os
|
||||||
|
from ..utils.utils import get_lora_info
|
||||||
|
from .utils import extract_lora_name
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LoraRandomizerNode:
|
||||||
|
"""Node that randomly selects LoRAs from a pool"""
|
||||||
|
|
||||||
|
NAME = "Lora Randomizer (LoraManager)"
|
||||||
|
CATEGORY = "Lora Manager/randomizer"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"randomizer_config": ("RANDOMIZER_CONFIG", {}),
|
||||||
|
"loras": ("LORAS", {}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"pool_config": ("POOL_CONFIG", {}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LORA_STACK",)
|
||||||
|
RETURN_NAMES = ("lora_stack",)
|
||||||
|
|
||||||
|
FUNCTION = "randomize"
|
||||||
|
OUTPUT_NODE = False
|
||||||
|
|
||||||
|
async def randomize(self, randomizer_config, loras, pool_config=None):
|
||||||
|
"""
|
||||||
|
Randomize LoRAs based on configuration and pool filters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
randomizer_config: Dict with randomizer settings (count, strength ranges, roll mode)
|
||||||
|
loras: List of LoRA dicts from LORAS widget (includes locked state)
|
||||||
|
pool_config: Optional config from LoRA Pool node for filtering
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with 'result' (LORA_STACK tuple) and 'ui' (for widget display)
|
||||||
|
"""
|
||||||
|
from ..services.service_registry import ServiceRegistry
|
||||||
|
|
||||||
|
# Get lora scanner to access available loras
|
||||||
|
scanner = await ServiceRegistry.get_lora_scanner()
|
||||||
|
|
||||||
|
# Parse randomizer settings
|
||||||
|
count_mode = randomizer_config.get("count_mode", "range")
|
||||||
|
count_fixed = randomizer_config.get("count_fixed", 5)
|
||||||
|
count_min = randomizer_config.get("count_min", 3)
|
||||||
|
count_max = randomizer_config.get("count_max", 7)
|
||||||
|
model_strength_min = randomizer_config.get("model_strength_min", 0.0)
|
||||||
|
model_strength_max = randomizer_config.get("model_strength_max", 1.0)
|
||||||
|
use_same_clip_strength = randomizer_config.get("use_same_clip_strength", True)
|
||||||
|
clip_strength_min = randomizer_config.get("clip_strength_min", 0.0)
|
||||||
|
clip_strength_max = randomizer_config.get("clip_strength_max", 1.0)
|
||||||
|
roll_mode = randomizer_config.get("roll_mode", "frontend")
|
||||||
|
|
||||||
|
# Determine target count
|
||||||
|
if count_mode == "fixed":
|
||||||
|
target_count = count_fixed
|
||||||
|
else:
|
||||||
|
target_count = random.randint(count_min, count_max)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[LoraRandomizerNode] Target count: {target_count}, Roll mode: {roll_mode}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract locked LoRAs from input
|
||||||
|
locked_loras = [lora for lora in loras if lora.get("locked", False)]
|
||||||
|
locked_count = len(locked_loras)
|
||||||
|
|
||||||
|
logger.info(f"[LoraRandomizerNode] Locked LoRAs: {locked_count}")
|
||||||
|
|
||||||
|
# Get available loras from cache
|
||||||
|
try:
|
||||||
|
cache_data = await scanner.get_cached_data(force_refresh=False)
|
||||||
|
if cache_data and hasattr(cache_data, "raw_data"):
|
||||||
|
available_loras = cache_data.raw_data
|
||||||
|
else:
|
||||||
|
available_loras = []
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[LoraRandomizerNode] Failed to get lora cache: {e}")
|
||||||
|
available_loras = []
|
||||||
|
|
||||||
|
# Apply pool filters if provided
|
||||||
|
if pool_config:
|
||||||
|
available_loras = await self._apply_pool_filters(
|
||||||
|
available_loras, pool_config, scanner
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[LoraRandomizerNode] Available LoRAs after filtering: {len(available_loras)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate how many new LoRAs to select
|
||||||
|
# In frontend mode, if loras already has data, preserve unlocked ones if roll_mode requires
|
||||||
|
# For simplicity in backend mode, we regenerate all unlocked slots
|
||||||
|
slots_needed = target_count - locked_count
|
||||||
|
|
||||||
|
if slots_needed < 0:
|
||||||
|
slots_needed = 0
|
||||||
|
# Too many locked, trim to target
|
||||||
|
locked_loras = locked_loras[:target_count]
|
||||||
|
locked_count = len(locked_loras)
|
||||||
|
|
||||||
|
# Filter out locked LoRAs from available pool
|
||||||
|
locked_names = {lora["name"] for lora in locked_loras}
|
||||||
|
available_pool = [
|
||||||
|
l for l in available_loras if l["file_name"] not in locked_names
|
||||||
|
]
|
||||||
|
|
||||||
|
# Ensure we don't try to select more than available
|
||||||
|
if slots_needed > len(available_pool):
|
||||||
|
slots_needed = len(available_pool)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[LoraRandomizerNode] Selecting {slots_needed} new LoRAs from {len(available_pool)} available"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Random sample
|
||||||
|
selected = []
|
||||||
|
if slots_needed > 0:
|
||||||
|
selected = random.sample(available_pool, slots_needed)
|
||||||
|
|
||||||
|
# Generate random strengths for selected LoRAs
|
||||||
|
result_loras = []
|
||||||
|
for lora in selected:
|
||||||
|
model_str = round(random.uniform(model_strength_min, model_strength_max), 2)
|
||||||
|
|
||||||
|
if use_same_clip_strength:
|
||||||
|
clip_str = model_str
|
||||||
|
else:
|
||||||
|
clip_str = round(
|
||||||
|
random.uniform(clip_strength_min, clip_strength_max), 2
|
||||||
|
)
|
||||||
|
|
||||||
|
result_loras.append(
|
||||||
|
{
|
||||||
|
"name": lora["file_name"],
|
||||||
|
"strength": model_str,
|
||||||
|
"clipStrength": clip_str,
|
||||||
|
"active": True,
|
||||||
|
"expanded": abs(model_str - clip_str) > 0.001,
|
||||||
|
"locked": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Merge with locked LoRAs
|
||||||
|
result_loras.extend(locked_loras)
|
||||||
|
|
||||||
|
logger.info(f"[LoraRandomizerNode] Final LoRA count: {len(result_loras)}")
|
||||||
|
|
||||||
|
# Build LORA_STACK output
|
||||||
|
lora_stack = []
|
||||||
|
for lora in result_loras:
|
||||||
|
if not lora.get("active", False):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get file path
|
||||||
|
lora_path, trigger_words = get_lora_info(lora["name"])
|
||||||
|
if not lora_path:
|
||||||
|
logger.warning(
|
||||||
|
f"[LoraRandomizerNode] Could not find path for LoRA: {lora['name']}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Normalize path separators
|
||||||
|
lora_path = lora_path.replace("/", os.sep)
|
||||||
|
|
||||||
|
# Extract strengths
|
||||||
|
model_strength = lora.get("strength", 1.0)
|
||||||
|
clip_strength = lora.get("clipStrength", model_strength)
|
||||||
|
|
||||||
|
lora_stack.append((lora_path, model_strength, clip_strength))
|
||||||
|
|
||||||
|
# Return format: result for workflow + ui for frontend display
|
||||||
|
return {"result": (lora_stack,), "ui": {"loras": result_loras}}
|
||||||
|
|
||||||
|
async def _apply_pool_filters(self, available_loras, pool_config, scanner):
|
||||||
|
"""
|
||||||
|
Apply pool_config filters to available LoRAs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
available_loras: List of all LoRA dicts
|
||||||
|
pool_config: Dict with filter settings from LoRA Pool node
|
||||||
|
scanner: Scanner instance for accessing filter utilities
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered list of LoRA dicts
|
||||||
|
"""
|
||||||
|
from ..services.lora_service import LoraService
|
||||||
|
from ..services.model_query import FilterCriteria
|
||||||
|
|
||||||
|
# Create lora service instance for filtering
|
||||||
|
lora_service = LoraService(scanner)
|
||||||
|
|
||||||
|
# Extract filter parameters from pool_config
|
||||||
|
selected_base_models = pool_config.get("baseModels", [])
|
||||||
|
tags_dict = pool_config.get("tags", {})
|
||||||
|
include_tags = tags_dict.get("include", [])
|
||||||
|
exclude_tags = tags_dict.get("exclude", [])
|
||||||
|
folders_dict = pool_config.get("folders", {})
|
||||||
|
include_folders = folders_dict.get("include", [])
|
||||||
|
exclude_folders = folders_dict.get("exclude", [])
|
||||||
|
license_dict = pool_config.get("license", {})
|
||||||
|
no_credit_required = license_dict.get("noCreditRequired", False)
|
||||||
|
allow_selling = license_dict.get("allowSelling", False)
|
||||||
|
|
||||||
|
# Build tag filters dict
|
||||||
|
tag_filters = {}
|
||||||
|
for tag in include_tags:
|
||||||
|
tag_filters[tag] = "include"
|
||||||
|
for tag in exclude_tags:
|
||||||
|
tag_filters[tag] = "exclude"
|
||||||
|
|
||||||
|
# Build folder filter
|
||||||
|
# LoRA Pool uses include/exclude folders, we need to apply this logic
|
||||||
|
# For now, we'll filter based on folder path matching
|
||||||
|
if include_folders or exclude_folders:
|
||||||
|
filtered = []
|
||||||
|
for lora in available_loras:
|
||||||
|
folder = lora.get("folder", "")
|
||||||
|
|
||||||
|
# Check exclude folders first
|
||||||
|
excluded = False
|
||||||
|
for exclude_folder in exclude_folders:
|
||||||
|
if folder.startswith(exclude_folder):
|
||||||
|
excluded = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if excluded:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check include folders
|
||||||
|
if include_folders:
|
||||||
|
included = False
|
||||||
|
for include_folder in include_folders:
|
||||||
|
if folder.startswith(include_folder):
|
||||||
|
included = True
|
||||||
|
break
|
||||||
|
if not included:
|
||||||
|
continue
|
||||||
|
|
||||||
|
filtered.append(lora)
|
||||||
|
|
||||||
|
available_loras = filtered
|
||||||
|
|
||||||
|
# Apply base model filter
|
||||||
|
if selected_base_models:
|
||||||
|
available_loras = [
|
||||||
|
lora
|
||||||
|
for lora in available_loras
|
||||||
|
if lora.get("base_model") in selected_base_models
|
||||||
|
]
|
||||||
|
|
||||||
|
# Apply tag filters
|
||||||
|
if tag_filters:
|
||||||
|
criteria = FilterCriteria(tags=tag_filters)
|
||||||
|
available_loras = lora_service.filter_set.apply(available_loras, criteria)
|
||||||
|
|
||||||
|
# Apply license filters
|
||||||
|
if no_credit_required:
|
||||||
|
available_loras = [
|
||||||
|
lora
|
||||||
|
for lora in available_loras
|
||||||
|
if not lora.get("civitai", {}).get("allowNoCredit", True)
|
||||||
|
]
|
||||||
|
|
||||||
|
if allow_selling:
|
||||||
|
available_loras = [
|
||||||
|
lora
|
||||||
|
for lora in available_loras
|
||||||
|
if lora.get("civitai", {}).get("allowCommercialUse", ["None"])[0]
|
||||||
|
!= "None"
|
||||||
|
]
|
||||||
|
|
||||||
|
return available_loras
|
||||||
|
|
||||||
|
|
||||||
|
# Node class mappings for ComfyUI
|
||||||
|
NODE_CLASS_MAPPINGS = {"LoraRandomizerNode": LoraRandomizerNode}
|
||||||
|
|
||||||
|
# Display name mappings
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {"LoraRandomizerNode": "LoRA Randomizer"}
|
||||||
@@ -45,6 +45,9 @@ class LoraRoutes(BaseModelRoutes):
|
|||||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/get-trigger-words', prefix, self.get_lora_trigger_words)
|
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/get-trigger-words', prefix, self.get_lora_trigger_words)
|
||||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/usage-tips-by-path', prefix, self.get_lora_usage_tips_by_path)
|
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/usage-tips-by-path', prefix, self.get_lora_usage_tips_by_path)
|
||||||
|
|
||||||
|
# Randomizer routes
|
||||||
|
registrar.add_prefixed_route('POST', '/api/lm/{prefix}/random-sample', prefix, self.get_random_loras)
|
||||||
|
|
||||||
# ComfyUI integration
|
# ComfyUI integration
|
||||||
registrar.add_prefixed_route('POST', '/api/lm/{prefix}/get_trigger_words', prefix, self.get_trigger_words)
|
registrar.add_prefixed_route('POST', '/api/lm/{prefix}/get_trigger_words', prefix, self.get_trigger_words)
|
||||||
|
|
||||||
@@ -215,6 +218,74 @@ class LoraRoutes(BaseModelRoutes):
|
|||||||
'error': str(e)
|
'error': str(e)
|
||||||
}, status=500)
|
}, status=500)
|
||||||
|
|
||||||
|
async def get_random_loras(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get random LoRAs based on filters and strength ranges"""
|
||||||
|
try:
|
||||||
|
json_data = await request.json()
|
||||||
|
|
||||||
|
# Parse parameters
|
||||||
|
count = json_data.get('count', 5)
|
||||||
|
count_min = json_data.get('count_min')
|
||||||
|
count_max = json_data.get('count_max')
|
||||||
|
model_strength_min = float(json_data.get('model_strength_min', 0.0))
|
||||||
|
model_strength_max = float(json_data.get('model_strength_max', 1.0))
|
||||||
|
use_same_clip_strength = json_data.get('use_same_clip_strength', True)
|
||||||
|
clip_strength_min = float(json_data.get('clip_strength_min', 0.0))
|
||||||
|
clip_strength_max = float(json_data.get('clip_strength_max', 1.0))
|
||||||
|
locked_loras = json_data.get('locked_loras', [])
|
||||||
|
pool_config = json_data.get('pool_config')
|
||||||
|
|
||||||
|
# Determine target count
|
||||||
|
if count_min is not None and count_max is not None:
|
||||||
|
import random
|
||||||
|
target_count = random.randint(count_min, count_max)
|
||||||
|
else:
|
||||||
|
target_count = count
|
||||||
|
|
||||||
|
# Validate parameters
|
||||||
|
if target_count < 1 or target_count > 100:
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': 'Count must be between 1 and 100'
|
||||||
|
}, status=400)
|
||||||
|
|
||||||
|
if model_strength_min < 0 or model_strength_max > 10:
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': 'Model strength must be between 0 and 10'
|
||||||
|
}, status=400)
|
||||||
|
|
||||||
|
# Get random LoRAs from service
|
||||||
|
result_loras = await self.service.get_random_loras(
|
||||||
|
count=target_count,
|
||||||
|
model_strength_min=model_strength_min,
|
||||||
|
model_strength_max=model_strength_max,
|
||||||
|
use_same_clip_strength=use_same_clip_strength,
|
||||||
|
clip_strength_min=clip_strength_min,
|
||||||
|
clip_strength_max=clip_strength_max,
|
||||||
|
locked_loras=locked_loras,
|
||||||
|
pool_config=pool_config
|
||||||
|
)
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
'success': True,
|
||||||
|
'loras': result_loras,
|
||||||
|
'count': len(result_loras)
|
||||||
|
})
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Invalid parameter for random LoRAs: {e}")
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=400)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting random LoRAs: {e}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
async def get_trigger_words(self, request: web.Request) -> web.Response:
|
async def get_trigger_words(self, request: web.Request) -> web.Response:
|
||||||
"""Get trigger words for specified LoRA models"""
|
"""Get trigger words for specified LoRA models"""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -178,7 +178,188 @@ class LoraService(BaseModelService):
|
|||||||
def find_duplicate_hashes(self) -> Dict:
|
def find_duplicate_hashes(self) -> Dict:
|
||||||
"""Find LoRAs with duplicate SHA256 hashes"""
|
"""Find LoRAs with duplicate SHA256 hashes"""
|
||||||
return self.scanner._hash_index.get_duplicate_hashes()
|
return self.scanner._hash_index.get_duplicate_hashes()
|
||||||
|
|
||||||
def find_duplicate_filenames(self) -> Dict:
|
def find_duplicate_filenames(self) -> Dict:
|
||||||
"""Find LoRAs with conflicting filenames"""
|
"""Find LoRAs with conflicting filenames"""
|
||||||
return self.scanner._hash_index.get_duplicate_filenames()
|
return self.scanner._hash_index.get_duplicate_filenames()
|
||||||
|
|
||||||
|
async def get_random_loras(
|
||||||
|
self,
|
||||||
|
count: int,
|
||||||
|
model_strength_min: float = 0.0,
|
||||||
|
model_strength_max: float = 1.0,
|
||||||
|
use_same_clip_strength: bool = True,
|
||||||
|
clip_strength_min: float = 0.0,
|
||||||
|
clip_strength_max: float = 1.0,
|
||||||
|
locked_loras: Optional[List[Dict]] = None,
|
||||||
|
pool_config: Optional[Dict] = None
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Get random LoRAs with specified strength ranges.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
count: Number of LoRAs to select
|
||||||
|
model_strength_min: Minimum model strength
|
||||||
|
model_strength_max: Maximum model strength
|
||||||
|
use_same_clip_strength: Whether to use same strength for clip
|
||||||
|
clip_strength_min: Minimum clip strength
|
||||||
|
clip_strength_max: Maximum clip strength
|
||||||
|
locked_loras: List of locked LoRA dicts to preserve
|
||||||
|
pool_config: Optional pool config for filtering
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of LoRA dicts with randomized strengths
|
||||||
|
"""
|
||||||
|
import random
|
||||||
|
|
||||||
|
if locked_loras is None:
|
||||||
|
locked_loras = []
|
||||||
|
|
||||||
|
# Get available loras from cache
|
||||||
|
cache = await self.scanner.get_cached_data(force_refresh=False)
|
||||||
|
available_loras = cache.raw_data if cache else []
|
||||||
|
|
||||||
|
# Apply pool filters if provided
|
||||||
|
if pool_config:
|
||||||
|
available_loras = await self._apply_pool_filters(
|
||||||
|
available_loras, pool_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate slots needed (total - locked)
|
||||||
|
locked_count = len(locked_loras)
|
||||||
|
slots_needed = count - locked_count
|
||||||
|
|
||||||
|
if slots_needed < 0:
|
||||||
|
slots_needed = 0
|
||||||
|
# Too many locked, trim to target
|
||||||
|
locked_loras = locked_loras[:count]
|
||||||
|
|
||||||
|
# Filter out locked LoRAs from available pool
|
||||||
|
locked_names = {lora['name'] for lora in locked_loras}
|
||||||
|
available_pool = [
|
||||||
|
l for l in available_loras
|
||||||
|
if l['model_name'] not in locked_names
|
||||||
|
]
|
||||||
|
|
||||||
|
# Ensure we don't try to select more than available
|
||||||
|
if slots_needed > len(available_pool):
|
||||||
|
slots_needed = len(available_pool)
|
||||||
|
|
||||||
|
# Random sample
|
||||||
|
selected = []
|
||||||
|
if slots_needed > 0:
|
||||||
|
selected = random.sample(available_pool, slots_needed)
|
||||||
|
|
||||||
|
# Generate random strengths for selected LoRAs
|
||||||
|
result_loras = []
|
||||||
|
for lora in selected:
|
||||||
|
model_str = round(
|
||||||
|
random.uniform(model_strength_min, model_strength_max), 2
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_same_clip_strength:
|
||||||
|
clip_str = model_str
|
||||||
|
else:
|
||||||
|
clip_str = round(
|
||||||
|
random.uniform(clip_strength_min, clip_strength_max), 2
|
||||||
|
)
|
||||||
|
|
||||||
|
result_loras.append({
|
||||||
|
'name': lora['model_name'],
|
||||||
|
'strength': model_str,
|
||||||
|
'clipStrength': clip_str,
|
||||||
|
'active': True,
|
||||||
|
'expanded': abs(model_str - clip_str) > 0.001,
|
||||||
|
'locked': False
|
||||||
|
})
|
||||||
|
|
||||||
|
# Merge with locked LoRAs
|
||||||
|
result_loras.extend(locked_loras)
|
||||||
|
|
||||||
|
return result_loras
|
||||||
|
|
||||||
|
async def _apply_pool_filters(self, available_loras: List[Dict], pool_config: Dict) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Apply pool_config filters to available LoRAs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
available_loras: List of all LoRA dicts
|
||||||
|
pool_config: Dict with filter settings from LoRA Pool node
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered list of LoRA dicts
|
||||||
|
"""
|
||||||
|
from .model_query import FilterCriteria
|
||||||
|
|
||||||
|
# Extract filter parameters from pool_config
|
||||||
|
selected_base_models = pool_config.get('selected_base_models', [])
|
||||||
|
include_tags = pool_config.get('include_tags', [])
|
||||||
|
exclude_tags = pool_config.get('exclude_tags', [])
|
||||||
|
include_folders = pool_config.get('include_folders', [])
|
||||||
|
exclude_folders = pool_config.get('exclude_folders', [])
|
||||||
|
no_credit_required = pool_config.get('no_credit_required', False)
|
||||||
|
allow_selling = pool_config.get('allow_selling', False)
|
||||||
|
|
||||||
|
# Build tag filters dict
|
||||||
|
tag_filters = {}
|
||||||
|
for tag in include_tags:
|
||||||
|
tag_filters[tag] = 'include'
|
||||||
|
for tag in exclude_tags:
|
||||||
|
tag_filters[tag] = 'exclude'
|
||||||
|
|
||||||
|
# Build folder filter
|
||||||
|
if include_folders or exclude_folders:
|
||||||
|
filtered = []
|
||||||
|
for lora in available_loras:
|
||||||
|
folder = lora.get('folder', '')
|
||||||
|
|
||||||
|
# Check exclude folders first
|
||||||
|
excluded = False
|
||||||
|
for exclude_folder in exclude_folders:
|
||||||
|
if folder.startswith(exclude_folder):
|
||||||
|
excluded = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if excluded:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check include folders
|
||||||
|
if include_folders:
|
||||||
|
included = False
|
||||||
|
for include_folder in include_folders:
|
||||||
|
if folder.startswith(include_folder):
|
||||||
|
included = True
|
||||||
|
break
|
||||||
|
if not included:
|
||||||
|
continue
|
||||||
|
|
||||||
|
filtered.append(lora)
|
||||||
|
|
||||||
|
available_loras = filtered
|
||||||
|
|
||||||
|
# Apply base model filter
|
||||||
|
if selected_base_models:
|
||||||
|
available_loras = [
|
||||||
|
lora for lora in available_loras
|
||||||
|
if lora.get('base_model') in selected_base_models
|
||||||
|
]
|
||||||
|
|
||||||
|
# Apply tag filters
|
||||||
|
if tag_filters:
|
||||||
|
criteria = FilterCriteria(tags=tag_filters)
|
||||||
|
available_loras = self.filter_set.apply(available_loras, criteria)
|
||||||
|
|
||||||
|
# Apply license filters
|
||||||
|
if no_credit_required:
|
||||||
|
available_loras = [
|
||||||
|
lora for lora in available_loras
|
||||||
|
if not lora.get('civitai', {}).get('allowNoCredit', True)
|
||||||
|
]
|
||||||
|
|
||||||
|
if allow_selling:
|
||||||
|
available_loras = [
|
||||||
|
lora for lora in available_loras
|
||||||
|
if lora.get('civitai', {}).get('allowCommercialUse', ['None'])[0] != 'None'
|
||||||
|
]
|
||||||
|
|
||||||
|
return available_loras
|
||||||
|
|||||||
191
tests/routes/test_randomizer_endpoints.py
Normal file
191
tests/routes/test_randomizer_endpoints.py
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
import json
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from py.routes.lora_routes import LoraRoutes
|
||||||
|
|
||||||
|
|
||||||
|
class DummyRequest:
|
||||||
|
def __init__(self, *, query=None, match_info=None, json_data=None):
|
||||||
|
self.query = query or {}
|
||||||
|
self.match_info = match_info or {}
|
||||||
|
self._json_data = json_data or {}
|
||||||
|
|
||||||
|
async def json(self):
|
||||||
|
return self._json_data
|
||||||
|
|
||||||
|
|
||||||
|
class StubLoraService:
|
||||||
|
"""Stub service for testing randomizer endpoints"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.random_loras = []
|
||||||
|
|
||||||
|
async def get_random_loras(self, **kwargs):
|
||||||
|
return self.random_loras
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def routes():
|
||||||
|
handler = LoraRoutes()
|
||||||
|
handler.service = StubLoraService()
|
||||||
|
return handler
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_random_loras_success(routes):
|
||||||
|
"""Test successful random LoRA generation"""
|
||||||
|
routes.service.random_loras = [
|
||||||
|
{
|
||||||
|
'name': 'test_lora_1',
|
||||||
|
'strength': 0.8,
|
||||||
|
'clipStrength': 0.8,
|
||||||
|
'active': True,
|
||||||
|
'expanded': False,
|
||||||
|
'locked': False
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'name': 'test_lora_2',
|
||||||
|
'strength': 0.6,
|
||||||
|
'clipStrength': 0.6,
|
||||||
|
'active': True,
|
||||||
|
'expanded': False,
|
||||||
|
'locked': False
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
request = DummyRequest(json_data={
|
||||||
|
'count': 5,
|
||||||
|
'model_strength_min': 0.5,
|
||||||
|
'model_strength_max': 1.0,
|
||||||
|
'use_same_clip_strength': True,
|
||||||
|
'locked_loras': []
|
||||||
|
})
|
||||||
|
|
||||||
|
response = await routes.get_random_loras(request)
|
||||||
|
payload = json.loads(response.text)
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
assert payload['success'] is True
|
||||||
|
assert 'loras' in payload
|
||||||
|
assert payload['count'] == 2
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_random_loras_with_range(routes):
|
||||||
|
"""Test random LoRAs with count range"""
|
||||||
|
routes.service.random_loras = [
|
||||||
|
{
|
||||||
|
'name': 'test_lora_1',
|
||||||
|
'strength': 0.8,
|
||||||
|
'clipStrength': 0.8,
|
||||||
|
'active': True,
|
||||||
|
'expanded': False,
|
||||||
|
'locked': False
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
request = DummyRequest(json_data={
|
||||||
|
'count_min': 3,
|
||||||
|
'count_max': 7,
|
||||||
|
'model_strength_min': 0.0,
|
||||||
|
'model_strength_max': 1.0,
|
||||||
|
'use_same_clip_strength': True
|
||||||
|
})
|
||||||
|
|
||||||
|
response = await routes.get_random_loras(request)
|
||||||
|
payload = json.loads(response.text)
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
assert payload['success'] is True
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_random_loras_invalid_count(routes):
|
||||||
|
"""Test invalid count parameter"""
|
||||||
|
request = DummyRequest(json_data={
|
||||||
|
'count': 150, # Over limit
|
||||||
|
'model_strength_min': 0.0,
|
||||||
|
'model_strength_max': 1.0
|
||||||
|
})
|
||||||
|
|
||||||
|
response = await routes.get_random_loras(request)
|
||||||
|
payload = json.loads(response.text)
|
||||||
|
|
||||||
|
assert response.status == 400
|
||||||
|
assert payload['success'] is False
|
||||||
|
assert 'Count must be between 1 and 100' in payload['error']
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_random_loras_invalid_strength(routes):
|
||||||
|
"""Test invalid strength range"""
|
||||||
|
request = DummyRequest(json_data={
|
||||||
|
'count': 5,
|
||||||
|
'model_strength_min': -0.5, # Invalid
|
||||||
|
'model_strength_max': 1.0
|
||||||
|
})
|
||||||
|
|
||||||
|
response = await routes.get_random_loras(request)
|
||||||
|
payload = json.loads(response.text)
|
||||||
|
|
||||||
|
assert response.status == 400
|
||||||
|
assert payload['success'] is False
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_random_loras_with_locked(routes):
|
||||||
|
"""Test random LoRAs with locked items"""
|
||||||
|
routes.service.random_loras = [
|
||||||
|
{
|
||||||
|
'name': 'new_lora',
|
||||||
|
'strength': 0.7,
|
||||||
|
'clipStrength': 0.7,
|
||||||
|
'active': True,
|
||||||
|
'expanded': False,
|
||||||
|
'locked': False
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'name': 'locked_lora',
|
||||||
|
'strength': 0.9,
|
||||||
|
'clipStrength': 0.9,
|
||||||
|
'active': True,
|
||||||
|
'expanded': False,
|
||||||
|
'locked': True
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
request = DummyRequest(json_data={
|
||||||
|
'count': 5,
|
||||||
|
'model_strength_min': 0.5,
|
||||||
|
'model_strength_max': 1.0,
|
||||||
|
'use_same_clip_strength': True,
|
||||||
|
'locked_loras': [
|
||||||
|
{
|
||||||
|
'name': 'locked_lora',
|
||||||
|
'strength': 0.9,
|
||||||
|
'clipStrength': 0.9,
|
||||||
|
'active': True,
|
||||||
|
'expanded': False,
|
||||||
|
'locked': True
|
||||||
|
}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
response = await routes.get_random_loras(request)
|
||||||
|
payload = json.loads(response.text)
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
assert payload['success'] is True
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_random_loras_error(routes, monkeypatch):
|
||||||
|
"""Test error handling"""
|
||||||
|
async def failing(*_args, **_kwargs):
|
||||||
|
raise RuntimeError("Service error")
|
||||||
|
|
||||||
|
routes.service.get_random_loras = failing
|
||||||
|
request = DummyRequest(json_data={'count': 5})
|
||||||
|
|
||||||
|
response = await routes.get_random_loras(request)
|
||||||
|
payload = json.loads(response.text)
|
||||||
|
|
||||||
|
assert response.status == 500
|
||||||
|
assert payload['success'] is False
|
||||||
|
assert 'error' in payload
|
||||||
110
vue-widgets/src/components/LoraRandomizerWidget.vue
Normal file
110
vue-widgets/src/components/LoraRandomizerWidget.vue
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
<template>
|
||||||
|
<div class="lora-randomizer-widget">
|
||||||
|
<LoraRandomizerSettingsView
|
||||||
|
:count-mode="state.countMode.value"
|
||||||
|
:count-fixed="state.countFixed.value"
|
||||||
|
:count-min="state.countMin.value"
|
||||||
|
:count-max="state.countMax.value"
|
||||||
|
:model-strength-min="state.modelStrengthMin.value"
|
||||||
|
:model-strength-max="state.modelStrengthMax.value"
|
||||||
|
:use-same-clip-strength="state.useSameClipStrength.value"
|
||||||
|
:clip-strength-min="state.clipStrengthMin.value"
|
||||||
|
:clip-strength-max="state.clipStrengthMax.value"
|
||||||
|
:roll-mode="state.rollMode.value"
|
||||||
|
:is-rolling="state.isRolling.value"
|
||||||
|
:is-clip-strength-disabled="state.isClipStrengthDisabled.value"
|
||||||
|
@update:count-mode="state.countMode.value = $event"
|
||||||
|
@update:count-fixed="state.countFixed.value = $event"
|
||||||
|
@update:count-min="state.countMin.value = $event"
|
||||||
|
@update:count-max="state.countMax.value = $event"
|
||||||
|
@update:model-strength-min="state.modelStrengthMin.value = $event"
|
||||||
|
@update:model-strength-max="state.modelStrengthMax.value = $event"
|
||||||
|
@update:use-same-clip-strength="state.useSameClipStrength.value = $event"
|
||||||
|
@update:clip-strength-min="state.clipStrengthMin.value = $event"
|
||||||
|
@update:clip-strength-max="state.clipStrengthMax.value = $event"
|
||||||
|
@update:roll-mode="state.rollMode.value = $event"
|
||||||
|
@roll="handleRoll"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<script setup lang="ts">
|
||||||
|
import { onMounted } from 'vue'
|
||||||
|
import LoraRandomizerSettingsView from './lora-randomizer/LoraRandomizerSettingsView.vue'
|
||||||
|
import { useLoraRandomizerState } from '../composables/useLoraRandomizerState'
|
||||||
|
import type { ComponentWidget, RandomizerConfig } from '../composables/types'
|
||||||
|
|
||||||
|
// Props
|
||||||
|
const props = defineProps<{
|
||||||
|
widget: ComponentWidget
|
||||||
|
node: { id: number }
|
||||||
|
}>()
|
||||||
|
|
||||||
|
// State management
|
||||||
|
const state = useLoraRandomizerState(props.widget)
|
||||||
|
|
||||||
|
// Handle roll button click
|
||||||
|
const handleRoll = async () => {
|
||||||
|
try {
|
||||||
|
console.log('[LoraRandomizerWidget] Roll button clicked')
|
||||||
|
|
||||||
|
// Get pool config from connected input (if any)
|
||||||
|
// This would need to be passed from the node's pool_config input
|
||||||
|
const poolConfig = null // TODO: Get from node input if connected
|
||||||
|
|
||||||
|
// Get locked loras from the loras widget
|
||||||
|
// This would need to be retrieved from the loras widget on the node
|
||||||
|
const lockedLoras: any[] = [] // TODO: Get from loras widget
|
||||||
|
|
||||||
|
// Call API to get random loras
|
||||||
|
const randomLoras = await state.rollLoras(poolConfig, lockedLoras)
|
||||||
|
|
||||||
|
console.log('[LoraRandomizerWidget] Got random LoRAs:', randomLoras)
|
||||||
|
|
||||||
|
// Update the loras widget with the new selection
|
||||||
|
// This will be handled by emitting an event or directly updating the loras widget
|
||||||
|
// For now, we'll emit a custom event that the parent widget handler can catch
|
||||||
|
if (typeof (props.widget as any).onRoll === 'function') {
|
||||||
|
(props.widget as any).onRoll(randomLoras)
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('[LoraRandomizerWidget] Error rolling LoRAs:', error)
|
||||||
|
alert('Failed to roll LoRAs: ' + (error as Error).message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lifecycle
|
||||||
|
onMounted(async () => {
|
||||||
|
// Setup serialization
|
||||||
|
props.widget.serializeValue = async () => {
|
||||||
|
const config = state.buildConfig()
|
||||||
|
console.log('[LoraRandomizerWidget] Serializing config:', config)
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle external value updates (e.g., loading workflow, paste)
|
||||||
|
props.widget.onSetValue = (v) => {
|
||||||
|
console.log('[LoraRandomizerWidget] Restoring from config:', v)
|
||||||
|
state.restoreFromConfig(v as RandomizerConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore from saved value
|
||||||
|
if (props.widget.value) {
|
||||||
|
console.log('[LoraRandomizerWidget] Restoring from saved value:', props.widget.value)
|
||||||
|
state.restoreFromConfig(props.widget.value as RandomizerConfig)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<style scoped>
|
||||||
|
.lora-randomizer-widget {
|
||||||
|
padding: 12px;
|
||||||
|
background: rgba(40, 44, 52, 0.6);
|
||||||
|
border-radius: 4px;
|
||||||
|
height: 100%;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
overflow: hidden;
|
||||||
|
box-sizing: border-box;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
@@ -0,0 +1,359 @@
|
|||||||
|
<template>
|
||||||
|
<div class="randomizer-settings">
|
||||||
|
<div class="settings-header">
|
||||||
|
<h3 class="settings-title">RANDOMIZER SETTINGS</h3>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- LoRA Count -->
|
||||||
|
<div class="setting-section">
|
||||||
|
<label class="setting-label">LoRA Count</label>
|
||||||
|
<div class="count-mode-selector">
|
||||||
|
<label class="radio-label">
|
||||||
|
<input
|
||||||
|
type="radio"
|
||||||
|
name="count-mode"
|
||||||
|
value="fixed"
|
||||||
|
:checked="countMode === 'fixed'"
|
||||||
|
@change="$emit('update:countMode', 'fixed')"
|
||||||
|
/>
|
||||||
|
<span>Fixed:</span>
|
||||||
|
<input
|
||||||
|
type="number"
|
||||||
|
class="number-input"
|
||||||
|
:value="countFixed"
|
||||||
|
:disabled="countMode !== 'fixed'"
|
||||||
|
min="1"
|
||||||
|
max="100"
|
||||||
|
@input="$emit('update:countFixed', parseInt(($event.target as HTMLInputElement).value))"
|
||||||
|
/>
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
<div class="count-mode-selector">
|
||||||
|
<label class="radio-label">
|
||||||
|
<input
|
||||||
|
type="radio"
|
||||||
|
name="count-mode"
|
||||||
|
value="range"
|
||||||
|
:checked="countMode === 'range'"
|
||||||
|
@change="$emit('update:countMode', 'range')"
|
||||||
|
/>
|
||||||
|
<span>Range:</span>
|
||||||
|
<input
|
||||||
|
type="number"
|
||||||
|
class="number-input"
|
||||||
|
:value="countMin"
|
||||||
|
:disabled="countMode !== 'range'"
|
||||||
|
min="1"
|
||||||
|
max="100"
|
||||||
|
@input="$emit('update:countMin', parseInt(($event.target as HTMLInputElement).value))"
|
||||||
|
/>
|
||||||
|
<span>to</span>
|
||||||
|
<input
|
||||||
|
type="number"
|
||||||
|
class="number-input"
|
||||||
|
:value="countMax"
|
||||||
|
:disabled="countMode !== 'range'"
|
||||||
|
min="1"
|
||||||
|
max="100"
|
||||||
|
@input="$emit('update:countMax', parseInt(($event.target as HTMLInputElement).value))"
|
||||||
|
/>
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Model Strength Range -->
|
||||||
|
<div class="setting-section">
|
||||||
|
<label class="setting-label">Model Strength Range</label>
|
||||||
|
<div class="strength-inputs">
|
||||||
|
<div class="strength-input-group">
|
||||||
|
<label>Min:</label>
|
||||||
|
<input
|
||||||
|
type="number"
|
||||||
|
class="number-input"
|
||||||
|
:value="modelStrengthMin"
|
||||||
|
min="0"
|
||||||
|
max="10"
|
||||||
|
step="0.1"
|
||||||
|
@input="$emit('update:modelStrengthMin', parseFloat(($event.target as HTMLInputElement).value))"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div class="strength-input-group">
|
||||||
|
<label>Max:</label>
|
||||||
|
<input
|
||||||
|
type="number"
|
||||||
|
class="number-input"
|
||||||
|
:value="modelStrengthMax"
|
||||||
|
min="0"
|
||||||
|
max="10"
|
||||||
|
step="0.1"
|
||||||
|
@input="$emit('update:modelStrengthMax', parseFloat(($event.target as HTMLInputElement).value))"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Clip Strength Range -->
|
||||||
|
<div class="setting-section">
|
||||||
|
<label class="setting-label">Clip Strength Range</label>
|
||||||
|
<div class="checkbox-group">
|
||||||
|
<label class="checkbox-label">
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
:checked="useSameClipStrength"
|
||||||
|
@change="$emit('update:useSameClipStrength', ($event.target as HTMLInputElement).checked)"
|
||||||
|
/>
|
||||||
|
<span>Same as model</span>
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
<div class="strength-inputs" :class="{ disabled: isClipStrengthDisabled }">
|
||||||
|
<div class="strength-input-group">
|
||||||
|
<label>Min:</label>
|
||||||
|
<input
|
||||||
|
type="number"
|
||||||
|
class="number-input"
|
||||||
|
:value="clipStrengthMin"
|
||||||
|
:disabled="isClipStrengthDisabled"
|
||||||
|
min="0"
|
||||||
|
max="10"
|
||||||
|
step="0.1"
|
||||||
|
@input="$emit('update:clipStrengthMin', parseFloat(($event.target as HTMLInputElement).value))"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div class="strength-input-group">
|
||||||
|
<label>Max:</label>
|
||||||
|
<input
|
||||||
|
type="number"
|
||||||
|
class="number-input"
|
||||||
|
:value="clipStrengthMax"
|
||||||
|
:disabled="isClipStrengthDisabled"
|
||||||
|
min="0"
|
||||||
|
max="10"
|
||||||
|
step="0.1"
|
||||||
|
@input="$emit('update:clipStrengthMax', parseFloat(($event.target as HTMLInputElement).value))"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Roll Mode -->
|
||||||
|
<div class="setting-section">
|
||||||
|
<label class="setting-label">Roll Mode</label>
|
||||||
|
<div class="roll-mode-selector">
|
||||||
|
<label class="radio-label">
|
||||||
|
<input
|
||||||
|
type="radio"
|
||||||
|
name="roll-mode"
|
||||||
|
value="frontend"
|
||||||
|
:checked="rollMode === 'frontend'"
|
||||||
|
@change="$emit('update:rollMode', 'frontend')"
|
||||||
|
/>
|
||||||
|
<span>Frontend Roll (fixed until re-rolled)</span>
|
||||||
|
</label>
|
||||||
|
<button
|
||||||
|
class="roll-button"
|
||||||
|
:disabled="rollMode !== 'frontend' || isRolling"
|
||||||
|
@click="$emit('roll')"
|
||||||
|
>
|
||||||
|
<span v-if="!isRolling">🎲 Roll</span>
|
||||||
|
<span v-else>Rolling...</span>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
<div class="roll-mode-selector">
|
||||||
|
<label class="radio-label">
|
||||||
|
<input
|
||||||
|
type="radio"
|
||||||
|
name="roll-mode"
|
||||||
|
value="backend"
|
||||||
|
:checked="rollMode === 'backend'"
|
||||||
|
@change="$emit('update:rollMode', 'backend')"
|
||||||
|
/>
|
||||||
|
<span>Backend Roll (randomizes each execution)</span>
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<script setup lang="ts">
|
||||||
|
defineProps<{
|
||||||
|
countMode: 'fixed' | 'range'
|
||||||
|
countFixed: number
|
||||||
|
countMin: number
|
||||||
|
countMax: number
|
||||||
|
modelStrengthMin: number
|
||||||
|
modelStrengthMax: number
|
||||||
|
useSameClipStrength: boolean
|
||||||
|
clipStrengthMin: number
|
||||||
|
clipStrengthMax: number
|
||||||
|
rollMode: 'frontend' | 'backend'
|
||||||
|
isRolling: boolean
|
||||||
|
isClipStrengthDisabled: boolean
|
||||||
|
}>()
|
||||||
|
|
||||||
|
defineEmits<{
|
||||||
|
'update:countMode': [value: 'fixed' | 'range']
|
||||||
|
'update:countFixed': [value: number]
|
||||||
|
'update:countMin': [value: number]
|
||||||
|
'update:countMax': [value: number]
|
||||||
|
'update:modelStrengthMin': [value: number]
|
||||||
|
'update:modelStrengthMax': [value: number]
|
||||||
|
'update:useSameClipStrength': [value: boolean]
|
||||||
|
'update:clipStrengthMin': [value: number]
|
||||||
|
'update:clipStrengthMax': [value: number]
|
||||||
|
'update:rollMode': [value: 'frontend' | 'backend']
|
||||||
|
roll: []
|
||||||
|
}>()
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<style scoped>
|
||||||
|
.randomizer-settings {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 16px;
|
||||||
|
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||||
|
color: #e4e4e7;
|
||||||
|
}
|
||||||
|
|
||||||
|
.settings-header {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: space-between;
|
||||||
|
margin-bottom: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.settings-title {
|
||||||
|
font-size: 11px;
|
||||||
|
font-weight: 600;
|
||||||
|
letter-spacing: 0.05em;
|
||||||
|
color: #a1a1aa;
|
||||||
|
margin: 0;
|
||||||
|
text-transform: uppercase;
|
||||||
|
}
|
||||||
|
|
||||||
|
.setting-section {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.setting-label {
|
||||||
|
font-size: 12px;
|
||||||
|
font-weight: 500;
|
||||||
|
color: #d4d4d8;
|
||||||
|
}
|
||||||
|
|
||||||
|
.count-mode-selector,
|
||||||
|
.roll-mode-selector {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 8px;
|
||||||
|
padding: 6px 8px;
|
||||||
|
background: rgba(30, 30, 36, 0.5);
|
||||||
|
border-radius: 4px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.radio-label {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 8px;
|
||||||
|
font-size: 13px;
|
||||||
|
color: #e4e4e7;
|
||||||
|
cursor: pointer;
|
||||||
|
flex: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
.radio-label input[type='radio'] {
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
|
||||||
|
.radio-label input[type='radio']:disabled {
|
||||||
|
cursor: not-allowed;
|
||||||
|
}
|
||||||
|
|
||||||
|
.number-input {
|
||||||
|
width: 60px;
|
||||||
|
padding: 4px 8px;
|
||||||
|
background: rgba(20, 20, 24, 0.6);
|
||||||
|
border: 1px solid rgba(255, 255, 255, 0.1);
|
||||||
|
border-radius: 3px;
|
||||||
|
color: #e4e4e7;
|
||||||
|
font-size: 13px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.number-input:disabled {
|
||||||
|
opacity: 0.5;
|
||||||
|
cursor: not-allowed;
|
||||||
|
}
|
||||||
|
|
||||||
|
.strength-inputs {
|
||||||
|
display: flex;
|
||||||
|
gap: 12px;
|
||||||
|
padding: 6px 8px;
|
||||||
|
background: rgba(30, 30, 36, 0.5);
|
||||||
|
border-radius: 4px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.strength-inputs.disabled {
|
||||||
|
opacity: 0.5;
|
||||||
|
}
|
||||||
|
|
||||||
|
.strength-input-group {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 6px;
|
||||||
|
flex: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
.strength-input-group label {
|
||||||
|
font-size: 12px;
|
||||||
|
color: #d4d4d8;
|
||||||
|
}
|
||||||
|
|
||||||
|
.checkbox-group {
|
||||||
|
padding: 6px 8px;
|
||||||
|
background: rgba(30, 30, 36, 0.5);
|
||||||
|
border-radius: 4px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.checkbox-label {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 8px;
|
||||||
|
font-size: 13px;
|
||||||
|
color: #e4e4e7;
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
|
||||||
|
.checkbox-label input[type='checkbox'] {
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
|
||||||
|
.roll-button {
|
||||||
|
padding: 6px 16px;
|
||||||
|
background: linear-gradient(135deg, #3b82f6 0%, #2563eb 100%);
|
||||||
|
border: none;
|
||||||
|
border-radius: 4px;
|
||||||
|
color: white;
|
||||||
|
font-size: 13px;
|
||||||
|
font-weight: 500;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: all 0.2s;
|
||||||
|
white-space: nowrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
.roll-button:hover:not(:disabled) {
|
||||||
|
background: linear-gradient(135deg, #2563eb 0%, #1d4ed8 100%);
|
||||||
|
transform: translateY(-1px);
|
||||||
|
box-shadow: 0 2px 8px rgba(59, 130, 246, 0.3);
|
||||||
|
}
|
||||||
|
|
||||||
|
.roll-button:active:not(:disabled) {
|
||||||
|
transform: translateY(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
.roll-button:disabled {
|
||||||
|
opacity: 0.5;
|
||||||
|
cursor: not-allowed;
|
||||||
|
background: linear-gradient(135deg, #52525b 0%, #3f3f46 100%);
|
||||||
|
}
|
||||||
|
</style>
|
||||||
@@ -37,13 +37,6 @@ export interface FolderTreeNode {
|
|||||||
children?: FolderTreeNode[]
|
children?: FolderTreeNode[]
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ComponentWidget {
|
|
||||||
serializeValue?: () => Promise<LoraPoolConfig>
|
|
||||||
value?: LoraPoolConfig | LegacyLoraPoolConfig
|
|
||||||
onSetValue?: (v: LoraPoolConfig | LegacyLoraPoolConfig) => void
|
|
||||||
updateConfig?: (v: LoraPoolConfig) => void
|
|
||||||
}
|
|
||||||
|
|
||||||
// Legacy config for migration (v1)
|
// Legacy config for migration (v1)
|
||||||
export interface LegacyLoraPoolConfig {
|
export interface LegacyLoraPoolConfig {
|
||||||
version: 1
|
version: 1
|
||||||
@@ -59,3 +52,33 @@ export interface LegacyLoraPoolConfig {
|
|||||||
}
|
}
|
||||||
preview: { matchCount: number; lastUpdated: number }
|
preview: { matchCount: number; lastUpdated: number }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Randomizer config
|
||||||
|
export interface RandomizerConfig {
|
||||||
|
count_mode: 'fixed' | 'range'
|
||||||
|
count_fixed: number
|
||||||
|
count_min: number
|
||||||
|
count_max: number
|
||||||
|
model_strength_min: number
|
||||||
|
model_strength_max: number
|
||||||
|
use_same_clip_strength: boolean
|
||||||
|
clip_strength_min: number
|
||||||
|
clip_strength_max: number
|
||||||
|
roll_mode: 'frontend' | 'backend'
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface LoraEntry {
|
||||||
|
name: string
|
||||||
|
strength: number
|
||||||
|
clipStrength: number
|
||||||
|
active: boolean
|
||||||
|
expanded: boolean
|
||||||
|
locked: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ComponentWidget {
|
||||||
|
serializeValue?: () => Promise<LoraPoolConfig | RandomizerConfig>
|
||||||
|
value?: LoraPoolConfig | LegacyLoraPoolConfig | RandomizerConfig
|
||||||
|
onSetValue?: (v: LoraPoolConfig | LegacyLoraPoolConfig | RandomizerConfig) => void
|
||||||
|
updateConfig?: (v: LoraPoolConfig | RandomizerConfig) => void
|
||||||
|
}
|
||||||
|
|||||||
142
vue-widgets/src/composables/useLoraRandomizerState.ts
Normal file
142
vue-widgets/src/composables/useLoraRandomizerState.ts
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
import { ref, computed } from 'vue'
|
||||||
|
import type { ComponentWidget, RandomizerConfig, LoraEntry } from './types'
|
||||||
|
|
||||||
|
export function useLoraRandomizerState(widget: ComponentWidget) {
|
||||||
|
// State refs
|
||||||
|
const countMode = ref<'fixed' | 'range'>('range')
|
||||||
|
const countFixed = ref(5)
|
||||||
|
const countMin = ref(3)
|
||||||
|
const countMax = ref(7)
|
||||||
|
const modelStrengthMin = ref(0.0)
|
||||||
|
const modelStrengthMax = ref(1.0)
|
||||||
|
const useSameClipStrength = ref(true)
|
||||||
|
const clipStrengthMin = ref(0.0)
|
||||||
|
const clipStrengthMax = ref(1.0)
|
||||||
|
const rollMode = ref<'frontend' | 'backend'>('frontend')
|
||||||
|
const isRolling = ref(false)
|
||||||
|
|
||||||
|
// Build config object from current state
|
||||||
|
const buildConfig = (): RandomizerConfig => ({
|
||||||
|
count_mode: countMode.value,
|
||||||
|
count_fixed: countFixed.value,
|
||||||
|
count_min: countMin.value,
|
||||||
|
count_max: countMax.value,
|
||||||
|
model_strength_min: modelStrengthMin.value,
|
||||||
|
model_strength_max: modelStrengthMax.value,
|
||||||
|
use_same_clip_strength: useSameClipStrength.value,
|
||||||
|
clip_strength_min: clipStrengthMin.value,
|
||||||
|
clip_strength_max: clipStrengthMax.value,
|
||||||
|
roll_mode: rollMode.value,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Restore state from config object
|
||||||
|
const restoreFromConfig = (config: RandomizerConfig) => {
|
||||||
|
countMode.value = config.count_mode || 'range'
|
||||||
|
countFixed.value = config.count_fixed || 5
|
||||||
|
countMin.value = config.count_min || 3
|
||||||
|
countMax.value = config.count_max || 7
|
||||||
|
modelStrengthMin.value = config.model_strength_min ?? 0.0
|
||||||
|
modelStrengthMax.value = config.model_strength_max ?? 1.0
|
||||||
|
useSameClipStrength.value = config.use_same_clip_strength ?? true
|
||||||
|
clipStrengthMin.value = config.clip_strength_min ?? 0.0
|
||||||
|
clipStrengthMax.value = config.clip_strength_max ?? 1.0
|
||||||
|
rollMode.value = config.roll_mode || 'frontend'
|
||||||
|
}
|
||||||
|
|
||||||
|
// Roll loras - call API to get random selection
|
||||||
|
const rollLoras = async (
|
||||||
|
poolConfig: any | null,
|
||||||
|
lockedLoras: LoraEntry[]
|
||||||
|
): Promise<LoraEntry[]> => {
|
||||||
|
try {
|
||||||
|
isRolling.value = true
|
||||||
|
|
||||||
|
const config = buildConfig()
|
||||||
|
|
||||||
|
// Build request body
|
||||||
|
const requestBody: any = {
|
||||||
|
model_strength_min: config.model_strength_min,
|
||||||
|
model_strength_max: config.model_strength_max,
|
||||||
|
use_same_clip_strength: config.use_same_clip_strength,
|
||||||
|
clip_strength_min: config.clip_strength_min,
|
||||||
|
clip_strength_max: config.clip_strength_max,
|
||||||
|
locked_loras: lockedLoras,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add count parameters
|
||||||
|
if (config.count_mode === 'fixed') {
|
||||||
|
requestBody.count = config.count_fixed
|
||||||
|
} else {
|
||||||
|
requestBody.count_min = config.count_min
|
||||||
|
requestBody.count_max = config.count_max
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add pool config if provided
|
||||||
|
if (poolConfig) {
|
||||||
|
// Convert pool config to backend format
|
||||||
|
requestBody.pool_config = {
|
||||||
|
selected_base_models: poolConfig.filters?.baseModels || [],
|
||||||
|
include_tags: poolConfig.filters?.tags?.include || [],
|
||||||
|
exclude_tags: poolConfig.filters?.tags?.exclude || [],
|
||||||
|
include_folders: poolConfig.filters?.folders?.include || [],
|
||||||
|
exclude_folders: poolConfig.filters?.folders?.exclude || [],
|
||||||
|
no_credit_required: poolConfig.filters?.license?.noCreditRequired || false,
|
||||||
|
allow_selling: poolConfig.filters?.license?.allowSelling || false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call API endpoint
|
||||||
|
const response = await fetch('/api/lm/loras/random-sample', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
body: JSON.stringify(requestBody),
|
||||||
|
})
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
const errorData = await response.json()
|
||||||
|
throw new Error(errorData.error || 'Failed to fetch random LoRAs')
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json()
|
||||||
|
|
||||||
|
if (!data.success) {
|
||||||
|
throw new Error(data.error || 'Failed to get random LoRAs')
|
||||||
|
}
|
||||||
|
|
||||||
|
return data.loras || []
|
||||||
|
} catch (error) {
|
||||||
|
console.error('[LoraRandomizerState] Error rolling LoRAs:', error)
|
||||||
|
throw error
|
||||||
|
} finally {
|
||||||
|
isRolling.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Computed properties
|
||||||
|
const isClipStrengthDisabled = computed(() => useSameClipStrength.value)
|
||||||
|
|
||||||
|
return {
|
||||||
|
// State refs
|
||||||
|
countMode,
|
||||||
|
countFixed,
|
||||||
|
countMin,
|
||||||
|
countMax,
|
||||||
|
modelStrengthMin,
|
||||||
|
modelStrengthMax,
|
||||||
|
useSameClipStrength,
|
||||||
|
clipStrengthMin,
|
||||||
|
clipStrengthMax,
|
||||||
|
rollMode,
|
||||||
|
isRolling,
|
||||||
|
|
||||||
|
// Computed
|
||||||
|
isClipStrengthDisabled,
|
||||||
|
|
||||||
|
// Methods
|
||||||
|
buildConfig,
|
||||||
|
restoreFromConfig,
|
||||||
|
rollLoras,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,13 +1,17 @@
|
|||||||
import { createApp, type App as VueApp } from 'vue'
|
import { createApp, type App as VueApp } from 'vue'
|
||||||
import PrimeVue from 'primevue/config'
|
import PrimeVue from 'primevue/config'
|
||||||
import LoraPoolWidget from '@/components/LoraPoolWidget.vue'
|
import LoraPoolWidget from '@/components/LoraPoolWidget.vue'
|
||||||
import type { LoraPoolConfig, LegacyLoraPoolConfig } from './composables/types'
|
import LoraRandomizerWidget from '@/components/LoraRandomizerWidget.vue'
|
||||||
|
import type { LoraPoolConfig, LegacyLoraPoolConfig, RandomizerConfig } from './composables/types'
|
||||||
|
|
||||||
// @ts-ignore - ComfyUI external module
|
// @ts-ignore - ComfyUI external module
|
||||||
import { app } from '../../../scripts/app.js'
|
import { app } from '../../../scripts/app.js'
|
||||||
|
|
||||||
const vueApps = new Map<number, VueApp>()
|
const vueApps = new Map<number, VueApp>()
|
||||||
|
|
||||||
|
// Cache for dynamically loaded addLorasWidget module
|
||||||
|
let addLorasWidgetCache: any = null
|
||||||
|
|
||||||
// @ts-ignore
|
// @ts-ignore
|
||||||
function createLoraPoolWidget(node) {
|
function createLoraPoolWidget(node) {
|
||||||
const container = document.createElement('div')
|
const container = document.createElement('div')
|
||||||
@@ -78,14 +82,109 @@ function createLoraPoolWidget(node) {
|
|||||||
return { widget }
|
return { widget }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// @ts-ignore
|
||||||
|
function createLoraRandomizerWidget(node) {
|
||||||
|
const container = document.createElement('div')
|
||||||
|
container.id = `lora-randomizer-widget-${node.id}`
|
||||||
|
container.style.width = '100%'
|
||||||
|
container.style.height = '100%'
|
||||||
|
container.style.display = 'flex'
|
||||||
|
container.style.flexDirection = 'column'
|
||||||
|
container.style.overflow = 'hidden'
|
||||||
|
|
||||||
|
let internalValue: RandomizerConfig | undefined
|
||||||
|
|
||||||
|
const widget = node.addDOMWidget(
|
||||||
|
'randomizer_config',
|
||||||
|
'RANDOMIZER_CONFIG',
|
||||||
|
container,
|
||||||
|
{
|
||||||
|
getValue() {
|
||||||
|
return internalValue
|
||||||
|
},
|
||||||
|
setValue(v: RandomizerConfig) {
|
||||||
|
internalValue = v
|
||||||
|
if (typeof widget.onSetValue === 'function') {
|
||||||
|
widget.onSetValue(v)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
serialize: true,
|
||||||
|
getMinHeight() {
|
||||||
|
return 500
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
widget.updateConfig = (v: RandomizerConfig) => {
|
||||||
|
internalValue = v
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle roll event from Vue component
|
||||||
|
widget.onRoll = (randomLoras: any[]) => {
|
||||||
|
console.log('[createLoraRandomizerWidget] Roll event received:', randomLoras)
|
||||||
|
|
||||||
|
// Find the loras widget on this node and update it
|
||||||
|
const lorasWidget = node.widgets.find((w: any) => w.name === 'loras')
|
||||||
|
if (lorasWidget) {
|
||||||
|
lorasWidget.value = randomLoras
|
||||||
|
console.log('[createLoraRandomizerWidget] Updated loras widget with rolled LoRAs')
|
||||||
|
} else {
|
||||||
|
console.warn('[createLoraRandomizerWidget] loras widget not found on node')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const vueApp = createApp(LoraRandomizerWidget, {
|
||||||
|
widget,
|
||||||
|
node
|
||||||
|
})
|
||||||
|
|
||||||
|
vueApp.use(PrimeVue, {
|
||||||
|
unstyled: true,
|
||||||
|
ripple: false
|
||||||
|
})
|
||||||
|
|
||||||
|
vueApp.mount(container)
|
||||||
|
vueApps.set(node.id + 10000, vueApp) // Offset to avoid collision with pool widget
|
||||||
|
|
||||||
|
widget.computeLayoutSize = () => {
|
||||||
|
const minWidth = 500
|
||||||
|
const minHeight = 500
|
||||||
|
const maxHeight = 500
|
||||||
|
|
||||||
|
return { minHeight, minWidth, maxHeight }
|
||||||
|
}
|
||||||
|
|
||||||
|
widget.onRemove = () => {
|
||||||
|
const vueApp = vueApps.get(node.id + 10000)
|
||||||
|
if (vueApp) {
|
||||||
|
vueApp.unmount()
|
||||||
|
vueApps.delete(node.id + 10000)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return { widget }
|
||||||
|
}
|
||||||
|
|
||||||
app.registerExtension({
|
app.registerExtension({
|
||||||
name: 'LoraManager.VueWidgets',
|
name: 'LoraManager.VueWidgets',
|
||||||
|
|
||||||
getCustomWidgets() {
|
getCustomWidgets() {
|
||||||
return {
|
return {
|
||||||
// @ts-ignore
|
// @ts-ignore
|
||||||
LORA_POOL_CONFIG(node) {
|
LORA_POOL_CONFIG(node) {
|
||||||
return createLoraPoolWidget(node)
|
return createLoraPoolWidget(node)
|
||||||
|
},
|
||||||
|
// @ts-ignore
|
||||||
|
RANDOMIZER_CONFIG(node) {
|
||||||
|
return createLoraRandomizerWidget(node)
|
||||||
|
},
|
||||||
|
// @ts-ignore
|
||||||
|
async LORAS(node: any) {
|
||||||
|
if (!addLorasWidgetCache) {
|
||||||
|
const module = await import(/* @vite-ignore */ '../loras_widget.js')
|
||||||
|
addLorasWidgetCache = module.addLorasWidget
|
||||||
|
}
|
||||||
|
return addLorasWidgetCache(node, 'loras', {}, null)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,7 +21,8 @@ export default defineConfig({
|
|||||||
},
|
},
|
||||||
rollupOptions: {
|
rollupOptions: {
|
||||||
external: [
|
external: [
|
||||||
'../../../scripts/app.js'
|
'../../../scripts/app.js',
|
||||||
|
'../loras_widget.js'
|
||||||
],
|
],
|
||||||
output: {
|
output: {
|
||||||
dir: '../web/comfyui/vue-widgets',
|
dir: '../web/comfyui/vue-widgets',
|
||||||
|
|||||||
40
web/comfyui/lora_demo_widget.js
Normal file
40
web/comfyui/lora_demo_widget.js
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import { app } from "../../../scripts/app.js";
|
||||||
|
|
||||||
|
app.registerExtension({
|
||||||
|
name: "LoraManager.LoraDemo",
|
||||||
|
|
||||||
|
// Hook into node creation
|
||||||
|
async nodeCreated(node) {
|
||||||
|
if (node.comfyClass !== "Lora Demo (LoraManager)") {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store original onExecuted
|
||||||
|
const originalOnExecuted = node.onExecuted?.bind(node);
|
||||||
|
|
||||||
|
// Override onExecuted to handle UI updates
|
||||||
|
node.onExecuted = function(output) {
|
||||||
|
// Check if output has loras data
|
||||||
|
if (output?.loras && Array.isArray(output.loras)) {
|
||||||
|
console.log("[LoraDemoNode] Received loras data from backend:", output.loras);
|
||||||
|
|
||||||
|
// Find the loras widget on this node
|
||||||
|
const lorasWidget = node.widgets.find(w => w.name === 'loras');
|
||||||
|
|
||||||
|
if (lorasWidget) {
|
||||||
|
// Update widget value with backend data
|
||||||
|
lorasWidget.value = output.loras;
|
||||||
|
|
||||||
|
console.log(`[LoraDemoNode] Updated widget with ${output.loras.length} loras`);
|
||||||
|
} else {
|
||||||
|
console.warn("[LoraDemoNode] loras widget not found on node");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call original onExecuted if it exists
|
||||||
|
if (originalOnExecuted) {
|
||||||
|
return originalOnExecuted(output);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
});
|
||||||
44
web/comfyui/lora_randomizer_widget.js
Normal file
44
web/comfyui/lora_randomizer_widget.js
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
import { app } from "../../../scripts/app.js";
|
||||||
|
|
||||||
|
app.registerExtension({
|
||||||
|
name: "LoraManager.LoraRandomizer",
|
||||||
|
|
||||||
|
// Hook into node creation
|
||||||
|
async nodeCreated(node) {
|
||||||
|
if (node.comfyClass !== "Lora Randomizer (LoraManager)") {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log("[LoraRandomizerWidget] Node created:", node.id);
|
||||||
|
|
||||||
|
// Store original onExecuted
|
||||||
|
const originalOnExecuted = node.onExecuted?.bind(node);
|
||||||
|
|
||||||
|
// Override onExecuted to handle UI updates
|
||||||
|
node.onExecuted = function(output) {
|
||||||
|
console.log("[LoraRandomizerWidget] Node executed with output:", output);
|
||||||
|
|
||||||
|
// Check if output has loras data
|
||||||
|
if (output?.loras && Array.isArray(output.loras)) {
|
||||||
|
console.log("[LoraRandomizerWidget] Received loras data from backend:", output.loras);
|
||||||
|
|
||||||
|
// Find the loras widget on this node
|
||||||
|
const lorasWidget = node.widgets.find(w => w.name === 'loras');
|
||||||
|
|
||||||
|
if (lorasWidget) {
|
||||||
|
// Update widget value with backend data
|
||||||
|
lorasWidget.value = output.loras;
|
||||||
|
|
||||||
|
console.log(`[LoraRandomizerWidget] Updated widget with ${output.loras.length} loras`);
|
||||||
|
} else {
|
||||||
|
console.warn("[LoraRandomizerWidget] loras widget not found on node");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call original onExecuted if it exists
|
||||||
|
if (originalOnExecuted) {
|
||||||
|
return originalOnExecuted(output);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
});
|
||||||
@@ -97,27 +97,27 @@ app.registerExtension({
|
|||||||
if (isUpdating) return;
|
if (isUpdating) return;
|
||||||
isUpdating = true;
|
isUpdating = true;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Update this stacker's direct trigger toggles with its own active loras
|
// Update this stacker's direct trigger toggles with its own active loras
|
||||||
// Only if the stacker node itself is active (mode 0 for Always, mode 3 for On Trigger)
|
// Only if the stacker node itself is active (mode 0 for Always, mode 3 for On Trigger)
|
||||||
const isNodeActive = this.mode === undefined || this.mode === 0 || this.mode === 3;
|
const isNodeActive = this.mode === undefined || this.mode === 0 || this.mode === 3;
|
||||||
const activeLoraNames = new Set();
|
const activeLoraNames = new Set();
|
||||||
if (isNodeActive) {
|
if (isNodeActive) {
|
||||||
value.forEach((lora) => {
|
value.forEach((lora) => {
|
||||||
if (lora.active) {
|
if (lora.active) {
|
||||||
activeLoraNames.add(lora.name);
|
activeLoraNames.add(lora.name);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
updateConnectedTriggerWords(this, activeLoraNames);
|
||||||
|
|
||||||
|
// Find all Lora Loader nodes in the chain that might need updates
|
||||||
|
updateDownstreamLoaders(this);
|
||||||
|
} finally {
|
||||||
|
isUpdating = false;
|
||||||
}
|
}
|
||||||
updateConnectedTriggerWords(this, activeLoraNames);
|
|
||||||
|
|
||||||
// Find all Lora Loader nodes in the chain that might need updates
|
scheduleInputSync(value);
|
||||||
updateDownstreamLoaders(this);
|
|
||||||
} finally {
|
|
||||||
isUpdating = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
scheduleInputSync(value);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
this.lorasWidget = result.widget;
|
this.lorasWidget = result.widget;
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user