mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
Enhance trained words extraction and display: include class tokens in response and update UI accordingly. See #147
This commit is contained in:
@@ -1005,13 +1005,14 @@ class MiscRoutes:
|
|||||||
'error': 'File is not a safetensors file'
|
'error': 'File is not a safetensors file'
|
||||||
}, status=400)
|
}, status=400)
|
||||||
|
|
||||||
# Extract trained words
|
# Extract trained words and class_tokens
|
||||||
trained_words = await extract_trained_words(file_path)
|
trained_words, class_tokens = await extract_trained_words(file_path)
|
||||||
|
|
||||||
# Return result
|
# Return result with both trained words and class tokens
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
'success': True,
|
'success': True,
|
||||||
'trained_words': trained_words
|
'trained_words': trained_words,
|
||||||
|
'class_tokens': class_tokens
|
||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -83,18 +83,36 @@ async def extract_checkpoint_metadata(file_path: str) -> dict:
|
|||||||
# Return default values
|
# Return default values
|
||||||
return {'base_model': 'Unknown', 'model_type': 'checkpoint'}
|
return {'base_model': 'Unknown', 'model_type': 'checkpoint'}
|
||||||
|
|
||||||
async def extract_trained_words(file_path: str) -> List[Tuple[str, int]]:
|
async def extract_trained_words(file_path: str) -> Tuple[List[Tuple[str, int]], str]:
|
||||||
"""Extract trained words from a safetensors file and sort by frequency
|
"""Extract trained words from a safetensors file and sort by frequency
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_path: Path to the safetensors file
|
file_path: Path to the safetensors file
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of (word, frequency) tuples sorted by frequency (highest first)
|
Tuple of:
|
||||||
|
- List of (word, frequency) tuples sorted by frequency (highest first)
|
||||||
|
- class_tokens value (or None if not found)
|
||||||
"""
|
"""
|
||||||
|
class_tokens = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||||
metadata = f.metadata()
|
metadata = f.metadata()
|
||||||
|
|
||||||
|
# Extract class_tokens from ss_datasets if present
|
||||||
|
if metadata and "ss_datasets" in metadata:
|
||||||
|
try:
|
||||||
|
datasets_data = json.loads(metadata["ss_datasets"])
|
||||||
|
# Look for class_tokens in the first subset
|
||||||
|
if datasets_data and isinstance(datasets_data, list) and datasets_data[0].get("subsets"):
|
||||||
|
subsets = datasets_data[0].get("subsets", [])
|
||||||
|
if subsets and isinstance(subsets, list) and len(subsets) > 0:
|
||||||
|
class_tokens = subsets[0].get("class_tokens")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error parsing ss_datasets for class_tokens: {str(e)}")
|
||||||
|
|
||||||
|
# Extract tag frequency as before
|
||||||
if metadata and "ss_tag_frequency" in metadata:
|
if metadata and "ss_tag_frequency" in metadata:
|
||||||
# Parse the JSON string into a dictionary
|
# Parse the JSON string into a dictionary
|
||||||
tag_data = json.loads(metadata["ss_tag_frequency"])
|
tag_data = json.loads(metadata["ss_tag_frequency"])
|
||||||
@@ -108,8 +126,8 @@ async def extract_trained_words(file_path: str) -> List[Tuple[str, int]]:
|
|||||||
|
|
||||||
# Sort words by frequency (highest first)
|
# Sort words by frequency (highest first)
|
||||||
sorted_words = sorted(words_dict.items(), key=lambda x: x[1], reverse=True)
|
sorted_words = sorted(words_dict.items(), key=lambda x: x[1], reverse=True)
|
||||||
return sorted_words
|
return sorted_words, class_tokens
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error extracting trained words from {file_path}: {str(e)}")
|
logger.error(f"Error extracting trained words from {file_path}: {str(e)}")
|
||||||
|
|
||||||
return []
|
return [], class_tokens
|
||||||
@@ -1513,3 +1513,31 @@
|
|||||||
background: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.1);
|
background: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.1);
|
||||||
border-color: var(--lora-accent);
|
border-color: var(--lora-accent);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Class tokens styling */
|
||||||
|
.class-tokens-container {
|
||||||
|
padding: 10px;
|
||||||
|
display: flex;
|
||||||
|
flex-wrap: wrap;
|
||||||
|
gap: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.class-token-item {
|
||||||
|
background: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.1) !important;
|
||||||
|
border: 1px solid var(--lora-accent) !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.token-badge {
|
||||||
|
background: var(--lora-accent);
|
||||||
|
color: white;
|
||||||
|
font-size: 0.7em;
|
||||||
|
padding: 2px 5px;
|
||||||
|
border-radius: 8px;
|
||||||
|
white-space: nowrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
.dropdown-separator {
|
||||||
|
height: 1px;
|
||||||
|
background: var(--lora-border);
|
||||||
|
margin: 5px 10px;
|
||||||
|
}
|
||||||
@@ -8,32 +8,36 @@ import { saveModelMetadata } from '../../api/loraApi.js';
|
|||||||
/**
|
/**
|
||||||
* Fetch trained words for a model
|
* Fetch trained words for a model
|
||||||
* @param {string} filePath - Path to the model file
|
* @param {string} filePath - Path to the model file
|
||||||
* @returns {Promise<Array>} - Array of [word, frequency] pairs
|
* @returns {Promise<Object>} - Object with trained words and class tokens
|
||||||
*/
|
*/
|
||||||
async function fetchTrainedWords(filePath) {
|
async function fetchTrainedWords(filePath) {
|
||||||
try {
|
try {
|
||||||
const response = await fetch(`/api/trained-words?file_path=${encodeURIComponent(filePath)}`);
|
const response = await fetch(`/api/trained-words?file_path=${encodeURIComponent(filePath)}`);
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
|
|
||||||
if (data.success && data.trained_words) {
|
if (data.success) {
|
||||||
return data.trained_words; // Returns array of [word, frequency] pairs
|
return {
|
||||||
|
trainedWords: data.trained_words || [], // Returns array of [word, frequency] pairs
|
||||||
|
classTokens: data.class_tokens // Can be null or a string
|
||||||
|
};
|
||||||
} else {
|
} else {
|
||||||
throw new Error(data.error || 'Failed to fetch trained words');
|
throw new Error(data.error || 'Failed to fetch trained words');
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Error fetching trained words:', error);
|
console.error('Error fetching trained words:', error);
|
||||||
showToast('Could not load trained words', 'error');
|
showToast('Could not load trained words', 'error');
|
||||||
return [];
|
return { trainedWords: [], classTokens: null };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create suggestion dropdown with trained words as tags
|
* Create suggestion dropdown with trained words as tags
|
||||||
* @param {Array} trainedWords - Array of [word, frequency] pairs
|
* @param {Array} trainedWords - Array of [word, frequency] pairs
|
||||||
|
* @param {string|null} classTokens - Class tokens from training
|
||||||
* @param {Array} existingWords - Already added trigger words
|
* @param {Array} existingWords - Already added trigger words
|
||||||
* @returns {HTMLElement} - Dropdown element
|
* @returns {HTMLElement} - Dropdown element
|
||||||
*/
|
*/
|
||||||
function createSuggestionDropdown(trainedWords, existingWords = []) {
|
function createSuggestionDropdown(trainedWords, classTokens, existingWords = []) {
|
||||||
const dropdown = document.createElement('div');
|
const dropdown = document.createElement('div');
|
||||||
dropdown.className = 'trained-words-dropdown';
|
dropdown.className = 'trained-words-dropdown';
|
||||||
|
|
||||||
@@ -41,23 +45,85 @@ function createSuggestionDropdown(trainedWords, existingWords = []) {
|
|||||||
const header = document.createElement('div');
|
const header = document.createElement('div');
|
||||||
header.className = 'trained-words-header';
|
header.className = 'trained-words-header';
|
||||||
|
|
||||||
if (!trainedWords || trainedWords.length === 0) {
|
// No suggestions case
|
||||||
|
if ((!trainedWords || trainedWords.length === 0) && !classTokens) {
|
||||||
header.innerHTML = '<span>No suggestions available</span>';
|
header.innerHTML = '<span>No suggestions available</span>';
|
||||||
dropdown.appendChild(header);
|
dropdown.appendChild(header);
|
||||||
dropdown.innerHTML += '<div class="no-trained-words">No trained words found in this model. You can manually enter trigger words.</div>';
|
dropdown.innerHTML += '<div class="no-trained-words">No trained words or class tokens found in this model. You can manually enter trigger words.</div>';
|
||||||
return dropdown;
|
return dropdown;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sort by frequency (highest first)
|
// Sort trained words by frequency (highest first) if available
|
||||||
|
if (trainedWords && trainedWords.length > 0) {
|
||||||
trainedWords.sort((a, b) => b[1] - a[1]);
|
trainedWords.sort((a, b) => b[1] - a[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add class tokens section if available
|
||||||
|
if (classTokens) {
|
||||||
|
// Add class tokens header
|
||||||
|
const classTokensHeader = document.createElement('div');
|
||||||
|
classTokensHeader.className = 'trained-words-header';
|
||||||
|
classTokensHeader.innerHTML = `
|
||||||
|
<span>Class Token</span>
|
||||||
|
<small>Add to your prompt for best results</small>
|
||||||
|
`;
|
||||||
|
dropdown.appendChild(classTokensHeader);
|
||||||
|
|
||||||
|
// Add class tokens container
|
||||||
|
const classTokensContainer = document.createElement('div');
|
||||||
|
classTokensContainer.className = 'class-tokens-container';
|
||||||
|
|
||||||
|
// Create a special item for the class token
|
||||||
|
const tokenItem = document.createElement('div');
|
||||||
|
tokenItem.className = `trained-word-item class-token-item ${existingWords.includes(classTokens) ? 'already-added' : ''}`;
|
||||||
|
tokenItem.title = `Class token: ${classTokens}`;
|
||||||
|
tokenItem.innerHTML = `
|
||||||
|
<span class="trained-word-text">${classTokens}</span>
|
||||||
|
<div class="trained-word-meta">
|
||||||
|
<span class="token-badge">Class Token</span>
|
||||||
|
${existingWords.includes(classTokens) ?
|
||||||
|
'<span class="added-indicator"><i class="fas fa-check"></i></span>' : ''}
|
||||||
|
</div>
|
||||||
|
`;
|
||||||
|
|
||||||
|
// Add click handler if not already added
|
||||||
|
if (!existingWords.includes(classTokens)) {
|
||||||
|
tokenItem.addEventListener('click', () => {
|
||||||
|
// Automatically add this word
|
||||||
|
addNewTriggerWord(classTokens);
|
||||||
|
|
||||||
|
// Also populate the input field for potential editing
|
||||||
|
const input = document.querySelector('.new-trigger-word-input');
|
||||||
|
if (input) input.value = classTokens;
|
||||||
|
|
||||||
|
// Focus on the input
|
||||||
|
if (input) input.focus();
|
||||||
|
|
||||||
|
// Update dropdown without removing it
|
||||||
|
updateTrainedWordsDropdown();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
classTokensContainer.appendChild(tokenItem);
|
||||||
|
dropdown.appendChild(classTokensContainer);
|
||||||
|
|
||||||
|
// Add separator if we also have trained words
|
||||||
|
if (trainedWords && trainedWords.length > 0) {
|
||||||
|
const separator = document.createElement('div');
|
||||||
|
separator.className = 'dropdown-separator';
|
||||||
|
dropdown.appendChild(separator);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add trained words header if we have any
|
||||||
|
if (trainedWords && trainedWords.length > 0) {
|
||||||
header.innerHTML = `
|
header.innerHTML = `
|
||||||
<span>Suggestions from training data</span>
|
<span>Word Suggestions</span>
|
||||||
<small>${trainedWords.length} words found</small>
|
<small>${trainedWords.length} words found</small>
|
||||||
`;
|
`;
|
||||||
dropdown.appendChild(header);
|
dropdown.appendChild(header);
|
||||||
|
|
||||||
// Create tag container
|
// Create tag container for trained words
|
||||||
const container = document.createElement('div');
|
const container = document.createElement('div');
|
||||||
container.className = 'trained-words-container';
|
container.className = 'trained-words-container';
|
||||||
|
|
||||||
@@ -97,6 +163,11 @@ function createSuggestionDropdown(trainedWords, existingWords = []) {
|
|||||||
});
|
});
|
||||||
|
|
||||||
dropdown.appendChild(container);
|
dropdown.appendChild(container);
|
||||||
|
} else if (!classTokens) {
|
||||||
|
// If we have neither class tokens nor trained words
|
||||||
|
dropdown.innerHTML += '<div class="no-trained-words">No word suggestions found in this model. You can manually enter trigger words.</div>';
|
||||||
|
}
|
||||||
|
|
||||||
return dropdown;
|
return dropdown;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -171,6 +242,7 @@ export function renderTriggerWords(words, filePath) {
|
|||||||
export function setupTriggerWordsEditMode() {
|
export function setupTriggerWordsEditMode() {
|
||||||
// Store trained words data
|
// Store trained words data
|
||||||
let trainedWordsList = [];
|
let trainedWordsList = [];
|
||||||
|
let classTokensValue = null;
|
||||||
let isTrainedWordsLoaded = false;
|
let isTrainedWordsLoaded = false;
|
||||||
// Store original trigger words for restoring on cancel
|
// Store original trigger words for restoring on cancel
|
||||||
let originalTriggerWords = [];
|
let originalTriggerWords = [];
|
||||||
@@ -228,7 +300,9 @@ export function setupTriggerWordsEditMode() {
|
|||||||
|
|
||||||
// Asynchronously load trained words if not already loaded
|
// Asynchronously load trained words if not already loaded
|
||||||
if (!isTrainedWordsLoaded) {
|
if (!isTrainedWordsLoaded) {
|
||||||
trainedWordsList = await fetchTrainedWords(filePath);
|
const result = await fetchTrainedWords(filePath);
|
||||||
|
trainedWordsList = result.trainedWords;
|
||||||
|
classTokensValue = result.classTokens;
|
||||||
isTrainedWordsLoaded = true;
|
isTrainedWordsLoaded = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -236,7 +310,7 @@ export function setupTriggerWordsEditMode() {
|
|||||||
loadingIndicator.remove();
|
loadingIndicator.remove();
|
||||||
|
|
||||||
// Create and display suggestion dropdown
|
// Create and display suggestion dropdown
|
||||||
const dropdown = createSuggestionDropdown(trainedWordsList, existingWords);
|
const dropdown = createSuggestionDropdown(trainedWordsList, classTokensValue, existingWords);
|
||||||
addForm.appendChild(dropdown);
|
addForm.appendChild(dropdown);
|
||||||
|
|
||||||
// Focus the input
|
// Focus the input
|
||||||
|
|||||||
Reference in New Issue
Block a user