Files
ComfyUI-Lora-Manager/tests/conftest.py
Will Miao 85e511d81c feat(testing): implement Phase 4 advanced testing
- Add Hypothesis property-based tests (19 tests)
- Add Syrupy snapshot tests (7 tests)
- Add pytest-benchmark performance tests (11 tests)
- Fix Hypothesis plugin compatibility by creating MockModule class
- Update pytest.ini to exclude .hypothesis directory
- Add .hypothesis/ to .gitignore
- Update requirements-dev.txt with testing dependencies
- Mark Phase 4 complete in backend-testing-improvement-plan.md

All 947 tests passing.
2026-02-11 11:58:28 +08:00

365 lines
12 KiB
Python

import asyncio
import importlib.util
import inspect
import sys
import types
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence
from unittest import mock
import pytest
REPO_ROOT = Path(__file__).resolve().parents[1]
PY_INIT = REPO_ROOT / "py" / "__init__.py"
class MockModule(types.ModuleType):
"""A mock module class that is hashable (unlike SimpleNamespace).
This allows the module to be stored in sets/dicts without causing issues
with tools like Hypothesis that iterate over sys.modules.
"""
def __init__(self, name: str, **kwargs):
super().__init__(name)
for key, value in kwargs.items():
setattr(self, key, value)
def __hash__(self):
return hash(self.__name__)
def __eq__(self, other):
if isinstance(other, MockModule):
return self.__name__ == other.__name__
return NotImplemented
def _load_repo_package(name: str) -> types.ModuleType:
"""Ensure the repository's ``py`` package is importable under *name*."""
module = sys.modules.get(name)
if module and getattr(module, "__file__", None) == str(PY_INIT):
return module
spec = importlib.util.spec_from_file_location(
name,
PY_INIT,
submodule_search_locations=[str(PY_INIT.parent)],
)
if spec is None or spec.loader is None: # pragma: no cover - initialization guard
raise ImportError(f"Unable to load repository package for alias '{name}'")
package = importlib.util.module_from_spec(spec)
spec.loader.exec_module(package) # type: ignore[attr-defined]
package.__path__ = [str(PY_INIT.parent)] # type: ignore[attr-defined]
sys.modules[name] = package
return package
_repo_package = _load_repo_package("py")
sys.modules.setdefault("py_local", _repo_package)
# Mock ComfyUI modules before any imports from the main project
server_mock = MockModule("server")
server_mock.PromptServer = mock.MagicMock()
sys.modules['server'] = server_mock
folder_paths_mock = MockModule("folder_paths")
folder_paths_mock.get_folder_paths = mock.MagicMock(return_value=[])
folder_paths_mock.folder_names_and_paths = {}
sys.modules['folder_paths'] = folder_paths_mock
# Mock other ComfyUI modules that might be imported
comfy_mock = MockModule("comfy")
comfy_mock.utils = MockModule("comfy.utils")
comfy_mock.model_management = MockModule("comfy.model_management")
comfy_mock.comfy_types = MockModule("comfy.comfy_types")
comfy_mock.comfy_types.IO = mock.MagicMock()
sys.modules['comfy'] = comfy_mock
sys.modules['comfy.utils'] = comfy_mock.utils
sys.modules['comfy.model_management'] = comfy_mock.model_management
sys.modules['comfy.comfy_types'] = comfy_mock.comfy_types
execution_mock = MockModule("execution")
execution_mock.PromptExecutor = mock.MagicMock()
sys.modules['execution'] = execution_mock
# Mock ComfyUI nodes module
nodes_mock = MockModule("nodes")
nodes_mock.LoraLoader = mock.MagicMock()
nodes_mock.SaveImage = mock.MagicMock()
nodes_mock.NODE_CLASS_MAPPINGS = {}
sys.modules['nodes'] = nodes_mock
@pytest.fixture(autouse=True)
def _isolate_settings_dir(tmp_path_factory, monkeypatch, request):
"""Redirect settings.json into a temporary directory for each test."""
if request.node.get_closest_marker("no_settings_dir_isolation"):
from py.services import settings_manager as settings_manager_module
settings_manager_module.reset_settings_manager()
yield
settings_manager_module.reset_settings_manager()
return
settings_dir = tmp_path_factory.mktemp("settings_dir")
def fake_get_settings_dir(create: bool = True) -> str:
if create:
settings_dir.mkdir(exist_ok=True)
return str(settings_dir)
monkeypatch.setattr("py.utils.settings_paths.get_settings_dir", fake_get_settings_dir)
monkeypatch.setattr(
"py.utils.settings_paths.user_config_dir",
lambda *_args, **_kwargs: str(settings_dir),
)
from py.services import settings_manager as settings_manager_module
settings_manager_module.reset_settings_manager()
yield
settings_manager_module.reset_settings_manager()
@dataclass
class MockHashIndex:
"""Minimal hash index stub mirroring the scanner contract."""
removed_paths: List[str] = field(default_factory=list)
def remove_by_path(self, path: str) -> None:
self.removed_paths.append(path)
class MockCache:
"""Cache object with the attributes."""
def __init__(self, items: Optional[Sequence[Dict[str, Any]]] = None):
self.raw_data: List[Dict[str, Any]] = list(items or [])
self.resort_calls = 0
async def resort(self) -> None:
self.resort_calls += 1
# expects the coroutine interface but does not
# rely on the return value.
class MockScanner:
"""Scanner double that exposes the attributes used by route utilities."""
def __init__(self, cache: Optional[MockCache] = None, hash_index: Optional[MockHashIndex] = None):
self._cache = cache or MockCache()
self._hash_index = hash_index or MockHashIndex()
self._tags_count: Dict[str, int] = {}
self._excluded_models: List[str] = []
self.updated_models: List[Dict[str, Any]] = []
self.preview_updates: List[Dict[str, Any]] = []
self.bulk_deleted: List[Sequence[str]] = []
self._cancelled = False
self.model_type = "test-model"
def is_cancelled(self) -> bool:
return self._cancelled
def cancel_task(self) -> None:
self._cancelled = True
def reset_cancellation(self) -> None:
self._cancelled = False
async def get_cached_data(self, force_refresh: bool = False):
return self._cache
async def update_single_model_cache(self, original_path: str, new_path: str, metadata: Dict[str, Any]) -> bool:
self.updated_models.append({
"original_path": original_path,
"new_path": new_path,
"metadata": metadata,
})
for item in self._cache.raw_data:
if item.get("file_path") == original_path:
item.update(metadata)
return True
async def update_preview_in_cache(self, model_path: str, preview_path: str, nsfw_level: int) -> bool:
self.preview_updates.append({
"model_path": model_path,
"preview_path": preview_path,
"nsfw_level": nsfw_level,
})
for item in self._cache.raw_data:
if item.get("file_path") == model_path:
item["preview_url"] = preview_path
item["preview_nsfw_level"] = nsfw_level
return True
async def bulk_delete_models(self, file_paths: Sequence[str]) -> Dict[str, Any]:
self.bulk_deleted.append(tuple(file_paths))
self._cache.raw_data = [item for item in self._cache.raw_data if item.get("file_path") not in file_paths]
await self._cache.resort()
for path in file_paths:
self._hash_index.remove_by_path(path)
return {"success": True, "deleted": list(file_paths)}
class MockModelService:
"""Service stub consumed by the shared routes."""
def __init__(self, scanner: MockScanner):
self.scanner = scanner
self.model_type = "test-model"
self.paginated_items: List[Dict[str, Any]] = []
self.formatted: List[Dict[str, Any]] = []
self.model_types: List[Dict[str, Any]] = []
async def get_paginated_data(self, **params: Any) -> Dict[str, Any]:
items = [dict(item) for item in self.paginated_items]
total = len(items)
page = params.get("page", 1)
page_size = params.get("page_size", 20)
return {
"items": items,
"total": total,
"page": page,
"page_size": page_size,
"total_pages": max(1, (total + page_size - 1) // page_size),
}
async def format_response(self, item: Dict[str, Any]) -> Dict[str, Any]:
formatted = {**item, "formatted": True}
self.formatted.append(formatted)
return formatted
# Convenience helpers used by assorted routes. They are no-ops for the
# smoke tests but document the expected surface area of the real services.
def get_model_roots(self) -> List[str]:
return ["."]
async def scan_models(self, *_, **__): # pragma: no cover - behaviour exercised via mocks
return None
async def get_model_notes(self, *_args, **_kwargs): # pragma: no cover
return None
async def get_model_preview_url(self, *_args, **_kwargs): # pragma: no cover
return ""
async def get_model_civitai_url(self, *_args, **_kwargs): # pragma: no cover
return {"civitai_url": ""}
async def get_model_metadata(self, *_args, **_kwargs): # pragma: no cover
return {}
async def get_model_description(self, *_args, **_kwargs): # pragma: no cover
return ""
async def get_relative_paths(self, *_args, **_kwargs): # pragma: no cover
return []
async def get_model_types(self, limit: int = 20):
return list(self.model_types)[:limit]
def has_hash(self, *_args, **_kwargs): # pragma: no cover
return False
def get_path_by_hash(self, *_args, **_kwargs): # pragma: no cover
return ""
@pytest.fixture
def mock_hash_index() -> MockHashIndex:
return MockHashIndex()
@pytest.fixture
def mock_cache() -> MockCache:
return MockCache()
@pytest.fixture
def mock_scanner(mock_cache: MockCache, mock_hash_index: MockHashIndex) -> MockScanner:
return MockScanner(cache=mock_cache, hash_index=mock_hash_index)
@pytest.fixture
def mock_service(mock_scanner: MockScanner) -> MockModelService:
return MockModelService(scanner=mock_scanner)
@pytest.fixture
def mock_downloader():
"""Provide a configurable mock downloader."""
class MockDownloader:
def __init__(self):
self.download_calls = []
self.should_fail = False
self.return_value = (True, "success")
async def download_file(self, url, target_path, **kwargs):
self.download_calls.append({"url": url, "target_path": target_path, "kwargs": kwargs})
if self.should_fail:
return False, "Download failed"
return self.return_value
return MockDownloader()
@pytest.fixture
def mock_websocket_manager():
"""Provide a recording WebSocket manager."""
class RecordingWebSocketManager:
def __init__(self):
self.payloads = []
self.broadcast_count = 0
async def broadcast(self, payload):
self.payloads.append(payload)
self.broadcast_count += 1
def get_payloads_by_type(self, msg_type: str):
"""Get all payloads of a specific message type."""
return [p for p in self.payloads if p.get("type") == msg_type]
return RecordingWebSocketManager()
@pytest.fixture(autouse=True)
def reset_singletons():
"""Reset all singletons before each test to ensure isolation."""
# Import here to avoid circular imports
from py.services.download_manager import DownloadManager
from py.services.service_registry import ServiceRegistry
from py.services.model_scanner import ModelScanner
from py.services.settings_manager import get_settings_manager
# Reset DownloadManager singleton
DownloadManager._instance = None
# Reset ServiceRegistry
ServiceRegistry._services = {}
ServiceRegistry._initialized = False
# Reset ModelScanner instances
if hasattr(ModelScanner, '_instances'):
ModelScanner._instances.clear()
# Reset SettingsManager
settings_manager = get_settings_manager()
if hasattr(settings_manager, '_reset'):
settings_manager._reset()
yield
# Cleanup after test
DownloadManager._instance = None
ServiceRegistry._services = {}
ServiceRegistry._initialized = False
if hasattr(ModelScanner, '_instances'):
ModelScanner._instances.clear()