mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-06 08:26:45 -03:00
feat(download): add experimental aria2 backend
This commit is contained in:
@@ -305,4 +305,39 @@ describe('SettingsManager library controls', () => {
|
||||
'success',
|
||||
);
|
||||
});
|
||||
|
||||
it('loads download backend settings and toggles the aria2 path field', () => {
|
||||
const manager = createManager();
|
||||
document.body.innerHTML = `
|
||||
<select id="downloadBackend">
|
||||
<option value="python">Python</option>
|
||||
<option value="aria2">aria2</option>
|
||||
</select>
|
||||
<div id="aria2PathSetting" style="display: none;"></div>
|
||||
<input id="aria2cPath" />
|
||||
`;
|
||||
|
||||
state.global.settings = {
|
||||
download_backend: 'aria2',
|
||||
aria2c_path: '/usr/bin/aria2c',
|
||||
};
|
||||
|
||||
const saveSpy = vi.spyOn(manager, 'saveSelectSetting').mockResolvedValue();
|
||||
|
||||
manager.loadDownloadBackendSettings();
|
||||
|
||||
const backendSelect = document.getElementById('downloadBackend');
|
||||
const aria2PathSetting = document.getElementById('aria2PathSetting');
|
||||
const aria2cPath = document.getElementById('aria2cPath');
|
||||
|
||||
expect(backendSelect.value).toBe('aria2');
|
||||
expect(aria2cPath.value).toBe('/usr/bin/aria2c');
|
||||
expect(aria2PathSetting.style.display).toBe('block');
|
||||
|
||||
backendSelect.value = 'python';
|
||||
backendSelect.onchange();
|
||||
|
||||
expect(aria2PathSetting.style.display).toBe('none');
|
||||
expect(saveSpy).toHaveBeenCalledWith('downloadBackend', 'download_backend');
|
||||
});
|
||||
});
|
||||
|
||||
269
tests/services/test_aria2_downloader.py
Normal file
269
tests/services/test_aria2_downloader.py
Normal file
@@ -0,0 +1,269 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.aria2_downloader import Aria2Downloader, Aria2Error
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_file_polls_until_complete(tmp_path, monkeypatch):
|
||||
downloader = Aria2Downloader()
|
||||
downloader._rpc_url = "http://127.0.0.1/jsonrpc"
|
||||
downloader._rpc_secret = "secret"
|
||||
|
||||
save_path = tmp_path / "downloads" / "model.safetensors"
|
||||
progress_events = []
|
||||
rpc_calls = []
|
||||
statuses = iter(
|
||||
[
|
||||
{
|
||||
"gid": "gid-1",
|
||||
"status": "active",
|
||||
"completedLength": "5",
|
||||
"totalLength": "10",
|
||||
"downloadSpeed": "25",
|
||||
},
|
||||
{
|
||||
"gid": "gid-1",
|
||||
"status": "complete",
|
||||
"completedLength": "10",
|
||||
"totalLength": "10",
|
||||
"downloadSpeed": "0",
|
||||
"files": [{"path": str(save_path)}],
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
async def fake_rpc_call(method, params):
|
||||
rpc_calls.append((method, params))
|
||||
if method == "aria2.addUri":
|
||||
return "gid-1"
|
||||
if method == "aria2.tellStatus":
|
||||
return next(statuses)
|
||||
raise AssertionError(f"Unexpected RPC method: {method}")
|
||||
|
||||
monkeypatch.setattr(downloader, "_ensure_process", AsyncMock())
|
||||
monkeypatch.setattr(
|
||||
downloader,
|
||||
"_resolve_authenticated_redirect_url",
|
||||
AsyncMock(
|
||||
return_value="https://signed.example.com/model.safetensors?token=abc"
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call)
|
||||
monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock())
|
||||
|
||||
async def progress_callback(progress, snapshot=None):
|
||||
progress_events.append(snapshot.percent_complete if snapshot else progress)
|
||||
|
||||
success, result = await downloader.download_file(
|
||||
"https://civitai.com/api/download/models/123",
|
||||
str(save_path),
|
||||
download_id="download-1",
|
||||
progress_callback=progress_callback,
|
||||
headers={"Authorization": "Bearer token"},
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert result == str(save_path)
|
||||
assert progress_events == [50.0, 100.0]
|
||||
assert downloader._transfers == {}
|
||||
assert rpc_calls[0][0] == "aria2.addUri"
|
||||
assert rpc_calls[0][1][0] == [
|
||||
"https://signed.example.com/model.safetensors?token=abc"
|
||||
]
|
||||
assert rpc_calls[0][1][1]["out"] == "model.safetensors"
|
||||
assert "header" not in rpc_calls[0][1][1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_file_keeps_auth_headers_when_civitai_does_not_redirect(
|
||||
tmp_path, monkeypatch
|
||||
):
|
||||
downloader = Aria2Downloader()
|
||||
downloader._rpc_url = "http://127.0.0.1/jsonrpc"
|
||||
downloader._rpc_secret = "secret"
|
||||
|
||||
save_path = tmp_path / "downloads" / "model.safetensors"
|
||||
rpc_calls = []
|
||||
statuses = iter(
|
||||
[
|
||||
{
|
||||
"gid": "gid-1",
|
||||
"status": "complete",
|
||||
"completedLength": "10",
|
||||
"totalLength": "10",
|
||||
"downloadSpeed": "0",
|
||||
"files": [{"path": str(save_path)}],
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
async def fake_rpc_call(method, params):
|
||||
rpc_calls.append((method, params))
|
||||
if method == "aria2.addUri":
|
||||
return "gid-1"
|
||||
if method == "aria2.tellStatus":
|
||||
return next(statuses)
|
||||
raise AssertionError(f"Unexpected RPC method: {method}")
|
||||
|
||||
monkeypatch.setattr(downloader, "_ensure_process", AsyncMock())
|
||||
monkeypatch.setattr(
|
||||
downloader,
|
||||
"_resolve_authenticated_redirect_url",
|
||||
AsyncMock(return_value="https://civitai.com/api/download/models/123"),
|
||||
)
|
||||
monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call)
|
||||
monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock())
|
||||
|
||||
success, result = await downloader.download_file(
|
||||
"https://civitai.com/api/download/models/123",
|
||||
str(save_path),
|
||||
download_id="download-1",
|
||||
headers={"Authorization": "Bearer token"},
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert result == str(save_path)
|
||||
assert rpc_calls[0][1][0] == ["https://civitai.com/api/download/models/123"]
|
||||
assert rpc_calls[0][1][1]["header"] == ["Authorization: Bearer token"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_resume_cancel_forward_to_rpc(monkeypatch):
|
||||
downloader = Aria2Downloader()
|
||||
downloader._transfers["download-1"] = type(
|
||||
"Transfer", (), {"gid": "gid-1", "save_path": "/tmp/model.safetensors"}
|
||||
)()
|
||||
|
||||
calls = []
|
||||
|
||||
async def fake_rpc_call(method, params):
|
||||
calls.append((method, params))
|
||||
return "gid-1"
|
||||
|
||||
monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call)
|
||||
|
||||
pause_result = await downloader.pause_download("download-1")
|
||||
resume_result = await downloader.resume_download("download-1")
|
||||
cancel_result = await downloader.cancel_download("download-1")
|
||||
|
||||
assert pause_result["success"] is True
|
||||
assert resume_result["success"] is True
|
||||
assert cancel_result["success"] is True
|
||||
assert calls == [
|
||||
("aria2.forcePause", ["gid-1"]),
|
||||
("aria2.unpause", ["gid-1"]),
|
||||
("aria2.forceRemove", ["gid-1"]),
|
||||
]
|
||||
|
||||
|
||||
def test_build_progress_snapshot_normalizes_numeric_fields():
|
||||
downloader = Aria2Downloader()
|
||||
|
||||
snapshot = downloader._build_progress_snapshot(
|
||||
{
|
||||
"completedLength": "75",
|
||||
"totalLength": "100",
|
||||
"downloadSpeed": "512",
|
||||
}
|
||||
)
|
||||
|
||||
assert snapshot.percent_complete == 75.0
|
||||
assert snapshot.bytes_downloaded == 75
|
||||
assert snapshot.total_bytes == 100
|
||||
assert snapshot.bytes_per_second == 512.0
|
||||
|
||||
|
||||
def test_resolve_executable_raises_when_binary_missing(monkeypatch):
|
||||
downloader = Aria2Downloader()
|
||||
settings = type("Settings", (), {"get": lambda self, key, default=None: ""})()
|
||||
|
||||
monkeypatch.setattr("py.services.aria2_downloader.get_settings_manager", lambda: settings)
|
||||
monkeypatch.setattr("py.services.aria2_downloader.shutil.which", lambda _: None)
|
||||
|
||||
with pytest.raises(Aria2Error):
|
||||
downloader._resolve_executable()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rpc_call_surfaces_json_error_on_non_200(monkeypatch):
|
||||
downloader = Aria2Downloader()
|
||||
downloader._rpc_url = "http://127.0.0.1:6800/jsonrpc"
|
||||
downloader._rpc_secret = "secret"
|
||||
|
||||
class FakeResponse:
|
||||
status = 400
|
||||
|
||||
async def text(self):
|
||||
return (
|
||||
'{"jsonrpc":"2.0","id":"x","error":{"code":1,"message":"Unauthorized"}}'
|
||||
)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
class FakeSession:
|
||||
def post(self, _url, json=None):
|
||||
return FakeResponse()
|
||||
|
||||
monkeypatch.setattr(downloader, "_get_rpc_session", AsyncMock(return_value=FakeSession()))
|
||||
|
||||
with pytest.raises(Aria2Error) as exc_info:
|
||||
await downloader._rpc_call("aria2.addUri", [["https://example.com/file"]])
|
||||
|
||||
assert "Unauthorized" in str(exc_info.value)
|
||||
assert "aria2.addUri" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_authenticated_redirect_url_returns_location(monkeypatch):
|
||||
downloader = Aria2Downloader()
|
||||
|
||||
class FakeResponse:
|
||||
status = 307
|
||||
headers = {"Location": "https://signed.example.com/file.safetensors"}
|
||||
|
||||
async def text(self):
|
||||
return ""
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
class FakeSession:
|
||||
def get(self, _url, headers=None, allow_redirects=False, proxy=None):
|
||||
return FakeResponse()
|
||||
|
||||
class FakeDownloader:
|
||||
default_headers = {"User-Agent": "ComfyUI-LoRA-Manager/1.0"}
|
||||
proxy_url = None
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
async def _session():
|
||||
return FakeSession()
|
||||
|
||||
return _session()
|
||||
|
||||
fake_downloader = FakeDownloader()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"py.services.aria2_downloader.get_downloader",
|
||||
AsyncMock(return_value=fake_downloader),
|
||||
)
|
||||
|
||||
result = await downloader._resolve_authenticated_redirect_url(
|
||||
"https://civitai.com/api/download/models/123",
|
||||
{"Authorization": "Bearer token"},
|
||||
)
|
||||
|
||||
assert result == "https://signed.example.com/file.safetensors"
|
||||
@@ -179,6 +179,7 @@ async def test_successful_download_uses_defaults(
|
||||
progress_callback,
|
||||
model_type,
|
||||
download_id,
|
||||
transfer_backend=None,
|
||||
):
|
||||
captured.update(
|
||||
{
|
||||
@@ -268,6 +269,7 @@ async def test_download_uses_active_mirrors(
|
||||
progress_callback,
|
||||
model_type,
|
||||
download_id,
|
||||
transfer_backend=None,
|
||||
):
|
||||
captured["download_urls"] = download_urls
|
||||
return {"success": True}
|
||||
@@ -288,6 +290,214 @@ async def test_download_uses_active_mirrors(
|
||||
assert captured["download_urls"] == ["https://mirror.example/file.safetensors"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_resume_cancel_delegate_to_aria2_backend(monkeypatch):
|
||||
manager = DownloadManager()
|
||||
|
||||
task = asyncio.create_task(asyncio.sleep(60))
|
||||
manager._download_tasks["download-1"] = task
|
||||
manager._pause_events["download-1"] = download_manager.DownloadStreamControl()
|
||||
manager._active_downloads["download-1"] = {
|
||||
"transfer_backend": "aria2",
|
||||
"status": "downloading",
|
||||
}
|
||||
|
||||
class DummyAria2Downloader:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
async def pause_download(self, download_id):
|
||||
self.calls.append(("pause", download_id))
|
||||
return {"success": True, "message": "paused"}
|
||||
|
||||
async def resume_download(self, download_id):
|
||||
self.calls.append(("resume", download_id))
|
||||
return {"success": True, "message": "resumed"}
|
||||
|
||||
async def cancel_download(self, download_id):
|
||||
self.calls.append(("cancel", download_id))
|
||||
return {"success": True, "message": "cancelled"}
|
||||
|
||||
async def has_transfer(self, download_id):
|
||||
self.calls.append(("has_transfer", download_id))
|
||||
return True
|
||||
|
||||
dummy_aria2 = DummyAria2Downloader()
|
||||
monkeypatch.setattr(
|
||||
download_manager,
|
||||
"get_aria2_downloader",
|
||||
AsyncMock(return_value=dummy_aria2),
|
||||
)
|
||||
|
||||
pause_result = await manager.pause_download("download-1")
|
||||
assert pause_result["success"] is True
|
||||
assert manager._active_downloads["download-1"]["status"] == "paused"
|
||||
|
||||
resume_result = await manager.resume_download("download-1")
|
||||
assert resume_result["success"] is True
|
||||
assert manager._active_downloads["download-1"]["status"] == "downloading"
|
||||
|
||||
cancel_result = await manager.cancel_download("download-1")
|
||||
assert cancel_result["success"] is True
|
||||
assert task.cancelled() or task.done()
|
||||
assert dummy_aria2.calls == [
|
||||
("has_transfer", "download-1"),
|
||||
("pause", "download-1"),
|
||||
("has_transfer", "download-1"),
|
||||
("resume", "download-1"),
|
||||
("cancel", "download-1"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_allows_queued_aria2_task_without_transfer(monkeypatch):
|
||||
manager = DownloadManager()
|
||||
|
||||
started = asyncio.Event()
|
||||
|
||||
async def blocked_task():
|
||||
started.set()
|
||||
await asyncio.sleep(60)
|
||||
|
||||
task = asyncio.create_task(blocked_task())
|
||||
await started.wait()
|
||||
|
||||
manager._download_tasks["download-queued"] = task
|
||||
manager._active_downloads["download-queued"] = {
|
||||
"transfer_backend": "aria2",
|
||||
"status": "queued",
|
||||
}
|
||||
|
||||
class DummyAria2Downloader:
|
||||
async def cancel_download(self, download_id):
|
||||
return {"success": False, "error": "Download task not found"}
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_manager,
|
||||
"get_aria2_downloader",
|
||||
AsyncMock(return_value=DummyAria2Downloader()),
|
||||
)
|
||||
|
||||
result = await manager.cancel_download("download-queued")
|
||||
|
||||
assert result["success"] is True
|
||||
assert task.cancelled() or task.done()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_resume_queued_aria2_task_without_transfer(monkeypatch):
|
||||
manager = DownloadManager()
|
||||
|
||||
task = asyncio.create_task(asyncio.sleep(60))
|
||||
manager._download_tasks["download-queued"] = task
|
||||
manager._pause_events["download-queued"] = download_manager.DownloadStreamControl()
|
||||
manager._active_downloads["download-queued"] = {
|
||||
"transfer_backend": "aria2",
|
||||
"status": "waiting",
|
||||
"bytes_per_second": 12.0,
|
||||
}
|
||||
|
||||
class DummyAria2Downloader:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
async def has_transfer(self, download_id):
|
||||
self.calls.append(("has_transfer", download_id))
|
||||
return False
|
||||
|
||||
async def pause_download(self, download_id):
|
||||
self.calls.append(("pause", download_id))
|
||||
return {"success": True, "message": "paused"}
|
||||
|
||||
async def resume_download(self, download_id):
|
||||
self.calls.append(("resume", download_id))
|
||||
return {"success": True, "message": "resumed"}
|
||||
|
||||
dummy_aria2 = DummyAria2Downloader()
|
||||
monkeypatch.setattr(
|
||||
download_manager,
|
||||
"get_aria2_downloader",
|
||||
AsyncMock(return_value=dummy_aria2),
|
||||
)
|
||||
|
||||
pause_result = await manager.pause_download("download-queued")
|
||||
assert pause_result == {"success": True, "message": "Download paused successfully"}
|
||||
assert manager._active_downloads["download-queued"]["status"] == "paused"
|
||||
assert manager._pause_events["download-queued"].is_paused() is True
|
||||
|
||||
resume_result = await manager.resume_download("download-queued")
|
||||
assert resume_result == {"success": True, "message": "Download resumed successfully"}
|
||||
assert manager._active_downloads["download-queued"]["status"] == "downloading"
|
||||
assert manager._pause_events["download-queued"].is_set() is True
|
||||
assert dummy_aria2.calls == [
|
||||
("has_transfer", "download-queued"),
|
||||
("has_transfer", "download-queued"),
|
||||
]
|
||||
|
||||
task.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_uses_captured_backend_when_settings_change(
|
||||
monkeypatch, scanners, metadata_provider, tmp_path
|
||||
):
|
||||
manager = DownloadManager()
|
||||
settings = get_settings_manager()
|
||||
settings.settings["download_backend"] = "aria2"
|
||||
|
||||
semaphore = asyncio.Semaphore(0)
|
||||
manager._download_semaphore = semaphore
|
||||
|
||||
captured = {}
|
||||
|
||||
async def fake_execute_original_download(
|
||||
self,
|
||||
model_id,
|
||||
model_version_id,
|
||||
save_dir,
|
||||
relative_path,
|
||||
progress_callback,
|
||||
use_default_paths,
|
||||
download_id=None,
|
||||
transfer_backend="python",
|
||||
source=None,
|
||||
file_params=None,
|
||||
):
|
||||
captured["transfer_backend"] = transfer_backend
|
||||
return {"success": True}
|
||||
|
||||
monkeypatch.setattr(
|
||||
DownloadManager,
|
||||
"_execute_original_download",
|
||||
fake_execute_original_download,
|
||||
)
|
||||
|
||||
download_task = asyncio.create_task(
|
||||
manager.download_from_civitai(
|
||||
model_version_id=99,
|
||||
save_dir=str(tmp_path),
|
||||
use_default_paths=True,
|
||||
progress_callback=None,
|
||||
source=None,
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.sleep(0)
|
||||
assert len(manager._active_downloads) == 1
|
||||
download_id = next(iter(manager._active_downloads))
|
||||
assert manager._active_downloads[download_id]["transfer_backend"] == "aria2"
|
||||
|
||||
settings.settings["download_backend"] = "python"
|
||||
semaphore.release()
|
||||
|
||||
result = await download_task
|
||||
|
||||
assert result["success"] is True
|
||||
assert captured["transfer_backend"] == "aria2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_aborts_when_version_exists(
|
||||
monkeypatch, scanners, metadata_provider
|
||||
|
||||
@@ -136,6 +136,190 @@ async def test_execute_download_retries_urls(monkeypatch, tmp_path):
|
||||
assert dummy_scanner.calls # ensure cache updated
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_download_uses_aria2_backend_for_model_files(monkeypatch, tmp_path):
|
||||
manager = DownloadManager()
|
||||
settings = get_settings_manager()
|
||||
settings.settings["download_backend"] = "aria2"
|
||||
settings.settings["civitai_api_key"] = "secret-key"
|
||||
|
||||
save_dir = tmp_path / "downloads"
|
||||
save_dir.mkdir()
|
||||
target_path = save_dir / "file.safetensors"
|
||||
|
||||
class DummyMetadata:
|
||||
def __init__(self, path: Path):
|
||||
self.file_path = str(path)
|
||||
self.sha256 = "sha256"
|
||||
self.file_name = path.stem
|
||||
self.preview_url = None
|
||||
|
||||
def generate_unique_filename(self, *_args, **_kwargs):
|
||||
return os.path.basename(self.file_path)
|
||||
|
||||
def update_file_info(self, _path):
|
||||
return None
|
||||
|
||||
def to_dict(self):
|
||||
return {"file_path": self.file_path}
|
||||
|
||||
class DummyAria2Downloader:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
async def download_file(
|
||||
self,
|
||||
url,
|
||||
save_path,
|
||||
*,
|
||||
download_id,
|
||||
progress_callback=None,
|
||||
headers=None,
|
||||
):
|
||||
self.calls.append(
|
||||
{
|
||||
"url": url,
|
||||
"save_path": save_path,
|
||||
"download_id": download_id,
|
||||
"headers": headers,
|
||||
}
|
||||
)
|
||||
Path(save_path).write_text("content")
|
||||
return True, save_path
|
||||
|
||||
dummy_aria2 = DummyAria2Downloader()
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_manager,
|
||||
"get_aria2_downloader",
|
||||
AsyncMock(return_value=dummy_aria2),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
download_manager,
|
||||
"get_downloader",
|
||||
AsyncMock(side_effect=AssertionError("python downloader should not be used")),
|
||||
)
|
||||
|
||||
class DummyScanner:
|
||||
async def add_model_to_cache(self, metadata_dict, relative_path):
|
||||
return {"metadata": metadata_dict, "relative_path": relative_path}
|
||||
|
||||
dummy_scanner = DummyScanner()
|
||||
monkeypatch.setattr(
|
||||
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
DownloadManager,
|
||||
"_get_checkpoint_scanner",
|
||||
AsyncMock(return_value=dummy_scanner),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=dummy_scanner)
|
||||
)
|
||||
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
|
||||
|
||||
result = await manager._execute_download(
|
||||
download_urls=["https://civitai.com/api/download/models/1"],
|
||||
save_dir=str(save_dir),
|
||||
metadata=DummyMetadata(target_path),
|
||||
version_info={"images": []},
|
||||
relative_path="",
|
||||
progress_callback=None,
|
||||
model_type="lora",
|
||||
download_id="download-1",
|
||||
)
|
||||
|
||||
assert result == {"success": True}
|
||||
assert dummy_aria2.calls == [
|
||||
{
|
||||
"url": "https://civitai.com/api/download/models/1",
|
||||
"save_path": str(target_path),
|
||||
"download_id": "download-1",
|
||||
"headers": {"Authorization": "Bearer secret-key"},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_download_allows_anonymous_civitai_with_aria2(
|
||||
monkeypatch, tmp_path
|
||||
):
|
||||
manager = DownloadManager()
|
||||
settings = get_settings_manager()
|
||||
settings.settings["download_backend"] = "aria2"
|
||||
settings.settings["civitai_api_key"] = ""
|
||||
|
||||
save_dir = tmp_path / "downloads"
|
||||
save_dir.mkdir()
|
||||
target_path = save_dir / "file.safetensors"
|
||||
|
||||
class DummyMetadata:
|
||||
def __init__(self, path: Path):
|
||||
self.file_path = str(path)
|
||||
self.sha256 = "sha256"
|
||||
self.file_name = path.stem
|
||||
self.preview_url = None
|
||||
|
||||
def generate_unique_filename(self, *_args, **_kwargs):
|
||||
return os.path.basename(self.file_path)
|
||||
|
||||
def update_file_info(self, _path):
|
||||
return None
|
||||
|
||||
def to_dict(self):
|
||||
return {"file_path": self.file_path}
|
||||
|
||||
class DummyAria2Downloader:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
async def download_file(
|
||||
self,
|
||||
url,
|
||||
save_path,
|
||||
*,
|
||||
download_id,
|
||||
progress_callback=None,
|
||||
headers=None,
|
||||
):
|
||||
self.calls.append({"url": url, "headers": headers, "download_id": download_id})
|
||||
Path(save_path).write_text("content")
|
||||
return True, save_path
|
||||
|
||||
dummy_aria2 = DummyAria2Downloader()
|
||||
monkeypatch.setattr(
|
||||
download_manager,
|
||||
"get_aria2_downloader",
|
||||
AsyncMock(return_value=dummy_aria2),
|
||||
)
|
||||
|
||||
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
|
||||
monkeypatch.setattr(
|
||||
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
|
||||
)
|
||||
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
|
||||
|
||||
result = await manager._execute_download(
|
||||
download_urls=["https://civitai.com/api/download/models/1"],
|
||||
save_dir=str(save_dir),
|
||||
metadata=DummyMetadata(target_path),
|
||||
version_info={"images": []},
|
||||
relative_path="",
|
||||
progress_callback=None,
|
||||
model_type="lora",
|
||||
download_id="download-2",
|
||||
)
|
||||
|
||||
assert result == {"success": True}
|
||||
assert dummy_aria2.calls == [
|
||||
{
|
||||
"url": "https://civitai.com/api/download/models/1",
|
||||
"headers": None,
|
||||
"download_id": "download-2",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_download_adjusts_checkpoint_sub_type(monkeypatch, tmp_path):
|
||||
"""Test that checkpoint sub_type is adjusted during download."""
|
||||
@@ -276,6 +460,13 @@ async def test_execute_download_extracts_zip_single_model(monkeypatch, tmp_path)
|
||||
monkeypatch.setattr(
|
||||
download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader())
|
||||
)
|
||||
|
||||
class ImmediateLoop:
|
||||
async def run_in_executor(self, executor, func, *args):
|
||||
return func(*args)
|
||||
|
||||
monkeypatch.setattr(download_manager.asyncio, "get_running_loop", lambda: ImmediateLoop())
|
||||
|
||||
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
|
||||
monkeypatch.setattr(
|
||||
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
|
||||
@@ -344,6 +535,13 @@ async def test_execute_download_extracts_zip_multiple_models(monkeypatch, tmp_pa
|
||||
monkeypatch.setattr(
|
||||
download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader())
|
||||
)
|
||||
|
||||
class ImmediateLoop:
|
||||
async def run_in_executor(self, executor, func, *args):
|
||||
return func(*args)
|
||||
|
||||
monkeypatch.setattr(download_manager.asyncio, "get_running_loop", lambda: ImmediateLoop())
|
||||
|
||||
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
|
||||
monkeypatch.setattr(
|
||||
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
|
||||
@@ -418,6 +616,13 @@ async def test_execute_download_extracts_zip_pt_embedding(monkeypatch, tmp_path)
|
||||
monkeypatch.setattr(
|
||||
download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader())
|
||||
)
|
||||
|
||||
class ImmediateLoop:
|
||||
async def run_in_executor(self, executor, func, *args):
|
||||
return func(*args)
|
||||
|
||||
monkeypatch.setattr(download_manager.asyncio, "get_running_loop", lambda: ImmediateLoop())
|
||||
|
||||
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
|
||||
monkeypatch.setattr(
|
||||
ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=dummy_scanner)
|
||||
@@ -446,6 +651,36 @@ async def test_execute_download_extracts_zip_pt_embedding(monkeypatch, tmp_path)
|
||||
assert dummy_scanner.add_model_to_cache.await_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_model_files_from_archive_uses_executor(monkeypatch, tmp_path):
|
||||
manager = DownloadManager()
|
||||
archive_path = tmp_path / "bundle.zip"
|
||||
with zipfile.ZipFile(archive_path, "w") as archive:
|
||||
archive.writestr("inner/model.safetensors", b"model")
|
||||
|
||||
captured = {}
|
||||
|
||||
class ImmediateLoop:
|
||||
async def run_in_executor(self, executor, func, *args):
|
||||
captured["executor"] = executor
|
||||
return func(*args)
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_manager.asyncio,
|
||||
"get_running_loop",
|
||||
lambda: ImmediateLoop(),
|
||||
)
|
||||
|
||||
extracted = await manager._extract_model_files_from_archive(
|
||||
str(archive_path),
|
||||
{".safetensors"},
|
||||
)
|
||||
|
||||
assert captured["executor"] is manager._archive_executor
|
||||
assert len(extracted) == 1
|
||||
assert extracted[0].endswith("model.safetensors")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_download_updates_state():
|
||||
"""Test that pause_download updates download state correctly."""
|
||||
@@ -469,6 +704,233 @@ async def test_pause_download_updates_state():
|
||||
assert manager._active_downloads[download_id]["bytes_per_second"] == 0.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_download_reverts_local_pause_when_aria2_pause_fails(monkeypatch):
|
||||
manager = DownloadManager()
|
||||
|
||||
download_id = "dl"
|
||||
manager._download_tasks[download_id] = object()
|
||||
pause_control = DownloadStreamControl()
|
||||
manager._pause_events[download_id] = pause_control
|
||||
manager._active_downloads[download_id] = {
|
||||
"transfer_backend": "aria2",
|
||||
"status": "downloading",
|
||||
"bytes_per_second": 42.0,
|
||||
}
|
||||
|
||||
class DummyAria2Downloader:
|
||||
async def has_transfer(self, _download_id):
|
||||
return True
|
||||
|
||||
async def pause_download(self, _download_id):
|
||||
return {"success": False, "error": "rpc failed"}
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_manager,
|
||||
"get_aria2_downloader",
|
||||
AsyncMock(return_value=DummyAria2Downloader()),
|
||||
)
|
||||
|
||||
result = await manager.pause_download(download_id)
|
||||
|
||||
assert result == {"success": False, "error": "rpc failed"}
|
||||
assert pause_control.is_set() is True
|
||||
assert manager._active_downloads[download_id]["status"] == "downloading"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_download_reverts_local_pause_when_aria2_probe_raises(monkeypatch):
|
||||
manager = DownloadManager()
|
||||
|
||||
download_id = "dl"
|
||||
manager._download_tasks[download_id] = object()
|
||||
pause_control = DownloadStreamControl()
|
||||
manager._pause_events[download_id] = pause_control
|
||||
manager._active_downloads[download_id] = {
|
||||
"transfer_backend": "aria2",
|
||||
"status": "downloading",
|
||||
"bytes_per_second": 42.0,
|
||||
}
|
||||
|
||||
class DummyAria2Downloader:
|
||||
async def has_transfer(self, _download_id):
|
||||
raise RuntimeError("rpc unavailable")
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_manager,
|
||||
"get_aria2_downloader",
|
||||
AsyncMock(return_value=DummyAria2Downloader()),
|
||||
)
|
||||
|
||||
result = await manager.pause_download(download_id)
|
||||
|
||||
assert result == {"success": False, "error": "rpc unavailable"}
|
||||
assert pause_control.is_set() is True
|
||||
assert manager._active_downloads[download_id]["status"] == "downloading"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_download_returns_error_when_aria2_probe_raises(monkeypatch):
|
||||
manager = DownloadManager()
|
||||
|
||||
download_id = "dl"
|
||||
pause_control = DownloadStreamControl()
|
||||
pause_control.pause()
|
||||
manager._pause_events[download_id] = pause_control
|
||||
manager._active_downloads[download_id] = {
|
||||
"transfer_backend": "aria2",
|
||||
"status": "paused",
|
||||
"bytes_per_second": 0.0,
|
||||
}
|
||||
|
||||
class DummyAria2Downloader:
|
||||
async def has_transfer(self, _download_id):
|
||||
raise RuntimeError("rpc unavailable")
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_manager,
|
||||
"get_aria2_downloader",
|
||||
AsyncMock(return_value=DummyAria2Downloader()),
|
||||
)
|
||||
|
||||
result = await manager.resume_download(download_id)
|
||||
|
||||
assert result == {"success": False, "error": "rpc unavailable"}
|
||||
assert pause_control.is_paused() is True
|
||||
assert manager._active_downloads[download_id]["status"] == "paused"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_download_still_cancels_local_task_when_aria2_raises(monkeypatch):
|
||||
manager = DownloadManager()
|
||||
|
||||
started = asyncio.Event()
|
||||
|
||||
async def blocked_task():
|
||||
started.set()
|
||||
await asyncio.sleep(60)
|
||||
|
||||
task = asyncio.create_task(blocked_task())
|
||||
await started.wait()
|
||||
|
||||
download_id = "download-queued"
|
||||
manager._download_tasks[download_id] = task
|
||||
manager._active_downloads[download_id] = {
|
||||
"transfer_backend": "aria2",
|
||||
"status": "queued",
|
||||
}
|
||||
|
||||
class DummyAria2Downloader:
|
||||
async def cancel_download(self, _download_id):
|
||||
raise RuntimeError("rpc unavailable")
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_manager,
|
||||
"get_aria2_downloader",
|
||||
AsyncMock(return_value=DummyAria2Downloader()),
|
||||
)
|
||||
|
||||
result = await manager.cancel_download(download_id)
|
||||
|
||||
assert result["success"] is True
|
||||
assert task.cancelled() or task.done()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_download_waits_for_paused_pre_transfer_gate(monkeypatch, tmp_path):
|
||||
manager = DownloadManager()
|
||||
|
||||
save_dir = tmp_path / "downloads"
|
||||
save_dir.mkdir()
|
||||
target_path = save_dir / "file.safetensors"
|
||||
|
||||
class DummyMetadata:
|
||||
def __init__(self, path: Path):
|
||||
self.file_path = str(path)
|
||||
self.sha256 = "sha256"
|
||||
self.file_name = path.stem
|
||||
self.preview_url = None
|
||||
|
||||
def generate_unique_filename(self, *_args, **_kwargs):
|
||||
return os.path.basename(self.file_path)
|
||||
|
||||
def update_file_info(self, _path):
|
||||
return None
|
||||
|
||||
def to_dict(self):
|
||||
return {"file_path": self.file_path}
|
||||
|
||||
pause_control = DownloadStreamControl()
|
||||
pause_control.pause()
|
||||
manager._pause_events["download-1"] = pause_control
|
||||
manager._active_downloads["download-1"] = {
|
||||
"status": "downloading",
|
||||
"bytes_per_second": 42.0,
|
||||
}
|
||||
|
||||
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
|
||||
monkeypatch.setattr(
|
||||
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
|
||||
)
|
||||
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
|
||||
|
||||
started = asyncio.Event()
|
||||
allow_finish = asyncio.Event()
|
||||
captured = {"calls": 0}
|
||||
|
||||
async def fake_download_model_file(
|
||||
self,
|
||||
download_url,
|
||||
save_path,
|
||||
*,
|
||||
backend,
|
||||
progress_callback,
|
||||
use_auth,
|
||||
download_id,
|
||||
pause_control,
|
||||
):
|
||||
captured["calls"] += 1
|
||||
started.set()
|
||||
await allow_finish.wait()
|
||||
Path(save_path).write_text("content")
|
||||
return True, save_path
|
||||
|
||||
monkeypatch.setattr(
|
||||
DownloadManager,
|
||||
"_download_model_file",
|
||||
fake_download_model_file,
|
||||
)
|
||||
|
||||
task = asyncio.create_task(
|
||||
manager._execute_download(
|
||||
download_urls=["https://civitai.com/api/download/models/1"],
|
||||
save_dir=str(save_dir),
|
||||
metadata=DummyMetadata(target_path),
|
||||
version_info={"images": []},
|
||||
relative_path="",
|
||||
progress_callback=None,
|
||||
model_type="lora",
|
||||
download_id="download-1",
|
||||
transfer_backend="aria2",
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.sleep(0)
|
||||
assert started.is_set() is False
|
||||
assert captured["calls"] == 0
|
||||
assert manager._active_downloads["download-1"]["status"] == "paused"
|
||||
|
||||
pause_control.resume()
|
||||
await asyncio.wait_for(started.wait(), timeout=1.0)
|
||||
assert captured["calls"] == 1
|
||||
assert manager._active_downloads["download-1"]["status"] == "downloading"
|
||||
|
||||
allow_finish.set()
|
||||
result = await task
|
||||
|
||||
assert result == {"success": True}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_download_rejects_unknown_task():
|
||||
"""Test that pause_download rejects unknown download tasks."""
|
||||
|
||||
@@ -147,6 +147,11 @@ def test_environment_variable_overrides_settings(tmp_path, monkeypatch):
|
||||
assert mgr.get("civitai_api_key") == "secret"
|
||||
|
||||
|
||||
def test_default_download_backend_is_python(manager):
|
||||
assert manager.get("download_backend") == "python"
|
||||
assert manager.get("aria2c_path") == ""
|
||||
|
||||
|
||||
def _create_manager_with_settings(
|
||||
tmp_path, monkeypatch, initial_settings, *, save_spy=None
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user