feat(trigger): add optional strength adjustment for trigger words

Add `allow_strength_adjustment` parameter to enable mouse wheel adjustment of trigger word strengths. When enabled, strength values are preserved and can be modified interactively. Also improves trigger word parsing by handling whitespace more consistently and adding debug logging for trigger data inspection.
This commit is contained in:
Will Miao
2025-11-09 22:24:23 +08:00
parent f81ff2efe9
commit 4dd8ce778e
4 changed files with 265 additions and 117 deletions

View File

@@ -23,6 +23,10 @@ class TriggerWordToggle:
"default": True, "default": True,
"tooltip": "Sets the default initial state (active or inactive) when trigger words are added." "tooltip": "Sets the default initial state (active or inactive) when trigger words are added."
}), }),
"allow_strength_adjustment": ("BOOLEAN", {
"default": False,
"tooltip": "Enable mouse wheel adjustment of each trigger word's strength."
}),
}, },
"optional": FlexibleOptionalInputType(any_type), "optional": FlexibleOptionalInputType(any_type),
"hidden": { "hidden": {
@@ -47,7 +51,14 @@ class TriggerWordToggle:
else: else:
return data return data
def process_trigger_words(self, id, group_mode, default_active, **kwargs): def process_trigger_words(
self,
id,
group_mode,
default_active,
allow_strength_adjustment=False,
**kwargs,
):
# Handle both old and new formats for trigger_words # Handle both old and new formats for trigger_words
trigger_words_data = self._get_toggle_data(kwargs, 'orinalMessage') trigger_words_data = self._get_toggle_data(kwargs, 'orinalMessage')
trigger_words = trigger_words_data if isinstance(trigger_words_data, str) else "" trigger_words = trigger_words_data if isinstance(trigger_words_data, str) else ""
@@ -73,50 +84,60 @@ class TriggerWordToggle:
# Extract strength if it's in the format "(word:strength)" # Extract strength if it's in the format "(word:strength)"
strength_match = re.match(r'\((.+):([\d.]+)\)', text) strength_match = re.match(r'\((.+):([\d.]+)\)', text)
if strength_match: if strength_match:
original_word = strength_match.group(1) original_word = strength_match.group(1).strip()
strength = float(strength_match.group(2)) strength = float(strength_match.group(2))
active_state[original_word] = active active_state[original_word] = active
strength_map[original_word] = strength if allow_strength_adjustment:
strength_map[original_word] = strength
else: else:
active_state[text] = active active_state[text.strip()] = active
if group_mode: if group_mode:
# Split by two or more consecutive commas to get groups if isinstance(trigger_data, list):
groups = re.split(r',{2,}', trigger_words) filtered_groups = []
# Remove leading/trailing whitespace from each group for item in trigger_data:
groups = [group.strip() for group in groups] text = (item.get('text') or "").strip()
if not text:
continue
if item.get('active', False):
filtered_groups.append(text)
# Process groups: keep those not in toggle_trigger_words or those that are active if filtered_groups:
filtered_groups = [] filtered_triggers = ', '.join(filtered_groups)
for group in groups: else:
# Check if this group contains any words that are in the active_state filtered_triggers = ""
group_words = [word.strip() for word in group.split(',')]
active_group_words = []
for word in group_words:
# Remove any existing strength formatting for comparison
word_comparison = re.sub(r'\((.+):([\d.]+)\)', r'\1', word).strip()
if word_comparison not in active_state or active_state[word_comparison]:
# If this word has a strength value, use that instead of the original
if word_comparison in strength_map:
active_group_words.append(f"({word_comparison}:{strength_map[word_comparison]:.2f})")
else:
# Preserve existing strength formatting if the word was previously modified
# Check if the original word had strength formatting
strength_match = re.match(r'\((.+):([\d.]+)\)', word)
if strength_match:
active_group_words.append(word)
else:
active_group_words.append(word)
if active_group_words:
filtered_groups.append(', '.join(active_group_words))
if filtered_groups:
filtered_triggers = ', '.join(filtered_groups)
else: else:
filtered_triggers = "" # Split by two or more consecutive commas to get groups
groups = re.split(r',{2,}', trigger_words)
# Remove leading/trailing whitespace from each group
groups = [group.strip() for group in groups]
# Process groups: keep those not in toggle_trigger_words or those that are active
filtered_groups = []
for group in groups:
# Check if this group contains any words that are in the active_state
group_words = [word.strip() for word in group.split(',')]
active_group_words = []
for word in group_words:
word_comparison = re.sub(r'\((.+):([\d.]+)\)', r'\1', word).strip()
if word_comparison not in active_state or active_state[word_comparison]:
active_group_words.append(
self._format_word_output(
word_comparison,
strength_map,
allow_strength_adjustment,
)
)
if active_group_words:
filtered_groups.append(', '.join(active_group_words))
if filtered_groups:
filtered_triggers = ', '.join(filtered_groups)
else:
filtered_triggers = ""
else: else:
# Normal mode: split by commas and treat each word as a separate tag # Normal mode: split by commas and treat each word as a separate tag
original_words = [word.strip() for word in trigger_words.split(',')] original_words = [word.strip() for word in trigger_words.split(',')]
@@ -129,17 +150,13 @@ class TriggerWordToggle:
word_comparison = re.sub(r'\((.+):([\d.]+)\)', r'\1', word).strip() word_comparison = re.sub(r'\((.+):([\d.]+)\)', r'\1', word).strip()
if word_comparison not in active_state or active_state[word_comparison]: if word_comparison not in active_state or active_state[word_comparison]:
# If this word has a strength value, use that instead of the original filtered_words.append(
if word_comparison in strength_map: self._format_word_output(
filtered_words.append(f"({word_comparison}:{strength_map[word_comparison]:.2f})") word_comparison,
else: strength_map,
# Preserve existing strength formatting if the word was previously modified allow_strength_adjustment,
# Check if the original word had strength formatting )
strength_match = re.match(r'\((.+):([\d.]+)\)', word) )
if strength_match:
filtered_words.append(word)
else:
filtered_words.append(word)
if filtered_words: if filtered_words:
filtered_triggers = ', '.join(filtered_words) filtered_triggers = ', '.join(filtered_words)
@@ -150,3 +167,8 @@ class TriggerWordToggle:
logger.error(f"Error processing trigger words: {e}") logger.error(f"Error processing trigger words: {e}")
return (filtered_triggers,) return (filtered_triggers,)
def _format_word_output(self, base_word, strength_map, allow_strength_adjustment):
if allow_strength_adjustment and base_word in strength_map:
return f"({base_word}:{strength_map[base_word]:.2f})"
return base_word

View File

@@ -0,0 +1,26 @@
from py.nodes.trigger_word_toggle import TriggerWordToggle
def test_group_mode_preserves_parenthesized_groups():
node = TriggerWordToggle()
trigger_data = [
{'text': 'flat color, dark theme', 'active': True, 'strength': None, 'highlighted': False},
{'text': '(a, really, long, test, trigger, word:1.06)', 'active': True, 'strength': 1.06, 'highlighted': False},
{'text': '(sinozick style:0.94)', 'active': True, 'strength': 0.94, 'highlighted': False},
]
original_message = (
"flat color, dark theme, (a, really, long, test, trigger, word:1.06), "
"(sinozick style:0.94)"
)
filtered, = node.process_trigger_words(
id="node",
group_mode=True,
default_active=True,
allow_strength_adjustment=False,
orinalMessage=original_message,
toggle_trigger_words=trigger_data,
)
assert filtered == original_message

View File

@@ -1,10 +1,12 @@
import { forwardMiddleMouseToCanvas } from "./utils.js"; import { forwardMiddleMouseToCanvas } from "./utils.js";
export function addTagsWidget(node, name, opts, callback, wheelSensitivity = 0.02) { export function addTagsWidget(node, name, opts, callback, wheelSensitivity = 0.02, options = {}) {
// Create container for tags // Create container for tags
const container = document.createElement("div"); const container = document.createElement("div");
container.className = "comfy-tags-container"; container.className = "comfy-tags-container";
const { allowStrengthAdjustment = true } = options;
forwardMiddleMouseToCanvas(container); forwardMiddleMouseToCanvas(container);
// Set initial height // Set initial height
@@ -41,6 +43,7 @@ export function addTagsWidget(node, name, opts, callback, wheelSensitivity = 0.0
} }
const normalizedTags = tagsData; const normalizedTags = tagsData;
const showStrengthInfo = widget.allowStrengthAdjustment ?? allowStrengthAdjustment;
if (normalizedTags.length === 0) { if (normalizedTags.length === 0) {
// Show message when no tags are present // Show message when no tags are present
@@ -82,16 +85,44 @@ export function addTagsWidget(node, name, opts, callback, wheelSensitivity = 0.0
const tagEl = document.createElement("div"); const tagEl = document.createElement("div");
tagEl.className = "comfy-tag"; tagEl.className = "comfy-tag";
updateTagStyle(tagEl, active, highlighted, strength); const textSpan = document.createElement("span");
textSpan.className = "comfy-tag-text";
textSpan.textContent = text;
Object.assign(textSpan.style, {
display: "inline-block",
overflow: "hidden",
textOverflow: "ellipsis",
whiteSpace: "nowrap",
minWidth: "0",
flexGrow: "1",
});
tagEl.appendChild(textSpan);
// Set the text content to include strength if present const strengthBadge = showStrengthInfo ? document.createElement("span") : null;
// Always show strength if it has been modified to avoid layout shift if (strengthBadge) {
if (strength !== undefined && strength !== null) { strengthBadge.className = "comfy-tag-strength";
tagEl.textContent = `${text}:${strength.toFixed(2)}`; Object.assign(strengthBadge.style, {
} else { fontSize: "11px",
tagEl.textContent = text; fontWeight: "600",
padding: "1px 6px",
borderRadius: "999px",
letterSpacing: "0.2px",
backgroundColor: "rgba(255,255,255,0.08)",
color: "rgba(255,255,255,0.95)",
border: "1px solid rgba(255,255,255,0.25)",
lineHeight: "normal",
minWidth: "34px",
textAlign: "center",
pointerEvents: "none",
opacity: "0",
visibility: "hidden",
transition: "opacity 0.2s ease",
});
tagEl.appendChild(strengthBadge);
} }
tagEl.title = text; // Set tooltip for full content
updateTagStyle(tagEl, active, highlighted, strength);
updateStrengthDisplay(tagEl, strength, text, showStrengthInfo);
// Add click handler to toggle state // Add click handler to toggle state
tagEl.addEventListener("click", (e) => { tagEl.addEventListener("click", (e) => {
@@ -100,12 +131,14 @@ export function addTagsWidget(node, name, opts, callback, wheelSensitivity = 0.0
// Toggle active state for this specific tag using its index // Toggle active state for this specific tag using its index
const updatedTags = [...widget.value]; const updatedTags = [...widget.value];
updatedTags[index].active = !updatedTags[index].active; updatedTags[index].active = !updatedTags[index].active;
textSpan.textContent = updatedTags[index].text;
updateTagStyle( updateTagStyle(
tagEl, tagEl,
updatedTags[index].active, updatedTags[index].active,
updatedTags[index].highlighted, updatedTags[index].highlighted,
updatedTags[index].strength updatedTags[index].strength
); );
updateStrengthDisplay(tagEl, updatedTags[index].strength, updatedTags[index].text);
tagEl.dataset.active = updatedTags[index].active ? "true" : "false"; tagEl.dataset.active = updatedTags[index].active ? "true" : "false";
tagEl.dataset.highlighted = updatedTags[index].highlighted ? "true" : "false"; tagEl.dataset.highlighted = updatedTags[index].highlighted ? "true" : "false";
@@ -114,48 +147,42 @@ export function addTagsWidget(node, name, opts, callback, wheelSensitivity = 0.0
}); });
// Add mouse wheel handler to adjust strength // Add mouse wheel handler to adjust strength
tagEl.addEventListener("wheel", (e) => { if (showStrengthInfo) {
e.preventDefault(); tagEl.addEventListener("wheel", (e) => {
e.stopPropagation(); e.preventDefault();
e.stopPropagation();
// Only adjust strength if the mouse is over the tag // Only adjust strength if the mouse is over the tag
const updatedTags = [...widget.value]; const updatedTags = [...widget.value];
let currentStrength = updatedTags[index].strength; let currentStrength = updatedTags[index].strength;
// If no strength is set, default to 1.0 // If no strength is set, default to 1.0
if (currentStrength === undefined || currentStrength === null) { if (currentStrength === undefined || currentStrength === null) {
currentStrength = 1.0; currentStrength = 1.0;
} }
// Adjust strength based on scroll direction // Adjust strength based on scroll direction
// DeltaY < 0 is scroll up, deltaY > 0 is scroll down // DeltaY < 0 is scroll up, deltaY > 0 is scroll down
if (e.deltaY < 0) { if (e.deltaY < 0) {
// Scroll up: increase strength by wheelSensitivity // Scroll up: increase strength by wheelSensitivity
currentStrength += wheelSensitivity; currentStrength += wheelSensitivity;
} else { } else {
// Scroll down: decrease strength by wheelSensitivity // Scroll down: decrease strength by wheelSensitivity
currentStrength -= wheelSensitivity; currentStrength -= wheelSensitivity;
} }
// Ensure strength doesn't go below 0 // Ensure strength doesn't go below 0
currentStrength = Math.max(0, currentStrength); currentStrength = Math.max(0, currentStrength);
// Update the strength value // Update the strength value
updatedTags[index].strength = currentStrength; updatedTags[index].strength = currentStrength;
textSpan.textContent = updatedTags[index].text;
// Update the tag display to show the strength value updateStrengthDisplay(tagEl, currentStrength, updatedTags[index].text, showStrengthInfo);
// Always show strength once it has been modified to avoid layout shift
tagEl.textContent = `${updatedTags[index].text}:${currentStrength.toFixed(2)}`;
updateTagStyle( widget.value = updatedTags;
tagEl, });
updatedTags[index].active, }
updatedTags[index].highlighted,
updatedTags[index].strength
);
widget.value = updatedTags;
});
rowContainer.appendChild(tagEl); rowContainer.appendChild(tagEl);
tagCount++; tagCount++;
@@ -190,7 +217,7 @@ export function addTagsWidget(node, name, opts, callback, wheelSensitivity = 0.0
// Helper function to update tag style based on active state // Helper function to update tag style based on active state
function updateTagStyle(tagEl, active, highlighted = false, strength = null) { function updateTagStyle(tagEl, active, highlighted = false, strength = null) {
const baseStyles = { const baseStyles = {
padding: "3px 10px", // Adjusted vertical padding to balance text padding: "3px 10px",
borderRadius: "6px", borderRadius: "6px",
maxWidth: "200px", maxWidth: "200px",
overflow: "hidden", overflow: "hidden",
@@ -200,7 +227,9 @@ export function addTagsWidget(node, name, opts, callback, wheelSensitivity = 0.0
cursor: "pointer", cursor: "pointer",
transition: "all 0.2s ease", transition: "all 0.2s ease",
border: "1px solid transparent", border: "1px solid transparent",
display: "inline-block", // inline-block for better text truncation display: "inline-flex",
alignItems: "center",
gap: "6px",
boxShadow: "0 1px 2px rgba(0,0,0,0.1)", boxShadow: "0 1px 2px rgba(0,0,0,0.1)",
margin: "1px", margin: "1px",
userSelect: "none", userSelect: "none",
@@ -214,7 +243,6 @@ export function addTagsWidget(node, name, opts, callback, wheelSensitivity = 0.0
maxWidth: "200px", maxWidth: "200px",
lineHeight: "16px", // Added explicit line-height lineHeight: "16px", // Added explicit line-height
verticalAlign: "middle", // Added vertical alignment verticalAlign: "middle", // Added vertical alignment
position: "relative", // For better text positioning
textAlign: "center", // Center text horizontally textAlign: "center", // Center text horizontally
}; };
@@ -263,6 +291,42 @@ export function addTagsWidget(node, name, opts, callback, wheelSensitivity = 0.0
tagEl.dataset.highlighted = highlighted ? "true" : "false"; tagEl.dataset.highlighted = highlighted ? "true" : "false";
} }
function formatStrengthValue(value) {
if (value === undefined || value === null) {
return null;
}
const num = Number(value);
if (!Number.isFinite(num)) {
return null;
}
return num.toFixed(2);
}
function updateStrengthDisplay(tagEl, strength, baseText, showStrengthInfo) {
if (!showStrengthInfo) {
tagEl.title = baseText;
return;
}
const badge = tagEl.querySelector(".comfy-tag-strength");
if (!badge) {
tagEl.title = baseText;
return;
}
const displayValue = strength === undefined || strength === null ? 1 : strength;
const formatted = formatStrengthValue(displayValue);
if (formatted !== null) {
badge.textContent = formatted;
badge.style.opacity = "1";
badge.style.visibility = "visible";
tagEl.title = `${baseText} (${formatted})`;
} else {
badge.textContent = "";
badge.style.opacity = "0";
badge.style.visibility = "hidden";
tagEl.title = baseText;
}
}
// Store the value as array // Store the value as array
let widgetValue = initialTagsData; let widgetValue = initialTagsData;

View File

@@ -75,13 +75,20 @@ app.registerExtension({
requestAnimationFrame(async () => { requestAnimationFrame(async () => {
// Get the wheel sensitivity setting // Get the wheel sensitivity setting
const wheelSensitivity = getWheelSensitivity(); const wheelSensitivity = getWheelSensitivity();
const groupModeWidget = node.widgets[0];
const defaultActiveWidget = node.widgets[1];
const strengthAdjustmentWidget = node.widgets[2];
const initialStrengthAdjustment = Boolean(strengthAdjustmentWidget?.value);
// Get the widget object directly from the returned object // Get the widget object directly from the returned object
const result = addTagsWidget(node, "toggle_trigger_words", { const result = addTagsWidget(node, "toggle_trigger_words", {
defaultVal: [] defaultVal: []
}, null, wheelSensitivity); }, null, wheelSensitivity, {
allowStrengthAdjustment: initialStrengthAdjustment
});
node.tagWidget = result.widget; node.tagWidget = result.widget;
node.tagWidget.allowStrengthAdjustment = initialStrengthAdjustment;
const normalizeTagText = (text) => const normalizeTagText = (text) =>
(typeof text === 'string' ? text.trim().toLowerCase() : ''); (typeof text === 'string' ? text.trim().toLowerCase() : '');
@@ -148,31 +155,40 @@ app.registerExtension({
hiddenWidget.type = CONVERTED_TYPE; hiddenWidget.type = CONVERTED_TYPE;
hiddenWidget.hidden = true; hiddenWidget.hidden = true;
hiddenWidget.computeSize = () => [0, -4]; hiddenWidget.computeSize = () => [0, -4];
node.originalMessageWidget = hiddenWidget;
// Restore saved value if exists // Restore saved value if exists
const tagWidgetIndex = node.widgets.indexOf(result.widget);
const originalMessageWidgetIndex = node.widgets.indexOf(hiddenWidget);
if (node.widgets_values && node.widgets_values.length > 0) { if (node.widgets_values && node.widgets_values.length > 0) {
// 0 is group mode, 1 is default_active, 2 is tag widget, 3 is original message if (tagWidgetIndex >= 0) {
const savedValue = node.widgets_values[2]; const savedValue = node.widgets_values[tagWidgetIndex];
if (savedValue) { if (savedValue) {
result.widget.value = Array.isArray(savedValue) ? savedValue : []; result.widget.value = Array.isArray(savedValue) ? savedValue : [];
}
} }
const originalMessage = node.widgets_values[3]; if (originalMessageWidgetIndex >= 0) {
if (originalMessage) { const originalMessage = node.widgets_values[originalMessageWidgetIndex];
hiddenWidget.value = originalMessage; if (originalMessage) {
hiddenWidget.value = originalMessage;
}
} }
} }
requestAnimationFrame(() => node.applyTriggerHighlightState?.()); requestAnimationFrame(() => node.applyTriggerHighlightState?.());
const groupModeWidget = node.widgets[0];
groupModeWidget.callback = (value) => { groupModeWidget.callback = (value) => {
if (node.widgets[3].value) { if (node.originalMessageWidget?.value) {
this.updateTagsBasedOnMode(node, node.widgets[3].value, value); this.updateTagsBasedOnMode(
node,
node.originalMessageWidget.value,
value,
Boolean(strengthAdjustmentWidget?.value)
);
} }
} }
// Add callback for default_active widget // Add callback for default_active widget
const defaultActiveWidget = node.widgets[1];
defaultActiveWidget.callback = (value) => { defaultActiveWidget.callback = (value) => {
// Set all existing tags' active state to the new value // Set all existing tags' active state to the new value
if (node.tagWidget && node.tagWidget.value) { if (node.tagWidget && node.tagWidget.value) {
@@ -185,6 +201,21 @@ app.registerExtension({
} }
} }
if (strengthAdjustmentWidget) {
strengthAdjustmentWidget.callback = (value) => {
const allowStrengthAdjustment = Boolean(value);
if (node.tagWidget) {
node.tagWidget.allowStrengthAdjustment = allowStrengthAdjustment;
}
this.updateTagsBasedOnMode(
node,
node.originalMessageWidget?.value || "",
groupModeWidget?.value ?? false,
allowStrengthAdjustment
);
};
}
// Override the serializeValue method to properly format trigger words with strength // Override the serializeValue method to properly format trigger words with strength
const originalSerializeValue = result.widget.serializeValue; const originalSerializeValue = result.widget.serializeValue;
result.widget.serializeValue = function() { result.widget.serializeValue = function() {
@@ -215,18 +246,23 @@ app.registerExtension({
} }
// Store the original message for mode switching // Store the original message for mode switching
node.widgets[3].value = message; if (node.originalMessageWidget) {
node.originalMessageWidget.value = message;
}
if (node.tagWidget) { if (node.tagWidget) {
// Parse tags based on current group mode // Parse tags based on current group mode
const groupMode = node.widgets[0] ? node.widgets[0].value : false; const groupMode = node.widgets[0] ? node.widgets[0].value : false;
this.updateTagsBasedOnMode(node, message, groupMode); const allowStrengthAdjustment = Boolean(node.widgets[2]?.value);
node.tagWidget.allowStrengthAdjustment = allowStrengthAdjustment;
this.updateTagsBasedOnMode(node, message, groupMode, allowStrengthAdjustment);
} }
}, },
// Update tags display based on group mode // Update tags display based on group mode
updateTagsBasedOnMode(node, message, groupMode) { updateTagsBasedOnMode(node, message, groupMode, allowStrengthAdjustment = false) {
if (!node.tagWidget) return; if (!node.tagWidget) return;
node.tagWidget.allowStrengthAdjustment = allowStrengthAdjustment;
const existingTags = node.tagWidget.value || []; const existingTags = node.tagWidget.value || [];
const existingTagMap = {}; const existingTagMap = {};
@@ -235,7 +271,7 @@ app.registerExtension({
existingTags.forEach(tag => { existingTags.forEach(tag => {
existingTagMap[tag.text] = { existingTagMap[tag.text] = {
active: tag.active, active: tag.active,
strength: tag.strength strength: allowStrengthAdjustment ? tag.strength : null
}; };
}); });