feat: enhance model metadata provider with import error handling and mock setup for tests

This commit is contained in:
Will Miao
2025-09-21 19:56:38 +08:00
parent 4faf912c6f
commit 1022b07f64
4 changed files with 133 additions and 12 deletions

View File

@@ -1,11 +1,41 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import json import json
import aiosqlite
import logging import logging
from bs4 import BeautifulSoup from typing import Optional, Dict, Tuple, Any
from typing import Optional, Dict, Tuple
from .downloader import get_downloader from .downloader import get_downloader
try:
from bs4 import BeautifulSoup
except ImportError as exc:
BeautifulSoup = None # type: ignore[assignment]
_BS4_IMPORT_ERROR = exc
else:
_BS4_IMPORT_ERROR = None
try:
import aiosqlite
except ImportError as exc:
aiosqlite = None # type: ignore[assignment]
_AIOSQLITE_IMPORT_ERROR = exc
else:
_AIOSQLITE_IMPORT_ERROR = None
def _require_beautifulsoup() -> Any:
if BeautifulSoup is None:
raise RuntimeError(
"BeautifulSoup (bs4) is required for CivArchiveModelMetadataProvider. "
"Install it with 'pip install beautifulsoup4'."
) from _BS4_IMPORT_ERROR
return BeautifulSoup
def _require_aiosqlite() -> Any:
if aiosqlite is None:
raise RuntimeError(
"aiosqlite is required for SQLiteModelMetadataProvider. "
"Install it with 'pip install aiosqlite'."
) from _AIOSQLITE_IMPORT_ERROR
return aiosqlite
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ModelMetadataProvider(ABC): class ModelMetadataProvider(ABC):
@@ -78,7 +108,8 @@ class CivArchiveModelMetadataProvider(ModelMetadataProvider):
html_content = await response.text() html_content = await response.text()
# Parse HTML to extract JSON data # Parse HTML to extract JSON data
soup = BeautifulSoup(html_content, 'html.parser') soup_parser = _require_beautifulsoup()
soup = soup_parser(html_content, 'html.parser')
script_tag = soup.find('script', {'id': '__NEXT_DATA__', 'type': 'application/json'}) script_tag = soup.find('script', {'id': '__NEXT_DATA__', 'type': 'application/json'})
if not script_tag: if not script_tag:
@@ -171,10 +202,11 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
def __init__(self, db_path: str): def __init__(self, db_path: str):
self.db_path = db_path self.db_path = db_path
self._aiosqlite = _require_aiosqlite()
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]: async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Find model by hash value from SQLite database""" """Find model by hash value from SQLite database"""
async with aiosqlite.connect(self.db_path) as db: async with self._aiosqlite.connect(self.db_path) as db:
# Look up in model_files table to get model_id and version_id # Look up in model_files table to get model_id and version_id
query = """ query = """
SELECT model_id, version_id SELECT model_id, version_id
@@ -182,7 +214,7 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
WHERE sha256 = ? WHERE sha256 = ?
LIMIT 1 LIMIT 1
""" """
db.row_factory = aiosqlite.Row db.row_factory = self._aiosqlite.Row
cursor = await db.execute(query, (model_hash.upper(),)) cursor = await db.execute(query, (model_hash.upper(),))
file_row = await cursor.fetchone() file_row = await cursor.fetchone()
@@ -199,8 +231,8 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
async def get_model_versions(self, model_id: str) -> Optional[Dict]: async def get_model_versions(self, model_id: str) -> Optional[Dict]:
"""Get all versions of a model from SQLite database""" """Get all versions of a model from SQLite database"""
async with aiosqlite.connect(self.db_path) as db: async with self._aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row db.row_factory = self._aiosqlite.Row
# First check if model exists # First check if model exists
model_query = "SELECT * FROM models WHERE id = ?" model_query = "SELECT * FROM models WHERE id = ?"
@@ -258,8 +290,8 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
if not model_id and not version_id: if not model_id and not version_id:
return None return None
async with aiosqlite.connect(self.db_path) as db: async with self._aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row db.row_factory = self._aiosqlite.Row
# Case 1: Only version_id is provided # Case 1: Only version_id is provided
if model_id is None and version_id is not None: if model_id is None and version_id is not None:
@@ -295,8 +327,8 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]: async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Fetch model version metadata from SQLite database""" """Fetch model version metadata from SQLite database"""
async with aiosqlite.connect(self.db_path) as db: async with self._aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row db.row_factory = self._aiosqlite.Row
# Get version details # Get version details
version_query = "SELECT model_id FROM model_versions WHERE id = ?" version_query = "SELECT model_id FROM model_versions WHERE id = ?"

