mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 06:32:12 -03:00
feat: enhance model metadata provider with import error handling and mock setup for tests
This commit is contained in:
@@ -1,11 +1,41 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import json
|
||||
import aiosqlite
|
||||
import logging
|
||||
from bs4 import BeautifulSoup
|
||||
from typing import Optional, Dict, Tuple
|
||||
from typing import Optional, Dict, Tuple, Any
|
||||
from .downloader import get_downloader
|
||||
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
except ImportError as exc:
|
||||
BeautifulSoup = None # type: ignore[assignment]
|
||||
_BS4_IMPORT_ERROR = exc
|
||||
else:
|
||||
_BS4_IMPORT_ERROR = None
|
||||
|
||||
try:
|
||||
import aiosqlite
|
||||
except ImportError as exc:
|
||||
aiosqlite = None # type: ignore[assignment]
|
||||
_AIOSQLITE_IMPORT_ERROR = exc
|
||||
else:
|
||||
_AIOSQLITE_IMPORT_ERROR = None
|
||||
|
||||
def _require_beautifulsoup() -> Any:
|
||||
if BeautifulSoup is None:
|
||||
raise RuntimeError(
|
||||
"BeautifulSoup (bs4) is required for CivArchiveModelMetadataProvider. "
|
||||
"Install it with 'pip install beautifulsoup4'."
|
||||
) from _BS4_IMPORT_ERROR
|
||||
return BeautifulSoup
|
||||
|
||||
def _require_aiosqlite() -> Any:
|
||||
if aiosqlite is None:
|
||||
raise RuntimeError(
|
||||
"aiosqlite is required for SQLiteModelMetadataProvider. "
|
||||
"Install it with 'pip install aiosqlite'."
|
||||
) from _AIOSQLITE_IMPORT_ERROR
|
||||
return aiosqlite
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelMetadataProvider(ABC):
|
||||
@@ -78,7 +108,8 @@ class CivArchiveModelMetadataProvider(ModelMetadataProvider):
|
||||
html_content = await response.text()
|
||||
|
||||
# Parse HTML to extract JSON data
|
||||
soup = BeautifulSoup(html_content, 'html.parser')
|
||||
soup_parser = _require_beautifulsoup()
|
||||
soup = soup_parser(html_content, 'html.parser')
|
||||
script_tag = soup.find('script', {'id': '__NEXT_DATA__', 'type': 'application/json'})
|
||||
|
||||
if not script_tag:
|
||||
@@ -171,10 +202,11 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
||||
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
self._aiosqlite = _require_aiosqlite()
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Find model by hash value from SQLite database"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
async with self._aiosqlite.connect(self.db_path) as db:
|
||||
# Look up in model_files table to get model_id and version_id
|
||||
query = """
|
||||
SELECT model_id, version_id
|
||||
@@ -182,7 +214,7 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
||||
WHERE sha256 = ?
|
||||
LIMIT 1
|
||||
"""
|
||||
db.row_factory = aiosqlite.Row
|
||||
db.row_factory = self._aiosqlite.Row
|
||||
cursor = await db.execute(query, (model_hash.upper(),))
|
||||
file_row = await cursor.fetchone()
|
||||
|
||||
@@ -199,8 +231,8 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
||||
|
||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||
"""Get all versions of a model from SQLite database"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with self._aiosqlite.connect(self.db_path) as db:
|
||||
db.row_factory = self._aiosqlite.Row
|
||||
|
||||
# First check if model exists
|
||||
model_query = "SELECT * FROM models WHERE id = ?"
|
||||
@@ -258,8 +290,8 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
||||
if not model_id and not version_id:
|
||||
return None
|
||||
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with self._aiosqlite.connect(self.db_path) as db:
|
||||
db.row_factory = self._aiosqlite.Row
|
||||
|
||||
# Case 1: Only version_id is provided
|
||||
if model_id is None and version_id is not None:
|
||||
@@ -295,8 +327,8 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
||||
|
||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Fetch model version metadata from SQLite database"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with self._aiosqlite.connect(self.db_path) as db:
|
||||
db.row_factory = self._aiosqlite.Row
|
||||
|
||||
# Get version details
|
||||
version_query = "SELECT model_id FROM model_versions WHERE id = ?"
|
||||
|
||||
Reference in New Issue
Block a user