feat(download): add experimental aria2 backend

This commit is contained in:
Will Miao
2026-04-19 21:46:09 +08:00
parent 0ced53c059
commit 1c530ea013
21 changed files with 1867 additions and 28 deletions

View File

@@ -136,6 +136,190 @@ async def test_execute_download_retries_urls(monkeypatch, tmp_path):
assert dummy_scanner.calls # ensure cache updated
@pytest.mark.asyncio
async def test_execute_download_uses_aria2_backend_for_model_files(monkeypatch, tmp_path):
manager = DownloadManager()
settings = get_settings_manager()
settings.settings["download_backend"] = "aria2"
settings.settings["civitai_api_key"] = "secret-key"
save_dir = tmp_path / "downloads"
save_dir.mkdir()
target_path = save_dir / "file.safetensors"
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = str(path)
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
def generate_unique_filename(self, *_args, **_kwargs):
return os.path.basename(self.file_path)
def update_file_info(self, _path):
return None
def to_dict(self):
return {"file_path": self.file_path}
class DummyAria2Downloader:
def __init__(self):
self.calls = []
async def download_file(
self,
url,
save_path,
*,
download_id,
progress_callback=None,
headers=None,
):
self.calls.append(
{
"url": url,
"save_path": save_path,
"download_id": download_id,
"headers": headers,
}
)
Path(save_path).write_text("content")
return True, save_path
dummy_aria2 = DummyAria2Downloader()
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=dummy_aria2),
)
monkeypatch.setattr(
download_manager,
"get_downloader",
AsyncMock(side_effect=AssertionError("python downloader should not be used")),
)
class DummyScanner:
async def add_model_to_cache(self, metadata_dict, relative_path):
return {"metadata": metadata_dict, "relative_path": relative_path}
dummy_scanner = DummyScanner()
monkeypatch.setattr(
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
)
monkeypatch.setattr(
DownloadManager,
"_get_checkpoint_scanner",
AsyncMock(return_value=dummy_scanner),
)
monkeypatch.setattr(
ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=dummy_scanner)
)
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
result = await manager._execute_download(
download_urls=["https://civitai.com/api/download/models/1"],
save_dir=str(save_dir),
metadata=DummyMetadata(target_path),
version_info={"images": []},
relative_path="",
progress_callback=None,
model_type="lora",
download_id="download-1",
)
assert result == {"success": True}
assert dummy_aria2.calls == [
{
"url": "https://civitai.com/api/download/models/1",
"save_path": str(target_path),
"download_id": "download-1",
"headers": {"Authorization": "Bearer secret-key"},
}
]
@pytest.mark.asyncio
async def test_execute_download_allows_anonymous_civitai_with_aria2(
monkeypatch, tmp_path
):
manager = DownloadManager()
settings = get_settings_manager()
settings.settings["download_backend"] = "aria2"
settings.settings["civitai_api_key"] = ""
save_dir = tmp_path / "downloads"
save_dir.mkdir()
target_path = save_dir / "file.safetensors"
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = str(path)
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
def generate_unique_filename(self, *_args, **_kwargs):
return os.path.basename(self.file_path)
def update_file_info(self, _path):
return None
def to_dict(self):
return {"file_path": self.file_path}
class DummyAria2Downloader:
def __init__(self):
self.calls = []
async def download_file(
self,
url,
save_path,
*,
download_id,
progress_callback=None,
headers=None,
):
self.calls.append({"url": url, "headers": headers, "download_id": download_id})
Path(save_path).write_text("content")
return True, save_path
dummy_aria2 = DummyAria2Downloader()
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=dummy_aria2),
)
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
monkeypatch.setattr(
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
)
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
result = await manager._execute_download(
download_urls=["https://civitai.com/api/download/models/1"],
save_dir=str(save_dir),
metadata=DummyMetadata(target_path),
version_info={"images": []},
relative_path="",
progress_callback=None,
model_type="lora",
download_id="download-2",
)
assert result == {"success": True}
assert dummy_aria2.calls == [
{
"url": "https://civitai.com/api/download/models/1",
"headers": None,
"download_id": "download-2",
}
]
@pytest.mark.asyncio
async def test_execute_download_adjusts_checkpoint_sub_type(monkeypatch, tmp_path):
"""Test that checkpoint sub_type is adjusted during download."""
@@ -276,6 +460,13 @@ async def test_execute_download_extracts_zip_single_model(monkeypatch, tmp_path)
monkeypatch.setattr(
download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader())
)
class ImmediateLoop:
async def run_in_executor(self, executor, func, *args):
return func(*args)
monkeypatch.setattr(download_manager.asyncio, "get_running_loop", lambda: ImmediateLoop())
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
monkeypatch.setattr(
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
@@ -344,6 +535,13 @@ async def test_execute_download_extracts_zip_multiple_models(monkeypatch, tmp_pa
monkeypatch.setattr(
download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader())
)
class ImmediateLoop:
async def run_in_executor(self, executor, func, *args):
return func(*args)
monkeypatch.setattr(download_manager.asyncio, "get_running_loop", lambda: ImmediateLoop())
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
monkeypatch.setattr(
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
@@ -418,6 +616,13 @@ async def test_execute_download_extracts_zip_pt_embedding(monkeypatch, tmp_path)
monkeypatch.setattr(
download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader())
)
class ImmediateLoop:
async def run_in_executor(self, executor, func, *args):
return func(*args)
monkeypatch.setattr(download_manager.asyncio, "get_running_loop", lambda: ImmediateLoop())
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
monkeypatch.setattr(
ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=dummy_scanner)
@@ -446,6 +651,36 @@ async def test_execute_download_extracts_zip_pt_embedding(monkeypatch, tmp_path)
assert dummy_scanner.add_model_to_cache.await_count == 1
@pytest.mark.asyncio
async def test_extract_model_files_from_archive_uses_executor(monkeypatch, tmp_path):
manager = DownloadManager()
archive_path = tmp_path / "bundle.zip"
with zipfile.ZipFile(archive_path, "w") as archive:
archive.writestr("inner/model.safetensors", b"model")
captured = {}
class ImmediateLoop:
async def run_in_executor(self, executor, func, *args):
captured["executor"] = executor
return func(*args)
monkeypatch.setattr(
download_manager.asyncio,
"get_running_loop",
lambda: ImmediateLoop(),
)
extracted = await manager._extract_model_files_from_archive(
str(archive_path),
{".safetensors"},
)
assert captured["executor"] is manager._archive_executor
assert len(extracted) == 1
assert extracted[0].endswith("model.safetensors")
@pytest.mark.asyncio
async def test_pause_download_updates_state():
"""Test that pause_download updates download state correctly."""
@@ -469,6 +704,233 @@ async def test_pause_download_updates_state():
assert manager._active_downloads[download_id]["bytes_per_second"] == 0.0
@pytest.mark.asyncio
async def test_pause_download_reverts_local_pause_when_aria2_pause_fails(monkeypatch):
manager = DownloadManager()
download_id = "dl"
manager._download_tasks[download_id] = object()
pause_control = DownloadStreamControl()
manager._pause_events[download_id] = pause_control
manager._active_downloads[download_id] = {
"transfer_backend": "aria2",
"status": "downloading",
"bytes_per_second": 42.0,
}
class DummyAria2Downloader:
async def has_transfer(self, _download_id):
return True
async def pause_download(self, _download_id):
return {"success": False, "error": "rpc failed"}
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
result = await manager.pause_download(download_id)
assert result == {"success": False, "error": "rpc failed"}
assert pause_control.is_set() is True
assert manager._active_downloads[download_id]["status"] == "downloading"
@pytest.mark.asyncio
async def test_pause_download_reverts_local_pause_when_aria2_probe_raises(monkeypatch):
manager = DownloadManager()
download_id = "dl"
manager._download_tasks[download_id] = object()
pause_control = DownloadStreamControl()
manager._pause_events[download_id] = pause_control
manager._active_downloads[download_id] = {
"transfer_backend": "aria2",
"status": "downloading",
"bytes_per_second": 42.0,
}
class DummyAria2Downloader:
async def has_transfer(self, _download_id):
raise RuntimeError("rpc unavailable")
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
result = await manager.pause_download(download_id)
assert result == {"success": False, "error": "rpc unavailable"}
assert pause_control.is_set() is True
assert manager._active_downloads[download_id]["status"] == "downloading"
@pytest.mark.asyncio
async def test_resume_download_returns_error_when_aria2_probe_raises(monkeypatch):
manager = DownloadManager()
download_id = "dl"
pause_control = DownloadStreamControl()
pause_control.pause()
manager._pause_events[download_id] = pause_control
manager._active_downloads[download_id] = {
"transfer_backend": "aria2",
"status": "paused",
"bytes_per_second": 0.0,
}
class DummyAria2Downloader:
async def has_transfer(self, _download_id):
raise RuntimeError("rpc unavailable")
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
result = await manager.resume_download(download_id)
assert result == {"success": False, "error": "rpc unavailable"}
assert pause_control.is_paused() is True
assert manager._active_downloads[download_id]["status"] == "paused"
@pytest.mark.asyncio
async def test_cancel_download_still_cancels_local_task_when_aria2_raises(monkeypatch):
manager = DownloadManager()
started = asyncio.Event()
async def blocked_task():
started.set()
await asyncio.sleep(60)
task = asyncio.create_task(blocked_task())
await started.wait()
download_id = "download-queued"
manager._download_tasks[download_id] = task
manager._active_downloads[download_id] = {
"transfer_backend": "aria2",
"status": "queued",
}
class DummyAria2Downloader:
async def cancel_download(self, _download_id):
raise RuntimeError("rpc unavailable")
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
result = await manager.cancel_download(download_id)
assert result["success"] is True
assert task.cancelled() or task.done()
@pytest.mark.asyncio
async def test_execute_download_waits_for_paused_pre_transfer_gate(monkeypatch, tmp_path):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
target_path = save_dir / "file.safetensors"
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = str(path)
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
def generate_unique_filename(self, *_args, **_kwargs):
return os.path.basename(self.file_path)
def update_file_info(self, _path):
return None
def to_dict(self):
return {"file_path": self.file_path}
pause_control = DownloadStreamControl()
pause_control.pause()
manager._pause_events["download-1"] = pause_control
manager._active_downloads["download-1"] = {
"status": "downloading",
"bytes_per_second": 42.0,
}
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
monkeypatch.setattr(
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
)
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
started = asyncio.Event()
allow_finish = asyncio.Event()
captured = {"calls": 0}
async def fake_download_model_file(
self,
download_url,
save_path,
*,
backend,
progress_callback,
use_auth,
download_id,
pause_control,
):
captured["calls"] += 1
started.set()
await allow_finish.wait()
Path(save_path).write_text("content")
return True, save_path
monkeypatch.setattr(
DownloadManager,
"_download_model_file",
fake_download_model_file,
)
task = asyncio.create_task(
manager._execute_download(
download_urls=["https://civitai.com/api/download/models/1"],
save_dir=str(save_dir),
metadata=DummyMetadata(target_path),
version_info={"images": []},
relative_path="",
progress_callback=None,
model_type="lora",
download_id="download-1",
transfer_backend="aria2",
)
)
await asyncio.sleep(0)
assert started.is_set() is False
assert captured["calls"] == 0
assert manager._active_downloads["download-1"]["status"] == "paused"
pause_control.resume()
await asyncio.wait_for(started.wait(), timeout=1.0)
assert captured["calls"] == 1
assert manager._active_downloads["download-1"]["status"] == "downloading"
allow_finish.set()
result = await task
assert result == {"success": True}
@pytest.mark.asyncio
async def test_pause_download_rejects_unknown_task():
"""Test that pause_download rejects unknown download tasks."""