mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 14:42:11 -03:00
feat(example-images): add stop control for download panel
This commit is contained in:
@@ -41,9 +41,11 @@ class StubDownloadManager:
|
||||
def __init__(self) -> None:
|
||||
self.pause_calls = 0
|
||||
self.resume_calls = 0
|
||||
self.stop_calls = 0
|
||||
self.force_payloads: list[dict[str, Any]] = []
|
||||
self.pause_error: Exception | None = None
|
||||
self.resume_error: Exception | None = None
|
||||
self.stop_error: Exception | None = None
|
||||
self.force_error: Exception | None = None
|
||||
|
||||
async def get_status(self, request: web.Request) -> dict[str, Any]:
|
||||
@@ -61,6 +63,12 @@ class StubDownloadManager:
|
||||
raise self.resume_error
|
||||
return {"success": True, "message": "resumed"}
|
||||
|
||||
async def stop_download(self, request: web.Request) -> dict[str, Any]:
|
||||
self.stop_calls += 1
|
||||
if self.stop_error:
|
||||
raise self.stop_error
|
||||
return {"success": True, "message": "stopping"}
|
||||
|
||||
async def start_force_download(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
self.force_payloads.append(payload)
|
||||
if self.force_error:
|
||||
@@ -193,17 +201,22 @@ async def test_pause_and_resume_return_client_errors_when_not_running():
|
||||
async with registrar_app() as harness:
|
||||
harness.download_manager.pause_error = DownloadNotRunningError()
|
||||
harness.download_manager.resume_error = DownloadNotRunningError("Stopped")
|
||||
harness.download_manager.stop_error = DownloadNotRunningError("Not running")
|
||||
|
||||
pause_response = await harness.client.post("/api/lm/pause-example-images")
|
||||
resume_response = await harness.client.post("/api/lm/resume-example-images")
|
||||
stop_response = await harness.client.post("/api/lm/stop-example-images")
|
||||
|
||||
assert pause_response.status == 400
|
||||
assert resume_response.status == 400
|
||||
assert stop_response.status == 400
|
||||
|
||||
pause_body = await _json(pause_response)
|
||||
resume_body = await _json(resume_response)
|
||||
stop_body = await _json(stop_response)
|
||||
assert pause_body == {"success": False, "error": "No download in progress"}
|
||||
assert resume_body == {"success": False, "error": "Stopped"}
|
||||
assert stop_body == {"success": False, "error": "Not running"}
|
||||
|
||||
|
||||
async def test_import_route_returns_validation_errors():
|
||||
|
||||
@@ -51,6 +51,10 @@ class StubDownloadManager:
|
||||
self.calls.append(("resume_download", None))
|
||||
return {"operation": "resume_download"}
|
||||
|
||||
async def stop_download(self, request: web.Request) -> dict:
|
||||
self.calls.append(("stop_download", None))
|
||||
return {"operation": "stop_download"}
|
||||
|
||||
async def start_force_download(self, payload: Any) -> dict:
|
||||
self.calls.append(("start_force_download", payload))
|
||||
return {"operation": "start_force_download", "payload": payload}
|
||||
@@ -195,19 +199,23 @@ async def test_status_route_returns_manager_payload():
|
||||
assert harness.download_manager.calls == [("get_status", {"detail": "true"})]
|
||||
|
||||
|
||||
async def test_pause_and_resume_routes_delegate():
|
||||
async def test_pause_resume_and_stop_routes_delegate():
|
||||
async with example_images_app() as harness:
|
||||
pause_response = await harness.client.post("/api/lm/pause-example-images")
|
||||
resume_response = await harness.client.post("/api/lm/resume-example-images")
|
||||
stop_response = await harness.client.post("/api/lm/stop-example-images")
|
||||
|
||||
assert pause_response.status == 200
|
||||
assert await pause_response.json() == {"operation": "pause_download"}
|
||||
assert resume_response.status == 200
|
||||
assert await resume_response.json() == {"operation": "resume_download"}
|
||||
assert stop_response.status == 200
|
||||
assert await stop_response.json() == {"operation": "stop_download"}
|
||||
|
||||
assert harness.download_manager.calls[-2:] == [
|
||||
assert harness.download_manager.calls[-3:] == [
|
||||
("pause_download", None),
|
||||
("resume_download", None),
|
||||
("stop_download", None),
|
||||
]
|
||||
|
||||
|
||||
@@ -309,6 +317,10 @@ async def test_download_handler_methods_delegate() -> None:
|
||||
self.calls.append(("resume_download", request))
|
||||
return {"status": "running"}
|
||||
|
||||
async def stop_download(self, request) -> dict:
|
||||
self.calls.append(("stop_download", request))
|
||||
return {"status": "stopping"}
|
||||
|
||||
async def start_force_download(self, payload) -> dict:
|
||||
self.calls.append(("start_force_download", payload))
|
||||
return {"status": "force", "payload": payload}
|
||||
@@ -342,6 +354,8 @@ async def test_download_handler_methods_delegate() -> None:
|
||||
assert json.loads(pause_response.text) == {"status": "paused"}
|
||||
resume_response = await handler.resume_example_images(request)
|
||||
assert json.loads(resume_response.text) == {"status": "running"}
|
||||
stop_response = await handler.stop_example_images(request)
|
||||
assert json.loads(stop_response.text) == {"status": "stopping"}
|
||||
force_response = await handler.force_download_example_images(request)
|
||||
assert json.loads(force_response.text) == {"status": "force", "payload": {"foo": "bar"}}
|
||||
|
||||
@@ -350,6 +364,7 @@ async def test_download_handler_methods_delegate() -> None:
|
||||
("get_status", request),
|
||||
("pause_download", request),
|
||||
("resume_download", request),
|
||||
("stop_download", request),
|
||||
("start_force_download", {"foo": "bar"}),
|
||||
]
|
||||
|
||||
@@ -460,6 +475,7 @@ def test_handler_set_route_mapping_includes_all_handlers() -> None:
|
||||
"get_example_images_status",
|
||||
"pause_example_images",
|
||||
"resume_example_images",
|
||||
"stop_example_images",
|
||||
"force_download_example_images",
|
||||
"import_example_images",
|
||||
"delete_example_image",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
@@ -128,6 +129,59 @@ async def test_pause_and_resume_flow(monkeypatch: pytest.MonkeyPatch, tmp_path)
|
||||
await asyncio.wait_for(task, timeout=1)
|
||||
|
||||
|
||||
async def test_stop_download_transitions_to_stopped(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
|
||||
settings_manager = get_settings_manager()
|
||||
settings_manager.settings["example_images_path"] = str(tmp_path)
|
||||
settings_manager.settings["libraries"] = {"default": {}}
|
||||
settings_manager.settings["active_library"] = "default"
|
||||
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
started = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
|
||||
async def fake_download(self, *_args):
|
||||
started.set()
|
||||
await release.wait()
|
||||
async with self._state_lock:
|
||||
if self._stop_requested and self._progress['status'] == 'stopping':
|
||||
self._progress['status'] = 'stopped'
|
||||
else:
|
||||
self._progress['status'] = 'completed'
|
||||
self._progress['end_time'] = time.time()
|
||||
self._stop_requested = False
|
||||
await self._broadcast_progress(status=self._progress['status'])
|
||||
async with self._state_lock:
|
||||
self._is_downloading = False
|
||||
self._download_task = None
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_module.DownloadManager,
|
||||
"_download_all_example_images",
|
||||
fake_download,
|
||||
)
|
||||
|
||||
await manager.start_download({})
|
||||
await asyncio.wait_for(started.wait(), timeout=1)
|
||||
|
||||
stop_response = await manager.stop_download(object())
|
||||
assert stop_response == {"success": True, "message": "Download stopping"}
|
||||
assert manager._progress["status"] == "stopping"
|
||||
|
||||
task = manager._download_task
|
||||
assert task is not None
|
||||
release.set()
|
||||
await asyncio.wait_for(task, timeout=1)
|
||||
|
||||
assert manager._progress["status"] == "stopped"
|
||||
assert manager._is_downloading is False
|
||||
assert manager._stop_requested is False
|
||||
statuses = [payload["status"] for payload in ws_manager.payloads]
|
||||
assert "stopping" in statuses
|
||||
assert "stopped" in statuses
|
||||
|
||||
|
||||
async def test_pause_or_resume_without_running_download(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
manager = download_module.DownloadManager(ws_manager=RecordingWebSocketManager())
|
||||
|
||||
@@ -136,3 +190,6 @@ async def test_pause_or_resume_without_running_download(monkeypatch: pytest.Monk
|
||||
|
||||
with pytest.raises(download_module.DownloadNotRunningError):
|
||||
await manager.resume_download(object())
|
||||
|
||||
with pytest.raises(download_module.DownloadNotRunningError):
|
||||
await manager.stop_download(object())
|
||||
|
||||
Reference in New Issue
Block a user