diff --git a/py/nodes/trigger_word_toggle.py b/py/nodes/trigger_word_toggle.py index 99318341..3ef225c4 100644 --- a/py/nodes/trigger_word_toggle.py +++ b/py/nodes/trigger_word_toggle.py @@ -72,103 +72,81 @@ class TriggerWordToggle: # Convert to list if it's a JSON string if isinstance(trigger_data, str): trigger_data = json.loads(trigger_data) - - # Create dictionaries to track active state of words or groups - # Also track strength values for each trigger word - active_state = {} - strength_map = {} - - for item in trigger_data: - text = item['text'] - active = item.get('active', False) - # Extract strength if it's in the format "(word:strength)" - strength_match = re.match(r'\((.+):([\d.]+)\)', text) - if strength_match: - original_word = strength_match.group(1).strip() - strength = float(strength_match.group(2)) - active_state[original_word] = active + + if isinstance(trigger_data, list): + if group_mode: if allow_strength_adjustment: - strength_map[original_word] = strength - else: - active_state[text.strip()] = active - - if group_mode: - if isinstance(trigger_data, list): - filtered_groups = [] - for item in trigger_data: - text = (item.get('text') or "").strip() - if not text: - continue - if item.get('active', False): - filtered_groups.append(text) - - if filtered_groups: - filtered_triggers = ', '.join(filtered_groups) - else: - filtered_triggers = "" - else: - # Split by two or more consecutive commas to get groups - groups = re.split(r',{2,}', trigger_words) - # Remove leading/trailing whitespace from each group - groups = [group.strip() for group in groups] - - # Process groups: keep those not in toggle_trigger_words or those that are active - filtered_groups = [] - for group in groups: - # Check if this group contains any words that are in the active_state - group_words = [word.strip() for word in group.split(',')] - active_group_words = [] - - for word in group_words: - word_comparison = re.sub(r'\((.+):([\d.]+)\)', r'\1', word).strip() - - if word_comparison not in active_state or active_state[word_comparison]: - active_group_words.append( - self._format_word_output( - word_comparison, - strength_map, - allow_strength_adjustment, - ) - ) - - if active_group_words: - filtered_groups.append(', '.join(active_group_words)) - - if filtered_groups: - filtered_triggers = ', '.join(filtered_groups) - else: - filtered_triggers = "" - else: - # Normal mode: split by commas and treat each word as a separate tag - original_words = [word.strip() for word in trigger_words.split(',')] - # Filter out empty strings - original_words = [word for word in original_words if word] - - filtered_words = [] - for word in original_words: - # Remove any existing strength formatting for comparison - word_comparison = re.sub(r'\((.+):([\d.]+)\)', r'\1', word).strip() - - if word_comparison not in active_state or active_state[word_comparison]: - filtered_words.append( + parsed_items = [ + self._parse_trigger_item(item, allow_strength_adjustment) + for item in trigger_data + ] + filtered_groups = [ self._format_word_output( - word_comparison, - strength_map, + item["text"], + item["strength"], allow_strength_adjustment, ) - ) - - if filtered_words: - filtered_triggers = ', '.join(filtered_words) + for item in parsed_items + if item["text"] and item["active"] + ] + else: + filtered_groups = [ + (item.get('text') or "").strip() + for item in trigger_data + if (item.get('text') or "").strip() and item.get('active', False) + ] + filtered_triggers = ', '.join(filtered_groups) if filtered_groups else "" else: - filtered_triggers = "" - + parsed_items = [ + self._parse_trigger_item(item, allow_strength_adjustment) + for item in trigger_data + ] + filtered_words = [ + self._format_word_output( + item["text"], + item["strength"], + allow_strength_adjustment, + ) + for item in parsed_items + if item["text"] and item["active"] + ] + filtered_triggers = ', '.join(filtered_words) if filtered_words else "" + else: + # Fallback to original message parsing if data is not in the expected list format + if group_mode: + groups = re.split(r',{2,}', trigger_words) + groups = [group.strip() for group in groups if group.strip()] + filtered_triggers = ', '.join(groups) + else: + words = [word.strip() for word in trigger_words.split(',') if word.strip()] + filtered_triggers = ', '.join(words) + except Exception as e: logger.error(f"Error processing trigger words: {e}") return (filtered_triggers,) - def _format_word_output(self, base_word, strength_map, allow_strength_adjustment): - if allow_strength_adjustment and base_word in strength_map: - return f"({base_word}:{strength_map[base_word]:.2f})" + def _parse_trigger_item(self, item, allow_strength_adjustment): + text = (item.get('text') or "").strip() + active = bool(item.get('active', False)) + strength = item.get('strength') + + strength_match = re.match(r'^\((.+):([\d.]+)\)$', text) + if strength_match: + text = strength_match.group(1).strip() + if strength is None: + try: + strength = float(strength_match.group(2)) + except ValueError: + strength = None + + return { + "text": text, + "active": active, + "strength": strength if allow_strength_adjustment else None, + } + + def _format_word_output(self, base_word, strength, allow_strength_adjustment): + if allow_strength_adjustment and strength is not None: + return f"({base_word}:{strength:.2f})" return base_word diff --git a/tests/nodes/test_trigger_word_toggle.py b/tests/nodes/test_trigger_word_toggle.py index 64a92463..82e651c1 100644 --- a/tests/nodes/test_trigger_word_toggle.py +++ b/tests/nodes/test_trigger_word_toggle.py @@ -24,3 +24,61 @@ def test_group_mode_preserves_parenthesized_groups(): ) assert filtered == original_message + + +def test_duplicate_words_keep_individual_active_states(): + node = TriggerWordToggle() + trigger_data = [ + {'text': 'A', 'active': True, 'strength': None, 'highlighted': False}, + {'text': 'A', 'active': False, 'strength': None, 'highlighted': False}, + ] + + filtered, = node.process_trigger_words( + id="node", + group_mode=False, + default_active=True, + allow_strength_adjustment=False, + orinalMessage="A, A", + toggle_trigger_words=trigger_data, + ) + + assert filtered == "A" + + +def test_duplicate_words_preserve_strength_per_instance(): + node = TriggerWordToggle() + trigger_data = [ + {'text': '(A:0.50)', 'active': False, 'strength': 0.50, 'highlighted': False}, + {'text': 'A', 'active': True, 'strength': 1.2, 'highlighted': False}, + {'text': '(A:0.75)', 'active': True, 'strength': 0.75, 'highlighted': False}, + ] + + filtered, = node.process_trigger_words( + id="node", + group_mode=False, + default_active=True, + allow_strength_adjustment=True, + orinalMessage="A, A, A", + toggle_trigger_words=trigger_data, + ) + + assert filtered == "(A:1.20), (A:0.75)" + + +def test_duplicate_groups_respect_active_state(): + node = TriggerWordToggle() + trigger_data = [ + {'text': 'A, B', 'active': False, 'strength': None, 'highlighted': False}, + {'text': 'A, B', 'active': True, 'strength': None, 'highlighted': False}, + ] + + filtered, = node.process_trigger_words( + id="node", + group_mode=True, + default_active=True, + allow_strength_adjustment=False, + orinalMessage="A, B,, A, B", + toggle_trigger_words=trigger_data, + ) + + assert filtered == "A, B" diff --git a/web/comfyui/trigger_word_toggle.js b/web/comfyui/trigger_word_toggle.js index 55d1baf1..b548a96f 100644 --- a/web/comfyui/trigger_word_toggle.js +++ b/web/comfyui/trigger_word_toggle.js @@ -265,15 +265,24 @@ app.registerExtension({ node.tagWidget.allowStrengthAdjustment = allowStrengthAdjustment; const existingTags = node.tagWidget.value || []; - const existingTagMap = {}; - - // Create a map of existing tags and their active states and strengths - existingTags.forEach(tag => { - existingTagMap[tag.text] = { + const existingTagState = existingTags.reduce((acc, tag) => { + const key = tag.text; + if (!acc[key]) { + acc[key] = []; + } + acc[key].push({ active: tag.active, - strength: allowStrengthAdjustment ? tag.strength : null - }; - }); + strength: allowStrengthAdjustment ? tag.strength : null, + }); + return acc; + }, {}); + const consumeExistingState = (text) => { + const states = existingTagState[text]; + if (states && states.length > 0) { + return states.shift(); + } + return null; + }; // Get default active state from the widget const defaultActive = node.widgets[1] ? node.widgets[1].value : true; @@ -292,7 +301,7 @@ app.registerExtension({ .filter(group => group) .map(group => { // Check if this group already exists with strength info - const existing = existingTagMap[group]; + const existing = consumeExistingState(group); return { text: group, // Use existing values if available, otherwise use defaults @@ -315,16 +324,16 @@ app.registerExtension({ tagArray = message .split(',') .map(word => word.trim()) - .filter(word => word) - .map(word => { - // Check if this word already exists with strength info - const existing = existingTagMap[word]; - return { - text: word, - // Use existing values if available, otherwise use defaults - active: existing ? existing.active : defaultActive, - strength: existing ? existing.strength : null - }; + .filter(word => word) + .map(word => { + // Check if this word already exists with strength info + const existing = consumeExistingState(word); + return { + text: word, + // Use existing values if available, otherwise use defaults + active: existing ? existing.active : defaultActive, + strength: existing ? existing.strength : null + }; }); }