feat: enhance model metadata provider with import error handling and mock setup for tests

This commit is contained in:
Will Miao
2025-09-21 19:56:38 +08:00
parent 4faf912c6f
commit 1022b07f64
4 changed files with 133 additions and 12 deletions

View File

@@ -1,11 +1,41 @@
from abc import ABC, abstractmethod
import json
import aiosqlite
import logging
from bs4 import BeautifulSoup
from typing import Optional, Dict, Tuple
from typing import Optional, Dict, Tuple, Any
from .downloader import get_downloader
try:
from bs4 import BeautifulSoup
except ImportError as exc:
BeautifulSoup = None # type: ignore[assignment]
_BS4_IMPORT_ERROR = exc
else:
_BS4_IMPORT_ERROR = None
try:
import aiosqlite
except ImportError as exc:
aiosqlite = None # type: ignore[assignment]
_AIOSQLITE_IMPORT_ERROR = exc
else:
_AIOSQLITE_IMPORT_ERROR = None
def _require_beautifulsoup() -> Any:
if BeautifulSoup is None:
raise RuntimeError(
"BeautifulSoup (bs4) is required for CivArchiveModelMetadataProvider. "
"Install it with 'pip install beautifulsoup4'."
) from _BS4_IMPORT_ERROR
return BeautifulSoup
def _require_aiosqlite() -> Any:
if aiosqlite is None:
raise RuntimeError(
"aiosqlite is required for SQLiteModelMetadataProvider. "
"Install it with 'pip install aiosqlite'."
) from _AIOSQLITE_IMPORT_ERROR
return aiosqlite
logger = logging.getLogger(__name__)
class ModelMetadataProvider(ABC):
@@ -78,7 +108,8 @@ class CivArchiveModelMetadataProvider(ModelMetadataProvider):
html_content = await response.text()
# Parse HTML to extract JSON data
soup = BeautifulSoup(html_content, 'html.parser')
soup_parser = _require_beautifulsoup()
soup = soup_parser(html_content, 'html.parser')
script_tag = soup.find('script', {'id': '__NEXT_DATA__', 'type': 'application/json'})
if not script_tag:
@@ -171,10 +202,11 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
def __init__(self, db_path: str):
self.db_path = db_path
self._aiosqlite = _require_aiosqlite()
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Find model by hash value from SQLite database"""
async with aiosqlite.connect(self.db_path) as db:
async with self._aiosqlite.connect(self.db_path) as db:
# Look up in model_files table to get model_id and version_id
query = """
SELECT model_id, version_id
@@ -182,7 +214,7 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
WHERE sha256 = ?
LIMIT 1
"""
db.row_factory = aiosqlite.Row
db.row_factory = self._aiosqlite.Row
cursor = await db.execute(query, (model_hash.upper(),))
file_row = await cursor.fetchone()
@@ -199,8 +231,8 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
"""Get all versions of a model from SQLite database"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
async with self._aiosqlite.connect(self.db_path) as db:
db.row_factory = self._aiosqlite.Row
# First check if model exists
model_query = "SELECT * FROM models WHERE id = ?"
@@ -258,8 +290,8 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
if not model_id and not version_id:
return None
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
async with self._aiosqlite.connect(self.db_path) as db:
db.row_factory = self._aiosqlite.Row
# Case 1: Only version_id is provided
if model_id is None and version_id is not None:
@@ -295,8 +327,8 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Fetch model version metadata from SQLite database"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
async with self._aiosqlite.connect(self.db_path) as db:
db.row_factory = self._aiosqlite.Row
# Get version details
version_query = "SELECT model_id FROM model_versions WHERE id = ?"