Merge branch 'main' into codex/github-mention-fixnetwork-add-connectivityguard-to-short

This commit is contained in:
pixelpaws
2026-04-20 15:54:30 +08:00
committed by GitHub
28 changed files with 4469 additions and 194 deletions

View File

@@ -0,0 +1,354 @@
from __future__ import annotations
import asyncio
from pathlib import Path
from unittest.mock import AsyncMock
import pytest
from py.services.aria2_downloader import Aria2Downloader, Aria2Error
from py.services.aria2_transfer_state import Aria2TransferStateStore
from py.services import aria2_transfer_state
@pytest.fixture(autouse=True)
def isolate_aria2_state(monkeypatch, tmp_path):
state_path = tmp_path / "cache" / "aria2" / "downloads.json"
monkeypatch.setattr(
aria2_transfer_state,
"get_aria2_state_path",
lambda: str(state_path),
)
@pytest.mark.asyncio
async def test_download_file_polls_until_complete(tmp_path, monkeypatch):
downloader = Aria2Downloader()
downloader._rpc_url = "http://127.0.0.1/jsonrpc"
downloader._rpc_secret = "secret"
save_path = tmp_path / "downloads" / "model.safetensors"
progress_events = []
rpc_calls = []
statuses = iter(
[
{
"gid": "gid-1",
"status": "active",
"completedLength": "5",
"totalLength": "10",
"downloadSpeed": "25",
},
{
"gid": "gid-1",
"status": "complete",
"completedLength": "10",
"totalLength": "10",
"downloadSpeed": "0",
"files": [{"path": str(save_path)}],
},
]
)
async def fake_rpc_call(method, params):
rpc_calls.append((method, params))
if method == "aria2.addUri":
return "gid-1"
if method == "aria2.tellStatus":
return next(statuses)
raise AssertionError(f"Unexpected RPC method: {method}")
monkeypatch.setattr(downloader, "_ensure_process", AsyncMock())
monkeypatch.setattr(
downloader,
"_resolve_authenticated_redirect_url",
AsyncMock(
return_value="https://signed.example.com/model.safetensors?token=abc"
),
)
monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call)
monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock())
async def progress_callback(progress, snapshot=None):
progress_events.append(snapshot.percent_complete if snapshot else progress)
success, result = await downloader.download_file(
"https://civitai.com/api/download/models/123",
str(save_path),
download_id="download-1",
progress_callback=progress_callback,
headers={"Authorization": "Bearer token"},
)
assert success is True
assert result == str(save_path)
assert progress_events == [50.0, 100.0]
assert downloader._transfers == {}
assert rpc_calls[0][0] == "aria2.addUri"
assert rpc_calls[0][1][0] == [
"https://signed.example.com/model.safetensors?token=abc"
]
assert rpc_calls[0][1][1]["out"] == "model.safetensors"
assert "header" not in rpc_calls[0][1][1]
@pytest.mark.asyncio
async def test_transfer_state_store_shares_lock_and_preserves_concurrent_updates(tmp_path):
state_path = tmp_path / "cache" / "aria2" / "downloads.json"
store_a = Aria2TransferStateStore(str(state_path))
store_b = Aria2TransferStateStore(str(state_path))
assert store_a._lock is store_b._lock
await asyncio.gather(
store_a.upsert("download-1", {"status": "downloading", "gid": "gid-1"}),
store_b.upsert("download-2", {"status": "paused", "gid": "gid-2"}),
)
assert await store_a.get("download-1") == {"status": "downloading", "gid": "gid-1"}
assert await store_b.get("download-2") == {"status": "paused", "gid": "gid-2"}
@pytest.mark.asyncio
async def test_download_file_keeps_auth_headers_when_civitai_does_not_redirect(
tmp_path, monkeypatch
):
downloader = Aria2Downloader()
downloader._rpc_url = "http://127.0.0.1/jsonrpc"
downloader._rpc_secret = "secret"
save_path = tmp_path / "downloads" / "model.safetensors"
rpc_calls = []
statuses = iter(
[
{
"gid": "gid-1",
"status": "complete",
"completedLength": "10",
"totalLength": "10",
"downloadSpeed": "0",
"files": [{"path": str(save_path)}],
},
]
)
async def fake_rpc_call(method, params):
rpc_calls.append((method, params))
if method == "aria2.addUri":
return "gid-1"
if method == "aria2.tellStatus":
return next(statuses)
raise AssertionError(f"Unexpected RPC method: {method}")
monkeypatch.setattr(downloader, "_ensure_process", AsyncMock())
monkeypatch.setattr(
downloader,
"_resolve_authenticated_redirect_url",
AsyncMock(return_value="https://civitai.com/api/download/models/123"),
)
monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call)
monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock())
success, result = await downloader.download_file(
"https://civitai.com/api/download/models/123",
str(save_path),
download_id="download-1",
headers={"Authorization": "Bearer token"},
)
assert success is True
assert result == str(save_path)
assert rpc_calls[0][1][0] == ["https://civitai.com/api/download/models/123"]
assert rpc_calls[0][1][1]["header"] == ["Authorization: Bearer token"]
@pytest.mark.asyncio
async def test_pause_resume_cancel_forward_to_rpc(monkeypatch):
downloader = Aria2Downloader()
downloader._transfers["download-1"] = type(
"Transfer", (), {"gid": "gid-1", "save_path": "/tmp/model.safetensors"}
)()
calls = []
async def fake_rpc_call(method, params):
calls.append((method, params))
return "gid-1"
monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call)
pause_result = await downloader.pause_download("download-1")
resume_result = await downloader.resume_download("download-1")
cancel_result = await downloader.cancel_download("download-1")
assert pause_result["success"] is True
assert resume_result["success"] is True
assert cancel_result["success"] is True
assert calls == [
("aria2.forcePause", ["gid-1"]),
("aria2.unpause", ["gid-1"]),
("aria2.forceRemove", ["gid-1"]),
]
@pytest.mark.asyncio
async def test_download_file_reuses_existing_transfer_without_add_uri(
tmp_path, monkeypatch
):
downloader = Aria2Downloader()
downloader._rpc_url = "http://127.0.0.1/jsonrpc"
downloader._rpc_secret = "secret"
save_path = tmp_path / "downloads" / "model.safetensors"
downloader._transfers["download-1"] = type(
"Transfer", (), {"gid": "gid-1", "save_path": str(save_path)}
)()
rpc_calls = []
statuses = iter(
[
{
"gid": "gid-1",
"status": "active",
"completedLength": "5",
"totalLength": "10",
"downloadSpeed": "25",
},
{
"gid": "gid-1",
"status": "complete",
"completedLength": "10",
"totalLength": "10",
"downloadSpeed": "0",
"files": [{"path": str(save_path)}],
},
]
)
async def fake_rpc_call(method, params):
rpc_calls.append((method, params))
if method == "aria2.tellStatus":
return next(statuses)
raise AssertionError(f"Unexpected RPC method: {method}")
monkeypatch.setattr(downloader, "_ensure_process", AsyncMock())
monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call)
monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock())
success, result = await downloader.download_file(
"https://example.com/model.safetensors",
str(save_path),
download_id="download-1",
)
assert success is True
assert result == str(save_path)
assert [call[0] for call in rpc_calls] == ["aria2.tellStatus", "aria2.tellStatus"]
def test_build_progress_snapshot_normalizes_numeric_fields():
downloader = Aria2Downloader()
snapshot = downloader._build_progress_snapshot(
{
"completedLength": "75",
"totalLength": "100",
"downloadSpeed": "512",
}
)
assert snapshot.percent_complete == 75.0
assert snapshot.bytes_downloaded == 75
assert snapshot.total_bytes == 100
assert snapshot.bytes_per_second == 512.0
def test_resolve_executable_raises_when_binary_missing(monkeypatch):
downloader = Aria2Downloader()
settings = type("Settings", (), {"get": lambda self, key, default=None: ""})()
monkeypatch.setattr("py.services.aria2_downloader.get_settings_manager", lambda: settings)
monkeypatch.setattr("py.services.aria2_downloader.shutil.which", lambda _: None)
with pytest.raises(Aria2Error):
downloader._resolve_executable()
@pytest.mark.asyncio
async def test_rpc_call_surfaces_json_error_on_non_200(monkeypatch):
downloader = Aria2Downloader()
downloader._rpc_url = "http://127.0.0.1:6800/jsonrpc"
downloader._rpc_secret = "secret"
class FakeResponse:
status = 400
async def text(self):
return (
'{"jsonrpc":"2.0","id":"x","error":{"code":1,"message":"Unauthorized"}}'
)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
class FakeSession:
def post(self, _url, json=None):
return FakeResponse()
monkeypatch.setattr(downloader, "_get_rpc_session", AsyncMock(return_value=FakeSession()))
with pytest.raises(Aria2Error) as exc_info:
await downloader._rpc_call("aria2.addUri", [["https://example.com/file"]])
assert "Unauthorized" in str(exc_info.value)
assert "aria2.addUri" in str(exc_info.value)
@pytest.mark.asyncio
async def test_resolve_authenticated_redirect_url_returns_location(monkeypatch):
downloader = Aria2Downloader()
class FakeResponse:
status = 307
headers = {"Location": "https://signed.example.com/file.safetensors"}
async def text(self):
return ""
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
class FakeSession:
def get(self, _url, headers=None, allow_redirects=False, proxy=None):
return FakeResponse()
class FakeDownloader:
default_headers = {"User-Agent": "ComfyUI-LoRA-Manager/1.0"}
proxy_url = None
@property
def session(self):
async def _session():
return FakeSession()
return _session()
fake_downloader = FakeDownloader()
monkeypatch.setattr(
"py.services.aria2_downloader.get_downloader",
AsyncMock(return_value=fake_downloader),
)
result = await downloader._resolve_authenticated_redirect_url(
"https://civitai.com/api/download/models/123",
{"Authorization": "Bearer token"},
)
assert result == "https://signed.example.com/file.safetensors"

