Files
ComfyUI-Lora-Manager/tests/routes/test_example_images_route_registrar_handlers.py
Will Miao a6e23a7630 feat(example-images): add NSFW level setting endpoint
Add new POST endpoint `/api/lm/example-images/set-nsfw-level` to allow updating NSFW classification for individual example images. The endpoint supports both regular and custom images, validates required parameters, and updates the corresponding model metadata. This enables users to manually adjust NSFW ratings for better content filtering.
2025-12-09 20:37:16 +08:00

239 lines
8.0 KiB
Python

from __future__ import annotations
import json
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, Dict
from aiohttp import web
from aiohttp.test_utils import TestClient, TestServer
from py.routes.example_images_route_registrar import ExampleImagesRouteRegistrar
from py.routes.handlers.example_images_handlers import (
ExampleImagesDownloadHandler,
ExampleImagesFileHandler,
ExampleImagesHandlerSet,
ExampleImagesManagementHandler,
)
from py.services.use_cases.example_images import (
DownloadExampleImagesInProgressError,
ImportExampleImagesValidationError,
)
from py.utils.example_images_download_manager import (
DownloadInProgressError,
DownloadNotRunningError,
)
class StubDownloadUseCase:
def __init__(self) -> None:
self.payloads: list[dict[str, Any]] = []
self.error: Exception | None = None
async def execute(self, payload: dict[str, Any]) -> dict[str, Any]:
self.payloads.append(payload)
if self.error:
raise self.error
return {"success": True, "payload": payload}
class StubDownloadManager:
def __init__(self) -> None:
self.pause_calls = 0
self.resume_calls = 0
self.stop_calls = 0
self.force_payloads: list[dict[str, Any]] = []
self.pause_error: Exception | None = None
self.resume_error: Exception | None = None
self.stop_error: Exception | None = None
self.force_error: Exception | None = None
async def get_status(self, request: web.Request) -> dict[str, Any]:
return {"success": True, "status": "idle"}
async def pause_download(self, request: web.Request) -> dict[str, Any]:
self.pause_calls += 1
if self.pause_error:
raise self.pause_error
return {"success": True, "message": "paused"}
async def resume_download(self, request: web.Request) -> dict[str, Any]:
self.resume_calls += 1
if self.resume_error:
raise self.resume_error
return {"success": True, "message": "resumed"}
async def stop_download(self, request: web.Request) -> dict[str, Any]:
self.stop_calls += 1
if self.stop_error:
raise self.stop_error
return {"success": True, "message": "stopping"}
async def start_force_download(self, payload: dict[str, Any]) -> dict[str, Any]:
self.force_payloads.append(payload)
if self.force_error:
raise self.force_error
return {"success": True, "payload": payload}
class StubImportUseCase:
def __init__(self) -> None:
self.requests: list[web.Request] = []
self.error: Exception | None = None
async def execute(self, request: web.Request) -> dict[str, Any]:
self.requests.append(request)
if self.error:
raise self.error
return {"success": True}
class StubProcessor:
def __init__(self) -> None:
self.delete_calls: list[web.Request] = []
self.nsfw_calls: list[web.Request] = []
async def delete_custom_image(self, request: web.Request) -> web.Response:
self.delete_calls.append(request)
return web.json_response({"deleted": True})
async def set_example_image_nsfw_level(self, request: web.Request) -> web.Response:
self.nsfw_calls.append(request)
return web.json_response({"updated": True})
class StubCleanupService:
def __init__(self) -> None:
self.calls = 0
async def cleanup_example_image_folders(self) -> dict[str, Any]:
self.calls += 1
return {"success": True}
class StubFileManager:
async def open_folder(self, request: web.Request) -> web.Response:
return web.json_response({"opened": True})
async def get_files(self, request: web.Request) -> web.Response:
return web.json_response({"files": []})
async def has_images(self, request: web.Request) -> web.Response:
return web.json_response({"has": False})
@dataclass
class RegistrarHarness:
client: TestClient
download_use_case: StubDownloadUseCase
download_manager: StubDownloadManager
import_use_case: StubImportUseCase
@asynccontextmanager
async def registrar_app() -> RegistrarHarness:
app = web.Application()
download_use_case = StubDownloadUseCase()
download_manager = StubDownloadManager()
import_use_case = StubImportUseCase()
processor = StubProcessor()
cleanup_service = StubCleanupService()
file_manager = StubFileManager()
handler_set = ExampleImagesHandlerSet(
download=ExampleImagesDownloadHandler(download_use_case, download_manager),
management=ExampleImagesManagementHandler(import_use_case, processor, cleanup_service),
files=ExampleImagesFileHandler(file_manager),
)
registrar = ExampleImagesRouteRegistrar(app)
registrar.register_routes(handler_set.to_route_mapping())
server = TestServer(app)
client = TestClient(server)
await client.start_server()
try:
yield RegistrarHarness(
client=client,
download_use_case=download_use_case,
download_manager=download_manager,
import_use_case=import_use_case,
)
finally:
await client.close()
async def _json(response: web.StreamResponse) -> Dict[str, Any]:
text = await response.text()
return json.loads(text) if text else {}
async def test_download_route_surfaces_in_progress_error():
async with registrar_app() as harness:
progress = {"status": "running"}
harness.download_use_case.error = DownloadExampleImagesInProgressError(progress)
response = await harness.client.post(
"/api/lm/download-example-images",
json={"model_types": ["lora"]},
)
assert response.status == 400
body = await _json(response)
assert body["status"] == progress
assert body["error"] == "Download already in progress"
async def test_force_download_translates_manager_errors():
async with registrar_app() as harness:
snapshot = {"status": "running"}
harness.download_manager.force_error = DownloadInProgressError(snapshot)
response = await harness.client.post(
"/api/lm/force-download-example-images",
json={"model_hashes": ["abc"]},
)
assert response.status == 400
body = await _json(response)
assert body["status"] == snapshot
assert body["error"] == "Download already in progress"
async def test_pause_and_resume_return_client_errors_when_not_running():
async with registrar_app() as harness:
harness.download_manager.pause_error = DownloadNotRunningError()
harness.download_manager.resume_error = DownloadNotRunningError("Stopped")
harness.download_manager.stop_error = DownloadNotRunningError("Not running")
pause_response = await harness.client.post("/api/lm/pause-example-images")
resume_response = await harness.client.post("/api/lm/resume-example-images")
stop_response = await harness.client.post("/api/lm/stop-example-images")
assert pause_response.status == 400
assert resume_response.status == 400
assert stop_response.status == 400
pause_body = await _json(pause_response)
resume_body = await _json(resume_response)
stop_body = await _json(stop_response)
assert pause_body == {"success": False, "error": "No download in progress"}
assert resume_body == {"success": False, "error": "Stopped"}
assert stop_body == {"success": False, "error": "Not running"}
async def test_import_route_returns_validation_errors():
async with registrar_app() as harness:
harness.import_use_case.error = ImportExampleImagesValidationError("bad payload")
response = await harness.client.post(
"/api/lm/import-example-images",
json={"model_hash": "missing"},
)
assert response.status == 400
body = await _json(response)
assert body == {"success": False, "error": "bad payload"}