mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
chore(tests): clean integration route header
This commit is contained in:
160
tests/routes/test_route_integration.py
Normal file
160
tests/routes/test_route_integration.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""End-to-end integration tests for aiohttp route registrars."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from types import SimpleNamespace
|
||||
from typing import AsyncIterator, Dict, Iterable, List, Sequence
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
from py.routes.lora_routes import LoraRoutes
|
||||
from py.services.service_registry import ServiceRegistry
|
||||
from py.services.websocket_manager import ws_manager as global_ws_manager
|
||||
|
||||
|
||||
class IntegrationCache:
|
||||
"""Minimal cache implementation satisfying the service contract."""
|
||||
|
||||
def __init__(self, items: Sequence[Dict[str, object]]) -> None:
|
||||
self.raw_data: List[Dict[str, object]] = [dict(item) for item in items]
|
||||
self.folders: List[str] = ["/"]
|
||||
|
||||
async def get_sorted_data(self, *_: object, **__: object) -> List[Dict[str, object]]:
|
||||
"""Return cached data without additional sorting."""
|
||||
return [dict(item) for item in self.raw_data]
|
||||
|
||||
async def resort(self) -> None:
|
||||
"""Resort is a no-op for the static fixture data."""
|
||||
return None
|
||||
|
||||
|
||||
class IntegrationScanner:
|
||||
"""Scanner double that registers with ServiceRegistry expectations."""
|
||||
|
||||
def __init__(self, items: Iterable[Dict[str, object]]) -> None:
|
||||
self.model_type = "lora"
|
||||
self._cache = IntegrationCache(list(items))
|
||||
self._hash_index = SimpleNamespace(
|
||||
removed_paths=[],
|
||||
remove_by_path=lambda path: self._hash_index.removed_paths.append(path),
|
||||
get_duplicate_hashes=lambda: {},
|
||||
get_duplicate_filenames=lambda: {},
|
||||
)
|
||||
self._tags_count: Dict[str, int] = {}
|
||||
self._excluded_models: List[str] = []
|
||||
|
||||
async def get_cached_data(self, *_: object, **__: object) -> IntegrationCache:
|
||||
return self._cache
|
||||
|
||||
def get_model_roots(self) -> List[str]: # pragma: no cover - documented surface
|
||||
return ["/"]
|
||||
|
||||
async def bulk_delete_models(self, file_paths: Iterable[str]) -> Dict[str, object]:
|
||||
existing_paths = {item["file_path"] for item in self._cache.raw_data}
|
||||
deleted = [path for path in file_paths if path in existing_paths]
|
||||
self._cache.raw_data = [
|
||||
item for item in self._cache.raw_data if item["file_path"] not in deleted
|
||||
]
|
||||
await self._cache.resort()
|
||||
for path in deleted:
|
||||
self._hash_index.remove_by_path(path)
|
||||
return {"success": True, "deleted": deleted}
|
||||
|
||||
async def check_model_version_exists(self, *_: object, **__: object) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def aiohttp_client(app: web.Application) -> AsyncIterator[TestClient]:
|
||||
"""Spin up a TestClient with lifecycle management."""
|
||||
|
||||
server = TestServer(app)
|
||||
client = TestClient(server)
|
||||
await client.start_server()
|
||||
try:
|
||||
yield client
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
def test_lora_route_stack_returns_real_data():
|
||||
"""Spin up LoRA routes and ensure ServiceRegistry-powered wiring succeeds."""
|
||||
|
||||
async def scenario() -> None:
|
||||
ServiceRegistry.clear_services()
|
||||
|
||||
fixture_item = {
|
||||
"model_name": "Alpha",
|
||||
"file_name": "alpha.safetensors",
|
||||
"folder": "root",
|
||||
"file_path": "/tmp/alpha.safetensors",
|
||||
"size": 128,
|
||||
"modified": "2024-01-01T00:00:00Z",
|
||||
"tags": ["integration"],
|
||||
"civitai": {"trainedWords": ["alpha"]},
|
||||
"preview_url": "",
|
||||
"preview_nsfw_level": 0,
|
||||
"base_model": "SD1",
|
||||
"usage_tips": "Use gently",
|
||||
"notes": "Integration sample",
|
||||
"from_civitai": True,
|
||||
}
|
||||
scanner = IntegrationScanner([fixture_item])
|
||||
await ServiceRegistry.register_service("lora_scanner", scanner)
|
||||
|
||||
app = web.Application()
|
||||
routes = LoraRoutes()
|
||||
routes.setup_routes(app)
|
||||
|
||||
async with aiohttp_client(app) as client:
|
||||
response = await client.get("/api/lm/loras/list")
|
||||
payload = await response.json()
|
||||
|
||||
assert response.status == 200
|
||||
assert payload["total"] == 1
|
||||
returned = payload["items"][0]
|
||||
assert returned["model_name"] == "Alpha"
|
||||
assert returned["file_name"] == "alpha.safetensors"
|
||||
assert returned["usage_tips"] == "Use gently"
|
||||
|
||||
asyncio.run(scenario())
|
||||
ServiceRegistry.clear_services()
|
||||
|
||||
|
||||
def test_websocket_routes_broadcast_through_registry():
|
||||
"""Ensure websocket endpoints accept connections and relay broadcasts."""
|
||||
|
||||
async def scenario() -> None:
|
||||
ServiceRegistry.clear_services()
|
||||
ws_manager = await ServiceRegistry.get_websocket_manager()
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get("/ws/fetch-progress", ws_manager.handle_connection)
|
||||
app.router.add_get("/ws/download-progress", ws_manager.handle_download_connection)
|
||||
|
||||
async with aiohttp_client(app) as client:
|
||||
fetch_ws = await client.ws_connect("/ws/fetch-progress")
|
||||
await ws_manager.broadcast({"kind": "ping"})
|
||||
message = await asyncio.wait_for(fetch_ws.receive_json(), timeout=1)
|
||||
assert message == {"kind": "ping"}
|
||||
|
||||
download_ws = await client.ws_connect("/ws/download-progress?id=session-1")
|
||||
greeting = await asyncio.wait_for(download_ws.receive_json(), timeout=1)
|
||||
assert greeting["type"] == "download_id"
|
||||
assert greeting["download_id"] == "session-1"
|
||||
|
||||
await ws_manager.broadcast_download_progress("session-1", {"progress": 55})
|
||||
progress = await asyncio.wait_for(download_ws.receive_json(), timeout=1)
|
||||
assert progress["progress"] == 55
|
||||
|
||||
await fetch_ws.close()
|
||||
await download_ws.close()
|
||||
|
||||
# Ensure the registry cached instance matches the module-level singleton.
|
||||
assert ws_manager is global_ws_manager
|
||||
|
||||
asyncio.run(scenario())
|
||||
ServiceRegistry.clear_services()
|
||||
Reference in New Issue
Block a user