mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat: enhance model metadata provider with import error handling and mock setup for tests
This commit is contained in:
@@ -1,11 +1,41 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import json
|
||||
import aiosqlite
|
||||
import logging
|
||||
from bs4 import BeautifulSoup
|
||||
from typing import Optional, Dict, Tuple
|
||||
from typing import Optional, Dict, Tuple, Any
|
||||
from .downloader import get_downloader
|
||||
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
except ImportError as exc:
|
||||
BeautifulSoup = None # type: ignore[assignment]
|
||||
_BS4_IMPORT_ERROR = exc
|
||||
else:
|
||||
_BS4_IMPORT_ERROR = None
|
||||
|
||||
try:
|
||||
import aiosqlite
|
||||
except ImportError as exc:
|
||||
aiosqlite = None # type: ignore[assignment]
|
||||
_AIOSQLITE_IMPORT_ERROR = exc
|
||||
else:
|
||||
_AIOSQLITE_IMPORT_ERROR = None
|
||||
|
||||
def _require_beautifulsoup() -> Any:
|
||||
if BeautifulSoup is None:
|
||||
raise RuntimeError(
|
||||
"BeautifulSoup (bs4) is required for CivArchiveModelMetadataProvider. "
|
||||
"Install it with 'pip install beautifulsoup4'."
|
||||
) from _BS4_IMPORT_ERROR
|
||||
return BeautifulSoup
|
||||
|
||||
def _require_aiosqlite() -> Any:
|
||||
if aiosqlite is None:
|
||||
raise RuntimeError(
|
||||
"aiosqlite is required for SQLiteModelMetadataProvider. "
|
||||
"Install it with 'pip install aiosqlite'."
|
||||
) from _AIOSQLITE_IMPORT_ERROR
|
||||
return aiosqlite
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelMetadataProvider(ABC):
|
||||
@@ -78,7 +108,8 @@ class CivArchiveModelMetadataProvider(ModelMetadataProvider):
|
||||
html_content = await response.text()
|
||||
|
||||
# Parse HTML to extract JSON data
|
||||
soup = BeautifulSoup(html_content, 'html.parser')
|
||||
soup_parser = _require_beautifulsoup()
|
||||
soup = soup_parser(html_content, 'html.parser')
|
||||
script_tag = soup.find('script', {'id': '__NEXT_DATA__', 'type': 'application/json'})
|
||||
|
||||
if not script_tag:
|
||||
@@ -171,10 +202,11 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
||||
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
self._aiosqlite = _require_aiosqlite()
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Find model by hash value from SQLite database"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
async with self._aiosqlite.connect(self.db_path) as db:
|
||||
# Look up in model_files table to get model_id and version_id
|
||||
query = """
|
||||
SELECT model_id, version_id
|
||||
@@ -182,7 +214,7 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
||||
WHERE sha256 = ?
|
||||
LIMIT 1
|
||||
"""
|
||||
db.row_factory = aiosqlite.Row
|
||||
db.row_factory = self._aiosqlite.Row
|
||||
cursor = await db.execute(query, (model_hash.upper(),))
|
||||
file_row = await cursor.fetchone()
|
||||
|
||||
@@ -199,8 +231,8 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
||||
|
||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||
"""Get all versions of a model from SQLite database"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with self._aiosqlite.connect(self.db_path) as db:
|
||||
db.row_factory = self._aiosqlite.Row
|
||||
|
||||
# First check if model exists
|
||||
model_query = "SELECT * FROM models WHERE id = ?"
|
||||
@@ -258,8 +290,8 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
||||
if not model_id and not version_id:
|
||||
return None
|
||||
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with self._aiosqlite.connect(self.db_path) as db:
|
||||
db.row_factory = self._aiosqlite.Row
|
||||
|
||||
# Case 1: Only version_id is provided
|
||||
if model_id is None and version_id is not None:
|
||||
@@ -295,8 +327,8 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
||||
|
||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Fetch model version metadata from SQLite database"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with self._aiosqlite.connect(self.db_path) as db:
|
||||
db.row_factory = self._aiosqlite.Row
|
||||
|
||||
# Get version details
|
||||
version_query = "SELECT model_id FROM model_versions WHERE id = ?"
|
||||
|
||||
8
pytest.ini
Normal file
8
pytest.ini
Normal file
@@ -0,0 +1,8 @@
|
||||
[pytest]
|
||||
addopts = -v --import-mode=importlib
|
||||
testpaths = tests
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
# Skip problematic directories to avoid import conflicts
|
||||
norecursedirs = .git .tox dist build *.egg __pycache__ py
|
||||
46
run_tests.py
Normal file
46
run_tests.py
Normal file
@@ -0,0 +1,46 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test runner script for ComfyUI-Lora-Manager.
|
||||
|
||||
This script runs pytest from the tests directory to avoid import issues
|
||||
with the root __init__.py file.
|
||||
"""
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
def main():
|
||||
"""Run pytest from the tests directory to avoid import issues."""
|
||||
# Get the script directory
|
||||
script_dir = Path(__file__).parent.absolute()
|
||||
tests_dir = script_dir / "tests"
|
||||
|
||||
if not tests_dir.exists():
|
||||
print(f"Error: Tests directory not found at {tests_dir}")
|
||||
return 1
|
||||
|
||||
# Change to tests directory
|
||||
original_cwd = os.getcwd()
|
||||
os.chdir(tests_dir)
|
||||
|
||||
try:
|
||||
# Build pytest command
|
||||
cmd = [
|
||||
sys.executable, "-m", "pytest",
|
||||
"-v",
|
||||
"--rootdir=.",
|
||||
] + sys.argv[1:] # Pass any additional arguments
|
||||
|
||||
print(f"Running: {' '.join(cmd)}")
|
||||
print(f"Working directory: {tests_dir}")
|
||||
|
||||
# Run pytest
|
||||
result = subprocess.run(cmd, cwd=tests_dir)
|
||||
return result.returncode
|
||||
finally:
|
||||
# Restore original working directory
|
||||
os.chdir(original_cwd)
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -1,8 +1,43 @@
|
||||
import types
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
from unittest import mock
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
# Mock ComfyUI modules before any imports from the main project
|
||||
server_mock = types.SimpleNamespace()
|
||||
server_mock.PromptServer = mock.MagicMock()
|
||||
sys.modules['server'] = server_mock
|
||||
|
||||
folder_paths_mock = types.SimpleNamespace()
|
||||
folder_paths_mock.get_folder_paths = mock.MagicMock(return_value=[])
|
||||
folder_paths_mock.folder_names_and_paths = {}
|
||||
sys.modules['folder_paths'] = folder_paths_mock
|
||||
|
||||
# Mock other ComfyUI modules that might be imported
|
||||
comfy_mock = types.SimpleNamespace()
|
||||
comfy_mock.utils = types.SimpleNamespace()
|
||||
comfy_mock.model_management = types.SimpleNamespace()
|
||||
comfy_mock.comfy_types = types.SimpleNamespace()
|
||||
comfy_mock.comfy_types.IO = mock.MagicMock()
|
||||
sys.modules['comfy'] = comfy_mock
|
||||
sys.modules['comfy.utils'] = comfy_mock.utils
|
||||
sys.modules['comfy.model_management'] = comfy_mock.model_management
|
||||
sys.modules['comfy.comfy_types'] = comfy_mock.comfy_types
|
||||
|
||||
execution_mock = types.SimpleNamespace()
|
||||
execution_mock.PromptExecutor = mock.MagicMock()
|
||||
sys.modules['execution'] = execution_mock
|
||||
|
||||
# Mock ComfyUI nodes module
|
||||
nodes_mock = types.SimpleNamespace()
|
||||
nodes_mock.LoraLoader = mock.MagicMock()
|
||||
nodes_mock.SaveImage = mock.MagicMock()
|
||||
nodes_mock.NODE_CLASS_MAPPINGS = {}
|
||||
sys.modules['nodes'] = nodes_mock
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockHashIndex:
|
||||
|
||||
Reference in New Issue
Block a user