mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Add new POST endpoint `/api/lm/example-images/set-nsfw-level` to allow updating NSFW classification for individual example images. The endpoint supports both regular and custom images, validates required parameters, and updates the corresponding model metadata. This enables users to manually adjust NSFW ratings for better content filtering.
239 lines
8.0 KiB
Python
239 lines
8.0 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from contextlib import asynccontextmanager
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict
|
|
|
|
from aiohttp import web
|
|
from aiohttp.test_utils import TestClient, TestServer
|
|
|
|
from py.routes.example_images_route_registrar import ExampleImagesRouteRegistrar
|
|
from py.routes.handlers.example_images_handlers import (
|
|
ExampleImagesDownloadHandler,
|
|
ExampleImagesFileHandler,
|
|
ExampleImagesHandlerSet,
|
|
ExampleImagesManagementHandler,
|
|
)
|
|
from py.services.use_cases.example_images import (
|
|
DownloadExampleImagesInProgressError,
|
|
ImportExampleImagesValidationError,
|
|
)
|
|
from py.utils.example_images_download_manager import (
|
|
DownloadInProgressError,
|
|
DownloadNotRunningError,
|
|
)
|
|
|
|
|
|
class StubDownloadUseCase:
|
|
def __init__(self) -> None:
|
|
self.payloads: list[dict[str, Any]] = []
|
|
self.error: Exception | None = None
|
|
|
|
async def execute(self, payload: dict[str, Any]) -> dict[str, Any]:
|
|
self.payloads.append(payload)
|
|
if self.error:
|
|
raise self.error
|
|
return {"success": True, "payload": payload}
|
|
|
|
|
|
class StubDownloadManager:
|
|
def __init__(self) -> None:
|
|
self.pause_calls = 0
|
|
self.resume_calls = 0
|
|
self.stop_calls = 0
|
|
self.force_payloads: list[dict[str, Any]] = []
|
|
self.pause_error: Exception | None = None
|
|
self.resume_error: Exception | None = None
|
|
self.stop_error: Exception | None = None
|
|
self.force_error: Exception | None = None
|
|
|
|
async def get_status(self, request: web.Request) -> dict[str, Any]:
|
|
return {"success": True, "status": "idle"}
|
|
|
|
async def pause_download(self, request: web.Request) -> dict[str, Any]:
|
|
self.pause_calls += 1
|
|
if self.pause_error:
|
|
raise self.pause_error
|
|
return {"success": True, "message": "paused"}
|
|
|
|
async def resume_download(self, request: web.Request) -> dict[str, Any]:
|
|
self.resume_calls += 1
|
|
if self.resume_error:
|
|
raise self.resume_error
|
|
return {"success": True, "message": "resumed"}
|
|
|
|
async def stop_download(self, request: web.Request) -> dict[str, Any]:
|
|
self.stop_calls += 1
|
|
if self.stop_error:
|
|
raise self.stop_error
|
|
return {"success": True, "message": "stopping"}
|
|
|
|
async def start_force_download(self, payload: dict[str, Any]) -> dict[str, Any]:
|
|
self.force_payloads.append(payload)
|
|
if self.force_error:
|
|
raise self.force_error
|
|
return {"success": True, "payload": payload}
|
|
|
|
|
|
class StubImportUseCase:
|
|
def __init__(self) -> None:
|
|
self.requests: list[web.Request] = []
|
|
self.error: Exception | None = None
|
|
|
|
async def execute(self, request: web.Request) -> dict[str, Any]:
|
|
self.requests.append(request)
|
|
if self.error:
|
|
raise self.error
|
|
return {"success": True}
|
|
|
|
|
|
class StubProcessor:
|
|
def __init__(self) -> None:
|
|
self.delete_calls: list[web.Request] = []
|
|
self.nsfw_calls: list[web.Request] = []
|
|
|
|
async def delete_custom_image(self, request: web.Request) -> web.Response:
|
|
self.delete_calls.append(request)
|
|
return web.json_response({"deleted": True})
|
|
|
|
async def set_example_image_nsfw_level(self, request: web.Request) -> web.Response:
|
|
self.nsfw_calls.append(request)
|
|
return web.json_response({"updated": True})
|
|
|
|
|
|
class StubCleanupService:
|
|
def __init__(self) -> None:
|
|
self.calls = 0
|
|
|
|
async def cleanup_example_image_folders(self) -> dict[str, Any]:
|
|
self.calls += 1
|
|
return {"success": True}
|
|
|
|
|
|
class StubFileManager:
|
|
async def open_folder(self, request: web.Request) -> web.Response:
|
|
return web.json_response({"opened": True})
|
|
|
|
async def get_files(self, request: web.Request) -> web.Response:
|
|
return web.json_response({"files": []})
|
|
|
|
async def has_images(self, request: web.Request) -> web.Response:
|
|
return web.json_response({"has": False})
|
|
|
|
|
|
@dataclass
|
|
class RegistrarHarness:
|
|
client: TestClient
|
|
download_use_case: StubDownloadUseCase
|
|
download_manager: StubDownloadManager
|
|
import_use_case: StubImportUseCase
|
|
|
|
|
|
@asynccontextmanager
|
|
async def registrar_app() -> RegistrarHarness:
|
|
app = web.Application()
|
|
|
|
download_use_case = StubDownloadUseCase()
|
|
download_manager = StubDownloadManager()
|
|
import_use_case = StubImportUseCase()
|
|
processor = StubProcessor()
|
|
cleanup_service = StubCleanupService()
|
|
file_manager = StubFileManager()
|
|
|
|
handler_set = ExampleImagesHandlerSet(
|
|
download=ExampleImagesDownloadHandler(download_use_case, download_manager),
|
|
management=ExampleImagesManagementHandler(import_use_case, processor, cleanup_service),
|
|
files=ExampleImagesFileHandler(file_manager),
|
|
)
|
|
|
|
registrar = ExampleImagesRouteRegistrar(app)
|
|
registrar.register_routes(handler_set.to_route_mapping())
|
|
|
|
server = TestServer(app)
|
|
client = TestClient(server)
|
|
await client.start_server()
|
|
|
|
try:
|
|
yield RegistrarHarness(
|
|
client=client,
|
|
download_use_case=download_use_case,
|
|
download_manager=download_manager,
|
|
import_use_case=import_use_case,
|
|
)
|
|
finally:
|
|
await client.close()
|
|
|
|
|
|
async def _json(response: web.StreamResponse) -> Dict[str, Any]:
|
|
text = await response.text()
|
|
return json.loads(text) if text else {}
|
|
|
|
|
|
async def test_download_route_surfaces_in_progress_error():
|
|
async with registrar_app() as harness:
|
|
progress = {"status": "running"}
|
|
harness.download_use_case.error = DownloadExampleImagesInProgressError(progress)
|
|
|
|
response = await harness.client.post(
|
|
"/api/lm/download-example-images",
|
|
json={"model_types": ["lora"]},
|
|
)
|
|
|
|
assert response.status == 400
|
|
body = await _json(response)
|
|
assert body["status"] == progress
|
|
assert body["error"] == "Download already in progress"
|
|
|
|
|
|
async def test_force_download_translates_manager_errors():
|
|
async with registrar_app() as harness:
|
|
snapshot = {"status": "running"}
|
|
harness.download_manager.force_error = DownloadInProgressError(snapshot)
|
|
|
|
response = await harness.client.post(
|
|
"/api/lm/force-download-example-images",
|
|
json={"model_hashes": ["abc"]},
|
|
)
|
|
|
|
assert response.status == 400
|
|
body = await _json(response)
|
|
assert body["status"] == snapshot
|
|
assert body["error"] == "Download already in progress"
|
|
|
|
|
|
async def test_pause_and_resume_return_client_errors_when_not_running():
|
|
async with registrar_app() as harness:
|
|
harness.download_manager.pause_error = DownloadNotRunningError()
|
|
harness.download_manager.resume_error = DownloadNotRunningError("Stopped")
|
|
harness.download_manager.stop_error = DownloadNotRunningError("Not running")
|
|
|
|
pause_response = await harness.client.post("/api/lm/pause-example-images")
|
|
resume_response = await harness.client.post("/api/lm/resume-example-images")
|
|
stop_response = await harness.client.post("/api/lm/stop-example-images")
|
|
|
|
assert pause_response.status == 400
|
|
assert resume_response.status == 400
|
|
assert stop_response.status == 400
|
|
|
|
pause_body = await _json(pause_response)
|
|
resume_body = await _json(resume_response)
|
|
stop_body = await _json(stop_response)
|
|
assert pause_body == {"success": False, "error": "No download in progress"}
|
|
assert resume_body == {"success": False, "error": "Stopped"}
|
|
assert stop_body == {"success": False, "error": "Not running"}
|
|
|
|
|
|
async def test_import_route_returns_validation_errors():
|
|
async with registrar_app() as harness:
|
|
harness.import_use_case.error = ImportExampleImagesValidationError("bad payload")
|
|
|
|
response = await harness.client.post(
|
|
"/api/lm/import-example-images",
|
|
json={"model_hash": "missing"},
|
|
)
|
|
|
|
assert response.status == 400
|
|
body = await _json(response)
|
|
assert body == {"success": False, "error": "bad payload"}
|