View File

@@ -39,6 +39,26 @@ async def test_connectivity_guard_enters_cooldown_after_threshold():
assert guard.cooldown_remaining_seconds() > 0
async def test_connectivity_guard_scopes_cooldown_to_destination():
guard = await ConnectivityGuard.get_instance()
destination_a = "civitai.com"
destination_b = "api.github.com"
guard.register_network_failure(
OSError(errno.ENETUNREACH, "unreachable"),
destination_a,
)
guard.register_network_failure(asyncio.TimeoutError("timeout"), destination_a)
guard.register_network_failure(ConnectionRefusedError("refused"), destination_a)
assert guard.should_block_request(destination_a) is True
assert guard.should_block_request(destination_b) is False
guard.register_success(destination_a)
assert guard.should_block_request(destination_a) is False
async def test_connectivity_guard_recovers_after_success():
guard = await ConnectivityGuard.get_instance()
guard.online = False
@@ -55,21 +75,51 @@ async def test_connectivity_guard_recovers_after_success():
async def test_downloader_short_circuits_all_request_helpers_during_cooldown():
guard = await ConnectivityGuard.get_instance()
guard.cooldown_until = datetime.now() + timedelta(seconds=30)
guard.online = False
guard.failure_count = 3
destination = "example.invalid"
guard.register_network_failure(
OSError(errno.ENETUNREACH, "unreachable"),
destination,
)
guard.register_network_failure(asyncio.TimeoutError("timeout"), destination)
guard.register_network_failure(
ConnectionRefusedError("refused"),
destination,
)
downloader = Downloader()
ok, payload = await downloader.make_request("GET", "https://example.invalid")
ok, payload = await downloader.make_request("GET", f"https://{destination}")
assert ok is False
assert payload == OFFLINE_COOLDOWN_ERROR
ok, payload, headers = await downloader.download_to_memory("https://example.invalid")
ok, payload, headers = await downloader.download_to_memory(f"https://{destination}")
assert ok is False
assert payload == OFFLINE_FRIENDLY_MESSAGE
assert headers is None
ok, payload = await downloader.get_response_headers("https://example.invalid")
ok, payload = await downloader.get_response_headers(f"https://{destination}")
assert ok is False
assert payload == OFFLINE_COOLDOWN_ERROR
async def test_downloader_only_short_circuits_requests_for_same_destination():
guard = await ConnectivityGuard.get_instance()
guard.register_network_failure(
OSError(errno.ENETUNREACH, "unreachable"),
"example.invalid",
)
guard.register_network_failure(asyncio.TimeoutError("timeout"), "example.invalid")
guard.register_network_failure(
ConnectionRefusedError("refused"),
"example.invalid",
)
downloader = Downloader()
ok, payload = await downloader.make_request("GET", "https://example.invalid")
assert ok is False
assert payload == OFFLINE_COOLDOWN_ERROR
assert (
guard.should_block_request(downloader._guard_destination("https://example.com"))
is False
)

