mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-23 14:12:11 -03:00
feat: add .opencode to gitignore and refactor lora routes
- Add .opencode directory to gitignore for agent-related files - Refactor lora_routes.py with consistent string formatting and improved route registration - Add DualRangeSlider Vue component for enhanced UI controls
This commit is contained in:
@@ -12,14 +12,15 @@ from ..utils.utils import get_lora_info
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoraRoutes(BaseModelRoutes):
|
||||
"""LoRA-specific route controller"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize LoRA routes with LoRA service"""
|
||||
super().__init__()
|
||||
self.template_name = "loras.html"
|
||||
|
||||
|
||||
async def initialize_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
@@ -29,231 +30,225 @@ class LoraRoutes(BaseModelRoutes):
|
||||
|
||||
# Attach service dependencies
|
||||
self.attach_service(self.service)
|
||||
|
||||
|
||||
def setup_routes(self, app: web.Application):
|
||||
"""Setup LoRA routes"""
|
||||
# Schedule service initialization on app startup
|
||||
app.on_startup.append(lambda _: self.initialize_services())
|
||||
|
||||
# Setup common routes with 'loras' prefix (includes page route)
|
||||
super().setup_routes(app, 'loras')
|
||||
super().setup_routes(app, "loras")
|
||||
|
||||
def setup_specific_routes(self, registrar: ModelRouteRegistrar, prefix: str):
|
||||
"""Setup LoRA-specific routes"""
|
||||
# LoRA-specific query routes
|
||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/letter-counts', prefix, self.get_letter_counts)
|
||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/get-trigger-words', prefix, self.get_lora_trigger_words)
|
||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/usage-tips-by-path', prefix, self.get_lora_usage_tips_by_path)
|
||||
registrar.add_prefixed_route(
|
||||
"GET", "/api/lm/{prefix}/letter-counts", prefix, self.get_letter_counts
|
||||
)
|
||||
registrar.add_prefixed_route(
|
||||
"GET",
|
||||
"/api/lm/{prefix}/get-trigger-words",
|
||||
prefix,
|
||||
self.get_lora_trigger_words,
|
||||
)
|
||||
registrar.add_prefixed_route(
|
||||
"GET",
|
||||
"/api/lm/{prefix}/usage-tips-by-path",
|
||||
prefix,
|
||||
self.get_lora_usage_tips_by_path,
|
||||
)
|
||||
|
||||
# Randomizer routes
|
||||
registrar.add_prefixed_route('POST', '/api/lm/{prefix}/random-sample', prefix, self.get_random_loras)
|
||||
registrar.add_prefixed_route(
|
||||
"POST", "/api/lm/{prefix}/random-sample", prefix, self.get_random_loras
|
||||
)
|
||||
|
||||
# ComfyUI integration
|
||||
registrar.add_prefixed_route('POST', '/api/lm/{prefix}/get_trigger_words', prefix, self.get_trigger_words)
|
||||
|
||||
registrar.add_prefixed_route(
|
||||
"POST", "/api/lm/{prefix}/get_trigger_words", prefix, self.get_trigger_words
|
||||
)
|
||||
|
||||
def _parse_specific_params(self, request: web.Request) -> Dict:
|
||||
"""Parse LoRA-specific parameters"""
|
||||
params = {}
|
||||
|
||||
|
||||
# LoRA-specific parameters
|
||||
if 'first_letter' in request.query:
|
||||
params['first_letter'] = request.query.get('first_letter')
|
||||
|
||||
if "first_letter" in request.query:
|
||||
params["first_letter"] = request.query.get("first_letter")
|
||||
|
||||
# Handle fuzzy search parameter name variation
|
||||
if request.query.get('fuzzy') == 'true':
|
||||
params['fuzzy_search'] = True
|
||||
|
||||
if request.query.get("fuzzy") == "true":
|
||||
params["fuzzy_search"] = True
|
||||
|
||||
# Handle additional filter parameters for LoRAs
|
||||
if 'lora_hash' in request.query:
|
||||
if not params.get('hash_filters'):
|
||||
params['hash_filters'] = {}
|
||||
params['hash_filters']['single_hash'] = request.query['lora_hash'].lower()
|
||||
elif 'lora_hashes' in request.query:
|
||||
if not params.get('hash_filters'):
|
||||
params['hash_filters'] = {}
|
||||
params['hash_filters']['multiple_hashes'] = [h.lower() for h in request.query['lora_hashes'].split(',')]
|
||||
|
||||
if "lora_hash" in request.query:
|
||||
if not params.get("hash_filters"):
|
||||
params["hash_filters"] = {}
|
||||
params["hash_filters"]["single_hash"] = request.query["lora_hash"].lower()
|
||||
elif "lora_hashes" in request.query:
|
||||
if not params.get("hash_filters"):
|
||||
params["hash_filters"] = {}
|
||||
params["hash_filters"]["multiple_hashes"] = [
|
||||
h.lower() for h in request.query["lora_hashes"].split(",")
|
||||
]
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def _validate_civitai_model_type(self, model_type: str) -> bool:
|
||||
"""Validate CivitAI model type for LoRA"""
|
||||
from ..utils.constants import VALID_LORA_TYPES
|
||||
|
||||
return model_type.lower() in VALID_LORA_TYPES
|
||||
|
||||
|
||||
def _get_expected_model_types(self) -> str:
|
||||
"""Get expected model types string for error messages"""
|
||||
return "LORA, LoCon, or DORA"
|
||||
|
||||
|
||||
# LoRA-specific route handlers
|
||||
async def get_letter_counts(self, request: web.Request) -> web.Response:
|
||||
"""Get count of LoRAs for each letter of the alphabet"""
|
||||
try:
|
||||
letter_counts = await self.service.get_letter_counts()
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'letter_counts': letter_counts
|
||||
})
|
||||
return web.json_response({"success": True, "letter_counts": letter_counts})
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting letter counts: {e}")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
return web.json_response({"success": False, "error": str(e)}, status=500)
|
||||
|
||||
async def get_lora_notes(self, request: web.Request) -> web.Response:
|
||||
"""Get notes for a specific LoRA file"""
|
||||
try:
|
||||
lora_name = request.query.get('name')
|
||||
lora_name = request.query.get("name")
|
||||
if not lora_name:
|
||||
return web.Response(text='Lora file name is required', status=400)
|
||||
|
||||
return web.Response(text="Lora file name is required", status=400)
|
||||
|
||||
notes = await self.service.get_lora_notes(lora_name)
|
||||
if notes is not None:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'notes': notes
|
||||
})
|
||||
return web.json_response({"success": True, "notes": notes})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'LoRA not found in cache'
|
||||
}, status=404)
|
||||
|
||||
return web.json_response(
|
||||
{"success": False, "error": "LoRA not found in cache"}, status=404
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora notes: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
return web.json_response({"success": False, "error": str(e)}, status=500)
|
||||
|
||||
async def get_lora_trigger_words(self, request: web.Request) -> web.Response:
|
||||
"""Get trigger words for a specific LoRA file"""
|
||||
try:
|
||||
lora_name = request.query.get('name')
|
||||
lora_name = request.query.get("name")
|
||||
if not lora_name:
|
||||
return web.Response(text='Lora file name is required', status=400)
|
||||
|
||||
return web.Response(text="Lora file name is required", status=400)
|
||||
|
||||
trigger_words = await self.service.get_lora_trigger_words(lora_name)
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'trigger_words': trigger_words
|
||||
})
|
||||
|
||||
return web.json_response({"success": True, "trigger_words": trigger_words})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora trigger words: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
return web.json_response({"success": False, "error": str(e)}, status=500)
|
||||
|
||||
async def get_lora_usage_tips_by_path(self, request: web.Request) -> web.Response:
|
||||
"""Get usage tips for a LoRA by its relative path"""
|
||||
try:
|
||||
relative_path = request.query.get('relative_path')
|
||||
relative_path = request.query.get("relative_path")
|
||||
if not relative_path:
|
||||
return web.Response(text='Relative path is required', status=400)
|
||||
|
||||
usage_tips = await self.service.get_lora_usage_tips_by_relative_path(relative_path)
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'usage_tips': usage_tips or ''
|
||||
})
|
||||
|
||||
return web.Response(text="Relative path is required", status=400)
|
||||
|
||||
usage_tips = await self.service.get_lora_usage_tips_by_relative_path(
|
||||
relative_path
|
||||
)
|
||||
return web.json_response({"success": True, "usage_tips": usage_tips or ""})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora usage tips by path: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
return web.json_response({"success": False, "error": str(e)}, status=500)
|
||||
|
||||
async def get_lora_preview_url(self, request: web.Request) -> web.Response:
|
||||
"""Get the static preview URL for a LoRA file"""
|
||||
try:
|
||||
lora_name = request.query.get('name')
|
||||
lora_name = request.query.get("name")
|
||||
if not lora_name:
|
||||
return web.Response(text='Lora file name is required', status=400)
|
||||
|
||||
return web.Response(text="Lora file name is required", status=400)
|
||||
|
||||
preview_url = await self.service.get_lora_preview_url(lora_name)
|
||||
if preview_url:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'preview_url': preview_url
|
||||
})
|
||||
return web.json_response({"success": True, "preview_url": preview_url})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No preview URL found for the specified lora'
|
||||
}, status=404)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"success": False,
|
||||
"error": "No preview URL found for the specified lora",
|
||||
},
|
||||
status=404,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora preview URL: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
return web.json_response({"success": False, "error": str(e)}, status=500)
|
||||
|
||||
async def get_lora_civitai_url(self, request: web.Request) -> web.Response:
|
||||
"""Get the Civitai URL for a LoRA file"""
|
||||
try:
|
||||
lora_name = request.query.get('name')
|
||||
lora_name = request.query.get("name")
|
||||
if not lora_name:
|
||||
return web.Response(text='Lora file name is required', status=400)
|
||||
|
||||
return web.Response(text="Lora file name is required", status=400)
|
||||
|
||||
result = await self.service.get_lora_civitai_url(lora_name)
|
||||
if result['civitai_url']:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
**result
|
||||
})
|
||||
if result["civitai_url"]:
|
||||
return web.json_response({"success": True, **result})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No Civitai data found for the specified lora'
|
||||
}, status=404)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"success": False,
|
||||
"error": "No Civitai data found for the specified lora",
|
||||
},
|
||||
status=404,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora Civitai URL: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
return web.json_response({"success": False, "error": str(e)}, status=500)
|
||||
|
||||
async def get_random_loras(self, request: web.Request) -> web.Response:
|
||||
"""Get random LoRAs based on filters and strength ranges"""
|
||||
try:
|
||||
json_data = await request.json()
|
||||
|
||||
# Parse parameters
|
||||
count = json_data.get('count', 5)
|
||||
count_min = json_data.get('count_min')
|
||||
count_max = json_data.get('count_max')
|
||||
model_strength_min = float(json_data.get('model_strength_min', 0.0))
|
||||
model_strength_max = float(json_data.get('model_strength_max', 1.0))
|
||||
use_same_clip_strength = json_data.get('use_same_clip_strength', True)
|
||||
clip_strength_min = float(json_data.get('clip_strength_min', 0.0))
|
||||
clip_strength_max = float(json_data.get('clip_strength_max', 1.0))
|
||||
locked_loras = json_data.get('locked_loras', [])
|
||||
pool_config = json_data.get('pool_config')
|
||||
count = json_data.get("count", 5)
|
||||
count_min = json_data.get("count_min")
|
||||
count_max = json_data.get("count_max")
|
||||
model_strength_min = float(json_data.get("model_strength_min", 0.0))
|
||||
model_strength_max = float(json_data.get("model_strength_max", 1.0))
|
||||
use_same_clip_strength = json_data.get("use_same_clip_strength", True)
|
||||
clip_strength_min = float(json_data.get("clip_strength_min", 0.0))
|
||||
clip_strength_max = float(json_data.get("clip_strength_max", 1.0))
|
||||
locked_loras = json_data.get("locked_loras", [])
|
||||
pool_config = json_data.get("pool_config")
|
||||
|
||||
# Determine target count
|
||||
if count_min is not None and count_max is not None:
|
||||
import random
|
||||
|
||||
target_count = random.randint(count_min, count_max)
|
||||
else:
|
||||
target_count = count
|
||||
|
||||
# Validate parameters
|
||||
if target_count < 1 or target_count > 100:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Count must be between 1 and 100'
|
||||
}, status=400)
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Count must be between 1 and 100"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
if model_strength_min < 0 or model_strength_max > 10:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Model strength must be between 0 and 10'
|
||||
}, status=400)
|
||||
if model_strength_min < -10 or model_strength_max > 10:
|
||||
return web.json_response(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Model strength must be between -10 and 10",
|
||||
},
|
||||
status=400,
|
||||
)
|
||||
|
||||
# Get random LoRAs from service
|
||||
result_loras = await self.service.get_random_loras(
|
||||
@@ -264,27 +259,19 @@ class LoraRoutes(BaseModelRoutes):
|
||||
clip_strength_min=clip_strength_min,
|
||||
clip_strength_max=clip_strength_max,
|
||||
locked_loras=locked_loras,
|
||||
pool_config=pool_config
|
||||
pool_config=pool_config,
|
||||
)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'loras': result_loras,
|
||||
'count': len(result_loras)
|
||||
})
|
||||
return web.json_response(
|
||||
{"success": True, "loras": result_loras, "count": len(result_loras)}
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid parameter for random LoRAs: {e}")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=400)
|
||||
return web.json_response({"success": False, "error": str(e)}, status=400)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting random LoRAs: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
return web.json_response({"success": False, "error": str(e)}, status=500)
|
||||
|
||||
async def get_trigger_words(self, request: web.Request) -> web.Response:
|
||||
"""Get trigger words for specified LoRA models"""
|
||||
@@ -292,15 +279,17 @@ class LoraRoutes(BaseModelRoutes):
|
||||
json_data = await request.json()
|
||||
lora_names = json_data.get("lora_names", [])
|
||||
node_ids = json_data.get("node_ids", [])
|
||||
|
||||
|
||||
all_trigger_words = []
|
||||
for lora_name in lora_names:
|
||||
_, trigger_words = get_lora_info(lora_name)
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
|
||||
# Format the trigger words
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
trigger_words_text = (
|
||||
",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
)
|
||||
|
||||
# Send update to all connected trigger word toggle nodes
|
||||
for entry in node_ids:
|
||||
node_identifier = entry
|
||||
@@ -314,21 +303,15 @@ class LoraRoutes(BaseModelRoutes):
|
||||
except (TypeError, ValueError):
|
||||
parsed_node_id = node_identifier
|
||||
|
||||
payload = {
|
||||
"id": parsed_node_id,
|
||||
"message": trigger_words_text
|
||||
}
|
||||
payload = {"id": parsed_node_id, "message": trigger_words_text}
|
||||
|
||||
if graph_identifier is not None:
|
||||
payload["graph_id"] = str(graph_identifier)
|
||||
|
||||
PromptServer.instance.send_sync("trigger_word_update", payload)
|
||||
|
||||
|
||||
return web.json_response({"success": True})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting trigger words: {e}")
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
return web.json_response({"success": False, "error": str(e)}, status=500)
|
||||
|
||||
Reference in New Issue
Block a user