mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 13:12:12 -03:00
- Add missing mocks for comfy.sd and comfy.utils modules in conftest.py - Fix i18n translation keys: use .help instead of .description for tooltip keys
369 lines
12 KiB
Python
369 lines
12 KiB
Python
import asyncio
|
|
import importlib.util
|
|
import inspect
|
|
import sys
|
|
import types
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Sequence
|
|
from unittest import mock
|
|
|
|
import pytest
|
|
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parents[1]
|
|
PY_INIT = REPO_ROOT / "py" / "__init__.py"
|
|
|
|
|
|
class MockModule(types.ModuleType):
|
|
"""A mock module class that is hashable (unlike SimpleNamespace).
|
|
|
|
This allows the module to be stored in sets/dicts without causing issues
|
|
with tools like Hypothesis that iterate over sys.modules.
|
|
"""
|
|
|
|
def __init__(self, name: str, **kwargs):
|
|
super().__init__(name)
|
|
for key, value in kwargs.items():
|
|
setattr(self, key, value)
|
|
|
|
def __hash__(self):
|
|
return hash(self.__name__)
|
|
|
|
def __eq__(self, other):
|
|
if isinstance(other, MockModule):
|
|
return self.__name__ == other.__name__
|
|
return NotImplemented
|
|
|
|
|
|
def _load_repo_package(name: str) -> types.ModuleType:
|
|
"""Ensure the repository's ``py`` package is importable under *name*."""
|
|
|
|
module = sys.modules.get(name)
|
|
if module and getattr(module, "__file__", None) == str(PY_INIT):
|
|
return module
|
|
|
|
spec = importlib.util.spec_from_file_location(
|
|
name,
|
|
PY_INIT,
|
|
submodule_search_locations=[str(PY_INIT.parent)],
|
|
)
|
|
if spec is None or spec.loader is None: # pragma: no cover - initialization guard
|
|
raise ImportError(f"Unable to load repository package for alias '{name}'")
|
|
|
|
package = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(package) # type: ignore[attr-defined]
|
|
package.__path__ = [str(PY_INIT.parent)] # type: ignore[attr-defined]
|
|
sys.modules[name] = package
|
|
return package
|
|
|
|
|
|
_repo_package = _load_repo_package("py")
|
|
sys.modules.setdefault("py_local", _repo_package)
|
|
|
|
# Mock ComfyUI modules before any imports from the main project
|
|
server_mock = MockModule("server")
|
|
server_mock.PromptServer = mock.MagicMock()
|
|
sys.modules['server'] = server_mock
|
|
|
|
folder_paths_mock = MockModule("folder_paths")
|
|
folder_paths_mock.get_folder_paths = mock.MagicMock(return_value=[])
|
|
folder_paths_mock.folder_names_and_paths = {}
|
|
sys.modules['folder_paths'] = folder_paths_mock
|
|
|
|
# Mock other ComfyUI modules that might be imported
|
|
comfy_mock = MockModule("comfy")
|
|
comfy_mock.utils = MockModule("comfy.utils")
|
|
comfy_mock.utils.load_torch_file = mock.MagicMock(return_value={})
|
|
comfy_mock.sd = MockModule("comfy.sd")
|
|
comfy_mock.sd.load_lora_for_models = mock.MagicMock(return_value=(None, None))
|
|
comfy_mock.model_management = MockModule("comfy.model_management")
|
|
comfy_mock.comfy_types = MockModule("comfy.comfy_types")
|
|
comfy_mock.comfy_types.IO = mock.MagicMock()
|
|
sys.modules['comfy'] = comfy_mock
|
|
sys.modules['comfy.utils'] = comfy_mock.utils
|
|
sys.modules['comfy.sd'] = comfy_mock.sd
|
|
sys.modules['comfy.model_management'] = comfy_mock.model_management
|
|
sys.modules['comfy.comfy_types'] = comfy_mock.comfy_types
|
|
|
|
execution_mock = MockModule("execution")
|
|
execution_mock.PromptExecutor = mock.MagicMock()
|
|
sys.modules['execution'] = execution_mock
|
|
|
|
# Mock ComfyUI nodes module
|
|
nodes_mock = MockModule("nodes")
|
|
nodes_mock.LoraLoader = mock.MagicMock()
|
|
nodes_mock.SaveImage = mock.MagicMock()
|
|
nodes_mock.NODE_CLASS_MAPPINGS = {}
|
|
sys.modules['nodes'] = nodes_mock
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _isolate_settings_dir(tmp_path_factory, monkeypatch, request):
|
|
"""Redirect settings.json into a temporary directory for each test."""
|
|
|
|
if request.node.get_closest_marker("no_settings_dir_isolation"):
|
|
from py.services import settings_manager as settings_manager_module
|
|
|
|
settings_manager_module.reset_settings_manager()
|
|
yield
|
|
settings_manager_module.reset_settings_manager()
|
|
return
|
|
|
|
settings_dir = tmp_path_factory.mktemp("settings_dir")
|
|
|
|
def fake_get_settings_dir(create: bool = True) -> str:
|
|
if create:
|
|
settings_dir.mkdir(exist_ok=True)
|
|
return str(settings_dir)
|
|
|
|
monkeypatch.setattr("py.utils.settings_paths.get_settings_dir", fake_get_settings_dir)
|
|
monkeypatch.setattr(
|
|
"py.utils.settings_paths.user_config_dir",
|
|
lambda *_args, **_kwargs: str(settings_dir),
|
|
)
|
|
|
|
from py.services import settings_manager as settings_manager_module
|
|
|
|
settings_manager_module.reset_settings_manager()
|
|
yield
|
|
settings_manager_module.reset_settings_manager()
|
|
|
|
|
|
@dataclass
|
|
class MockHashIndex:
|
|
"""Minimal hash index stub mirroring the scanner contract."""
|
|
|
|
removed_paths: List[str] = field(default_factory=list)
|
|
|
|
def remove_by_path(self, path: str) -> None:
|
|
self.removed_paths.append(path)
|
|
|
|
|
|
class MockCache:
|
|
"""Cache object with the attributes."""
|
|
|
|
def __init__(self, items: Optional[Sequence[Dict[str, Any]]] = None):
|
|
self.raw_data: List[Dict[str, Any]] = list(items or [])
|
|
self.resort_calls = 0
|
|
|
|
async def resort(self) -> None:
|
|
self.resort_calls += 1
|
|
# expects the coroutine interface but does not
|
|
# rely on the return value.
|
|
|
|
|
|
class MockScanner:
|
|
"""Scanner double that exposes the attributes used by route utilities."""
|
|
|
|
def __init__(self, cache: Optional[MockCache] = None, hash_index: Optional[MockHashIndex] = None):
|
|
self._cache = cache or MockCache()
|
|
self._hash_index = hash_index or MockHashIndex()
|
|
self._tags_count: Dict[str, int] = {}
|
|
self._excluded_models: List[str] = []
|
|
self.updated_models: List[Dict[str, Any]] = []
|
|
self.preview_updates: List[Dict[str, Any]] = []
|
|
self.bulk_deleted: List[Sequence[str]] = []
|
|
self._cancelled = False
|
|
self.model_type = "test-model"
|
|
|
|
def is_cancelled(self) -> bool:
|
|
return self._cancelled
|
|
|
|
def cancel_task(self) -> None:
|
|
self._cancelled = True
|
|
|
|
def reset_cancellation(self) -> None:
|
|
self._cancelled = False
|
|
|
|
async def get_cached_data(self, force_refresh: bool = False):
|
|
return self._cache
|
|
|
|
async def update_single_model_cache(self, original_path: str, new_path: str, metadata: Dict[str, Any]) -> bool:
|
|
self.updated_models.append({
|
|
"original_path": original_path,
|
|
"new_path": new_path,
|
|
"metadata": metadata,
|
|
})
|
|
for item in self._cache.raw_data:
|
|
if item.get("file_path") == original_path:
|
|
item.update(metadata)
|
|
return True
|
|
|
|
async def update_preview_in_cache(self, model_path: str, preview_path: str, nsfw_level: int) -> bool:
|
|
self.preview_updates.append({
|
|
"model_path": model_path,
|
|
"preview_path": preview_path,
|
|
"nsfw_level": nsfw_level,
|
|
})
|
|
for item in self._cache.raw_data:
|
|
if item.get("file_path") == model_path:
|
|
item["preview_url"] = preview_path
|
|
item["preview_nsfw_level"] = nsfw_level
|
|
return True
|
|
|
|
async def bulk_delete_models(self, file_paths: Sequence[str]) -> Dict[str, Any]:
|
|
self.bulk_deleted.append(tuple(file_paths))
|
|
self._cache.raw_data = [item for item in self._cache.raw_data if item.get("file_path") not in file_paths]
|
|
await self._cache.resort()
|
|
for path in file_paths:
|
|
self._hash_index.remove_by_path(path)
|
|
return {"success": True, "deleted": list(file_paths)}
|
|
|
|
|
|
class MockModelService:
|
|
"""Service stub consumed by the shared routes."""
|
|
|
|
def __init__(self, scanner: MockScanner):
|
|
self.scanner = scanner
|
|
self.model_type = "test-model"
|
|
self.paginated_items: List[Dict[str, Any]] = []
|
|
self.formatted: List[Dict[str, Any]] = []
|
|
self.model_types: List[Dict[str, Any]] = []
|
|
|
|
async def get_paginated_data(self, **params: Any) -> Dict[str, Any]:
|
|
items = [dict(item) for item in self.paginated_items]
|
|
total = len(items)
|
|
page = params.get("page", 1)
|
|
page_size = params.get("page_size", 20)
|
|
return {
|
|
"items": items,
|
|
"total": total,
|
|
"page": page,
|
|
"page_size": page_size,
|
|
"total_pages": max(1, (total + page_size - 1) // page_size),
|
|
}
|
|
|
|
async def format_response(self, item: Dict[str, Any]) -> Dict[str, Any]:
|
|
formatted = {**item, "formatted": True}
|
|
self.formatted.append(formatted)
|
|
return formatted
|
|
|
|
# Convenience helpers used by assorted routes. They are no-ops for the
|
|
# smoke tests but document the expected surface area of the real services.
|
|
def get_model_roots(self) -> List[str]:
|
|
return ["."]
|
|
|
|
async def scan_models(self, *_, **__): # pragma: no cover - behaviour exercised via mocks
|
|
return None
|
|
|
|
async def get_model_notes(self, *_args, **_kwargs): # pragma: no cover
|
|
return None
|
|
|
|
async def get_model_preview_url(self, *_args, **_kwargs): # pragma: no cover
|
|
return ""
|
|
|
|
async def get_model_civitai_url(self, *_args, **_kwargs): # pragma: no cover
|
|
return {"civitai_url": ""}
|
|
|
|
async def get_model_metadata(self, *_args, **_kwargs): # pragma: no cover
|
|
return {}
|
|
|
|
async def get_model_description(self, *_args, **_kwargs): # pragma: no cover
|
|
return ""
|
|
|
|
async def get_relative_paths(self, *_args, **_kwargs): # pragma: no cover
|
|
return []
|
|
|
|
async def get_model_types(self, limit: int = 20):
|
|
return list(self.model_types)[:limit]
|
|
|
|
def has_hash(self, *_args, **_kwargs): # pragma: no cover
|
|
return False
|
|
|
|
def get_path_by_hash(self, *_args, **_kwargs): # pragma: no cover
|
|
return ""
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_hash_index() -> MockHashIndex:
|
|
return MockHashIndex()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_cache() -> MockCache:
|
|
return MockCache()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_scanner(mock_cache: MockCache, mock_hash_index: MockHashIndex) -> MockScanner:
|
|
return MockScanner(cache=mock_cache, hash_index=mock_hash_index)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_service(mock_scanner: MockScanner) -> MockModelService:
|
|
return MockModelService(scanner=mock_scanner)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_downloader():
|
|
"""Provide a configurable mock downloader."""
|
|
class MockDownloader:
|
|
def __init__(self):
|
|
self.download_calls = []
|
|
self.should_fail = False
|
|
self.return_value = (True, "success")
|
|
|
|
async def download_file(self, url, target_path, **kwargs):
|
|
self.download_calls.append({"url": url, "target_path": target_path, "kwargs": kwargs})
|
|
if self.should_fail:
|
|
return False, "Download failed"
|
|
return self.return_value
|
|
|
|
return MockDownloader()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_websocket_manager():
|
|
"""Provide a recording WebSocket manager."""
|
|
class RecordingWebSocketManager:
|
|
def __init__(self):
|
|
self.payloads = []
|
|
self.broadcast_count = 0
|
|
|
|
async def broadcast(self, payload):
|
|
self.payloads.append(payload)
|
|
self.broadcast_count += 1
|
|
|
|
def get_payloads_by_type(self, msg_type: str):
|
|
"""Get all payloads of a specific message type."""
|
|
return [p for p in self.payloads if p.get("type") == msg_type]
|
|
|
|
return RecordingWebSocketManager()
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def reset_singletons():
|
|
"""Reset all singletons before each test to ensure isolation."""
|
|
# Import here to avoid circular imports
|
|
from py.services.download_manager import DownloadManager
|
|
from py.services.service_registry import ServiceRegistry
|
|
from py.services.model_scanner import ModelScanner
|
|
from py.services.settings_manager import get_settings_manager
|
|
|
|
# Reset DownloadManager singleton
|
|
DownloadManager._instance = None
|
|
|
|
# Reset ServiceRegistry
|
|
ServiceRegistry._services = {}
|
|
ServiceRegistry._initialized = False
|
|
|
|
# Reset ModelScanner instances
|
|
if hasattr(ModelScanner, '_instances'):
|
|
ModelScanner._instances.clear()
|
|
|
|
# Reset SettingsManager
|
|
settings_manager = get_settings_manager()
|
|
if hasattr(settings_manager, '_reset'):
|
|
settings_manager._reset()
|
|
|
|
yield
|
|
|
|
# Cleanup after test
|
|
DownloadManager._instance = None
|
|
ServiceRegistry._services = {}
|
|
ServiceRegistry._initialized = False
|
|
if hasattr(ModelScanner, '_instances'):
|
|
ModelScanner._instances.clear()
|
|
|