diff --git a/.gitignore b/.gitignore
index a5af59a1..708ef925 100644
--- a/.gitignore
+++ b/.gitignore
@@ -12,6 +12,9 @@ coverage/
.coverage
model_cache/
+# agent
+.opencode/
+
# Vue widgets development cache (but keep build output)
vue-widgets/node_modules/
vue-widgets/.vite/
diff --git a/py/routes/lora_routes.py b/py/routes/lora_routes.py
index 35e16c8b..ecc30a32 100644
--- a/py/routes/lora_routes.py
+++ b/py/routes/lora_routes.py
@@ -12,14 +12,15 @@ from ..utils.utils import get_lora_info
logger = logging.getLogger(__name__)
+
class LoraRoutes(BaseModelRoutes):
"""LoRA-specific route controller"""
-
+
def __init__(self):
"""Initialize LoRA routes with LoRA service"""
super().__init__()
self.template_name = "loras.html"
-
+
async def initialize_services(self):
"""Initialize services from ServiceRegistry"""
lora_scanner = await ServiceRegistry.get_lora_scanner()
@@ -29,231 +30,225 @@ class LoraRoutes(BaseModelRoutes):
# Attach service dependencies
self.attach_service(self.service)
-
+
def setup_routes(self, app: web.Application):
"""Setup LoRA routes"""
# Schedule service initialization on app startup
app.on_startup.append(lambda _: self.initialize_services())
# Setup common routes with 'loras' prefix (includes page route)
- super().setup_routes(app, 'loras')
+ super().setup_routes(app, "loras")
def setup_specific_routes(self, registrar: ModelRouteRegistrar, prefix: str):
"""Setup LoRA-specific routes"""
# LoRA-specific query routes
- registrar.add_prefixed_route('GET', '/api/lm/{prefix}/letter-counts', prefix, self.get_letter_counts)
- registrar.add_prefixed_route('GET', '/api/lm/{prefix}/get-trigger-words', prefix, self.get_lora_trigger_words)
- registrar.add_prefixed_route('GET', '/api/lm/{prefix}/usage-tips-by-path', prefix, self.get_lora_usage_tips_by_path)
+ registrar.add_prefixed_route(
+ "GET", "/api/lm/{prefix}/letter-counts", prefix, self.get_letter_counts
+ )
+ registrar.add_prefixed_route(
+ "GET",
+ "/api/lm/{prefix}/get-trigger-words",
+ prefix,
+ self.get_lora_trigger_words,
+ )
+ registrar.add_prefixed_route(
+ "GET",
+ "/api/lm/{prefix}/usage-tips-by-path",
+ prefix,
+ self.get_lora_usage_tips_by_path,
+ )
# Randomizer routes
- registrar.add_prefixed_route('POST', '/api/lm/{prefix}/random-sample', prefix, self.get_random_loras)
+ registrar.add_prefixed_route(
+ "POST", "/api/lm/{prefix}/random-sample", prefix, self.get_random_loras
+ )
# ComfyUI integration
- registrar.add_prefixed_route('POST', '/api/lm/{prefix}/get_trigger_words', prefix, self.get_trigger_words)
-
+ registrar.add_prefixed_route(
+ "POST", "/api/lm/{prefix}/get_trigger_words", prefix, self.get_trigger_words
+ )
+
def _parse_specific_params(self, request: web.Request) -> Dict:
"""Parse LoRA-specific parameters"""
params = {}
-
+
# LoRA-specific parameters
- if 'first_letter' in request.query:
- params['first_letter'] = request.query.get('first_letter')
-
+ if "first_letter" in request.query:
+ params["first_letter"] = request.query.get("first_letter")
+
# Handle fuzzy search parameter name variation
- if request.query.get('fuzzy') == 'true':
- params['fuzzy_search'] = True
-
+ if request.query.get("fuzzy") == "true":
+ params["fuzzy_search"] = True
+
# Handle additional filter parameters for LoRAs
- if 'lora_hash' in request.query:
- if not params.get('hash_filters'):
- params['hash_filters'] = {}
- params['hash_filters']['single_hash'] = request.query['lora_hash'].lower()
- elif 'lora_hashes' in request.query:
- if not params.get('hash_filters'):
- params['hash_filters'] = {}
- params['hash_filters']['multiple_hashes'] = [h.lower() for h in request.query['lora_hashes'].split(',')]
-
+ if "lora_hash" in request.query:
+ if not params.get("hash_filters"):
+ params["hash_filters"] = {}
+ params["hash_filters"]["single_hash"] = request.query["lora_hash"].lower()
+ elif "lora_hashes" in request.query:
+ if not params.get("hash_filters"):
+ params["hash_filters"] = {}
+ params["hash_filters"]["multiple_hashes"] = [
+ h.lower() for h in request.query["lora_hashes"].split(",")
+ ]
+
return params
-
+
def _validate_civitai_model_type(self, model_type: str) -> bool:
"""Validate CivitAI model type for LoRA"""
from ..utils.constants import VALID_LORA_TYPES
+
return model_type.lower() in VALID_LORA_TYPES
-
+
def _get_expected_model_types(self) -> str:
"""Get expected model types string for error messages"""
return "LORA, LoCon, or DORA"
-
+
# LoRA-specific route handlers
async def get_letter_counts(self, request: web.Request) -> web.Response:
"""Get count of LoRAs for each letter of the alphabet"""
try:
letter_counts = await self.service.get_letter_counts()
- return web.json_response({
- 'success': True,
- 'letter_counts': letter_counts
- })
+ return web.json_response({"success": True, "letter_counts": letter_counts})
except Exception as e:
logger.error(f"Error getting letter counts: {e}")
- return web.json_response({
- 'success': False,
- 'error': str(e)
- }, status=500)
-
+ return web.json_response({"success": False, "error": str(e)}, status=500)
+
async def get_lora_notes(self, request: web.Request) -> web.Response:
"""Get notes for a specific LoRA file"""
try:
- lora_name = request.query.get('name')
+ lora_name = request.query.get("name")
if not lora_name:
- return web.Response(text='Lora file name is required', status=400)
-
+ return web.Response(text="Lora file name is required", status=400)
+
notes = await self.service.get_lora_notes(lora_name)
if notes is not None:
- return web.json_response({
- 'success': True,
- 'notes': notes
- })
+ return web.json_response({"success": True, "notes": notes})
else:
- return web.json_response({
- 'success': False,
- 'error': 'LoRA not found in cache'
- }, status=404)
-
+ return web.json_response(
+ {"success": False, "error": "LoRA not found in cache"}, status=404
+ )
+
except Exception as e:
logger.error(f"Error getting lora notes: {e}", exc_info=True)
- return web.json_response({
- 'success': False,
- 'error': str(e)
- }, status=500)
-
+ return web.json_response({"success": False, "error": str(e)}, status=500)
+
async def get_lora_trigger_words(self, request: web.Request) -> web.Response:
"""Get trigger words for a specific LoRA file"""
try:
- lora_name = request.query.get('name')
+ lora_name = request.query.get("name")
if not lora_name:
- return web.Response(text='Lora file name is required', status=400)
-
+ return web.Response(text="Lora file name is required", status=400)
+
trigger_words = await self.service.get_lora_trigger_words(lora_name)
- return web.json_response({
- 'success': True,
- 'trigger_words': trigger_words
- })
-
+ return web.json_response({"success": True, "trigger_words": trigger_words})
+
except Exception as e:
logger.error(f"Error getting lora trigger words: {e}", exc_info=True)
- return web.json_response({
- 'success': False,
- 'error': str(e)
- }, status=500)
-
+ return web.json_response({"success": False, "error": str(e)}, status=500)
+
async def get_lora_usage_tips_by_path(self, request: web.Request) -> web.Response:
"""Get usage tips for a LoRA by its relative path"""
try:
- relative_path = request.query.get('relative_path')
+ relative_path = request.query.get("relative_path")
if not relative_path:
- return web.Response(text='Relative path is required', status=400)
-
- usage_tips = await self.service.get_lora_usage_tips_by_relative_path(relative_path)
- return web.json_response({
- 'success': True,
- 'usage_tips': usage_tips or ''
- })
-
+ return web.Response(text="Relative path is required", status=400)
+
+ usage_tips = await self.service.get_lora_usage_tips_by_relative_path(
+ relative_path
+ )
+ return web.json_response({"success": True, "usage_tips": usage_tips or ""})
+
except Exception as e:
logger.error(f"Error getting lora usage tips by path: {e}", exc_info=True)
- return web.json_response({
- 'success': False,
- 'error': str(e)
- }, status=500)
-
+ return web.json_response({"success": False, "error": str(e)}, status=500)
+
async def get_lora_preview_url(self, request: web.Request) -> web.Response:
"""Get the static preview URL for a LoRA file"""
try:
- lora_name = request.query.get('name')
+ lora_name = request.query.get("name")
if not lora_name:
- return web.Response(text='Lora file name is required', status=400)
-
+ return web.Response(text="Lora file name is required", status=400)
+
preview_url = await self.service.get_lora_preview_url(lora_name)
if preview_url:
- return web.json_response({
- 'success': True,
- 'preview_url': preview_url
- })
+ return web.json_response({"success": True, "preview_url": preview_url})
else:
- return web.json_response({
- 'success': False,
- 'error': 'No preview URL found for the specified lora'
- }, status=404)
-
+ return web.json_response(
+ {
+ "success": False,
+ "error": "No preview URL found for the specified lora",
+ },
+ status=404,
+ )
+
except Exception as e:
logger.error(f"Error getting lora preview URL: {e}", exc_info=True)
- return web.json_response({
- 'success': False,
- 'error': str(e)
- }, status=500)
-
+ return web.json_response({"success": False, "error": str(e)}, status=500)
+
async def get_lora_civitai_url(self, request: web.Request) -> web.Response:
"""Get the Civitai URL for a LoRA file"""
try:
- lora_name = request.query.get('name')
+ lora_name = request.query.get("name")
if not lora_name:
- return web.Response(text='Lora file name is required', status=400)
-
+ return web.Response(text="Lora file name is required", status=400)
+
result = await self.service.get_lora_civitai_url(lora_name)
- if result['civitai_url']:
- return web.json_response({
- 'success': True,
- **result
- })
+ if result["civitai_url"]:
+ return web.json_response({"success": True, **result})
else:
- return web.json_response({
- 'success': False,
- 'error': 'No Civitai data found for the specified lora'
- }, status=404)
-
+ return web.json_response(
+ {
+ "success": False,
+ "error": "No Civitai data found for the specified lora",
+ },
+ status=404,
+ )
+
except Exception as e:
logger.error(f"Error getting lora Civitai URL: {e}", exc_info=True)
- return web.json_response({
- 'success': False,
- 'error': str(e)
- }, status=500)
-
+ return web.json_response({"success": False, "error": str(e)}, status=500)
+
async def get_random_loras(self, request: web.Request) -> web.Response:
"""Get random LoRAs based on filters and strength ranges"""
try:
json_data = await request.json()
# Parse parameters
- count = json_data.get('count', 5)
- count_min = json_data.get('count_min')
- count_max = json_data.get('count_max')
- model_strength_min = float(json_data.get('model_strength_min', 0.0))
- model_strength_max = float(json_data.get('model_strength_max', 1.0))
- use_same_clip_strength = json_data.get('use_same_clip_strength', True)
- clip_strength_min = float(json_data.get('clip_strength_min', 0.0))
- clip_strength_max = float(json_data.get('clip_strength_max', 1.0))
- locked_loras = json_data.get('locked_loras', [])
- pool_config = json_data.get('pool_config')
+ count = json_data.get("count", 5)
+ count_min = json_data.get("count_min")
+ count_max = json_data.get("count_max")
+ model_strength_min = float(json_data.get("model_strength_min", 0.0))
+ model_strength_max = float(json_data.get("model_strength_max", 1.0))
+ use_same_clip_strength = json_data.get("use_same_clip_strength", True)
+ clip_strength_min = float(json_data.get("clip_strength_min", 0.0))
+ clip_strength_max = float(json_data.get("clip_strength_max", 1.0))
+ locked_loras = json_data.get("locked_loras", [])
+ pool_config = json_data.get("pool_config")
# Determine target count
if count_min is not None and count_max is not None:
import random
+
target_count = random.randint(count_min, count_max)
else:
target_count = count
# Validate parameters
if target_count < 1 or target_count > 100:
- return web.json_response({
- 'success': False,
- 'error': 'Count must be between 1 and 100'
- }, status=400)
+ return web.json_response(
+ {"success": False, "error": "Count must be between 1 and 100"},
+ status=400,
+ )
- if model_strength_min < 0 or model_strength_max > 10:
- return web.json_response({
- 'success': False,
- 'error': 'Model strength must be between 0 and 10'
- }, status=400)
+ if model_strength_min < -10 or model_strength_max > 10:
+ return web.json_response(
+ {
+ "success": False,
+ "error": "Model strength must be between -10 and 10",
+ },
+ status=400,
+ )
# Get random LoRAs from service
result_loras = await self.service.get_random_loras(
@@ -264,27 +259,19 @@ class LoraRoutes(BaseModelRoutes):
clip_strength_min=clip_strength_min,
clip_strength_max=clip_strength_max,
locked_loras=locked_loras,
- pool_config=pool_config
+ pool_config=pool_config,
)
- return web.json_response({
- 'success': True,
- 'loras': result_loras,
- 'count': len(result_loras)
- })
+ return web.json_response(
+ {"success": True, "loras": result_loras, "count": len(result_loras)}
+ )
except ValueError as e:
logger.error(f"Invalid parameter for random LoRAs: {e}")
- return web.json_response({
- 'success': False,
- 'error': str(e)
- }, status=400)
+ return web.json_response({"success": False, "error": str(e)}, status=400)
except Exception as e:
logger.error(f"Error getting random LoRAs: {e}", exc_info=True)
- return web.json_response({
- 'success': False,
- 'error': str(e)
- }, status=500)
+ return web.json_response({"success": False, "error": str(e)}, status=500)
async def get_trigger_words(self, request: web.Request) -> web.Response:
"""Get trigger words for specified LoRA models"""
@@ -292,15 +279,17 @@ class LoraRoutes(BaseModelRoutes):
json_data = await request.json()
lora_names = json_data.get("lora_names", [])
node_ids = json_data.get("node_ids", [])
-
+
all_trigger_words = []
for lora_name in lora_names:
_, trigger_words = get_lora_info(lora_name)
all_trigger_words.extend(trigger_words)
-
+
# Format the trigger words
- trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
-
+ trigger_words_text = (
+ ",, ".join(all_trigger_words) if all_trigger_words else ""
+ )
+
# Send update to all connected trigger word toggle nodes
for entry in node_ids:
node_identifier = entry
@@ -314,21 +303,15 @@ class LoraRoutes(BaseModelRoutes):
except (TypeError, ValueError):
parsed_node_id = node_identifier
- payload = {
- "id": parsed_node_id,
- "message": trigger_words_text
- }
+ payload = {"id": parsed_node_id, "message": trigger_words_text}
if graph_identifier is not None:
payload["graph_id"] = str(graph_identifier)
PromptServer.instance.send_sync("trigger_word_update", payload)
-
+
return web.json_response({"success": True})
except Exception as e:
logger.error(f"Error getting trigger words: {e}")
- return web.json_response({
- "success": False,
- "error": str(e)
- }, status=500)
+ return web.json_response({"success": False, "error": str(e)}, status=500)
diff --git a/tests/routes/test_randomizer_endpoints.py b/tests/routes/test_randomizer_endpoints.py
index b8283cf3..4541a849 100644
--- a/tests/routes/test_randomizer_endpoints.py
+++ b/tests/routes/test_randomizer_endpoints.py
@@ -37,155 +37,167 @@ async def test_get_random_loras_success(routes):
"""Test successful random LoRA generation"""
routes.service.random_loras = [
{
- 'name': 'test_lora_1',
- 'strength': 0.8,
- 'clipStrength': 0.8,
- 'active': True,
- 'expanded': False,
- 'locked': False
+ "name": "test_lora_1",
+ "strength": 0.8,
+ "clipStrength": 0.8,
+ "active": True,
+ "expanded": False,
+ "locked": False,
},
{
- 'name': 'test_lora_2',
- 'strength': 0.6,
- 'clipStrength': 0.6,
- 'active': True,
- 'expanded': False,
- 'locked': False
- }
+ "name": "test_lora_2",
+ "strength": 0.6,
+ "clipStrength": 0.6,
+ "active": True,
+ "expanded": False,
+ "locked": False,
+ },
]
- request = DummyRequest(json_data={
- 'count': 5,
- 'model_strength_min': 0.5,
- 'model_strength_max': 1.0,
- 'use_same_clip_strength': True,
- 'locked_loras': []
- })
+ request = DummyRequest(
+ json_data={
+ "count": 5,
+ "model_strength_min": 0.5,
+ "model_strength_max": 1.0,
+ "use_same_clip_strength": True,
+ "locked_loras": [],
+ }
+ )
response = await routes.get_random_loras(request)
payload = json.loads(response.text)
assert response.status == 200
- assert payload['success'] is True
- assert 'loras' in payload
- assert payload['count'] == 2
+ assert payload["success"] is True
+ assert "loras" in payload
+ assert payload["count"] == 2
async def test_get_random_loras_with_range(routes):
"""Test random LoRAs with count range"""
routes.service.random_loras = [
{
- 'name': 'test_lora_1',
- 'strength': 0.8,
- 'clipStrength': 0.8,
- 'active': True,
- 'expanded': False,
- 'locked': False
+ "name": "test_lora_1",
+ "strength": 0.8,
+ "clipStrength": 0.8,
+ "active": True,
+ "expanded": False,
+ "locked": False,
}
]
- request = DummyRequest(json_data={
- 'count_min': 3,
- 'count_max': 7,
- 'model_strength_min': 0.0,
- 'model_strength_max': 1.0,
- 'use_same_clip_strength': True
- })
+ request = DummyRequest(
+ json_data={
+ "count_min": 3,
+ "count_max": 7,
+ "model_strength_min": 0.0,
+ "model_strength_max": 1.0,
+ "use_same_clip_strength": True,
+ }
+ )
response = await routes.get_random_loras(request)
payload = json.loads(response.text)
assert response.status == 200
- assert payload['success'] is True
+ assert payload["success"] is True
async def test_get_random_loras_invalid_count(routes):
"""Test invalid count parameter"""
- request = DummyRequest(json_data={
- 'count': 150, # Over limit
- 'model_strength_min': 0.0,
- 'model_strength_max': 1.0
- })
+ request = DummyRequest(
+ json_data={
+ "count": 150, # Over limit
+ "model_strength_min": 0.0,
+ "model_strength_max": 1.0,
+ }
+ )
response = await routes.get_random_loras(request)
payload = json.loads(response.text)
assert response.status == 400
- assert payload['success'] is False
- assert 'Count must be between 1 and 100' in payload['error']
+ assert payload["success"] is False
+ assert "Count must be between 1 and 100" in payload["error"]
async def test_get_random_loras_invalid_strength(routes):
"""Test invalid strength range"""
- request = DummyRequest(json_data={
- 'count': 5,
- 'model_strength_min': -0.5, # Invalid
- 'model_strength_max': 1.0
- })
+ request = DummyRequest(
+ json_data={
+ "count": 5,
+ "model_strength_min": -11, # Invalid (below -10)
+ "model_strength_max": 1.0,
+ }
+ )
response = await routes.get_random_loras(request)
payload = json.loads(response.text)
assert response.status == 400
- assert payload['success'] is False
+ assert payload["success"] is False
+ assert "Model strength must be between -10 and 10" in payload["error"]
async def test_get_random_loras_with_locked(routes):
"""Test random LoRAs with locked items"""
routes.service.random_loras = [
{
- 'name': 'new_lora',
- 'strength': 0.7,
- 'clipStrength': 0.7,
- 'active': True,
- 'expanded': False,
- 'locked': False
+ "name": "new_lora",
+ "strength": 0.7,
+ "clipStrength": 0.7,
+ "active": True,
+ "expanded": False,
+ "locked": False,
},
{
- 'name': 'locked_lora',
- 'strength': 0.9,
- 'clipStrength': 0.9,
- 'active': True,
- 'expanded': False,
- 'locked': True
- }
+ "name": "locked_lora",
+ "strength": 0.9,
+ "clipStrength": 0.9,
+ "active": True,
+ "expanded": False,
+ "locked": True,
+ },
]
- request = DummyRequest(json_data={
- 'count': 5,
- 'model_strength_min': 0.5,
- 'model_strength_max': 1.0,
- 'use_same_clip_strength': True,
- 'locked_loras': [
- {
- 'name': 'locked_lora',
- 'strength': 0.9,
- 'clipStrength': 0.9,
- 'active': True,
- 'expanded': False,
- 'locked': True
- }
- ]
- })
+ request = DummyRequest(
+ json_data={
+ "count": 5,
+ "model_strength_min": 0.5,
+ "model_strength_max": 1.0,
+ "use_same_clip_strength": True,
+ "locked_loras": [
+ {
+ "name": "locked_lora",
+ "strength": 0.9,
+ "clipStrength": 0.9,
+ "active": True,
+ "expanded": False,
+ "locked": True,
+ }
+ ],
+ }
+ )
response = await routes.get_random_loras(request)
payload = json.loads(response.text)
assert response.status == 200
- assert payload['success'] is True
+ assert payload["success"] is True
async def test_get_random_loras_error(routes, monkeypatch):
"""Test error handling"""
+
async def failing(*_args, **_kwargs):
raise RuntimeError("Service error")
routes.service.get_random_loras = failing
- request = DummyRequest(json_data={'count': 5})
+ request = DummyRequest(json_data={"count": 5})
response = await routes.get_random_loras(request)
payload = json.loads(response.text)
assert response.status == 500
- assert payload['success'] is False
- assert 'error' in payload
+ assert payload["success"] is False
+ assert "error" in payload
diff --git a/vue-widgets/src/components/LoraPoolWidget.vue b/vue-widgets/src/components/LoraPoolWidget.vue
index e531861b..12321ee9 100644
--- a/vue-widgets/src/components/LoraPoolWidget.vue
+++ b/vue-widgets/src/components/LoraPoolWidget.vue
@@ -104,7 +104,7 @@ onMounted(async () => {
// Handle external value updates (e.g., loading workflow, paste)
props.widget.onSetValue = (v) => {
- state.restoreFromConfig(v)
+ state.restoreFromConfig(v as LoraPoolConfig | LegacyLoraPoolConfig)
state.refreshPreview()
}
diff --git a/vue-widgets/src/components/lora-randomizer/LoraRandomizerSettingsView.vue b/vue-widgets/src/components/lora-randomizer/LoraRandomizerSettingsView.vue
index b9afe89b..fd6138bd 100644
--- a/vue-widgets/src/components/lora-randomizer/LoraRandomizerSettingsView.vue
+++ b/vue-widgets/src/components/lora-randomizer/LoraRandomizerSettingsView.vue
@@ -7,8 +7,8 @@
-
-
-
-
-