mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 06:32:12 -03:00
Add API endpoint for fetching trained words and implement dropdown suggestions in the trigger words editor. See #147
This commit is contained in:
@@ -15,6 +15,7 @@ from ..services.service_registry import ServiceRegistry
|
||||
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS
|
||||
from ..services.civitai_client import CivitaiClient
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
from ..utils.lora_metadata import extract_trained_words
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -61,6 +62,9 @@ class MiscRoutes:
|
||||
# Add new route for opening example images folder
|
||||
app.router.add_post('/api/open-example-images-folder', MiscRoutes.open_example_images_folder)
|
||||
|
||||
# Add new route for getting trained words
|
||||
app.router.add_get('/api/trained-words', MiscRoutes.get_trained_words)
|
||||
|
||||
@staticmethod
|
||||
async def clear_cache(request):
|
||||
"""Clear all cache files from the cache folder"""
|
||||
@@ -955,3 +959,50 @@ class MiscRoutes:
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def get_trained_words(request):
|
||||
"""
|
||||
Get trained words from a safetensors file, sorted by frequency
|
||||
|
||||
Expects a query parameter:
|
||||
file_path: Path to the safetensors file
|
||||
"""
|
||||
try:
|
||||
# Get file path from query parameters
|
||||
file_path = request.query.get('file_path')
|
||||
|
||||
if not file_path:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Missing file_path parameter'
|
||||
}, status=400)
|
||||
|
||||
# Check if file exists and is a safetensors file
|
||||
if not os.path.exists(file_path):
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': f"File not found: {file_path}"
|
||||
}, status=404)
|
||||
|
||||
if not file_path.lower().endswith('.safetensors'):
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'File is not a safetensors file'
|
||||
}, status=400)
|
||||
|
||||
# Extract trained words
|
||||
trained_words = await extract_trained_words(file_path)
|
||||
|
||||
# Return result
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'trained_words': trained_words
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get trained words: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from safetensors import safe_open
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Tuple
|
||||
from .model_utils import determine_base_model
|
||||
import os
|
||||
import logging
|
||||
import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -80,4 +81,35 @@ async def extract_checkpoint_metadata(file_path: str) -> dict:
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting checkpoint metadata for {file_path}: {e}")
|
||||
# Return default values
|
||||
return {'base_model': 'Unknown', 'model_type': 'checkpoint'}
|
||||
return {'base_model': 'Unknown', 'model_type': 'checkpoint'}
|
||||
|
||||
async def extract_trained_words(file_path: str) -> List[Tuple[str, int]]:
|
||||
"""Extract trained words from a safetensors file and sort by frequency
|
||||
|
||||
Args:
|
||||
file_path: Path to the safetensors file
|
||||
|
||||
Returns:
|
||||
List of (word, frequency) tuples sorted by frequency (highest first)
|
||||
"""
|
||||
try:
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata and "ss_tag_frequency" in metadata:
|
||||
# Parse the JSON string into a dictionary
|
||||
tag_data = json.loads(metadata["ss_tag_frequency"])
|
||||
|
||||
# The structure may have an outer key (like "image_dir" or "img")
|
||||
# We need to get the inner dictionary with the actual word frequencies
|
||||
if tag_data:
|
||||
# Get the first key (usually "image_dir" or "img")
|
||||
first_key = list(tag_data.keys())[0]
|
||||
words_dict = tag_data[first_key]
|
||||
|
||||
# Sort words by frequency (highest first)
|
||||
sorted_words = sorted(words_dict.items(), key=lambda x: x[1], reverse=True)
|
||||
return sorted_words
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting trained words from {file_path}: {str(e)}")
|
||||
|
||||
return []
|
||||
Reference in New Issue
Block a user