refactor(routes): limit update endpoints to essentials

This commit is contained in:
pixelpaws
2025-10-15 15:37:35 +08:00
parent 321ff72953
commit ee0d241c75
11 changed files with 749 additions and 8 deletions

View File

@@ -0,0 +1,404 @@
"""Service for tracking remote model version updates."""
from __future__ import annotations
import asyncio
import json
import logging
import os
import sqlite3
import time
from dataclasses import dataclass
from typing import Dict, Iterable, List, Mapping, Optional, Sequence
from .errors import RateLimitError
logger = logging.getLogger(__name__)
@dataclass
class ModelUpdateRecord:
"""Representation of a persisted update record."""
model_type: str
model_id: int
largest_version_id: Optional[int]
version_ids: List[int]
in_library_version_ids: List[int]
last_checked_at: Optional[float]
should_ignore: bool
def has_update(self) -> bool:
"""Return True when remote versions exceed the local library."""
if self.should_ignore or not self.version_ids:
return False
local_versions = set(self.in_library_version_ids)
return any(version_id not in local_versions for version_id in self.version_ids)
class ModelUpdateService:
"""Persist and query remote model version metadata."""
_SCHEMA = """
CREATE TABLE IF NOT EXISTS model_update_status (
model_type TEXT NOT NULL,
model_id INTEGER NOT NULL,
largest_version_id INTEGER,
version_ids TEXT,
in_library_version_ids TEXT,
last_checked_at REAL,
should_ignore INTEGER DEFAULT 0,
PRIMARY KEY (model_type, model_id)
)
"""
def __init__(self, db_path: str, *, ttl_seconds: int = 24 * 60 * 60) -> None:
self._db_path = db_path
self._ttl_seconds = ttl_seconds
self._lock = asyncio.Lock()
self._schema_initialized = False
self._ensure_directory()
self._initialize_schema()
def _ensure_directory(self) -> None:
directory = os.path.dirname(self._db_path)
if directory:
os.makedirs(directory, exist_ok=True)
def _connect(self) -> sqlite3.Connection:
conn = sqlite3.connect(self._db_path, check_same_thread=False)
conn.row_factory = sqlite3.Row
return conn
def _initialize_schema(self) -> None:
if self._schema_initialized:
return
try:
with self._connect() as conn:
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA foreign_keys = ON")
conn.executescript(self._SCHEMA)
self._schema_initialized = True
except Exception as exc: # pragma: no cover - defensive guard
logger.error("Failed to initialize update schema: %s", exc, exc_info=True)
raise
async def refresh_for_model_type(
self,
model_type: str,
scanner,
metadata_provider,
*,
force_refresh: bool = False,
) -> Dict[int, ModelUpdateRecord]:
"""Refresh update information for every model present in the cache."""
local_versions = await self._collect_local_versions(scanner)
results: Dict[int, ModelUpdateRecord] = {}
for model_id, version_ids in local_versions.items():
record = await self._refresh_single_model(
model_type,
model_id,
version_ids,
metadata_provider,
force_refresh=force_refresh,
)
if record:
results[model_id] = record
return results
async def refresh_single_model(
self,
model_type: str,
model_id: int,
scanner,
metadata_provider,
*,
force_refresh: bool = False,
) -> Optional[ModelUpdateRecord]:
"""Refresh update information for a specific model id."""
local_versions = await self._collect_local_versions(scanner)
version_ids = local_versions.get(model_id, [])
return await self._refresh_single_model(
model_type,
model_id,
version_ids,
metadata_provider,
force_refresh=force_refresh,
)
async def update_in_library_versions(
self,
model_type: str,
model_id: int,
version_ids: Sequence[int],
) -> ModelUpdateRecord:
"""Persist a new set of in-library version identifiers."""
normalized_versions = self._normalize_sequence(version_ids)
async with self._lock:
existing = self._get_record(model_type, model_id)
record = ModelUpdateRecord(
model_type=model_type,
model_id=model_id,
largest_version_id=existing.largest_version_id if existing else None,
version_ids=list(existing.version_ids) if existing else [],
in_library_version_ids=normalized_versions,
last_checked_at=existing.last_checked_at if existing else None,
should_ignore=existing.should_ignore if existing else False,
)
self._upsert_record(record)
return record
async def set_should_ignore(
self, model_type: str, model_id: int, should_ignore: bool
) -> ModelUpdateRecord:
"""Toggle the ignore flag for a model."""
async with self._lock:
existing = self._get_record(model_type, model_id)
if existing:
record = ModelUpdateRecord(
model_type=model_type,
model_id=model_id,
largest_version_id=existing.largest_version_id,
version_ids=list(existing.version_ids),
in_library_version_ids=list(existing.in_library_version_ids),
last_checked_at=existing.last_checked_at,
should_ignore=should_ignore,
)
else:
record = ModelUpdateRecord(
model_type=model_type,
model_id=model_id,
largest_version_id=None,
version_ids=[],
in_library_version_ids=[],
last_checked_at=None,
should_ignore=should_ignore,
)
self._upsert_record(record)
return record
async def get_record(self, model_type: str, model_id: int) -> Optional[ModelUpdateRecord]:
"""Return a cached record without triggering remote fetches."""
async with self._lock:
return self._get_record(model_type, model_id)
async def has_update(self, model_type: str, model_id: int) -> bool:
"""Determine if a model has updates pending."""
record = await self.get_record(model_type, model_id)
return record.has_update() if record else False
async def _refresh_single_model(
self,
model_type: str,
model_id: int,
local_versions: Sequence[int],
metadata_provider,
*,
force_refresh: bool = False,
) -> Optional[ModelUpdateRecord]:
normalized_local = self._normalize_sequence(local_versions)
now = time.time()
async with self._lock:
existing = self._get_record(model_type, model_id)
if existing and existing.should_ignore and not force_refresh:
record = ModelUpdateRecord(
model_type=model_type,
model_id=model_id,
largest_version_id=existing.largest_version_id,
version_ids=list(existing.version_ids),
in_library_version_ids=normalized_local,
last_checked_at=existing.last_checked_at,
should_ignore=True,
)
self._upsert_record(record)
return record
should_fetch = force_refresh or not existing or self._is_stale(existing, now)
# release lock during network request
fetched_versions: List[int] | None = None
if metadata_provider and should_fetch:
try:
response = await metadata_provider.get_model_versions(model_id)
except RateLimitError:
raise
except Exception as exc: # pragma: no cover - defensive log
logger.error(
"Failed to fetch versions for model %s (%s): %s",
model_id,
model_type,
exc,
exc_info=True,
)
else:
fetched_versions = self._extract_version_ids(response)
async with self._lock:
existing = self._get_record(model_type, model_id)
if existing and existing.should_ignore and not force_refresh:
# Ignore state could have flipped while awaiting provider
record = ModelUpdateRecord(
model_type=model_type,
model_id=model_id,
largest_version_id=existing.largest_version_id,
version_ids=list(existing.version_ids),
in_library_version_ids=normalized_local,
last_checked_at=existing.last_checked_at,
should_ignore=True,
)
self._upsert_record(record)
return record
version_ids = (
fetched_versions
if fetched_versions is not None
else (list(existing.version_ids) if existing else [])
)
largest = max(version_ids) if version_ids else None
last_checked = now if fetched_versions is not None else (
existing.last_checked_at if existing else None
)
record = ModelUpdateRecord(
model_type=model_type,
model_id=model_id,
largest_version_id=largest,
version_ids=version_ids,
in_library_version_ids=normalized_local,
last_checked_at=last_checked,
should_ignore=existing.should_ignore if existing else False,
)
self._upsert_record(record)
return record
async def _collect_local_versions(self, scanner) -> Dict[int, List[int]]:
cache = await scanner.get_cached_data()
mapping: Dict[int, set[int]] = {}
if not cache or not getattr(cache, "raw_data", None):
return {}
for item in cache.raw_data:
civitai = item.get("civitai") if isinstance(item, dict) else None
if not isinstance(civitai, dict):
continue
model_id = self._normalize_int(civitai.get("modelId"))
version_id = self._normalize_int(civitai.get("id"))
if model_id is None or version_id is None:
continue
mapping.setdefault(model_id, set()).add(version_id)
return {model_id: sorted(ids) for model_id, ids in mapping.items()}
def _is_stale(self, record: ModelUpdateRecord, now: float) -> bool:
if record.last_checked_at is None:
return True
return (now - record.last_checked_at) >= self._ttl_seconds
@staticmethod
def _normalize_int(value) -> Optional[int]:
try:
if value is None:
return None
return int(value)
except (TypeError, ValueError):
return None
def _normalize_sequence(self, values: Sequence[int]) -> List[int]:
normalized = [
item
for item in (self._normalize_int(value) for value in values)
if item is not None
]
return sorted(dict.fromkeys(normalized))
def _extract_version_ids(self, response) -> List[int]:
if not isinstance(response, Mapping):
return []
versions = response.get("modelVersions")
if not isinstance(versions, Iterable):
return []
normalized = []
for entry in versions:
if isinstance(entry, Mapping):
normalized_id = self._normalize_int(entry.get("id"))
else:
normalized_id = self._normalize_int(entry)
if normalized_id is not None:
normalized.append(normalized_id)
return sorted(dict.fromkeys(normalized))
def _get_record(self, model_type: str, model_id: int) -> Optional[ModelUpdateRecord]:
with self._connect() as conn:
row = conn.execute(
"""
SELECT model_type, model_id, largest_version_id, version_ids,
in_library_version_ids, last_checked_at, should_ignore
FROM model_update_status
WHERE model_type = ? AND model_id = ?
""",
(model_type, model_id),
).fetchone()
if not row:
return None
return ModelUpdateRecord(
model_type=row["model_type"],
model_id=int(row["model_id"]),
largest_version_id=self._normalize_int(row["largest_version_id"]),
version_ids=self._deserialize_json_array(row["version_ids"]),
in_library_version_ids=self._deserialize_json_array(
row["in_library_version_ids"]
),
last_checked_at=row["last_checked_at"],
should_ignore=bool(row["should_ignore"]),
)
def _upsert_record(self, record: ModelUpdateRecord) -> None:
payload = (
record.model_type,
record.model_id,
record.largest_version_id,
json.dumps(record.version_ids),
json.dumps(record.in_library_version_ids),
record.last_checked_at,
1 if record.should_ignore else 0,
)
with self._connect() as conn:
conn.execute(
"""
INSERT INTO model_update_status (
model_type, model_id, largest_version_id, version_ids,
in_library_version_ids, last_checked_at, should_ignore
) VALUES (?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(model_type, model_id) DO UPDATE SET
largest_version_id = excluded.largest_version_id,
version_ids = excluded.version_ids,
in_library_version_ids = excluded.in_library_version_ids,
last_checked_at = excluded.last_checked_at,
should_ignore = excluded.should_ignore
""",
payload,
)
conn.commit()
@staticmethod
def _deserialize_json_array(value) -> List[int]:
if not value:
return []
try:
data = json.loads(value)
except (TypeError, json.JSONDecodeError):
return []
if isinstance(data, list):
normalized = []
for entry in data:
try:
normalized.append(int(entry))
except (TypeError, ValueError):
continue
return sorted(dict.fromkeys(normalized))
return []

