Files
ComfyUI-Lora-Manager/tests/services/test_aria2_downloader.py
2026-04-20 09:52:48 +08:00

355 lines
11 KiB
Python

from __future__ import annotations
import asyncio
from pathlib import Path
from unittest.mock import AsyncMock
import pytest
from py.services.aria2_downloader import Aria2Downloader, Aria2Error
from py.services.aria2_transfer_state import Aria2TransferStateStore
from py.services import aria2_transfer_state
@pytest.fixture(autouse=True)
def isolate_aria2_state(monkeypatch, tmp_path):
state_path = tmp_path / "cache" / "aria2" / "downloads.json"
monkeypatch.setattr(
aria2_transfer_state,
"get_aria2_state_path",
lambda: str(state_path),
)
@pytest.mark.asyncio
async def test_download_file_polls_until_complete(tmp_path, monkeypatch):
downloader = Aria2Downloader()
downloader._rpc_url = "http://127.0.0.1/jsonrpc"
downloader._rpc_secret = "secret"
save_path = tmp_path / "downloads" / "model.safetensors"
progress_events = []
rpc_calls = []
statuses = iter(
[
{
"gid": "gid-1",
"status": "active",
"completedLength": "5",
"totalLength": "10",
"downloadSpeed": "25",
},
{
"gid": "gid-1",
"status": "complete",
"completedLength": "10",
"totalLength": "10",
"downloadSpeed": "0",
"files": [{"path": str(save_path)}],
},
]
)
async def fake_rpc_call(method, params):
rpc_calls.append((method, params))
if method == "aria2.addUri":
return "gid-1"
if method == "aria2.tellStatus":
return next(statuses)
raise AssertionError(f"Unexpected RPC method: {method}")
monkeypatch.setattr(downloader, "_ensure_process", AsyncMock())
monkeypatch.setattr(
downloader,
"_resolve_authenticated_redirect_url",
AsyncMock(
return_value="https://signed.example.com/model.safetensors?token=abc"
),
)
monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call)
monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock())
async def progress_callback(progress, snapshot=None):
progress_events.append(snapshot.percent_complete if snapshot else progress)
success, result = await downloader.download_file(
"https://civitai.com/api/download/models/123",
str(save_path),
download_id="download-1",
progress_callback=progress_callback,
headers={"Authorization": "Bearer token"},
)
assert success is True
assert result == str(save_path)
assert progress_events == [50.0, 100.0]
assert downloader._transfers == {}
assert rpc_calls[0][0] == "aria2.addUri"
assert rpc_calls[0][1][0] == [
"https://signed.example.com/model.safetensors?token=abc"
]
assert rpc_calls[0][1][1]["out"] == "model.safetensors"
assert "header" not in rpc_calls[0][1][1]
@pytest.mark.asyncio
async def test_transfer_state_store_shares_lock_and_preserves_concurrent_updates(tmp_path):
state_path = tmp_path / "cache" / "aria2" / "downloads.json"
store_a = Aria2TransferStateStore(str(state_path))
store_b = Aria2TransferStateStore(str(state_path))
assert store_a._lock is store_b._lock
await asyncio.gather(
store_a.upsert("download-1", {"status": "downloading", "gid": "gid-1"}),
store_b.upsert("download-2", {"status": "paused", "gid": "gid-2"}),
)
assert await store_a.get("download-1") == {"status": "downloading", "gid": "gid-1"}
assert await store_b.get("download-2") == {"status": "paused", "gid": "gid-2"}
@pytest.mark.asyncio
async def test_download_file_keeps_auth_headers_when_civitai_does_not_redirect(
tmp_path, monkeypatch
):
downloader = Aria2Downloader()
downloader._rpc_url = "http://127.0.0.1/jsonrpc"
downloader._rpc_secret = "secret"
save_path = tmp_path / "downloads" / "model.safetensors"
rpc_calls = []
statuses = iter(
[
{
"gid": "gid-1",
"status": "complete",
"completedLength": "10",
"totalLength": "10",
"downloadSpeed": "0",
"files": [{"path": str(save_path)}],
},
]
)
async def fake_rpc_call(method, params):
rpc_calls.append((method, params))
if method == "aria2.addUri":
return "gid-1"
if method == "aria2.tellStatus":
return next(statuses)
raise AssertionError(f"Unexpected RPC method: {method}")
monkeypatch.setattr(downloader, "_ensure_process", AsyncMock())
monkeypatch.setattr(
downloader,
"_resolve_authenticated_redirect_url",
AsyncMock(return_value="https://civitai.com/api/download/models/123"),
)
monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call)
monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock())
success, result = await downloader.download_file(
"https://civitai.com/api/download/models/123",
str(save_path),
download_id="download-1",
headers={"Authorization": "Bearer token"},
)
assert success is True
assert result == str(save_path)
assert rpc_calls[0][1][0] == ["https://civitai.com/api/download/models/123"]
assert rpc_calls[0][1][1]["header"] == ["Authorization: Bearer token"]
@pytest.mark.asyncio
async def test_pause_resume_cancel_forward_to_rpc(monkeypatch):
downloader = Aria2Downloader()
downloader._transfers["download-1"] = type(
"Transfer", (), {"gid": "gid-1", "save_path": "/tmp/model.safetensors"}
)()
calls = []
async def fake_rpc_call(method, params):
calls.append((method, params))
return "gid-1"
monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call)
pause_result = await downloader.pause_download("download-1")
resume_result = await downloader.resume_download("download-1")
cancel_result = await downloader.cancel_download("download-1")
assert pause_result["success"] is True
assert resume_result["success"] is True
assert cancel_result["success"] is True
assert calls == [
("aria2.forcePause", ["gid-1"]),
("aria2.unpause", ["gid-1"]),
("aria2.forceRemove", ["gid-1"]),
]
@pytest.mark.asyncio
async def test_download_file_reuses_existing_transfer_without_add_uri(
tmp_path, monkeypatch
):
downloader = Aria2Downloader()
downloader._rpc_url = "http://127.0.0.1/jsonrpc"
downloader._rpc_secret = "secret"
save_path = tmp_path / "downloads" / "model.safetensors"
downloader._transfers["download-1"] = type(
"Transfer", (), {"gid": "gid-1", "save_path": str(save_path)}
)()
rpc_calls = []
statuses = iter(
[
{
"gid": "gid-1",
"status": "active",
"completedLength": "5",
"totalLength": "10",
"downloadSpeed": "25",
},
{
"gid": "gid-1",
"status": "complete",
"completedLength": "10",
"totalLength": "10",
"downloadSpeed": "0",
"files": [{"path": str(save_path)}],
},
]
)
async def fake_rpc_call(method, params):
rpc_calls.append((method, params))
if method == "aria2.tellStatus":
return next(statuses)
raise AssertionError(f"Unexpected RPC method: {method}")
monkeypatch.setattr(downloader, "_ensure_process", AsyncMock())
monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call)
monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock())
success, result = await downloader.download_file(
"https://example.com/model.safetensors",
str(save_path),
download_id="download-1",
)
assert success is True
assert result == str(save_path)
assert [call[0] for call in rpc_calls] == ["aria2.tellStatus", "aria2.tellStatus"]
def test_build_progress_snapshot_normalizes_numeric_fields():
downloader = Aria2Downloader()
snapshot = downloader._build_progress_snapshot(
{
"completedLength": "75",
"totalLength": "100",
"downloadSpeed": "512",
}
)
assert snapshot.percent_complete == 75.0
assert snapshot.bytes_downloaded == 75
assert snapshot.total_bytes == 100
assert snapshot.bytes_per_second == 512.0
def test_resolve_executable_raises_when_binary_missing(monkeypatch):
downloader = Aria2Downloader()
settings = type("Settings", (), {"get": lambda self, key, default=None: ""})()
monkeypatch.setattr("py.services.aria2_downloader.get_settings_manager", lambda: settings)
monkeypatch.setattr("py.services.aria2_downloader.shutil.which", lambda _: None)
with pytest.raises(Aria2Error):
downloader._resolve_executable()
@pytest.mark.asyncio
async def test_rpc_call_surfaces_json_error_on_non_200(monkeypatch):
downloader = Aria2Downloader()
downloader._rpc_url = "http://127.0.0.1:6800/jsonrpc"
downloader._rpc_secret = "secret"
class FakeResponse:
status = 400
async def text(self):
return (
'{"jsonrpc":"2.0","id":"x","error":{"code":1,"message":"Unauthorized"}}'
)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
class FakeSession:
def post(self, _url, json=None):
return FakeResponse()
monkeypatch.setattr(downloader, "_get_rpc_session", AsyncMock(return_value=FakeSession()))
with pytest.raises(Aria2Error) as exc_info:
await downloader._rpc_call("aria2.addUri", [["https://example.com/file"]])
assert "Unauthorized" in str(exc_info.value)
assert "aria2.addUri" in str(exc_info.value)
@pytest.mark.asyncio
async def test_resolve_authenticated_redirect_url_returns_location(monkeypatch):
downloader = Aria2Downloader()
class FakeResponse:
status = 307
headers = {"Location": "https://signed.example.com/file.safetensors"}
async def text(self):
return ""
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
class FakeSession:
def get(self, _url, headers=None, allow_redirects=False, proxy=None):
return FakeResponse()
class FakeDownloader:
default_headers = {"User-Agent": "ComfyUI-LoRA-Manager/1.0"}
proxy_url = None
@property
def session(self):
async def _session():
return FakeSession()
return _session()
fake_downloader = FakeDownloader()
monkeypatch.setattr(
"py.services.aria2_downloader.get_downloader",
AsyncMock(return_value=fake_downloader),
)
result = await downloader._resolve_authenticated_redirect_url(
"https://civitai.com/api/download/models/123",
{"Authorization": "Bearer token"},
)
assert result == "https://signed.example.com/file.safetensors"