mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat: enhance model metadata provider with import error handling and mock setup for tests
This commit is contained in:
@@ -1,11 +1,41 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import json
|
import json
|
||||||
import aiosqlite
|
|
||||||
import logging
|
import logging
|
||||||
from bs4 import BeautifulSoup
|
from typing import Optional, Dict, Tuple, Any
|
||||||
from typing import Optional, Dict, Tuple
|
|
||||||
from .downloader import get_downloader
|
from .downloader import get_downloader
|
||||||
|
|
||||||
|
try:
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
except ImportError as exc:
|
||||||
|
BeautifulSoup = None # type: ignore[assignment]
|
||||||
|
_BS4_IMPORT_ERROR = exc
|
||||||
|
else:
|
||||||
|
_BS4_IMPORT_ERROR = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
import aiosqlite
|
||||||
|
except ImportError as exc:
|
||||||
|
aiosqlite = None # type: ignore[assignment]
|
||||||
|
_AIOSQLITE_IMPORT_ERROR = exc
|
||||||
|
else:
|
||||||
|
_AIOSQLITE_IMPORT_ERROR = None
|
||||||
|
|
||||||
|
def _require_beautifulsoup() -> Any:
|
||||||
|
if BeautifulSoup is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"BeautifulSoup (bs4) is required for CivArchiveModelMetadataProvider. "
|
||||||
|
"Install it with 'pip install beautifulsoup4'."
|
||||||
|
) from _BS4_IMPORT_ERROR
|
||||||
|
return BeautifulSoup
|
||||||
|
|
||||||
|
def _require_aiosqlite() -> Any:
|
||||||
|
if aiosqlite is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"aiosqlite is required for SQLiteModelMetadataProvider. "
|
||||||
|
"Install it with 'pip install aiosqlite'."
|
||||||
|
) from _AIOSQLITE_IMPORT_ERROR
|
||||||
|
return aiosqlite
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class ModelMetadataProvider(ABC):
|
class ModelMetadataProvider(ABC):
|
||||||
@@ -78,7 +108,8 @@ class CivArchiveModelMetadataProvider(ModelMetadataProvider):
|
|||||||
html_content = await response.text()
|
html_content = await response.text()
|
||||||
|
|
||||||
# Parse HTML to extract JSON data
|
# Parse HTML to extract JSON data
|
||||||
soup = BeautifulSoup(html_content, 'html.parser')
|
soup_parser = _require_beautifulsoup()
|
||||||
|
soup = soup_parser(html_content, 'html.parser')
|
||||||
script_tag = soup.find('script', {'id': '__NEXT_DATA__', 'type': 'application/json'})
|
script_tag = soup.find('script', {'id': '__NEXT_DATA__', 'type': 'application/json'})
|
||||||
|
|
||||||
if not script_tag:
|
if not script_tag:
|
||||||
@@ -171,10 +202,11 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
|||||||
|
|
||||||
def __init__(self, db_path: str):
|
def __init__(self, db_path: str):
|
||||||
self.db_path = db_path
|
self.db_path = db_path
|
||||||
|
self._aiosqlite = _require_aiosqlite()
|
||||||
|
|
||||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||||
"""Find model by hash value from SQLite database"""
|
"""Find model by hash value from SQLite database"""
|
||||||
async with aiosqlite.connect(self.db_path) as db:
|
async with self._aiosqlite.connect(self.db_path) as db:
|
||||||
# Look up in model_files table to get model_id and version_id
|
# Look up in model_files table to get model_id and version_id
|
||||||
query = """
|
query = """
|
||||||
SELECT model_id, version_id
|
SELECT model_id, version_id
|
||||||
@@ -182,7 +214,7 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
|||||||
WHERE sha256 = ?
|
WHERE sha256 = ?
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
"""
|
"""
|
||||||
db.row_factory = aiosqlite.Row
|
db.row_factory = self._aiosqlite.Row
|
||||||
cursor = await db.execute(query, (model_hash.upper(),))
|
cursor = await db.execute(query, (model_hash.upper(),))
|
||||||
file_row = await cursor.fetchone()
|
file_row = await cursor.fetchone()
|
||||||
|
|
||||||
@@ -199,8 +231,8 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
|||||||
|
|
||||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||||
"""Get all versions of a model from SQLite database"""
|
"""Get all versions of a model from SQLite database"""
|
||||||
async with aiosqlite.connect(self.db_path) as db:
|
async with self._aiosqlite.connect(self.db_path) as db:
|
||||||
db.row_factory = aiosqlite.Row
|
db.row_factory = self._aiosqlite.Row
|
||||||
|
|
||||||
# First check if model exists
|
# First check if model exists
|
||||||
model_query = "SELECT * FROM models WHERE id = ?"
|
model_query = "SELECT * FROM models WHERE id = ?"
|
||||||
@@ -258,8 +290,8 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
|||||||
if not model_id and not version_id:
|
if not model_id and not version_id:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async with aiosqlite.connect(self.db_path) as db:
|
async with self._aiosqlite.connect(self.db_path) as db:
|
||||||
db.row_factory = aiosqlite.Row
|
db.row_factory = self._aiosqlite.Row
|
||||||
|
|
||||||
# Case 1: Only version_id is provided
|
# Case 1: Only version_id is provided
|
||||||
if model_id is None and version_id is not None:
|
if model_id is None and version_id is not None:
|
||||||
@@ -295,8 +327,8 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
|||||||
|
|
||||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||||
"""Fetch model version metadata from SQLite database"""
|
"""Fetch model version metadata from SQLite database"""
|
||||||
async with aiosqlite.connect(self.db_path) as db:
|
async with self._aiosqlite.connect(self.db_path) as db:
|
||||||
db.row_factory = aiosqlite.Row
|
db.row_factory = self._aiosqlite.Row
|
||||||
|
|
||||||
# Get version details
|
# Get version details
|
||||||
version_query = "SELECT model_id FROM model_versions WHERE id = ?"
|
version_query = "SELECT model_id FROM model_versions WHERE id = ?"
|
||||||
|
|||||||
8
pytest.ini
Normal file
8
pytest.ini
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
[pytest]
|
||||||
|
addopts = -v --import-mode=importlib
|
||||||
|
testpaths = tests
|
||||||
|
python_files = test_*.py
|
||||||
|
python_classes = Test*
|
||||||
|
python_functions = test_*
|
||||||
|
# Skip problematic directories to avoid import conflicts
|
||||||
|
norecursedirs = .git .tox dist build *.egg __pycache__ py
|
||||||
46
run_tests.py
Normal file
46
run_tests.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test runner script for ComfyUI-Lora-Manager.
|
||||||
|
|
||||||
|
This script runs pytest from the tests directory to avoid import issues
|
||||||
|
with the root __init__.py file.
|
||||||
|
"""
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run pytest from the tests directory to avoid import issues."""
|
||||||
|
# Get the script directory
|
||||||
|
script_dir = Path(__file__).parent.absolute()
|
||||||
|
tests_dir = script_dir / "tests"
|
||||||
|
|
||||||
|
if not tests_dir.exists():
|
||||||
|
print(f"Error: Tests directory not found at {tests_dir}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# Change to tests directory
|
||||||
|
original_cwd = os.getcwd()
|
||||||
|
os.chdir(tests_dir)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Build pytest command
|
||||||
|
cmd = [
|
||||||
|
sys.executable, "-m", "pytest",
|
||||||
|
"-v",
|
||||||
|
"--rootdir=.",
|
||||||
|
] + sys.argv[1:] # Pass any additional arguments
|
||||||
|
|
||||||
|
print(f"Running: {' '.join(cmd)}")
|
||||||
|
print(f"Working directory: {tests_dir}")
|
||||||
|
|
||||||
|
# Run pytest
|
||||||
|
result = subprocess.run(cmd, cwd=tests_dir)
|
||||||
|
return result.returncode
|
||||||
|
finally:
|
||||||
|
# Restore original working directory
|
||||||
|
os.chdir(original_cwd)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
@@ -1,8 +1,43 @@
|
|||||||
|
import types
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Optional, Sequence
|
from typing import Any, Dict, List, Optional, Sequence
|
||||||
|
from unittest import mock
|
||||||
|
import sys
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
# Mock ComfyUI modules before any imports from the main project
|
||||||
|
server_mock = types.SimpleNamespace()
|
||||||
|
server_mock.PromptServer = mock.MagicMock()
|
||||||
|
sys.modules['server'] = server_mock
|
||||||
|
|
||||||
|
folder_paths_mock = types.SimpleNamespace()
|
||||||
|
folder_paths_mock.get_folder_paths = mock.MagicMock(return_value=[])
|
||||||
|
folder_paths_mock.folder_names_and_paths = {}
|
||||||
|
sys.modules['folder_paths'] = folder_paths_mock
|
||||||
|
|
||||||
|
# Mock other ComfyUI modules that might be imported
|
||||||
|
comfy_mock = types.SimpleNamespace()
|
||||||
|
comfy_mock.utils = types.SimpleNamespace()
|
||||||
|
comfy_mock.model_management = types.SimpleNamespace()
|
||||||
|
comfy_mock.comfy_types = types.SimpleNamespace()
|
||||||
|
comfy_mock.comfy_types.IO = mock.MagicMock()
|
||||||
|
sys.modules['comfy'] = comfy_mock
|
||||||
|
sys.modules['comfy.utils'] = comfy_mock.utils
|
||||||
|
sys.modules['comfy.model_management'] = comfy_mock.model_management
|
||||||
|
sys.modules['comfy.comfy_types'] = comfy_mock.comfy_types
|
||||||
|
|
||||||
|
execution_mock = types.SimpleNamespace()
|
||||||
|
execution_mock.PromptExecutor = mock.MagicMock()
|
||||||
|
sys.modules['execution'] = execution_mock
|
||||||
|
|
||||||
|
# Mock ComfyUI nodes module
|
||||||
|
nodes_mock = types.SimpleNamespace()
|
||||||
|
nodes_mock.LoraLoader = mock.MagicMock()
|
||||||
|
nodes_mock.SaveImage = mock.MagicMock()
|
||||||
|
nodes_mock.NODE_CLASS_MAPPINGS = {}
|
||||||
|
sys.modules['nodes'] = nodes_mock
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MockHashIndex:
|
class MockHashIndex:
|
||||||
|
|||||||
Reference in New Issue
Block a user