from pathlib import Path from types import SimpleNamespace from unittest.mock import AsyncMock import pytest from py.services.download_manager import DownloadManager from py.services import download_manager from py.services.service_registry import ServiceRegistry from py.services.settings_manager import settings @pytest.fixture(autouse=True) def reset_download_manager(): """Ensure each test operates on a fresh singleton.""" DownloadManager._instance = None yield DownloadManager._instance = None @pytest.fixture(autouse=True) def isolate_settings(monkeypatch, tmp_path): """Point settings writes at a temporary directory to avoid touching real files.""" default_settings = settings._get_default_settings() default_settings.update( { "default_lora_root": str(tmp_path), "default_checkpoint_root": str(tmp_path / "checkpoints"), "default_embedding_root": str(tmp_path / "embeddings"), "download_path_templates": { "lora": "{base_model}/{first_tag}", "checkpoint": "{base_model}/{first_tag}", "embedding": "{base_model}/{first_tag}", }, "base_model_path_mappings": {"BaseModel": "MappedModel"}, } ) monkeypatch.setattr(settings, "settings", default_settings) monkeypatch.setattr(type(settings), "_save_settings", lambda self: None) @pytest.fixture(autouse=True) def stub_metadata(monkeypatch): class _StubMetadata: def __init__(self, save_path: str): self.file_path = save_path self.sha256 = "sha256" self.file_name = Path(save_path).stem def _factory(save_path: str): return _StubMetadata(save_path) def _make_class(): @staticmethod def from_civitai_info(_version_info, _file_info, save_path): return _factory(save_path) return type("StubMetadata", (), {"from_civitai_info": from_civitai_info}) stub_class = _make_class() monkeypatch.setattr(download_manager, "LoraMetadata", stub_class) monkeypatch.setattr(download_manager, "CheckpointMetadata", stub_class) monkeypatch.setattr(download_manager, "EmbeddingMetadata", stub_class) class DummyScanner: def __init__(self, exists: bool = False): self.exists = exists self.calls = [] async def check_model_version_exists(self, version_id): self.calls.append(version_id) return self.exists @pytest.fixture def scanners(monkeypatch): lora_scanner = DummyScanner() checkpoint_scanner = DummyScanner() embedding_scanner = DummyScanner() monkeypatch.setattr(ServiceRegistry, "get_lora_scanner", AsyncMock(return_value=lora_scanner)) monkeypatch.setattr(ServiceRegistry, "get_checkpoint_scanner", AsyncMock(return_value=checkpoint_scanner)) monkeypatch.setattr(ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=embedding_scanner)) return SimpleNamespace( lora=lora_scanner, checkpoint=checkpoint_scanner, embedding=embedding_scanner, ) @pytest.fixture def metadata_provider(monkeypatch): class DummyProvider: def __init__(self): self.calls = [] async def get_model_version(self, model_id, model_version_id): self.calls.append((model_id, model_version_id)) return { "id": 42, "model": {"type": "LoRA", "tags": ["fantasy"]}, "baseModel": "BaseModel", "creator": {"username": "Author"}, "files": [ { "primary": True, "downloadUrl": "https://example.invalid/file.safetensors", "name": "file.safetensors", } ], } provider = DummyProvider() monkeypatch.setattr( download_manager, "get_default_metadata_provider", AsyncMock(return_value=provider), ) return provider @pytest.fixture(autouse=True) def noop_cleanup(monkeypatch): async def _cleanup(self, task_id): if task_id in self._active_downloads: self._active_downloads[task_id]["cleaned"] = True monkeypatch.setattr(DownloadManager, "_cleanup_download_record", _cleanup) async def test_download_requires_identifier(): manager = DownloadManager() result = await manager.download_from_civitai() assert result == { "success": False, "error": "Either model_id or model_version_id must be provided", } async def test_successful_download_uses_defaults(monkeypatch, scanners, metadata_provider, tmp_path): manager = DownloadManager() captured = {} async def fake_execute_download( self, *, download_url, save_dir, metadata, version_info, relative_path, progress_callback, model_type, download_id, ): captured.update( { "download_url": download_url, "save_dir": Path(save_dir), "relative_path": relative_path, "progress_callback": progress_callback, "model_type": model_type, "download_id": download_id, "metadata_path": metadata.file_path, } ) return {"success": True} monkeypatch.setattr(DownloadManager, "_execute_download", fake_execute_download, raising=False) result = await manager.download_from_civitai( model_version_id=99, save_dir=str(tmp_path), use_default_paths=True, progress_callback=None, source=None, ) assert result["success"] is True assert "download_id" in result assert manager._download_tasks == {} assert manager._active_downloads[result["download_id"]]["status"] == "completed" assert captured["relative_path"] == "MappedModel/fantasy" expected_dir = Path(settings.get("default_lora_root")) / "MappedModel" / "fantasy" assert captured["save_dir"] == expected_dir assert captured["model_type"] == "lora" async def test_download_aborts_when_version_exists(monkeypatch, scanners, metadata_provider): scanners.lora.exists = True manager = DownloadManager() execute_mock = AsyncMock(return_value={"success": True}) monkeypatch.setattr(DownloadManager, "_execute_download", execute_mock) result = await manager.download_from_civitai(model_version_id=101, save_dir="/tmp") assert result["success"] is False assert result["error"] == "Model version already exists in lora library" assert "download_id" in result assert execute_mock.await_count == 0 async def test_download_handles_metadata_errors(monkeypatch, scanners): async def failing_provider(*_args, **_kwargs): return None monkeypatch.setattr( download_manager, "get_default_metadata_provider", AsyncMock(return_value=SimpleNamespace(get_model_version=AsyncMock(return_value=None))), ) manager = DownloadManager() result = await manager.download_from_civitai(model_version_id=5, save_dir="/tmp") assert result["success"] is False assert result["error"] == "Failed to fetch model metadata" assert "download_id" in result async def test_download_rejects_unsupported_model_type(monkeypatch, scanners): class Provider: async def get_model_version(self, *_args, **_kwargs): return { "model": {"type": "Unsupported", "tags": []}, "files": [], } monkeypatch.setattr( download_manager, "get_default_metadata_provider", AsyncMock(return_value=Provider()), ) manager = DownloadManager() result = await manager.download_from_civitai(model_version_id=5, save_dir="/tmp") assert result["success"] is False assert result["error"].startswith("Model type")