diff --git a/py/services/model_metadata_provider.py b/py/services/model_metadata_provider.py index 55091118..1a36d353 100644 --- a/py/services/model_metadata_provider.py +++ b/py/services/model_metadata_provider.py @@ -1,11 +1,41 @@ from abc import ABC, abstractmethod import json -import aiosqlite import logging -from bs4 import BeautifulSoup -from typing import Optional, Dict, Tuple +from typing import Optional, Dict, Tuple, Any from .downloader import get_downloader +try: + from bs4 import BeautifulSoup +except ImportError as exc: + BeautifulSoup = None # type: ignore[assignment] + _BS4_IMPORT_ERROR = exc +else: + _BS4_IMPORT_ERROR = None + +try: + import aiosqlite +except ImportError as exc: + aiosqlite = None # type: ignore[assignment] + _AIOSQLITE_IMPORT_ERROR = exc +else: + _AIOSQLITE_IMPORT_ERROR = None + +def _require_beautifulsoup() -> Any: + if BeautifulSoup is None: + raise RuntimeError( + "BeautifulSoup (bs4) is required for CivArchiveModelMetadataProvider. " + "Install it with 'pip install beautifulsoup4'." + ) from _BS4_IMPORT_ERROR + return BeautifulSoup + +def _require_aiosqlite() -> Any: + if aiosqlite is None: + raise RuntimeError( + "aiosqlite is required for SQLiteModelMetadataProvider. " + "Install it with 'pip install aiosqlite'." + ) from _AIOSQLITE_IMPORT_ERROR + return aiosqlite + logger = logging.getLogger(__name__) class ModelMetadataProvider(ABC): @@ -78,7 +108,8 @@ class CivArchiveModelMetadataProvider(ModelMetadataProvider): html_content = await response.text() # Parse HTML to extract JSON data - soup = BeautifulSoup(html_content, 'html.parser') + soup_parser = _require_beautifulsoup() + soup = soup_parser(html_content, 'html.parser') script_tag = soup.find('script', {'id': '__NEXT_DATA__', 'type': 'application/json'}) if not script_tag: @@ -171,10 +202,11 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider): def __init__(self, db_path: str): self.db_path = db_path + self._aiosqlite = _require_aiosqlite() async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]: """Find model by hash value from SQLite database""" - async with aiosqlite.connect(self.db_path) as db: + async with self._aiosqlite.connect(self.db_path) as db: # Look up in model_files table to get model_id and version_id query = """ SELECT model_id, version_id @@ -182,7 +214,7 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider): WHERE sha256 = ? LIMIT 1 """ - db.row_factory = aiosqlite.Row + db.row_factory = self._aiosqlite.Row cursor = await db.execute(query, (model_hash.upper(),)) file_row = await cursor.fetchone() @@ -199,8 +231,8 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider): async def get_model_versions(self, model_id: str) -> Optional[Dict]: """Get all versions of a model from SQLite database""" - async with aiosqlite.connect(self.db_path) as db: - db.row_factory = aiosqlite.Row + async with self._aiosqlite.connect(self.db_path) as db: + db.row_factory = self._aiosqlite.Row # First check if model exists model_query = "SELECT * FROM models WHERE id = ?" @@ -258,8 +290,8 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider): if not model_id and not version_id: return None - async with aiosqlite.connect(self.db_path) as db: - db.row_factory = aiosqlite.Row + async with self._aiosqlite.connect(self.db_path) as db: + db.row_factory = self._aiosqlite.Row # Case 1: Only version_id is provided if model_id is None and version_id is not None: @@ -295,8 +327,8 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider): async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]: """Fetch model version metadata from SQLite database""" - async with aiosqlite.connect(self.db_path) as db: - db.row_factory = aiosqlite.Row + async with self._aiosqlite.connect(self.db_path) as db: + db.row_factory = self._aiosqlite.Row # Get version details version_query = "SELECT model_id FROM model_versions WHERE id = ?" diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..44f4dc04 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,8 @@ +[pytest] +addopts = -v --import-mode=importlib +testpaths = tests +python_files = test_*.py +python_classes = Test* +python_functions = test_* +# Skip problematic directories to avoid import conflicts +norecursedirs = .git .tox dist build *.egg __pycache__ py \ No newline at end of file diff --git a/run_tests.py b/run_tests.py new file mode 100644 index 00000000..627c77bc --- /dev/null +++ b/run_tests.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +""" +Test runner script for ComfyUI-Lora-Manager. + +This script runs pytest from the tests directory to avoid import issues +with the root __init__.py file. +""" +import subprocess +import sys +import os +from pathlib import Path + +def main(): + """Run pytest from the tests directory to avoid import issues.""" + # Get the script directory + script_dir = Path(__file__).parent.absolute() + tests_dir = script_dir / "tests" + + if not tests_dir.exists(): + print(f"Error: Tests directory not found at {tests_dir}") + return 1 + + # Change to tests directory + original_cwd = os.getcwd() + os.chdir(tests_dir) + + try: + # Build pytest command + cmd = [ + sys.executable, "-m", "pytest", + "-v", + "--rootdir=.", + ] + sys.argv[1:] # Pass any additional arguments + + print(f"Running: {' '.join(cmd)}") + print(f"Working directory: {tests_dir}") + + # Run pytest + result = subprocess.run(cmd, cwd=tests_dir) + return result.returncode + finally: + # Restore original working directory + os.chdir(original_cwd) + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/conftest.py b/tests/conftest.py index 0ec41c73..dfe99691 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,43 @@ +import types from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Sequence +from unittest import mock +import sys import pytest +# Mock ComfyUI modules before any imports from the main project +server_mock = types.SimpleNamespace() +server_mock.PromptServer = mock.MagicMock() +sys.modules['server'] = server_mock + +folder_paths_mock = types.SimpleNamespace() +folder_paths_mock.get_folder_paths = mock.MagicMock(return_value=[]) +folder_paths_mock.folder_names_and_paths = {} +sys.modules['folder_paths'] = folder_paths_mock + +# Mock other ComfyUI modules that might be imported +comfy_mock = types.SimpleNamespace() +comfy_mock.utils = types.SimpleNamespace() +comfy_mock.model_management = types.SimpleNamespace() +comfy_mock.comfy_types = types.SimpleNamespace() +comfy_mock.comfy_types.IO = mock.MagicMock() +sys.modules['comfy'] = comfy_mock +sys.modules['comfy.utils'] = comfy_mock.utils +sys.modules['comfy.model_management'] = comfy_mock.model_management +sys.modules['comfy.comfy_types'] = comfy_mock.comfy_types + +execution_mock = types.SimpleNamespace() +execution_mock.PromptExecutor = mock.MagicMock() +sys.modules['execution'] = execution_mock + +# Mock ComfyUI nodes module +nodes_mock = types.SimpleNamespace() +nodes_mock.LoraLoader = mock.MagicMock() +nodes_mock.SaveImage = mock.MagicMock() +nodes_mock.NODE_CLASS_MAPPINGS = {} +sys.modules['nodes'] = nodes_mock + @dataclass class MockHashIndex: