Files
ComfyUI-Lora-Manager/tests/services/test_download_manager.py
2025-09-25 09:40:25 +08:00

248 lines
7.8 KiB
Python

from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from py.services.download_manager import DownloadManager
from py.services import download_manager
from py.services.service_registry import ServiceRegistry
from py.services.settings_manager import settings
@pytest.fixture(autouse=True)
def reset_download_manager():
"""Ensure each test operates on a fresh singleton."""
DownloadManager._instance = None
yield
DownloadManager._instance = None
@pytest.fixture(autouse=True)
def isolate_settings(monkeypatch, tmp_path):
"""Point settings writes at a temporary directory to avoid touching real files."""
default_settings = settings._get_default_settings()
default_settings.update(
{
"default_lora_root": str(tmp_path),
"default_checkpoint_root": str(tmp_path / "checkpoints"),
"default_embedding_root": str(tmp_path / "embeddings"),
"download_path_templates": {
"lora": "{base_model}/{first_tag}",
"checkpoint": "{base_model}/{first_tag}",
"embedding": "{base_model}/{first_tag}",
},
"base_model_path_mappings": {"BaseModel": "MappedModel"},
}
)
monkeypatch.setattr(settings, "settings", default_settings)
monkeypatch.setattr(type(settings), "_save_settings", lambda self: None)
@pytest.fixture(autouse=True)
def stub_metadata(monkeypatch):
class _StubMetadata:
def __init__(self, save_path: str):
self.file_path = save_path
self.sha256 = "sha256"
self.file_name = Path(save_path).stem
def _factory(save_path: str):
return _StubMetadata(save_path)
def _make_class():
@staticmethod
def from_civitai_info(_version_info, _file_info, save_path):
return _factory(save_path)
return type("StubMetadata", (), {"from_civitai_info": from_civitai_info})
stub_class = _make_class()
monkeypatch.setattr(download_manager, "LoraMetadata", stub_class)
monkeypatch.setattr(download_manager, "CheckpointMetadata", stub_class)
monkeypatch.setattr(download_manager, "EmbeddingMetadata", stub_class)
class DummyScanner:
def __init__(self, exists: bool = False):
self.exists = exists
self.calls = []
async def check_model_version_exists(self, version_id):
self.calls.append(version_id)
return self.exists
@pytest.fixture
def scanners(monkeypatch):
lora_scanner = DummyScanner()
checkpoint_scanner = DummyScanner()
embedding_scanner = DummyScanner()
monkeypatch.setattr(ServiceRegistry, "get_lora_scanner", AsyncMock(return_value=lora_scanner))
monkeypatch.setattr(ServiceRegistry, "get_checkpoint_scanner", AsyncMock(return_value=checkpoint_scanner))
monkeypatch.setattr(ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=embedding_scanner))
return SimpleNamespace(
lora=lora_scanner,
checkpoint=checkpoint_scanner,
embedding=embedding_scanner,
)
@pytest.fixture
def metadata_provider(monkeypatch):
class DummyProvider:
def __init__(self):
self.calls = []
async def get_model_version(self, model_id, model_version_id):
self.calls.append((model_id, model_version_id))
return {
"id": 42,
"model": {"type": "LoRA", "tags": ["fantasy"]},
"baseModel": "BaseModel",
"creator": {"username": "Author"},
"files": [
{
"primary": True,
"downloadUrl": "https://example.invalid/file.safetensors",
"name": "file.safetensors",
}
],
}
provider = DummyProvider()
monkeypatch.setattr(
download_manager,
"get_default_metadata_provider",
AsyncMock(return_value=provider),
)
return provider
@pytest.fixture(autouse=True)
def noop_cleanup(monkeypatch):
async def _cleanup(self, task_id):
if task_id in self._active_downloads:
self._active_downloads[task_id]["cleaned"] = True
monkeypatch.setattr(DownloadManager, "_cleanup_download_record", _cleanup)
async def test_download_requires_identifier():
manager = DownloadManager()
result = await manager.download_from_civitai()
assert result == {
"success": False,
"error": "Either model_id or model_version_id must be provided",
}
async def test_successful_download_uses_defaults(monkeypatch, scanners, metadata_provider, tmp_path):
manager = DownloadManager()
captured = {}
async def fake_execute_download(
self,
*,
download_url,
save_dir,
metadata,
version_info,
relative_path,
progress_callback,
model_type,
download_id,
):
captured.update(
{
"download_url": download_url,
"save_dir": Path(save_dir),
"relative_path": relative_path,
"progress_callback": progress_callback,
"model_type": model_type,
"download_id": download_id,
"metadata_path": metadata.file_path,
}
)
return {"success": True}
monkeypatch.setattr(DownloadManager, "_execute_download", fake_execute_download, raising=False)
result = await manager.download_from_civitai(
model_version_id=99,
save_dir=str(tmp_path),
use_default_paths=True,
progress_callback=None,
source=None,
)
assert result["success"] is True
assert "download_id" in result
assert manager._download_tasks == {}
assert manager._active_downloads[result["download_id"]]["status"] == "completed"
assert captured["relative_path"] == "MappedModel/fantasy"
expected_dir = Path(settings.get("default_lora_root")) / "MappedModel" / "fantasy"
assert captured["save_dir"] == expected_dir
assert captured["model_type"] == "lora"
async def test_download_aborts_when_version_exists(monkeypatch, scanners, metadata_provider):
scanners.lora.exists = True
manager = DownloadManager()
execute_mock = AsyncMock(return_value={"success": True})
monkeypatch.setattr(DownloadManager, "_execute_download", execute_mock)
result = await manager.download_from_civitai(model_version_id=101, save_dir="/tmp")
assert result["success"] is False
assert result["error"] == "Model version already exists in lora library"
assert "download_id" in result
assert execute_mock.await_count == 0
async def test_download_handles_metadata_errors(monkeypatch, scanners):
async def failing_provider(*_args, **_kwargs):
return None
monkeypatch.setattr(
download_manager,
"get_default_metadata_provider",
AsyncMock(return_value=SimpleNamespace(get_model_version=AsyncMock(return_value=None))),
)
manager = DownloadManager()
result = await manager.download_from_civitai(model_version_id=5, save_dir="/tmp")
assert result["success"] is False
assert result["error"] == "Failed to fetch model metadata"
assert "download_id" in result
async def test_download_rejects_unsupported_model_type(monkeypatch, scanners):
class Provider:
async def get_model_version(self, *_args, **_kwargs):
return {
"model": {"type": "Unsupported", "tags": []},
"files": [],
}
monkeypatch.setattr(
download_manager,
"get_default_metadata_provider",
AsyncMock(return_value=Provider()),
)
manager = DownloadManager()
result = await manager.download_from_civitai(model_version_id=5, save_dir="/tmp")
assert result["success"] is False
assert result["error"].startswith("Model type")