mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 23:25:43 -03:00
- Add set_version_update_ignore endpoint to toggle ignore status for specific versions - Add get_model_versions endpoint to retrieve version details with optional refresh - Update serialization to include version-specific data and preview overrides - Modify database schema to support version-level ignore tracking - Improve error handling for rate limiting and missing models These changes enable granular control over version updates and provide better visibility into model version status.
780 lines
29 KiB
Python
780 lines
29 KiB
Python
"""Service for tracking remote model version updates."""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import sqlite3
|
|
import time
|
|
from dataclasses import dataclass, replace
|
|
from typing import Dict, Iterable, List, Mapping, Optional, Sequence
|
|
|
|
from .errors import RateLimitError
|
|
from .settings_manager import get_settings_manager
|
|
from ..utils.civitai_utils import rewrite_preview_url
|
|
from ..utils.preview_selection import select_preview_media
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class ModelVersionRecord:
|
|
"""Persisted metadata for a single model version."""
|
|
|
|
version_id: int
|
|
name: Optional[str]
|
|
base_model: Optional[str]
|
|
released_at: Optional[str]
|
|
size_bytes: Optional[int]
|
|
preview_url: Optional[str]
|
|
is_in_library: bool
|
|
should_ignore: bool
|
|
sort_index: int = 0
|
|
|
|
|
|
@dataclass
|
|
class ModelUpdateRecord:
|
|
"""Representation of a persisted update record."""
|
|
|
|
model_type: str
|
|
model_id: int
|
|
versions: List[ModelVersionRecord]
|
|
last_checked_at: Optional[float]
|
|
should_ignore_model: bool
|
|
|
|
@property
|
|
def largest_version_id(self) -> Optional[int]:
|
|
"""Return the highest known version identifier for the model."""
|
|
|
|
if not self.versions:
|
|
return None
|
|
return max(version.version_id for version in self.versions)
|
|
|
|
@property
|
|
def version_ids(self) -> List[int]:
|
|
"""Return all known version identifiers."""
|
|
|
|
return [version.version_id for version in self.versions]
|
|
|
|
@property
|
|
def in_library_version_ids(self) -> List[int]:
|
|
"""Return the subset of version identifiers present in the local library."""
|
|
|
|
return [version.version_id for version in self.versions if version.is_in_library]
|
|
|
|
def has_update(self) -> bool:
|
|
"""Return True when a non-ignored remote version is missing locally."""
|
|
|
|
if self.should_ignore_model:
|
|
return False
|
|
return any(
|
|
not version.is_in_library and not version.should_ignore for version in self.versions
|
|
)
|
|
|
|
|
|
class ModelUpdateService:
|
|
"""Persist and query remote model version metadata."""
|
|
|
|
_SCHEMA = """
|
|
PRAGMA foreign_keys = ON;
|
|
DROP TABLE IF EXISTS model_update_versions;
|
|
DROP TABLE IF EXISTS model_update_status;
|
|
CREATE TABLE model_update_status (
|
|
model_id INTEGER PRIMARY KEY,
|
|
model_type TEXT NOT NULL,
|
|
last_checked_at REAL,
|
|
should_ignore_model INTEGER NOT NULL DEFAULT 0
|
|
);
|
|
CREATE TABLE model_update_versions (
|
|
version_id INTEGER PRIMARY KEY,
|
|
model_id INTEGER NOT NULL,
|
|
sort_index INTEGER NOT NULL DEFAULT 0,
|
|
name TEXT,
|
|
base_model TEXT,
|
|
released_at TEXT,
|
|
size_bytes INTEGER,
|
|
preview_url TEXT,
|
|
is_in_library INTEGER NOT NULL DEFAULT 0,
|
|
should_ignore INTEGER NOT NULL DEFAULT 0,
|
|
FOREIGN KEY(model_id) REFERENCES model_update_status(model_id) ON DELETE CASCADE
|
|
);
|
|
CREATE INDEX IF NOT EXISTS idx_model_update_versions_model_id
|
|
ON model_update_versions(model_id);
|
|
"""
|
|
|
|
def __init__(self, db_path: str, *, ttl_seconds: int = 24 * 60 * 60, settings_manager=None) -> None:
|
|
self._db_path = db_path
|
|
self._ttl_seconds = ttl_seconds
|
|
self._lock = asyncio.Lock()
|
|
self._schema_initialized = False
|
|
self._settings = settings_manager or get_settings_manager()
|
|
self._ensure_directory()
|
|
self._initialize_schema()
|
|
|
|
def _ensure_directory(self) -> None:
|
|
directory = os.path.dirname(self._db_path)
|
|
if directory:
|
|
os.makedirs(directory, exist_ok=True)
|
|
|
|
def _connect(self) -> sqlite3.Connection:
|
|
conn = sqlite3.connect(self._db_path, check_same_thread=False)
|
|
conn.row_factory = sqlite3.Row
|
|
return conn
|
|
|
|
def _initialize_schema(self) -> None:
|
|
if self._schema_initialized:
|
|
return
|
|
try:
|
|
with self._connect() as conn:
|
|
conn.execute("PRAGMA journal_mode=WAL")
|
|
conn.execute("PRAGMA foreign_keys = ON")
|
|
conn.executescript(self._SCHEMA)
|
|
self._schema_initialized = True
|
|
except Exception as exc: # pragma: no cover - defensive guard
|
|
logger.error("Failed to initialize update schema: %s", exc, exc_info=True)
|
|
raise
|
|
|
|
async def refresh_for_model_type(
|
|
self,
|
|
model_type: str,
|
|
scanner,
|
|
metadata_provider,
|
|
*,
|
|
force_refresh: bool = False,
|
|
) -> Dict[int, ModelUpdateRecord]:
|
|
"""Refresh update information for every model present in the cache."""
|
|
|
|
local_versions = await self._collect_local_versions(scanner)
|
|
results: Dict[int, ModelUpdateRecord] = {}
|
|
prefetched: Dict[int, Mapping] = {}
|
|
|
|
fetch_targets: List[int] = []
|
|
if metadata_provider and local_versions:
|
|
now = time.time()
|
|
async with self._lock:
|
|
for model_id in local_versions.keys():
|
|
existing = self._get_record(model_type, model_id)
|
|
if existing and existing.should_ignore_model and not force_refresh:
|
|
continue
|
|
if force_refresh or not existing or self._is_stale(existing, now):
|
|
fetch_targets.append(model_id)
|
|
|
|
if fetch_targets:
|
|
try:
|
|
prefetched = await self._fetch_model_versions_bulk(
|
|
metadata_provider,
|
|
fetch_targets,
|
|
)
|
|
except NotImplementedError:
|
|
prefetched = {}
|
|
|
|
for model_id, version_ids in local_versions.items():
|
|
record = await self._refresh_single_model(
|
|
model_type,
|
|
model_id,
|
|
version_ids,
|
|
metadata_provider,
|
|
force_refresh=force_refresh,
|
|
prefetched_response=prefetched.get(model_id),
|
|
)
|
|
if record:
|
|
results[model_id] = record
|
|
return results
|
|
|
|
async def refresh_single_model(
|
|
self,
|
|
model_type: str,
|
|
model_id: int,
|
|
scanner,
|
|
metadata_provider,
|
|
*,
|
|
force_refresh: bool = False,
|
|
) -> Optional[ModelUpdateRecord]:
|
|
"""Refresh update information for a specific model id."""
|
|
|
|
local_versions = await self._collect_local_versions(scanner)
|
|
version_ids = local_versions.get(model_id, [])
|
|
return await self._refresh_single_model(
|
|
model_type,
|
|
model_id,
|
|
version_ids,
|
|
metadata_provider,
|
|
force_refresh=force_refresh,
|
|
)
|
|
|
|
async def update_in_library_versions(
|
|
self,
|
|
model_type: str,
|
|
model_id: int,
|
|
version_ids: Sequence[int],
|
|
) -> ModelUpdateRecord:
|
|
"""Persist a new set of in-library version identifiers."""
|
|
|
|
normalized_versions = self._normalize_sequence(version_ids)
|
|
async with self._lock:
|
|
existing = self._get_record(model_type, model_id)
|
|
record = self._merge_with_local_versions(
|
|
existing,
|
|
normalized_versions,
|
|
model_type=model_type,
|
|
model_id=model_id,
|
|
)
|
|
self._upsert_record(record)
|
|
return record
|
|
|
|
async def set_should_ignore(
|
|
self, model_type: str, model_id: int, should_ignore: bool
|
|
) -> ModelUpdateRecord:
|
|
"""Toggle the ignore flag for a model."""
|
|
|
|
async with self._lock:
|
|
existing = self._get_record(model_type, model_id)
|
|
if existing:
|
|
record = ModelUpdateRecord(
|
|
model_type=existing.model_type,
|
|
model_id=existing.model_id,
|
|
versions=list(existing.versions),
|
|
last_checked_at=existing.last_checked_at,
|
|
should_ignore_model=should_ignore,
|
|
)
|
|
else:
|
|
record = ModelUpdateRecord(
|
|
model_type=model_type,
|
|
model_id=model_id,
|
|
versions=[],
|
|
last_checked_at=None,
|
|
should_ignore_model=should_ignore,
|
|
)
|
|
self._upsert_record(record)
|
|
return record
|
|
|
|
async def set_version_should_ignore(
|
|
self,
|
|
model_type: str,
|
|
model_id: int,
|
|
version_id: int,
|
|
should_ignore: bool,
|
|
) -> ModelUpdateRecord:
|
|
"""Toggle the ignore flag for an individual version."""
|
|
|
|
async with self._lock:
|
|
existing = self._get_record(model_type, model_id)
|
|
versions: List[ModelVersionRecord] = []
|
|
found = False
|
|
if existing:
|
|
for record_version in existing.versions:
|
|
if record_version.version_id == version_id:
|
|
versions.append(
|
|
replace(record_version, should_ignore=should_ignore)
|
|
)
|
|
found = True
|
|
else:
|
|
versions.append(record_version)
|
|
if not found:
|
|
versions.append(
|
|
ModelVersionRecord(
|
|
version_id=version_id,
|
|
name=None,
|
|
base_model=None,
|
|
released_at=None,
|
|
size_bytes=None,
|
|
preview_url=None,
|
|
is_in_library=False,
|
|
should_ignore=should_ignore,
|
|
sort_index=len(versions),
|
|
)
|
|
)
|
|
|
|
record = ModelUpdateRecord(
|
|
model_type=existing.model_type if existing else model_type,
|
|
model_id=existing.model_id if existing else model_id,
|
|
versions=self._sorted_versions(versions),
|
|
last_checked_at=existing.last_checked_at if existing else None,
|
|
should_ignore_model=existing.should_ignore_model if existing else False,
|
|
)
|
|
self._upsert_record(record)
|
|
return record
|
|
|
|
async def get_record(self, model_type: str, model_id: int) -> Optional[ModelUpdateRecord]:
|
|
"""Return a cached record without triggering remote fetches."""
|
|
|
|
async with self._lock:
|
|
return self._get_record(model_type, model_id)
|
|
|
|
async def has_update(self, model_type: str, model_id: int) -> bool:
|
|
"""Determine if a model has updates pending."""
|
|
|
|
record = await self.get_record(model_type, model_id)
|
|
return record.has_update() if record else False
|
|
|
|
async def _refresh_single_model(
|
|
self,
|
|
model_type: str,
|
|
model_id: int,
|
|
local_versions: Sequence[int],
|
|
metadata_provider,
|
|
*,
|
|
force_refresh: bool = False,
|
|
prefetched_response: Optional[Mapping] = None,
|
|
) -> Optional[ModelUpdateRecord]:
|
|
normalized_local = self._normalize_sequence(local_versions)
|
|
now = time.time()
|
|
async with self._lock:
|
|
existing = self._get_record(model_type, model_id)
|
|
if existing and existing.should_ignore_model and not force_refresh:
|
|
record = self._merge_with_local_versions(
|
|
existing,
|
|
normalized_local,
|
|
)
|
|
self._upsert_record(record)
|
|
return record
|
|
|
|
should_fetch = force_refresh or not existing or self._is_stale(existing, now)
|
|
# release lock during network request
|
|
fetched_versions: List[ModelVersionRecord] | None = None
|
|
refresh_succeeded = False
|
|
response: Optional[Mapping] = None
|
|
if metadata_provider and should_fetch:
|
|
response = prefetched_response
|
|
if response is None:
|
|
try:
|
|
response = await metadata_provider.get_model_versions(model_id)
|
|
except RateLimitError:
|
|
raise
|
|
except Exception as exc: # pragma: no cover - defensive log
|
|
logger.error(
|
|
"Failed to fetch versions for model %s (%s): %s",
|
|
model_id,
|
|
model_type,
|
|
exc,
|
|
exc_info=True,
|
|
)
|
|
if response is not None:
|
|
extracted = self._extract_versions(response)
|
|
if extracted is not None:
|
|
fetched_versions = extracted
|
|
refresh_succeeded = True
|
|
|
|
async with self._lock:
|
|
existing = self._get_record(model_type, model_id)
|
|
if existing and existing.should_ignore_model and not force_refresh:
|
|
record = self._merge_with_local_versions(
|
|
existing,
|
|
normalized_local,
|
|
)
|
|
self._upsert_record(record)
|
|
return record
|
|
|
|
if refresh_succeeded and isinstance(fetched_versions, list):
|
|
record = self._build_record_from_remote(
|
|
model_type,
|
|
model_id,
|
|
normalized_local,
|
|
fetched_versions,
|
|
existing,
|
|
now,
|
|
)
|
|
else:
|
|
record = self._merge_with_local_versions(
|
|
existing,
|
|
normalized_local,
|
|
model_type=model_type,
|
|
model_id=model_id,
|
|
last_checked_at=existing.last_checked_at if existing else None,
|
|
)
|
|
self._upsert_record(record)
|
|
return record
|
|
|
|
async def _fetch_model_versions_bulk(
|
|
self,
|
|
metadata_provider,
|
|
model_ids: Sequence[int],
|
|
) -> Dict[int, Mapping]:
|
|
"""Fetch model metadata in batches of up to 100 ids."""
|
|
|
|
BATCH_SIZE = 100
|
|
normalized = self._normalize_sequence(model_ids)
|
|
if not normalized:
|
|
return {}
|
|
|
|
aggregated: Dict[int, Mapping] = {}
|
|
for index in range(0, len(normalized), BATCH_SIZE):
|
|
chunk = normalized[index : index + BATCH_SIZE]
|
|
try:
|
|
response = await metadata_provider.get_model_versions_bulk(chunk)
|
|
except RateLimitError:
|
|
raise
|
|
if response is None:
|
|
continue
|
|
if not isinstance(response, Mapping):
|
|
logger.debug(
|
|
"Unexpected bulk response type %s from provider %s", type(response), metadata_provider
|
|
)
|
|
continue
|
|
for key, value in response.items():
|
|
normalized_key = self._normalize_int(key)
|
|
if normalized_key is None:
|
|
continue
|
|
if isinstance(value, Mapping):
|
|
aggregated[normalized_key] = value
|
|
return aggregated
|
|
|
|
async def _collect_local_versions(self, scanner) -> Dict[int, List[int]]:
|
|
cache = await scanner.get_cached_data()
|
|
mapping: Dict[int, set[int]] = {}
|
|
if not cache or not getattr(cache, "raw_data", None):
|
|
return {}
|
|
|
|
for item in cache.raw_data:
|
|
civitai = item.get("civitai") if isinstance(item, dict) else None
|
|
if not isinstance(civitai, dict):
|
|
continue
|
|
model_id = self._normalize_int(civitai.get("modelId"))
|
|
version_id = self._normalize_int(civitai.get("id"))
|
|
if model_id is None or version_id is None:
|
|
continue
|
|
mapping.setdefault(model_id, set()).add(version_id)
|
|
|
|
return {model_id: sorted(ids) for model_id, ids in mapping.items()}
|
|
|
|
def _merge_with_local_versions(
|
|
self,
|
|
existing: Optional[ModelUpdateRecord],
|
|
normalized_local: Sequence[int],
|
|
*,
|
|
model_type: Optional[str] = None,
|
|
model_id: Optional[int] = None,
|
|
last_checked_at: Optional[float] = None,
|
|
) -> ModelUpdateRecord:
|
|
local_set = set(normalized_local)
|
|
versions: List[ModelVersionRecord] = []
|
|
ignore_map: Dict[int, bool] = {}
|
|
if existing:
|
|
model_type = existing.model_type
|
|
model_id = existing.model_id
|
|
last_checked_at = existing.last_checked_at if last_checked_at is None else last_checked_at
|
|
ignore_map = {version.version_id: version.should_ignore for version in existing.versions}
|
|
for version in existing.versions:
|
|
versions.append(
|
|
replace(
|
|
version,
|
|
is_in_library=version.version_id in local_set,
|
|
)
|
|
)
|
|
elif model_type is None or model_id is None:
|
|
raise ValueError("model_type and model_id are required when creating a new record")
|
|
|
|
seen_ids = {version.version_id for version in versions}
|
|
for missing_id in sorted(local_set - seen_ids):
|
|
versions.append(
|
|
ModelVersionRecord(
|
|
version_id=missing_id,
|
|
name=None,
|
|
base_model=None,
|
|
released_at=None,
|
|
size_bytes=None,
|
|
preview_url=None,
|
|
is_in_library=True,
|
|
should_ignore=ignore_map.get(missing_id, False),
|
|
sort_index=len(versions),
|
|
)
|
|
)
|
|
|
|
return ModelUpdateRecord(
|
|
model_type=model_type,
|
|
model_id=model_id,
|
|
versions=self._sorted_versions(versions),
|
|
last_checked_at=last_checked_at,
|
|
should_ignore_model=existing.should_ignore_model if existing else False,
|
|
)
|
|
|
|
def _build_record_from_remote(
|
|
self,
|
|
model_type: str,
|
|
model_id: int,
|
|
local_versions: Sequence[int],
|
|
remote_versions: Sequence[ModelVersionRecord],
|
|
existing: Optional[ModelUpdateRecord],
|
|
timestamp: float,
|
|
) -> ModelUpdateRecord:
|
|
local_set = set(local_versions)
|
|
ignore_map = {version.version_id: version.should_ignore for version in existing.versions} if existing else {}
|
|
preview_map = {version.version_id: version.preview_url for version in existing.versions} if existing else {}
|
|
sort_map = {version.version_id: version.sort_index for version in existing.versions} if existing else {}
|
|
existing_map = {version.version_id: version for version in existing.versions} if existing else {}
|
|
|
|
versions: List[ModelVersionRecord] = []
|
|
seen_ids: set[int] = set()
|
|
for index, remote_version in enumerate(remote_versions):
|
|
version_id = remote_version.version_id
|
|
seen_ids.add(version_id)
|
|
versions.append(
|
|
ModelVersionRecord(
|
|
version_id=version_id,
|
|
name=remote_version.name,
|
|
base_model=remote_version.base_model,
|
|
released_at=remote_version.released_at,
|
|
size_bytes=remote_version.size_bytes,
|
|
preview_url=remote_version.preview_url or preview_map.get(version_id),
|
|
is_in_library=version_id in local_set,
|
|
should_ignore=ignore_map.get(version_id, remote_version.should_ignore),
|
|
sort_index=sort_map.get(version_id, index),
|
|
)
|
|
)
|
|
|
|
missing_local = local_set - seen_ids
|
|
if missing_local:
|
|
for version_id in sorted(missing_local):
|
|
existing_version = existing_map.get(version_id)
|
|
if existing_version:
|
|
versions.append(
|
|
replace(
|
|
existing_version,
|
|
is_in_library=True,
|
|
)
|
|
)
|
|
else:
|
|
versions.append(
|
|
ModelVersionRecord(
|
|
version_id=version_id,
|
|
name=None,
|
|
base_model=None,
|
|
released_at=None,
|
|
size_bytes=None,
|
|
preview_url=None,
|
|
is_in_library=True,
|
|
should_ignore=ignore_map.get(version_id, False),
|
|
sort_index=len(versions),
|
|
)
|
|
)
|
|
|
|
return ModelUpdateRecord(
|
|
model_type=model_type,
|
|
model_id=model_id,
|
|
versions=self._sorted_versions(versions),
|
|
last_checked_at=timestamp,
|
|
should_ignore_model=existing.should_ignore_model if existing else False,
|
|
)
|
|
|
|
def _sorted_versions(self, versions: Sequence[ModelVersionRecord]) -> List[ModelVersionRecord]:
|
|
ordered = sorted(versions, key=lambda version: (version.sort_index, version.version_id))
|
|
normalized: List[ModelVersionRecord] = []
|
|
for index, version in enumerate(ordered):
|
|
normalized.append(replace(version, sort_index=index))
|
|
return normalized
|
|
|
|
def _is_stale(self, record: ModelUpdateRecord, now: float) -> bool:
|
|
if record.last_checked_at is None:
|
|
return True
|
|
return (now - record.last_checked_at) >= self._ttl_seconds
|
|
|
|
@staticmethod
|
|
def _normalize_int(value) -> Optional[int]:
|
|
try:
|
|
if value is None:
|
|
return None
|
|
return int(value)
|
|
except (TypeError, ValueError):
|
|
return None
|
|
|
|
def _normalize_sequence(self, values: Sequence[int]) -> List[int]:
|
|
normalized = [
|
|
item
|
|
for item in (self._normalize_int(value) for value in values)
|
|
if item is not None
|
|
]
|
|
return sorted(dict.fromkeys(normalized))
|
|
|
|
@staticmethod
|
|
def _normalize_string(value) -> Optional[str]:
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, str):
|
|
stripped = value.strip()
|
|
return stripped or None
|
|
try:
|
|
return str(value)
|
|
except Exception: # pragma: no cover - defensive conversion
|
|
return None
|
|
|
|
def _extract_versions(self, response) -> Optional[List[ModelVersionRecord]]:
|
|
if not isinstance(response, Mapping):
|
|
return None
|
|
versions = response.get("modelVersions")
|
|
if versions is None:
|
|
return []
|
|
if not isinstance(versions, Iterable):
|
|
return None
|
|
extracted: List[ModelVersionRecord] = []
|
|
for index, entry in enumerate(versions):
|
|
if not isinstance(entry, Mapping):
|
|
continue
|
|
version_id = self._normalize_int(entry.get("id"))
|
|
if version_id is None:
|
|
continue
|
|
name = self._normalize_string(entry.get("name"))
|
|
base_model = self._normalize_string(entry.get("baseModel"))
|
|
released_at = self._normalize_string(entry.get("publishedAt") or entry.get("createdAt"))
|
|
size_bytes = self._extract_size_bytes(entry.get("files"))
|
|
preview_url = self._extract_preview_url(entry.get("images"))
|
|
extracted.append(
|
|
ModelVersionRecord(
|
|
version_id=version_id,
|
|
name=name,
|
|
base_model=base_model,
|
|
released_at=released_at,
|
|
size_bytes=size_bytes,
|
|
preview_url=preview_url,
|
|
is_in_library=False,
|
|
should_ignore=False,
|
|
sort_index=index,
|
|
)
|
|
)
|
|
return extracted
|
|
|
|
def _extract_size_bytes(self, files) -> Optional[int]:
|
|
if not isinstance(files, Iterable):
|
|
return None
|
|
for entry in files:
|
|
if not isinstance(entry, Mapping):
|
|
continue
|
|
size_kb = entry.get("sizeKB")
|
|
if size_kb is None:
|
|
continue
|
|
try:
|
|
size_float = float(size_kb)
|
|
except (TypeError, ValueError):
|
|
continue
|
|
return int(size_float * 1024)
|
|
return None
|
|
|
|
def _extract_preview_url(self, images) -> Optional[str]:
|
|
if not isinstance(images, Iterable):
|
|
return None
|
|
|
|
candidates = [entry for entry in images if isinstance(entry, Mapping)]
|
|
if not candidates:
|
|
return None
|
|
|
|
blur_mature_content = True
|
|
settings = getattr(self, "_settings", None)
|
|
if settings is not None and hasattr(settings, "get"):
|
|
try:
|
|
blur_mature_content = bool(settings.get("blur_mature_content", True))
|
|
except Exception: # pragma: no cover - defensive guard
|
|
blur_mature_content = True
|
|
|
|
selected, _ = select_preview_media(candidates, blur_mature_content=blur_mature_content)
|
|
if not selected:
|
|
return None
|
|
|
|
url = selected.get("url")
|
|
if not isinstance(url, str) or not url:
|
|
return None
|
|
|
|
media_type = selected.get("type")
|
|
if not isinstance(media_type, str):
|
|
media_type = None
|
|
|
|
rewritten, _ = rewrite_preview_url(url, media_type)
|
|
return rewritten or url
|
|
|
|
def _get_record(self, model_type: str, model_id: int) -> Optional[ModelUpdateRecord]:
|
|
with self._connect() as conn:
|
|
status_row = conn.execute(
|
|
"""
|
|
SELECT model_id, model_type, last_checked_at, should_ignore_model
|
|
FROM model_update_status
|
|
WHERE model_id = ?
|
|
""",
|
|
(model_id,),
|
|
).fetchone()
|
|
if not status_row:
|
|
return None
|
|
stored_type = status_row["model_type"]
|
|
if stored_type and stored_type != model_type:
|
|
logger.debug(
|
|
"Model id %s requested as %s but stored as %s", model_id, model_type, stored_type
|
|
)
|
|
version_rows = conn.execute(
|
|
"""
|
|
SELECT version_id, sort_index, name, base_model, released_at, size_bytes,
|
|
preview_url, is_in_library, should_ignore
|
|
FROM model_update_versions
|
|
WHERE model_id = ?
|
|
ORDER BY sort_index ASC, version_id ASC
|
|
""",
|
|
(model_id,),
|
|
).fetchall()
|
|
|
|
versions: List[ModelVersionRecord] = []
|
|
for row in version_rows:
|
|
versions.append(
|
|
ModelVersionRecord(
|
|
version_id=int(row["version_id"]),
|
|
name=row["name"],
|
|
base_model=row["base_model"],
|
|
released_at=row["released_at"],
|
|
size_bytes=self._normalize_int(row["size_bytes"]),
|
|
preview_url=row["preview_url"],
|
|
is_in_library=bool(row["is_in_library"]),
|
|
should_ignore=bool(row["should_ignore"]),
|
|
sort_index=self._normalize_int(row["sort_index"]) or 0,
|
|
)
|
|
)
|
|
|
|
return ModelUpdateRecord(
|
|
model_type=stored_type or model_type,
|
|
model_id=int(status_row["model_id"]),
|
|
versions=self._sorted_versions(versions),
|
|
last_checked_at=status_row["last_checked_at"],
|
|
should_ignore_model=bool(status_row["should_ignore_model"]),
|
|
)
|
|
|
|
def _upsert_record(self, record: ModelUpdateRecord) -> None:
|
|
payload = (
|
|
record.model_id,
|
|
record.model_type,
|
|
record.last_checked_at,
|
|
1 if record.should_ignore_model else 0,
|
|
)
|
|
with self._connect() as conn:
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO model_update_status (
|
|
model_id, model_type, last_checked_at, should_ignore_model
|
|
) VALUES (?, ?, ?, ?)
|
|
ON CONFLICT(model_id) DO UPDATE SET
|
|
model_type = excluded.model_type,
|
|
last_checked_at = excluded.last_checked_at,
|
|
should_ignore_model = excluded.should_ignore_model
|
|
""",
|
|
payload,
|
|
)
|
|
conn.execute(
|
|
"DELETE FROM model_update_versions WHERE model_id = ?",
|
|
(record.model_id,),
|
|
)
|
|
for version in record.versions:
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO model_update_versions (
|
|
version_id, model_id, sort_index, name, base_model, released_at,
|
|
size_bytes, preview_url, is_in_library, should_ignore
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
version.version_id,
|
|
record.model_id,
|
|
version.sort_index,
|
|
version.name,
|
|
version.base_model,
|
|
version.released_at,
|
|
version.size_bytes,
|
|
version.preview_url,
|
|
1 if version.is_in_library else 0,
|
|
1 if version.should_ignore else 0,
|
|
),
|
|
)
|
|
conn.commit()
|