From 99d2ba26b9c94f80d555bd64ce67930fdd072749 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Mon, 2 Jun 2025 17:04:33 +0800 Subject: [PATCH] Add API endpoint for fetching trained words and implement dropdown suggestions in the trigger words editor. See #147 --- py/routes/misc_routes.py | 51 +++ py/utils/lora_metadata.py | 36 +- static/css/components/lora-modal.css | 161 +++++++- .../js/components/loraModal/TriggerWords.js | 343 ++++++++++++++---- 4 files changed, 508 insertions(+), 83 deletions(-) diff --git a/py/routes/misc_routes.py b/py/routes/misc_routes.py index e50a99b6..17267360 100644 --- a/py/routes/misc_routes.py +++ b/py/routes/misc_routes.py @@ -15,6 +15,7 @@ from ..services.service_registry import ServiceRegistry from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS from ..services.civitai_client import CivitaiClient from ..utils.routes_common import ModelRouteUtils +from ..utils.lora_metadata import extract_trained_words logger = logging.getLogger(__name__) @@ -61,6 +62,9 @@ class MiscRoutes: # Add new route for opening example images folder app.router.add_post('/api/open-example-images-folder', MiscRoutes.open_example_images_folder) + # Add new route for getting trained words + app.router.add_get('/api/trained-words', MiscRoutes.get_trained_words) + @staticmethod async def clear_cache(request): """Clear all cache files from the cache folder""" @@ -955,3 +959,50 @@ class MiscRoutes: 'success': False, 'error': str(e) }, status=500) + + @staticmethod + async def get_trained_words(request): + """ + Get trained words from a safetensors file, sorted by frequency + + Expects a query parameter: + file_path: Path to the safetensors file + """ + try: + # Get file path from query parameters + file_path = request.query.get('file_path') + + if not file_path: + return web.json_response({ + 'success': False, + 'error': 'Missing file_path parameter' + }, status=400) + + # Check if file exists and is a safetensors file + if not os.path.exists(file_path): + return web.json_response({ + 'success': False, + 'error': f"File not found: {file_path}" + }, status=404) + + if not file_path.lower().endswith('.safetensors'): + return web.json_response({ + 'success': False, + 'error': 'File is not a safetensors file' + }, status=400) + + # Extract trained words + trained_words = await extract_trained_words(file_path) + + # Return result + return web.json_response({ + 'success': True, + 'trained_words': trained_words + }) + + except Exception as e: + logger.error(f"Failed to get trained words: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) diff --git a/py/utils/lora_metadata.py b/py/utils/lora_metadata.py index 3dcecd75..a91f57bb 100644 --- a/py/utils/lora_metadata.py +++ b/py/utils/lora_metadata.py @@ -1,8 +1,9 @@ from safetensors import safe_open -from typing import Dict +from typing import Dict, List, Tuple from .model_utils import determine_base_model import os import logging +import json logger = logging.getLogger(__name__) @@ -80,4 +81,35 @@ async def extract_checkpoint_metadata(file_path: str) -> dict: except Exception as e: logger.error(f"Error extracting checkpoint metadata for {file_path}: {e}") # Return default values - return {'base_model': 'Unknown', 'model_type': 'checkpoint'} \ No newline at end of file + return {'base_model': 'Unknown', 'model_type': 'checkpoint'} + +async def extract_trained_words(file_path: str) -> List[Tuple[str, int]]: + """Extract trained words from a safetensors file and sort by frequency + + Args: + file_path: Path to the safetensors file + + Returns: + List of (word, frequency) tuples sorted by frequency (highest first) + """ + try: + with safe_open(file_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + if metadata and "ss_tag_frequency" in metadata: + # Parse the JSON string into a dictionary + tag_data = json.loads(metadata["ss_tag_frequency"]) + + # The structure may have an outer key (like "image_dir" or "img") + # We need to get the inner dictionary with the actual word frequencies + if tag_data: + # Get the first key (usually "image_dir" or "img") + first_key = list(tag_data.keys())[0] + words_dict = tag_data[first_key] + + # Sort words by frequency (highest first) + sorted_words = sorted(words_dict.items(), key=lambda x: x[1], reverse=True) + return sorted_words + except Exception as e: + logger.error(f"Error extracting trained words from {file_path}: {str(e)}") + + return [] \ No newline at end of file diff --git a/static/css/components/lora-modal.css b/static/css/components/lora-modal.css index cdbf087c..89b02512 100644 --- a/static/css/components/lora-modal.css +++ b/static/css/components/lora-modal.css @@ -132,7 +132,7 @@ } .scroll-indicator:hover { - background: oklch(var(--lora-accent) / 0.1); + background: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.1); transform: translateY(-1px); } @@ -241,7 +241,7 @@ /* Keep the hover effect using accent color */ .trigger-word-tag:hover { - background: oklch(var(--lora-accent) / 0.1); + background: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.1); border-color: var(--lora-accent); } @@ -301,7 +301,7 @@ } .trigger-words-edit-controls button:hover { - background: oklch(var(--lora-accent) / 0.1); + background: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.1); border-color: var(--lora-accent); } @@ -324,6 +324,7 @@ margin-top: var(--space-2); display: flex; gap: var(--space-1); + position: relative; /* Added for dropdown positioning */ } .new-trigger-word-input { @@ -346,7 +347,7 @@ padding: 4px 8px; border-radius: var(--border-radius-xs); border: 1px solid var(--border-color); - background: var(--bg-color); + background: var (--bg-color); color: var(--text-color); font-size: 0.85em; cursor: pointer; @@ -371,6 +372,146 @@ background: rgba(255, 255, 255, 0.05); } +/* Trained Words Loading Indicator */ +.trained-words-loading { + display: flex; + align-items: center; + justify-content: center; + margin: var(--space-1) 0; + color: var(--text-color); + opacity: 0.7; + font-size: 0.9em; + gap: 8px; +} + +.trained-words-loading i { + color: var(--lora-accent); +} + +/* Trained Words Dropdown Styles */ +.trained-words-dropdown { + position: absolute; + top: 100%; + left: 0; + right: 0; + background: var(--bg-color); + border: 1px solid var(--border-color); + border-radius: var(--border-radius-sm); + margin-top: 4px; + z-index: 100; + box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15); + overflow: hidden; + display: flex; + flex-direction: column; +} + +.trained-words-header { + display: flex; + justify-content: space-between; + align-items: center; + padding: 8px 12px; + background: var(--card-bg); + border-bottom: 1px solid var(--border-color); +} + +.trained-words-header span { + font-size: 0.9em; + font-weight: 500; + color: var(--text-color); +} + +.trained-words-header small { + font-size: 0.8em; + opacity: 0.7; +} + +.trained-words-container { + max-height: 200px; + overflow-y: auto; + padding: 10px; + display: flex; + flex-wrap: wrap; + gap: 8px; + align-content: flex-start; +} + +.trained-word-item { + display: inline-flex; + align-items: center; + justify-content: space-between; + padding: 5px 10px; + cursor: pointer; + transition: all 0.2s ease; + border-radius: var(--border-radius-xs); + background: var(--lora-surface); + border: 1px solid var(--lora-border); + max-width: 150px; +} + +.trained-word-item:hover { + background: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.1); + border-color: var(--lora-accent); +} + +.trained-word-item.already-added { + opacity: 0.7; + cursor: default; +} + +.trained-word-item.already-added:hover { + background: var(--lora-surface); + border-color: var(--lora-border); +} + +.trained-word-text { + color: var(--lora-accent); + font-size: 0.9em; + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; + margin-right: 4px; + max-width: 100px; +} + +.trained-word-meta { + display: flex; + align-items: center; + gap: 4px; + flex-shrink: 0; +} + +.trained-word-freq { + color: var (--text-color); + font-size: 0.75em; + background: rgba(0, 0, 0, 0.05); + border-radius: 10px; + min-width: 20px; + padding: 1px 5px; + text-align: center; + line-height: 1.2; +} + +[data-theme="dark"] .trained-word-freq { + background: rgba(255, 255, 255, 0.05); +} + +.added-indicator { + color: var(--lora-accent); + display: flex; + align-items: center; + justify-content: center; + font-size: 0.75em; +} + +.no-trained-words { + padding: 16px 12px; + text-align: center; + color: var(--text-color); + opacity: 0.7; + font-style: italic; + font-size: 0.9em; +} + /* Editable Fields */ .editable-field { position: relative; @@ -515,7 +656,7 @@ } .preset-tag:hover { - background: oklch(var(--lora-accent) / 0.1); + background: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.1); border-color: var(--lora-accent); } @@ -549,10 +690,6 @@ position: relative; } -.file-name-wrapper:hover { - background: oklch(var(--lora-accent) / 0.1); -} - .file-name-content { padding: 2px 4px; border-radius: var(--border-radius-xs); @@ -749,7 +886,7 @@ .tab-btn:hover { opacity: 1; - background: oklch(var(--lora-accent) / 0.05); + background: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.05); } .tab-btn.active { @@ -931,7 +1068,7 @@ .model-description-content pre { background: rgba(0, 0, 0, 0.05); border-radius: var(--border-radius-xs); - padding: var(--space-1); + padding: var (--space-1); white-space: pre-wrap; margin: 1em 0; overflow-x: auto; @@ -1373,6 +1510,6 @@ /* Optional: add hover effect for creator info */ .creator-info:hover { - background: oklch(var(--lora-accent) / 0.1); + background: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.1); border-color: var(--lora-accent); } \ No newline at end of file diff --git a/static/js/components/loraModal/TriggerWords.js b/static/js/components/loraModal/TriggerWords.js index 5c9004f3..7abaecc2 100644 --- a/static/js/components/loraModal/TriggerWords.js +++ b/static/js/components/loraModal/TriggerWords.js @@ -1,15 +1,110 @@ /** * TriggerWords.js - * 处理LoRA模型触发词相关的功能模块 + * Module that handles trigger word functionality for LoRA models */ import { showToast, copyToClipboard } from '../../utils/uiHelpers.js'; import { saveModelMetadata } from '../../api/loraApi.js'; /** - * 渲染触发词 - * @param {Array} words - 触发词数组 - * @param {string} filePath - 文件路径 - * @returns {string} HTML内容 + * Fetch trained words for a model + * @param {string} filePath - Path to the model file + * @returns {Promise} - Array of [word, frequency] pairs + */ +async function fetchTrainedWords(filePath) { + try { + const response = await fetch(`/api/trained-words?file_path=${encodeURIComponent(filePath)}`); + const data = await response.json(); + + if (data.success && data.trained_words) { + return data.trained_words; // Returns array of [word, frequency] pairs + } else { + throw new Error(data.error || 'Failed to fetch trained words'); + } + } catch (error) { + console.error('Error fetching trained words:', error); + showToast('Could not load trained words', 'error'); + return []; + } +} + +/** + * Create suggestion dropdown with trained words as tags + * @param {Array} trainedWords - Array of [word, frequency] pairs + * @param {Array} existingWords - Already added trigger words + * @returns {HTMLElement} - Dropdown element + */ +function createSuggestionDropdown(trainedWords, existingWords = []) { + const dropdown = document.createElement('div'); + dropdown.className = 'trained-words-dropdown'; + + // Create header + const header = document.createElement('div'); + header.className = 'trained-words-header'; + + if (!trainedWords || trainedWords.length === 0) { + header.innerHTML = 'No suggestions available'; + dropdown.appendChild(header); + dropdown.innerHTML += '
No trained words found in this model. You can manually enter trigger words.
'; + return dropdown; + } + + // Sort by frequency (highest first) + trainedWords.sort((a, b) => b[1] - a[1]); + + header.innerHTML = ` + Suggestions from training data + ${trainedWords.length} words found + `; + dropdown.appendChild(header); + + // Create tag container + const container = document.createElement('div'); + container.className = 'trained-words-container'; + + // Add each trained word as a tag + trainedWords.forEach(([word, frequency]) => { + const isAdded = existingWords.includes(word); + + const item = document.createElement('div'); + item.className = `trained-word-item ${isAdded ? 'already-added' : ''}`; + item.title = word; // Show full word on hover if truncated + item.innerHTML = ` + ${word} +
+ ${frequency} + ${isAdded ? '' : ''} +
+ `; + + if (!isAdded) { + item.addEventListener('click', () => { + // Automatically add this word + addNewTriggerWord(word); + + // Also populate the input field for potential editing + const input = document.querySelector('.new-trigger-word-input'); + if (input) input.value = word; + + // Focus on the input + if (input) input.focus(); + + // Update dropdown without removing it + updateTrainedWordsDropdown(); + }); + } + + container.appendChild(item); + }); + + dropdown.appendChild(container); + return dropdown; +} + +/** + * Render trigger words + * @param {Array} words - Array of trigger words + * @param {string} filePath - File path + * @returns {string} HTML content */ export function renderTriggerWords(words, filePath) { if (!words.length) return ` @@ -24,19 +119,14 @@ export function renderTriggerWords(words, filePath) { No trigger word needed + - `; @@ -63,44 +153,53 @@ export function renderTriggerWords(words, filePath) { `).join('')} + - `; } /** - * 设置触发词编辑模式 + * Set up trigger words edit mode */ export function setupTriggerWordsEditMode() { + // Store trained words data + let trainedWordsList = []; + let isTrainedWordsLoaded = false; + // Store original trigger words for restoring on cancel + let originalTriggerWords = []; + const editBtn = document.querySelector('.edit-trigger-words-btn'); if (!editBtn) return; - editBtn.addEventListener('click', function() { + editBtn.addEventListener('click', async function() { const triggerWordsSection = this.closest('.trigger-words'); const isEditMode = triggerWordsSection.classList.toggle('edit-mode'); + const filePath = this.dataset.filePath; // Toggle edit mode UI elements const triggerWordTags = triggerWordsSection.querySelectorAll('.trigger-word-tag'); const editControls = triggerWordsSection.querySelector('.trigger-words-edit-controls'); + const addForm = triggerWordsSection.querySelector('.add-trigger-word-form'); const noTriggerWords = triggerWordsSection.querySelector('.no-trigger-words'); const tagsContainer = triggerWordsSection.querySelector('.trigger-words-tags'); if (isEditMode) { this.innerHTML = ''; // Change to cancel icon this.title = "Cancel editing"; + + // Store original trigger words for potential restoration + originalTriggerWords = Array.from(triggerWordTags).map(tag => tag.dataset.word); + + // Show edit controls and input form editControls.style.display = 'flex'; + addForm.style.display = 'flex'; // If we have no trigger words yet, hide the "No trigger word needed" text // and show the empty tags container @@ -115,10 +214,44 @@ export function setupTriggerWordsEditMode() { tag.querySelector('.trigger-word-copy').style.display = 'none'; tag.querySelector('.delete-trigger-word-btn').style.display = 'block'; }); + + // Load trained words and display dropdown when entering edit mode + // Add loading indicator + const loadingIndicator = document.createElement('div'); + loadingIndicator.className = 'trained-words-loading'; + loadingIndicator.innerHTML = ' Loading suggestions...'; + addForm.appendChild(loadingIndicator); + + // Get currently added trigger words + const currentTags = triggerWordsSection.querySelectorAll('.trigger-word-tag'); + const existingWords = Array.from(currentTags).map(tag => tag.dataset.word); + + // Asynchronously load trained words if not already loaded + if (!isTrainedWordsLoaded) { + trainedWordsList = await fetchTrainedWords(filePath); + isTrainedWordsLoaded = true; + } + + // Remove loading indicator + loadingIndicator.remove(); + + // Create and display suggestion dropdown + const dropdown = createSuggestionDropdown(trainedWordsList, existingWords); + addForm.appendChild(dropdown); + + // Focus the input + addForm.querySelector('input').focus(); + } else { this.innerHTML = ''; // Change back to edit icon this.title = "Edit trigger words"; + + // Hide edit controls and input form editControls.style.display = 'none'; + addForm.style.display = 'none'; + + // BUGFIX: Restore original trigger words when canceling edit + restoreOriginalTriggerWords(triggerWordsSection, originalTriggerWords); // If we have no trigger words, show the "No trigger word needed" text // and hide the empty tags container @@ -128,57 +261,26 @@ export function setupTriggerWordsEditMode() { if (tagsContainer) tagsContainer.style.display = 'none'; } - // Restore original state - triggerWordTags.forEach(tag => { - const word = tag.dataset.word; - tag.onclick = () => copyTriggerWord(word); - tag.querySelector('.trigger-word-copy').style.display = 'flex'; - tag.querySelector('.delete-trigger-word-btn').style.display = 'none'; - }); - - // Hide add form if open - triggerWordsSection.querySelector('.add-trigger-word-form').style.display = 'none'; + // Remove dropdown if present + const dropdown = document.querySelector('.trained-words-dropdown'); + if (dropdown) dropdown.remove(); } }); - // Set up add trigger word button - const addBtn = document.querySelector('.add-trigger-word-btn'); - if (addBtn) { - addBtn.addEventListener('click', function() { - const triggerWordsSection = this.closest('.trigger-words'); - const addForm = triggerWordsSection.querySelector('.add-trigger-word-form'); - addForm.style.display = 'flex'; - addForm.querySelector('input').focus(); - }); - } - - // Set up confirm and cancel add buttons - const confirmAddBtn = document.querySelector('.confirm-add-trigger-word-btn'); - const cancelAddBtn = document.querySelector('.cancel-add-trigger-word-btn'); + // Set up input for adding trigger words const triggerWordInput = document.querySelector('.new-trigger-word-input'); - if (confirmAddBtn && triggerWordInput) { - confirmAddBtn.addEventListener('click', function() { - addNewTriggerWord(triggerWordInput.value); - }); - + if (triggerWordInput) { // Add keydown event to input triggerWordInput.addEventListener('keydown', function(e) { if (e.key === 'Enter') { e.preventDefault(); addNewTriggerWord(this.value); + this.value = ''; // Clear input after adding } }); } - if (cancelAddBtn) { - cancelAddBtn.addEventListener('click', function() { - const addForm = this.closest('.add-trigger-word-form'); - addForm.style.display = 'none'; - addForm.querySelector('input').value = ''; - }); - } - // Set up save button const saveBtn = document.querySelector('.save-trigger-words-btn'); if (saveBtn) { @@ -191,13 +293,59 @@ export function setupTriggerWordsEditMode() { e.stopPropagation(); const tag = this.closest('.trigger-word-tag'); tag.remove(); + + // Update status of items in the trained words dropdown + updateTrainedWordsDropdown(); }); }); } /** - * 添加新触发词 - * @param {string} word - 要添加的触发词 + * Restore original trigger words when canceling edit + * @param {HTMLElement} section - The trigger words section + * @param {Array} originalWords - Original trigger words + */ +function restoreOriginalTriggerWords(section, originalWords) { + const tagsContainer = section.querySelector('.trigger-words-tags'); + const noTriggerWords = section.querySelector('.no-trigger-words'); + + if (!tagsContainer) return; + + // Clear current tags + tagsContainer.innerHTML = ''; + + if (originalWords.length === 0) { + if (noTriggerWords) noTriggerWords.style.display = ''; + tagsContainer.style.display = 'none'; + return; + } + + // Hide "no trigger words" message + if (noTriggerWords) noTriggerWords.style.display = 'none'; + tagsContainer.style.display = 'flex'; + + // Recreate original tags + originalWords.forEach(word => { + const tag = document.createElement('div'); + tag.className = 'trigger-word-tag'; + tag.dataset.word = word; + tag.onclick = () => copyTriggerWord(word); + tag.innerHTML = ` + ${word} + + + + + `; + tagsContainer.appendChild(tag); + }); +} + +/** + * Add a new trigger word + * @param {string} word - Trigger word to add */ function addNewTriggerWord(word) { word = word.trim(); @@ -265,18 +413,75 @@ function addNewTriggerWord(word) { const deleteBtn = newTag.querySelector('.delete-trigger-word-btn'); deleteBtn.addEventListener('click', function() { newTag.remove(); + // Update dropdown after removing + updateTrainedWordsDropdown(); }); tagsContainer.appendChild(newTag); - // Clear and hide the input form - const triggerWordInput = document.querySelector('.new-trigger-word-input'); - triggerWordInput.value = ''; - document.querySelector('.add-trigger-word-form').style.display = 'none'; + // Update status of items in the trained words dropdown + updateTrainedWordsDropdown(); } /** - * 保存触发词 + * Update status of items in the trained words dropdown + */ +function updateTrainedWordsDropdown() { + const dropdown = document.querySelector('.trained-words-dropdown'); + if (!dropdown) return; + + // Get all current trigger words + const currentTags = document.querySelectorAll('.trigger-word-tag'); + const existingWords = Array.from(currentTags).map(tag => tag.dataset.word); + + // Update status of each item in dropdown + dropdown.querySelectorAll('.trained-word-item').forEach(item => { + const wordText = item.querySelector('.trained-word-text').textContent; + const isAdded = existingWords.includes(wordText); + + if (isAdded) { + item.classList.add('already-added'); + + // Add indicator if it doesn't exist + let indicator = item.querySelector('.added-indicator'); + if (!indicator) { + const meta = item.querySelector('.trained-word-meta'); + indicator = document.createElement('span'); + indicator.className = 'added-indicator'; + indicator.innerHTML = ''; + meta.appendChild(indicator); + } + + // Remove click event + item.onclick = null; + } else { + // Re-enable items that are no longer in the list + item.classList.remove('already-added'); + + // Remove indicator if it exists + const indicator = item.querySelector('.added-indicator'); + if (indicator) indicator.remove(); + + // Restore click event if not already set + if (!item.onclick) { + item.onclick = () => { + const word = item.querySelector('.trained-word-text').textContent; + addNewTriggerWord(word); + + // Also populate the input field + const input = document.querySelector('.new-trigger-word-input'); + if (input) input.value = word; + + // Focus the input + if (input) input.focus(); + }; + } + } + }); +} + +/** + * Save trigger words */ async function saveTriggerWords() { const filePath = document.querySelector('.edit-trigger-words-btn').dataset.filePath; @@ -331,8 +536,8 @@ async function saveTriggerWords() { } /** - * 复制触发词到剪贴板 - * @param {string} word - 要复制的触发词 + * Copy a trigger word to clipboard + * @param {string} word - Word to copy */ window.copyTriggerWord = async function(word) { try {