checkpoint

This commit is contained in:
Will Miao
2025-03-11 19:29:31 +08:00
parent 5a6c412845
commit ad56cafd62
7 changed files with 255 additions and 16 deletions

View File

@@ -8,7 +8,7 @@ from .utils import FlexibleOptionalInputType, any_type
class LoraManagerLoader:
NAME = "Lora Loader (LoraManager)"
CATEGORY = "loaders"
CATEGORY = "Lora Manager/loaders"
@classmethod
def INPUT_TYPES(cls):
@@ -23,7 +23,10 @@ class LoraManagerLoader:
"placeholder": "LoRA syntax input: <lora:name:strength>"
}),
},
"optional": FlexibleOptionalInputType(any_type),
"optional": {
**FlexibleOptionalInputType(any_type),
"lora_stack": ("LORA_STACK", {"default": None}),
}
}
RETURN_TYPES = ("MODEL", "CLIP", IO.STRING)
@@ -49,11 +52,32 @@ class LoraManagerLoader:
return relative_path, trigger_words
return lora_name, [] # Fallback if not found
def load_loras(self, model, clip, text, **kwargs):
"""Loads multiple LoRAs based on the kwargs input."""
def extract_lora_name(self, lora_path):
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
# Get the basename without extension
basename = os.path.basename(lora_path)
return os.path.splitext(basename)[0]
def load_loras(self, model, clip, text, lora_stack=None, **kwargs):
print("load_loras kwargs: ", kwargs)
"""Loads multiple LoRAs based on the kwargs input and lora_stack."""
loaded_loras = []
all_trigger_words = []
# First process lora_stack if available
if lora_stack:
for lora_path, model_strength, clip_strength in lora_stack:
# Apply the LoRA using the provided path and strengths
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
# Extract lora name for trigger words lookup
lora_name = self.extract_lora_name(lora_path)
_, trigger_words = asyncio.run(self.get_lora_info(lora_name))
all_trigger_words.extend(trigger_words)
loaded_loras.append(f"{lora_name}: {model_strength}")
# Then process loras from kwargs
if 'loras' in kwargs:
for lora in kwargs['loras']:
if not lora.get('active', False):

94
py/nodes/lora_stacker.py Normal file
View File

@@ -0,0 +1,94 @@
from comfy.comfy_types import IO # type: ignore
from ..services.lora_scanner import LoraScanner
from ..config import config
import asyncio
import os
from .utils import FlexibleOptionalInputType, any_type
class LoraStacker:
NAME = "Lora Stacker (LoraManager)"
CATEGORY = "Lora Manager/stackers"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"text": (IO.STRING, {
"multiline": True,
"dynamicPrompts": True,
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
"placeholder": "LoRA syntax input: <lora:name:strength>"
}),
},
"optional": {
**FlexibleOptionalInputType(any_type),
"lora_stack": ("LORA_STACK", {"default": None}),
}
}
RETURN_TYPES = ("LORA_STACK", IO.STRING)
RETURN_NAMES = ("LORA_STACK", "trigger_words")
FUNCTION = "stack_loras"
async def get_lora_info(self, lora_name):
"""Get the lora path and trigger words from cache"""
scanner = await LoraScanner.get_instance()
cache = await scanner.get_cached_data()
for item in cache.raw_data:
if item.get('file_name') == lora_name:
file_path = item.get('file_path')
if file_path:
for root in config.loras_roots:
root = root.replace(os.sep, '/')
if file_path.startswith(root):
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
# Get trigger words from civitai metadata
civitai = item.get('civitai', {})
trigger_words = civitai.get('trainedWords', []) if civitai else []
return relative_path, trigger_words
return lora_name, [] # Fallback if not found
def extract_lora_name(self, lora_path):
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
# Get the basename without extension
basename = os.path.basename(lora_path)
return os.path.splitext(basename)[0]
def stack_loras(self, text, lora_stack=None, **kwargs):
print("stack_loras kwargs: ", kwargs)
"""Stacks multiple LoRAs based on the kwargs input without loading them."""
stack = []
all_trigger_words = []
# Process existing lora_stack if available
if lora_stack:
stack.extend(lora_stack)
# Get trigger words from existing stack entries
for lora_path, _, _ in lora_stack:
lora_name = self.extract_lora_name(lora_path)
_, trigger_words = asyncio.run(self.get_lora_info(lora_name))
all_trigger_words.extend(trigger_words)
if 'loras' in kwargs:
for lora in kwargs['loras']:
if not lora.get('active', False):
continue
lora_name = lora['name']
model_strength = float(lora['strength'])
clip_strength = model_strength # Using same strength for both as in the original loader
# Get lora path and trigger words
lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name))
# Add to stack without loading
stack.append((lora_path, model_strength, clip_strength))
# Add trigger words to collection
all_trigger_words.extend(trigger_words)
# use ',, ' to separate trigger words for group mode
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
return (stack, trigger_words_text)

View File

@@ -4,7 +4,7 @@ from .utils import FlexibleOptionalInputType, any_type
class TriggerWordToggle:
NAME = "TriggerWord Toggle (LoraManager)"
CATEGORY = "lora manager"
CATEGORY = "Lora Manager/utils"
DESCRIPTION = "Toggle trigger words on/off"
@classmethod
@@ -13,10 +13,7 @@ class TriggerWordToggle:
"required": {
"group_mode": ("BOOLEAN", {"default": True}),
},
"optional": {
**FlexibleOptionalInputType(any_type),
"trigger_words": ("STRING", {"default": "", "defaultInput": True}),
},
"optional": FlexibleOptionalInputType(any_type),
"hidden": {
"id": "UNIQUE_ID", # 会被 ComfyUI 自动替换为唯一ID
},
@@ -26,7 +23,9 @@ class TriggerWordToggle:
RETURN_NAMES = ("filtered_trigger_words",)
FUNCTION = "process_trigger_words"
def process_trigger_words(self, id, trigger_words="", **kwargs):
def process_trigger_words(self, id, **kwargs):
print("trigger_words ", kwargs)
trigger_words = kwargs.get("trigger_words", "")
# Send trigger words to frontend
PromptServer.instance.send_sync("trigger_word_update", {
"id": id,