Files
ComfyUI-Lora-Manager/tests/services/test_download_manager_error.py
2026-04-20 09:52:48 +08:00

1610 lines
50 KiB
Python

"""Error handling and execution tests for DownloadManager."""
import asyncio
import os
import zipfile
from datetime import datetime
from pathlib import Path
from types import SimpleNamespace
from typing import Optional
from unittest.mock import AsyncMock
import pytest
from py.services.download_manager import DownloadManager
from py.services.downloader import DownloadStreamControl
from py.services import download_manager
from py.services import aria2_transfer_state
from py.services.service_registry import ServiceRegistry
from py.services.settings_manager import SettingsManager, get_settings_manager
from py.utils.metadata_manager import MetadataManager
@pytest.fixture(autouse=True)
def reset_download_manager():
"""Ensure each test operates on a fresh singleton."""
DownloadManager._instance = None
yield
DownloadManager._instance = None
@pytest.fixture(autouse=True)
def isolate_settings(monkeypatch, tmp_path):
"""Point settings writes at a temporary directory to avoid touching real files."""
manager = get_settings_manager()
default_settings = manager._get_default_settings()
default_settings.update(
{
"default_lora_root": str(tmp_path),
"default_checkpoint_root": str(tmp_path / "checkpoints"),
"default_embedding_root": str(tmp_path / "embeddings"),
"download_path_templates": {
"lora": "{base_model}/{first_tag}",
"checkpoint": "{base_model}/{first_tag}",
"embedding": "{base_model}/{first_tag}",
},
"base_model_path_mappings": {"BaseModel": "MappedModel"},
}
)
monkeypatch.setattr(manager, "settings", default_settings)
monkeypatch.setattr(SettingsManager, "_save_settings", lambda self: None)
@pytest.fixture(autouse=True)
def isolate_aria2_state(monkeypatch, tmp_path):
state_path = tmp_path / "cache" / "aria2" / "downloads.json"
monkeypatch.setattr(
aria2_transfer_state,
"get_aria2_state_path",
lambda: str(state_path),
)
@pytest.mark.asyncio
async def test_execute_download_retries_urls(monkeypatch, tmp_path):
"""Test that download retries multiple URLs on failure."""
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
initial_path = save_dir / "file.safetensors"
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = str(path)
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
def generate_unique_filename(self, *_args, **_kwargs):
return os.path.basename(self.file_path)
def update_file_info(self, _path):
return None
def to_dict(self):
return {"file_path": self.file_path}
metadata = DummyMetadata(initial_path)
version_info = {"images": []}
download_urls = [
"https://first.example/file.safetensors",
"https://second.example/file.safetensors",
]
class DummyDownloader:
def __init__(self):
self.calls = []
async def download_file(self, url, path, progress_callback=None, use_auth=None):
self.calls.append((url, path, use_auth))
if len(self.calls) == 1:
return False, "first failed"
# Create the target file to simulate a successful download
Path(path).write_text("content")
return True, "second success"
dummy_downloader = DummyDownloader()
monkeypatch.setattr(
download_manager, "get_downloader", AsyncMock(return_value=dummy_downloader)
)
class DummyScanner:
def __init__(self):
self.calls = []
async def add_model_to_cache(self, metadata_dict, relative_path):
self.calls.append((metadata_dict, relative_path))
dummy_scanner = DummyScanner()
monkeypatch.setattr(
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
)
monkeypatch.setattr(
DownloadManager,
"_get_checkpoint_scanner",
AsyncMock(return_value=dummy_scanner),
)
monkeypatch.setattr(
ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=dummy_scanner)
)
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
result = await manager._execute_download(
download_urls=download_urls,
save_dir=str(save_dir),
metadata=metadata,
version_info=version_info,
relative_path="",
progress_callback=None,
model_type="lora",
download_id=None,
)
assert result == {"success": True}
assert [url for url, *_ in dummy_downloader.calls] == download_urls
assert dummy_scanner.calls # ensure cache updated
@pytest.mark.asyncio
async def test_execute_download_uses_aria2_backend_for_model_files(monkeypatch, tmp_path):
manager = DownloadManager()
settings = get_settings_manager()
settings.settings["download_backend"] = "aria2"
settings.settings["civitai_api_key"] = "secret-key"
save_dir = tmp_path / "downloads"
save_dir.mkdir()
target_path = save_dir / "file.safetensors"
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = str(path)
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
def generate_unique_filename(self, *_args, **_kwargs):
return os.path.basename(self.file_path)
def update_file_info(self, _path):
return None
def to_dict(self):
return {"file_path": self.file_path}
class DummyAria2Downloader:
def __init__(self):
self.calls = []
async def download_file(
self,
url,
save_path,
*,
download_id,
progress_callback=None,
headers=None,
):
self.calls.append(
{
"url": url,
"save_path": save_path,
"download_id": download_id,
"headers": headers,
}
)
Path(save_path).write_text("content")
return True, save_path
dummy_aria2 = DummyAria2Downloader()
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=dummy_aria2),
)
monkeypatch.setattr(
download_manager,
"get_downloader",
AsyncMock(side_effect=AssertionError("python downloader should not be used")),
)
class DummyScanner:
async def add_model_to_cache(self, metadata_dict, relative_path):
return {"metadata": metadata_dict, "relative_path": relative_path}
dummy_scanner = DummyScanner()
monkeypatch.setattr(
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
)
monkeypatch.setattr(
DownloadManager,
"_get_checkpoint_scanner",
AsyncMock(return_value=dummy_scanner),
)
monkeypatch.setattr(
ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=dummy_scanner)
)
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
result = await manager._execute_download(
download_urls=["https://civitai.com/api/download/models/1"],
save_dir=str(save_dir),
metadata=DummyMetadata(target_path),
version_info={"images": []},
relative_path="",
progress_callback=None,
model_type="lora",
download_id="download-1",
)
assert result == {"success": True}
assert dummy_aria2.calls == [
{
"url": "https://civitai.com/api/download/models/1",
"save_path": str(target_path),
"download_id": "download-1",
"headers": {"Authorization": "Bearer secret-key"},
}
]
@pytest.mark.asyncio
async def test_execute_download_allows_anonymous_civitai_with_aria2(
monkeypatch, tmp_path
):
manager = DownloadManager()
settings = get_settings_manager()
settings.settings["download_backend"] = "aria2"
settings.settings["civitai_api_key"] = ""
save_dir = tmp_path / "downloads"
save_dir.mkdir()
target_path = save_dir / "file.safetensors"
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = str(path)
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
def generate_unique_filename(self, *_args, **_kwargs):
return os.path.basename(self.file_path)
def update_file_info(self, _path):
return None
def to_dict(self):
return {"file_path": self.file_path}
class DummyAria2Downloader:
def __init__(self):
self.calls = []
async def download_file(
self,
url,
save_path,
*,
download_id,
progress_callback=None,
headers=None,
):
self.calls.append({"url": url, "headers": headers, "download_id": download_id})
Path(save_path).write_text("content")
return True, save_path
dummy_aria2 = DummyAria2Downloader()
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=dummy_aria2),
)
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
monkeypatch.setattr(
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
)
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
result = await manager._execute_download(
download_urls=["https://civitai.com/api/download/models/1"],
save_dir=str(save_dir),
metadata=DummyMetadata(target_path),
version_info={"images": []},
relative_path="",
progress_callback=None,
model_type="lora",
download_id="download-2",
)
assert result == {"success": True}
assert dummy_aria2.calls == [
{
"url": "https://civitai.com/api/download/models/1",
"headers": None,
"download_id": "download-2",
}
]
@pytest.mark.asyncio
async def test_execute_download_adjusts_checkpoint_sub_type(monkeypatch, tmp_path):
"""Test that checkpoint sub_type is adjusted during download."""
manager = DownloadManager()
root_dir = tmp_path / "checkpoints"
root_dir.mkdir()
save_dir = root_dir
target_path = save_dir / "model.safetensors"
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = path.as_posix()
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
self.preview_nsfw_level = 0
self.sub_type = "checkpoint"
def generate_unique_filename(self, *_args, **_kwargs):
return os.path.basename(self.file_path)
def update_file_info(self, updated_path):
self.file_path = Path(updated_path).as_posix()
def to_dict(self):
return {
"file_path": self.file_path,
"sub_type": self.sub_type,
"sha256": self.sha256,
}
metadata = DummyMetadata(target_path)
version_info = {"images": []}
download_urls = ["https://example.invalid/model.safetensors"]
class DummyDownloader:
async def download_file(
self, _url, path, progress_callback=None, use_auth=None
):
Path(path).write_text("content")
return True, "ok"
monkeypatch.setattr(
download_manager,
"get_downloader",
AsyncMock(return_value=DummyDownloader()),
)
class DummyCheckpointScanner:
def __init__(self, root: Path):
self.root = root.as_posix()
self.add_calls = []
def _find_root_for_file(self, file_path: str):
return self.root if file_path.startswith(self.root) else None
def adjust_metadata(
self, metadata_obj, _file_path: str, root_path: Optional[str]
):
if root_path:
metadata_obj.sub_type = "diffusion_model"
return metadata_obj
def adjust_cached_entry(self, entry):
if entry.get("file_path", "").startswith(self.root):
entry["sub_type"] = "diffusion_model"
return entry
async def add_model_to_cache(self, metadata_dict, relative_path):
self.add_calls.append((metadata_dict, relative_path))
return True
dummy_scanner = DummyCheckpointScanner(root_dir)
monkeypatch.setattr(
DownloadManager,
"_get_checkpoint_scanner",
AsyncMock(return_value=dummy_scanner),
)
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
result = await manager._execute_download(
download_urls=download_urls,
save_dir=str(save_dir),
metadata=metadata,
version_info=version_info,
relative_path="",
progress_callback=None,
model_type="checkpoint",
download_id=None,
)
assert result == {"success": True}
assert metadata.sub_type == "diffusion_model"
saved_metadata = MetadataManager.save_metadata.await_args.args[1]
assert saved_metadata.sub_type == "diffusion_model"
assert dummy_scanner.add_calls
cached_entry, _ = dummy_scanner.add_calls[0]
assert cached_entry["sub_type"] == "diffusion_model"
@pytest.mark.asyncio
async def test_execute_download_extracts_zip_single_model(monkeypatch, tmp_path):
"""Test extraction of single model from ZIP file."""
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
zip_path = save_dir / "bundle.zip"
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = str(path)
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
def generate_unique_filename(self, *_args, **_kwargs):
return os.path.basename(self.file_path)
def update_file_info(self, updated_path):
self.file_path = str(updated_path)
self.file_name = Path(updated_path).stem
def to_dict(self):
return {"file_path": self.file_path}
metadata = DummyMetadata(zip_path)
version_info = {"images": []}
download_urls = ["https://example.invalid/model.zip"]
class DummyDownloader:
async def download_file(self, *_args, **_kwargs):
with zipfile.ZipFile(str(zip_path), "w") as archive:
archive.writestr("inner/model.safetensors", b"model")
archive.writestr("docs/readme.txt", b"ignore")
return True, "ok"
monkeypatch.setattr(
download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader())
)
class ImmediateLoop:
async def run_in_executor(self, executor, func, *args):
return func(*args)
monkeypatch.setattr(download_manager.asyncio, "get_running_loop", lambda: ImmediateLoop())
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
monkeypatch.setattr(
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
)
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
result = await manager._execute_download(
download_urls=download_urls,
save_dir=str(save_dir),
metadata=metadata,
version_info=version_info,
relative_path="",
progress_callback=None,
model_type="lora",
download_id=None,
)
assert result == {"success": True}
assert not zip_path.exists()
extracted = save_dir / "model.safetensors"
assert extracted.exists()
saved_call = MetadataManager.save_metadata.await_args
assert saved_call.args[0] == str(extracted)
# SHA256 comes from metadata (API value), not recalculated
assert saved_call.args[1].sha256 == "sha256"
assert dummy_scanner.add_model_to_cache.await_count == 1
@pytest.mark.asyncio
async def test_execute_download_extracts_zip_multiple_models(monkeypatch, tmp_path):
"""Test extraction of multiple models from ZIP file."""
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
zip_path = save_dir / "bundle.zip"
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = str(path)
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
def generate_unique_filename(self, *_args, **_kwargs):
return os.path.basename(self.file_path)
def update_file_info(self, updated_path):
self.file_path = str(updated_path)
self.file_name = Path(updated_path).stem
def to_dict(self):
return {"file_path": self.file_path}
metadata = DummyMetadata(zip_path)
version_info = {"images": []}
download_urls = ["https://example.invalid/model.zip"]
class DummyDownloader:
async def download_file(self, *_args, **_kwargs):
with zipfile.ZipFile(str(zip_path), "w") as archive:
archive.writestr("first/model-one.safetensors", b"one")
archive.writestr("second/model-two.safetensors", b"two")
archive.writestr("readme.md", b"ignore")
return True, "ok"
monkeypatch.setattr(
download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader())
)
class ImmediateLoop:
async def run_in_executor(self, executor, func, *args):
return func(*args)
monkeypatch.setattr(download_manager.asyncio, "get_running_loop", lambda: ImmediateLoop())
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
monkeypatch.setattr(
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
)
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
result = await manager._execute_download(
download_urls=download_urls,
save_dir=str(save_dir),
metadata=metadata,
version_info=version_info,
relative_path="",
progress_callback=None,
model_type="lora",
download_id=None,
)
assert result == {"success": True}
assert not zip_path.exists()
extracted_one = save_dir / "model-one.safetensors"
extracted_two = save_dir / "model-two.safetensors"
assert extracted_one.exists()
assert extracted_two.exists()
assert MetadataManager.save_metadata.await_count == 2
assert dummy_scanner.add_model_to_cache.await_count == 2
metadata_calls = MetadataManager.save_metadata.await_args_list
assert metadata_calls[0].args[0] == str(extracted_one)
# SHA256 comes from metadata (API value), not recalculated
assert metadata_calls[0].args[1].sha256 == "sha256"
assert metadata_calls[1].args[0] == str(extracted_two)
assert metadata_calls[1].args[1].sha256 == "sha256"
@pytest.mark.asyncio
async def test_execute_download_extracts_zip_pt_embedding(monkeypatch, tmp_path):
"""Test extraction of .pt embedding files from ZIP."""
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
zip_path = save_dir / "bundle.zip"
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = str(path)
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
def generate_unique_filename(self, *_args, **_kwargs):
return os.path.basename(self.file_path)
def update_file_info(self, updated_path):
self.file_path = str(updated_path)
self.file_name = Path(updated_path).stem
def to_dict(self):
return {"file_path": self.file_path}
metadata = DummyMetadata(zip_path)
version_info = {"images": []}
download_urls = ["https://example.invalid/model.zip"]
class DummyDownloader:
async def download_file(self, *_args, **_kwargs):
with zipfile.ZipFile(str(zip_path), "w") as archive:
archive.writestr("inner/embedding.pt", b"embedding")
archive.writestr("docs/readme.txt", b"ignore")
return True, "ok"
monkeypatch.setattr(
download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader())
)
class ImmediateLoop:
async def run_in_executor(self, executor, func, *args):
return func(*args)
monkeypatch.setattr(download_manager.asyncio, "get_running_loop", lambda: ImmediateLoop())
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
monkeypatch.setattr(
ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=dummy_scanner)
)
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
result = await manager._execute_download(
download_urls=download_urls,
save_dir=str(save_dir),
metadata=metadata,
version_info=version_info,
relative_path="",
progress_callback=None,
model_type="embedding",
download_id=None,
)
assert result == {"success": True}
assert not zip_path.exists()
extracted = save_dir / "embedding.pt"
assert extracted.exists()
saved_call = MetadataManager.save_metadata.await_args
assert saved_call.args[0] == str(extracted)
# SHA256 comes from metadata (API value), not recalculated
assert saved_call.args[1].sha256 == "sha256"
assert dummy_scanner.add_model_to_cache.await_count == 1
@pytest.mark.asyncio
async def test_extract_model_files_from_archive_uses_executor(monkeypatch, tmp_path):
manager = DownloadManager()
archive_path = tmp_path / "bundle.zip"
with zipfile.ZipFile(archive_path, "w") as archive:
archive.writestr("inner/model.safetensors", b"model")
captured = {}
class ImmediateLoop:
async def run_in_executor(self, executor, func, *args):
captured["executor"] = executor
return func(*args)
monkeypatch.setattr(
download_manager.asyncio,
"get_running_loop",
lambda: ImmediateLoop(),
)
extracted = await manager._extract_model_files_from_archive(
str(archive_path),
{".safetensors"},
)
assert captured["executor"] is manager._archive_executor
assert len(extracted) == 1
assert extracted[0].endswith("model.safetensors")
@pytest.mark.asyncio
async def test_pause_download_updates_state():
"""Test that pause_download updates download state correctly."""
manager = DownloadManager()
download_id = "dl"
manager._download_tasks[download_id] = object()
pause_control = DownloadStreamControl()
manager._pause_events[download_id] = pause_control
manager._active_downloads[download_id] = {
"status": "downloading",
"bytes_per_second": 42.0,
}
result = await manager.pause_download(download_id)
assert result == {"success": True, "message": "Download paused successfully"}
assert download_id in manager._pause_events
assert manager._pause_events[download_id].is_set() is False
assert manager._active_downloads[download_id]["status"] == "paused"
assert manager._active_downloads[download_id]["bytes_per_second"] == 0.0
@pytest.mark.asyncio
async def test_pause_download_reverts_local_pause_when_aria2_pause_fails(monkeypatch):
manager = DownloadManager()
download_id = "dl"
manager._download_tasks[download_id] = object()
pause_control = DownloadStreamControl()
manager._pause_events[download_id] = pause_control
manager._active_downloads[download_id] = {
"transfer_backend": "aria2",
"status": "downloading",
"bytes_per_second": 42.0,
}
class DummyAria2Downloader:
async def has_transfer(self, _download_id):
return True
async def pause_download(self, _download_id):
return {"success": False, "error": "rpc failed"}
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
result = await manager.pause_download(download_id)
assert result == {"success": False, "error": "rpc failed"}
assert pause_control.is_set() is True
assert manager._active_downloads[download_id]["status"] == "downloading"
@pytest.mark.asyncio
async def test_pause_download_reverts_local_pause_when_aria2_probe_raises(monkeypatch):
manager = DownloadManager()
download_id = "dl"
manager._download_tasks[download_id] = object()
pause_control = DownloadStreamControl()
manager._pause_events[download_id] = pause_control
manager._active_downloads[download_id] = {
"transfer_backend": "aria2",
"status": "downloading",
"bytes_per_second": 42.0,
}
class DummyAria2Downloader:
async def has_transfer(self, _download_id):
raise RuntimeError("rpc unavailable")
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
result = await manager.pause_download(download_id)
assert result == {"success": False, "error": "rpc unavailable"}
assert pause_control.is_set() is True
assert manager._active_downloads[download_id]["status"] == "downloading"
@pytest.mark.asyncio
async def test_resume_download_returns_error_when_aria2_probe_raises(monkeypatch):
manager = DownloadManager()
download_id = "dl"
pause_control = DownloadStreamControl()
pause_control.pause()
manager._pause_events[download_id] = pause_control
manager._active_downloads[download_id] = {
"transfer_backend": "aria2",
"status": "paused",
"bytes_per_second": 0.0,
}
class DummyAria2Downloader:
async def has_transfer(self, _download_id):
raise RuntimeError("rpc unavailable")
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
result = await manager.resume_download(download_id)
assert result == {"success": False, "error": "rpc unavailable"}
assert pause_control.is_paused() is True
assert manager._active_downloads[download_id]["status"] == "paused"
@pytest.mark.asyncio
async def test_resume_download_does_not_spawn_restored_worker_when_aria2_resume_fails(
monkeypatch, tmp_path
):
manager = DownloadManager()
download_id = "dl"
save_path = tmp_path / "file.safetensors"
pause_control = DownloadStreamControl()
pause_control.pause()
manager._pause_events[download_id] = pause_control
manager._active_downloads[download_id] = {
"transfer_backend": "aria2",
"status": "paused",
"bytes_per_second": 0.0,
}
await manager._aria2_state_store.upsert(
download_id,
{
"download_id": download_id,
"transfer_backend": "aria2",
"status": "paused",
"save_path": str(save_path),
"file_path": str(save_path),
"model_id": 12,
"model_version_id": 34,
"resume_context": {
"version_info": {"id": 34, "modelId": 12, "model": {"id": 12}},
"file_info": {
"name": "file.safetensors",
"downloadUrl": "https://example.com/file.safetensors",
},
"model_type": "lora",
"relative_path": "",
"save_dir": str(tmp_path),
"download_urls": ["https://example.com/file.safetensors"],
},
},
)
resume_restored = AsyncMock(return_value={"success": True})
monkeypatch.setattr(manager, "_resume_restored_aria2_download", resume_restored)
class DummyAria2Downloader:
async def has_transfer(self, _download_id):
return True
async def resume_download(self, _download_id):
return {"success": False, "error": "rpc unavailable"}
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
result = await manager.resume_download(download_id)
assert result == {"success": False, "error": "rpc unavailable"}
assert download_id not in manager._download_tasks
assert resume_restored.await_count == 0
assert pause_control.is_paused() is True
assert manager._active_downloads[download_id]["status"] == "paused"
@pytest.mark.asyncio
async def test_start_background_download_task_cleans_up_finished_restore_task():
manager = DownloadManager()
download_id = "download-1"
manager._pause_events[download_id] = DownloadStreamControl()
async def finished_restore():
return {"success": True}
task = manager._start_background_download_task(download_id, finished_restore())
await task
await asyncio.sleep(0)
assert download_id not in manager._download_tasks
assert download_id not in manager._pause_events
@pytest.mark.asyncio
async def test_cancel_download_still_cancels_local_task_when_aria2_raises(monkeypatch):
manager = DownloadManager()
started = asyncio.Event()
async def blocked_task():
started.set()
await asyncio.sleep(60)
task = asyncio.create_task(blocked_task())
await started.wait()
download_id = "download-queued"
manager._download_tasks[download_id] = task
manager._active_downloads[download_id] = {
"transfer_backend": "aria2",
"status": "queued",
}
class DummyAria2Downloader:
async def cancel_download(self, _download_id):
raise RuntimeError("rpc unavailable")
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
result = await manager.cancel_download(download_id)
assert result["success"] is True
assert task.cancelled() or task.done()
@pytest.mark.asyncio
async def test_cancel_download_preserves_tracking_when_aria2_returns_error(monkeypatch, tmp_path):
manager = DownloadManager()
download_id = "download-queued"
save_path = tmp_path / "file.safetensors"
save_path.write_text("partial")
(tmp_path / "file.safetensors.aria2").write_text("control")
pause_control = DownloadStreamControl()
manager._pause_events[download_id] = pause_control
manager._download_tasks[download_id] = object()
manager._active_downloads[download_id] = {
"transfer_backend": "aria2",
"status": "downloading",
"file_path": str(save_path),
}
await manager._aria2_state_store.upsert(
download_id,
{
"download_id": download_id,
"transfer_backend": "aria2",
"status": "downloading",
"save_path": str(save_path),
"file_path": str(save_path),
},
)
cleanup_files = AsyncMock(return_value=None)
monkeypatch.setattr(manager, "_cleanup_cancelled_download_files", cleanup_files)
class DummyAria2Downloader:
async def cancel_download(self, _download_id):
return {"success": False, "error": "rpc unavailable"}
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
result = await manager.cancel_download(download_id)
assert result == {"success": False, "error": "rpc unavailable"}
assert download_id in manager._download_tasks
assert download_id in manager._pause_events
assert await manager._aria2_state_store.get(download_id) is not None
assert cleanup_files.await_count == 0
@pytest.mark.asyncio
async def test_cancel_download_rejects_completed_history_entry(tmp_path):
manager = DownloadManager()
download_id = "completed-download"
save_path = tmp_path / "file.safetensors"
metadata_path = tmp_path / "file.metadata.json"
preview_path = tmp_path / "file.jpeg"
save_path.write_text("complete")
metadata_path.write_text("{}")
preview_path.write_text("preview")
manager._active_downloads[download_id] = {
"transfer_backend": "aria2",
"status": "completed",
"file_path": str(save_path),
"preview_path": str(preview_path),
}
result = await manager.cancel_download(download_id)
assert result == {"success": False, "error": "Download task not found"}
assert save_path.exists()
assert metadata_path.exists()
assert preview_path.exists()
@pytest.mark.asyncio
async def test_cancel_download_removes_preview_and_aria2_control_files(monkeypatch, tmp_path):
manager = DownloadManager()
started = asyncio.Event()
async def blocked_task():
started.set()
await asyncio.sleep(60)
task = asyncio.create_task(blocked_task())
await started.wait()
save_path = tmp_path / "file.safetensors"
save_path.write_text("partial")
aria2_path = tmp_path / "file.safetensors.aria2"
aria2_path.write_text("control")
preview_path = tmp_path / "file.jpeg"
preview_path.write_text("preview")
download_id = "download-queued"
manager._download_tasks[download_id] = task
manager._active_downloads[download_id] = {
"transfer_backend": "aria2",
"status": "queued",
"file_path": str(save_path),
"aria2_control_path": str(aria2_path),
"preview_path": str(preview_path),
}
class DummyAria2Downloader:
async def cancel_download(self, _download_id):
return {"success": True, "message": "cancelled"}
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
result = await manager.cancel_download(download_id)
assert result["success"] is True
assert not save_path.exists()
assert not aria2_path.exists()
assert not preview_path.exists()
@pytest.mark.asyncio
async def test_cancel_download_does_not_delete_untracked_same_basename_preview(
monkeypatch, tmp_path
):
manager = DownloadManager()
started = asyncio.Event()
async def blocked_task():
started.set()
await asyncio.sleep(60)
task = asyncio.create_task(blocked_task())
await started.wait()
save_path = tmp_path / "file.safetensors"
save_path.write_text("partial")
aria2_path = tmp_path / "file.safetensors.aria2"
aria2_path.write_text("control")
manual_preview_path = tmp_path / "file.jpg"
manual_preview_path.write_text("manual")
download_id = "download-queued"
manager._download_tasks[download_id] = task
manager._active_downloads[download_id] = {
"transfer_backend": "aria2",
"status": "queued",
"file_path": str(save_path),
"aria2_control_path": str(aria2_path),
}
class DummyAria2Downloader:
async def cancel_download(self, _download_id):
return {"success": True, "message": "cancelled"}
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
result = await manager.cancel_download(download_id)
assert result["success"] is True
assert not save_path.exists()
assert not aria2_path.exists()
assert manual_preview_path.exists()
@pytest.mark.asyncio
async def test_cleanup_cancelled_download_files_retries_aria2_control_deletion(
monkeypatch, tmp_path
):
manager = DownloadManager()
download_id = "download-1"
save_path = tmp_path / "file.safetensors"
aria2_path = tmp_path / "file.safetensors.aria2"
save_path.write_text("partial")
aria2_path.write_text("control")
original_unlink = os.unlink
attempts = {"count": 0}
def flaky_unlink(path):
if path == str(aria2_path) and attempts["count"] == 0:
attempts["count"] += 1
raise PermissionError("still locked")
return original_unlink(path)
monkeypatch.setattr(download_manager.os, "unlink", flaky_unlink)
monkeypatch.setattr("py.services.download_manager.asyncio.sleep", AsyncMock())
await manager._cleanup_cancelled_download_files(
download_id,
{
"file_path": str(save_path),
"aria2_control_path": str(aria2_path),
"transfer_backend": "aria2",
},
)
assert attempts["count"] == 1
assert not save_path.exists()
assert not aria2_path.exists()
@pytest.mark.asyncio
async def test_execute_download_waits_for_paused_pre_transfer_gate(monkeypatch, tmp_path):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
target_path = save_dir / "file.safetensors"
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = str(path)
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
def generate_unique_filename(self, *_args, **_kwargs):
return os.path.basename(self.file_path)
def update_file_info(self, _path):
return None
def to_dict(self):
return {"file_path": self.file_path}
pause_control = DownloadStreamControl()
pause_control.pause()
manager._pause_events["download-1"] = pause_control
manager._active_downloads["download-1"] = {
"status": "downloading",
"bytes_per_second": 42.0,
}
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
monkeypatch.setattr(
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
)
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
started = asyncio.Event()
allow_finish = asyncio.Event()
captured = {"calls": 0}
async def fake_download_model_file(
self,
download_url,
save_path,
*,
backend,
progress_callback,
use_auth,
download_id,
pause_control,
):
captured["calls"] += 1
started.set()
await allow_finish.wait()
Path(save_path).write_text("content")
return True, save_path
monkeypatch.setattr(
DownloadManager,
"_download_model_file",
fake_download_model_file,
)
task = asyncio.create_task(
manager._execute_download(
download_urls=["https://civitai.com/api/download/models/1"],
save_dir=str(save_dir),
metadata=DummyMetadata(target_path),
version_info={"images": []},
relative_path="",
progress_callback=None,
model_type="lora",
download_id="download-1",
transfer_backend="aria2",
)
)
await asyncio.sleep(0)
assert started.is_set() is False
assert captured["calls"] == 0
assert manager._active_downloads["download-1"]["status"] == "paused"
pause_control.resume()
await asyncio.wait_for(started.wait(), timeout=1.0)
assert captured["calls"] == 1
assert manager._active_downloads["download-1"]["status"] == "downloading"
allow_finish.set()
result = await task
assert result == {"success": True}
@pytest.mark.asyncio
async def test_execute_download_reuses_existing_aria2_partial_path(monkeypatch, tmp_path):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
target_path = save_dir / "file.safetensors"
target_path.write_text("partial")
control_path = save_dir / "file.safetensors.aria2"
control_path.write_text("control")
await manager._aria2_state_store.upsert(
"download-1",
{
"download_id": "download-1",
"transfer_backend": "aria2",
"save_path": str(target_path),
"file_path": str(target_path),
"status": "paused",
},
)
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = str(path)
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
def generate_unique_filename(self, *_args, **_kwargs):
return "renamed.safetensors"
def update_file_info(self, _path):
return None
def to_dict(self):
return {"file_path": self.file_path}
manager._active_downloads["download-1"] = {"transfer_backend": "aria2"}
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
monkeypatch.setattr(
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
)
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
async def fake_download_model_file(
self,
download_url,
save_path,
*,
backend,
progress_callback,
use_auth,
download_id,
pause_control,
):
Path(save_path).write_text("content")
return True, save_path
monkeypatch.setattr(DownloadManager, "_download_model_file", fake_download_model_file)
result = await manager._execute_download(
download_urls=["https://example.com/file.safetensors"],
save_dir=str(save_dir),
metadata=DummyMetadata(target_path),
version_info={"images": []},
relative_path="",
progress_callback=None,
model_type="lora",
download_id="download-1",
transfer_backend="aria2",
)
assert result == {"success": True}
assert manager._active_downloads["download-1"]["file_path"] == str(target_path)
assert not (save_dir / "renamed.safetensors").exists()
@pytest.mark.asyncio
async def test_execute_download_rejects_conflicting_aria2_partial_path(tmp_path):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
target_path = save_dir / "file.safetensors"
target_path.write_text("partial")
(save_dir / "file.safetensors.aria2").write_text("control")
await manager._aria2_state_store.upsert(
"other-download",
{
"download_id": "other-download",
"transfer_backend": "aria2",
"save_path": str(target_path),
"file_path": str(target_path),
"status": "paused",
},
)
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = str(path)
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
def generate_unique_filename(self, *_args, **_kwargs):
raise AssertionError("should not rename")
result = await manager._execute_download(
download_urls=["https://example.com/file.safetensors"],
save_dir=str(save_dir),
metadata=DummyMetadata(target_path),
version_info={"images": []},
relative_path="",
progress_callback=None,
model_type="lora",
download_id="download-1",
transfer_backend="aria2",
)
assert result["success"] is False
assert "already using" in result["error"]
@pytest.mark.asyncio
async def test_execute_download_reassigns_same_aria2_partial_to_new_download_id(
monkeypatch, tmp_path
):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
target_path = save_dir / "file.safetensors"
target_path.write_text("partial")
(save_dir / "file.safetensors.aria2").write_text("control")
await manager._aria2_state_store.upsert(
"old-download",
{
"download_id": "old-download",
"transfer_backend": "aria2",
"save_path": str(target_path),
"file_path": str(target_path),
"status": "paused",
"model_id": 11,
"model_version_id": 22,
},
)
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = str(path)
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
def generate_unique_filename(self, *_args, **_kwargs):
raise AssertionError("should not rename")
def update_file_info(self, _path):
return None
def to_dict(self):
return {"file_path": self.file_path}
class DummyAria2Downloader:
def __init__(self):
self.calls = []
async def reassign_transfer(self, previous_download_id, new_download_id):
self.calls.append(("reassign_transfer", previous_download_id, new_download_id))
return None
dummy_aria2 = DummyAria2Downloader()
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=dummy_aria2),
)
manager._active_downloads["old-download"] = {
"transfer_backend": "aria2",
"model_id": 11,
"model_version_id": 22,
"status": "paused",
}
manager._active_downloads["new-download"] = {
"transfer_backend": "aria2",
"model_id": 11,
"model_version_id": 22,
"status": "queued",
}
resolved, path = await manager._resolve_download_target_path(
str(save_dir),
DummyMetadata(target_path),
transfer_backend="aria2",
download_id="new-download",
)
assert resolved is True
assert path == str(target_path)
assert "old-download" not in manager._active_downloads
assert manager._active_downloads["new-download"]["file_path"] == str(target_path)
assert dummy_aria2.calls == [("reassign_transfer", "old-download", "new-download")]
assert await manager._aria2_state_store.get("old-download") is None
assert (await manager._aria2_state_store.get("new-download"))["save_path"] == str(
target_path
)
def test_is_same_aria2_download_request_requires_version_id_match():
manager = DownloadManager()
assert (
manager._is_same_aria2_download_request(
{"model_id": 1, "model_version_id": None},
{"model_id": 1, "model_version_id": 2},
)
is False
)
assert (
manager._is_same_aria2_download_request(
{"model_id": 1, "model_version_id": 3},
{"model_id": 1, "model_version_id": None},
)
is False
)
@pytest.mark.asyncio
async def test_adopt_existing_aria2_download_cancels_old_running_task(monkeypatch, tmp_path):
manager = DownloadManager()
save_path = tmp_path / "file.safetensors"
started = asyncio.Event()
cancelled = asyncio.Event()
call_order = []
async def old_download():
started.set()
try:
await asyncio.sleep(60)
except asyncio.CancelledError:
call_order.append("old-task-cancelled")
cancelled.set()
raise
old_task = asyncio.create_task(old_download())
await started.wait()
manager._download_tasks["old-download"] = old_task
old_pause_control = DownloadStreamControl()
old_pause_control.pause()
manager._pause_events["old-download"] = old_pause_control
manager._active_downloads["old-download"] = {
"transfer_backend": "aria2",
"model_id": 11,
"model_version_id": 22,
"status": "downloading",
}
manager._active_downloads["new-download"] = {
"transfer_backend": "aria2",
"model_id": 11,
"model_version_id": 22,
"status": "queued",
}
await manager._aria2_state_store.upsert(
"old-download",
{
"download_id": "old-download",
"transfer_backend": "aria2",
"save_path": str(save_path),
"file_path": str(save_path),
"status": "downloading",
"model_id": 11,
"model_version_id": 22,
},
)
class DummyAria2Downloader:
async def reassign_transfer(self, previous_download_id, new_download_id):
call_order.append("reassign-transfer")
return None
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
await manager._adopt_existing_aria2_download(
"old-download",
"new-download",
{"model_id": 11, "model_version_id": 22},
str(save_path),
)
assert cancelled.is_set() is True
assert "old-download" not in manager._download_tasks
assert call_order == ["reassign-transfer", "old-task-cancelled"]
@pytest.mark.asyncio
async def test_pause_download_rejects_unknown_task():
"""Test that pause_download rejects unknown download tasks."""
manager = DownloadManager()
result = await manager.pause_download("missing")
assert result == {"success": False, "error": "Download task not found"}
@pytest.mark.asyncio
async def test_resume_download_sets_event_and_status():
"""Test that resume_download sets event and updates status."""
manager = DownloadManager()
download_id = "dl"
pause_control = DownloadStreamControl()
pause_control.pause()
pause_control.mark_progress()
manager._pause_events[download_id] = pause_control
manager._active_downloads[download_id] = {
"status": "paused",
"bytes_per_second": 0.0,
}
result = await manager.resume_download(download_id)
assert result == {"success": True, "message": "Download resumed successfully"}
assert manager._pause_events[download_id].is_set() is True
assert manager._active_downloads[download_id]["status"] == "downloading"
@pytest.mark.asyncio
async def test_resume_download_requests_reconnect_for_stalled_stream():
"""Test that resume_download requests reconnect for stalled streams."""
manager = DownloadManager()
download_id = "dl"
pause_control = DownloadStreamControl(stall_timeout=40)
pause_control.pause()
pause_control.last_progress_timestamp = datetime.now().timestamp() - 120
manager._pause_events[download_id] = pause_control
manager._active_downloads[download_id] = {
"status": "paused",
"bytes_per_second": 0.0,
}
result = await manager.resume_download(download_id)
assert result == {"success": True, "message": "Download resumed successfully"}
assert pause_control.is_set() is True
assert pause_control.has_reconnect_request() is True
@pytest.mark.asyncio
async def test_resume_download_rejects_when_not_paused():
"""Test that resume_download rejects when download is not paused."""
manager = DownloadManager()
download_id = "dl"
pause_control = DownloadStreamControl()
manager._pause_events[download_id] = pause_control
result = await manager.resume_download(download_id)
assert result == {"success": False, "error": "Download is not paused"}