mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Enhance trained words extraction and display: include class tokens in response and update UI accordingly. See #147
This commit is contained in:
@@ -1005,13 +1005,14 @@ class MiscRoutes:
|
||||
'error': 'File is not a safetensors file'
|
||||
}, status=400)
|
||||
|
||||
# Extract trained words
|
||||
trained_words = await extract_trained_words(file_path)
|
||||
# Extract trained words and class_tokens
|
||||
trained_words, class_tokens = await extract_trained_words(file_path)
|
||||
|
||||
# Return result
|
||||
# Return result with both trained words and class tokens
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'trained_words': trained_words
|
||||
'trained_words': trained_words,
|
||||
'class_tokens': class_tokens
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -83,18 +83,36 @@ async def extract_checkpoint_metadata(file_path: str) -> dict:
|
||||
# Return default values
|
||||
return {'base_model': 'Unknown', 'model_type': 'checkpoint'}
|
||||
|
||||
async def extract_trained_words(file_path: str) -> List[Tuple[str, int]]:
|
||||
async def extract_trained_words(file_path: str) -> Tuple[List[Tuple[str, int]], str]:
|
||||
"""Extract trained words from a safetensors file and sort by frequency
|
||||
|
||||
Args:
|
||||
file_path: Path to the safetensors file
|
||||
|
||||
Returns:
|
||||
List of (word, frequency) tuples sorted by frequency (highest first)
|
||||
Tuple of:
|
||||
- List of (word, frequency) tuples sorted by frequency (highest first)
|
||||
- class_tokens value (or None if not found)
|
||||
"""
|
||||
class_tokens = None
|
||||
|
||||
try:
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata()
|
||||
|
||||
# Extract class_tokens from ss_datasets if present
|
||||
if metadata and "ss_datasets" in metadata:
|
||||
try:
|
||||
datasets_data = json.loads(metadata["ss_datasets"])
|
||||
# Look for class_tokens in the first subset
|
||||
if datasets_data and isinstance(datasets_data, list) and datasets_data[0].get("subsets"):
|
||||
subsets = datasets_data[0].get("subsets", [])
|
||||
if subsets and isinstance(subsets, list) and len(subsets) > 0:
|
||||
class_tokens = subsets[0].get("class_tokens")
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing ss_datasets for class_tokens: {str(e)}")
|
||||
|
||||
# Extract tag frequency as before
|
||||
if metadata and "ss_tag_frequency" in metadata:
|
||||
# Parse the JSON string into a dictionary
|
||||
tag_data = json.loads(metadata["ss_tag_frequency"])
|
||||
@@ -108,8 +126,8 @@ async def extract_trained_words(file_path: str) -> List[Tuple[str, int]]:
|
||||
|
||||
# Sort words by frequency (highest first)
|
||||
sorted_words = sorted(words_dict.items(), key=lambda x: x[1], reverse=True)
|
||||
return sorted_words
|
||||
return sorted_words, class_tokens
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting trained words from {file_path}: {str(e)}")
|
||||
|
||||
return []
|
||||
return [], class_tokens
|
||||
Reference in New Issue
Block a user