View File

@@ -10,6 +10,7 @@ import pytest
from py.services.download_manager import DownloadManager
from py.services import download_manager
from py.services import aria2_transfer_state
from py.services.service_registry import ServiceRegistry
from py.services.settings_manager import SettingsManager, get_settings_manager
@@ -46,6 +47,16 @@ def isolate_settings(monkeypatch, tmp_path):
monkeypatch.setattr(SettingsManager, "_save_settings", lambda self: None)
@pytest.fixture(autouse=True)
def isolate_aria2_state(monkeypatch, tmp_path):
state_path = tmp_path / "cache" / "aria2" / "downloads.json"
monkeypatch.setattr(
aria2_transfer_state,
"get_aria2_state_path",
lambda: str(state_path),
)
@pytest.fixture(autouse=True)
def stub_metadata(monkeypatch):
class _StubMetadata:
@@ -179,6 +190,7 @@ async def test_successful_download_uses_defaults(
progress_callback,
model_type,
download_id,
transfer_backend=None,
):
captured.update(
{
@@ -268,6 +280,7 @@ async def test_download_uses_active_mirrors(
progress_callback,
model_type,
download_id,
transfer_backend=None,
):
captured["download_urls"] = download_urls
return {"success": True}
@@ -288,6 +301,644 @@ async def test_download_uses_active_mirrors(
assert captured["download_urls"] == ["https://mirror.example/file.safetensors"]
@pytest.mark.asyncio
async def test_pause_resume_cancel_delegate_to_aria2_backend(monkeypatch):
manager = DownloadManager()
task = asyncio.create_task(asyncio.sleep(60))
manager._download_tasks["download-1"] = task
manager._pause_events["download-1"] = download_manager.DownloadStreamControl()
manager._active_downloads["download-1"] = {
"transfer_backend": "aria2",
"status": "downloading",
}
class DummyAria2Downloader:
def __init__(self):
self.calls = []
async def pause_download(self, download_id):
self.calls.append(("pause", download_id))
return {"success": True, "message": "paused"}
async def resume_download(self, download_id):
self.calls.append(("resume", download_id))
return {"success": True, "message": "resumed"}
async def cancel_download(self, download_id):
self.calls.append(("cancel", download_id))
return {"success": True, "message": "cancelled"}
async def has_transfer(self, download_id):
self.calls.append(("has_transfer", download_id))
return True
dummy_aria2 = DummyAria2Downloader()
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=dummy_aria2),
)
pause_result = await manager.pause_download("download-1")
assert pause_result["success"] is True
assert manager._active_downloads["download-1"]["status"] == "paused"
resume_result = await manager.resume_download("download-1")
assert resume_result["success"] is True
assert manager._active_downloads["download-1"]["status"] == "downloading"
cancel_result = await manager.cancel_download("download-1")
assert cancel_result["success"] is True
assert task.cancelled() or task.done()
assert dummy_aria2.calls == [
("has_transfer", "download-1"),
("pause", "download-1"),
("has_transfer", "download-1"),
("resume", "download-1"),
("cancel", "download-1"),
]
@pytest.mark.asyncio
async def test_cancel_allows_queued_aria2_task_without_transfer(monkeypatch):
manager = DownloadManager()
started = asyncio.Event()
async def blocked_task():
started.set()
await asyncio.sleep(60)
task = asyncio.create_task(blocked_task())
await started.wait()
manager._download_tasks["download-queued"] = task
manager._active_downloads["download-queued"] = {
"transfer_backend": "aria2",
"status": "queued",
}
class DummyAria2Downloader:
async def cancel_download(self, download_id):
return {"success": False, "error": "Download task not found"}
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
result = await manager.cancel_download("download-queued")
assert result["success"] is True
assert task.cancelled() or task.done()
@pytest.mark.asyncio
async def test_pause_resume_queued_aria2_task_without_transfer(monkeypatch):
manager = DownloadManager()
task = asyncio.create_task(asyncio.sleep(60))
manager._download_tasks["download-queued"] = task
manager._pause_events["download-queued"] = download_manager.DownloadStreamControl()
manager._active_downloads["download-queued"] = {
"transfer_backend": "aria2",
"status": "waiting",
"bytes_per_second": 12.0,
}
class DummyAria2Downloader:
def __init__(self):
self.calls = []
async def has_transfer(self, download_id):
self.calls.append(("has_transfer", download_id))
return False
async def pause_download(self, download_id):
self.calls.append(("pause", download_id))
return {"success": True, "message": "paused"}
async def resume_download(self, download_id):
self.calls.append(("resume", download_id))
return {"success": True, "message": "resumed"}
dummy_aria2 = DummyAria2Downloader()
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=dummy_aria2),
)
pause_result = await manager.pause_download("download-queued")
assert pause_result == {"success": True, "message": "Download paused successfully"}
assert manager._active_downloads["download-queued"]["status"] == "paused"
assert manager._pause_events["download-queued"].is_paused() is True
resume_result = await manager.resume_download("download-queued")
assert resume_result == {"success": True, "message": "Download resumed successfully"}
assert manager._active_downloads["download-queued"]["status"] == "downloading"
assert manager._pause_events["download-queued"].is_set() is True
assert dummy_aria2.calls == [
("has_transfer", "download-queued"),
("has_transfer", "download-queued"),
]
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task
@pytest.mark.asyncio
async def test_resume_download_restores_persisted_aria2_task(monkeypatch, tmp_path):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
save_path = save_dir / "file.safetensors"
save_path.write_text("partial")
(save_dir / "file.safetensors.aria2").write_text("control")
await manager._aria2_state_store.upsert(
"download-1",
{
"download_id": "download-1",
"transfer_backend": "aria2",
"status": "paused",
"save_dir": str(save_dir),
"relative_path": "",
"use_default_paths": False,
"save_path": str(save_path),
"file_path": str(save_path),
"model_id": 12,
"model_version_id": 34,
},
)
created = {}
async def fake_download_with_semaphore(
self,
task_id,
model_id,
model_version_id,
save_dir,
relative_path,
progress_callback=None,
use_default_paths=False,
source=None,
file_params=None,
):
created.update(
{
"task_id": task_id,
"model_id": model_id,
"model_version_id": model_version_id,
"save_dir": save_dir,
}
)
return {"success": True}
class DummyAria2Downloader:
def __init__(self):
self.calls = []
async def get_status_by_gid(self, gid):
return None
async def has_transfer(self, download_id):
self.calls.append(("has_transfer", download_id))
return False
async def resume_download(self, download_id):
self.calls.append(("resume", download_id))
return {"success": True, "message": "resumed"}
async def restore_transfer(self, download_id, gid, save_path):
self.calls.append(("restore_transfer", download_id, gid, save_path))
dummy_aria2 = DummyAria2Downloader()
monkeypatch.setattr(
download_manager, "_download_with_semaphore", None, raising=False
)
monkeypatch.setattr(
DownloadManager,
"_download_with_semaphore",
fake_download_with_semaphore,
)
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=dummy_aria2),
)
result = await manager.resume_download("download-1")
await asyncio.sleep(0)
assert result == {"success": True, "message": "Download resumed successfully"}
assert created["task_id"] == "download-1"
assert created["model_version_id"] == 34
assert manager._active_downloads["download-1"]["status"] == "downloading"
assert manager._pause_events["download-1"].is_set() is True
@pytest.mark.asyncio
async def test_get_active_downloads_restores_persisted_aria2_entries(monkeypatch, tmp_path):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
save_path = save_dir / "file.safetensors"
save_path.write_text("partial")
(save_dir / "file.safetensors.aria2").write_text("control")
await manager._aria2_state_store.upsert(
"download-1",
{
"download_id": "download-1",
"transfer_backend": "aria2",
"status": "paused",
"save_path": str(save_path),
"file_path": str(save_path),
"model_id": 12,
"model_version_id": 34,
},
)
class DummyAria2Downloader:
async def get_status_by_gid(self, gid):
return None
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
downloads = await manager.get_active_downloads()
assert downloads["downloads"] == [
{
"download_id": "download-1",
"model_id": 12,
"model_version_id": 34,
"progress": 0,
"status": "paused",
"error": None,
"bytes_downloaded": 0,
"total_bytes": None,
"bytes_per_second": 0.0,
}
]
@pytest.mark.asyncio
async def test_get_active_downloads_restores_orphaned_aria2_partial_as_paused(
monkeypatch, tmp_path
):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
save_path = save_dir / "file.safetensors"
save_path.write_text("partial")
(save_dir / "file.safetensors.aria2").write_text("control")
await manager._aria2_state_store.upsert(
"download-1",
{
"download_id": "download-1",
"transfer_backend": "aria2",
"status": "downloading",
"save_path": str(save_path),
"file_path": str(save_path),
"model_id": 12,
"model_version_id": 34,
"gid": "missing-gid",
},
)
class DummyAria2Downloader:
async def get_status_by_gid(self, gid):
return None
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
downloads = await manager.get_active_downloads()
persisted = await manager._aria2_state_store.get("download-1")
assert downloads["downloads"] == [
{
"download_id": "download-1",
"model_id": 12,
"model_version_id": 34,
"progress": 0,
"status": "paused",
"error": None,
"bytes_downloaded": 0,
"total_bytes": None,
"bytes_per_second": 0.0,
}
]
assert manager._pause_events["download-1"].is_paused() is True
assert persisted["status"] == "paused"
@pytest.mark.asyncio
async def test_get_active_downloads_restarts_from_resume_context_for_active_restored_aria2(
monkeypatch, tmp_path
):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
save_path = save_dir / "file.safetensors"
save_path.write_text("partial")
await manager._aria2_state_store.upsert(
"download-1",
{
"download_id": "download-1",
"transfer_backend": "aria2",
"status": "downloading",
"save_path": str(save_path),
"file_path": str(save_path),
"model_id": 12,
"model_version_id": 34,
"gid": "gid-1",
"resume_context": {
"version_info": {
"id": 34,
"modelId": 12,
"model": {"id": 12, "type": "LoRA", "tags": ["fantasy"]},
"images": [],
},
"file_info": {
"name": "file.safetensors",
"type": "Model",
"primary": True,
"downloadUrl": "https://example.com/file.safetensors",
},
"model_type": "lora",
"relative_path": "",
"save_dir": str(save_dir),
"download_urls": ["https://example.com/file.safetensors"],
},
},
)
restarted = {}
class DummyAria2Downloader:
async def get_status_by_gid(self, gid):
return {"gid": gid, "status": "active"}
async def restore_transfer(self, download_id, gid, restored_path):
return None
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
async def fake_resume_restored_aria2_download(self, download_id, record):
restarted.update(
{
"download_id": download_id,
"model_id": record.get("model_id"),
"model_version_id": record.get("model_version_id"),
"save_dir": record.get("save_dir"),
"resume_context": record.get("resume_context"),
}
)
return {"success": True}
monkeypatch.setattr(
DownloadManager,
"_resume_restored_aria2_download",
fake_resume_restored_aria2_download,
)
execute_original = AsyncMock(side_effect=AssertionError("should not refetch metadata"))
monkeypatch.setattr(
DownloadManager,
"_execute_original_download",
execute_original,
)
downloads = await manager.get_active_downloads()
assert downloads["downloads"][0]["status"] == "downloading"
restarted_task = manager._download_tasks["download-1"]
await restarted_task
assert restarted["download_id"] == "download-1"
assert restarted["model_id"] == 12
assert restarted["model_version_id"] == 34
assert restarted["save_dir"] is None
assert restarted["resume_context"]["model_type"] == "lora"
assert execute_original.await_count == 0
@pytest.mark.asyncio
async def test_get_active_downloads_restores_persisted_aria2_without_initial_save_path(
monkeypatch, tmp_path
):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
save_path = save_dir / "file.safetensors"
save_path.write_text("partial")
(save_dir / "file.safetensors.aria2").write_text("control")
await manager._aria2_state_store.upsert(
"download-1",
{
"download_id": "download-1",
"transfer_backend": "aria2",
"status": "paused",
"model_id": 12,
"model_version_id": 34,
"resume_context": {
"version_info": {
"id": 34,
"modelId": 12,
"model": {"id": 12, "type": "LoRA"},
"images": [],
},
"file_info": {
"name": "file.safetensors",
"type": "Model",
"primary": True,
"downloadUrl": "https://example.com/file.safetensors",
},
"model_type": "lora",
"relative_path": "",
"save_dir": str(save_dir),
"download_urls": ["https://example.com/file.safetensors"],
},
},
)
class DummyAria2Downloader:
async def get_status_by_gid(self, gid):
return None
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
downloads = await manager.get_active_downloads()
persisted = await manager._aria2_state_store.get("download-1")
assert downloads["downloads"] == [
{
"download_id": "download-1",
"model_id": 12,
"model_version_id": 34,
"progress": 0,
"status": "paused",
"error": None,
"bytes_downloaded": 0,
"total_bytes": None,
"bytes_per_second": 0.0,
}
]
assert manager._active_downloads["download-1"]["file_path"] == str(save_path)
assert persisted["save_path"] == str(save_path)
assert persisted["file_path"] == str(save_path)
@pytest.mark.asyncio
async def test_resume_restored_aria2_download_updates_terminal_status_and_cleanup(monkeypatch):
manager = DownloadManager()
manager._active_downloads["download-1"] = {
"transfer_backend": "aria2",
"status": "paused",
"model_id": 12,
"model_version_id": 34,
"bytes_per_second": 10.0,
}
persist_state = AsyncMock()
cleanup_record = AsyncMock(return_value=None)
execute_download = AsyncMock(return_value={"success": True})
record_history = AsyncMock(return_value=None)
sync_version = AsyncMock(return_value=None)
monkeypatch.setattr(manager, "_persist_aria2_state", persist_state)
monkeypatch.setattr(manager, "_cleanup_download_record", cleanup_record)
monkeypatch.setattr(manager, "_execute_download", execute_download)
monkeypatch.setattr(manager, "_record_downloaded_version_history", record_history)
monkeypatch.setattr(manager, "_sync_downloaded_version", sync_version)
scheduled_tasks = []
original_create_task = asyncio.create_task
def tracking_create_task(coro):
task = original_create_task(coro)
scheduled_tasks.append(task)
return task
monkeypatch.setattr(download_manager.asyncio, "create_task", tracking_create_task)
result = await manager._resume_restored_aria2_download(
"download-1",
{
"download_id": "download-1",
"save_path": "/tmp/file.safetensors",
"file_path": "/tmp/file.safetensors",
"model_id": 12,
"model_version_id": 34,
"resume_context": {
"version_info": {
"id": 34,
"modelId": 12,
"model": {"id": 12},
"images": [],
},
"file_info": {
"name": "file.safetensors",
"downloadUrl": "https://example.com/file.safetensors",
},
"model_type": "lora",
"relative_path": "",
"save_dir": "/tmp",
"download_urls": ["https://example.com/file.safetensors"],
},
},
)
assert result == {"success": True}
assert manager._active_downloads["download-1"]["status"] == "completed"
assert manager._active_downloads["download-1"]["bytes_per_second"] == 0.0
assert persist_state.await_count == 2
assert len(scheduled_tasks) == 1
await asyncio.gather(*scheduled_tasks)
cleanup_record.assert_awaited_once_with("download-1")
@pytest.mark.asyncio
async def test_download_uses_captured_backend_when_settings_change(
monkeypatch, scanners, metadata_provider, tmp_path
):
manager = DownloadManager()
settings = get_settings_manager()
settings.settings["download_backend"] = "aria2"
semaphore = asyncio.Semaphore(0)
manager._download_semaphore = semaphore
captured = {}
async def fake_execute_original_download(
self,
model_id,
model_version_id,
save_dir,
relative_path,
progress_callback,
use_default_paths,
download_id=None,
transfer_backend="python",
source=None,
file_params=None,
):
captured["transfer_backend"] = transfer_backend
return {"success": True}
monkeypatch.setattr(
DownloadManager,
"_execute_original_download",
fake_execute_original_download,
)
download_task = asyncio.create_task(
manager.download_from_civitai(
model_version_id=99,
save_dir=str(tmp_path),
use_default_paths=True,
progress_callback=None,
source=None,
)
)
await asyncio.sleep(0)
assert len(manager._active_downloads) == 1
download_id = next(iter(manager._active_downloads))
assert manager._active_downloads[download_id]["transfer_backend"] == "aria2"
settings.settings["download_backend"] = "python"
semaphore.release()
result = await download_task
assert result["success"] is True
assert captured["transfer_backend"] == "aria2"
@pytest.mark.asyncio
async def test_download_aborts_when_version_exists(
monkeypatch, scanners, metadata_provider

File diff suppressed because it is too large Load Diff

View File

@@ -147,6 +147,11 @@ def test_environment_variable_overrides_settings(tmp_path, monkeypatch):
assert mgr.get("civitai_api_key") == "secret"
def test_default_download_backend_is_python(manager):
assert manager.get("download_backend") == "python"
assert manager.get("aria2c_path") == ""
def _create_manager_with_settings(
tmp_path, monkeypatch, initial_settings, *, save_spy=None
):
@@ -327,6 +332,43 @@ def test_auto_set_default_roots_keeps_valid_values(manager):
assert manager.get("default_embedding_root") == "/embeddings"
def test_auto_set_default_roots_keeps_valid_extra_values(manager):
manager.settings["default_lora_root"] = "/extra-loras"
manager.settings["default_checkpoint_root"] = "/extra-checkpoints"
manager.settings["default_embedding_root"] = "/extra-embeddings"
manager.settings["default_unet_root"] = "/extra-unet"
manager.settings["folder_paths"] = {
"loras": ["/loras"],
"checkpoints": ["/checkpoints"],
"unet": ["/unet"],
"embeddings": ["/embeddings"],
}
manager.settings["extra_folder_paths"] = {
"loras": ["/extra-loras"],
"checkpoints": ["/extra-checkpoints"],
"unet": ["/extra-unet"],
"embeddings": ["/extra-embeddings"],
}
manager._auto_set_default_roots()
assert manager.get("default_lora_root") == "/extra-loras"
assert manager.get("default_checkpoint_root") == "/extra-checkpoints"
assert manager.get("default_unet_root") == "/extra-unet"
assert manager.get("default_embedding_root") == "/extra-embeddings"
def test_auto_set_default_roots_falls_back_to_extra_when_primary_missing(manager):
manager.settings["default_lora_root"] = ""
manager.settings["folder_paths"] = {"loras": []}
manager.settings["extra_folder_paths"] = {"loras": ["/extra-loras"]}
manager._auto_set_default_roots()
assert manager.get("default_lora_root") == "/extra-loras"
def test_delete_setting(manager):
manager.set("example", 1)
manager.delete("example")