8
pytest.ini Normal file
View File

@@ -0,0 +1,8 @@
[pytest]
addopts = -v --import-mode=importlib
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
# Skip problematic directories to avoid import conflicts
norecursedirs = .git .tox dist build *.egg __pycache__ py

46
run_tests.py Normal file
View File

@@ -0,0 +1,46 @@
#!/usr/bin/env python3
"""
Test runner script for ComfyUI-Lora-Manager.
This script runs pytest from the tests directory to avoid import issues
with the root __init__.py file.
"""
import subprocess
import sys
import os
from pathlib import Path
def main():
"""Run pytest from the tests directory to avoid import issues."""
# Get the script directory
script_dir = Path(__file__).parent.absolute()
tests_dir = script_dir / "tests"
if not tests_dir.exists():
print(f"Error: Tests directory not found at {tests_dir}")
return 1
# Change to tests directory
original_cwd = os.getcwd()
os.chdir(tests_dir)
try:
# Build pytest command
cmd = [
sys.executable, "-m", "pytest",
"-v",
"--rootdir=.",
] + sys.argv[1:] # Pass any additional arguments
print(f"Running: {' '.join(cmd)}")
print(f"Working directory: {tests_dir}")
# Run pytest
result = subprocess.run(cmd, cwd=tests_dir)
return result.returncode
finally:
# Restore original working directory
os.chdir(original_cwd)
if __name__ == "__main__":
sys.exit(main())

View File

@@ -1,8 +1,43 @@
import types
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Sequence from typing import Any, Dict, List, Optional, Sequence
from unittest import mock
import sys
import pytest import pytest
# Mock ComfyUI modules before any imports from the main project
server_mock = types.SimpleNamespace()
server_mock.PromptServer = mock.MagicMock()
sys.modules['server'] = server_mock
folder_paths_mock = types.SimpleNamespace()
folder_paths_mock.get_folder_paths = mock.MagicMock(return_value=[])
folder_paths_mock.folder_names_and_paths = {}
sys.modules['folder_paths'] = folder_paths_mock
# Mock other ComfyUI modules that might be imported
comfy_mock = types.SimpleNamespace()
comfy_mock.utils = types.SimpleNamespace()
comfy_mock.model_management = types.SimpleNamespace()
comfy_mock.comfy_types = types.SimpleNamespace()
comfy_mock.comfy_types.IO = mock.MagicMock()
sys.modules['comfy'] = comfy_mock
sys.modules['comfy.utils'] = comfy_mock.utils
sys.modules['comfy.model_management'] = comfy_mock.model_management
sys.modules['comfy.comfy_types'] = comfy_mock.comfy_types
execution_mock = types.SimpleNamespace()
execution_mock.PromptExecutor = mock.MagicMock()
sys.modules['execution'] = execution_mock
# Mock ComfyUI nodes module
nodes_mock = types.SimpleNamespace()
nodes_mock.LoraLoader = mock.MagicMock()
nodes_mock.SaveImage = mock.MagicMock()
nodes_mock.NODE_CLASS_MAPPINGS = {}
sys.modules['nodes'] = nodes_mock
@dataclass @dataclass
class MockHashIndex: class MockHashIndex: