fix(download): restore aria2 resume lifecycle

This commit is contained in:
Will Miao
2026-04-20 09:52:48 +08:00
parent 24dd3a777c
commit 761108bfd1
6 changed files with 2123 additions and 120 deletions

View File

@@ -1,11 +1,24 @@
from __future__ import annotations
import asyncio
from pathlib import Path
from unittest.mock import AsyncMock
import pytest
from py.services.aria2_downloader import Aria2Downloader, Aria2Error
from py.services.aria2_transfer_state import Aria2TransferStateStore
from py.services import aria2_transfer_state
@pytest.fixture(autouse=True)
def isolate_aria2_state(monkeypatch, tmp_path):
state_path = tmp_path / "cache" / "aria2" / "downloads.json"
monkeypatch.setattr(
aria2_transfer_state,
"get_aria2_state_path",
lambda: str(state_path),
)
@pytest.mark.asyncio
@@ -79,6 +92,23 @@ async def test_download_file_polls_until_complete(tmp_path, monkeypatch):
assert "header" not in rpc_calls[0][1][1]
@pytest.mark.asyncio
async def test_transfer_state_store_shares_lock_and_preserves_concurrent_updates(tmp_path):
state_path = tmp_path / "cache" / "aria2" / "downloads.json"
store_a = Aria2TransferStateStore(str(state_path))
store_b = Aria2TransferStateStore(str(state_path))
assert store_a._lock is store_b._lock
await asyncio.gather(
store_a.upsert("download-1", {"status": "downloading", "gid": "gid-1"}),
store_b.upsert("download-2", {"status": "paused", "gid": "gid-2"}),
)
assert await store_a.get("download-1") == {"status": "downloading", "gid": "gid-1"}
assert await store_b.get("download-2") == {"status": "paused", "gid": "gid-2"}
@pytest.mark.asyncio
async def test_download_file_keeps_auth_headers_when_civitai_does_not_redirect(
tmp_path, monkeypatch
@@ -161,6 +191,61 @@ async def test_pause_resume_cancel_forward_to_rpc(monkeypatch):
]
@pytest.mark.asyncio
async def test_download_file_reuses_existing_transfer_without_add_uri(
tmp_path, monkeypatch
):
downloader = Aria2Downloader()
downloader._rpc_url = "http://127.0.0.1/jsonrpc"
downloader._rpc_secret = "secret"
save_path = tmp_path / "downloads" / "model.safetensors"
downloader._transfers["download-1"] = type(
"Transfer", (), {"gid": "gid-1", "save_path": str(save_path)}
)()
rpc_calls = []
statuses = iter(
[
{
"gid": "gid-1",
"status": "active",
"completedLength": "5",
"totalLength": "10",
"downloadSpeed": "25",
},
{
"gid": "gid-1",
"status": "complete",
"completedLength": "10",
"totalLength": "10",
"downloadSpeed": "0",
"files": [{"path": str(save_path)}],
},
]
)
async def fake_rpc_call(method, params):
rpc_calls.append((method, params))
if method == "aria2.tellStatus":
return next(statuses)
raise AssertionError(f"Unexpected RPC method: {method}")
monkeypatch.setattr(downloader, "_ensure_process", AsyncMock())
monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call)
monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock())
success, result = await downloader.download_file(
"https://example.com/model.safetensors",
str(save_path),
download_id="download-1",
)
assert success is True
assert result == str(save_path)
assert [call[0] for call in rpc_calls] == ["aria2.tellStatus", "aria2.tellStatus"]
def test_build_progress_snapshot_normalizes_numeric_fields():
downloader = Aria2Downloader()

View File

@@ -10,6 +10,7 @@ import pytest
from py.services.download_manager import DownloadManager
from py.services import download_manager
from py.services import aria2_transfer_state
from py.services.service_registry import ServiceRegistry
from py.services.settings_manager import SettingsManager, get_settings_manager
@@ -46,6 +47,16 @@ def isolate_settings(monkeypatch, tmp_path):
monkeypatch.setattr(SettingsManager, "_save_settings", lambda self: None)
@pytest.fixture(autouse=True)
def isolate_aria2_state(monkeypatch, tmp_path):
state_path = tmp_path / "cache" / "aria2" / "downloads.json"
monkeypatch.setattr(
aria2_transfer_state,
"get_aria2_state_path",
lambda: str(state_path),
)
@pytest.fixture(autouse=True)
def stub_metadata(monkeypatch):
class _StubMetadata:
@@ -439,6 +450,436 @@ async def test_pause_resume_queued_aria2_task_without_transfer(monkeypatch):
await task
@pytest.mark.asyncio
async def test_resume_download_restores_persisted_aria2_task(monkeypatch, tmp_path):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
save_path = save_dir / "file.safetensors"
save_path.write_text("partial")
(save_dir / "file.safetensors.aria2").write_text("control")
await manager._aria2_state_store.upsert(
"download-1",
{
"download_id": "download-1",
"transfer_backend": "aria2",
"status": "paused",
"save_dir": str(save_dir),
"relative_path": "",
"use_default_paths": False,
"save_path": str(save_path),
"file_path": str(save_path),
"model_id": 12,
"model_version_id": 34,
},
)
created = {}
async def fake_download_with_semaphore(
self,
task_id,
model_id,
model_version_id,
save_dir,
relative_path,
progress_callback=None,
use_default_paths=False,
source=None,
file_params=None,
):
created.update(
{
"task_id": task_id,
"model_id": model_id,
"model_version_id": model_version_id,
"save_dir": save_dir,
}
)
return {"success": True}
class DummyAria2Downloader:
def __init__(self):
self.calls = []
async def get_status_by_gid(self, gid):
return None
async def has_transfer(self, download_id):
self.calls.append(("has_transfer", download_id))
return False
async def resume_download(self, download_id):
self.calls.append(("resume", download_id))
return {"success": True, "message": "resumed"}
async def restore_transfer(self, download_id, gid, save_path):
self.calls.append(("restore_transfer", download_id, gid, save_path))
dummy_aria2 = DummyAria2Downloader()
monkeypatch.setattr(
download_manager, "_download_with_semaphore", None, raising=False
)
monkeypatch.setattr(
DownloadManager,
"_download_with_semaphore",
fake_download_with_semaphore,
)
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=dummy_aria2),
)
result = await manager.resume_download("download-1")
await asyncio.sleep(0)
assert result == {"success": True, "message": "Download resumed successfully"}
assert created["task_id"] == "download-1"
assert created["model_version_id"] == 34
assert manager._active_downloads["download-1"]["status"] == "downloading"
assert manager._pause_events["download-1"].is_set() is True
@pytest.mark.asyncio
async def test_get_active_downloads_restores_persisted_aria2_entries(monkeypatch, tmp_path):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
save_path = save_dir / "file.safetensors"
save_path.write_text("partial")
(save_dir / "file.safetensors.aria2").write_text("control")
await manager._aria2_state_store.upsert(
"download-1",
{
"download_id": "download-1",
"transfer_backend": "aria2",
"status": "paused",
"save_path": str(save_path),
"file_path": str(save_path),
"model_id": 12,
"model_version_id": 34,
},
)
class DummyAria2Downloader:
async def get_status_by_gid(self, gid):
return None
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
downloads = await manager.get_active_downloads()
assert downloads["downloads"] == [
{
"download_id": "download-1",
"model_id": 12,
"model_version_id": 34,
"progress": 0,
"status": "paused",
"error": None,
"bytes_downloaded": 0,
"total_bytes": None,
"bytes_per_second": 0.0,
}
]
@pytest.mark.asyncio
async def test_get_active_downloads_restores_orphaned_aria2_partial_as_paused(
monkeypatch, tmp_path
):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
save_path = save_dir / "file.safetensors"
save_path.write_text("partial")
(save_dir / "file.safetensors.aria2").write_text("control")
await manager._aria2_state_store.upsert(
"download-1",
{
"download_id": "download-1",
"transfer_backend": "aria2",
"status": "downloading",
"save_path": str(save_path),
"file_path": str(save_path),
"model_id": 12,
"model_version_id": 34,
"gid": "missing-gid",
},
)
class DummyAria2Downloader:
async def get_status_by_gid(self, gid):
return None
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
downloads = await manager.get_active_downloads()
persisted = await manager._aria2_state_store.get("download-1")
assert downloads["downloads"] == [
{
"download_id": "download-1",
"model_id": 12,
"model_version_id": 34,
"progress": 0,
"status": "paused",
"error": None,
"bytes_downloaded": 0,
"total_bytes": None,
"bytes_per_second": 0.0,
}
]
assert manager._pause_events["download-1"].is_paused() is True
assert persisted["status"] == "paused"
@pytest.mark.asyncio
async def test_get_active_downloads_restarts_from_resume_context_for_active_restored_aria2(
monkeypatch, tmp_path
):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
save_path = save_dir / "file.safetensors"
save_path.write_text("partial")
await manager._aria2_state_store.upsert(
"download-1",
{
"download_id": "download-1",
"transfer_backend": "aria2",
"status": "downloading",
"save_path": str(save_path),
"file_path": str(save_path),
"model_id": 12,
"model_version_id": 34,
"gid": "gid-1",
"resume_context": {
"version_info": {
"id": 34,
"modelId": 12,
"model": {"id": 12, "type": "LoRA", "tags": ["fantasy"]},
"images": [],
},
"file_info": {
"name": "file.safetensors",
"type": "Model",
"primary": True,
"downloadUrl": "https://example.com/file.safetensors",
},
"model_type": "lora",
"relative_path": "",
"save_dir": str(save_dir),
"download_urls": ["https://example.com/file.safetensors"],
},
},
)
restarted = {}
class DummyAria2Downloader:
async def get_status_by_gid(self, gid):
return {"gid": gid, "status": "active"}
async def restore_transfer(self, download_id, gid, restored_path):
return None
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
async def fake_resume_restored_aria2_download(self, download_id, record):
restarted.update(
{
"download_id": download_id,
"model_id": record.get("model_id"),
"model_version_id": record.get("model_version_id"),
"save_dir": record.get("save_dir"),
"resume_context": record.get("resume_context"),
}
)
return {"success": True}
monkeypatch.setattr(
DownloadManager,
"_resume_restored_aria2_download",
fake_resume_restored_aria2_download,
)
execute_original = AsyncMock(side_effect=AssertionError("should not refetch metadata"))
monkeypatch.setattr(
DownloadManager,
"_execute_original_download",
execute_original,
)
downloads = await manager.get_active_downloads()
assert downloads["downloads"][0]["status"] == "downloading"
restarted_task = manager._download_tasks["download-1"]
await restarted_task
assert restarted["download_id"] == "download-1"
assert restarted["model_id"] == 12
assert restarted["model_version_id"] == 34
assert restarted["save_dir"] is None
assert restarted["resume_context"]["model_type"] == "lora"
assert execute_original.await_count == 0
@pytest.mark.asyncio
async def test_get_active_downloads_restores_persisted_aria2_without_initial_save_path(
monkeypatch, tmp_path
):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
save_path = save_dir / "file.safetensors"
save_path.write_text("partial")
(save_dir / "file.safetensors.aria2").write_text("control")
await manager._aria2_state_store.upsert(
"download-1",
{
"download_id": "download-1",
"transfer_backend": "aria2",
"status": "paused",
"model_id": 12,
"model_version_id": 34,
"resume_context": {
"version_info": {
"id": 34,
"modelId": 12,
"model": {"id": 12, "type": "LoRA"},
"images": [],
},
"file_info": {
"name": "file.safetensors",
"type": "Model",
"primary": True,
"downloadUrl": "https://example.com/file.safetensors",
},
"model_type": "lora",
"relative_path": "",
"save_dir": str(save_dir),
"download_urls": ["https://example.com/file.safetensors"],
},
},
)
class DummyAria2Downloader:
async def get_status_by_gid(self, gid):
return None
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
downloads = await manager.get_active_downloads()
persisted = await manager._aria2_state_store.get("download-1")
assert downloads["downloads"] == [
{
"download_id": "download-1",
"model_id": 12,
"model_version_id": 34,
"progress": 0,
"status": "paused",
"error": None,
"bytes_downloaded": 0,
"total_bytes": None,
"bytes_per_second": 0.0,
}
]
assert manager._active_downloads["download-1"]["file_path"] == str(save_path)
assert persisted["save_path"] == str(save_path)
assert persisted["file_path"] == str(save_path)
@pytest.mark.asyncio
async def test_resume_restored_aria2_download_updates_terminal_status_and_cleanup(monkeypatch):
manager = DownloadManager()
manager._active_downloads["download-1"] = {
"transfer_backend": "aria2",
"status": "paused",
"model_id": 12,
"model_version_id": 34,
"bytes_per_second": 10.0,
}
persist_state = AsyncMock()
cleanup_record = AsyncMock(return_value=None)
execute_download = AsyncMock(return_value={"success": True})
record_history = AsyncMock(return_value=None)
sync_version = AsyncMock(return_value=None)
monkeypatch.setattr(manager, "_persist_aria2_state", persist_state)
monkeypatch.setattr(manager, "_cleanup_download_record", cleanup_record)
monkeypatch.setattr(manager, "_execute_download", execute_download)
monkeypatch.setattr(manager, "_record_downloaded_version_history", record_history)
monkeypatch.setattr(manager, "_sync_downloaded_version", sync_version)
scheduled_tasks = []
original_create_task = asyncio.create_task
def tracking_create_task(coro):
task = original_create_task(coro)
scheduled_tasks.append(task)
return task
monkeypatch.setattr(download_manager.asyncio, "create_task", tracking_create_task)
result = await manager._resume_restored_aria2_download(
"download-1",
{
"download_id": "download-1",
"save_path": "/tmp/file.safetensors",
"file_path": "/tmp/file.safetensors",
"model_id": 12,
"model_version_id": 34,
"resume_context": {
"version_info": {
"id": 34,
"modelId": 12,
"model": {"id": 12},
"images": [],
},
"file_info": {
"name": "file.safetensors",
"downloadUrl": "https://example.com/file.safetensors",
},
"model_type": "lora",
"relative_path": "",
"save_dir": "/tmp",
"download_urls": ["https://example.com/file.safetensors"],
},
},
)
assert result == {"success": True}
assert manager._active_downloads["download-1"]["status"] == "completed"
assert manager._active_downloads["download-1"]["bytes_per_second"] == 0.0
assert persist_state.await_count == 2
assert len(scheduled_tasks) == 1
await asyncio.gather(*scheduled_tasks)
cleanup_record.assert_awaited_once_with("download-1")
@pytest.mark.asyncio
async def test_download_uses_captured_backend_when_settings_change(
monkeypatch, scanners, metadata_provider, tmp_path

View File

@@ -14,6 +14,7 @@ import pytest
from py.services.download_manager import DownloadManager
from py.services.downloader import DownloadStreamControl
from py.services import download_manager
from py.services import aria2_transfer_state
from py.services.service_registry import ServiceRegistry
from py.services.settings_manager import SettingsManager, get_settings_manager
from py.utils.metadata_manager import MetadataManager
@@ -49,6 +50,16 @@ def isolate_settings(monkeypatch, tmp_path):
monkeypatch.setattr(SettingsManager, "_save_settings", lambda self: None)
@pytest.fixture(autouse=True)
def isolate_aria2_state(monkeypatch, tmp_path):
state_path = tmp_path / "cache" / "aria2" / "downloads.json"
monkeypatch.setattr(
aria2_transfer_state,
"get_aria2_state_path",
lambda: str(state_path),
)
@pytest.mark.asyncio
async def test_execute_download_retries_urls(monkeypatch, tmp_path):
"""Test that download retries multiple URLs on failure."""
@@ -800,6 +811,89 @@ async def test_resume_download_returns_error_when_aria2_probe_raises(monkeypatch
assert manager._active_downloads[download_id]["status"] == "paused"
@pytest.mark.asyncio
async def test_resume_download_does_not_spawn_restored_worker_when_aria2_resume_fails(
monkeypatch, tmp_path
):
manager = DownloadManager()
download_id = "dl"
save_path = tmp_path / "file.safetensors"
pause_control = DownloadStreamControl()
pause_control.pause()
manager._pause_events[download_id] = pause_control
manager._active_downloads[download_id] = {
"transfer_backend": "aria2",
"status": "paused",
"bytes_per_second": 0.0,
}
await manager._aria2_state_store.upsert(
download_id,
{
"download_id": download_id,
"transfer_backend": "aria2",
"status": "paused",
"save_path": str(save_path),
"file_path": str(save_path),
"model_id": 12,
"model_version_id": 34,
"resume_context": {
"version_info": {"id": 34, "modelId": 12, "model": {"id": 12}},
"file_info": {
"name": "file.safetensors",
"downloadUrl": "https://example.com/file.safetensors",
},
"model_type": "lora",
"relative_path": "",
"save_dir": str(tmp_path),
"download_urls": ["https://example.com/file.safetensors"],
},
},
)
resume_restored = AsyncMock(return_value={"success": True})
monkeypatch.setattr(manager, "_resume_restored_aria2_download", resume_restored)
class DummyAria2Downloader:
async def has_transfer(self, _download_id):
return True
async def resume_download(self, _download_id):
return {"success": False, "error": "rpc unavailable"}
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
result = await manager.resume_download(download_id)
assert result == {"success": False, "error": "rpc unavailable"}
assert download_id not in manager._download_tasks
assert resume_restored.await_count == 0
assert pause_control.is_paused() is True
assert manager._active_downloads[download_id]["status"] == "paused"
@pytest.mark.asyncio
async def test_start_background_download_task_cleans_up_finished_restore_task():
manager = DownloadManager()
download_id = "download-1"
manager._pause_events[download_id] = DownloadStreamControl()
async def finished_restore():
return {"success": True}
task = manager._start_background_download_task(download_id, finished_restore())
await task
await asyncio.sleep(0)
assert download_id not in manager._download_tasks
assert download_id not in manager._pause_events
@pytest.mark.asyncio
async def test_cancel_download_still_cancels_local_task_when_aria2_raises(monkeypatch):
manager = DownloadManager()
@@ -836,6 +930,217 @@ async def test_cancel_download_still_cancels_local_task_when_aria2_raises(monkey
assert task.cancelled() or task.done()
@pytest.mark.asyncio
async def test_cancel_download_preserves_tracking_when_aria2_returns_error(monkeypatch, tmp_path):
manager = DownloadManager()
download_id = "download-queued"
save_path = tmp_path / "file.safetensors"
save_path.write_text("partial")
(tmp_path / "file.safetensors.aria2").write_text("control")
pause_control = DownloadStreamControl()
manager._pause_events[download_id] = pause_control
manager._download_tasks[download_id] = object()
manager._active_downloads[download_id] = {
"transfer_backend": "aria2",
"status": "downloading",
"file_path": str(save_path),
}
await manager._aria2_state_store.upsert(
download_id,
{
"download_id": download_id,
"transfer_backend": "aria2",
"status": "downloading",
"save_path": str(save_path),
"file_path": str(save_path),
},
)
cleanup_files = AsyncMock(return_value=None)
monkeypatch.setattr(manager, "_cleanup_cancelled_download_files", cleanup_files)
class DummyAria2Downloader:
async def cancel_download(self, _download_id):
return {"success": False, "error": "rpc unavailable"}
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
result = await manager.cancel_download(download_id)
assert result == {"success": False, "error": "rpc unavailable"}
assert download_id in manager._download_tasks
assert download_id in manager._pause_events
assert await manager._aria2_state_store.get(download_id) is not None
assert cleanup_files.await_count == 0
@pytest.mark.asyncio
async def test_cancel_download_rejects_completed_history_entry(tmp_path):
manager = DownloadManager()
download_id = "completed-download"
save_path = tmp_path / "file.safetensors"
metadata_path = tmp_path / "file.metadata.json"
preview_path = tmp_path / "file.jpeg"
save_path.write_text("complete")
metadata_path.write_text("{}")
preview_path.write_text("preview")
manager._active_downloads[download_id] = {
"transfer_backend": "aria2",
"status": "completed",
"file_path": str(save_path),
"preview_path": str(preview_path),
}
result = await manager.cancel_download(download_id)
assert result == {"success": False, "error": "Download task not found"}
assert save_path.exists()
assert metadata_path.exists()
assert preview_path.exists()
@pytest.mark.asyncio
async def test_cancel_download_removes_preview_and_aria2_control_files(monkeypatch, tmp_path):
manager = DownloadManager()
started = asyncio.Event()
async def blocked_task():
started.set()
await asyncio.sleep(60)
task = asyncio.create_task(blocked_task())
await started.wait()
save_path = tmp_path / "file.safetensors"
save_path.write_text("partial")
aria2_path = tmp_path / "file.safetensors.aria2"
aria2_path.write_text("control")
preview_path = tmp_path / "file.jpeg"
preview_path.write_text("preview")
download_id = "download-queued"
manager._download_tasks[download_id] = task
manager._active_downloads[download_id] = {
"transfer_backend": "aria2",
"status": "queued",
"file_path": str(save_path),
"aria2_control_path": str(aria2_path),
"preview_path": str(preview_path),
}
class DummyAria2Downloader:
async def cancel_download(self, _download_id):
return {"success": True, "message": "cancelled"}
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
result = await manager.cancel_download(download_id)
assert result["success"] is True
assert not save_path.exists()
assert not aria2_path.exists()
assert not preview_path.exists()
@pytest.mark.asyncio
async def test_cancel_download_does_not_delete_untracked_same_basename_preview(
monkeypatch, tmp_path
):
manager = DownloadManager()
started = asyncio.Event()
async def blocked_task():
started.set()
await asyncio.sleep(60)
task = asyncio.create_task(blocked_task())
await started.wait()
save_path = tmp_path / "file.safetensors"
save_path.write_text("partial")
aria2_path = tmp_path / "file.safetensors.aria2"
aria2_path.write_text("control")
manual_preview_path = tmp_path / "file.jpg"
manual_preview_path.write_text("manual")
download_id = "download-queued"
manager._download_tasks[download_id] = task
manager._active_downloads[download_id] = {
"transfer_backend": "aria2",
"status": "queued",
"file_path": str(save_path),
"aria2_control_path": str(aria2_path),
}
class DummyAria2Downloader:
async def cancel_download(self, _download_id):
return {"success": True, "message": "cancelled"}
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
result = await manager.cancel_download(download_id)
assert result["success"] is True
assert not save_path.exists()
assert not aria2_path.exists()
assert manual_preview_path.exists()
@pytest.mark.asyncio
async def test_cleanup_cancelled_download_files_retries_aria2_control_deletion(
monkeypatch, tmp_path
):
manager = DownloadManager()
download_id = "download-1"
save_path = tmp_path / "file.safetensors"
aria2_path = tmp_path / "file.safetensors.aria2"
save_path.write_text("partial")
aria2_path.write_text("control")
original_unlink = os.unlink
attempts = {"count": 0}
def flaky_unlink(path):
if path == str(aria2_path) and attempts["count"] == 0:
attempts["count"] += 1
raise PermissionError("still locked")
return original_unlink(path)
monkeypatch.setattr(download_manager.os, "unlink", flaky_unlink)
monkeypatch.setattr("py.services.download_manager.asyncio.sleep", AsyncMock())
await manager._cleanup_cancelled_download_files(
download_id,
{
"file_path": str(save_path),
"aria2_control_path": str(aria2_path),
"transfer_backend": "aria2",
},
)
assert attempts["count"] == 1
assert not save_path.exists()
assert not aria2_path.exists()
@pytest.mark.asyncio
async def test_execute_download_waits_for_paused_pre_transfer_gate(monkeypatch, tmp_path):
manager = DownloadManager()
@@ -931,6 +1236,311 @@ async def test_execute_download_waits_for_paused_pre_transfer_gate(monkeypatch,
assert result == {"success": True}
@pytest.mark.asyncio
async def test_execute_download_reuses_existing_aria2_partial_path(monkeypatch, tmp_path):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
target_path = save_dir / "file.safetensors"
target_path.write_text("partial")
control_path = save_dir / "file.safetensors.aria2"
control_path.write_text("control")
await manager._aria2_state_store.upsert(
"download-1",
{
"download_id": "download-1",
"transfer_backend": "aria2",
"save_path": str(target_path),
"file_path": str(target_path),
"status": "paused",
},
)
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = str(path)
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
def generate_unique_filename(self, *_args, **_kwargs):
return "renamed.safetensors"
def update_file_info(self, _path):
return None
def to_dict(self):
return {"file_path": self.file_path}
manager._active_downloads["download-1"] = {"transfer_backend": "aria2"}
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
monkeypatch.setattr(
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
)
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
async def fake_download_model_file(
self,
download_url,
save_path,
*,
backend,
progress_callback,
use_auth,
download_id,
pause_control,
):
Path(save_path).write_text("content")
return True, save_path
monkeypatch.setattr(DownloadManager, "_download_model_file", fake_download_model_file)
result = await manager._execute_download(
download_urls=["https://example.com/file.safetensors"],
save_dir=str(save_dir),
metadata=DummyMetadata(target_path),
version_info={"images": []},
relative_path="",
progress_callback=None,
model_type="lora",
download_id="download-1",
transfer_backend="aria2",
)
assert result == {"success": True}
assert manager._active_downloads["download-1"]["file_path"] == str(target_path)
assert not (save_dir / "renamed.safetensors").exists()
@pytest.mark.asyncio
async def test_execute_download_rejects_conflicting_aria2_partial_path(tmp_path):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
target_path = save_dir / "file.safetensors"
target_path.write_text("partial")
(save_dir / "file.safetensors.aria2").write_text("control")
await manager._aria2_state_store.upsert(
"other-download",
{
"download_id": "other-download",
"transfer_backend": "aria2",
"save_path": str(target_path),
"file_path": str(target_path),
"status": "paused",
},
)
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = str(path)
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
def generate_unique_filename(self, *_args, **_kwargs):
raise AssertionError("should not rename")
result = await manager._execute_download(
download_urls=["https://example.com/file.safetensors"],
save_dir=str(save_dir),
metadata=DummyMetadata(target_path),
version_info={"images": []},
relative_path="",
progress_callback=None,
model_type="lora",
download_id="download-1",
transfer_backend="aria2",
)
assert result["success"] is False
assert "already using" in result["error"]
@pytest.mark.asyncio
async def test_execute_download_reassigns_same_aria2_partial_to_new_download_id(
monkeypatch, tmp_path
):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
target_path = save_dir / "file.safetensors"
target_path.write_text("partial")
(save_dir / "file.safetensors.aria2").write_text("control")
await manager._aria2_state_store.upsert(
"old-download",
{
"download_id": "old-download",
"transfer_backend": "aria2",
"save_path": str(target_path),
"file_path": str(target_path),
"status": "paused",
"model_id": 11,
"model_version_id": 22,
},
)
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = str(path)
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
def generate_unique_filename(self, *_args, **_kwargs):
raise AssertionError("should not rename")
def update_file_info(self, _path):
return None
def to_dict(self):
return {"file_path": self.file_path}
class DummyAria2Downloader:
def __init__(self):
self.calls = []
async def reassign_transfer(self, previous_download_id, new_download_id):
self.calls.append(("reassign_transfer", previous_download_id, new_download_id))
return None
dummy_aria2 = DummyAria2Downloader()
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=dummy_aria2),
)
manager._active_downloads["old-download"] = {
"transfer_backend": "aria2",
"model_id": 11,
"model_version_id": 22,
"status": "paused",
}
manager._active_downloads["new-download"] = {
"transfer_backend": "aria2",
"model_id": 11,
"model_version_id": 22,
"status": "queued",
}
resolved, path = await manager._resolve_download_target_path(
str(save_dir),
DummyMetadata(target_path),
transfer_backend="aria2",
download_id="new-download",
)
assert resolved is True
assert path == str(target_path)
assert "old-download" not in manager._active_downloads
assert manager._active_downloads["new-download"]["file_path"] == str(target_path)
assert dummy_aria2.calls == [("reassign_transfer", "old-download", "new-download")]
assert await manager._aria2_state_store.get("old-download") is None
assert (await manager._aria2_state_store.get("new-download"))["save_path"] == str(
target_path
)
def test_is_same_aria2_download_request_requires_version_id_match():
manager = DownloadManager()
assert (
manager._is_same_aria2_download_request(
{"model_id": 1, "model_version_id": None},
{"model_id": 1, "model_version_id": 2},
)
is False
)
assert (
manager._is_same_aria2_download_request(
{"model_id": 1, "model_version_id": 3},
{"model_id": 1, "model_version_id": None},
)
is False
)
@pytest.mark.asyncio
async def test_adopt_existing_aria2_download_cancels_old_running_task(monkeypatch, tmp_path):
manager = DownloadManager()
save_path = tmp_path / "file.safetensors"
started = asyncio.Event()
cancelled = asyncio.Event()
call_order = []
async def old_download():
started.set()
try:
await asyncio.sleep(60)
except asyncio.CancelledError:
call_order.append("old-task-cancelled")
cancelled.set()
raise
old_task = asyncio.create_task(old_download())
await started.wait()
manager._download_tasks["old-download"] = old_task
old_pause_control = DownloadStreamControl()
old_pause_control.pause()
manager._pause_events["old-download"] = old_pause_control
manager._active_downloads["old-download"] = {
"transfer_backend": "aria2",
"model_id": 11,
"model_version_id": 22,
"status": "downloading",
}
manager._active_downloads["new-download"] = {
"transfer_backend": "aria2",
"model_id": 11,
"model_version_id": 22,
"status": "queued",
}
await manager._aria2_state_store.upsert(
"old-download",
{
"download_id": "old-download",
"transfer_backend": "aria2",
"save_path": str(save_path),
"file_path": str(save_path),
"status": "downloading",
"model_id": 11,
"model_version_id": 22,
},
)
class DummyAria2Downloader:
async def reassign_transfer(self, previous_download_id, new_download_id):
call_order.append("reassign-transfer")
return None
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
await manager._adopt_existing_aria2_download(
"old-download",
"new-download",
{"model_id": 11, "model_version_id": 22},
str(save_path),
)
assert cancelled.is_set() is True
assert "old-download" not in manager._download_tasks
assert call_order == ["reassign-transfer", "old-task-cancelled"]
@pytest.mark.asyncio
async def test_pause_download_rejects_unknown_task():
"""Test that pause_download rejects unknown download tasks."""