refactor(routes): limit update endpoints to essentials

This commit is contained in:
pixelpaws
2025-10-15 15:37:35 +08:00
parent 321ff72953
commit ee0d241c75
11 changed files with 749 additions and 8 deletions

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from typing import Callable, Dict, Mapping
from typing import TYPE_CHECKING, Callable, Dict, Mapping
import jinja2
from aiohttp import web
@@ -42,8 +42,12 @@ from .handlers.model_handlers import (
ModelMoveHandler,
ModelPageView,
ModelQueryHandler,
ModelUpdateHandler,
)
if TYPE_CHECKING:
from ..services.model_update_service import ModelUpdateService
logger = logging.getLogger(__name__)
@@ -99,10 +103,18 @@ class BaseModelRoutes(ABC):
ws_manager=self._ws_manager,
download_manager_factory=ServiceRegistry.get_download_manager,
)
self._model_update_service: ModelUpdateService | None = None
if service is not None:
self.attach_service(service)
def set_model_update_service(self, service: "ModelUpdateService") -> None:
"""Attach the model update tracking service."""
self._model_update_service = service
self._handler_set = None
self._handler_mapping = None
def attach_service(self, service) -> None:
"""Attach a model service and rebuild handler dependencies."""
self.service = service
@@ -127,6 +139,7 @@ class BaseModelRoutes(ABC):
def _create_handler_set(self) -> ModelHandlerSet:
service = self._ensure_service()
update_service = self._ensure_model_update_service()
page_view = ModelPageView(
template_env=self.template_env,
template_name=self.template_name or "",
@@ -186,6 +199,12 @@ class BaseModelRoutes(ABC):
ws_manager=self._ws_manager,
logger=logger,
)
updates = ModelUpdateHandler(
service=service,
update_service=update_service,
metadata_provider_selector=get_metadata_provider,
logger=logger,
)
return ModelHandlerSet(
page_view=page_view,
listing=listing,
@@ -195,6 +214,7 @@ class BaseModelRoutes(ABC):
civitai=civitai,
move=move,
auto_organize=auto_organize,
updates=updates,
)
@property
@@ -273,3 +293,8 @@ class BaseModelRoutes(ABC):
return proxy
def _ensure_model_update_service(self) -> "ModelUpdateService":
if self._model_update_service is None:
raise RuntimeError("Model update service has not been attached")
return self._model_update_service

View File

@@ -21,7 +21,9 @@ class CheckpointRoutes(BaseModelRoutes):
"""Initialize services from ServiceRegistry"""
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
self.service = CheckpointService(checkpoint_scanner)
update_service = await ServiceRegistry.get_model_update_service()
self.set_model_update_service(update_service)
# Attach service dependencies
self.attach_service(self.service)

View File

@@ -20,7 +20,9 @@ class EmbeddingRoutes(BaseModelRoutes):
"""Initialize services from ServiceRegistry"""
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
self.service = EmbeddingService(embedding_scanner)
update_service = await ServiceRegistry.get_model_update_service()
self.set_model_update_service(update_service)
# Attach service dependencies
self.attach_service(self.service)

View File

