mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
When force=true is passed via API, models in failed_models set are re-downloaded instead of being skipped. On successful download, model is removed from failed_models set. This provides a manual batch repair mechanism for users when CivitAI media server is temporarily down and causes empty folders. Changes: - Backend: Add force parameter to start_download(), _download_all_example_images(), _process_model() - Backend: Skip failed_models check when force=true - Backend: Remove model from failed_models on successful force retry - Frontend: GlobalContextMenu now calls API with force=true directly - Tests: Update mock to accept force parameter
206 lines
6.7 KiB
Python
206 lines
6.7 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import time
|
|
from typing import Any, Dict
|
|
|
|
import pytest
|
|
|
|
from py.services.settings_manager import get_settings_manager
|
|
from py.utils import example_images_download_manager as download_module
|
|
|
|
|
|
class RecordingWebSocketManager:
|
|
def __init__(self) -> None:
|
|
self.payloads: list[Dict[str, Any]] = []
|
|
|
|
async def broadcast(self, payload: Dict[str, Any]) -> None:
|
|
self.payloads.append(payload)
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def restore_settings() -> None:
|
|
manager = get_settings_manager()
|
|
original = manager.settings.copy()
|
|
try:
|
|
yield
|
|
finally:
|
|
manager.settings.clear()
|
|
manager.settings.update(original)
|
|
|
|
|
|
async def test_start_download_requires_configured_path(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
manager = download_module.DownloadManager(ws_manager=RecordingWebSocketManager())
|
|
|
|
# Ensure example_images_path is not configured
|
|
settings_manager = get_settings_manager()
|
|
settings_manager.settings.pop("example_images_path", None)
|
|
|
|
with pytest.raises(download_module.DownloadConfigurationError) as exc_info:
|
|
await manager.start_download({})
|
|
|
|
assert "not configured" in str(exc_info.value)
|
|
|
|
result = await manager.start_download({"auto_mode": True})
|
|
assert result["success"] is True
|
|
assert "skipping auto download" in result["message"]
|
|
|
|
|
|
async def test_start_download_bootstraps_progress_and_task(
|
|
monkeypatch: pytest.MonkeyPatch, tmp_path
|
|
) -> None:
|
|
settings_manager = get_settings_manager()
|
|
settings_manager.settings["example_images_path"] = str(tmp_path)
|
|
settings_manager.settings["libraries"] = {"default": {}}
|
|
settings_manager.settings["active_library"] = "default"
|
|
|
|
ws_manager = RecordingWebSocketManager()
|
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
|
|
|
started = asyncio.Event()
|
|
release = asyncio.Event()
|
|
|
|
async def fake_download(
|
|
self, output_dir, optimize, model_types, delay, library_name, force=False
|
|
):
|
|
started.set()
|
|
await release.wait()
|
|
async with self._state_lock:
|
|
self._is_downloading = False
|
|
self._download_task = None
|
|
self._progress["status"] = "completed"
|
|
|
|
monkeypatch.setattr(
|
|
download_module.DownloadManager,
|
|
"_download_all_example_images",
|
|
fake_download,
|
|
)
|
|
|
|
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
|
|
assert result["success"] is True
|
|
assert manager._is_downloading is True
|
|
|
|
await asyncio.wait_for(started.wait(), timeout=1)
|
|
assert ws_manager.payloads[0]["status"] == "running"
|
|
|
|
task = manager._download_task
|
|
assert task is not None
|
|
release.set()
|
|
await asyncio.wait_for(task, timeout=1)
|
|
assert manager._is_downloading is False
|
|
assert manager._progress["status"] == "completed"
|
|
|
|
|
|
async def test_pause_and_resume_flow(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
|
|
settings_manager = get_settings_manager()
|
|
settings_manager.settings["example_images_path"] = str(tmp_path)
|
|
settings_manager.settings["libraries"] = {"default": {}}
|
|
settings_manager.settings["active_library"] = "default"
|
|
|
|
ws_manager = RecordingWebSocketManager()
|
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
|
|
|
started = asyncio.Event()
|
|
release = asyncio.Event()
|
|
|
|
async def fake_download(self, *_args):
|
|
started.set()
|
|
await release.wait()
|
|
async with self._state_lock:
|
|
self._is_downloading = False
|
|
self._download_task = None
|
|
|
|
monkeypatch.setattr(
|
|
download_module.DownloadManager,
|
|
"_download_all_example_images",
|
|
fake_download,
|
|
)
|
|
|
|
await manager.start_download({})
|
|
await asyncio.wait_for(started.wait(), timeout=1)
|
|
|
|
pause_response = await manager.pause_download(object())
|
|
assert pause_response == {"success": True, "message": "Download paused"}
|
|
assert manager._progress["status"] == "paused"
|
|
|
|
resume_response = await manager.resume_download(object())
|
|
assert resume_response == {"success": True, "message": "Download resumed"}
|
|
assert manager._progress["status"] == "running"
|
|
|
|
task = manager._download_task
|
|
assert task is not None
|
|
release.set()
|
|
await asyncio.wait_for(task, timeout=1)
|
|
|
|
|
|
async def test_stop_download_transitions_to_stopped(
|
|
monkeypatch: pytest.MonkeyPatch, tmp_path
|
|
) -> None:
|
|
settings_manager = get_settings_manager()
|
|
settings_manager.settings["example_images_path"] = str(tmp_path)
|
|
settings_manager.settings["libraries"] = {"default": {}}
|
|
settings_manager.settings["active_library"] = "default"
|
|
|
|
ws_manager = RecordingWebSocketManager()
|
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
|
|
|
started = asyncio.Event()
|
|
release = asyncio.Event()
|
|
|
|
async def fake_download(self, *_args):
|
|
started.set()
|
|
await release.wait()
|
|
async with self._state_lock:
|
|
if self._stop_requested and self._progress["status"] == "stopping":
|
|
self._progress["status"] = "stopped"
|
|
else:
|
|
self._progress["status"] = "completed"
|
|
self._progress["end_time"] = time.time()
|
|
self._stop_requested = False
|
|
await self._broadcast_progress(status=self._progress["status"])
|
|
async with self._state_lock:
|
|
self._is_downloading = False
|
|
self._download_task = None
|
|
|
|
monkeypatch.setattr(
|
|
download_module.DownloadManager,
|
|
"_download_all_example_images",
|
|
fake_download,
|
|
)
|
|
|
|
await manager.start_download({})
|
|
await asyncio.wait_for(started.wait(), timeout=1)
|
|
|
|
stop_response = await manager.stop_download(object())
|
|
assert stop_response == {"success": True, "message": "Download stopping"}
|
|
assert manager._progress["status"] == "stopping"
|
|
|
|
task = manager._download_task
|
|
assert task is not None
|
|
release.set()
|
|
await asyncio.wait_for(task, timeout=1)
|
|
|
|
assert manager._progress["status"] == "stopped"
|
|
assert manager._is_downloading is False
|
|
assert manager._stop_requested is False
|
|
statuses = [payload["status"] for payload in ws_manager.payloads]
|
|
assert "stopping" in statuses
|
|
assert "stopped" in statuses
|
|
|
|
|
|
async def test_pause_or_resume_without_running_download(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
manager = download_module.DownloadManager(ws_manager=RecordingWebSocketManager())
|
|
|
|
with pytest.raises(download_module.DownloadNotRunningError):
|
|
await manager.pause_download(object())
|
|
|
|
with pytest.raises(download_module.DownloadNotRunningError):
|
|
await manager.resume_download(object())
|
|
|
|
with pytest.raises(download_module.DownloadNotRunningError):
|
|
await manager.stop_download(object())
|