From 23679ec3f5933fefada76e444abb2258a620783b Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Thu, 25 Sep 2025 14:17:45 +0800 Subject: [PATCH] chore(tests): clean integration route header --- tests/routes/test_route_integration.py | 160 +++++++++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 tests/routes/test_route_integration.py diff --git a/tests/routes/test_route_integration.py b/tests/routes/test_route_integration.py new file mode 100644 index 00000000..dc0a1261 --- /dev/null +++ b/tests/routes/test_route_integration.py @@ -0,0 +1,160 @@ +"""End-to-end integration tests for aiohttp route registrars.""" + +from __future__ import annotations + +import asyncio +from contextlib import asynccontextmanager +from types import SimpleNamespace +from typing import AsyncIterator, Dict, Iterable, List, Sequence + +from aiohttp import web +from aiohttp.test_utils import TestClient, TestServer + +from py.routes.lora_routes import LoraRoutes +from py.services.service_registry import ServiceRegistry +from py.services.websocket_manager import ws_manager as global_ws_manager + + +class IntegrationCache: + """Minimal cache implementation satisfying the service contract.""" + + def __init__(self, items: Sequence[Dict[str, object]]) -> None: + self.raw_data: List[Dict[str, object]] = [dict(item) for item in items] + self.folders: List[str] = ["/"] + + async def get_sorted_data(self, *_: object, **__: object) -> List[Dict[str, object]]: + """Return cached data without additional sorting.""" + return [dict(item) for item in self.raw_data] + + async def resort(self) -> None: + """Resort is a no-op for the static fixture data.""" + return None + + +class IntegrationScanner: + """Scanner double that registers with ServiceRegistry expectations.""" + + def __init__(self, items: Iterable[Dict[str, object]]) -> None: + self.model_type = "lora" + self._cache = IntegrationCache(list(items)) + self._hash_index = SimpleNamespace( + removed_paths=[], + remove_by_path=lambda path: self._hash_index.removed_paths.append(path), + get_duplicate_hashes=lambda: {}, + get_duplicate_filenames=lambda: {}, + ) + self._tags_count: Dict[str, int] = {} + self._excluded_models: List[str] = [] + + async def get_cached_data(self, *_: object, **__: object) -> IntegrationCache: + return self._cache + + def get_model_roots(self) -> List[str]: # pragma: no cover - documented surface + return ["/"] + + async def bulk_delete_models(self, file_paths: Iterable[str]) -> Dict[str, object]: + existing_paths = {item["file_path"] for item in self._cache.raw_data} + deleted = [path for path in file_paths if path in existing_paths] + self._cache.raw_data = [ + item for item in self._cache.raw_data if item["file_path"] not in deleted + ] + await self._cache.resort() + for path in deleted: + self._hash_index.remove_by_path(path) + return {"success": True, "deleted": deleted} + + async def check_model_version_exists(self, *_: object, **__: object) -> bool: + return False + + +@asynccontextmanager +async def aiohttp_client(app: web.Application) -> AsyncIterator[TestClient]: + """Spin up a TestClient with lifecycle management.""" + + server = TestServer(app) + client = TestClient(server) + await client.start_server() + try: + yield client + finally: + await client.close() + + +def test_lora_route_stack_returns_real_data(): + """Spin up LoRA routes and ensure ServiceRegistry-powered wiring succeeds.""" + + async def scenario() -> None: + ServiceRegistry.clear_services() + + fixture_item = { + "model_name": "Alpha", + "file_name": "alpha.safetensors", + "folder": "root", + "file_path": "/tmp/alpha.safetensors", + "size": 128, + "modified": "2024-01-01T00:00:00Z", + "tags": ["integration"], + "civitai": {"trainedWords": ["alpha"]}, + "preview_url": "", + "preview_nsfw_level": 0, + "base_model": "SD1", + "usage_tips": "Use gently", + "notes": "Integration sample", + "from_civitai": True, + } + scanner = IntegrationScanner([fixture_item]) + await ServiceRegistry.register_service("lora_scanner", scanner) + + app = web.Application() + routes = LoraRoutes() + routes.setup_routes(app) + + async with aiohttp_client(app) as client: + response = await client.get("/api/lm/loras/list") + payload = await response.json() + + assert response.status == 200 + assert payload["total"] == 1 + returned = payload["items"][0] + assert returned["model_name"] == "Alpha" + assert returned["file_name"] == "alpha.safetensors" + assert returned["usage_tips"] == "Use gently" + + asyncio.run(scenario()) + ServiceRegistry.clear_services() + + +def test_websocket_routes_broadcast_through_registry(): + """Ensure websocket endpoints accept connections and relay broadcasts.""" + + async def scenario() -> None: + ServiceRegistry.clear_services() + ws_manager = await ServiceRegistry.get_websocket_manager() + + app = web.Application() + app.router.add_get("/ws/fetch-progress", ws_manager.handle_connection) + app.router.add_get("/ws/download-progress", ws_manager.handle_download_connection) + + async with aiohttp_client(app) as client: + fetch_ws = await client.ws_connect("/ws/fetch-progress") + await ws_manager.broadcast({"kind": "ping"}) + message = await asyncio.wait_for(fetch_ws.receive_json(), timeout=1) + assert message == {"kind": "ping"} + + download_ws = await client.ws_connect("/ws/download-progress?id=session-1") + greeting = await asyncio.wait_for(download_ws.receive_json(), timeout=1) + assert greeting["type"] == "download_id" + assert greeting["download_id"] == "session-1" + + await ws_manager.broadcast_download_progress("session-1", {"progress": 55}) + progress = await asyncio.wait_for(download_ws.receive_json(), timeout=1) + assert progress["progress"] == 55 + + await fetch_ws.close() + await download_ws.close() + + # Ensure the registry cached instance matches the module-level singleton. + assert ws_manager is global_ws_manager + + asyncio.run(scenario()) + ServiceRegistry.clear_services()