feat(download): add experimental aria2 backend

This commit is contained in:
Will Miao
2026-04-19 21:46:09 +08:00
parent 0ced53c059
commit 1c530ea013
21 changed files with 1867 additions and 28 deletions

View File

@@ -179,6 +179,7 @@ async def test_successful_download_uses_defaults(
progress_callback,
model_type,
download_id,
transfer_backend=None,
):
captured.update(
{
@@ -268,6 +269,7 @@ async def test_download_uses_active_mirrors(
progress_callback,
model_type,
download_id,
transfer_backend=None,
):
captured["download_urls"] = download_urls
return {"success": True}
@@ -288,6 +290,214 @@ async def test_download_uses_active_mirrors(
assert captured["download_urls"] == ["https://mirror.example/file.safetensors"]
@pytest.mark.asyncio
async def test_pause_resume_cancel_delegate_to_aria2_backend(monkeypatch):
manager = DownloadManager()
task = asyncio.create_task(asyncio.sleep(60))
manager._download_tasks["download-1"] = task
manager._pause_events["download-1"] = download_manager.DownloadStreamControl()
manager._active_downloads["download-1"] = {
"transfer_backend": "aria2",
"status": "downloading",
}
class DummyAria2Downloader:
def __init__(self):
self.calls = []
async def pause_download(self, download_id):
self.calls.append(("pause", download_id))
return {"success": True, "message": "paused"}
async def resume_download(self, download_id):
self.calls.append(("resume", download_id))
return {"success": True, "message": "resumed"}
async def cancel_download(self, download_id):
self.calls.append(("cancel", download_id))
return {"success": True, "message": "cancelled"}
async def has_transfer(self, download_id):
self.calls.append(("has_transfer", download_id))
return True
dummy_aria2 = DummyAria2Downloader()
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=dummy_aria2),
)
pause_result = await manager.pause_download("download-1")
assert pause_result["success"] is True
assert manager._active_downloads["download-1"]["status"] == "paused"
resume_result = await manager.resume_download("download-1")
assert resume_result["success"] is True
assert manager._active_downloads["download-1"]["status"] == "downloading"
cancel_result = await manager.cancel_download("download-1")
assert cancel_result["success"] is True
assert task.cancelled() or task.done()
assert dummy_aria2.calls == [
("has_transfer", "download-1"),
("pause", "download-1"),
("has_transfer", "download-1"),
("resume", "download-1"),
("cancel", "download-1"),
]
@pytest.mark.asyncio
async def test_cancel_allows_queued_aria2_task_without_transfer(monkeypatch):
manager = DownloadManager()
started = asyncio.Event()
async def blocked_task():
started.set()
await asyncio.sleep(60)
task = asyncio.create_task(blocked_task())
await started.wait()
manager._download_tasks["download-queued"] = task
manager._active_downloads["download-queued"] = {
"transfer_backend": "aria2",
"status": "queued",
}
class DummyAria2Downloader:
async def cancel_download(self, download_id):
return {"success": False, "error": "Download task not found"}
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
result = await manager.cancel_download("download-queued")
assert result["success"] is True
assert task.cancelled() or task.done()
@pytest.mark.asyncio
async def test_pause_resume_queued_aria2_task_without_transfer(monkeypatch):
manager = DownloadManager()
task = asyncio.create_task(asyncio.sleep(60))
manager._download_tasks["download-queued"] = task
manager._pause_events["download-queued"] = download_manager.DownloadStreamControl()
manager._active_downloads["download-queued"] = {
"transfer_backend": "aria2",
"status": "waiting",
"bytes_per_second": 12.0,
}
class DummyAria2Downloader:
def __init__(self):
self.calls = []
async def has_transfer(self, download_id):
self.calls.append(("has_transfer", download_id))
return False
async def pause_download(self, download_id):
self.calls.append(("pause", download_id))
return {"success": True, "message": "paused"}
async def resume_download(self, download_id):
self.calls.append(("resume", download_id))
return {"success": True, "message": "resumed"}
dummy_aria2 = DummyAria2Downloader()
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=dummy_aria2),
)
pause_result = await manager.pause_download("download-queued")
assert pause_result == {"success": True, "message": "Download paused successfully"}
assert manager._active_downloads["download-queued"]["status"] == "paused"
assert manager._pause_events["download-queued"].is_paused() is True
resume_result = await manager.resume_download("download-queued")
assert resume_result == {"success": True, "message": "Download resumed successfully"}
assert manager._active_downloads["download-queued"]["status"] == "downloading"
assert manager._pause_events["download-queued"].is_set() is True
assert dummy_aria2.calls == [
("has_transfer", "download-queued"),
("has_transfer", "download-queued"),
]
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task
@pytest.mark.asyncio
async def test_download_uses_captured_backend_when_settings_change(
monkeypatch, scanners, metadata_provider, tmp_path
):
manager = DownloadManager()
settings = get_settings_manager()
settings.settings["download_backend"] = "aria2"
semaphore = asyncio.Semaphore(0)
manager._download_semaphore = semaphore
captured = {}
async def fake_execute_original_download(
self,
model_id,
model_version_id,
save_dir,
relative_path,
progress_callback,
use_default_paths,
download_id=None,
transfer_backend="python",
source=None,
file_params=None,
):
captured["transfer_backend"] = transfer_backend
return {"success": True}
monkeypatch.setattr(
DownloadManager,
"_execute_original_download",
fake_execute_original_download,
)
download_task = asyncio.create_task(
manager.download_from_civitai(
model_version_id=99,
save_dir=str(tmp_path),
use_default_paths=True,
progress_callback=None,
source=None,
)
)
await asyncio.sleep(0)
assert len(manager._active_downloads) == 1
download_id = next(iter(manager._active_downloads))
assert manager._active_downloads[download_id]["transfer_backend"] == "aria2"
settings.settings["download_backend"] = "python"
semaphore.release()
result = await download_task
assert result["success"] is True
assert captured["transfer_backend"] == "aria2"
@pytest.mark.asyncio
async def test_download_aborts_when_version_exists(
monkeypatch, scanners, metadata_provider