Enhance trained words extraction and display: include class tokens in response and update UI accordingly. See #147

This commit is contained in:
Will Miao
2025-06-04 12:03:36 +08:00
parent 4b96c650eb
commit b4e7feed06
4 changed files with 167 additions and 46 deletions

View File

@@ -1005,13 +1005,14 @@ class MiscRoutes:
'error': 'File is not a safetensors file'
}, status=400)
# Extract trained words
trained_words = await extract_trained_words(file_path)
# Extract trained words and class_tokens
trained_words, class_tokens = await extract_trained_words(file_path)
# Return result
# Return result with both trained words and class tokens
return web.json_response({
'success': True,
'trained_words': trained_words
'trained_words': trained_words,
'class_tokens': class_tokens
})
except Exception as e:

View File

@@ -83,18 +83,36 @@ async def extract_checkpoint_metadata(file_path: str) -> dict:
# Return default values
return {'base_model': 'Unknown', 'model_type': 'checkpoint'}
async def extract_trained_words(file_path: str) -> List[Tuple[str, int]]:
async def extract_trained_words(file_path: str) -> Tuple[List[Tuple[str, int]], str]:
"""Extract trained words from a safetensors file and sort by frequency
Args:
file_path: Path to the safetensors file
Returns:
List of (word, frequency) tuples sorted by frequency (highest first)
Tuple of:
- List of (word, frequency) tuples sorted by frequency (highest first)
- class_tokens value (or None if not found)
"""
class_tokens = None
try:
with safe_open(file_path, framework="pt", device="cpu") as f:
metadata = f.metadata()
# Extract class_tokens from ss_datasets if present
if metadata and "ss_datasets" in metadata:
try:
datasets_data = json.loads(metadata["ss_datasets"])
# Look for class_tokens in the first subset
if datasets_data and isinstance(datasets_data, list) and datasets_data[0].get("subsets"):
subsets = datasets_data[0].get("subsets", [])
if subsets and isinstance(subsets, list) and len(subsets) > 0:
class_tokens = subsets[0].get("class_tokens")
except Exception as e:
logger.error(f"Error parsing ss_datasets for class_tokens: {str(e)}")
# Extract tag frequency as before
if metadata and "ss_tag_frequency" in metadata:
# Parse the JSON string into a dictionary
tag_data = json.loads(metadata["ss_tag_frequency"])
@@ -108,8 +126,8 @@ async def extract_trained_words(file_path: str) -> List[Tuple[str, int]]:
# Sort words by frequency (highest first)
sorted_words = sorted(words_dict.items(), key=lambda x: x[1], reverse=True)
return sorted_words
return sorted_words, class_tokens
except Exception as e:
logger.error(f"Error extracting trained words from {file_path}: {str(e)}")
return []
return [], class_tokens