@@ -29,6 +29,7 @@ from ...services.use_cases import (
)
from ...services.websocket_manager import WebSocketManager
from ...services.websocket_progress_callback import WebSocketProgressCallback
from ...services.errors import RateLimitError
from ...utils.file_utils import calculate_sha256
from ...utils.metadata_manager import MetadataManager
@@ -1017,6 +1018,156 @@ class ModelAutoOrganizeHandler:
return web.json_response({"success": False, "error": str(exc)}, status=500)
class ModelUpdateHandler:
"""Handle update tracking requests."""
def __init__(
self,
*,
service,
update_service,
metadata_provider_selector,
logger: logging.Logger,
) -> None:
self._service = service
self._update_service = update_service
self._metadata_provider_selector = metadata_provider_selector
self._logger = logger
async def refresh_model_updates(self, request: web.Request) -> web.Response:
payload = await self._read_json(request)
force_refresh = self._parse_bool(request.query.get("force")) or self._parse_bool(
payload.get("force")
)
provider = await self._get_civitai_provider()
if provider is None:
return web.json_response(
{"success": False, "error": "Civitai provider not available"}, status=503
)
try:
records = await self._update_service.refresh_for_model_type(
self._service.model_type,
self._service.scanner,
provider,
force_refresh=force_refresh,
)
except RateLimitError as exc:
return web.json_response(
{"success": False, "error": str(exc) or "Rate limited"}, status=429
)
except Exception as exc: # pragma: no cover - defensive logging
self._logger.error("Failed to refresh model updates: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
return web.json_response(
{
"success": True,
"records": [self._serialize_record(record) for record in records.values()],
}
)
async def set_model_update_ignore(self, request: web.Request) -> web.Response:
payload = await self._read_json(request)
model_id = self._normalize_model_id(payload.get("modelId"))
if model_id is None:
return web.json_response({"success": False, "error": "modelId is required"}, status=400)
should_ignore = self._parse_bool(payload.get("shouldIgnore"))
record = await self._update_service.set_should_ignore(
self._service.model_type, model_id, should_ignore
)
return web.json_response({"success": True, "record": self._serialize_record(record)})
async def get_model_update_status(self, request: web.Request) -> web.Response:
model_id = self._normalize_model_id(request.match_info.get("model_id"))
if model_id is None:
return web.json_response({"success": False, "error": "model_id must be an integer"}, status=400)
refresh = self._parse_bool(request.query.get("refresh"))
force = self._parse_bool(request.query.get("force"))
try:
record = await self._get_or_refresh_record(model_id, refresh=refresh, force=force)
except RateLimitError as exc:
return web.json_response(
{"success": False, "error": str(exc) or "Rate limited"}, status=429
)
if record is None:
return web.json_response(
{"success": False, "error": "Model not tracked"}, status=404
)
return web.json_response({"success": True, "record": self._serialize_record(record)})
async def _get_or_refresh_record(
self, model_id: int, *, refresh: bool, force: bool
) -> Optional[object]:
record = await self._update_service.get_record(self._service.model_type, model_id)
if record and not refresh and not force:
return record
provider = await self._get_civitai_provider()
if provider is None:
return record
return await self._update_service.refresh_single_model(
self._service.model_type,
model_id,
self._service.scanner,
provider,
force_refresh=force or refresh,
)
async def _get_civitai_provider(self):
try:
return await self._metadata_provider_selector("civitai_api")
except Exception as exc: # pragma: no cover - defensive log
self._logger.error("Failed to acquire civitai provider: %s", exc, exc_info=True)
return None
async def _read_json(self, request: web.Request) -> Dict:
if not request.can_read_body:
return {}
try:
return await request.json()
except Exception:
return {}
@staticmethod
def _parse_bool(value) -> bool:
if isinstance(value, bool):
return value
if isinstance(value, str):
return value.lower() in {"1", "true", "yes"}
if isinstance(value, (int, float)):
return bool(value)
return False
@staticmethod
def _normalize_model_id(value) -> Optional[int]:
try:
if value is None:
return None
return int(value)
except (TypeError, ValueError):
return None
@staticmethod
def _serialize_record(record) -> Dict:
return {
"modelType": record.model_type,
"modelId": record.model_id,
"largestVersionId": record.largest_version_id,
"versionIds": record.version_ids,
"inLibraryVersionIds": record.in_library_version_ids,
"lastCheckedAt": record.last_checked_at,
"shouldIgnore": record.should_ignore,
"hasUpdate": record.has_update(),
}
@dataclass
class ModelHandlerSet:
"""Aggregate concrete handlers into a flat mapping."""
@@ -1029,6 +1180,7 @@ class ModelHandlerSet:
civitai: ModelCivitaiHandler
move: ModelMoveHandler
auto_organize: ModelAutoOrganizeHandler
updates: ModelUpdateHandler
def to_route_mapping(self) -> Dict[str, Callable[[web.Request], Awaitable[web.Response]]]:
return {
@@ -1073,5 +1225,8 @@ class ModelHandlerSet:
"get_model_metadata": self.query.get_model_metadata,
"get_model_description": self.query.get_model_description,
"get_relative_paths": self.query.get_relative_paths,
"refresh_model_updates": self.updates.refresh_model_updates,
"set_model_update_ignore": self.updates.set_model_update_ignore,
"get_model_update_status": self.updates.get_model_update_status,
}

View File

@@ -24,7 +24,9 @@ class LoraRoutes(BaseModelRoutes):
"""Initialize services from ServiceRegistry"""
lora_scanner = await ServiceRegistry.get_lora_scanner()
self.service = LoraService(lora_scanner)
update_service = await ServiceRegistry.get_model_update_service()
self.set_model_update_service(update_service)
# Attach service dependencies
self.attach_service(self.service)

View File

@@ -55,6 +55,9 @@ COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
RouteDefinition("GET", "/api/lm/{prefix}/civitai/versions/{model_id}", "get_civitai_versions"),
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/version/{modelVersionId}", "get_civitai_model_by_version"),
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/hash/{hash}", "get_civitai_model_by_hash"),
RouteDefinition("POST", "/api/lm/{prefix}/updates/refresh", "refresh_model_updates"),
RouteDefinition("POST", "/api/lm/{prefix}/updates/ignore", "set_model_update_ignore"),
RouteDefinition("GET", "/api/lm/{prefix}/updates/status/{model_id}", "get_model_update_status"),
RouteDefinition("POST", "/api/lm/download-model", "download_model"),
RouteDefinition("GET", "/api/lm/download-model-get", "download_model_get"),
RouteDefinition("GET", "/api/lm/cancel-download-get", "cancel_download_get"),

View File

@@ -0,0 +1,404 @@
"""Service for tracking remote model version updates."""
from __future__ import annotations
import asyncio
import json
import logging
import os
import sqlite3
import time
from dataclasses import dataclass
from typing import Dict, Iterable, List, Mapping, Optional, Sequence
from .errors import RateLimitError
logger = logging.getLogger(__name__)
@dataclass
class ModelUpdateRecord:
"""Representation of a persisted update record."""
model_type: str
model_id: int
largest_version_id: Optional[int]
version_ids: List[int]
in_library_version_ids: List[int]
last_checked_at: Optional[float]
should_ignore: bool
def has_update(self) -> bool:
"""Return True when remote versions exceed the local library."""
if self.should_ignore or not self.version_ids:
return False
local_versions = set(self.in_library_version_ids)
return any(version_id not in local_versions for version_id in self.version_ids)
class ModelUpdateService:
"""Persist and query remote model version metadata."""
_SCHEMA = """
CREATE TABLE IF NOT EXISTS model_update_status (
model_type TEXT NOT NULL,
model_id INTEGER NOT NULL,
largest_version_id INTEGER,
version_ids TEXT,
in_library_version_ids TEXT,
last_checked_at REAL,
should_ignore INTEGER DEFAULT 0,
PRIMARY KEY (model_type, model_id)
)
"""
def __init__(self, db_path: str, *, ttl_seconds: int = 24 * 60 * 60) -> None:
self._db_path = db_path
self._ttl_seconds = ttl_seconds
self._lock = asyncio.Lock()
self._schema_initialized = False
self._ensure_directory()
self._initialize_schema()
def _ensure_directory(self) -> None:
directory = os.path.dirname(self._db_path)
if directory:
os.makedirs(directory, exist_ok=True)
def _connect(self) -> sqlite3.Connection:
conn = sqlite3.connect(self._db_path, check_same_thread=False)
conn.row_factory = sqlite3.Row
return conn
def _initialize_schema(self) -> None:
if self._schema_initialized:
return
try:
with self._connect() as conn:
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA foreign_keys = ON")
conn.executescript(self._SCHEMA)
self._schema_initialized = True
except Exception as exc: # pragma: no cover - defensive guard
logger.error("Failed to initialize update schema: %s", exc, exc_info=True)
raise
async def refresh_for_model_type(
self,
model_type: str,
scanner,
metadata_provider,
*,
force_refresh: bool = False,
) -> Dict[int, ModelUpdateRecord]:
"""Refresh update information for every model present in the cache."""
local_versions = await self._collect_local_versions(scanner)
results: Dict[int, ModelUpdateRecord] = {}
for model_id, version_ids in local_versions.items():
record = await self._refresh_single_model(
model_type,
model_id,
version_ids,
metadata_provider,
force_refresh=force_refresh,
)
if record:
results[model_id] = record
return results
async def refresh_single_model(
self,
model_type: str,
model_id: int,
scanner,
metadata_provider,
*,
force_refresh: bool = False,
) -> Optional[ModelUpdateRecord]:
"""Refresh update information for a specific model id."""
local_versions = await self._collect_local_versions(scanner)
version_ids = local_versions.get(model_id, [])
return await self._refresh_single_model(
model_type,
model_id,
version_ids,
metadata_provider,
force_refresh=force_refresh,
)
async def update_in_library_versions(
self,
model_type: str,
model_id: int,
version_ids: Sequence[int],
) -> ModelUpdateRecord:
"""Persist a new set of in-library version identifiers."""
normalized_versions = self._normalize_sequence(version_ids)
async with self._lock:
existing = self._get_record(model_type, model_id)
record = ModelUpdateRecord(
model_type=model_type,
model_id=model_id,
largest_version_id=existing.largest_version_id if existing else None,
version_ids=list(existing.version_ids) if existing else [],
in_library_version_ids=normalized_versions,
last_checked_at=existing.last_checked_at if existing else None,
should_ignore=existing.should_ignore if existing else False,
)
self._upsert_record(record)
return record
async def set_should_ignore(
self, model_type: str, model_id: int, should_ignore: bool
) -> ModelUpdateRecord:
"""Toggle the ignore flag for a model."""
async with self._lock:
existing = self._get_record(model_type, model_id)
if existing:
record = ModelUpdateRecord(
model_type=model_type,
model_id=model_id,
largest_version_id=existing.largest_version_id,
version_ids=list(existing.version_ids),
in_library_version_ids=list(existing.in_library_version_ids),
last_checked_at=existing.last_checked_at,
should_ignore=should_ignore,
)
else:
record = ModelUpdateRecord(
model_type=model_type,
model_id=model_id,
largest_version_id=None,
version_ids=[],
in_library_version_ids=[],
last_checked_at=None,
should_ignore=should_ignore,
)
self._upsert_record(record)
return record
async def get_record(self, model_type: str, model_id: int) -> Optional[ModelUpdateRecord]:
"""Return a cached record without triggering remote fetches."""
async with self._lock:
return self._get_record(model_type, model_id)
async def has_update(self, model_type: str, model_id: int) -> bool:
"""Determine if a model has updates pending."""
record = await self.get_record(model_type, model_id)
return record.has_update() if record else False
async def _refresh_single_model(
self,
model_type: str,
model_id: int,
local_versions: Sequence[int],
metadata_provider,
*,
force_refresh: bool = False,
) -> Optional[ModelUpdateRecord]:
normalized_local = self._normalize_sequence(local_versions)
now = time.time()
async with self._lock:
existing = self._get_record(model_type, model_id)
if existing and existing.should_ignore and not force_refresh:
record = ModelUpdateRecord(
model_type=model_type,
model_id=model_id,
largest_version_id=existing.largest_version_id,
version_ids=list(existing.version_ids),
in_library_version_ids=normalized_local,
last_checked_at=existing.last_checked_at,
should_ignore=True,
)
self._upsert_record(record)
return record
should_fetch = force_refresh or not existing or self._is_stale(existing, now)
# release lock during network request
fetched_versions: List[int] | None = None
if metadata_provider and should_fetch:
try:
response = await metadata_provider.get_model_versions(model_id)
except RateLimitError:
raise
except Exception as exc: # pragma: no cover - defensive log
logger.error(
"Failed to fetch versions for model %s (%s): %s",
model_id,
model_type,
exc,
exc_info=True,
)
else:
fetched_versions = self._extract_version_ids(response)
async with self._lock:
existing = self._get_record(model_type, model_id)
if existing and existing.should_ignore and not force_refresh:
# Ignore state could have flipped while awaiting provider
record = ModelUpdateRecord(
model_type=model_type,
model_id=model_id,
largest_version_id=existing.largest_version_id,
version_ids=list(existing.version_ids),
in_library_version_ids=normalized_local,
last_checked_at=existing.last_checked_at,
should_ignore=True,
)
self._upsert_record(record)
return record
version_ids = (
fetched_versions
if fetched_versions is not None
else (list(existing.version_ids) if existing else [])
)
largest = max(version_ids) if version_ids else None
last_checked = now if fetched_versions is not None else (
existing.last_checked_at if existing else None
)
record = ModelUpdateRecord(
model_type=model_type,
model_id=model_id,
largest_version_id=largest,
version_ids=version_ids,
in_library_version_ids=normalized_local,
last_checked_at=last_checked,
should_ignore=existing.should_ignore if existing else False,
)
self._upsert_record(record)
return record
async def _collect_local_versions(self, scanner) -> Dict[int, List[int]]:
cache = await scanner.get_cached_data()
mapping: Dict[int, set[int]] = {}
if not cache or not getattr(cache, "raw_data", None):
return {}
for item in cache.raw_data:
civitai = item.get("civitai") if isinstance(item, dict) else None
if not isinstance(civitai, dict):
continue
model_id = self._normalize_int(civitai.get("modelId"))
version_id = self._normalize_int(civitai.get("id"))
if model_id is None or version_id is None:
continue
mapping.setdefault(model_id, set()).add(version_id)
return {model_id: sorted(ids) for model_id, ids in mapping.items()}
def _is_stale(self, record: ModelUpdateRecord, now: float) -> bool:
if record.last_checked_at is None:
return True
return (now - record.last_checked_at) >= self._ttl_seconds
@staticmethod
def _normalize_int(value) -> Optional[int]:
try:
if value is None:
return None
return int(value)
except (TypeError, ValueError):
return None
def _normalize_sequence(self, values: Sequence[int]) -> List[int]:
normalized = [
item
for item in (self._normalize_int(value) for value in values)
if item is not None
]
return sorted(dict.fromkeys(normalized))
def _extract_version_ids(self, response) -> List[int]:
if not isinstance(response, Mapping):
return []
versions = response.get("modelVersions")
if not isinstance(versions, Iterable):
return []
normalized = []
for entry in versions:
if isinstance(entry, Mapping):
normalized_id = self._normalize_int(entry.get("id"))
else:
normalized_id = self._normalize_int(entry)
if normalized_id is not None:
normalized.append(normalized_id)
return sorted(dict.fromkeys(normalized))
def _get_record(self, model_type: str, model_id: int) -> Optional[ModelUpdateRecord]:
with self._connect() as conn:
row = conn.execute(
"""
SELECT model_type, model_id, largest_version_id, version_ids,
in_library_version_ids, last_checked_at, should_ignore
FROM model_update_status
WHERE model_type = ? AND model_id = ?
""",
(model_type, model_id),
).fetchone()
if not row:
return None
return ModelUpdateRecord(
model_type=row["model_type"],
model_id=int(row["model_id"]),
largest_version_id=self._normalize_int(row["largest_version_id"]),
version_ids=self._deserialize_json_array(row["version_ids"]),
in_library_version_ids=self._deserialize_json_array(
row["in_library_version_ids"]
),
last_checked_at=row["last_checked_at"],
should_ignore=bool(row["should_ignore"]),
)
def _upsert_record(self, record: ModelUpdateRecord) -> None:
payload = (
record.model_type,
record.model_id,
record.largest_version_id,
json.dumps(record.version_ids),
json.dumps(record.in_library_version_ids),
record.last_checked_at,
1 if record.should_ignore else 0,
)
with self._connect() as conn:
conn.execute(
"""
INSERT INTO model_update_status (
model_type, model_id, largest_version_id, version_ids,
in_library_version_ids, last_checked_at, should_ignore
) VALUES (?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(model_type, model_id) DO UPDATE SET
largest_version_id = excluded.largest_version_id,
version_ids = excluded.version_ids,
in_library_version_ids = excluded.in_library_version_ids,
last_checked_at = excluded.last_checked_at,
should_ignore = excluded.should_ignore
""",
payload,
)
conn.commit()
@staticmethod
def _deserialize_json_array(value) -> List[int]:
if not value:
return []
try:
data = json.loads(value)
except (TypeError, json.JSONDecodeError):
return []
if isinstance(data, list):
normalized = []
for entry in data:
try:
normalized.append(int(entry))
except (TypeError, ValueError):
continue
return sorted(dict.fromkeys(normalized))
return []

View File

@@ -81,6 +81,11 @@ class PersistentModelCache:
def is_enabled(self) -> bool:
return os.environ.get("LORA_MANAGER_DISABLE_PERSISTENT_CACHE", "0") != "1"
def get_database_path(self) -> str:
"""Expose the resolved SQLite database path."""
return self._db_path
def load_cache(self, model_type: str) -> Optional[PersistedCacheData]:
if not self.is_enabled():
return None

View File

@@ -128,23 +128,45 @@ class ServiceRegistry:
async def get_civitai_client(cls):
"""Get or create CivitAI client instance"""
service_name = "civitai_client"
if service_name in cls._services:
return cls._services[service_name]
async with cls._get_lock(service_name):
# Double-check after acquiring lock
if service_name in cls._services:
return cls._services[service_name]
# Import here to avoid circular imports
from .civitai_client import CivitaiClient
client = await CivitaiClient.get_instance()
cls._services[service_name] = client
logger.debug(f"Created and registered {service_name}")
return client
@classmethod
async def get_model_update_service(cls):
"""Get or create the model update tracking service."""
service_name = "model_update_service"
if service_name in cls._services:
return cls._services[service_name]
async with cls._get_lock(service_name):
if service_name in cls._services:
return cls._services[service_name]
from .model_update_service import ModelUpdateService
from .persistent_model_cache import get_persistent_cache
cache = get_persistent_cache()
service = ModelUpdateService(cache.get_database_path())
cls._services[service_name] = service
logger.debug(f"Created and registered {service_name}")
return service
@classmethod
async def get_civarchive_client(cls):
"""Get or create CivArchive client instance"""

View File

@@ -5,6 +5,7 @@ import sys
from pathlib import Path
import types
from dataclasses import dataclass, field
from typing import Optional
folder_paths_stub = types.SimpleNamespace(get_folder_paths=lambda *_: [])
@@ -32,6 +33,41 @@ class DummyRoutes(BaseModelRoutes):
def setup_specific_routes(self, registrar, prefix: str) -> None: # pragma: no cover - no extra routes in smoke tests
return None
def __init__(self, service=None):
super().__init__(service)
self.set_model_update_service(NullModelUpdateService())
@dataclass
class NullUpdateRecord:
model_type: str
model_id: int
largest_version_id: int | None = None
version_ids: list[int] = field(default_factory=list)
in_library_version_ids: list[int] = field(default_factory=list)
last_checked_at: float | None = None
should_ignore: bool = False
def has_update(self) -> bool:
return False
class NullModelUpdateService:
async def refresh_for_model_type(self, *args, **kwargs):
return {}
async def refresh_single_model(self, *args, **kwargs):
return None
async def update_in_library_versions(self, model_type, model_id, version_ids):
return NullUpdateRecord(model_type=model_type, model_id=model_id, in_library_version_ids=list(version_ids))
async def set_should_ignore(self, model_type, model_id, should_ignore):
return NullUpdateRecord(model_type=model_type, model_id=model_id, should_ignore=should_ignore)
async def get_record(self, *args, **kwargs):
return None
async def create_test_client(service) -> TestClient:
routes = DummyRoutes(service)

View File

@@ -0,0 +1,85 @@
import asyncio
from types import SimpleNamespace
import pytest
from py.services.model_update_service import ModelUpdateService
class DummyScanner:
def __init__(self, raw_data):
self._cache = SimpleNamespace(raw_data=raw_data)
async def get_cached_data(self, *args, **kwargs):
return self._cache
class DummyProvider:
def __init__(self, response):
self.response = response
self.calls: int = 0
async def get_model_versions(self, model_id):
self.calls += 1
return self.response
@pytest.mark.asyncio
async def test_refresh_persists_versions_and_uses_cache(tmp_path):
db_path = tmp_path / "updates.sqlite"
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
raw_data = [
{"civitai": {"modelId": 1, "id": 11}},
{"civitai": {"modelId": 1, "id": 15}},
]
scanner = DummyScanner(raw_data)
provider = DummyProvider({"modelVersions": [{"id": 11}, {"id": 15}]})
await service.refresh_for_model_type("lora", scanner, provider)
record = await service.get_record("lora", 1)
assert provider.calls == 1
assert record is not None
assert record.version_ids == [11, 15]
assert record.in_library_version_ids == [11, 15]
assert record.has_update() is False
await service.refresh_for_model_type("lora", scanner, provider)
assert provider.calls == 1, "provider should not be called again within TTL"
@pytest.mark.asyncio
async def test_refresh_respects_ignore_flag(tmp_path):
db_path = tmp_path / "updates.sqlite"
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
raw_data = [{"civitai": {"modelId": 2, "id": 21}}]
scanner = DummyScanner(raw_data)
provider = DummyProvider({"modelVersions": [{"id": 21}, {"id": 22}]})
await service.refresh_for_model_type("lora", scanner, provider)
await service.set_should_ignore("lora", 2, True)
provider.calls = 0
await service.refresh_for_model_type("lora", scanner, provider)
assert provider.calls == 0
@pytest.mark.asyncio
async def test_update_in_library_versions_changes_update_state(tmp_path):
db_path = tmp_path / "updates.sqlite"
service = ModelUpdateService(str(db_path), ttl_seconds=1)
raw_data = [{"civitai": {"modelId": 3, "id": 31}}]
scanner = DummyScanner(raw_data)
provider = DummyProvider({"modelVersions": [{"id": 31}, {"id": 35}]})
await service.refresh_for_model_type("lora", scanner, provider)
await service.update_in_library_versions("lora", 3, [31])
record = await service.get_record("lora", 3)
assert record is not None
assert record.has_update() is True
await service.update_in_library_versions("lora", 3, [31, 35])
record = await service.get_record("lora", 3)
assert record.has_update() is False