View File

@@ -81,6 +81,11 @@ class PersistentModelCache:
def is_enabled(self) -> bool:
return os.environ.get("LORA_MANAGER_DISABLE_PERSISTENT_CACHE", "0") != "1"
def get_database_path(self) -> str:
"""Expose the resolved SQLite database path."""
return self._db_path
def load_cache(self, model_type: str) -> Optional[PersistedCacheData]:
if not self.is_enabled():
return None

View File

@@ -128,23 +128,45 @@ class ServiceRegistry:
async def get_civitai_client(cls):
"""Get or create CivitAI client instance"""
service_name = "civitai_client"
if service_name in cls._services:
return cls._services[service_name]
async with cls._get_lock(service_name):
# Double-check after acquiring lock
if service_name in cls._services:
return cls._services[service_name]
# Import here to avoid circular imports
from .civitai_client import CivitaiClient
client = await CivitaiClient.get_instance()
cls._services[service_name] = client
logger.debug(f"Created and registered {service_name}")
return client
@classmethod
async def get_model_update_service(cls):
"""Get or create the model update tracking service."""
service_name = "model_update_service"
if service_name in cls._services:
return cls._services[service_name]
async with cls._get_lock(service_name):
if service_name in cls._services:
return cls._services[service_name]
from .model_update_service import ModelUpdateService
from .persistent_model_cache import get_persistent_cache
cache = get_persistent_cache()
service = ModelUpdateService(cache.get_database_path())
cls._services[service_name] = service
logger.debug(f"Created and registered {service_name}")
return service
@classmethod
async def get_civarchive_client(cls):
"""Get or create CivArchive client instance"""