import asyncio import importlib.util import inspect import sys import types from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional, Sequence from unittest import mock import pytest REPO_ROOT = Path(__file__).resolve().parents[1] PY_INIT = REPO_ROOT / "py" / "__init__.py" class MockModule(types.ModuleType): """A mock module class that is hashable (unlike SimpleNamespace). This allows the module to be stored in sets/dicts without causing issues with tools like Hypothesis that iterate over sys.modules. """ def __init__(self, name: str, **kwargs): super().__init__(name) for key, value in kwargs.items(): setattr(self, key, value) def __hash__(self): return hash(self.__name__) def __eq__(self, other): if isinstance(other, MockModule): return self.__name__ == other.__name__ return NotImplemented def _load_repo_package(name: str) -> types.ModuleType: """Ensure the repository's ``py`` package is importable under *name*.""" module = sys.modules.get(name) if module and getattr(module, "__file__", None) == str(PY_INIT): return module spec = importlib.util.spec_from_file_location( name, PY_INIT, submodule_search_locations=[str(PY_INIT.parent)], ) if spec is None or spec.loader is None: # pragma: no cover - initialization guard raise ImportError(f"Unable to load repository package for alias '{name}'") package = importlib.util.module_from_spec(spec) spec.loader.exec_module(package) # type: ignore[attr-defined] package.__path__ = [str(PY_INIT.parent)] # type: ignore[attr-defined] sys.modules[name] = package return package _repo_package = _load_repo_package("py") sys.modules.setdefault("py_local", _repo_package) # Mock ComfyUI modules before any imports from the main project server_mock = MockModule("server") server_mock.PromptServer = mock.MagicMock() sys.modules['server'] = server_mock folder_paths_mock = MockModule("folder_paths") folder_paths_mock.get_folder_paths = mock.MagicMock(return_value=[]) folder_paths_mock.folder_names_and_paths = {} sys.modules['folder_paths'] = folder_paths_mock # Mock other ComfyUI modules that might be imported comfy_mock = MockModule("comfy") comfy_mock.utils = MockModule("comfy.utils") comfy_mock.model_management = MockModule("comfy.model_management") comfy_mock.comfy_types = MockModule("comfy.comfy_types") comfy_mock.comfy_types.IO = mock.MagicMock() sys.modules['comfy'] = comfy_mock sys.modules['comfy.utils'] = comfy_mock.utils sys.modules['comfy.model_management'] = comfy_mock.model_management sys.modules['comfy.comfy_types'] = comfy_mock.comfy_types execution_mock = MockModule("execution") execution_mock.PromptExecutor = mock.MagicMock() sys.modules['execution'] = execution_mock # Mock ComfyUI nodes module nodes_mock = MockModule("nodes") nodes_mock.LoraLoader = mock.MagicMock() nodes_mock.SaveImage = mock.MagicMock() nodes_mock.NODE_CLASS_MAPPINGS = {} sys.modules['nodes'] = nodes_mock @pytest.fixture(autouse=True) def _isolate_settings_dir(tmp_path_factory, monkeypatch, request): """Redirect settings.json into a temporary directory for each test.""" if request.node.get_closest_marker("no_settings_dir_isolation"): from py.services import settings_manager as settings_manager_module settings_manager_module.reset_settings_manager() yield settings_manager_module.reset_settings_manager() return settings_dir = tmp_path_factory.mktemp("settings_dir") def fake_get_settings_dir(create: bool = True) -> str: if create: settings_dir.mkdir(exist_ok=True) return str(settings_dir) monkeypatch.setattr("py.utils.settings_paths.get_settings_dir", fake_get_settings_dir) monkeypatch.setattr( "py.utils.settings_paths.user_config_dir", lambda *_args, **_kwargs: str(settings_dir), ) from py.services import settings_manager as settings_manager_module settings_manager_module.reset_settings_manager() yield settings_manager_module.reset_settings_manager() @dataclass class MockHashIndex: """Minimal hash index stub mirroring the scanner contract.""" removed_paths: List[str] = field(default_factory=list) def remove_by_path(self, path: str) -> None: self.removed_paths.append(path) class MockCache: """Cache object with the attributes.""" def __init__(self, items: Optional[Sequence[Dict[str, Any]]] = None): self.raw_data: List[Dict[str, Any]] = list(items or []) self.resort_calls = 0 async def resort(self) -> None: self.resort_calls += 1 # expects the coroutine interface but does not # rely on the return value. class MockScanner: """Scanner double that exposes the attributes used by route utilities.""" def __init__(self, cache: Optional[MockCache] = None, hash_index: Optional[MockHashIndex] = None): self._cache = cache or MockCache() self._hash_index = hash_index or MockHashIndex() self._tags_count: Dict[str, int] = {} self._excluded_models: List[str] = [] self.updated_models: List[Dict[str, Any]] = [] self.preview_updates: List[Dict[str, Any]] = [] self.bulk_deleted: List[Sequence[str]] = [] self._cancelled = False self.model_type = "test-model" def is_cancelled(self) -> bool: return self._cancelled def cancel_task(self) -> None: self._cancelled = True def reset_cancellation(self) -> None: self._cancelled = False async def get_cached_data(self, force_refresh: bool = False): return self._cache async def update_single_model_cache(self, original_path: str, new_path: str, metadata: Dict[str, Any]) -> bool: self.updated_models.append({ "original_path": original_path, "new_path": new_path, "metadata": metadata, }) for item in self._cache.raw_data: if item.get("file_path") == original_path: item.update(metadata) return True async def update_preview_in_cache(self, model_path: str, preview_path: str, nsfw_level: int) -> bool: self.preview_updates.append({ "model_path": model_path, "preview_path": preview_path, "nsfw_level": nsfw_level, }) for item in self._cache.raw_data: if item.get("file_path") == model_path: item["preview_url"] = preview_path item["preview_nsfw_level"] = nsfw_level return True async def bulk_delete_models(self, file_paths: Sequence[str]) -> Dict[str, Any]: self.bulk_deleted.append(tuple(file_paths)) self._cache.raw_data = [item for item in self._cache.raw_data if item.get("file_path") not in file_paths] await self._cache.resort() for path in file_paths: self._hash_index.remove_by_path(path) return {"success": True, "deleted": list(file_paths)} class MockModelService: """Service stub consumed by the shared routes.""" def __init__(self, scanner: MockScanner): self.scanner = scanner self.model_type = "test-model" self.paginated_items: List[Dict[str, Any]] = [] self.formatted: List[Dict[str, Any]] = [] self.model_types: List[Dict[str, Any]] = [] async def get_paginated_data(self, **params: Any) -> Dict[str, Any]: items = [dict(item) for item in self.paginated_items] total = len(items) page = params.get("page", 1) page_size = params.get("page_size", 20) return { "items": items, "total": total, "page": page, "page_size": page_size, "total_pages": max(1, (total + page_size - 1) // page_size), } async def format_response(self, item: Dict[str, Any]) -> Dict[str, Any]: formatted = {**item, "formatted": True} self.formatted.append(formatted) return formatted # Convenience helpers used by assorted routes. They are no-ops for the # smoke tests but document the expected surface area of the real services. def get_model_roots(self) -> List[str]: return ["."] async def scan_models(self, *_, **__): # pragma: no cover - behaviour exercised via mocks return None async def get_model_notes(self, *_args, **_kwargs): # pragma: no cover return None async def get_model_preview_url(self, *_args, **_kwargs): # pragma: no cover return "" async def get_model_civitai_url(self, *_args, **_kwargs): # pragma: no cover return {"civitai_url": ""} async def get_model_metadata(self, *_args, **_kwargs): # pragma: no cover return {} async def get_model_description(self, *_args, **_kwargs): # pragma: no cover return "" async def get_relative_paths(self, *_args, **_kwargs): # pragma: no cover return [] async def get_model_types(self, limit: int = 20): return list(self.model_types)[:limit] def has_hash(self, *_args, **_kwargs): # pragma: no cover return False def get_path_by_hash(self, *_args, **_kwargs): # pragma: no cover return "" @pytest.fixture def mock_hash_index() -> MockHashIndex: return MockHashIndex() @pytest.fixture def mock_cache() -> MockCache: return MockCache() @pytest.fixture def mock_scanner(mock_cache: MockCache, mock_hash_index: MockHashIndex) -> MockScanner: return MockScanner(cache=mock_cache, hash_index=mock_hash_index) @pytest.fixture def mock_service(mock_scanner: MockScanner) -> MockModelService: return MockModelService(scanner=mock_scanner) @pytest.fixture def mock_downloader(): """Provide a configurable mock downloader.""" class MockDownloader: def __init__(self): self.download_calls = [] self.should_fail = False self.return_value = (True, "success") async def download_file(self, url, target_path, **kwargs): self.download_calls.append({"url": url, "target_path": target_path, "kwargs": kwargs}) if self.should_fail: return False, "Download failed" return self.return_value return MockDownloader() @pytest.fixture def mock_websocket_manager(): """Provide a recording WebSocket manager.""" class RecordingWebSocketManager: def __init__(self): self.payloads = [] self.broadcast_count = 0 async def broadcast(self, payload): self.payloads.append(payload) self.broadcast_count += 1 def get_payloads_by_type(self, msg_type: str): """Get all payloads of a specific message type.""" return [p for p in self.payloads if p.get("type") == msg_type] return RecordingWebSocketManager() @pytest.fixture(autouse=True) def reset_singletons(): """Reset all singletons before each test to ensure isolation.""" # Import here to avoid circular imports from py.services.download_manager import DownloadManager from py.services.service_registry import ServiceRegistry from py.services.model_scanner import ModelScanner from py.services.settings_manager import get_settings_manager # Reset DownloadManager singleton DownloadManager._instance = None # Reset ServiceRegistry ServiceRegistry._services = {} ServiceRegistry._initialized = False # Reset ModelScanner instances if hasattr(ModelScanner, '_instances'): ModelScanner._instances.clear() # Reset SettingsManager settings_manager = get_settings_manager() if hasattr(settings_manager, '_reset'): settings_manager._reset() yield # Cleanup after test DownloadManager._instance = None ServiceRegistry._services = {} ServiceRegistry._initialized = False if hasattr(ModelScanner, '_instances'): ModelScanner._instances.clear()