mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
- Import and register two new nodes: LoraDemoNode and LoraRandomizerNode - Update import exception handling for better readability with multi-line formatting - Add comprehensive documentation file `docs/custom-node-ui-output.md` for UI output usage in custom nodes - Ensure proper node registration in NODE_CLASS_MAPPINGS for ComfyUI integration - Maintain backward compatibility with existing node structure and import fallbacks
192 lines
4.8 KiB
Python
192 lines
4.8 KiB
Python
import json
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
from py.routes.lora_routes import LoraRoutes
|
|
|
|
|
|
class DummyRequest:
|
|
def __init__(self, *, query=None, match_info=None, json_data=None):
|
|
self.query = query or {}
|
|
self.match_info = match_info or {}
|
|
self._json_data = json_data or {}
|
|
|
|
async def json(self):
|
|
return self._json_data
|
|
|
|
|
|
class StubLoraService:
|
|
"""Stub service for testing randomizer endpoints"""
|
|
|
|
def __init__(self):
|
|
self.random_loras = []
|
|
|
|
async def get_random_loras(self, **kwargs):
|
|
return self.random_loras
|
|
|
|
|
|
@pytest.fixture
|
|
def routes():
|
|
handler = LoraRoutes()
|
|
handler.service = StubLoraService()
|
|
return handler
|
|
|
|
|
|
async def test_get_random_loras_success(routes):
|
|
"""Test successful random LoRA generation"""
|
|
routes.service.random_loras = [
|
|
{
|
|
'name': 'test_lora_1',
|
|
'strength': 0.8,
|
|
'clipStrength': 0.8,
|
|
'active': True,
|
|
'expanded': False,
|
|
'locked': False
|
|
},
|
|
{
|
|
'name': 'test_lora_2',
|
|
'strength': 0.6,
|
|
'clipStrength': 0.6,
|
|
'active': True,
|
|
'expanded': False,
|
|
'locked': False
|
|
}
|
|
]
|
|
|
|
request = DummyRequest(json_data={
|
|
'count': 5,
|
|
'model_strength_min': 0.5,
|
|
'model_strength_max': 1.0,
|
|
'use_same_clip_strength': True,
|
|
'locked_loras': []
|
|
})
|
|
|
|
response = await routes.get_random_loras(request)
|
|
payload = json.loads(response.text)
|
|
|
|
assert response.status == 200
|
|
assert payload['success'] is True
|
|
assert 'loras' in payload
|
|
assert payload['count'] == 2
|
|
|
|
|
|
async def test_get_random_loras_with_range(routes):
|
|
"""Test random LoRAs with count range"""
|
|
routes.service.random_loras = [
|
|
{
|
|
'name': 'test_lora_1',
|
|
'strength': 0.8,
|
|
'clipStrength': 0.8,
|
|
'active': True,
|
|
'expanded': False,
|
|
'locked': False
|
|
}
|
|
]
|
|
|
|
request = DummyRequest(json_data={
|
|
'count_min': 3,
|
|
'count_max': 7,
|
|
'model_strength_min': 0.0,
|
|
'model_strength_max': 1.0,
|
|
'use_same_clip_strength': True
|
|
})
|
|
|
|
response = await routes.get_random_loras(request)
|
|
payload = json.loads(response.text)
|
|
|
|
assert response.status == 200
|
|
assert payload['success'] is True
|
|
|
|
|
|
async def test_get_random_loras_invalid_count(routes):
|
|
"""Test invalid count parameter"""
|
|
request = DummyRequest(json_data={
|
|
'count': 150, # Over limit
|
|
'model_strength_min': 0.0,
|
|
'model_strength_max': 1.0
|
|
})
|
|
|
|
response = await routes.get_random_loras(request)
|
|
payload = json.loads(response.text)
|
|
|
|
assert response.status == 400
|
|
assert payload['success'] is False
|
|
assert 'Count must be between 1 and 100' in payload['error']
|
|
|
|
|
|
async def test_get_random_loras_invalid_strength(routes):
|
|
"""Test invalid strength range"""
|
|
request = DummyRequest(json_data={
|
|
'count': 5,
|
|
'model_strength_min': -0.5, # Invalid
|
|
'model_strength_max': 1.0
|
|
})
|
|
|
|
response = await routes.get_random_loras(request)
|
|
payload = json.loads(response.text)
|
|
|
|
assert response.status == 400
|
|
assert payload['success'] is False
|
|
|
|
|
|
async def test_get_random_loras_with_locked(routes):
|
|
"""Test random LoRAs with locked items"""
|
|
routes.service.random_loras = [
|
|
{
|
|
'name': 'new_lora',
|
|
'strength': 0.7,
|
|
'clipStrength': 0.7,
|
|
'active': True,
|
|
'expanded': False,
|
|
'locked': False
|
|
},
|
|
{
|
|
'name': 'locked_lora',
|
|
'strength': 0.9,
|
|
'clipStrength': 0.9,
|
|
'active': True,
|
|
'expanded': False,
|
|
'locked': True
|
|
}
|
|
]
|
|
|
|
request = DummyRequest(json_data={
|
|
'count': 5,
|
|
'model_strength_min': 0.5,
|
|
'model_strength_max': 1.0,
|
|
'use_same_clip_strength': True,
|
|
'locked_loras': [
|
|
{
|
|
'name': 'locked_lora',
|
|
'strength': 0.9,
|
|
'clipStrength': 0.9,
|
|
'active': True,
|
|
'expanded': False,
|
|
'locked': True
|
|
}
|
|
]
|
|
})
|
|
|
|
response = await routes.get_random_loras(request)
|
|
payload = json.loads(response.text)
|
|
|
|
assert response.status == 200
|
|
assert payload['success'] is True
|
|
|
|
|
|
async def test_get_random_loras_error(routes, monkeypatch):
|
|
"""Test error handling"""
|
|
async def failing(*_args, **_kwargs):
|
|
raise RuntimeError("Service error")
|
|
|
|
routes.service.get_random_loras = failing
|
|
request = DummyRequest(json_data={'count': 5})
|
|
|
|
response = await routes.get_random_loras(request)
|
|
payload = json.loads(response.text)
|
|
|
|
assert response.status == 500
|
|
assert payload['success'] is False
|
|
assert 'error' in payload
|