Add API endpoint for fetching trained words and implement dropdown suggestions in the trigger words editor. See #147

This commit is contained in:
Will Miao
2025-06-02 17:04:33 +08:00
parent 396924f4cc
commit 99d2ba26b9
4 changed files with 508 additions and 83 deletions

View File

@@ -1,8 +1,9 @@
from safetensors import safe_open
from typing import Dict
from typing import Dict, List, Tuple
from .model_utils import determine_base_model
import os
import logging
import json
logger = logging.getLogger(__name__)
@@ -80,4 +81,35 @@ async def extract_checkpoint_metadata(file_path: str) -> dict:
except Exception as e:
logger.error(f"Error extracting checkpoint metadata for {file_path}: {e}")
# Return default values
return {'base_model': 'Unknown', 'model_type': 'checkpoint'}
return {'base_model': 'Unknown', 'model_type': 'checkpoint'}
async def extract_trained_words(file_path: str) -> List[Tuple[str, int]]:
"""Extract trained words from a safetensors file and sort by frequency
Args:
file_path: Path to the safetensors file
Returns:
List of (word, frequency) tuples sorted by frequency (highest first)
"""
try:
with safe_open(file_path, framework="pt", device="cpu") as f:
metadata = f.metadata()
if metadata and "ss_tag_frequency" in metadata:
# Parse the JSON string into a dictionary
tag_data = json.loads(metadata["ss_tag_frequency"])
# The structure may have an outer key (like "image_dir" or "img")
# We need to get the inner dictionary with the actual word frequencies
if tag_data:
# Get the first key (usually "image_dir" or "img")
first_key = list(tag_data.keys())[0]
words_dict = tag_data[first_key]
# Sort words by frequency (highest first)
sorted_words = sorted(words_dict.items(), key=lambda x: x[1], reverse=True)
return sorted_words
except Exception as e:
logger.error(f"Error extracting trained words from {file_path}: {str(e)}")
return []