feat: add model version management endpoints

- Add set_version_update_ignore endpoint to toggle ignore status for specific versions
- Add get_model_versions endpoint to retrieve version details with optional refresh
- Update serialization to include version-specific data and preview overrides
- Modify database schema to support version-level ignore tracking
- Improve error handling for rate limiting and missing models

These changes enable granular control over version updates and provide better visibility into model version status.
This commit is contained in:
Will Miao
2025-10-25 14:54:23 +08:00
parent 0594e278b6
commit 58cafdb713
6 changed files with 753 additions and 148 deletions

View File

@@ -1085,6 +1085,28 @@ class ModelUpdateHandler:
)
return web.json_response({"success": True, "record": self._serialize_record(record)})
async def set_version_update_ignore(self, request: web.Request) -> web.Response:
payload = await self._read_json(request)
model_id = self._normalize_model_id(payload.get("modelId"))
version_id = self._normalize_model_id(payload.get("versionId"))
if model_id is None or version_id is None:
return web.json_response(
{"success": False, "error": "modelId and versionId are required"},
status=400,
)
should_ignore = self._parse_bool(payload.get("shouldIgnore"))
record = await self._update_service.set_version_should_ignore(
self._service.model_type,
model_id,
version_id,
should_ignore,
)
overrides = await self._build_preview_overrides(record)
return web.json_response(
{"success": True, "record": self._serialize_record(record, preview_overrides=overrides)}
)
async def get_model_update_status(self, request: web.Request) -> web.Response:
model_id = self._normalize_model_id(request.match_info.get("model_id"))
if model_id is None:
@@ -1107,6 +1129,33 @@ class ModelUpdateHandler:
return web.json_response({"success": True, "record": self._serialize_record(record)})
async def get_model_versions(self, request: web.Request) -> web.Response:
model_id = self._normalize_model_id(request.match_info.get("model_id"))
if model_id is None:
return web.json_response(
{"success": False, "error": "model_id must be an integer"}, status=400
)
refresh = self._parse_bool(request.query.get("refresh"))
force = self._parse_bool(request.query.get("force"))
try:
record = await self._get_or_refresh_record(model_id, refresh=refresh, force=force)
except RateLimitError as exc:
return web.json_response(
{"success": False, "error": str(exc) or "Rate limited"}, status=429
)
if record is None:
return web.json_response(
{"success": False, "error": "Model not tracked"}, status=404
)
overrides = await self._build_preview_overrides(record)
return web.json_response(
{"success": True, "record": self._serialize_record(record, preview_overrides=overrides)}
)
async def _get_or_refresh_record(
self, model_id: int, *, refresh: bool, force: bool
) -> Optional[object]:
@@ -1160,8 +1209,13 @@ class ModelUpdateHandler:
except (TypeError, ValueError):
return None
@staticmethod
def _serialize_record(record) -> Dict:
def _serialize_record(
self,
record,
*,
preview_overrides: Optional[Dict[int, Optional[str]]] = None,
) -> Dict:
overrides = preview_overrides or {}
return {
"modelType": record.model_type,
"modelId": record.model_id,
@@ -1169,10 +1223,50 @@ class ModelUpdateHandler:
"versionIds": record.version_ids,
"inLibraryVersionIds": record.in_library_version_ids,
"lastCheckedAt": record.last_checked_at,
"shouldIgnore": record.should_ignore,
"shouldIgnore": record.should_ignore_model,
"hasUpdate": record.has_update(),
"versions": [
self._serialize_version(version, overrides.get(version.version_id))
for version in record.versions
],
}
@staticmethod
def _serialize_version(version, override_preview: Optional[str]) -> Dict:
preview_url = override_preview if override_preview is not None else version.preview_url
return {
"versionId": version.version_id,
"name": version.name,
"baseModel": version.base_model,
"releasedAt": version.released_at,
"sizeBytes": version.size_bytes,
"previewUrl": preview_url,
"isInLibrary": version.is_in_library,
"shouldIgnore": version.should_ignore,
}
async def _build_preview_overrides(self, record) -> Dict[int, Optional[str]]:
overrides: Dict[int, Optional[str]] = {}
try:
cache = await self._service.scanner.get_cached_data()
except Exception as exc: # pragma: no cover - defensive logging
self._logger.debug("Failed to load cache while building preview overrides: %s", exc)
return overrides
version_index = getattr(cache, "version_index", None)
if not version_index:
return overrides
for version in record.versions:
if not version.is_in_library:
continue
cache_entry = version_index.get(version.version_id)
if isinstance(cache_entry, Mapping):
preview = cache_entry.get("preview_url")
if isinstance(preview, str) and preview:
overrides[version.version_id] = config.get_preview_static_url(preview)
return overrides
@dataclass
class ModelHandlerSet:
@@ -1233,6 +1327,8 @@ class ModelHandlerSet:
"get_relative_paths": self.query.get_relative_paths,
"refresh_model_updates": self.updates.refresh_model_updates,
"set_model_update_ignore": self.updates.set_model_update_ignore,
"set_version_update_ignore": self.updates.set_version_update_ignore,
"get_model_update_status": self.updates.get_model_update_status,
"get_model_versions": self.updates.get_model_versions,
}

