mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat: add LoraDemoNode and LoraRandomizerNode with documentation
- Import and register two new nodes: LoraDemoNode and LoraRandomizerNode - Update import exception handling for better readability with multi-line formatting - Add comprehensive documentation file `docs/custom-node-ui-output.md` for UI output usage in custom nodes - Ensure proper node registration in NODE_CLASS_MAPPINGS for ComfyUI integration - Maintain backward compatibility with existing node structure and import fallbacks
This commit is contained in:
41
__init__.py
41
__init__.py
@@ -9,8 +9,12 @@ try: # pragma: no cover - import fallback for pytest collection
|
||||
from .py.nodes.wanvideo_lora_select import WanVideoLoraSelectLM
|
||||
from .py.nodes.wanvideo_lora_select_from_text import WanVideoLoraSelectFromText
|
||||
from .py.nodes.lora_pool import LoraPoolNode
|
||||
from .py.nodes.lora_demo import LoraDemoNode
|
||||
from .py.nodes.lora_randomizer import LoraRandomizerNode
|
||||
from .py.metadata_collector import init as init_metadata_collector
|
||||
except ImportError: # pragma: no cover - allows running under pytest without package install
|
||||
except (
|
||||
ImportError
|
||||
): # pragma: no cover - allows running under pytest without package install
|
||||
import importlib
|
||||
import pathlib
|
||||
import sys
|
||||
@@ -21,15 +25,27 @@ except ImportError: # pragma: no cover - allows running under pytest without pa
|
||||
|
||||
PromptLoraManager = importlib.import_module("py.nodes.prompt").PromptLoraManager
|
||||
LoraManager = importlib.import_module("py.lora_manager").LoraManager
|
||||
LoraManagerLoader = importlib.import_module("py.nodes.lora_loader").LoraManagerLoader
|
||||
LoraManagerTextLoader = importlib.import_module("py.nodes.lora_loader").LoraManagerTextLoader
|
||||
TriggerWordToggle = importlib.import_module("py.nodes.trigger_word_toggle").TriggerWordToggle
|
||||
LoraManagerLoader = importlib.import_module(
|
||||
"py.nodes.lora_loader"
|
||||
).LoraManagerLoader
|
||||
LoraManagerTextLoader = importlib.import_module(
|
||||
"py.nodes.lora_loader"
|
||||
).LoraManagerTextLoader
|
||||
TriggerWordToggle = importlib.import_module(
|
||||
"py.nodes.trigger_word_toggle"
|
||||
).TriggerWordToggle
|
||||
LoraStacker = importlib.import_module("py.nodes.lora_stacker").LoraStacker
|
||||
SaveImageLM = importlib.import_module("py.nodes.save_image").SaveImageLM
|
||||
DebugMetadata = importlib.import_module("py.nodes.debug_metadata").DebugMetadata
|
||||
WanVideoLoraSelectLM = importlib.import_module("py.nodes.wanvideo_lora_select").WanVideoLoraSelectLM
|
||||
WanVideoLoraSelectFromText = importlib.import_module("py.nodes.wanvideo_lora_select_from_text").WanVideoLoraSelectFromText
|
||||
WanVideoLoraSelectLM = importlib.import_module(
|
||||
"py.nodes.wanvideo_lora_select"
|
||||
).WanVideoLoraSelectLM
|
||||
WanVideoLoraSelectFromText = importlib.import_module(
|
||||
"py.nodes.wanvideo_lora_select_from_text"
|
||||
).WanVideoLoraSelectFromText
|
||||
LoraPoolNode = importlib.import_module("py.nodes.lora_pool").LoraPoolNode
|
||||
LoraDemoNode = importlib.import_module("py.nodes.lora_demo").LoraDemoNode
|
||||
LoraRandomizerNode = importlib.import_module("py.nodes.lora_randomizer").LoraRandomizerNode
|
||||
init_metadata_collector = importlib.import_module("py.metadata_collector").init
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
@@ -42,7 +58,9 @@ NODE_CLASS_MAPPINGS = {
|
||||
DebugMetadata.NAME: DebugMetadata,
|
||||
WanVideoLoraSelectLM.NAME: WanVideoLoraSelectLM,
|
||||
WanVideoLoraSelectFromText.NAME: WanVideoLoraSelectFromText,
|
||||
LoraPoolNode.NAME: LoraPoolNode
|
||||
LoraPoolNode.NAME: LoraPoolNode,
|
||||
LoraDemoNode.NAME: LoraDemoNode,
|
||||
LoraRandomizerNode.NAME: LoraRandomizerNode,
|
||||
}
|
||||
|
||||
WEB_DIRECTORY = "./web/comfyui"
|
||||
@@ -50,15 +68,20 @@ WEB_DIRECTORY = "./web/comfyui"
|
||||
# Check and build Vue widgets if needed (development mode)
|
||||
try:
|
||||
from .py.vue_widget_builder import check_and_build_vue_widgets
|
||||
|
||||
# Auto-build in development, warn only if fails
|
||||
check_and_build_vue_widgets(auto_build=True, warn_only=True)
|
||||
except ImportError:
|
||||
# Fallback for pytest
|
||||
import importlib
|
||||
check_and_build_vue_widgets = importlib.import_module("py.vue_widget_builder").check_and_build_vue_widgets
|
||||
|
||||
check_and_build_vue_widgets = importlib.import_module(
|
||||
"py.vue_widget_builder"
|
||||
).check_and_build_vue_widgets
|
||||
check_and_build_vue_widgets(auto_build=True, warn_only=True)
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.warning(f"[LoRA Manager] Vue widget build check skipped: {e}")
|
||||
|
||||
# Initialize metadata collector
|
||||
@@ -66,4 +89,4 @@ init_metadata_collector()
|
||||
|
||||
# Register routes on import
|
||||
LoraManager.add_routes()
|
||||
__all__ = ['NODE_CLASS_MAPPINGS', 'WEB_DIRECTORY']
|
||||
__all__ = ["NODE_CLASS_MAPPINGS", "WEB_DIRECTORY"]
|
||||
|
||||
95
py/nodes/lora_demo.py
Normal file
95
py/nodes/lora_demo.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
Lora Demo Node - Demonstrates LORAS custom widget type usage.
|
||||
|
||||
This node accepts LORAS widget input and outputs a summary string.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import random
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoraDemoNode:
|
||||
"""Demo node that uses LORAS custom widget type."""
|
||||
|
||||
NAME = "Lora Demo (LoraManager)"
|
||||
CATEGORY = "Lora Manager/demo"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"loras": ("LORAS", {}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
RETURN_NAMES = ("summary",)
|
||||
|
||||
FUNCTION = "process"
|
||||
OUTPUT_NODE = False
|
||||
|
||||
async def process(self, loras):
|
||||
"""
|
||||
Process LoRAs input and return summary + UI data for widget.
|
||||
|
||||
Args:
|
||||
loras: List of LoRA dictionaries with structure:
|
||||
[{'name': str, 'strength': float, 'clipStrength': float, 'active': bool, ...}]
|
||||
|
||||
Returns:
|
||||
Dictionary with 'result' (for workflow) and 'ui' (for frontend display)
|
||||
"""
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
|
||||
# Get lora scanner to access available loras
|
||||
scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
# Get available loras from cache
|
||||
available_loras = []
|
||||
try:
|
||||
cache_data = await scanner.get_cached_data(force_refresh=False)
|
||||
if cache_data and hasattr(cache_data, "raw_data"):
|
||||
available_loras = cache_data.raw_data
|
||||
except Exception as e:
|
||||
logger.warning(f"[LoraDemoNode] Failed to get lora cache: {e}")
|
||||
|
||||
# Randomly select 3-5 loras
|
||||
num_to_select = random.randint(3, 5)
|
||||
if len(available_loras) < num_to_select:
|
||||
num_to_select = len(available_loras)
|
||||
|
||||
selected_loras = (
|
||||
random.sample(available_loras, num_to_select) if num_to_select > 0 else []
|
||||
)
|
||||
|
||||
# Generate random loras data for widget
|
||||
widget_loras = []
|
||||
for lora in selected_loras:
|
||||
strength = round(random.uniform(0.1, 1.0), 2)
|
||||
widget_loras.append(
|
||||
{
|
||||
"name": lora.get("file_name", "Unknown"),
|
||||
"strength": strength,
|
||||
"clipStrength": strength,
|
||||
"active": True,
|
||||
"expanded": False,
|
||||
}
|
||||
)
|
||||
|
||||
# Create summary string
|
||||
active_names = [l["name"] for l in widget_loras]
|
||||
summary = f"Randomized {len(active_names)} LoRAs: {', '.join(active_names)}"
|
||||
|
||||
logger.info(f"[LoraDemoNode] {summary}")
|
||||
|
||||
# Return format: result for workflow + ui for frontend
|
||||
return {"result": (summary,), "ui": {"loras": widget_loras}}
|
||||
|
||||
|
||||
# Node class mappings for ComfyUI
|
||||
NODE_CLASS_MAPPINGS = {"LoraDemoNode": LoraDemoNode}
|
||||
|
||||
# Display name mappings
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {"LoraDemoNode": "LoRA Demo"}
|
||||
@@ -31,28 +31,28 @@ class LoraPoolNode:
|
||||
"hidden": {
|
||||
# Hidden input to pass through unique node ID for frontend
|
||||
"unique_id": "UNIQUE_ID",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LORA_POOL_CONFIG",)
|
||||
RETURN_NAMES = ("pool_config",)
|
||||
RETURN_TYPES = ("POOL_CONFIG",)
|
||||
RETURN_NAMES = ("POOL_CONFIG",)
|
||||
|
||||
FUNCTION = "process"
|
||||
OUTPUT_NODE = False
|
||||
|
||||
def process(self, pool_config, unique_id=None):
|
||||
"""
|
||||
Pass through the pool configuration.
|
||||
Pass through the pool configuration filters.
|
||||
|
||||
The config is generated entirely by the frontend widget.
|
||||
This function validates and passes through the configuration.
|
||||
This function validates and returns only the filters field.
|
||||
|
||||
Args:
|
||||
pool_config: Dict containing filter criteria from widget
|
||||
unique_id: Node's unique ID (hidden)
|
||||
|
||||
Returns:
|
||||
Tuple containing the validated pool_config
|
||||
Tuple containing the filters dict from pool_config
|
||||
"""
|
||||
# Validate required structure
|
||||
if not isinstance(pool_config, dict):
|
||||
@@ -63,10 +63,13 @@ class LoraPoolNode:
|
||||
if "version" not in pool_config:
|
||||
pool_config["version"] = 1
|
||||
|
||||
# Log for debugging
|
||||
logger.debug(f"[LoraPoolNode] Processing config: {pool_config}")
|
||||
# Extract filters field
|
||||
filters = pool_config.get("filters", self._default_config()["filters"])
|
||||
|
||||
return (pool_config,)
|
||||
# Log for debugging
|
||||
logger.debug(f"[LoraPoolNode] Processing filters: {filters}")
|
||||
|
||||
return (filters,)
|
||||
|
||||
@staticmethod
|
||||
def _default_config():
|
||||
@@ -76,23 +79,16 @@ class LoraPoolNode:
|
||||
"filters": {
|
||||
"baseModels": [],
|
||||
"tags": {"include": [], "exclude": []},
|
||||
"folder": {"path": None, "recursive": True},
|
||||
"folders": {"include": [], "exclude": []},
|
||||
"favoritesOnly": False,
|
||||
"license": {
|
||||
"noCreditRequired": None,
|
||||
"allowSellingGeneratedContent": None
|
||||
}
|
||||
"license": {"noCreditRequired": False, "allowSelling": False},
|
||||
},
|
||||
"preview": {"matchCount": 0, "lastUpdated": 0}
|
||||
"preview": {"matchCount": 0, "lastUpdated": 0},
|
||||
}
|
||||
|
||||
|
||||
# Node class mappings for ComfyUI
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LoraPoolNode": LoraPoolNode
|
||||
}
|
||||
NODE_CLASS_MAPPINGS = {"LoraPoolNode": LoraPoolNode}
|
||||
|
||||
# Display name mappings
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"LoraPoolNode": "LoRA Pool (Filter)"
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {"LoraPoolNode": "LoRA Pool (Filter)"}
|
||||
|
||||
297
py/nodes/lora_randomizer.py
Normal file
297
py/nodes/lora_randomizer.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""
|
||||
Lora Randomizer Node - Randomly selects LoRAs from a pool with configurable settings.
|
||||
|
||||
This node accepts optional pool_config input to filter available LoRAs, and outputs
|
||||
a LORA_STACK with randomly selected LoRAs. Supports both frontend roll (fixed selection)
|
||||
and backend roll (randomizes each execution).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import random
|
||||
import os
|
||||
from ..utils.utils import get_lora_info
|
||||
from .utils import extract_lora_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoraRandomizerNode:
|
||||
"""Node that randomly selects LoRAs from a pool"""
|
||||
|
||||
NAME = "Lora Randomizer (LoraManager)"
|
||||
CATEGORY = "Lora Manager/randomizer"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"randomizer_config": ("RANDOMIZER_CONFIG", {}),
|
||||
"loras": ("LORAS", {}),
|
||||
},
|
||||
"optional": {
|
||||
"pool_config": ("POOL_CONFIG", {}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LORA_STACK",)
|
||||
RETURN_NAMES = ("lora_stack",)
|
||||
|
||||
FUNCTION = "randomize"
|
||||
OUTPUT_NODE = False
|
||||
|
||||
async def randomize(self, randomizer_config, loras, pool_config=None):
|
||||
"""
|
||||
Randomize LoRAs based on configuration and pool filters.
|
||||
|
||||
Args:
|
||||
randomizer_config: Dict with randomizer settings (count, strength ranges, roll mode)
|
||||
loras: List of LoRA dicts from LORAS widget (includes locked state)
|
||||
pool_config: Optional config from LoRA Pool node for filtering
|
||||
|
||||
Returns:
|
||||
Dictionary with 'result' (LORA_STACK tuple) and 'ui' (for widget display)
|
||||
"""
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
|
||||
# Get lora scanner to access available loras
|
||||
scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
# Parse randomizer settings
|
||||
count_mode = randomizer_config.get("count_mode", "range")
|
||||
count_fixed = randomizer_config.get("count_fixed", 5)
|
||||
count_min = randomizer_config.get("count_min", 3)
|
||||
count_max = randomizer_config.get("count_max", 7)
|
||||
model_strength_min = randomizer_config.get("model_strength_min", 0.0)
|
||||
model_strength_max = randomizer_config.get("model_strength_max", 1.0)
|
||||
use_same_clip_strength = randomizer_config.get("use_same_clip_strength", True)
|
||||
clip_strength_min = randomizer_config.get("clip_strength_min", 0.0)
|
||||
clip_strength_max = randomizer_config.get("clip_strength_max", 1.0)
|
||||
roll_mode = randomizer_config.get("roll_mode", "frontend")
|
||||
|
||||
# Determine target count
|
||||
if count_mode == "fixed":
|
||||
target_count = count_fixed
|
||||
else:
|
||||
target_count = random.randint(count_min, count_max)
|
||||
|
||||
logger.info(
|
||||
f"[LoraRandomizerNode] Target count: {target_count}, Roll mode: {roll_mode}"
|
||||
)
|
||||
|
||||
# Extract locked LoRAs from input
|
||||
locked_loras = [lora for lora in loras if lora.get("locked", False)]
|
||||
locked_count = len(locked_loras)
|
||||
|
||||
logger.info(f"[LoraRandomizerNode] Locked LoRAs: {locked_count}")
|
||||
|
||||
# Get available loras from cache
|
||||
try:
|
||||
cache_data = await scanner.get_cached_data(force_refresh=False)
|
||||
if cache_data and hasattr(cache_data, "raw_data"):
|
||||
available_loras = cache_data.raw_data
|
||||
else:
|
||||
available_loras = []
|
||||
except Exception as e:
|
||||
logger.warning(f"[LoraRandomizerNode] Failed to get lora cache: {e}")
|
||||
available_loras = []
|
||||
|
||||
# Apply pool filters if provided
|
||||
if pool_config:
|
||||
available_loras = await self._apply_pool_filters(
|
||||
available_loras, pool_config, scanner
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[LoraRandomizerNode] Available LoRAs after filtering: {len(available_loras)}"
|
||||
)
|
||||
|
||||
# Calculate how many new LoRAs to select
|
||||
# In frontend mode, if loras already has data, preserve unlocked ones if roll_mode requires
|
||||
# For simplicity in backend mode, we regenerate all unlocked slots
|
||||
slots_needed = target_count - locked_count
|
||||
|
||||
if slots_needed < 0:
|
||||
slots_needed = 0
|
||||
# Too many locked, trim to target
|
||||
locked_loras = locked_loras[:target_count]
|
||||
locked_count = len(locked_loras)
|
||||
|
||||
# Filter out locked LoRAs from available pool
|
||||
locked_names = {lora["name"] for lora in locked_loras}
|
||||
available_pool = [
|
||||
l for l in available_loras if l["file_name"] not in locked_names
|
||||
]
|
||||
|
||||
# Ensure we don't try to select more than available
|
||||
if slots_needed > len(available_pool):
|
||||
slots_needed = len(available_pool)
|
||||
|
||||
logger.info(
|
||||
f"[LoraRandomizerNode] Selecting {slots_needed} new LoRAs from {len(available_pool)} available"
|
||||
)
|
||||
|
||||
# Random sample
|
||||
selected = []
|
||||
if slots_needed > 0:
|
||||
selected = random.sample(available_pool, slots_needed)
|
||||
|
||||
# Generate random strengths for selected LoRAs
|
||||
result_loras = []
|
||||
for lora in selected:
|
||||
model_str = round(random.uniform(model_strength_min, model_strength_max), 2)
|
||||
|
||||
if use_same_clip_strength:
|
||||
clip_str = model_str
|
||||
else:
|
||||
clip_str = round(
|
||||
random.uniform(clip_strength_min, clip_strength_max), 2
|
||||
)
|
||||
|
||||
result_loras.append(
|
||||
{
|
||||
"name": lora["file_name"],
|
||||
"strength": model_str,
|
||||
"clipStrength": clip_str,
|
||||
"active": True,
|
||||
"expanded": abs(model_str - clip_str) > 0.001,
|
||||
"locked": False,
|
||||
}
|
||||
)
|
||||
|
||||
# Merge with locked LoRAs
|
||||
result_loras.extend(locked_loras)
|
||||
|
||||
logger.info(f"[LoraRandomizerNode] Final LoRA count: {len(result_loras)}")
|
||||
|
||||
# Build LORA_STACK output
|
||||
lora_stack = []
|
||||
for lora in result_loras:
|
||||
if not lora.get("active", False):
|
||||
continue
|
||||
|
||||
# Get file path
|
||||
lora_path, trigger_words = get_lora_info(lora["name"])
|
||||
if not lora_path:
|
||||
logger.warning(
|
||||
f"[LoraRandomizerNode] Could not find path for LoRA: {lora['name']}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Normalize path separators
|
||||
lora_path = lora_path.replace("/", os.sep)
|
||||
|
||||
# Extract strengths
|
||||
model_strength = lora.get("strength", 1.0)
|
||||
clip_strength = lora.get("clipStrength", model_strength)
|
||||
|
||||
lora_stack.append((lora_path, model_strength, clip_strength))
|
||||
|
||||
# Return format: result for workflow + ui for frontend display
|
||||
return {"result": (lora_stack,), "ui": {"loras": result_loras}}
|
||||
|
||||
async def _apply_pool_filters(self, available_loras, pool_config, scanner):
|
||||
"""
|
||||
Apply pool_config filters to available LoRAs.
|
||||
|
||||
Args:
|
||||
available_loras: List of all LoRA dicts
|
||||
pool_config: Dict with filter settings from LoRA Pool node
|
||||
scanner: Scanner instance for accessing filter utilities
|
||||
|
||||
Returns:
|
||||
Filtered list of LoRA dicts
|
||||
"""
|
||||
from ..services.lora_service import LoraService
|
||||
from ..services.model_query import FilterCriteria
|
||||
|
||||
# Create lora service instance for filtering
|
||||
lora_service = LoraService(scanner)
|
||||
|
||||
# Extract filter parameters from pool_config
|
||||
selected_base_models = pool_config.get("baseModels", [])
|
||||
tags_dict = pool_config.get("tags", {})
|
||||
include_tags = tags_dict.get("include", [])
|
||||
exclude_tags = tags_dict.get("exclude", [])
|
||||
folders_dict = pool_config.get("folders", {})
|
||||
include_folders = folders_dict.get("include", [])
|
||||
exclude_folders = folders_dict.get("exclude", [])
|
||||
license_dict = pool_config.get("license", {})
|
||||
no_credit_required = license_dict.get("noCreditRequired", False)
|
||||
allow_selling = license_dict.get("allowSelling", False)
|
||||
|
||||
# Build tag filters dict
|
||||
tag_filters = {}
|
||||
for tag in include_tags:
|
||||
tag_filters[tag] = "include"
|
||||
for tag in exclude_tags:
|
||||
tag_filters[tag] = "exclude"
|
||||
|
||||
# Build folder filter
|
||||
# LoRA Pool uses include/exclude folders, we need to apply this logic
|
||||
# For now, we'll filter based on folder path matching
|
||||
if include_folders or exclude_folders:
|
||||
filtered = []
|
||||
for lora in available_loras:
|
||||
folder = lora.get("folder", "")
|
||||
|
||||
# Check exclude folders first
|
||||
excluded = False
|
||||
for exclude_folder in exclude_folders:
|
||||
if folder.startswith(exclude_folder):
|
||||
excluded = True
|
||||
break
|
||||
|
||||
if excluded:
|
||||
continue
|
||||
|
||||
# Check include folders
|
||||
if include_folders:
|
||||
included = False
|
||||
for include_folder in include_folders:
|
||||
if folder.startswith(include_folder):
|
||||
included = True
|
||||
break
|
||||
if not included:
|
||||
continue
|
||||
|
||||
filtered.append(lora)
|
||||
|
||||
available_loras = filtered
|
||||
|
||||
# Apply base model filter
|
||||
if selected_base_models:
|
||||
available_loras = [
|
||||
lora
|
||||
for lora in available_loras
|
||||
if lora.get("base_model") in selected_base_models
|
||||
]
|
||||
|
||||
# Apply tag filters
|
||||
if tag_filters:
|
||||
criteria = FilterCriteria(tags=tag_filters)
|
||||
available_loras = lora_service.filter_set.apply(available_loras, criteria)
|
||||
|
||||
# Apply license filters
|
||||
if no_credit_required:
|
||||
available_loras = [
|
||||
lora
|
||||
for lora in available_loras
|
||||
if not lora.get("civitai", {}).get("allowNoCredit", True)
|
||||
]
|
||||
|
||||
if allow_selling:
|
||||
available_loras = [
|
||||
lora
|
||||
for lora in available_loras
|
||||
if lora.get("civitai", {}).get("allowCommercialUse", ["None"])[0]
|
||||
!= "None"
|
||||
]
|
||||
|
||||
return available_loras
|
||||
|
||||
|
||||
# Node class mappings for ComfyUI
|
||||
NODE_CLASS_MAPPINGS = {"LoraRandomizerNode": LoraRandomizerNode}
|
||||
|
||||
# Display name mappings
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {"LoraRandomizerNode": "LoRA Randomizer"}
|
||||
@@ -45,6 +45,9 @@ class LoraRoutes(BaseModelRoutes):
|
||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/get-trigger-words', prefix, self.get_lora_trigger_words)
|
||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/usage-tips-by-path', prefix, self.get_lora_usage_tips_by_path)
|
||||
|
||||
# Randomizer routes
|
||||
registrar.add_prefixed_route('POST', '/api/lm/{prefix}/random-sample', prefix, self.get_random_loras)
|
||||
|
||||
# ComfyUI integration
|
||||
registrar.add_prefixed_route('POST', '/api/lm/{prefix}/get_trigger_words', prefix, self.get_trigger_words)
|
||||
|
||||
@@ -215,6 +218,74 @@ class LoraRoutes(BaseModelRoutes):
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_random_loras(self, request: web.Request) -> web.Response:
|
||||
"""Get random LoRAs based on filters and strength ranges"""
|
||||
try:
|
||||
json_data = await request.json()
|
||||
|
||||
# Parse parameters
|
||||
count = json_data.get('count', 5)
|
||||
count_min = json_data.get('count_min')
|
||||
count_max = json_data.get('count_max')
|
||||
model_strength_min = float(json_data.get('model_strength_min', 0.0))
|
||||
model_strength_max = float(json_data.get('model_strength_max', 1.0))
|
||||
use_same_clip_strength = json_data.get('use_same_clip_strength', True)
|
||||
clip_strength_min = float(json_data.get('clip_strength_min', 0.0))
|
||||
clip_strength_max = float(json_data.get('clip_strength_max', 1.0))
|
||||
locked_loras = json_data.get('locked_loras', [])
|
||||
pool_config = json_data.get('pool_config')
|
||||
|
||||
# Determine target count
|
||||
if count_min is not None and count_max is not None:
|
||||
import random
|
||||
target_count = random.randint(count_min, count_max)
|
||||
else:
|
||||
target_count = count
|
||||
|
||||
# Validate parameters
|
||||
if target_count < 1 or target_count > 100:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Count must be between 1 and 100'
|
||||
}, status=400)
|
||||
|
||||
if model_strength_min < 0 or model_strength_max > 10:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Model strength must be between 0 and 10'
|
||||
}, status=400)
|
||||
|
||||
# Get random LoRAs from service
|
||||
result_loras = await self.service.get_random_loras(
|
||||
count=target_count,
|
||||
model_strength_min=model_strength_min,
|
||||
model_strength_max=model_strength_max,
|
||||
use_same_clip_strength=use_same_clip_strength,
|
||||
clip_strength_min=clip_strength_min,
|
||||
clip_strength_max=clip_strength_max,
|
||||
locked_loras=locked_loras,
|
||||
pool_config=pool_config
|
||||
)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'loras': result_loras,
|
||||
'count': len(result_loras)
|
||||
})
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid parameter for random LoRAs: {e}")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=400)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting random LoRAs: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_trigger_words(self, request: web.Request) -> web.Response:
|
||||
"""Get trigger words for specified LoRA models"""
|
||||
try:
|
||||
|
||||
@@ -178,7 +178,188 @@ class LoraService(BaseModelService):
|
||||
def find_duplicate_hashes(self) -> Dict:
|
||||
"""Find LoRAs with duplicate SHA256 hashes"""
|
||||
return self.scanner._hash_index.get_duplicate_hashes()
|
||||
|
||||
|
||||
def find_duplicate_filenames(self) -> Dict:
|
||||
"""Find LoRAs with conflicting filenames"""
|
||||
return self.scanner._hash_index.get_duplicate_filenames()
|
||||
|
||||
async def get_random_loras(
|
||||
self,
|
||||
count: int,
|
||||
model_strength_min: float = 0.0,
|
||||
model_strength_max: float = 1.0,
|
||||
use_same_clip_strength: bool = True,
|
||||
clip_strength_min: float = 0.0,
|
||||
clip_strength_max: float = 1.0,
|
||||
locked_loras: Optional[List[Dict]] = None,
|
||||
pool_config: Optional[Dict] = None
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Get random LoRAs with specified strength ranges.
|
||||
|
||||
Args:
|
||||
count: Number of LoRAs to select
|
||||
model_strength_min: Minimum model strength
|
||||
model_strength_max: Maximum model strength
|
||||
use_same_clip_strength: Whether to use same strength for clip
|
||||
clip_strength_min: Minimum clip strength
|
||||
clip_strength_max: Maximum clip strength
|
||||
locked_loras: List of locked LoRA dicts to preserve
|
||||
pool_config: Optional pool config for filtering
|
||||
|
||||
Returns:
|
||||
List of LoRA dicts with randomized strengths
|
||||
"""
|
||||
import random
|
||||
|
||||
if locked_loras is None:
|
||||
locked_loras = []
|
||||
|
||||
# Get available loras from cache
|
||||
cache = await self.scanner.get_cached_data(force_refresh=False)
|
||||
available_loras = cache.raw_data if cache else []
|
||||
|
||||
# Apply pool filters if provided
|
||||
if pool_config:
|
||||
available_loras = await self._apply_pool_filters(
|
||||
available_loras, pool_config
|
||||
)
|
||||
|
||||
# Calculate slots needed (total - locked)
|
||||
locked_count = len(locked_loras)
|
||||
slots_needed = count - locked_count
|
||||
|
||||
if slots_needed < 0:
|
||||
slots_needed = 0
|
||||
# Too many locked, trim to target
|
||||
locked_loras = locked_loras[:count]
|
||||
|
||||
# Filter out locked LoRAs from available pool
|
||||
locked_names = {lora['name'] for lora in locked_loras}
|
||||
available_pool = [
|
||||
l for l in available_loras
|
||||
if l['model_name'] not in locked_names
|
||||
]
|
||||
|
||||
# Ensure we don't try to select more than available
|
||||
if slots_needed > len(available_pool):
|
||||
slots_needed = len(available_pool)
|
||||
|
||||
# Random sample
|
||||
selected = []
|
||||
if slots_needed > 0:
|
||||
selected = random.sample(available_pool, slots_needed)
|
||||
|
||||
# Generate random strengths for selected LoRAs
|
||||
result_loras = []
|
||||
for lora in selected:
|
||||
model_str = round(
|
||||
random.uniform(model_strength_min, model_strength_max), 2
|
||||
)
|
||||
|
||||
if use_same_clip_strength:
|
||||
clip_str = model_str
|
||||
else:
|
||||
clip_str = round(
|
||||
random.uniform(clip_strength_min, clip_strength_max), 2
|
||||
)
|
||||
|
||||
result_loras.append({
|
||||
'name': lora['model_name'],
|
||||
'strength': model_str,
|
||||
'clipStrength': clip_str,
|
||||
'active': True,
|
||||
'expanded': abs(model_str - clip_str) > 0.001,
|
||||
'locked': False
|
||||
})
|
||||
|
||||
# Merge with locked LoRAs
|
||||
result_loras.extend(locked_loras)
|
||||
|
||||
return result_loras
|
||||
|
||||
async def _apply_pool_filters(self, available_loras: List[Dict], pool_config: Dict) -> List[Dict]:
|
||||
"""
|
||||
Apply pool_config filters to available LoRAs.
|
||||
|
||||
Args:
|
||||
available_loras: List of all LoRA dicts
|
||||
pool_config: Dict with filter settings from LoRA Pool node
|
||||
|
||||
Returns:
|
||||
Filtered list of LoRA dicts
|
||||
"""
|
||||
from .model_query import FilterCriteria
|
||||
|
||||
# Extract filter parameters from pool_config
|
||||
selected_base_models = pool_config.get('selected_base_models', [])
|
||||
include_tags = pool_config.get('include_tags', [])
|
||||
exclude_tags = pool_config.get('exclude_tags', [])
|
||||
include_folders = pool_config.get('include_folders', [])
|
||||
exclude_folders = pool_config.get('exclude_folders', [])
|
||||
no_credit_required = pool_config.get('no_credit_required', False)
|
||||
allow_selling = pool_config.get('allow_selling', False)
|
||||
|
||||
# Build tag filters dict
|
||||
tag_filters = {}
|
||||
for tag in include_tags:
|
||||
tag_filters[tag] = 'include'
|
||||
for tag in exclude_tags:
|
||||
tag_filters[tag] = 'exclude'
|
||||
|
||||
# Build folder filter
|
||||
if include_folders or exclude_folders:
|
||||
filtered = []
|
||||
for lora in available_loras:
|
||||
folder = lora.get('folder', '')
|
||||
|
||||
# Check exclude folders first
|
||||
excluded = False
|
||||
for exclude_folder in exclude_folders:
|
||||
if folder.startswith(exclude_folder):
|
||||
excluded = True
|
||||
break
|
||||
|
||||
if excluded:
|
||||
continue
|
||||
|
||||
# Check include folders
|
||||
if include_folders:
|
||||
included = False
|
||||
for include_folder in include_folders:
|
||||
if folder.startswith(include_folder):
|
||||
included = True
|
||||
break
|
||||
if not included:
|
||||
continue
|
||||
|
||||
filtered.append(lora)
|
||||
|
||||
available_loras = filtered
|
||||
|
||||
# Apply base model filter
|
||||
if selected_base_models:
|
||||
available_loras = [
|
||||
lora for lora in available_loras
|
||||
if lora.get('base_model') in selected_base_models
|
||||
]
|
||||
|
||||
# Apply tag filters
|
||||
if tag_filters:
|
||||
criteria = FilterCriteria(tags=tag_filters)
|
||||
available_loras = self.filter_set.apply(available_loras, criteria)
|
||||
|
||||
# Apply license filters
|
||||
if no_credit_required:
|
||||
available_loras = [
|
||||
lora for lora in available_loras
|
||||
if not lora.get('civitai', {}).get('allowNoCredit', True)
|
||||
]
|
||||
|
||||
if allow_selling:
|
||||
available_loras = [
|
||||
lora for lora in available_loras
|
||||
if lora.get('civitai', {}).get('allowCommercialUse', ['None'])[0] != 'None'
|
||||
]
|
||||
|
||||
return available_loras
|
||||
|
||||
191
tests/routes/test_randomizer_endpoints.py
Normal file
191
tests/routes/test_randomizer_endpoints.py
Normal file
@@ -0,0 +1,191 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from py.routes.lora_routes import LoraRoutes
|
||||
|
||||
|
||||
class DummyRequest:
|
||||
def __init__(self, *, query=None, match_info=None, json_data=None):
|
||||
self.query = query or {}
|
||||
self.match_info = match_info or {}
|
||||
self._json_data = json_data or {}
|
||||
|
||||
async def json(self):
|
||||
return self._json_data
|
||||
|
||||
|
||||
class StubLoraService:
|
||||
"""Stub service for testing randomizer endpoints"""
|
||||
|
||||
def __init__(self):
|
||||
self.random_loras = []
|
||||
|
||||
async def get_random_loras(self, **kwargs):
|
||||
return self.random_loras
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def routes():
|
||||
handler = LoraRoutes()
|
||||
handler.service = StubLoraService()
|
||||
return handler
|
||||
|
||||
|
||||
async def test_get_random_loras_success(routes):
|
||||
"""Test successful random LoRA generation"""
|
||||
routes.service.random_loras = [
|
||||
{
|
||||
'name': 'test_lora_1',
|
||||
'strength': 0.8,
|
||||
'clipStrength': 0.8,
|
||||
'active': True,
|
||||
'expanded': False,
|
||||
'locked': False
|
||||
},
|
||||
{
|
||||
'name': 'test_lora_2',
|
||||
'strength': 0.6,
|
||||
'clipStrength': 0.6,
|
||||
'active': True,
|
||||
'expanded': False,
|
||||
'locked': False
|
||||
}
|
||||
]
|
||||
|
||||
request = DummyRequest(json_data={
|
||||
'count': 5,
|
||||
'model_strength_min': 0.5,
|
||||
'model_strength_max': 1.0,
|
||||
'use_same_clip_strength': True,
|
||||
'locked_loras': []
|
||||
})
|
||||
|
||||
response = await routes.get_random_loras(request)
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert response.status == 200
|
||||
assert payload['success'] is True
|
||||
assert 'loras' in payload
|
||||
assert payload['count'] == 2
|
||||
|
||||
|
||||
async def test_get_random_loras_with_range(routes):
|
||||
"""Test random LoRAs with count range"""
|
||||
routes.service.random_loras = [
|
||||
{
|
||||
'name': 'test_lora_1',
|
||||
'strength': 0.8,
|
||||
'clipStrength': 0.8,
|
||||
'active': True,
|
||||
'expanded': False,
|
||||
'locked': False
|
||||
}
|
||||
]
|
||||
|
||||
request = DummyRequest(json_data={
|
||||
'count_min': 3,
|
||||
'count_max': 7,
|
||||
'model_strength_min': 0.0,
|
||||
'model_strength_max': 1.0,
|
||||
'use_same_clip_strength': True
|
||||
})
|
||||
|
||||
response = await routes.get_random_loras(request)
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert response.status == 200
|
||||
assert payload['success'] is True
|
||||
|
||||
|
||||
async def test_get_random_loras_invalid_count(routes):
|
||||
"""Test invalid count parameter"""
|
||||
request = DummyRequest(json_data={
|
||||
'count': 150, # Over limit
|
||||
'model_strength_min': 0.0,
|
||||
'model_strength_max': 1.0
|
||||
})
|
||||
|
||||
response = await routes.get_random_loras(request)
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert response.status == 400
|
||||
assert payload['success'] is False
|
||||
assert 'Count must be between 1 and 100' in payload['error']
|
||||
|
||||
|
||||
async def test_get_random_loras_invalid_strength(routes):
|
||||
"""Test invalid strength range"""
|
||||
request = DummyRequest(json_data={
|
||||
'count': 5,
|
||||
'model_strength_min': -0.5, # Invalid
|
||||
'model_strength_max': 1.0
|
||||
})
|
||||
|
||||
response = await routes.get_random_loras(request)
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert response.status == 400
|
||||
assert payload['success'] is False
|
||||
|
||||
|
||||
async def test_get_random_loras_with_locked(routes):
|
||||
"""Test random LoRAs with locked items"""
|
||||
routes.service.random_loras = [
|
||||
{
|
||||
'name': 'new_lora',
|
||||
'strength': 0.7,
|
||||
'clipStrength': 0.7,
|
||||
'active': True,
|
||||
'expanded': False,
|
||||
'locked': False
|
||||
},
|
||||
{
|
||||
'name': 'locked_lora',
|
||||
'strength': 0.9,
|
||||
'clipStrength': 0.9,
|
||||
'active': True,
|
||||
'expanded': False,
|
||||
'locked': True
|
||||
}
|
||||
]
|
||||
|
||||
request = DummyRequest(json_data={
|
||||
'count': 5,
|
||||
'model_strength_min': 0.5,
|
||||
'model_strength_max': 1.0,
|
||||
'use_same_clip_strength': True,
|
||||
'locked_loras': [
|
||||
{
|
||||
'name': 'locked_lora',
|
||||
'strength': 0.9,
|
||||
'clipStrength': 0.9,
|
||||
'active': True,
|
||||
'expanded': False,
|
||||
'locked': True
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
response = await routes.get_random_loras(request)
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert response.status == 200
|
||||
assert payload['success'] is True
|
||||
|
||||
|
||||
async def test_get_random_loras_error(routes, monkeypatch):
|
||||
"""Test error handling"""
|
||||
async def failing(*_args, **_kwargs):
|
||||
raise RuntimeError("Service error")
|
||||
|
||||
routes.service.get_random_loras = failing
|
||||
request = DummyRequest(json_data={'count': 5})
|
||||
|
||||
response = await routes.get_random_loras(request)
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert response.status == 500
|
||||
assert payload['success'] is False
|
||||
assert 'error' in payload
|
||||
110
vue-widgets/src/components/LoraRandomizerWidget.vue
Normal file
110
vue-widgets/src/components/LoraRandomizerWidget.vue
Normal file
@@ -0,0 +1,110 @@
|
||||
<template>
|
||||
<div class="lora-randomizer-widget">
|
||||
<LoraRandomizerSettingsView
|
||||
:count-mode="state.countMode.value"
|
||||
:count-fixed="state.countFixed.value"
|
||||
:count-min="state.countMin.value"
|
||||
:count-max="state.countMax.value"
|
||||
:model-strength-min="state.modelStrengthMin.value"
|
||||
:model-strength-max="state.modelStrengthMax.value"
|
||||
:use-same-clip-strength="state.useSameClipStrength.value"
|
||||
:clip-strength-min="state.clipStrengthMin.value"
|
||||
:clip-strength-max="state.clipStrengthMax.value"
|
||||
:roll-mode="state.rollMode.value"
|
||||
:is-rolling="state.isRolling.value"
|
||||
:is-clip-strength-disabled="state.isClipStrengthDisabled.value"
|
||||
@update:count-mode="state.countMode.value = $event"
|
||||
@update:count-fixed="state.countFixed.value = $event"
|
||||
@update:count-min="state.countMin.value = $event"
|
||||
@update:count-max="state.countMax.value = $event"
|
||||
@update:model-strength-min="state.modelStrengthMin.value = $event"
|
||||
@update:model-strength-max="state.modelStrengthMax.value = $event"
|
||||
@update:use-same-clip-strength="state.useSameClipStrength.value = $event"
|
||||
@update:clip-strength-min="state.clipStrengthMin.value = $event"
|
||||
@update:clip-strength-max="state.clipStrengthMax.value = $event"
|
||||
@update:roll-mode="state.rollMode.value = $event"
|
||||
@roll="handleRoll"
|
||||
/>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { onMounted } from 'vue'
|
||||
import LoraRandomizerSettingsView from './lora-randomizer/LoraRandomizerSettingsView.vue'
|
||||
import { useLoraRandomizerState } from '../composables/useLoraRandomizerState'
|
||||
import type { ComponentWidget, RandomizerConfig } from '../composables/types'
|
||||
|
||||
// Props
|
||||
const props = defineProps<{
|
||||
widget: ComponentWidget
|
||||
node: { id: number }
|
||||
}>()
|
||||
|
||||
// State management
|
||||
const state = useLoraRandomizerState(props.widget)
|
||||
|
||||
// Handle roll button click
|
||||
const handleRoll = async () => {
|
||||
try {
|
||||
console.log('[LoraRandomizerWidget] Roll button clicked')
|
||||
|
||||
// Get pool config from connected input (if any)
|
||||
// This would need to be passed from the node's pool_config input
|
||||
const poolConfig = null // TODO: Get from node input if connected
|
||||
|
||||
// Get locked loras from the loras widget
|
||||
// This would need to be retrieved from the loras widget on the node
|
||||
const lockedLoras: any[] = [] // TODO: Get from loras widget
|
||||
|
||||
// Call API to get random loras
|
||||
const randomLoras = await state.rollLoras(poolConfig, lockedLoras)
|
||||
|
||||
console.log('[LoraRandomizerWidget] Got random LoRAs:', randomLoras)
|
||||
|
||||
// Update the loras widget with the new selection
|
||||
// This will be handled by emitting an event or directly updating the loras widget
|
||||
// For now, we'll emit a custom event that the parent widget handler can catch
|
||||
if (typeof (props.widget as any).onRoll === 'function') {
|
||||
(props.widget as any).onRoll(randomLoras)
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('[LoraRandomizerWidget] Error rolling LoRAs:', error)
|
||||
alert('Failed to roll LoRAs: ' + (error as Error).message)
|
||||
}
|
||||
}
|
||||
|
||||
// Lifecycle
|
||||
onMounted(async () => {
|
||||
// Setup serialization
|
||||
props.widget.serializeValue = async () => {
|
||||
const config = state.buildConfig()
|
||||
console.log('[LoraRandomizerWidget] Serializing config:', config)
|
||||
return config
|
||||
}
|
||||
|
||||
// Handle external value updates (e.g., loading workflow, paste)
|
||||
props.widget.onSetValue = (v) => {
|
||||
console.log('[LoraRandomizerWidget] Restoring from config:', v)
|
||||
state.restoreFromConfig(v as RandomizerConfig)
|
||||
}
|
||||
|
||||
// Restore from saved value
|
||||
if (props.widget.value) {
|
||||
console.log('[LoraRandomizerWidget] Restoring from saved value:', props.widget.value)
|
||||
state.restoreFromConfig(props.widget.value as RandomizerConfig)
|
||||
}
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.lora-randomizer-widget {
|
||||
padding: 12px;
|
||||
background: rgba(40, 44, 52, 0.6);
|
||||
border-radius: 4px;
|
||||
height: 100%;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
overflow: hidden;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
</style>
|
||||
@@ -0,0 +1,359 @@
|
||||
<template>
|
||||
<div class="randomizer-settings">
|
||||
<div class="settings-header">
|
||||
<h3 class="settings-title">RANDOMIZER SETTINGS</h3>
|
||||
</div>
|
||||
|
||||
<!-- LoRA Count -->
|
||||
<div class="setting-section">
|
||||
<label class="setting-label">LoRA Count</label>
|
||||
<div class="count-mode-selector">
|
||||
<label class="radio-label">
|
||||
<input
|
||||
type="radio"
|
||||
name="count-mode"
|
||||
value="fixed"
|
||||
:checked="countMode === 'fixed'"
|
||||
@change="$emit('update:countMode', 'fixed')"
|
||||
/>
|
||||
<span>Fixed:</span>
|
||||
<input
|
||||
type="number"
|
||||
class="number-input"
|
||||
:value="countFixed"
|
||||
:disabled="countMode !== 'fixed'"
|
||||
min="1"
|
||||
max="100"
|
||||
@input="$emit('update:countFixed', parseInt(($event.target as HTMLInputElement).value))"
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
<div class="count-mode-selector">
|
||||
<label class="radio-label">
|
||||
<input
|
||||
type="radio"
|
||||
name="count-mode"
|
||||
value="range"
|
||||
:checked="countMode === 'range'"
|
||||
@change="$emit('update:countMode', 'range')"
|
||||
/>
|
||||
<span>Range:</span>
|
||||
<input
|
||||
type="number"
|
||||
class="number-input"
|
||||
:value="countMin"
|
||||
:disabled="countMode !== 'range'"
|
||||
min="1"
|
||||
max="100"
|
||||
@input="$emit('update:countMin', parseInt(($event.target as HTMLInputElement).value))"
|
||||
/>
|
||||
<span>to</span>
|
||||
<input
|
||||
type="number"
|
||||
class="number-input"
|
||||
:value="countMax"
|
||||
:disabled="countMode !== 'range'"
|
||||
min="1"
|
||||
max="100"
|
||||
@input="$emit('update:countMax', parseInt(($event.target as HTMLInputElement).value))"
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Model Strength Range -->
|
||||
<div class="setting-section">
|
||||
<label class="setting-label">Model Strength Range</label>
|
||||
<div class="strength-inputs">
|
||||
<div class="strength-input-group">
|
||||
<label>Min:</label>
|
||||
<input
|
||||
type="number"
|
||||
class="number-input"
|
||||
:value="modelStrengthMin"
|
||||
min="0"
|
||||
max="10"
|
||||
step="0.1"
|
||||
@input="$emit('update:modelStrengthMin', parseFloat(($event.target as HTMLInputElement).value))"
|
||||
/>
|
||||
</div>
|
||||
<div class="strength-input-group">
|
||||
<label>Max:</label>
|
||||
<input
|
||||
type="number"
|
||||
class="number-input"
|
||||
:value="modelStrengthMax"
|
||||
min="0"
|
||||
max="10"
|
||||
step="0.1"
|
||||
@input="$emit('update:modelStrengthMax', parseFloat(($event.target as HTMLInputElement).value))"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Clip Strength Range -->
|
||||
<div class="setting-section">
|
||||
<label class="setting-label">Clip Strength Range</label>
|
||||
<div class="checkbox-group">
|
||||
<label class="checkbox-label">
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="useSameClipStrength"
|
||||
@change="$emit('update:useSameClipStrength', ($event.target as HTMLInputElement).checked)"
|
||||
/>
|
||||
<span>Same as model</span>
|
||||
</label>
|
||||
</div>
|
||||
<div class="strength-inputs" :class="{ disabled: isClipStrengthDisabled }">
|
||||
<div class="strength-input-group">
|
||||
<label>Min:</label>
|
||||
<input
|
||||
type="number"
|
||||
class="number-input"
|
||||
:value="clipStrengthMin"
|
||||
:disabled="isClipStrengthDisabled"
|
||||
min="0"
|
||||
max="10"
|
||||
step="0.1"
|
||||
@input="$emit('update:clipStrengthMin', parseFloat(($event.target as HTMLInputElement).value))"
|
||||
/>
|
||||
</div>
|
||||
<div class="strength-input-group">
|
||||
<label>Max:</label>
|
||||
<input
|
||||
type="number"
|
||||
class="number-input"
|
||||
:value="clipStrengthMax"
|
||||
:disabled="isClipStrengthDisabled"
|
||||
min="0"
|
||||
max="10"
|
||||
step="0.1"
|
||||
@input="$emit('update:clipStrengthMax', parseFloat(($event.target as HTMLInputElement).value))"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Roll Mode -->
|
||||
<div class="setting-section">
|
||||
<label class="setting-label">Roll Mode</label>
|
||||
<div class="roll-mode-selector">
|
||||
<label class="radio-label">
|
||||
<input
|
||||
type="radio"
|
||||
name="roll-mode"
|
||||
value="frontend"
|
||||
:checked="rollMode === 'frontend'"
|
||||
@change="$emit('update:rollMode', 'frontend')"
|
||||
/>
|
||||
<span>Frontend Roll (fixed until re-rolled)</span>
|
||||
</label>
|
||||
<button
|
||||
class="roll-button"
|
||||
:disabled="rollMode !== 'frontend' || isRolling"
|
||||
@click="$emit('roll')"
|
||||
>
|
||||
<span v-if="!isRolling">🎲 Roll</span>
|
||||
<span v-else>Rolling...</span>
|
||||
</button>
|
||||
</div>
|
||||
<div class="roll-mode-selector">
|
||||
<label class="radio-label">
|
||||
<input
|
||||
type="radio"
|
||||
name="roll-mode"
|
||||
value="backend"
|
||||
:checked="rollMode === 'backend'"
|
||||
@change="$emit('update:rollMode', 'backend')"
|
||||
/>
|
||||
<span>Backend Roll (randomizes each execution)</span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
defineProps<{
|
||||
countMode: 'fixed' | 'range'
|
||||
countFixed: number
|
||||
countMin: number
|
||||
countMax: number
|
||||
modelStrengthMin: number
|
||||
modelStrengthMax: number
|
||||
useSameClipStrength: boolean
|
||||
clipStrengthMin: number
|
||||
clipStrengthMax: number
|
||||
rollMode: 'frontend' | 'backend'
|
||||
isRolling: boolean
|
||||
isClipStrengthDisabled: boolean
|
||||
}>()
|
||||
|
||||
defineEmits<{
|
||||
'update:countMode': [value: 'fixed' | 'range']
|
||||
'update:countFixed': [value: number]
|
||||
'update:countMin': [value: number]
|
||||
'update:countMax': [value: number]
|
||||
'update:modelStrengthMin': [value: number]
|
||||
'update:modelStrengthMax': [value: number]
|
||||
'update:useSameClipStrength': [value: boolean]
|
||||
'update:clipStrengthMin': [value: number]
|
||||
'update:clipStrengthMax': [value: number]
|
||||
'update:rollMode': [value: 'frontend' | 'backend']
|
||||
roll: []
|
||||
}>()
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.randomizer-settings {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 16px;
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||
color: #e4e4e7;
|
||||
}
|
||||
|
||||
.settings-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.settings-title {
|
||||
font-size: 11px;
|
||||
font-weight: 600;
|
||||
letter-spacing: 0.05em;
|
||||
color: #a1a1aa;
|
||||
margin: 0;
|
||||
text-transform: uppercase;
|
||||
}
|
||||
|
||||
.setting-section {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.setting-label {
|
||||
font-size: 12px;
|
||||
font-weight: 500;
|
||||
color: #d4d4d8;
|
||||
}
|
||||
|
||||
.count-mode-selector,
|
||||
.roll-mode-selector {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
padding: 6px 8px;
|
||||
background: rgba(30, 30, 36, 0.5);
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
.radio-label {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
font-size: 13px;
|
||||
color: #e4e4e7;
|
||||
cursor: pointer;
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.radio-label input[type='radio'] {
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.radio-label input[type='radio']:disabled {
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.number-input {
|
||||
width: 60px;
|
||||
padding: 4px 8px;
|
||||
background: rgba(20, 20, 24, 0.6);
|
||||
border: 1px solid rgba(255, 255, 255, 0.1);
|
||||
border-radius: 3px;
|
||||
color: #e4e4e7;
|
||||
font-size: 13px;
|
||||
}
|
||||
|
||||
.number-input:disabled {
|
||||
opacity: 0.5;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.strength-inputs {
|
||||
display: flex;
|
||||
gap: 12px;
|
||||
padding: 6px 8px;
|
||||
background: rgba(30, 30, 36, 0.5);
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
.strength-inputs.disabled {
|
||||
opacity: 0.5;
|
||||
}
|
||||
|
||||
.strength-input-group {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.strength-input-group label {
|
||||
font-size: 12px;
|
||||
color: #d4d4d8;
|
||||
}
|
||||
|
||||
.checkbox-group {
|
||||
padding: 6px 8px;
|
||||
background: rgba(30, 30, 36, 0.5);
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
.checkbox-label {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
font-size: 13px;
|
||||
color: #e4e4e7;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.checkbox-label input[type='checkbox'] {
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.roll-button {
|
||||
padding: 6px 16px;
|
||||
background: linear-gradient(135deg, #3b82f6 0%, #2563eb 100%);
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
color: white;
|
||||
font-size: 13px;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.roll-button:hover:not(:disabled) {
|
||||
background: linear-gradient(135deg, #2563eb 0%, #1d4ed8 100%);
|
||||
transform: translateY(-1px);
|
||||
box-shadow: 0 2px 8px rgba(59, 130, 246, 0.3);
|
||||
}
|
||||
|
||||
.roll-button:active:not(:disabled) {
|
||||
transform: translateY(0);
|
||||
}
|
||||
|
||||
.roll-button:disabled {
|
||||
opacity: 0.5;
|
||||
cursor: not-allowed;
|
||||
background: linear-gradient(135deg, #52525b 0%, #3f3f46 100%);
|
||||
}
|
||||
</style>
|
||||
@@ -37,13 +37,6 @@ export interface FolderTreeNode {
|
||||
children?: FolderTreeNode[]
|
||||
}
|
||||
|
||||
export interface ComponentWidget {
|
||||
serializeValue?: () => Promise<LoraPoolConfig>
|
||||
value?: LoraPoolConfig | LegacyLoraPoolConfig
|
||||
onSetValue?: (v: LoraPoolConfig | LegacyLoraPoolConfig) => void
|
||||
updateConfig?: (v: LoraPoolConfig) => void
|
||||
}
|
||||
|
||||
// Legacy config for migration (v1)
|
||||
export interface LegacyLoraPoolConfig {
|
||||
version: 1
|
||||
@@ -59,3 +52,33 @@ export interface LegacyLoraPoolConfig {
|
||||
}
|
||||
preview: { matchCount: number; lastUpdated: number }
|
||||
}
|
||||
|
||||
// Randomizer config
|
||||
export interface RandomizerConfig {
|
||||
count_mode: 'fixed' | 'range'
|
||||
count_fixed: number
|
||||
count_min: number
|
||||
count_max: number
|
||||
model_strength_min: number
|
||||
model_strength_max: number
|
||||
use_same_clip_strength: boolean
|
||||
clip_strength_min: number
|
||||
clip_strength_max: number
|
||||
roll_mode: 'frontend' | 'backend'
|
||||
}
|
||||
|
||||
export interface LoraEntry {
|
||||
name: string
|
||||
strength: number
|
||||
clipStrength: number
|
||||
active: boolean
|
||||
expanded: boolean
|
||||
locked: boolean
|
||||
}
|
||||
|
||||
export interface ComponentWidget {
|
||||
serializeValue?: () => Promise<LoraPoolConfig | RandomizerConfig>
|
||||
value?: LoraPoolConfig | LegacyLoraPoolConfig | RandomizerConfig
|
||||
onSetValue?: (v: LoraPoolConfig | LegacyLoraPoolConfig | RandomizerConfig) => void
|
||||
updateConfig?: (v: LoraPoolConfig | RandomizerConfig) => void
|
||||
}
|
||||
|
||||
142
vue-widgets/src/composables/useLoraRandomizerState.ts
Normal file
142
vue-widgets/src/composables/useLoraRandomizerState.ts
Normal file
@@ -0,0 +1,142 @@
|
||||
import { ref, computed } from 'vue'
|
||||
import type { ComponentWidget, RandomizerConfig, LoraEntry } from './types'
|
||||
|
||||
export function useLoraRandomizerState(widget: ComponentWidget) {
|
||||
// State refs
|
||||
const countMode = ref<'fixed' | 'range'>('range')
|
||||
const countFixed = ref(5)
|
||||
const countMin = ref(3)
|
||||
const countMax = ref(7)
|
||||
const modelStrengthMin = ref(0.0)
|
||||
const modelStrengthMax = ref(1.0)
|
||||
const useSameClipStrength = ref(true)
|
||||
const clipStrengthMin = ref(0.0)
|
||||
const clipStrengthMax = ref(1.0)
|
||||
const rollMode = ref<'frontend' | 'backend'>('frontend')
|
||||
const isRolling = ref(false)
|
||||
|
||||
// Build config object from current state
|
||||
const buildConfig = (): RandomizerConfig => ({
|
||||
count_mode: countMode.value,
|
||||
count_fixed: countFixed.value,
|
||||
count_min: countMin.value,
|
||||
count_max: countMax.value,
|
||||
model_strength_min: modelStrengthMin.value,
|
||||
model_strength_max: modelStrengthMax.value,
|
||||
use_same_clip_strength: useSameClipStrength.value,
|
||||
clip_strength_min: clipStrengthMin.value,
|
||||
clip_strength_max: clipStrengthMax.value,
|
||||
roll_mode: rollMode.value,
|
||||
})
|
||||
|
||||
// Restore state from config object
|
||||
const restoreFromConfig = (config: RandomizerConfig) => {
|
||||
countMode.value = config.count_mode || 'range'
|
||||
countFixed.value = config.count_fixed || 5
|
||||
countMin.value = config.count_min || 3
|
||||
countMax.value = config.count_max || 7
|
||||
modelStrengthMin.value = config.model_strength_min ?? 0.0
|
||||
modelStrengthMax.value = config.model_strength_max ?? 1.0
|
||||
useSameClipStrength.value = config.use_same_clip_strength ?? true
|
||||
clipStrengthMin.value = config.clip_strength_min ?? 0.0
|
||||
clipStrengthMax.value = config.clip_strength_max ?? 1.0
|
||||
rollMode.value = config.roll_mode || 'frontend'
|
||||
}
|
||||
|
||||
// Roll loras - call API to get random selection
|
||||
const rollLoras = async (
|
||||
poolConfig: any | null,
|
||||
lockedLoras: LoraEntry[]
|
||||
): Promise<LoraEntry[]> => {
|
||||
try {
|
||||
isRolling.value = true
|
||||
|
||||
const config = buildConfig()
|
||||
|
||||
// Build request body
|
||||
const requestBody: any = {
|
||||
model_strength_min: config.model_strength_min,
|
||||
model_strength_max: config.model_strength_max,
|
||||
use_same_clip_strength: config.use_same_clip_strength,
|
||||
clip_strength_min: config.clip_strength_min,
|
||||
clip_strength_max: config.clip_strength_max,
|
||||
locked_loras: lockedLoras,
|
||||
}
|
||||
|
||||
// Add count parameters
|
||||
if (config.count_mode === 'fixed') {
|
||||
requestBody.count = config.count_fixed
|
||||
} else {
|
||||
requestBody.count_min = config.count_min
|
||||
requestBody.count_max = config.count_max
|
||||
}
|
||||
|
||||
// Add pool config if provided
|
||||
if (poolConfig) {
|
||||
// Convert pool config to backend format
|
||||
requestBody.pool_config = {
|
||||
selected_base_models: poolConfig.filters?.baseModels || [],
|
||||
include_tags: poolConfig.filters?.tags?.include || [],
|
||||
exclude_tags: poolConfig.filters?.tags?.exclude || [],
|
||||
include_folders: poolConfig.filters?.folders?.include || [],
|
||||
exclude_folders: poolConfig.filters?.folders?.exclude || [],
|
||||
no_credit_required: poolConfig.filters?.license?.noCreditRequired || false,
|
||||
allow_selling: poolConfig.filters?.license?.allowSelling || false,
|
||||
}
|
||||
}
|
||||
|
||||
// Call API endpoint
|
||||
const response = await fetch('/api/lm/loras/random-sample', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(requestBody),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json()
|
||||
throw new Error(errorData.error || 'Failed to fetch random LoRAs')
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
|
||||
if (!data.success) {
|
||||
throw new Error(data.error || 'Failed to get random LoRAs')
|
||||
}
|
||||
|
||||
return data.loras || []
|
||||
} catch (error) {
|
||||
console.error('[LoraRandomizerState] Error rolling LoRAs:', error)
|
||||
throw error
|
||||
} finally {
|
||||
isRolling.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// Computed properties
|
||||
const isClipStrengthDisabled = computed(() => useSameClipStrength.value)
|
||||
|
||||
return {
|
||||
// State refs
|
||||
countMode,
|
||||
countFixed,
|
||||
countMin,
|
||||
countMax,
|
||||
modelStrengthMin,
|
||||
modelStrengthMax,
|
||||
useSameClipStrength,
|
||||
clipStrengthMin,
|
||||
clipStrengthMax,
|
||||
rollMode,
|
||||
isRolling,
|
||||
|
||||
// Computed
|
||||
isClipStrengthDisabled,
|
||||
|
||||
// Methods
|
||||
buildConfig,
|
||||
restoreFromConfig,
|
||||
rollLoras,
|
||||
}
|
||||
}
|
||||
@@ -1,13 +1,17 @@
|
||||
import { createApp, type App as VueApp } from 'vue'
|
||||
import PrimeVue from 'primevue/config'
|
||||
import LoraPoolWidget from '@/components/LoraPoolWidget.vue'
|
||||
import type { LoraPoolConfig, LegacyLoraPoolConfig } from './composables/types'
|
||||
import LoraRandomizerWidget from '@/components/LoraRandomizerWidget.vue'
|
||||
import type { LoraPoolConfig, LegacyLoraPoolConfig, RandomizerConfig } from './composables/types'
|
||||
|
||||
// @ts-ignore - ComfyUI external module
|
||||
import { app } from '../../../scripts/app.js'
|
||||
|
||||
const vueApps = new Map<number, VueApp>()
|
||||
|
||||
// Cache for dynamically loaded addLorasWidget module
|
||||
let addLorasWidgetCache: any = null
|
||||
|
||||
// @ts-ignore
|
||||
function createLoraPoolWidget(node) {
|
||||
const container = document.createElement('div')
|
||||
@@ -78,14 +82,109 @@ function createLoraPoolWidget(node) {
|
||||
return { widget }
|
||||
}
|
||||
|
||||
// @ts-ignore
|
||||
function createLoraRandomizerWidget(node) {
|
||||
const container = document.createElement('div')
|
||||
container.id = `lora-randomizer-widget-${node.id}`
|
||||
container.style.width = '100%'
|
||||
container.style.height = '100%'
|
||||
container.style.display = 'flex'
|
||||
container.style.flexDirection = 'column'
|
||||
container.style.overflow = 'hidden'
|
||||
|
||||
let internalValue: RandomizerConfig | undefined
|
||||
|
||||
const widget = node.addDOMWidget(
|
||||
'randomizer_config',
|
||||
'RANDOMIZER_CONFIG',
|
||||
container,
|
||||
{
|
||||
getValue() {
|
||||
return internalValue
|
||||
},
|
||||
setValue(v: RandomizerConfig) {
|
||||
internalValue = v
|
||||
if (typeof widget.onSetValue === 'function') {
|
||||
widget.onSetValue(v)
|
||||
}
|
||||
},
|
||||
serialize: true,
|
||||
getMinHeight() {
|
||||
return 500
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
widget.updateConfig = (v: RandomizerConfig) => {
|
||||
internalValue = v
|
||||
}
|
||||
|
||||
// Handle roll event from Vue component
|
||||
widget.onRoll = (randomLoras: any[]) => {
|
||||
console.log('[createLoraRandomizerWidget] Roll event received:', randomLoras)
|
||||
|
||||
// Find the loras widget on this node and update it
|
||||
const lorasWidget = node.widgets.find((w: any) => w.name === 'loras')
|
||||
if (lorasWidget) {
|
||||
lorasWidget.value = randomLoras
|
||||
console.log('[createLoraRandomizerWidget] Updated loras widget with rolled LoRAs')
|
||||
} else {
|
||||
console.warn('[createLoraRandomizerWidget] loras widget not found on node')
|
||||
}
|
||||
}
|
||||
|
||||
const vueApp = createApp(LoraRandomizerWidget, {
|
||||
widget,
|
||||
node
|
||||
})
|
||||
|
||||
vueApp.use(PrimeVue, {
|
||||
unstyled: true,
|
||||
ripple: false
|
||||
})
|
||||
|
||||
vueApp.mount(container)
|
||||
vueApps.set(node.id + 10000, vueApp) // Offset to avoid collision with pool widget
|
||||
|
||||
widget.computeLayoutSize = () => {
|
||||
const minWidth = 500
|
||||
const minHeight = 500
|
||||
const maxHeight = 500
|
||||
|
||||
return { minHeight, minWidth, maxHeight }
|
||||
}
|
||||
|
||||
widget.onRemove = () => {
|
||||
const vueApp = vueApps.get(node.id + 10000)
|
||||
if (vueApp) {
|
||||
vueApp.unmount()
|
||||
vueApps.delete(node.id + 10000)
|
||||
}
|
||||
}
|
||||
|
||||
return { widget }
|
||||
}
|
||||
|
||||
app.registerExtension({
|
||||
name: 'LoraManager.VueWidgets',
|
||||
|
||||
getCustomWidgets() {
|
||||
getCustomWidgets() {
|
||||
return {
|
||||
// @ts-ignore
|
||||
LORA_POOL_CONFIG(node) {
|
||||
return createLoraPoolWidget(node)
|
||||
},
|
||||
// @ts-ignore
|
||||
RANDOMIZER_CONFIG(node) {
|
||||
return createLoraRandomizerWidget(node)
|
||||
},
|
||||
// @ts-ignore
|
||||
async LORAS(node: any) {
|
||||
if (!addLorasWidgetCache) {
|
||||
const module = await import(/* @vite-ignore */ '../loras_widget.js')
|
||||
addLorasWidgetCache = module.addLorasWidget
|
||||
}
|
||||
return addLorasWidgetCache(node, 'loras', {}, null)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,7 +21,8 @@ export default defineConfig({
|
||||
},
|
||||
rollupOptions: {
|
||||
external: [
|
||||
'../../../scripts/app.js'
|
||||
'../../../scripts/app.js',
|
||||
'../loras_widget.js'
|
||||
],
|
||||
output: {
|
||||
dir: '../web/comfyui/vue-widgets',
|
||||
|
||||
40
web/comfyui/lora_demo_widget.js
Normal file
40
web/comfyui/lora_demo_widget.js
Normal file
@@ -0,0 +1,40 @@
|
||||
import { app } from "../../../scripts/app.js";
|
||||
|
||||
app.registerExtension({
|
||||
name: "LoraManager.LoraDemo",
|
||||
|
||||
// Hook into node creation
|
||||
async nodeCreated(node) {
|
||||
if (node.comfyClass !== "Lora Demo (LoraManager)") {
|
||||
return;
|
||||
}
|
||||
|
||||
// Store original onExecuted
|
||||
const originalOnExecuted = node.onExecuted?.bind(node);
|
||||
|
||||
// Override onExecuted to handle UI updates
|
||||
node.onExecuted = function(output) {
|
||||
// Check if output has loras data
|
||||
if (output?.loras && Array.isArray(output.loras)) {
|
||||
console.log("[LoraDemoNode] Received loras data from backend:", output.loras);
|
||||
|
||||
// Find the loras widget on this node
|
||||
const lorasWidget = node.widgets.find(w => w.name === 'loras');
|
||||
|
||||
if (lorasWidget) {
|
||||
// Update widget value with backend data
|
||||
lorasWidget.value = output.loras;
|
||||
|
||||
console.log(`[LoraDemoNode] Updated widget with ${output.loras.length} loras`);
|
||||
} else {
|
||||
console.warn("[LoraDemoNode] loras widget not found on node");
|
||||
}
|
||||
}
|
||||
|
||||
// Call original onExecuted if it exists
|
||||
if (originalOnExecuted) {
|
||||
return originalOnExecuted(output);
|
||||
}
|
||||
};
|
||||
}
|
||||
});
|
||||
44
web/comfyui/lora_randomizer_widget.js
Normal file
44
web/comfyui/lora_randomizer_widget.js
Normal file
@@ -0,0 +1,44 @@
|
||||
import { app } from "../../../scripts/app.js";
|
||||
|
||||
app.registerExtension({
|
||||
name: "LoraManager.LoraRandomizer",
|
||||
|
||||
// Hook into node creation
|
||||
async nodeCreated(node) {
|
||||
if (node.comfyClass !== "Lora Randomizer (LoraManager)") {
|
||||
return;
|
||||
}
|
||||
|
||||
console.log("[LoraRandomizerWidget] Node created:", node.id);
|
||||
|
||||
// Store original onExecuted
|
||||
const originalOnExecuted = node.onExecuted?.bind(node);
|
||||
|
||||
// Override onExecuted to handle UI updates
|
||||
node.onExecuted = function(output) {
|
||||
console.log("[LoraRandomizerWidget] Node executed with output:", output);
|
||||
|
||||
// Check if output has loras data
|
||||
if (output?.loras && Array.isArray(output.loras)) {
|
||||
console.log("[LoraRandomizerWidget] Received loras data from backend:", output.loras);
|
||||
|
||||
// Find the loras widget on this node
|
||||
const lorasWidget = node.widgets.find(w => w.name === 'loras');
|
||||
|
||||
if (lorasWidget) {
|
||||
// Update widget value with backend data
|
||||
lorasWidget.value = output.loras;
|
||||
|
||||
console.log(`[LoraRandomizerWidget] Updated widget with ${output.loras.length} loras`);
|
||||
} else {
|
||||
console.warn("[LoraRandomizerWidget] loras widget not found on node");
|
||||
}
|
||||
}
|
||||
|
||||
// Call original onExecuted if it exists
|
||||
if (originalOnExecuted) {
|
||||
return originalOnExecuted(output);
|
||||
}
|
||||
};
|
||||
}
|
||||
});
|
||||
@@ -97,27 +97,27 @@ app.registerExtension({
|
||||
if (isUpdating) return;
|
||||
isUpdating = true;
|
||||
|
||||
try {
|
||||
// Update this stacker's direct trigger toggles with its own active loras
|
||||
// Only if the stacker node itself is active (mode 0 for Always, mode 3 for On Trigger)
|
||||
const isNodeActive = this.mode === undefined || this.mode === 0 || this.mode === 3;
|
||||
const activeLoraNames = new Set();
|
||||
if (isNodeActive) {
|
||||
value.forEach((lora) => {
|
||||
if (lora.active) {
|
||||
activeLoraNames.add(lora.name);
|
||||
}
|
||||
});
|
||||
try {
|
||||
// Update this stacker's direct trigger toggles with its own active loras
|
||||
// Only if the stacker node itself is active (mode 0 for Always, mode 3 for On Trigger)
|
||||
const isNodeActive = this.mode === undefined || this.mode === 0 || this.mode === 3;
|
||||
const activeLoraNames = new Set();
|
||||
if (isNodeActive) {
|
||||
value.forEach((lora) => {
|
||||
if (lora.active) {
|
||||
activeLoraNames.add(lora.name);
|
||||
}
|
||||
});
|
||||
}
|
||||
updateConnectedTriggerWords(this, activeLoraNames);
|
||||
|
||||
// Find all Lora Loader nodes in the chain that might need updates
|
||||
updateDownstreamLoaders(this);
|
||||
} finally {
|
||||
isUpdating = false;
|
||||
}
|
||||
updateConnectedTriggerWords(this, activeLoraNames);
|
||||
|
||||
// Find all Lora Loader nodes in the chain that might need updates
|
||||
updateDownstreamLoaders(this);
|
||||
} finally {
|
||||
isUpdating = false;
|
||||
}
|
||||
|
||||
scheduleInputSync(value);
|
||||
scheduleInputSync(value);
|
||||
});
|
||||
|
||||
this.lorasWidget = result.widget;
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user