mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
test(routes): cover lora manager lifecycle
This commit is contained in:
213
tests/routes/test_lora_manager_lifecycle.py
Normal file
213
tests/routes/test_lora_manager_lifecycle.py
Normal file
@@ -0,0 +1,213 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from py import lora_manager
|
||||
|
||||
|
||||
class _DummyScanner:
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
self.initialized = False
|
||||
|
||||
async def initialize_in_background(self) -> None:
|
||||
self.initialized = True
|
||||
|
||||
|
||||
class _DummyWSManager:
|
||||
async def handle_connection(self, request): # pragma: no cover - interface stub
|
||||
return None
|
||||
|
||||
async def handle_download_connection(self, request): # pragma: no cover - interface stub
|
||||
return None
|
||||
|
||||
async def handle_init_connection(self, request): # pragma: no cover - interface stub
|
||||
return None
|
||||
|
||||
|
||||
async def test_lora_manager_lifecycle(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||
app = web.Application()
|
||||
monkeypatch.setattr(lora_manager.PromptServer, "instance", SimpleNamespace(app=app))
|
||||
|
||||
added_static_routes: list[tuple[str, Path]] = []
|
||||
def record_static_route(path: str, directory: str, *_, **__) -> SimpleNamespace:
|
||||
added_static_routes.append((path, Path(directory)))
|
||||
return SimpleNamespace()
|
||||
|
||||
added_get_routes: list[tuple[str, object]] = []
|
||||
def record_get_route(path: str, handler, *_, **__) -> SimpleNamespace:
|
||||
added_get_routes.append((path, handler))
|
||||
return SimpleNamespace()
|
||||
|
||||
monkeypatch.setattr(app.router, "add_static", record_static_route)
|
||||
monkeypatch.setattr(app.router, "add_get", record_get_route)
|
||||
|
||||
register_calls: list[bool] = []
|
||||
monkeypatch.setattr(lora_manager, "register_default_model_types", lambda: register_calls.append(True))
|
||||
|
||||
model_factory_calls: list[object] = []
|
||||
monkeypatch.setattr(lora_manager.ModelServiceFactory, "setup_all_routes", lambda app_: model_factory_calls.append(app_))
|
||||
monkeypatch.setattr(lora_manager.ModelServiceFactory, "get_registered_types", lambda: ["dummy"])
|
||||
|
||||
stats_setup: list[object] = []
|
||||
class FakeStatsRoutes:
|
||||
def setup_routes(self, app_: web.Application) -> None:
|
||||
stats_setup.append(app_)
|
||||
monkeypatch.setattr(lora_manager, "StatsRoutes", FakeStatsRoutes)
|
||||
|
||||
recipe_setup: list[object] = []
|
||||
class FakeRecipeRoutes:
|
||||
@staticmethod
|
||||
def setup_routes(app_: web.Application) -> None:
|
||||
recipe_setup.append(app_)
|
||||
monkeypatch.setattr(lora_manager, "RecipeRoutes", FakeRecipeRoutes)
|
||||
|
||||
update_setup: list[object] = []
|
||||
class FakeUpdateRoutes:
|
||||
@staticmethod
|
||||
def setup_routes(app_: web.Application) -> None:
|
||||
update_setup.append(app_)
|
||||
monkeypatch.setattr(lora_manager, "UpdateRoutes", FakeUpdateRoutes)
|
||||
|
||||
misc_setup: list[object] = []
|
||||
class FakeMiscRoutes:
|
||||
@staticmethod
|
||||
def setup_routes(app_: web.Application) -> None:
|
||||
misc_setup.append(app_)
|
||||
monkeypatch.setattr(lora_manager, "MiscRoutes", FakeMiscRoutes)
|
||||
|
||||
example_setup: list[tuple[object, object]] = []
|
||||
class FakeExampleImagesRoutes:
|
||||
@staticmethod
|
||||
def setup_routes(app_: web.Application, *, ws_manager) -> None:
|
||||
example_setup.append((app_, ws_manager))
|
||||
monkeypatch.setattr(lora_manager, "ExampleImagesRoutes", FakeExampleImagesRoutes)
|
||||
|
||||
preview_setup: list[object] = []
|
||||
class FakePreviewRoutes:
|
||||
@staticmethod
|
||||
def setup_routes(app_: web.Application) -> None:
|
||||
preview_setup.append(app_)
|
||||
monkeypatch.setattr(lora_manager, "PreviewRoutes", FakePreviewRoutes)
|
||||
|
||||
fake_ws = _DummyWSManager()
|
||||
monkeypatch.setattr(lora_manager, "ws_manager", fake_ws)
|
||||
|
||||
example_images_root = tmp_path / "example_images"
|
||||
example_images_root.mkdir()
|
||||
monkeypatch.setattr(
|
||||
lora_manager.settings,
|
||||
"get",
|
||||
lambda key, default=None: str(example_images_root) if key == "example_images_path" else default,
|
||||
)
|
||||
|
||||
loras_root = tmp_path / "loras"
|
||||
loras_root.mkdir()
|
||||
(loras_root / "model.ckpt.bak").write_text("stale")
|
||||
nested = loras_root / "nested"
|
||||
nested.mkdir()
|
||||
(nested / "nested_model.ckpt.bak").write_text("stale")
|
||||
|
||||
checkpoints_root = tmp_path / "checkpoints"
|
||||
checkpoints_root.mkdir()
|
||||
(checkpoints_root / "checkpoint.safetensors.bak").write_text("old")
|
||||
|
||||
embeddings_root = tmp_path / "embeddings"
|
||||
embeddings_root.mkdir()
|
||||
(embeddings_root / "embedding.pt.bak").write_text("old")
|
||||
|
||||
monkeypatch.setattr(lora_manager.config, "loras_roots", [str(loras_root)])
|
||||
monkeypatch.setattr(lora_manager.config, "base_models_roots", [str(checkpoints_root)])
|
||||
monkeypatch.setattr(lora_manager.config, "embeddings_roots", [str(embeddings_root)])
|
||||
|
||||
scanners = {
|
||||
"lora": _DummyScanner("lora"),
|
||||
"checkpoint": _DummyScanner("checkpoint"),
|
||||
"embedding": _DummyScanner("embedding"),
|
||||
"recipe": _DummyScanner("recipe"),
|
||||
}
|
||||
|
||||
registry_calls: list[str] = []
|
||||
|
||||
async def _stub(name: str, value):
|
||||
registry_calls.append(name)
|
||||
return value
|
||||
|
||||
monkeypatch.setattr(lora_manager.ServiceRegistry, "get_civitai_client", lambda: _stub("civitai_client", object()))
|
||||
monkeypatch.setattr(lora_manager.ServiceRegistry, "get_download_manager", lambda: _stub("download_manager", object()))
|
||||
monkeypatch.setattr(lora_manager.ServiceRegistry, "get_websocket_manager", lambda: _stub("websocket_manager", object()))
|
||||
monkeypatch.setattr(lora_manager.ServiceRegistry, "get_lora_scanner", lambda: _stub("lora_scanner", scanners["lora"]))
|
||||
monkeypatch.setattr(lora_manager.ServiceRegistry, "get_checkpoint_scanner", lambda: _stub("checkpoint_scanner", scanners["checkpoint"]))
|
||||
monkeypatch.setattr(lora_manager.ServiceRegistry, "get_embedding_scanner", lambda: _stub("embedding_scanner", scanners["embedding"]))
|
||||
monkeypatch.setattr(lora_manager.ServiceRegistry, "get_recipe_scanner", lambda: _stub("recipe_scanner", scanners["recipe"]))
|
||||
|
||||
migration_calls: list[bool] = []
|
||||
|
||||
async def fake_migration() -> None:
|
||||
migration_calls.append(True)
|
||||
|
||||
monkeypatch.setattr(
|
||||
lora_manager.ExampleImagesMigration,
|
||||
"check_and_run_migrations",
|
||||
staticmethod(fake_migration),
|
||||
)
|
||||
|
||||
original_create_task = asyncio.create_task
|
||||
scheduled_tasks: list[asyncio.Task] = []
|
||||
|
||||
def track_create_task(coro, *, name=None):
|
||||
task = original_create_task(coro, name=name)
|
||||
scheduled_tasks.append(task)
|
||||
return task
|
||||
|
||||
monkeypatch.setattr(asyncio, "create_task", track_create_task)
|
||||
|
||||
asyncio_logger = logging.getLogger("asyncio")
|
||||
original_filters = list(asyncio_logger.filters)
|
||||
try:
|
||||
lora_manager.LoraManager.add_routes()
|
||||
assert lora_manager.LoraManager._cleanup in app.on_shutdown
|
||||
assert app.on_startup, "startup hooks should be registered"
|
||||
assert register_calls == [True]
|
||||
assert model_factory_calls == [app]
|
||||
assert stats_setup == [app]
|
||||
assert recipe_setup == [app]
|
||||
assert update_setup == [app]
|
||||
assert misc_setup == [app]
|
||||
assert example_setup == [(app, fake_ws)]
|
||||
assert preview_setup == [app]
|
||||
assert {path for path, _ in added_static_routes} == {
|
||||
"/example_images_static",
|
||||
"/locales",
|
||||
"/loras_static",
|
||||
}
|
||||
get_paths = {path for path, _ in added_get_routes}
|
||||
assert {"/ws/fetch-progress", "/ws/download-progress", "/ws/init-progress"}.issubset(get_paths)
|
||||
assert any(filter_obj.__class__.__name__ == "ConnectionResetFilter" for filter_obj in asyncio_logger.filters)
|
||||
finally:
|
||||
asyncio_logger.filters[:] = original_filters
|
||||
|
||||
await lora_manager.LoraManager._initialize_services()
|
||||
|
||||
pending = [task for task in scheduled_tasks if not task.done()]
|
||||
if pending:
|
||||
await asyncio.gather(*pending)
|
||||
|
||||
task_names = {task.get_name() for task in scheduled_tasks}
|
||||
assert {"lora_cache_init", "checkpoint_cache_init", "embedding_cache_init", "recipe_cache_init", "post_init_tasks", "cleanup_bak_files"}.issubset(task_names)
|
||||
|
||||
for scanner in scanners.values():
|
||||
assert scanner.initialized is True
|
||||
|
||||
assert migration_calls == [True]
|
||||
|
||||
for root in (loras_root, checkpoints_root, embeddings_root):
|
||||
assert not any(path.suffix == ".bak" for path in root.rglob("*")), f"Backup files remain in {root}"
|
||||
|
||||
assert {"civitai_client", "download_manager", "websocket_manager", "lora_scanner", "checkpoint_scanner", "embedding_scanner", "recipe_scanner"}.issubset(registry_calls)
|
||||
Reference in New Issue
Block a user