test(routes): cover lora manager lifecycle

This commit is contained in:
pixelpaws
2025-10-05 22:19:10 +08:00
parent cb1f08d556
commit 8fcb979544

View File

@@ -0,0 +1,213 @@
from __future__ import annotations
import asyncio
import logging
from pathlib import Path
from types import SimpleNamespace
import pytest
from aiohttp import web
from py import lora_manager
class _DummyScanner:
def __init__(self, name: str) -> None:
self.name = name
self.initialized = False
async def initialize_in_background(self) -> None:
self.initialized = True
class _DummyWSManager:
async def handle_connection(self, request): # pragma: no cover - interface stub
return None
async def handle_download_connection(self, request): # pragma: no cover - interface stub
return None
async def handle_init_connection(self, request): # pragma: no cover - interface stub
return None
async def test_lora_manager_lifecycle(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
app = web.Application()
monkeypatch.setattr(lora_manager.PromptServer, "instance", SimpleNamespace(app=app))
added_static_routes: list[tuple[str, Path]] = []
def record_static_route(path: str, directory: str, *_, **__) -> SimpleNamespace:
added_static_routes.append((path, Path(directory)))
return SimpleNamespace()
added_get_routes: list[tuple[str, object]] = []
def record_get_route(path: str, handler, *_, **__) -> SimpleNamespace:
added_get_routes.append((path, handler))
return SimpleNamespace()
monkeypatch.setattr(app.router, "add_static", record_static_route)
monkeypatch.setattr(app.router, "add_get", record_get_route)
register_calls: list[bool] = []
monkeypatch.setattr(lora_manager, "register_default_model_types", lambda: register_calls.append(True))
model_factory_calls: list[object] = []
monkeypatch.setattr(lora_manager.ModelServiceFactory, "setup_all_routes", lambda app_: model_factory_calls.append(app_))
monkeypatch.setattr(lora_manager.ModelServiceFactory, "get_registered_types", lambda: ["dummy"])
stats_setup: list[object] = []
class FakeStatsRoutes:
def setup_routes(self, app_: web.Application) -> None:
stats_setup.append(app_)
monkeypatch.setattr(lora_manager, "StatsRoutes", FakeStatsRoutes)
recipe_setup: list[object] = []
class FakeRecipeRoutes:
@staticmethod
def setup_routes(app_: web.Application) -> None:
recipe_setup.append(app_)
monkeypatch.setattr(lora_manager, "RecipeRoutes", FakeRecipeRoutes)
update_setup: list[object] = []
class FakeUpdateRoutes:
@staticmethod
def setup_routes(app_: web.Application) -> None:
update_setup.append(app_)
monkeypatch.setattr(lora_manager, "UpdateRoutes", FakeUpdateRoutes)
misc_setup: list[object] = []
class FakeMiscRoutes:
@staticmethod
def setup_routes(app_: web.Application) -> None:
misc_setup.append(app_)
monkeypatch.setattr(lora_manager, "MiscRoutes", FakeMiscRoutes)
example_setup: list[tuple[object, object]] = []
class FakeExampleImagesRoutes:
@staticmethod
def setup_routes(app_: web.Application, *, ws_manager) -> None:
example_setup.append((app_, ws_manager))
monkeypatch.setattr(lora_manager, "ExampleImagesRoutes", FakeExampleImagesRoutes)
preview_setup: list[object] = []
class FakePreviewRoutes:
@staticmethod
def setup_routes(app_: web.Application) -> None:
preview_setup.append(app_)
monkeypatch.setattr(lora_manager, "PreviewRoutes", FakePreviewRoutes)
fake_ws = _DummyWSManager()
monkeypatch.setattr(lora_manager, "ws_manager", fake_ws)
example_images_root = tmp_path / "example_images"
example_images_root.mkdir()
monkeypatch.setattr(
lora_manager.settings,
"get",
lambda key, default=None: str(example_images_root) if key == "example_images_path" else default,
)
loras_root = tmp_path / "loras"
loras_root.mkdir()
(loras_root / "model.ckpt.bak").write_text("stale")
nested = loras_root / "nested"
nested.mkdir()
(nested / "nested_model.ckpt.bak").write_text("stale")
checkpoints_root = tmp_path / "checkpoints"
checkpoints_root.mkdir()
(checkpoints_root / "checkpoint.safetensors.bak").write_text("old")
embeddings_root = tmp_path / "embeddings"
embeddings_root.mkdir()
(embeddings_root / "embedding.pt.bak").write_text("old")
monkeypatch.setattr(lora_manager.config, "loras_roots", [str(loras_root)])
monkeypatch.setattr(lora_manager.config, "base_models_roots", [str(checkpoints_root)])
monkeypatch.setattr(lora_manager.config, "embeddings_roots", [str(embeddings_root)])
scanners = {
"lora": _DummyScanner("lora"),
"checkpoint": _DummyScanner("checkpoint"),
"embedding": _DummyScanner("embedding"),
"recipe": _DummyScanner("recipe"),
}
registry_calls: list[str] = []
async def _stub(name: str, value):
registry_calls.append(name)
return value
monkeypatch.setattr(lora_manager.ServiceRegistry, "get_civitai_client", lambda: _stub("civitai_client", object()))
monkeypatch.setattr(lora_manager.ServiceRegistry, "get_download_manager", lambda: _stub("download_manager", object()))
monkeypatch.setattr(lora_manager.ServiceRegistry, "get_websocket_manager", lambda: _stub("websocket_manager", object()))
monkeypatch.setattr(lora_manager.ServiceRegistry, "get_lora_scanner", lambda: _stub("lora_scanner", scanners["lora"]))
monkeypatch.setattr(lora_manager.ServiceRegistry, "get_checkpoint_scanner", lambda: _stub("checkpoint_scanner", scanners["checkpoint"]))
monkeypatch.setattr(lora_manager.ServiceRegistry, "get_embedding_scanner", lambda: _stub("embedding_scanner", scanners["embedding"]))
monkeypatch.setattr(lora_manager.ServiceRegistry, "get_recipe_scanner", lambda: _stub("recipe_scanner", scanners["recipe"]))
migration_calls: list[bool] = []
async def fake_migration() -> None:
migration_calls.append(True)
monkeypatch.setattr(
lora_manager.ExampleImagesMigration,
"check_and_run_migrations",
staticmethod(fake_migration),
)
original_create_task = asyncio.create_task
scheduled_tasks: list[asyncio.Task] = []
def track_create_task(coro, *, name=None):
task = original_create_task(coro, name=name)
scheduled_tasks.append(task)
return task
monkeypatch.setattr(asyncio, "create_task", track_create_task)
asyncio_logger = logging.getLogger("asyncio")
original_filters = list(asyncio_logger.filters)
try:
lora_manager.LoraManager.add_routes()
assert lora_manager.LoraManager._cleanup in app.on_shutdown
assert app.on_startup, "startup hooks should be registered"
assert register_calls == [True]
assert model_factory_calls == [app]
assert stats_setup == [app]
assert recipe_setup == [app]
assert update_setup == [app]
assert misc_setup == [app]
assert example_setup == [(app, fake_ws)]
assert preview_setup == [app]
assert {path for path, _ in added_static_routes} == {
"/example_images_static",
"/locales",
"/loras_static",
}
get_paths = {path for path, _ in added_get_routes}
assert {"/ws/fetch-progress", "/ws/download-progress", "/ws/init-progress"}.issubset(get_paths)
assert any(filter_obj.__class__.__name__ == "ConnectionResetFilter" for filter_obj in asyncio_logger.filters)
finally:
asyncio_logger.filters[:] = original_filters
await lora_manager.LoraManager._initialize_services()
pending = [task for task in scheduled_tasks if not task.done()]
if pending:
await asyncio.gather(*pending)
task_names = {task.get_name() for task in scheduled_tasks}
assert {"lora_cache_init", "checkpoint_cache_init", "embedding_cache_init", "recipe_cache_init", "post_init_tasks", "cleanup_bak_files"}.issubset(task_names)
for scanner in scanners.values():
assert scanner.initialized is True
assert migration_calls == [True]
for root in (loras_root, checkpoints_root, embeddings_root):
assert not any(path.suffix == ".bak" for path in root.rglob("*")), f"Backup files remain in {root}"
assert {"civitai_client", "download_manager", "websocket_manager", "lora_scanner", "checkpoint_scanner", "embedding_scanner", "recipe_scanner"}.issubset(registry_calls)