View File

@@ -57,7 +57,9 @@ COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/hash/{hash}", "get_civitai_model_by_hash"),
RouteDefinition("POST", "/api/lm/{prefix}/updates/refresh", "refresh_model_updates"),
RouteDefinition("POST", "/api/lm/{prefix}/updates/ignore", "set_model_update_ignore"),
RouteDefinition("POST", "/api/lm/{prefix}/updates/ignore-version", "set_version_update_ignore"),
RouteDefinition("GET", "/api/lm/{prefix}/updates/status/{model_id}", "get_model_update_status"),
RouteDefinition("GET", "/api/lm/{prefix}/updates/versions/{model_id}", "get_model_versions"),
RouteDefinition("POST", "/api/lm/download-model", "download_model"),
RouteDefinition("GET", "/api/lm/download-model-get", "download_model_get"),
RouteDefinition("GET", "/api/lm/cancel-download-get", "cancel_download_get"),

View File

@@ -2,61 +2,112 @@
from __future__ import annotations
import asyncio
import json
import logging
import os
import sqlite3
import time
from dataclasses import dataclass
from dataclasses import dataclass, replace
from typing import Dict, Iterable, List, Mapping, Optional, Sequence
from .errors import RateLimitError
from .settings_manager import get_settings_manager
from ..utils.civitai_utils import rewrite_preview_url
from ..utils.preview_selection import select_preview_media
logger = logging.getLogger(__name__)
@dataclass
class ModelVersionRecord:
"""Persisted metadata for a single model version."""
version_id: int
name: Optional[str]
base_model: Optional[str]
released_at: Optional[str]
size_bytes: Optional[int]
preview_url: Optional[str]
is_in_library: bool
should_ignore: bool
sort_index: int = 0
@dataclass
class ModelUpdateRecord:
"""Representation of a persisted update record."""
model_type: str
model_id: int
largest_version_id: Optional[int]
version_ids: List[int]
in_library_version_ids: List[int]
versions: List[ModelVersionRecord]
last_checked_at: Optional[float]
should_ignore: bool
should_ignore_model: bool
@property
def largest_version_id(self) -> Optional[int]:
"""Return the highest known version identifier for the model."""
if not self.versions:
return None
return max(version.version_id for version in self.versions)
@property
def version_ids(self) -> List[int]:
"""Return all known version identifiers."""
return [version.version_id for version in self.versions]
@property
def in_library_version_ids(self) -> List[int]:
"""Return the subset of version identifiers present in the local library."""
return [version.version_id for version in self.versions if version.is_in_library]
def has_update(self) -> bool:
"""Return True when remote versions exceed the local library."""
"""Return True when a non-ignored remote version is missing locally."""
if self.should_ignore or not self.version_ids:
if self.should_ignore_model:
return False
local_versions = set(self.in_library_version_ids)
return any(version_id not in local_versions for version_id in self.version_ids)
return any(
not version.is_in_library and not version.should_ignore for version in self.versions
)
class ModelUpdateService:
"""Persist and query remote model version metadata."""
_SCHEMA = """
CREATE TABLE IF NOT EXISTS model_update_status (
PRAGMA foreign_keys = ON;
DROP TABLE IF EXISTS model_update_versions;
DROP TABLE IF EXISTS model_update_status;
CREATE TABLE model_update_status (
model_id INTEGER PRIMARY KEY,
model_type TEXT NOT NULL,
model_id INTEGER NOT NULL,
largest_version_id INTEGER,
version_ids TEXT,
in_library_version_ids TEXT,
last_checked_at REAL,
should_ignore INTEGER DEFAULT 0,
PRIMARY KEY (model_type, model_id)
)
should_ignore_model INTEGER NOT NULL DEFAULT 0
);
CREATE TABLE model_update_versions (
version_id INTEGER PRIMARY KEY,
model_id INTEGER NOT NULL,
sort_index INTEGER NOT NULL DEFAULT 0,
name TEXT,
base_model TEXT,
released_at TEXT,
size_bytes INTEGER,
preview_url TEXT,
is_in_library INTEGER NOT NULL DEFAULT 0,
should_ignore INTEGER NOT NULL DEFAULT 0,
FOREIGN KEY(model_id) REFERENCES model_update_status(model_id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_model_update_versions_model_id
ON model_update_versions(model_id);
"""
def __init__(self, db_path: str, *, ttl_seconds: int = 24 * 60 * 60) -> None:
def __init__(self, db_path: str, *, ttl_seconds: int = 24 * 60 * 60, settings_manager=None) -> None:
self._db_path = db_path
self._ttl_seconds = ttl_seconds
self._lock = asyncio.Lock()
self._schema_initialized = False
self._settings = settings_manager or get_settings_manager()
self._ensure_directory()
self._initialize_schema()
@@ -103,7 +154,7 @@ class ModelUpdateService:
async with self._lock:
for model_id in local_versions.keys():
existing = self._get_record(model_type, model_id)
if existing and existing.should_ignore and not force_refresh:
if existing and existing.should_ignore_model and not force_refresh:
continue
if force_refresh or not existing or self._is_stale(existing, now):
fetch_targets.append(model_id)
@@ -162,14 +213,11 @@ class ModelUpdateService:
normalized_versions = self._normalize_sequence(version_ids)
async with self._lock:
existing = self._get_record(model_type, model_id)
record = ModelUpdateRecord(
record = self._merge_with_local_versions(
existing,
normalized_versions,
model_type=model_type,
model_id=model_id,
largest_version_id=existing.largest_version_id if existing else None,
version_ids=list(existing.version_ids) if existing else [],
in_library_version_ids=normalized_versions,
last_checked_at=existing.last_checked_at if existing else None,
should_ignore=existing.should_ignore if existing else False,
)
self._upsert_record(record)
return record
@@ -183,27 +231,70 @@ class ModelUpdateService:
existing = self._get_record(model_type, model_id)
if existing:
record = ModelUpdateRecord(
model_type=model_type,
model_id=model_id,
largest_version_id=existing.largest_version_id,
version_ids=list(existing.version_ids),
in_library_version_ids=list(existing.in_library_version_ids),
model_type=existing.model_type,
model_id=existing.model_id,
versions=list(existing.versions),
last_checked_at=existing.last_checked_at,
should_ignore=should_ignore,
should_ignore_model=should_ignore,
)
else:
record = ModelUpdateRecord(
model_type=model_type,
model_id=model_id,
largest_version_id=None,
version_ids=[],
in_library_version_ids=[],
versions=[],
last_checked_at=None,
should_ignore=should_ignore,
should_ignore_model=should_ignore,
)
self._upsert_record(record)
return record
async def set_version_should_ignore(
self,
model_type: str,
model_id: int,
version_id: int,
should_ignore: bool,
) -> ModelUpdateRecord:
"""Toggle the ignore flag for an individual version."""
async with self._lock:
existing = self._get_record(model_type, model_id)
versions: List[ModelVersionRecord] = []
found = False
if existing:
for record_version in existing.versions:
if record_version.version_id == version_id:
versions.append(
replace(record_version, should_ignore=should_ignore)
)
found = True
else:
versions.append(record_version)
if not found:
versions.append(
ModelVersionRecord(
version_id=version_id,
name=None,
base_model=None,
released_at=None,
size_bytes=None,
preview_url=None,
is_in_library=False,
should_ignore=should_ignore,
sort_index=len(versions),
)
)
record = ModelUpdateRecord(
model_type=existing.model_type if existing else model_type,
model_id=existing.model_id if existing else model_id,
versions=self._sorted_versions(versions),
last_checked_at=existing.last_checked_at if existing else None,
should_ignore_model=existing.should_ignore_model if existing else False,
)
self._upsert_record(record)
return record
async def get_record(self, model_type: str, model_id: int) -> Optional[ModelUpdateRecord]:
"""Return a cached record without triggering remote fetches."""
@@ -230,22 +321,17 @@ class ModelUpdateService:
now = time.time()
async with self._lock:
existing = self._get_record(model_type, model_id)
if existing and existing.should_ignore and not force_refresh:
record = ModelUpdateRecord(
model_type=model_type,
model_id=model_id,
largest_version_id=existing.largest_version_id,
version_ids=list(existing.version_ids),
in_library_version_ids=normalized_local,
last_checked_at=existing.last_checked_at,
should_ignore=True,
if existing and existing.should_ignore_model and not force_refresh:
record = self._merge_with_local_versions(
existing,
normalized_local,
)
self._upsert_record(record)
return record
should_fetch = force_refresh or not existing or self._is_stale(existing, now)
# release lock during network request
fetched_versions: List[int] | None = None
fetched_versions: List[ModelVersionRecord] | None = None
refresh_succeeded = False
response: Optional[Mapping] = None
if metadata_provider and should_fetch:
@@ -264,45 +350,38 @@ class ModelUpdateService:
exc_info=True,
)
if response is not None:
extracted = self._extract_version_ids(response)
extracted = self._extract_versions(response)
if extracted is not None:
fetched_versions = extracted
refresh_succeeded = True
async with self._lock:
existing = self._get_record(model_type, model_id)
if existing and existing.should_ignore and not force_refresh:
# Ignore state could have flipped while awaiting provider
record = ModelUpdateRecord(
model_type=model_type,
model_id=model_id,
largest_version_id=existing.largest_version_id,
version_ids=list(existing.version_ids),
in_library_version_ids=normalized_local,
last_checked_at=existing.last_checked_at,
should_ignore=True,
if existing and existing.should_ignore_model and not force_refresh:
record = self._merge_with_local_versions(
existing,
normalized_local,
)
self._upsert_record(record)
return record
version_ids = (
fetched_versions
if refresh_succeeded
else (list(existing.version_ids) if existing else [])
)
largest = max(version_ids) if version_ids else None
last_checked = now if refresh_succeeded else (
existing.last_checked_at if existing else None
)
record = ModelUpdateRecord(
model_type=model_type,
model_id=model_id,
largest_version_id=largest,
version_ids=version_ids,
in_library_version_ids=normalized_local,
last_checked_at=last_checked,
should_ignore=existing.should_ignore if existing else False,
)
if refresh_succeeded and isinstance(fetched_versions, list):
record = self._build_record_from_remote(
model_type,
model_id,
normalized_local,
fetched_versions,
existing,
now,
)
else:
record = self._merge_with_local_versions(
existing,
normalized_local,
model_type=model_type,
model_id=model_id,
last_checked_at=existing.last_checked_at if existing else None,
)
self._upsert_record(record)
return record
@@ -358,6 +437,132 @@ class ModelUpdateService:
return {model_id: sorted(ids) for model_id, ids in mapping.items()}
def _merge_with_local_versions(
self,
existing: Optional[ModelUpdateRecord],
normalized_local: Sequence[int],
*,
model_type: Optional[str] = None,
model_id: Optional[int] = None,
last_checked_at: Optional[float] = None,
) -> ModelUpdateRecord:
local_set = set(normalized_local)
versions: List[ModelVersionRecord] = []
ignore_map: Dict[int, bool] = {}
if existing:
model_type = existing.model_type
model_id = existing.model_id
last_checked_at = existing.last_checked_at if last_checked_at is None else last_checked_at
ignore_map = {version.version_id: version.should_ignore for version in existing.versions}
for version in existing.versions:
versions.append(
replace(
version,
is_in_library=version.version_id in local_set,
)
)
elif model_type is None or model_id is None:
raise ValueError("model_type and model_id are required when creating a new record")
seen_ids = {version.version_id for version in versions}
for missing_id in sorted(local_set - seen_ids):
versions.append(
ModelVersionRecord(
version_id=missing_id,
name=None,
base_model=None,
released_at=None,
size_bytes=None,
preview_url=None,
is_in_library=True,
should_ignore=ignore_map.get(missing_id, False),
sort_index=len(versions),
)
)
return ModelUpdateRecord(
model_type=model_type,
model_id=model_id,
versions=self._sorted_versions(versions),
last_checked_at=last_checked_at,
should_ignore_model=existing.should_ignore_model if existing else False,
)
def _build_record_from_remote(
self,
model_type: str,
model_id: int,
local_versions: Sequence[int],
remote_versions: Sequence[ModelVersionRecord],
existing: Optional[ModelUpdateRecord],
timestamp: float,
) -> ModelUpdateRecord:
local_set = set(local_versions)
ignore_map = {version.version_id: version.should_ignore for version in existing.versions} if existing else {}
preview_map = {version.version_id: version.preview_url for version in existing.versions} if existing else {}
sort_map = {version.version_id: version.sort_index for version in existing.versions} if existing else {}
existing_map = {version.version_id: version for version in existing.versions} if existing else {}
versions: List[ModelVersionRecord] = []
seen_ids: set[int] = set()
for index, remote_version in enumerate(remote_versions):
version_id = remote_version.version_id
seen_ids.add(version_id)
versions.append(
ModelVersionRecord(
version_id=version_id,
name=remote_version.name,
base_model=remote_version.base_model,
released_at=remote_version.released_at,
size_bytes=remote_version.size_bytes,
preview_url=remote_version.preview_url or preview_map.get(version_id),
is_in_library=version_id in local_set,
should_ignore=ignore_map.get(version_id, remote_version.should_ignore),
sort_index=sort_map.get(version_id, index),
)
)
missing_local = local_set - seen_ids
if missing_local:
for version_id in sorted(missing_local):
existing_version = existing_map.get(version_id)
if existing_version:
versions.append(
replace(
existing_version,
is_in_library=True,
)
)
else:
versions.append(
ModelVersionRecord(
version_id=version_id,
name=None,
base_model=None,
released_at=None,
size_bytes=None,
preview_url=None,
is_in_library=True,
should_ignore=ignore_map.get(version_id, False),
sort_index=len(versions),
)
)
return ModelUpdateRecord(
model_type=model_type,
model_id=model_id,
versions=self._sorted_versions(versions),
last_checked_at=timestamp,
should_ignore_model=existing.should_ignore_model if existing else False,
)
def _sorted_versions(self, versions: Sequence[ModelVersionRecord]) -> List[ModelVersionRecord]:
ordered = sorted(versions, key=lambda version: (version.sort_index, version.version_id))
normalized: List[ModelVersionRecord] = []
for index, version in enumerate(ordered):
normalized.append(replace(version, sort_index=index))
return normalized
def _is_stale(self, record: ModelUpdateRecord, now: float) -> bool:
if record.last_checked_at is None:
return True
@@ -380,7 +585,19 @@ class ModelUpdateService:
]
return sorted(dict.fromkeys(normalized))
def _extract_version_ids(self, response) -> Optional[List[int]]:
@staticmethod
def _normalize_string(value) -> Optional[str]:
if value is None:
return None
if isinstance(value, str):
stripped = value.strip()
return stripped or None
try:
return str(value)
except Exception: # pragma: no cover - defensive conversion
return None
def _extract_versions(self, response) -> Optional[List[ModelVersionRecord]]:
if not isinstance(response, Mapping):
return None
versions = response.get("modelVersions")
@@ -388,84 +605,175 @@ class ModelUpdateService:
return []
if not isinstance(versions, Iterable):
return None
normalized = []
for entry in versions:
if isinstance(entry, Mapping):
normalized_id = self._normalize_int(entry.get("id"))
else:
normalized_id = self._normalize_int(entry)
if normalized_id is not None:
normalized.append(normalized_id)
return sorted(dict.fromkeys(normalized))
extracted: List[ModelVersionRecord] = []
for index, entry in enumerate(versions):
if not isinstance(entry, Mapping):
continue
version_id = self._normalize_int(entry.get("id"))
if version_id is None:
continue
name = self._normalize_string(entry.get("name"))
base_model = self._normalize_string(entry.get("baseModel"))
released_at = self._normalize_string(entry.get("publishedAt") or entry.get("createdAt"))
size_bytes = self._extract_size_bytes(entry.get("files"))
preview_url = self._extract_preview_url(entry.get("images"))
extracted.append(
ModelVersionRecord(
version_id=version_id,
name=name,
base_model=base_model,
released_at=released_at,
size_bytes=size_bytes,
preview_url=preview_url,
is_in_library=False,
should_ignore=False,
sort_index=index,
)
)
return extracted
def _extract_size_bytes(self, files) -> Optional[int]:
if not isinstance(files, Iterable):
return None
for entry in files:
if not isinstance(entry, Mapping):
continue
size_kb = entry.get("sizeKB")
if size_kb is None:
continue
try:
size_float = float(size_kb)
except (TypeError, ValueError):
continue
return int(size_float * 1024)
return None
def _extract_preview_url(self, images) -> Optional[str]:
if not isinstance(images, Iterable):
return None
candidates = [entry for entry in images if isinstance(entry, Mapping)]
if not candidates:
return None
blur_mature_content = True
settings = getattr(self, "_settings", None)
if settings is not None and hasattr(settings, "get"):
try:
blur_mature_content = bool(settings.get("blur_mature_content", True))
except Exception: # pragma: no cover - defensive guard
blur_mature_content = True
selected, _ = select_preview_media(candidates, blur_mature_content=blur_mature_content)
if not selected:
return None
url = selected.get("url")
if not isinstance(url, str) or not url:
return None
media_type = selected.get("type")
if not isinstance(media_type, str):
media_type = None
rewritten, _ = rewrite_preview_url(url, media_type)
return rewritten or url
def _get_record(self, model_type: str, model_id: int) -> Optional[ModelUpdateRecord]:
with self._connect() as conn:
row = conn.execute(
status_row = conn.execute(
"""
SELECT model_type, model_id, largest_version_id, version_ids,
in_library_version_ids, last_checked_at, should_ignore
SELECT model_id, model_type, last_checked_at, should_ignore_model
FROM model_update_status
WHERE model_type = ? AND model_id = ?
WHERE model_id = ?
""",
(model_type, model_id),
(model_id,),
).fetchone()
if not row:
return None
if not status_row:
return None
stored_type = status_row["model_type"]
if stored_type and stored_type != model_type:
logger.debug(
"Model id %s requested as %s but stored as %s", model_id, model_type, stored_type
)
version_rows = conn.execute(
"""
SELECT version_id, sort_index, name, base_model, released_at, size_bytes,
preview_url, is_in_library, should_ignore
FROM model_update_versions
WHERE model_id = ?
ORDER BY sort_index ASC, version_id ASC
""",
(model_id,),
).fetchall()
versions: List[ModelVersionRecord] = []
for row in version_rows:
versions.append(
ModelVersionRecord(
version_id=int(row["version_id"]),
name=row["name"],
base_model=row["base_model"],
released_at=row["released_at"],
size_bytes=self._normalize_int(row["size_bytes"]),
preview_url=row["preview_url"],
is_in_library=bool(row["is_in_library"]),
should_ignore=bool(row["should_ignore"]),
sort_index=self._normalize_int(row["sort_index"]) or 0,
)
)
return ModelUpdateRecord(
model_type=row["model_type"],
model_id=int(row["model_id"]),
largest_version_id=self._normalize_int(row["largest_version_id"]),
version_ids=self._deserialize_json_array(row["version_ids"]),
in_library_version_ids=self._deserialize_json_array(
row["in_library_version_ids"]
),
last_checked_at=row["last_checked_at"],
should_ignore=bool(row["should_ignore"]),
model_type=stored_type or model_type,
model_id=int(status_row["model_id"]),
versions=self._sorted_versions(versions),
last_checked_at=status_row["last_checked_at"],
should_ignore_model=bool(status_row["should_ignore_model"]),
)
def _upsert_record(self, record: ModelUpdateRecord) -> None:
payload = (
record.model_type,
record.model_id,
record.largest_version_id,
json.dumps(record.version_ids),
json.dumps(record.in_library_version_ids),
record.model_type,
record.last_checked_at,
1 if record.should_ignore else 0,
1 if record.should_ignore_model else 0,
)
with self._connect() as conn:
conn.execute(
"""
INSERT INTO model_update_status (
model_type, model_id, largest_version_id, version_ids,
in_library_version_ids, last_checked_at, should_ignore
) VALUES (?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(model_type, model_id) DO UPDATE SET
largest_version_id = excluded.largest_version_id,
version_ids = excluded.version_ids,
in_library_version_ids = excluded.in_library_version_ids,
model_id, model_type, last_checked_at, should_ignore_model
) VALUES (?, ?, ?, ?)
ON CONFLICT(model_id) DO UPDATE SET
model_type = excluded.model_type,
last_checked_at = excluded.last_checked_at,
should_ignore = excluded.should_ignore
should_ignore_model = excluded.should_ignore_model
""",
payload,
)
conn.execute(
"DELETE FROM model_update_versions WHERE model_id = ?",
(record.model_id,),
)
for version in record.versions:
conn.execute(
"""
INSERT INTO model_update_versions (
version_id, model_id, sort_index, name, base_model, released_at,
size_bytes, preview_url, is_in_library, should_ignore
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
version.version_id,
record.model_id,
version.sort_index,
version.name,
version.base_model,
version.released_at,
version.size_bytes,
version.preview_url,
1 if version.is_in_library else 0,
1 if version.should_ignore else 0,
),
)
conn.commit()
@staticmethod
def _deserialize_json_array(value) -> List[int]:
if not value:
return []
try:
data = json.loads(value)
except (TypeError, json.JSONDecodeError):
return []
if isinstance(data, list):
normalized = []
for entry in data:
try:
normalized.append(int(entry))
except (TypeError, ValueError):
continue
return sorted(dict.fromkeys(normalized))
return []

View File

@@ -21,6 +21,7 @@ from py.services import model_file_service
from py.services.downloader import DownloadProgress
from py.services.metadata_sync_service import MetadataSyncService
from py.services.model_file_service import AutoOrganizeResult
from py.services.model_update_service import ModelVersionRecord
from py.services.service_registry import ServiceRegistry
from py.services.websocket_manager import ws_manager
from py.utils.exif_utils import ExifUtils
@@ -42,11 +43,23 @@ class DummyRoutes(BaseModelRoutes):
class NullUpdateRecord:
model_type: str
model_id: int
largest_version_id: int | None = None
version_ids: list[int] = field(default_factory=list)
in_library_version_ids: list[int] = field(default_factory=list)
versions: list[ModelVersionRecord] = field(default_factory=list)
last_checked_at: float | None = None
should_ignore: bool = False
should_ignore_model: bool = False
@property
def largest_version_id(self) -> int | None:
if not self.versions:
return None
return max(version.version_id for version in self.versions)
@property
def version_ids(self) -> list[int]:
return [version.version_id for version in self.versions]
@property
def in_library_version_ids(self) -> list[int]:
return [version.version_id for version in self.versions if version.is_in_library]
def has_update(self) -> bool:
return False
@@ -60,10 +73,30 @@ class NullModelUpdateService:
return None
async def update_in_library_versions(self, model_type, model_id, version_ids):
return NullUpdateRecord(model_type=model_type, model_id=model_id, in_library_version_ids=list(version_ids))
versions = [
ModelVersionRecord(
version_id=version_id,
name=None,
base_model=None,
released_at=None,
size_bytes=None,
preview_url=None,
is_in_library=True,
should_ignore=False,
)
for version_id in version_ids
]
return NullUpdateRecord(model_type=model_type, model_id=model_id, versions=versions)
async def set_should_ignore(self, model_type, model_id, should_ignore):
return NullUpdateRecord(model_type=model_type, model_id=model_id, should_ignore=should_ignore)
return NullUpdateRecord(
model_type=model_type,
model_id=model_id,
should_ignore_model=should_ignore,
)
async def set_version_should_ignore(self, model_type, model_id, version_id, should_ignore):
return await self.set_should_ignore(model_type, model_id, should_ignore)
async def get_record(self, *args, **kwargs):
return None

View File

@@ -0,0 +1,57 @@
import logging
from types import SimpleNamespace
import pytest
from py.config import config
from py.routes.handlers.model_handlers import ModelUpdateHandler
from py.services.model_update_service import ModelUpdateRecord, ModelVersionRecord
class DummyScanner:
def __init__(self, cache):
self._cache = cache
async def get_cached_data(self):
return self._cache
class DummyService:
def __init__(self, cache):
self.model_type = "lora"
self.scanner = DummyScanner(cache)
@pytest.mark.asyncio
async def test_build_preview_overrides_uses_static_urls():
cache = SimpleNamespace(version_index={123: {"preview_url": "/tmp/previews/example.png"}})
service = DummyService(cache)
handler = ModelUpdateHandler(
service=service,
update_service=SimpleNamespace(),
metadata_provider_selector=lambda *_: None,
logger=logging.getLogger(__name__),
)
record = ModelUpdateRecord(
model_type="lora",
model_id=42,
versions=[
ModelVersionRecord(
version_id=123,
name=None,
base_model=None,
released_at=None,
size_bytes=None,
preview_url=None,
is_in_library=True,
should_ignore=False,
)
],
last_checked_at=None,
should_ignore_model=False,
)
overrides = await handler._build_preview_overrides(record)
expected = config.get_preview_static_url("/tmp/previews/example.png")
assert overrides == {123: expected}

View File

@@ -1,4 +1,3 @@
import asyncio
from types import SimpleNamespace
import pytest
@@ -8,7 +7,7 @@ from py.services.model_update_service import ModelUpdateService
class DummyScanner:
def __init__(self, raw_data):
self._cache = SimpleNamespace(raw_data=raw_data)
self._cache = SimpleNamespace(raw_data=raw_data, version_index={})
async def get_cached_data(self, *args, **kwargs):
return self._cache
@@ -41,7 +40,28 @@ async def test_refresh_persists_versions_and_uses_cache(tmp_path):
{"civitai": {"modelId": 1, "id": 15}},
]
scanner = DummyScanner(raw_data)
provider = DummyProvider({"modelVersions": [{"id": 11}, {"id": 15}]})
provider = DummyProvider(
{
"modelVersions": [
{
"id": 11,
"name": "v1",
"baseModel": "SD15",
"publishedAt": "2024-01-01T00:00:00Z",
"files": [{"sizeKB": 1024}],
"images": [{"url": "https://example.com/1.png"}],
},
{
"id": 15,
"name": "v1.5",
"baseModel": "SD15",
"publishedAt": "2024-02-01T00:00:00Z",
"files": [{"sizeKB": 512}],
"images": [{"url": "https://example.com/2.png"}],
},
]
}
)
await service.refresh_for_model_type("lora", scanner, provider)
record = await service.get_record("lora", 1)
@@ -51,6 +71,8 @@ async def test_refresh_persists_versions_and_uses_cache(tmp_path):
assert record is not None
assert record.version_ids == [11, 15]
assert record.in_library_version_ids == [11, 15]
assert [version.name for version in record.versions] == ["v1", "v1.5"]
assert record.should_ignore_model is False
assert record.has_update() is False
await service.refresh_for_model_type("lora", scanner, provider)
@@ -64,7 +86,14 @@ async def test_refresh_respects_ignore_flag(tmp_path):
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
raw_data = [{"civitai": {"modelId": 2, "id": 21}}]
scanner = DummyScanner(raw_data)
provider = DummyProvider({"modelVersions": [{"id": 21}, {"id": 22}]})
provider = DummyProvider(
{
"modelVersions": [
{"id": 21, "files": [], "images": []},
{"id": 22, "files": [], "images": []},
]
}
)
await service.refresh_for_model_type("lora", scanner, provider)
await service.set_should_ignore("lora", 2, True)
@@ -74,6 +103,9 @@ async def test_refresh_respects_ignore_flag(tmp_path):
await service.refresh_for_model_type("lora", scanner, provider)
assert provider.calls == 0
assert provider.bulk_calls == []
record = await service.get_record("lora", 2)
assert record is not None
assert record.should_ignore_model is True
@pytest.mark.asyncio
@@ -82,7 +114,10 @@ async def test_refresh_falls_back_when_bulk_not_supported(tmp_path):
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
raw_data = [{"civitai": {"modelId": 4, "id": 41}}]
scanner = DummyScanner(raw_data)
provider = DummyProvider({"modelVersions": [{"id": 41}]}, support_bulk=False)
provider = DummyProvider(
{"modelVersions": [{"id": 41, "files": [], "images": []}]},
support_bulk=False,
)
await service.refresh_for_model_type("lora", scanner, provider)
record = await service.get_record("lora", 4)
@@ -117,7 +152,14 @@ async def test_update_in_library_versions_changes_update_state(tmp_path):
service = ModelUpdateService(str(db_path), ttl_seconds=1)
raw_data = [{"civitai": {"modelId": 3, "id": 31}}]
scanner = DummyScanner(raw_data)
provider = DummyProvider({"modelVersions": [{"id": 31}, {"id": 35}]})
provider = DummyProvider(
{
"modelVersions": [
{"id": 31, "files": [], "images": []},
{"id": 35, "files": [], "images": []},
]
}
)
await service.refresh_for_model_type("lora", scanner, provider)
await service.update_in_library_versions("lora", 3, [31])
@@ -130,3 +172,70 @@ async def test_update_in_library_versions_changes_update_state(tmp_path):
record = await service.get_record("lora", 3)
assert record.has_update() is False
@pytest.mark.asyncio
async def test_version_ignore_blocks_update_flag(tmp_path):
db_path = tmp_path / "updates.sqlite"
service = ModelUpdateService(str(db_path), ttl_seconds=1)
raw_data = [{"civitai": {"modelId": 5, "id": 51}}]
scanner = DummyScanner(raw_data)
provider = DummyProvider(
{
"modelVersions": [
{"id": 51, "files": [], "images": []},
{"id": 55, "files": [], "images": []},
]
}
)
await service.refresh_for_model_type("lora", scanner, provider)
record = await service.get_record("lora", 5)
assert record is not None
assert record.has_update() is True
await service.set_version_should_ignore("lora", 5, 55, True)
record = await service.get_record("lora", 5)
assert record is not None
assert record.has_update() is False
@pytest.mark.asyncio
async def test_refresh_rewrites_remote_preview_urls(tmp_path):
db_path = tmp_path / "updates.sqlite"
service = ModelUpdateService(str(db_path), ttl_seconds=1)
raw_data = [{"civitai": {"modelId": 7, "id": 71}}]
scanner = DummyScanner(raw_data)
provider = DummyProvider(
{
"modelVersions": [
{
"id": 71,
"files": [],
"images": [
{
"url": "https://image.civitai.com/high/original=true/sample.png",
"nsfwLevel": 6,
"type": "image",
},
{
"url": "https://image.civitai.com/safe/original=true/preview.png",
"nsfwLevel": 1,
"type": "image",
},
],
}
]
}
)
await service.refresh_for_model_type("lora", scanner, provider)
record = await service.get_record("lora", 7)
assert record is not None
assert record.versions
preview_url = record.versions[0].preview_url
assert (
preview_url
== "https://image.civitai.com/safe/width=450,optimized=true/preview.png"
)