diff --git a/py/routes/misc_routes.py b/py/routes/misc_routes.py index 9ce82f2b..76f3a44d 100644 --- a/py/routes/misc_routes.py +++ b/py/routes/misc_routes.py @@ -1005,13 +1005,14 @@ class MiscRoutes: 'error': 'File is not a safetensors file' }, status=400) - # Extract trained words - trained_words = await extract_trained_words(file_path) + # Extract trained words and class_tokens + trained_words, class_tokens = await extract_trained_words(file_path) - # Return result + # Return result with both trained words and class tokens return web.json_response({ 'success': True, - 'trained_words': trained_words + 'trained_words': trained_words, + 'class_tokens': class_tokens }) except Exception as e: diff --git a/py/utils/lora_metadata.py b/py/utils/lora_metadata.py index a91f57bb..7c8e7bc4 100644 --- a/py/utils/lora_metadata.py +++ b/py/utils/lora_metadata.py @@ -83,18 +83,36 @@ async def extract_checkpoint_metadata(file_path: str) -> dict: # Return default values return {'base_model': 'Unknown', 'model_type': 'checkpoint'} -async def extract_trained_words(file_path: str) -> List[Tuple[str, int]]: +async def extract_trained_words(file_path: str) -> Tuple[List[Tuple[str, int]], str]: """Extract trained words from a safetensors file and sort by frequency Args: file_path: Path to the safetensors file Returns: - List of (word, frequency) tuples sorted by frequency (highest first) + Tuple of: + - List of (word, frequency) tuples sorted by frequency (highest first) + - class_tokens value (or None if not found) """ + class_tokens = None + try: with safe_open(file_path, framework="pt", device="cpu") as f: metadata = f.metadata() + + # Extract class_tokens from ss_datasets if present + if metadata and "ss_datasets" in metadata: + try: + datasets_data = json.loads(metadata["ss_datasets"]) + # Look for class_tokens in the first subset + if datasets_data and isinstance(datasets_data, list) and datasets_data[0].get("subsets"): + subsets = datasets_data[0].get("subsets", []) + if subsets and isinstance(subsets, list) and len(subsets) > 0: + class_tokens = subsets[0].get("class_tokens") + except Exception as e: + logger.error(f"Error parsing ss_datasets for class_tokens: {str(e)}") + + # Extract tag frequency as before if metadata and "ss_tag_frequency" in metadata: # Parse the JSON string into a dictionary tag_data = json.loads(metadata["ss_tag_frequency"]) @@ -108,8 +126,8 @@ async def extract_trained_words(file_path: str) -> List[Tuple[str, int]]: # Sort words by frequency (highest first) sorted_words = sorted(words_dict.items(), key=lambda x: x[1], reverse=True) - return sorted_words + return sorted_words, class_tokens except Exception as e: logger.error(f"Error extracting trained words from {file_path}: {str(e)}") - return [] \ No newline at end of file + return [], class_tokens \ No newline at end of file diff --git a/static/css/components/lora-modal.css b/static/css/components/lora-modal.css index 89b02512..2263bd75 100644 --- a/static/css/components/lora-modal.css +++ b/static/css/components/lora-modal.css @@ -1512,4 +1512,32 @@ .creator-info:hover { background: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.1); border-color: var(--lora-accent); +} + +/* Class tokens styling */ +.class-tokens-container { + padding: 10px; + display: flex; + flex-wrap: wrap; + gap: 8px; +} + +.class-token-item { + background: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.1) !important; + border: 1px solid var(--lora-accent) !important; +} + +.token-badge { + background: var(--lora-accent); + color: white; + font-size: 0.7em; + padding: 2px 5px; + border-radius: 8px; + white-space: nowrap; +} + +.dropdown-separator { + height: 1px; + background: var(--lora-border); + margin: 5px 10px; } \ No newline at end of file diff --git a/static/js/components/loraModal/TriggerWords.js b/static/js/components/loraModal/TriggerWords.js index 7abaecc2..ab147608 100644 --- a/static/js/components/loraModal/TriggerWords.js +++ b/static/js/components/loraModal/TriggerWords.js @@ -8,32 +8,36 @@ import { saveModelMetadata } from '../../api/loraApi.js'; /** * Fetch trained words for a model * @param {string} filePath - Path to the model file - * @returns {Promise} - Array of [word, frequency] pairs + * @returns {Promise} - Object with trained words and class tokens */ async function fetchTrainedWords(filePath) { try { const response = await fetch(`/api/trained-words?file_path=${encodeURIComponent(filePath)}`); const data = await response.json(); - if (data.success && data.trained_words) { - return data.trained_words; // Returns array of [word, frequency] pairs + if (data.success) { + return { + trainedWords: data.trained_words || [], // Returns array of [word, frequency] pairs + classTokens: data.class_tokens // Can be null or a string + }; } else { throw new Error(data.error || 'Failed to fetch trained words'); } } catch (error) { console.error('Error fetching trained words:', error); showToast('Could not load trained words', 'error'); - return []; + return { trainedWords: [], classTokens: null }; } } /** * Create suggestion dropdown with trained words as tags * @param {Array} trainedWords - Array of [word, frequency] pairs + * @param {string|null} classTokens - Class tokens from training * @param {Array} existingWords - Already added trigger words * @returns {HTMLElement} - Dropdown element */ -function createSuggestionDropdown(trainedWords, existingWords = []) { +function createSuggestionDropdown(trainedWords, classTokens, existingWords = []) { const dropdown = document.createElement('div'); dropdown.className = 'trained-words-dropdown'; @@ -41,49 +45,56 @@ function createSuggestionDropdown(trainedWords, existingWords = []) { const header = document.createElement('div'); header.className = 'trained-words-header'; - if (!trainedWords || trainedWords.length === 0) { + // No suggestions case + if ((!trainedWords || trainedWords.length === 0) && !classTokens) { header.innerHTML = 'No suggestions available'; dropdown.appendChild(header); - dropdown.innerHTML += '
No trained words found in this model. You can manually enter trigger words.
'; + dropdown.innerHTML += '
No trained words or class tokens found in this model. You can manually enter trigger words.
'; return dropdown; } - // Sort by frequency (highest first) - trainedWords.sort((a, b) => b[1] - a[1]); + // Sort trained words by frequency (highest first) if available + if (trainedWords && trainedWords.length > 0) { + trainedWords.sort((a, b) => b[1] - a[1]); + } - header.innerHTML = ` - Suggestions from training data - ${trainedWords.length} words found - `; - dropdown.appendChild(header); - - // Create tag container - const container = document.createElement('div'); - container.className = 'trained-words-container'; - - // Add each trained word as a tag - trainedWords.forEach(([word, frequency]) => { - const isAdded = existingWords.includes(word); + // Add class tokens section if available + if (classTokens) { + // Add class tokens header + const classTokensHeader = document.createElement('div'); + classTokensHeader.className = 'trained-words-header'; + classTokensHeader.innerHTML = ` + Class Token + Add to your prompt for best results + `; + dropdown.appendChild(classTokensHeader); - const item = document.createElement('div'); - item.className = `trained-word-item ${isAdded ? 'already-added' : ''}`; - item.title = word; // Show full word on hover if truncated - item.innerHTML = ` - ${word} + // Add class tokens container + const classTokensContainer = document.createElement('div'); + classTokensContainer.className = 'class-tokens-container'; + + // Create a special item for the class token + const tokenItem = document.createElement('div'); + tokenItem.className = `trained-word-item class-token-item ${existingWords.includes(classTokens) ? 'already-added' : ''}`; + tokenItem.title = `Class token: ${classTokens}`; + tokenItem.innerHTML = ` + ${classTokens}
- ${frequency} - ${isAdded ? '' : ''} + Class Token + ${existingWords.includes(classTokens) ? + '' : ''}
`; - if (!isAdded) { - item.addEventListener('click', () => { + // Add click handler if not already added + if (!existingWords.includes(classTokens)) { + tokenItem.addEventListener('click', () => { // Automatically add this word - addNewTriggerWord(word); + addNewTriggerWord(classTokens); // Also populate the input field for potential editing const input = document.querySelector('.new-trigger-word-input'); - if (input) input.value = word; + if (input) input.value = classTokens; // Focus on the input if (input) input.focus(); @@ -93,10 +104,70 @@ function createSuggestionDropdown(trainedWords, existingWords = []) { }); } - container.appendChild(item); - }); + classTokensContainer.appendChild(tokenItem); + dropdown.appendChild(classTokensContainer); + + // Add separator if we also have trained words + if (trainedWords && trainedWords.length > 0) { + const separator = document.createElement('div'); + separator.className = 'dropdown-separator'; + dropdown.appendChild(separator); + } + } + + // Add trained words header if we have any + if (trainedWords && trainedWords.length > 0) { + header.innerHTML = ` + Word Suggestions + ${trainedWords.length} words found + `; + dropdown.appendChild(header); + + // Create tag container for trained words + const container = document.createElement('div'); + container.className = 'trained-words-container'; + + // Add each trained word as a tag + trainedWords.forEach(([word, frequency]) => { + const isAdded = existingWords.includes(word); + + const item = document.createElement('div'); + item.className = `trained-word-item ${isAdded ? 'already-added' : ''}`; + item.title = word; // Show full word on hover if truncated + item.innerHTML = ` + ${word} +
+ ${frequency} + ${isAdded ? '' : ''} +
+ `; + + if (!isAdded) { + item.addEventListener('click', () => { + // Automatically add this word + addNewTriggerWord(word); + + // Also populate the input field for potential editing + const input = document.querySelector('.new-trigger-word-input'); + if (input) input.value = word; + + // Focus on the input + if (input) input.focus(); + + // Update dropdown without removing it + updateTrainedWordsDropdown(); + }); + } + + container.appendChild(item); + }); + + dropdown.appendChild(container); + } else if (!classTokens) { + // If we have neither class tokens nor trained words + dropdown.innerHTML += '
No word suggestions found in this model. You can manually enter trigger words.
'; + } - dropdown.appendChild(container); return dropdown; } @@ -171,6 +242,7 @@ export function renderTriggerWords(words, filePath) { export function setupTriggerWordsEditMode() { // Store trained words data let trainedWordsList = []; + let classTokensValue = null; let isTrainedWordsLoaded = false; // Store original trigger words for restoring on cancel let originalTriggerWords = []; @@ -228,7 +300,9 @@ export function setupTriggerWordsEditMode() { // Asynchronously load trained words if not already loaded if (!isTrainedWordsLoaded) { - trainedWordsList = await fetchTrainedWords(filePath); + const result = await fetchTrainedWords(filePath); + trainedWordsList = result.trainedWords; + classTokensValue = result.classTokens; isTrainedWordsLoaded = true; } @@ -236,7 +310,7 @@ export function setupTriggerWordsEditMode() { loadingIndicator.remove(); // Create and display suggestion dropdown - const dropdown = createSuggestionDropdown(trainedWordsList, existingWords); + const dropdown = createSuggestionDropdown(trainedWordsList, classTokensValue, existingWords); addForm.appendChild(dropdown); // Focus the input