mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-06 16:36:45 -03:00
feat(download): add experimental aria2 backend
This commit is contained in:
269
tests/services/test_aria2_downloader.py
Normal file
269
tests/services/test_aria2_downloader.py
Normal file
@@ -0,0 +1,269 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.aria2_downloader import Aria2Downloader, Aria2Error
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_file_polls_until_complete(tmp_path, monkeypatch):
|
||||
downloader = Aria2Downloader()
|
||||
downloader._rpc_url = "http://127.0.0.1/jsonrpc"
|
||||
downloader._rpc_secret = "secret"
|
||||
|
||||
save_path = tmp_path / "downloads" / "model.safetensors"
|
||||
progress_events = []
|
||||
rpc_calls = []
|
||||
statuses = iter(
|
||||
[
|
||||
{
|
||||
"gid": "gid-1",
|
||||
"status": "active",
|
||||
"completedLength": "5",
|
||||
"totalLength": "10",
|
||||
"downloadSpeed": "25",
|
||||
},
|
||||
{
|
||||
"gid": "gid-1",
|
||||
"status": "complete",
|
||||
"completedLength": "10",
|
||||
"totalLength": "10",
|
||||
"downloadSpeed": "0",
|
||||
"files": [{"path": str(save_path)}],
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
async def fake_rpc_call(method, params):
|
||||
rpc_calls.append((method, params))
|
||||
if method == "aria2.addUri":
|
||||
return "gid-1"
|
||||
if method == "aria2.tellStatus":
|
||||
return next(statuses)
|
||||
raise AssertionError(f"Unexpected RPC method: {method}")
|
||||
|
||||
monkeypatch.setattr(downloader, "_ensure_process", AsyncMock())
|
||||
monkeypatch.setattr(
|
||||
downloader,
|
||||
"_resolve_authenticated_redirect_url",
|
||||
AsyncMock(
|
||||
return_value="https://signed.example.com/model.safetensors?token=abc"
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call)
|
||||
monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock())
|
||||
|
||||
async def progress_callback(progress, snapshot=None):
|
||||
progress_events.append(snapshot.percent_complete if snapshot else progress)
|
||||
|
||||
success, result = await downloader.download_file(
|
||||
"https://civitai.com/api/download/models/123",
|
||||
str(save_path),
|
||||
download_id="download-1",
|
||||
progress_callback=progress_callback,
|
||||
headers={"Authorization": "Bearer token"},
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert result == str(save_path)
|
||||
assert progress_events == [50.0, 100.0]
|
||||
assert downloader._transfers == {}
|
||||
assert rpc_calls[0][0] == "aria2.addUri"
|
||||
assert rpc_calls[0][1][0] == [
|
||||
"https://signed.example.com/model.safetensors?token=abc"
|
||||
]
|
||||
assert rpc_calls[0][1][1]["out"] == "model.safetensors"
|
||||
assert "header" not in rpc_calls[0][1][1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_file_keeps_auth_headers_when_civitai_does_not_redirect(
|
||||
tmp_path, monkeypatch
|
||||
):
|
||||
downloader = Aria2Downloader()
|
||||
downloader._rpc_url = "http://127.0.0.1/jsonrpc"
|
||||
downloader._rpc_secret = "secret"
|
||||
|
||||
save_path = tmp_path / "downloads" / "model.safetensors"
|
||||
rpc_calls = []
|
||||
statuses = iter(
|
||||
[
|
||||
{
|
||||
"gid": "gid-1",
|
||||
"status": "complete",
|
||||
"completedLength": "10",
|
||||
"totalLength": "10",
|
||||
"downloadSpeed": "0",
|
||||
"files": [{"path": str(save_path)}],
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
async def fake_rpc_call(method, params):
|
||||
rpc_calls.append((method, params))
|
||||
if method == "aria2.addUri":
|
||||
return "gid-1"
|
||||
if method == "aria2.tellStatus":
|
||||
return next(statuses)
|
||||
raise AssertionError(f"Unexpected RPC method: {method}")
|
||||
|
||||
monkeypatch.setattr(downloader, "_ensure_process", AsyncMock())
|
||||
monkeypatch.setattr(
|
||||
downloader,
|
||||
"_resolve_authenticated_redirect_url",
|
||||
AsyncMock(return_value="https://civitai.com/api/download/models/123"),
|
||||
)
|
||||
monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call)
|
||||
monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock())
|
||||
|
||||
success, result = await downloader.download_file(
|
||||
"https://civitai.com/api/download/models/123",
|
||||
str(save_path),
|
||||
download_id="download-1",
|
||||
headers={"Authorization": "Bearer token"},
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert result == str(save_path)
|
||||
assert rpc_calls[0][1][0] == ["https://civitai.com/api/download/models/123"]
|
||||
assert rpc_calls[0][1][1]["header"] == ["Authorization: Bearer token"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_resume_cancel_forward_to_rpc(monkeypatch):
|
||||
downloader = Aria2Downloader()
|
||||
downloader._transfers["download-1"] = type(
|
||||
"Transfer", (), {"gid": "gid-1", "save_path": "/tmp/model.safetensors"}
|
||||
)()
|
||||
|
||||
calls = []
|
||||
|
||||
async def fake_rpc_call(method, params):
|
||||
calls.append((method, params))
|
||||
return "gid-1"
|
||||
|
||||
monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call)
|
||||
|
||||
pause_result = await downloader.pause_download("download-1")
|
||||
resume_result = await downloader.resume_download("download-1")
|
||||
cancel_result = await downloader.cancel_download("download-1")
|
||||
|
||||
assert pause_result["success"] is True
|
||||
assert resume_result["success"] is True
|
||||
assert cancel_result["success"] is True
|
||||
assert calls == [
|
||||
("aria2.forcePause", ["gid-1"]),
|
||||
("aria2.unpause", ["gid-1"]),
|
||||
("aria2.forceRemove", ["gid-1"]),
|
||||
]
|
||||
|
||||
|
||||
def test_build_progress_snapshot_normalizes_numeric_fields():
|
||||
downloader = Aria2Downloader()
|
||||
|
||||
snapshot = downloader._build_progress_snapshot(
|
||||
{
|
||||
"completedLength": "75",
|
||||
"totalLength": "100",
|
||||
"downloadSpeed": "512",
|
||||
}
|
||||
)
|
||||
|
||||
assert snapshot.percent_complete == 75.0
|
||||
assert snapshot.bytes_downloaded == 75
|
||||
assert snapshot.total_bytes == 100
|
||||
assert snapshot.bytes_per_second == 512.0
|
||||
|
||||
|
||||
def test_resolve_executable_raises_when_binary_missing(monkeypatch):
|
||||
downloader = Aria2Downloader()
|
||||
settings = type("Settings", (), {"get": lambda self, key, default=None: ""})()
|
||||
|
||||
monkeypatch.setattr("py.services.aria2_downloader.get_settings_manager", lambda: settings)
|
||||
monkeypatch.setattr("py.services.aria2_downloader.shutil.which", lambda _: None)
|
||||
|
||||
with pytest.raises(Aria2Error):
|
||||
downloader._resolve_executable()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rpc_call_surfaces_json_error_on_non_200(monkeypatch):
|
||||
downloader = Aria2Downloader()
|
||||
downloader._rpc_url = "http://127.0.0.1:6800/jsonrpc"
|
||||
downloader._rpc_secret = "secret"
|
||||
|
||||
class FakeResponse:
|
||||
status = 400
|
||||
|
||||
async def text(self):
|
||||
return (
|
||||
'{"jsonrpc":"2.0","id":"x","error":{"code":1,"message":"Unauthorized"}}'
|
||||
)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
class FakeSession:
|
||||
def post(self, _url, json=None):
|
||||
return FakeResponse()
|
||||
|
||||
monkeypatch.setattr(downloader, "_get_rpc_session", AsyncMock(return_value=FakeSession()))
|
||||
|
||||
with pytest.raises(Aria2Error) as exc_info:
|
||||
await downloader._rpc_call("aria2.addUri", [["https://example.com/file"]])
|
||||
|
||||
assert "Unauthorized" in str(exc_info.value)
|
||||
assert "aria2.addUri" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_authenticated_redirect_url_returns_location(monkeypatch):
|
||||
downloader = Aria2Downloader()
|
||||
|
||||
class FakeResponse:
|
||||
status = 307
|
||||
headers = {"Location": "https://signed.example.com/file.safetensors"}
|
||||
|
||||
async def text(self):
|
||||
return ""
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
class FakeSession:
|
||||
def get(self, _url, headers=None, allow_redirects=False, proxy=None):
|
||||
return FakeResponse()
|
||||
|
||||
class FakeDownloader:
|
||||
default_headers = {"User-Agent": "ComfyUI-LoRA-Manager/1.0"}
|
||||
proxy_url = None
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
async def _session():
|
||||
return FakeSession()
|
||||
|
||||
return _session()
|
||||
|
||||
fake_downloader = FakeDownloader()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"py.services.aria2_downloader.get_downloader",
|
||||
AsyncMock(return_value=fake_downloader),
|
||||
)
|
||||
|
||||
result = await downloader._resolve_authenticated_redirect_url(
|
||||
"https://civitai.com/api/download/models/123",
|
||||
{"Authorization": "Bearer token"},
|
||||
)
|
||||
|
||||
assert result == "https://signed.example.com/file.safetensors"
|
||||
Reference in New Issue
Block a user