mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-06 16:36:45 -03:00
feat(download): add experimental aria2 backend
This commit is contained in:
@@ -179,6 +179,7 @@ async def test_successful_download_uses_defaults(
|
||||
progress_callback,
|
||||
model_type,
|
||||
download_id,
|
||||
transfer_backend=None,
|
||||
):
|
||||
captured.update(
|
||||
{
|
||||
@@ -268,6 +269,7 @@ async def test_download_uses_active_mirrors(
|
||||
progress_callback,
|
||||
model_type,
|
||||
download_id,
|
||||
transfer_backend=None,
|
||||
):
|
||||
captured["download_urls"] = download_urls
|
||||
return {"success": True}
|
||||
@@ -288,6 +290,214 @@ async def test_download_uses_active_mirrors(
|
||||
assert captured["download_urls"] == ["https://mirror.example/file.safetensors"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_resume_cancel_delegate_to_aria2_backend(monkeypatch):
|
||||
manager = DownloadManager()
|
||||
|
||||
task = asyncio.create_task(asyncio.sleep(60))
|
||||
manager._download_tasks["download-1"] = task
|
||||
manager._pause_events["download-1"] = download_manager.DownloadStreamControl()
|
||||
manager._active_downloads["download-1"] = {
|
||||
"transfer_backend": "aria2",
|
||||
"status": "downloading",
|
||||
}
|
||||
|
||||
class DummyAria2Downloader:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
async def pause_download(self, download_id):
|
||||
self.calls.append(("pause", download_id))
|
||||
return {"success": True, "message": "paused"}
|
||||
|
||||
async def resume_download(self, download_id):
|
||||
self.calls.append(("resume", download_id))
|
||||
return {"success": True, "message": "resumed"}
|
||||
|
||||
async def cancel_download(self, download_id):
|
||||
self.calls.append(("cancel", download_id))
|
||||
return {"success": True, "message": "cancelled"}
|
||||
|
||||
async def has_transfer(self, download_id):
|
||||
self.calls.append(("has_transfer", download_id))
|
||||
return True
|
||||
|
||||
dummy_aria2 = DummyAria2Downloader()
|
||||
monkeypatch.setattr(
|
||||
download_manager,
|
||||
"get_aria2_downloader",
|
||||
AsyncMock(return_value=dummy_aria2),
|
||||
)
|
||||
|
||||
pause_result = await manager.pause_download("download-1")
|
||||
assert pause_result["success"] is True
|
||||
assert manager._active_downloads["download-1"]["status"] == "paused"
|
||||
|
||||
resume_result = await manager.resume_download("download-1")
|
||||
assert resume_result["success"] is True
|
||||
assert manager._active_downloads["download-1"]["status"] == "downloading"
|
||||
|
||||
cancel_result = await manager.cancel_download("download-1")
|
||||
assert cancel_result["success"] is True
|
||||
assert task.cancelled() or task.done()
|
||||
assert dummy_aria2.calls == [
|
||||
("has_transfer", "download-1"),
|
||||
("pause", "download-1"),
|
||||
("has_transfer", "download-1"),
|
||||
("resume", "download-1"),
|
||||
("cancel", "download-1"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_allows_queued_aria2_task_without_transfer(monkeypatch):
|
||||
manager = DownloadManager()
|
||||
|
||||
started = asyncio.Event()
|
||||
|
||||
async def blocked_task():
|
||||
started.set()
|
||||
await asyncio.sleep(60)
|
||||
|
||||
task = asyncio.create_task(blocked_task())
|
||||
await started.wait()
|
||||
|
||||
manager._download_tasks["download-queued"] = task
|
||||
manager._active_downloads["download-queued"] = {
|
||||
"transfer_backend": "aria2",
|
||||
"status": "queued",
|
||||
}
|
||||
|
||||
class DummyAria2Downloader:
|
||||
async def cancel_download(self, download_id):
|
||||
return {"success": False, "error": "Download task not found"}
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_manager,
|
||||
"get_aria2_downloader",
|
||||
AsyncMock(return_value=DummyAria2Downloader()),
|
||||
)
|
||||
|
||||
result = await manager.cancel_download("download-queued")
|
||||
|
||||
assert result["success"] is True
|
||||
assert task.cancelled() or task.done()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_resume_queued_aria2_task_without_transfer(monkeypatch):
|
||||
manager = DownloadManager()
|
||||
|
||||
task = asyncio.create_task(asyncio.sleep(60))
|
||||
manager._download_tasks["download-queued"] = task
|
||||
manager._pause_events["download-queued"] = download_manager.DownloadStreamControl()
|
||||
manager._active_downloads["download-queued"] = {
|
||||
"transfer_backend": "aria2",
|
||||
"status": "waiting",
|
||||
"bytes_per_second": 12.0,
|
||||
}
|
||||
|
||||
class DummyAria2Downloader:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
async def has_transfer(self, download_id):
|
||||
self.calls.append(("has_transfer", download_id))
|
||||
return False
|
||||
|
||||
async def pause_download(self, download_id):
|
||||
self.calls.append(("pause", download_id))
|
||||
return {"success": True, "message": "paused"}
|
||||
|
||||
async def resume_download(self, download_id):
|
||||
self.calls.append(("resume", download_id))
|
||||
return {"success": True, "message": "resumed"}
|
||||
|
||||
dummy_aria2 = DummyAria2Downloader()
|
||||
monkeypatch.setattr(
|
||||
download_manager,
|
||||
"get_aria2_downloader",
|
||||
AsyncMock(return_value=dummy_aria2),
|
||||
)
|
||||
|
||||
pause_result = await manager.pause_download("download-queued")
|
||||
assert pause_result == {"success": True, "message": "Download paused successfully"}
|
||||
assert manager._active_downloads["download-queued"]["status"] == "paused"
|
||||
assert manager._pause_events["download-queued"].is_paused() is True
|
||||
|
||||
resume_result = await manager.resume_download("download-queued")
|
||||
assert resume_result == {"success": True, "message": "Download resumed successfully"}
|
||||
assert manager._active_downloads["download-queued"]["status"] == "downloading"
|
||||
assert manager._pause_events["download-queued"].is_set() is True
|
||||
assert dummy_aria2.calls == [
|
||||
("has_transfer", "download-queued"),
|
||||
("has_transfer", "download-queued"),
|
||||
]
|
||||
|
||||
task.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_uses_captured_backend_when_settings_change(
|
||||
monkeypatch, scanners, metadata_provider, tmp_path
|
||||
):
|
||||
manager = DownloadManager()
|
||||
settings = get_settings_manager()
|
||||
settings.settings["download_backend"] = "aria2"
|
||||
|
||||
semaphore = asyncio.Semaphore(0)
|
||||
manager._download_semaphore = semaphore
|
||||
|
||||
captured = {}
|
||||
|
||||
async def fake_execute_original_download(
|
||||
self,
|
||||
model_id,
|
||||
model_version_id,
|
||||
save_dir,
|
||||
relative_path,
|
||||
progress_callback,
|
||||
use_default_paths,
|
||||
download_id=None,
|
||||
transfer_backend="python",
|
||||
source=None,
|
||||
file_params=None,
|
||||
):
|
||||
captured["transfer_backend"] = transfer_backend
|
||||
return {"success": True}
|
||||
|
||||
monkeypatch.setattr(
|
||||
DownloadManager,
|
||||
"_execute_original_download",
|
||||
fake_execute_original_download,
|
||||
)
|
||||
|
||||
download_task = asyncio.create_task(
|
||||
manager.download_from_civitai(
|
||||
model_version_id=99,
|
||||
save_dir=str(tmp_path),
|
||||
use_default_paths=True,
|
||||
progress_callback=None,
|
||||
source=None,
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.sleep(0)
|
||||
assert len(manager._active_downloads) == 1
|
||||
download_id = next(iter(manager._active_downloads))
|
||||
assert manager._active_downloads[download_id]["transfer_backend"] == "aria2"
|
||||
|
||||
settings.settings["download_backend"] = "python"
|
||||
semaphore.release()
|
||||
|
||||
result = await download_task
|
||||
|
||||
assert result["success"] is True
|
||||
assert captured["transfer_backend"] == "aria2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_aborts_when_version_exists(
|
||||
monkeypatch, scanners, metadata_provider
|
||||
|
||||
Reference in New Issue
Block a user