mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-22 05:32:12 -03:00
161 lines
5.9 KiB
Python
161 lines
5.9 KiB
Python
"""End-to-end integration tests for aiohttp route registrars."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from contextlib import asynccontextmanager
|
|
from types import SimpleNamespace
|
|
from typing import AsyncIterator, Dict, Iterable, List, Sequence
|
|
|
|
from aiohttp import web
|
|
from aiohttp.test_utils import TestClient, TestServer
|
|
|
|
from py.routes.lora_routes import LoraRoutes
|
|
from py.services.service_registry import ServiceRegistry
|
|
from py.services.websocket_manager import ws_manager as global_ws_manager
|
|
|
|
|
|
class IntegrationCache:
|
|
"""Minimal cache implementation satisfying the service contract."""
|
|
|
|
def __init__(self, items: Sequence[Dict[str, object]]) -> None:
|
|
self.raw_data: List[Dict[str, object]] = [dict(item) for item in items]
|
|
self.folders: List[str] = ["/"]
|
|
|
|
async def get_sorted_data(self, *_: object, **__: object) -> List[Dict[str, object]]:
|
|
"""Return cached data without additional sorting."""
|
|
return [dict(item) for item in self.raw_data]
|
|
|
|
async def resort(self) -> None:
|
|
"""Resort is a no-op for the static fixture data."""
|
|
return None
|
|
|
|
|
|
class IntegrationScanner:
|
|
"""Scanner double that registers with ServiceRegistry expectations."""
|
|
|
|
def __init__(self, items: Iterable[Dict[str, object]]) -> None:
|
|
self.model_type = "lora"
|
|
self._cache = IntegrationCache(list(items))
|
|
self._hash_index = SimpleNamespace(
|
|
removed_paths=[],
|
|
remove_by_path=lambda path: self._hash_index.removed_paths.append(path),
|
|
get_duplicate_hashes=lambda: {},
|
|
get_duplicate_filenames=lambda: {},
|
|
)
|
|
self._tags_count: Dict[str, int] = {}
|
|
self._excluded_models: List[str] = []
|
|
|
|
async def get_cached_data(self, *_: object, **__: object) -> IntegrationCache:
|
|
return self._cache
|
|
|
|
def get_model_roots(self) -> List[str]: # pragma: no cover - documented surface
|
|
return ["/"]
|
|
|
|
async def bulk_delete_models(self, file_paths: Iterable[str]) -> Dict[str, object]:
|
|
existing_paths = {item["file_path"] for item in self._cache.raw_data}
|
|
deleted = [path for path in file_paths if path in existing_paths]
|
|
self._cache.raw_data = [
|
|
item for item in self._cache.raw_data if item["file_path"] not in deleted
|
|
]
|
|
await self._cache.resort()
|
|
for path in deleted:
|
|
self._hash_index.remove_by_path(path)
|
|
return {"success": True, "deleted": deleted}
|
|
|
|
async def check_model_version_exists(self, *_: object, **__: object) -> bool:
|
|
return False
|
|
|
|
|
|
@asynccontextmanager
|
|
async def aiohttp_client(app: web.Application) -> AsyncIterator[TestClient]:
|
|
"""Spin up a TestClient with lifecycle management."""
|
|
|
|
server = TestServer(app)
|
|
client = TestClient(server)
|
|
await client.start_server()
|
|
try:
|
|
yield client
|
|
finally:
|
|
await client.close()
|
|
|
|
|
|
def test_lora_route_stack_returns_real_data():
|
|
"""Spin up LoRA routes and ensure ServiceRegistry-powered wiring succeeds."""
|
|
|
|
async def scenario() -> None:
|
|
ServiceRegistry.clear_services()
|
|
|
|
fixture_item = {
|
|
"model_name": "Alpha",
|
|
"file_name": "alpha.safetensors",
|
|
"folder": "root",
|
|
"file_path": "/tmp/alpha.safetensors",
|
|
"size": 128,
|
|
"modified": "2024-01-01T00:00:00Z",
|
|
"tags": ["integration"],
|
|
"civitai": {"trainedWords": ["alpha"]},
|
|
"preview_url": "",
|
|
"preview_nsfw_level": 0,
|
|
"base_model": "SD1",
|
|
"usage_tips": "Use gently",
|
|
"notes": "Integration sample",
|
|
"from_civitai": True,
|
|
}
|
|
scanner = IntegrationScanner([fixture_item])
|
|
await ServiceRegistry.register_service("lora_scanner", scanner)
|
|
|
|
app = web.Application()
|
|
routes = LoraRoutes()
|
|
routes.setup_routes(app)
|
|
|
|
async with aiohttp_client(app) as client:
|
|
response = await client.get("/api/lm/loras/list")
|
|
payload = await response.json()
|
|
|
|
assert response.status == 200
|
|
assert payload["total"] == 1
|
|
returned = payload["items"][0]
|
|
assert returned["model_name"] == "Alpha"
|
|
assert returned["file_name"] == "alpha.safetensors"
|
|
assert returned["usage_tips"] == "Use gently"
|
|
|
|
asyncio.run(scenario())
|
|
ServiceRegistry.clear_services()
|
|
|
|
|
|
def test_websocket_routes_broadcast_through_registry():
|
|
"""Ensure websocket endpoints accept connections and relay broadcasts."""
|
|
|
|
async def scenario() -> None:
|
|
ServiceRegistry.clear_services()
|
|
ws_manager = await ServiceRegistry.get_websocket_manager()
|
|
|
|
app = web.Application()
|
|
app.router.add_get("/ws/fetch-progress", ws_manager.handle_connection)
|
|
app.router.add_get("/ws/download-progress", ws_manager.handle_download_connection)
|
|
|
|
async with aiohttp_client(app) as client:
|
|
fetch_ws = await client.ws_connect("/ws/fetch-progress")
|
|
await ws_manager.broadcast({"kind": "ping"})
|
|
message = await asyncio.wait_for(fetch_ws.receive_json(), timeout=1)
|
|
assert message == {"kind": "ping"}
|
|
|
|
download_ws = await client.ws_connect("/ws/download-progress?id=session-1")
|
|
greeting = await asyncio.wait_for(download_ws.receive_json(), timeout=1)
|
|
assert greeting["type"] == "download_id"
|
|
assert greeting["download_id"] == "session-1"
|
|
|
|
await ws_manager.broadcast_download_progress("session-1", {"progress": 55})
|
|
progress = await asyncio.wait_for(download_ws.receive_json(), timeout=1)
|
|
assert progress["progress"] == 55
|
|
|
|
await fetch_ws.close()
|
|
await download_ws.close()
|
|
|
|
# Ensure the registry cached instance matches the module-level singleton.
|
|
assert ws_manager is global_ws_manager
|
|
|
|
asyncio.run(scenario())
|
|
ServiceRegistry.clear_services()
|