mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
- Add /api/example-images/check-pending endpoint to quickly check models needing downloads - Improve DownloadManager.start_download() to return immediately without blocking - Add _handle_download_task_done callback for proper error handling and progress saving - Add check_pending_models() method for lightweight pre-download validation - Update frontend ExampleImagesManager to use new check-pending endpoint - Add comprehensive tests for new functionality
375 lines
13 KiB
Python
375 lines
13 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from contextlib import asynccontextmanager
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict
|
|
|
|
from aiohttp import web
|
|
from aiohttp.test_utils import TestClient, TestServer
|
|
|
|
from py.routes.example_images_route_registrar import ExampleImagesRouteRegistrar
|
|
from py.routes.handlers.example_images_handlers import (
|
|
ExampleImagesDownloadHandler,
|
|
ExampleImagesFileHandler,
|
|
ExampleImagesHandlerSet,
|
|
ExampleImagesManagementHandler,
|
|
)
|
|
from py.services.use_cases.example_images import (
|
|
DownloadExampleImagesInProgressError,
|
|
ImportExampleImagesValidationError,
|
|
)
|
|
from py.utils.example_images_download_manager import (
|
|
DownloadInProgressError,
|
|
DownloadNotRunningError,
|
|
)
|
|
|
|
|
|
class StubDownloadUseCase:
|
|
def __init__(self) -> None:
|
|
self.payloads: list[dict[str, Any]] = []
|
|
self.error: Exception | None = None
|
|
|
|
async def execute(self, payload: dict[str, Any]) -> dict[str, Any]:
|
|
self.payloads.append(payload)
|
|
if self.error:
|
|
raise self.error
|
|
return {"success": True, "payload": payload}
|
|
|
|
|
|
class StubDownloadManager:
|
|
def __init__(self) -> None:
|
|
self.pause_calls = 0
|
|
self.resume_calls = 0
|
|
self.stop_calls = 0
|
|
self.force_payloads: list[dict[str, Any]] = []
|
|
self.pause_error: Exception | None = None
|
|
self.resume_error: Exception | None = None
|
|
self.stop_error: Exception | None = None
|
|
self.force_error: Exception | None = None
|
|
self.check_pending_result: dict[str, Any] | None = None
|
|
self.check_pending_calls: list[list[str]] = []
|
|
|
|
async def get_status(self, request: web.Request) -> dict[str, Any]:
|
|
return {"success": True, "status": "idle"}
|
|
|
|
async def pause_download(self, request: web.Request) -> dict[str, Any]:
|
|
self.pause_calls += 1
|
|
if self.pause_error:
|
|
raise self.pause_error
|
|
return {"success": True, "message": "paused"}
|
|
|
|
async def resume_download(self, request: web.Request) -> dict[str, Any]:
|
|
self.resume_calls += 1
|
|
if self.resume_error:
|
|
raise self.resume_error
|
|
return {"success": True, "message": "resumed"}
|
|
|
|
async def stop_download(self, request: web.Request) -> dict[str, Any]:
|
|
self.stop_calls += 1
|
|
if self.stop_error:
|
|
raise self.stop_error
|
|
return {"success": True, "message": "stopping"}
|
|
|
|
async def start_force_download(self, payload: dict[str, Any]) -> dict[str, Any]:
|
|
self.force_payloads.append(payload)
|
|
if self.force_error:
|
|
raise self.force_error
|
|
return {"success": True, "payload": payload}
|
|
|
|
async def check_pending_models(self, model_types: list[str]) -> dict[str, Any]:
|
|
self.check_pending_calls.append(model_types)
|
|
if self.check_pending_result is not None:
|
|
return self.check_pending_result
|
|
return {
|
|
"success": True,
|
|
"is_downloading": False,
|
|
"total_models": 100,
|
|
"pending_count": 10,
|
|
"processed_count": 90,
|
|
"failed_count": 0,
|
|
"needs_download": True,
|
|
}
|
|
|
|
|
|
class StubImportUseCase:
|
|
def __init__(self) -> None:
|
|
self.requests: list[web.Request] = []
|
|
self.error: Exception | None = None
|
|
|
|
async def execute(self, request: web.Request) -> dict[str, Any]:
|
|
self.requests.append(request)
|
|
if self.error:
|
|
raise self.error
|
|
return {"success": True}
|
|
|
|
|
|
class StubProcessor:
|
|
def __init__(self) -> None:
|
|
self.delete_calls: list[web.Request] = []
|
|
self.nsfw_calls: list[web.Request] = []
|
|
|
|
async def delete_custom_image(self, request: web.Request) -> web.Response:
|
|
self.delete_calls.append(request)
|
|
return web.json_response({"deleted": True})
|
|
|
|
async def set_example_image_nsfw_level(self, request: web.Request) -> web.Response:
|
|
self.nsfw_calls.append(request)
|
|
return web.json_response({"updated": True})
|
|
|
|
|
|
class StubCleanupService:
|
|
def __init__(self) -> None:
|
|
self.calls = 0
|
|
|
|
async def cleanup_example_image_folders(self) -> dict[str, Any]:
|
|
self.calls += 1
|
|
return {"success": True}
|
|
|
|
|
|
class StubFileManager:
|
|
async def open_folder(self, request: web.Request) -> web.Response:
|
|
return web.json_response({"opened": True})
|
|
|
|
async def get_files(self, request: web.Request) -> web.Response:
|
|
return web.json_response({"files": []})
|
|
|
|
async def has_images(self, request: web.Request) -> web.Response:
|
|
return web.json_response({"has": False})
|
|
|
|
|
|
@dataclass
|
|
class RegistrarHarness:
|
|
client: TestClient
|
|
download_use_case: StubDownloadUseCase
|
|
download_manager: StubDownloadManager
|
|
import_use_case: StubImportUseCase
|
|
|
|
|
|
@asynccontextmanager
|
|
async def registrar_app() -> RegistrarHarness:
|
|
app = web.Application()
|
|
|
|
download_use_case = StubDownloadUseCase()
|
|
download_manager = StubDownloadManager()
|
|
import_use_case = StubImportUseCase()
|
|
processor = StubProcessor()
|
|
cleanup_service = StubCleanupService()
|
|
file_manager = StubFileManager()
|
|
|
|
handler_set = ExampleImagesHandlerSet(
|
|
download=ExampleImagesDownloadHandler(download_use_case, download_manager),
|
|
management=ExampleImagesManagementHandler(import_use_case, processor, cleanup_service),
|
|
files=ExampleImagesFileHandler(file_manager),
|
|
)
|
|
|
|
registrar = ExampleImagesRouteRegistrar(app)
|
|
registrar.register_routes(handler_set.to_route_mapping())
|
|
|
|
server = TestServer(app)
|
|
client = TestClient(server)
|
|
await client.start_server()
|
|
|
|
try:
|
|
yield RegistrarHarness(
|
|
client=client,
|
|
download_use_case=download_use_case,
|
|
download_manager=download_manager,
|
|
import_use_case=import_use_case,
|
|
)
|
|
finally:
|
|
await client.close()
|
|
|
|
|
|
async def _json(response: web.StreamResponse) -> Dict[str, Any]:
|
|
text = await response.text()
|
|
return json.loads(text) if text else {}
|
|
|
|
|
|
async def test_download_route_surfaces_in_progress_error():
|
|
async with registrar_app() as harness:
|
|
progress = {"status": "running"}
|
|
harness.download_use_case.error = DownloadExampleImagesInProgressError(progress)
|
|
|
|
response = await harness.client.post(
|
|
"/api/lm/download-example-images",
|
|
json={"model_types": ["lora"]},
|
|
)
|
|
|
|
assert response.status == 400
|
|
body = await _json(response)
|
|
assert body["status"] == progress
|
|
assert body["error"] == "Download already in progress"
|
|
|
|
|
|
async def test_force_download_translates_manager_errors():
|
|
async with registrar_app() as harness:
|
|
snapshot = {"status": "running"}
|
|
harness.download_manager.force_error = DownloadInProgressError(snapshot)
|
|
|
|
response = await harness.client.post(
|
|
"/api/lm/force-download-example-images",
|
|
json={"model_hashes": ["abc"]},
|
|
)
|
|
|
|
assert response.status == 400
|
|
body = await _json(response)
|
|
assert body["status"] == snapshot
|
|
assert body["error"] == "Download already in progress"
|
|
|
|
|
|
async def test_pause_and_resume_return_client_errors_when_not_running():
|
|
async with registrar_app() as harness:
|
|
harness.download_manager.pause_error = DownloadNotRunningError()
|
|
harness.download_manager.resume_error = DownloadNotRunningError("Stopped")
|
|
harness.download_manager.stop_error = DownloadNotRunningError("Not running")
|
|
|
|
pause_response = await harness.client.post("/api/lm/pause-example-images")
|
|
resume_response = await harness.client.post("/api/lm/resume-example-images")
|
|
stop_response = await harness.client.post("/api/lm/stop-example-images")
|
|
|
|
assert pause_response.status == 400
|
|
assert resume_response.status == 400
|
|
assert stop_response.status == 400
|
|
|
|
pause_body = await _json(pause_response)
|
|
resume_body = await _json(resume_response)
|
|
stop_body = await _json(stop_response)
|
|
assert pause_body == {"success": False, "error": "No download in progress"}
|
|
assert resume_body == {"success": False, "error": "Stopped"}
|
|
assert stop_body == {"success": False, "error": "Not running"}
|
|
|
|
|
|
async def test_import_route_returns_validation_errors():
|
|
async with registrar_app() as harness:
|
|
harness.import_use_case.error = ImportExampleImagesValidationError("bad payload")
|
|
|
|
response = await harness.client.post(
|
|
"/api/lm/import-example-images",
|
|
json={"model_hash": "missing"},
|
|
)
|
|
|
|
assert response.status == 400
|
|
body = await _json(response)
|
|
assert body == {"success": False, "error": "bad payload"}
|
|
|
|
|
|
async def test_check_example_images_needed_returns_pending_counts():
|
|
"""Test that check_example_images_needed endpoint returns pending model counts."""
|
|
async with registrar_app() as harness:
|
|
harness.download_manager.check_pending_result = {
|
|
"success": True,
|
|
"is_downloading": False,
|
|
"total_models": 5500,
|
|
"pending_count": 12,
|
|
"processed_count": 5488,
|
|
"failed_count": 45,
|
|
"needs_download": True,
|
|
}
|
|
|
|
response = await harness.client.post(
|
|
"/api/lm/check-example-images-needed",
|
|
json={"model_types": ["lora", "checkpoint"]},
|
|
)
|
|
|
|
assert response.status == 200
|
|
body = await _json(response)
|
|
assert body["success"] is True
|
|
assert body["total_models"] == 5500
|
|
assert body["pending_count"] == 12
|
|
assert body["processed_count"] == 5488
|
|
assert body["failed_count"] == 45
|
|
assert body["needs_download"] is True
|
|
assert body["is_downloading"] is False
|
|
|
|
# Verify the manager was called with correct model types
|
|
assert harness.download_manager.check_pending_calls == [["lora", "checkpoint"]]
|
|
|
|
|
|
async def test_check_example_images_needed_handles_download_in_progress():
|
|
"""Test that check_example_images_needed returns correct status when download is running."""
|
|
async with registrar_app() as harness:
|
|
harness.download_manager.check_pending_result = {
|
|
"success": True,
|
|
"is_downloading": True,
|
|
"total_models": 0,
|
|
"pending_count": 0,
|
|
"processed_count": 0,
|
|
"failed_count": 0,
|
|
"needs_download": False,
|
|
"message": "Download already in progress",
|
|
}
|
|
|
|
response = await harness.client.post(
|
|
"/api/lm/check-example-images-needed",
|
|
json={"model_types": ["lora"]},
|
|
)
|
|
|
|
assert response.status == 200
|
|
body = await _json(response)
|
|
assert body["success"] is True
|
|
assert body["is_downloading"] is True
|
|
assert body["needs_download"] is False
|
|
|
|
|
|
async def test_check_example_images_needed_handles_no_pending_models():
|
|
"""Test that check_example_images_needed returns correct status when no work is needed."""
|
|
async with registrar_app() as harness:
|
|
harness.download_manager.check_pending_result = {
|
|
"success": True,
|
|
"is_downloading": False,
|
|
"total_models": 5500,
|
|
"pending_count": 0,
|
|
"processed_count": 5500,
|
|
"failed_count": 0,
|
|
"needs_download": False,
|
|
}
|
|
|
|
response = await harness.client.post(
|
|
"/api/lm/check-example-images-needed",
|
|
json={"model_types": ["lora", "checkpoint", "embedding"]},
|
|
)
|
|
|
|
assert response.status == 200
|
|
body = await _json(response)
|
|
assert body["success"] is True
|
|
assert body["pending_count"] == 0
|
|
assert body["needs_download"] is False
|
|
assert body["processed_count"] == 5500
|
|
|
|
|
|
async def test_check_example_images_needed_uses_default_model_types():
|
|
"""Test that check_example_images_needed uses default model types when not specified."""
|
|
async with registrar_app() as harness:
|
|
response = await harness.client.post(
|
|
"/api/lm/check-example-images-needed",
|
|
json={}, # No model_types specified
|
|
)
|
|
|
|
assert response.status == 200
|
|
# Should use default model types
|
|
assert harness.download_manager.check_pending_calls == [["lora", "checkpoint", "embedding"]]
|
|
|
|
|
|
async def test_check_example_images_needed_returns_error_on_exception():
|
|
"""Test that check_example_images_needed returns 500 on internal error."""
|
|
async with registrar_app() as harness:
|
|
# Simulate an error by setting result to an error state
|
|
# Actually, we need to make the method raise an exception
|
|
original_method = harness.download_manager.check_pending_models
|
|
|
|
async def failing_check(_model_types):
|
|
raise RuntimeError("Database connection failed")
|
|
|
|
harness.download_manager.check_pending_models = failing_check
|
|
|
|
response = await harness.client.post(
|
|
"/api/lm/check-example-images-needed",
|
|
json={"model_types": ["lora"]},
|
|
)
|
|
|
|
assert response.status == 500
|
|
body = await _json(response)
|
|
assert body["success"] is False
|
|
assert "Database connection failed" in body["error"]
|