import types from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Sequence import asyncio import inspect from unittest import mock import sys import pytest # Mock ComfyUI modules before any imports from the main project server_mock = types.SimpleNamespace() server_mock.PromptServer = mock.MagicMock() sys.modules['server'] = server_mock folder_paths_mock = types.SimpleNamespace() folder_paths_mock.get_folder_paths = mock.MagicMock(return_value=[]) folder_paths_mock.folder_names_and_paths = {} sys.modules['folder_paths'] = folder_paths_mock # Mock other ComfyUI modules that might be imported comfy_mock = types.SimpleNamespace() comfy_mock.utils = types.SimpleNamespace() comfy_mock.model_management = types.SimpleNamespace() comfy_mock.comfy_types = types.SimpleNamespace() comfy_mock.comfy_types.IO = mock.MagicMock() sys.modules['comfy'] = comfy_mock sys.modules['comfy.utils'] = comfy_mock.utils sys.modules['comfy.model_management'] = comfy_mock.model_management sys.modules['comfy.comfy_types'] = comfy_mock.comfy_types execution_mock = types.SimpleNamespace() execution_mock.PromptExecutor = mock.MagicMock() sys.modules['execution'] = execution_mock # Mock ComfyUI nodes module nodes_mock = types.SimpleNamespace() nodes_mock.LoraLoader = mock.MagicMock() nodes_mock.SaveImage = mock.MagicMock() nodes_mock.NODE_CLASS_MAPPINGS = {} sys.modules['nodes'] = nodes_mock def pytest_pyfunc_call(pyfuncitem): if inspect.iscoroutinefunction(pyfuncitem.function): asyncio.run(pyfuncitem.obj(**pyfuncitem.funcargs)) return True return None @dataclass class MockHashIndex: """Minimal hash index stub mirroring the scanner contract.""" removed_paths: List[str] = field(default_factory=list) def remove_by_path(self, path: str) -> None: self.removed_paths.append(path) class MockCache: """Cache object with the attributes consumed by ``ModelRouteUtils``.""" def __init__(self, items: Optional[Sequence[Dict[str, Any]]] = None): self.raw_data: List[Dict[str, Any]] = list(items or []) self.resort_calls = 0 async def resort(self) -> None: self.resort_calls += 1 # ``ModelRouteUtils`` expects the coroutine interface but does not # rely on the return value. class MockScanner: """Scanner double that exposes the attributes used by route utilities.""" def __init__(self, cache: Optional[MockCache] = None, hash_index: Optional[MockHashIndex] = None): self._cache = cache or MockCache() self._hash_index = hash_index or MockHashIndex() self._tags_count: Dict[str, int] = {} self._excluded_models: List[str] = [] self.updated_models: List[Dict[str, Any]] = [] self.preview_updates: List[Dict[str, Any]] = [] self.bulk_deleted: List[Sequence[str]] = [] async def get_cached_data(self, force_refresh: bool = False): return self._cache async def update_single_model_cache(self, original_path: str, new_path: str, metadata: Dict[str, Any]) -> bool: self.updated_models.append({ "original_path": original_path, "new_path": new_path, "metadata": metadata, }) for item in self._cache.raw_data: if item.get("file_path") == original_path: item.update(metadata) return True async def update_preview_in_cache(self, model_path: str, preview_path: str, nsfw_level: int) -> bool: self.preview_updates.append({ "model_path": model_path, "preview_path": preview_path, "nsfw_level": nsfw_level, }) for item in self._cache.raw_data: if item.get("file_path") == model_path: item["preview_url"] = preview_path item["preview_nsfw_level"] = nsfw_level return True async def bulk_delete_models(self, file_paths: Sequence[str]) -> Dict[str, Any]: self.bulk_deleted.append(tuple(file_paths)) self._cache.raw_data = [item for item in self._cache.raw_data if item.get("file_path") not in file_paths] await self._cache.resort() for path in file_paths: self._hash_index.remove_by_path(path) return {"success": True, "deleted": list(file_paths)} class MockModelService: """Service stub consumed by the shared routes.""" def __init__(self, scanner: MockScanner): self.scanner = scanner self.model_type = "test-model" self.paginated_items: List[Dict[str, Any]] = [] self.formatted: List[Dict[str, Any]] = [] async def get_paginated_data(self, **params: Any) -> Dict[str, Any]: items = [dict(item) for item in self.paginated_items] total = len(items) page = params.get("page", 1) page_size = params.get("page_size", 20) return { "items": items, "total": total, "page": page, "page_size": page_size, "total_pages": max(1, (total + page_size - 1) // page_size), } async def format_response(self, item: Dict[str, Any]) -> Dict[str, Any]: formatted = {**item, "formatted": True} self.formatted.append(formatted) return formatted # Convenience helpers used by assorted routes. They are no-ops for the # smoke tests but document the expected surface area of the real services. def get_model_roots(self) -> List[str]: return ["."] async def scan_models(self, *_, **__): # pragma: no cover - behaviour exercised via mocks return None async def get_model_notes(self, *_args, **_kwargs): # pragma: no cover return None async def get_model_preview_url(self, *_args, **_kwargs): # pragma: no cover return "" async def get_model_civitai_url(self, *_args, **_kwargs): # pragma: no cover return {"civitai_url": ""} async def get_model_metadata(self, *_args, **_kwargs): # pragma: no cover return {} async def get_model_description(self, *_args, **_kwargs): # pragma: no cover return "" async def get_relative_paths(self, *_args, **_kwargs): # pragma: no cover return [] def has_hash(self, *_args, **_kwargs): # pragma: no cover return False def get_path_by_hash(self, *_args, **_kwargs): # pragma: no cover return "" @pytest.fixture def mock_hash_index() -> MockHashIndex: return MockHashIndex() @pytest.fixture def mock_cache() -> MockCache: return MockCache() @pytest.fixture def mock_scanner(mock_cache: MockCache, mock_hash_index: MockHashIndex) -> MockScanner: return MockScanner(cache=mock_cache, hash_index=mock_hash_index) @pytest.fixture def mock_service(mock_scanner: MockScanner) -> MockModelService: return MockModelService(scanner=mock_scanner)