mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
Merge pull request #572 from willmiao/codex/design-ui-for-model-update-notifications
refactor: tighten civitai update endpoints
This commit is contained in:
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Callable, Dict, Mapping
|
from typing import TYPE_CHECKING, Callable, Dict, Mapping
|
||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
@@ -42,8 +42,12 @@ from .handlers.model_handlers import (
|
|||||||
ModelMoveHandler,
|
ModelMoveHandler,
|
||||||
ModelPageView,
|
ModelPageView,
|
||||||
ModelQueryHandler,
|
ModelQueryHandler,
|
||||||
|
ModelUpdateHandler,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..services.model_update_service import ModelUpdateService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -99,10 +103,18 @@ class BaseModelRoutes(ABC):
|
|||||||
ws_manager=self._ws_manager,
|
ws_manager=self._ws_manager,
|
||||||
download_manager_factory=ServiceRegistry.get_download_manager,
|
download_manager_factory=ServiceRegistry.get_download_manager,
|
||||||
)
|
)
|
||||||
|
self._model_update_service: ModelUpdateService | None = None
|
||||||
|
|
||||||
if service is not None:
|
if service is not None:
|
||||||
self.attach_service(service)
|
self.attach_service(service)
|
||||||
|
|
||||||
|
def set_model_update_service(self, service: "ModelUpdateService") -> None:
|
||||||
|
"""Attach the model update tracking service."""
|
||||||
|
|
||||||
|
self._model_update_service = service
|
||||||
|
self._handler_set = None
|
||||||
|
self._handler_mapping = None
|
||||||
|
|
||||||
def attach_service(self, service) -> None:
|
def attach_service(self, service) -> None:
|
||||||
"""Attach a model service and rebuild handler dependencies."""
|
"""Attach a model service and rebuild handler dependencies."""
|
||||||
self.service = service
|
self.service = service
|
||||||
@@ -127,6 +139,7 @@ class BaseModelRoutes(ABC):
|
|||||||
|
|
||||||
def _create_handler_set(self) -> ModelHandlerSet:
|
def _create_handler_set(self) -> ModelHandlerSet:
|
||||||
service = self._ensure_service()
|
service = self._ensure_service()
|
||||||
|
update_service = self._ensure_model_update_service()
|
||||||
page_view = ModelPageView(
|
page_view = ModelPageView(
|
||||||
template_env=self.template_env,
|
template_env=self.template_env,
|
||||||
template_name=self.template_name or "",
|
template_name=self.template_name or "",
|
||||||
@@ -186,6 +199,12 @@ class BaseModelRoutes(ABC):
|
|||||||
ws_manager=self._ws_manager,
|
ws_manager=self._ws_manager,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
)
|
)
|
||||||
|
updates = ModelUpdateHandler(
|
||||||
|
service=service,
|
||||||
|
update_service=update_service,
|
||||||
|
metadata_provider_selector=get_metadata_provider,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
return ModelHandlerSet(
|
return ModelHandlerSet(
|
||||||
page_view=page_view,
|
page_view=page_view,
|
||||||
listing=listing,
|
listing=listing,
|
||||||
@@ -195,6 +214,7 @@ class BaseModelRoutes(ABC):
|
|||||||
civitai=civitai,
|
civitai=civitai,
|
||||||
move=move,
|
move=move,
|
||||||
auto_organize=auto_organize,
|
auto_organize=auto_organize,
|
||||||
|
updates=updates,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -273,3 +293,8 @@ class BaseModelRoutes(ABC):
|
|||||||
|
|
||||||
return proxy
|
return proxy
|
||||||
|
|
||||||
|
def _ensure_model_update_service(self) -> "ModelUpdateService":
|
||||||
|
if self._model_update_service is None:
|
||||||
|
raise RuntimeError("Model update service has not been attached")
|
||||||
|
return self._model_update_service
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,8 @@ class CheckpointRoutes(BaseModelRoutes):
|
|||||||
"""Initialize services from ServiceRegistry"""
|
"""Initialize services from ServiceRegistry"""
|
||||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||||
self.service = CheckpointService(checkpoint_scanner)
|
self.service = CheckpointService(checkpoint_scanner)
|
||||||
|
update_service = await ServiceRegistry.get_model_update_service()
|
||||||
|
self.set_model_update_service(update_service)
|
||||||
|
|
||||||
# Attach service dependencies
|
# Attach service dependencies
|
||||||
self.attach_service(self.service)
|
self.attach_service(self.service)
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ class EmbeddingRoutes(BaseModelRoutes):
|
|||||||
"""Initialize services from ServiceRegistry"""
|
"""Initialize services from ServiceRegistry"""
|
||||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||||
self.service = EmbeddingService(embedding_scanner)
|
self.service = EmbeddingService(embedding_scanner)
|
||||||
|
update_service = await ServiceRegistry.get_model_update_service()
|
||||||
|
self.set_model_update_service(update_service)
|
||||||
|
|
||||||
# Attach service dependencies
|
# Attach service dependencies
|
||||||
self.attach_service(self.service)
|
self.attach_service(self.service)
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from ...services.use_cases import (
|
|||||||
)
|
)
|
||||||
from ...services.websocket_manager import WebSocketManager
|
from ...services.websocket_manager import WebSocketManager
|
||||||
from ...services.websocket_progress_callback import WebSocketProgressCallback
|
from ...services.websocket_progress_callback import WebSocketProgressCallback
|
||||||
|
from ...services.errors import RateLimitError
|
||||||
from ...utils.file_utils import calculate_sha256
|
from ...utils.file_utils import calculate_sha256
|
||||||
from ...utils.metadata_manager import MetadataManager
|
from ...utils.metadata_manager import MetadataManager
|
||||||
|
|
||||||
@@ -1017,6 +1018,156 @@ class ModelAutoOrganizeHandler:
|
|||||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelUpdateHandler:
|
||||||
|
"""Handle update tracking requests."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
service,
|
||||||
|
update_service,
|
||||||
|
metadata_provider_selector,
|
||||||
|
logger: logging.Logger,
|
||||||
|
) -> None:
|
||||||
|
self._service = service
|
||||||
|
self._update_service = update_service
|
||||||
|
self._metadata_provider_selector = metadata_provider_selector
|
||||||
|
self._logger = logger
|
||||||
|
|
||||||
|
async def refresh_model_updates(self, request: web.Request) -> web.Response:
|
||||||
|
payload = await self._read_json(request)
|
||||||
|
force_refresh = self._parse_bool(request.query.get("force")) or self._parse_bool(
|
||||||
|
payload.get("force")
|
||||||
|
)
|
||||||
|
provider = await self._get_civitai_provider()
|
||||||
|
if provider is None:
|
||||||
|
return web.json_response(
|
||||||
|
{"success": False, "error": "Civitai provider not available"}, status=503
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
records = await self._update_service.refresh_for_model_type(
|
||||||
|
self._service.model_type,
|
||||||
|
self._service.scanner,
|
||||||
|
provider,
|
||||||
|
force_refresh=force_refresh,
|
||||||
|
)
|
||||||
|
except RateLimitError as exc:
|
||||||
|
return web.json_response(
|
||||||
|
{"success": False, "error": str(exc) or "Rate limited"}, status=429
|
||||||
|
)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
|
self._logger.error("Failed to refresh model updates: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
return web.json_response(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"records": [self._serialize_record(record) for record in records.values()],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def set_model_update_ignore(self, request: web.Request) -> web.Response:
|
||||||
|
payload = await self._read_json(request)
|
||||||
|
model_id = self._normalize_model_id(payload.get("modelId"))
|
||||||
|
if model_id is None:
|
||||||
|
return web.json_response({"success": False, "error": "modelId is required"}, status=400)
|
||||||
|
|
||||||
|
should_ignore = self._parse_bool(payload.get("shouldIgnore"))
|
||||||
|
record = await self._update_service.set_should_ignore(
|
||||||
|
self._service.model_type, model_id, should_ignore
|
||||||
|
)
|
||||||
|
return web.json_response({"success": True, "record": self._serialize_record(record)})
|
||||||
|
|
||||||
|
async def get_model_update_status(self, request: web.Request) -> web.Response:
|
||||||
|
model_id = self._normalize_model_id(request.match_info.get("model_id"))
|
||||||
|
if model_id is None:
|
||||||
|
return web.json_response({"success": False, "error": "model_id must be an integer"}, status=400)
|
||||||
|
|
||||||
|
refresh = self._parse_bool(request.query.get("refresh"))
|
||||||
|
force = self._parse_bool(request.query.get("force"))
|
||||||
|
|
||||||
|
try:
|
||||||
|
record = await self._get_or_refresh_record(model_id, refresh=refresh, force=force)
|
||||||
|
except RateLimitError as exc:
|
||||||
|
return web.json_response(
|
||||||
|
{"success": False, "error": str(exc) or "Rate limited"}, status=429
|
||||||
|
)
|
||||||
|
|
||||||
|
if record is None:
|
||||||
|
return web.json_response(
|
||||||
|
{"success": False, "error": "Model not tracked"}, status=404
|
||||||
|
)
|
||||||
|
|
||||||
|
return web.json_response({"success": True, "record": self._serialize_record(record)})
|
||||||
|
|
||||||
|
async def _get_or_refresh_record(
|
||||||
|
self, model_id: int, *, refresh: bool, force: bool
|
||||||
|
) -> Optional[object]:
|
||||||
|
record = await self._update_service.get_record(self._service.model_type, model_id)
|
||||||
|
if record and not refresh and not force:
|
||||||
|
return record
|
||||||
|
|
||||||
|
provider = await self._get_civitai_provider()
|
||||||
|
if provider is None:
|
||||||
|
return record
|
||||||
|
|
||||||
|
return await self._update_service.refresh_single_model(
|
||||||
|
self._service.model_type,
|
||||||
|
model_id,
|
||||||
|
self._service.scanner,
|
||||||
|
provider,
|
||||||
|
force_refresh=force or refresh,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get_civitai_provider(self):
|
||||||
|
try:
|
||||||
|
return await self._metadata_provider_selector("civitai_api")
|
||||||
|
except Exception as exc: # pragma: no cover - defensive log
|
||||||
|
self._logger.error("Failed to acquire civitai provider: %s", exc, exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _read_json(self, request: web.Request) -> Dict:
|
||||||
|
if not request.can_read_body:
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
return await request.json()
|
||||||
|
except Exception:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_bool(value) -> bool:
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
if isinstance(value, str):
|
||||||
|
return value.lower() in {"1", "true", "yes"}
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
return bool(value)
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_model_id(value) -> Optional[int]:
|
||||||
|
try:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
return int(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _serialize_record(record) -> Dict:
|
||||||
|
return {
|
||||||
|
"modelType": record.model_type,
|
||||||
|
"modelId": record.model_id,
|
||||||
|
"largestVersionId": record.largest_version_id,
|
||||||
|
"versionIds": record.version_ids,
|
||||||
|
"inLibraryVersionIds": record.in_library_version_ids,
|
||||||
|
"lastCheckedAt": record.last_checked_at,
|
||||||
|
"shouldIgnore": record.should_ignore,
|
||||||
|
"hasUpdate": record.has_update(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelHandlerSet:
|
class ModelHandlerSet:
|
||||||
"""Aggregate concrete handlers into a flat mapping."""
|
"""Aggregate concrete handlers into a flat mapping."""
|
||||||
@@ -1029,6 +1180,7 @@ class ModelHandlerSet:
|
|||||||
civitai: ModelCivitaiHandler
|
civitai: ModelCivitaiHandler
|
||||||
move: ModelMoveHandler
|
move: ModelMoveHandler
|
||||||
auto_organize: ModelAutoOrganizeHandler
|
auto_organize: ModelAutoOrganizeHandler
|
||||||
|
updates: ModelUpdateHandler
|
||||||
|
|
||||||
def to_route_mapping(self) -> Dict[str, Callable[[web.Request], Awaitable[web.Response]]]:
|
def to_route_mapping(self) -> Dict[str, Callable[[web.Request], Awaitable[web.Response]]]:
|
||||||
return {
|
return {
|
||||||
@@ -1073,5 +1225,8 @@ class ModelHandlerSet:
|
|||||||
"get_model_metadata": self.query.get_model_metadata,
|
"get_model_metadata": self.query.get_model_metadata,
|
||||||
"get_model_description": self.query.get_model_description,
|
"get_model_description": self.query.get_model_description,
|
||||||
"get_relative_paths": self.query.get_relative_paths,
|
"get_relative_paths": self.query.get_relative_paths,
|
||||||
|
"refresh_model_updates": self.updates.refresh_model_updates,
|
||||||
|
"set_model_update_ignore": self.updates.set_model_update_ignore,
|
||||||
|
"get_model_update_status": self.updates.get_model_update_status,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ class LoraRoutes(BaseModelRoutes):
|
|||||||
"""Initialize services from ServiceRegistry"""
|
"""Initialize services from ServiceRegistry"""
|
||||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||||
self.service = LoraService(lora_scanner)
|
self.service = LoraService(lora_scanner)
|
||||||
|
update_service = await ServiceRegistry.get_model_update_service()
|
||||||
|
self.set_model_update_service(update_service)
|
||||||
|
|
||||||
# Attach service dependencies
|
# Attach service dependencies
|
||||||
self.attach_service(self.service)
|
self.attach_service(self.service)
|
||||||
|
|||||||
@@ -55,6 +55,9 @@ COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
|||||||
RouteDefinition("GET", "/api/lm/{prefix}/civitai/versions/{model_id}", "get_civitai_versions"),
|
RouteDefinition("GET", "/api/lm/{prefix}/civitai/versions/{model_id}", "get_civitai_versions"),
|
||||||
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/version/{modelVersionId}", "get_civitai_model_by_version"),
|
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/version/{modelVersionId}", "get_civitai_model_by_version"),
|
||||||
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/hash/{hash}", "get_civitai_model_by_hash"),
|
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/hash/{hash}", "get_civitai_model_by_hash"),
|
||||||
|
RouteDefinition("POST", "/api/lm/{prefix}/updates/refresh", "refresh_model_updates"),
|
||||||
|
RouteDefinition("POST", "/api/lm/{prefix}/updates/ignore", "set_model_update_ignore"),
|
||||||
|
RouteDefinition("GET", "/api/lm/{prefix}/updates/status/{model_id}", "get_model_update_status"),
|
||||||
RouteDefinition("POST", "/api/lm/download-model", "download_model"),
|
RouteDefinition("POST", "/api/lm/download-model", "download_model"),
|
||||||
RouteDefinition("GET", "/api/lm/download-model-get", "download_model_get"),
|
RouteDefinition("GET", "/api/lm/download-model-get", "download_model_get"),
|
||||||
RouteDefinition("GET", "/api/lm/cancel-download-get", "cancel_download_get"),
|
RouteDefinition("GET", "/api/lm/cancel-download-get", "cancel_download_get"),
|
||||||
|
|||||||
411
py/services/model_update_service.py
Normal file
411
py/services/model_update_service.py
Normal file
@@ -0,0 +1,411 @@
|
|||||||
|
"""Service for tracking remote model version updates."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Iterable, List, Mapping, Optional, Sequence
|
||||||
|
|
||||||
|
from .errors import RateLimitError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelUpdateRecord:
|
||||||
|
"""Representation of a persisted update record."""
|
||||||
|
|
||||||
|
model_type: str
|
||||||
|
model_id: int
|
||||||
|
largest_version_id: Optional[int]
|
||||||
|
version_ids: List[int]
|
||||||
|
in_library_version_ids: List[int]
|
||||||
|
last_checked_at: Optional[float]
|
||||||
|
should_ignore: bool
|
||||||
|
|
||||||
|
def has_update(self) -> bool:
|
||||||
|
"""Return True when remote versions exceed the local library."""
|
||||||
|
|
||||||
|
if self.should_ignore or not self.version_ids:
|
||||||
|
return False
|
||||||
|
local_versions = set(self.in_library_version_ids)
|
||||||
|
return any(version_id not in local_versions for version_id in self.version_ids)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelUpdateService:
|
||||||
|
"""Persist and query remote model version metadata."""
|
||||||
|
|
||||||
|
_SCHEMA = """
|
||||||
|
CREATE TABLE IF NOT EXISTS model_update_status (
|
||||||
|
model_type TEXT NOT NULL,
|
||||||
|
model_id INTEGER NOT NULL,
|
||||||
|
largest_version_id INTEGER,
|
||||||
|
version_ids TEXT,
|
||||||
|
in_library_version_ids TEXT,
|
||||||
|
last_checked_at REAL,
|
||||||
|
should_ignore INTEGER DEFAULT 0,
|
||||||
|
PRIMARY KEY (model_type, model_id)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db_path: str, *, ttl_seconds: int = 24 * 60 * 60) -> None:
|
||||||
|
self._db_path = db_path
|
||||||
|
self._ttl_seconds = ttl_seconds
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
self._schema_initialized = False
|
||||||
|
self._ensure_directory()
|
||||||
|
self._initialize_schema()
|
||||||
|
|
||||||
|
def _ensure_directory(self) -> None:
|
||||||
|
directory = os.path.dirname(self._db_path)
|
||||||
|
if directory:
|
||||||
|
os.makedirs(directory, exist_ok=True)
|
||||||
|
|
||||||
|
def _connect(self) -> sqlite3.Connection:
|
||||||
|
conn = sqlite3.connect(self._db_path, check_same_thread=False)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
return conn
|
||||||
|
|
||||||
|
def _initialize_schema(self) -> None:
|
||||||
|
if self._schema_initialized:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
with self._connect() as conn:
|
||||||
|
conn.execute("PRAGMA journal_mode=WAL")
|
||||||
|
conn.execute("PRAGMA foreign_keys = ON")
|
||||||
|
conn.executescript(self._SCHEMA)
|
||||||
|
self._schema_initialized = True
|
||||||
|
except Exception as exc: # pragma: no cover - defensive guard
|
||||||
|
logger.error("Failed to initialize update schema: %s", exc, exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def refresh_for_model_type(
|
||||||
|
self,
|
||||||
|
model_type: str,
|
||||||
|
scanner,
|
||||||
|
metadata_provider,
|
||||||
|
*,
|
||||||
|
force_refresh: bool = False,
|
||||||
|
) -> Dict[int, ModelUpdateRecord]:
|
||||||
|
"""Refresh update information for every model present in the cache."""
|
||||||
|
|
||||||
|
local_versions = await self._collect_local_versions(scanner)
|
||||||
|
results: Dict[int, ModelUpdateRecord] = {}
|
||||||
|
for model_id, version_ids in local_versions.items():
|
||||||
|
record = await self._refresh_single_model(
|
||||||
|
model_type,
|
||||||
|
model_id,
|
||||||
|
version_ids,
|
||||||
|
metadata_provider,
|
||||||
|
force_refresh=force_refresh,
|
||||||
|
)
|
||||||
|
if record:
|
||||||
|
results[model_id] = record
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def refresh_single_model(
|
||||||
|
self,
|
||||||
|
model_type: str,
|
||||||
|
model_id: int,
|
||||||
|
scanner,
|
||||||
|
metadata_provider,
|
||||||
|
*,
|
||||||
|
force_refresh: bool = False,
|
||||||
|
) -> Optional[ModelUpdateRecord]:
|
||||||
|
"""Refresh update information for a specific model id."""
|
||||||
|
|
||||||
|
local_versions = await self._collect_local_versions(scanner)
|
||||||
|
version_ids = local_versions.get(model_id, [])
|
||||||
|
return await self._refresh_single_model(
|
||||||
|
model_type,
|
||||||
|
model_id,
|
||||||
|
version_ids,
|
||||||
|
metadata_provider,
|
||||||
|
force_refresh=force_refresh,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def update_in_library_versions(
|
||||||
|
self,
|
||||||
|
model_type: str,
|
||||||
|
model_id: int,
|
||||||
|
version_ids: Sequence[int],
|
||||||
|
) -> ModelUpdateRecord:
|
||||||
|
"""Persist a new set of in-library version identifiers."""
|
||||||
|
|
||||||
|
normalized_versions = self._normalize_sequence(version_ids)
|
||||||
|
async with self._lock:
|
||||||
|
existing = self._get_record(model_type, model_id)
|
||||||
|
record = ModelUpdateRecord(
|
||||||
|
model_type=model_type,
|
||||||
|
model_id=model_id,
|
||||||
|
largest_version_id=existing.largest_version_id if existing else None,
|
||||||
|
version_ids=list(existing.version_ids) if existing else [],
|
||||||
|
in_library_version_ids=normalized_versions,
|
||||||
|
last_checked_at=existing.last_checked_at if existing else None,
|
||||||
|
should_ignore=existing.should_ignore if existing else False,
|
||||||
|
)
|
||||||
|
self._upsert_record(record)
|
||||||
|
return record
|
||||||
|
|
||||||
|
async def set_should_ignore(
|
||||||
|
self, model_type: str, model_id: int, should_ignore: bool
|
||||||
|
) -> ModelUpdateRecord:
|
||||||
|
"""Toggle the ignore flag for a model."""
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
existing = self._get_record(model_type, model_id)
|
||||||
|
if existing:
|
||||||
|
record = ModelUpdateRecord(
|
||||||
|
model_type=model_type,
|
||||||
|
model_id=model_id,
|
||||||
|
largest_version_id=existing.largest_version_id,
|
||||||
|
version_ids=list(existing.version_ids),
|
||||||
|
in_library_version_ids=list(existing.in_library_version_ids),
|
||||||
|
last_checked_at=existing.last_checked_at,
|
||||||
|
should_ignore=should_ignore,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
record = ModelUpdateRecord(
|
||||||
|
model_type=model_type,
|
||||||
|
model_id=model_id,
|
||||||
|
largest_version_id=None,
|
||||||
|
version_ids=[],
|
||||||
|
in_library_version_ids=[],
|
||||||
|
last_checked_at=None,
|
||||||
|
should_ignore=should_ignore,
|
||||||
|
)
|
||||||
|
self._upsert_record(record)
|
||||||
|
return record
|
||||||
|
|
||||||
|
async def get_record(self, model_type: str, model_id: int) -> Optional[ModelUpdateRecord]:
|
||||||
|
"""Return a cached record without triggering remote fetches."""
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
return self._get_record(model_type, model_id)
|
||||||
|
|
||||||
|
async def has_update(self, model_type: str, model_id: int) -> bool:
|
||||||
|
"""Determine if a model has updates pending."""
|
||||||
|
|
||||||
|
record = await self.get_record(model_type, model_id)
|
||||||
|
return record.has_update() if record else False
|
||||||
|
|
||||||
|
async def _refresh_single_model(
|
||||||
|
self,
|
||||||
|
model_type: str,
|
||||||
|
model_id: int,
|
||||||
|
local_versions: Sequence[int],
|
||||||
|
metadata_provider,
|
||||||
|
*,
|
||||||
|
force_refresh: bool = False,
|
||||||
|
) -> Optional[ModelUpdateRecord]:
|
||||||
|
normalized_local = self._normalize_sequence(local_versions)
|
||||||
|
now = time.time()
|
||||||
|
async with self._lock:
|
||||||
|
existing = self._get_record(model_type, model_id)
|
||||||
|
if existing and existing.should_ignore and not force_refresh:
|
||||||
|
record = ModelUpdateRecord(
|
||||||
|
model_type=model_type,
|
||||||
|
model_id=model_id,
|
||||||
|
largest_version_id=existing.largest_version_id,
|
||||||
|
version_ids=list(existing.version_ids),
|
||||||
|
in_library_version_ids=normalized_local,
|
||||||
|
last_checked_at=existing.last_checked_at,
|
||||||
|
should_ignore=True,
|
||||||
|
)
|
||||||
|
self._upsert_record(record)
|
||||||
|
return record
|
||||||
|
|
||||||
|
should_fetch = force_refresh or not existing or self._is_stale(existing, now)
|
||||||
|
# release lock during network request
|
||||||
|
fetched_versions: List[int] | None = None
|
||||||
|
refresh_succeeded = False
|
||||||
|
if metadata_provider and should_fetch:
|
||||||
|
try:
|
||||||
|
response = await metadata_provider.get_model_versions(model_id)
|
||||||
|
except RateLimitError:
|
||||||
|
raise
|
||||||
|
except Exception as exc: # pragma: no cover - defensive log
|
||||||
|
logger.error(
|
||||||
|
"Failed to fetch versions for model %s (%s): %s",
|
||||||
|
model_id,
|
||||||
|
model_type,
|
||||||
|
exc,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if response is not None:
|
||||||
|
extracted = self._extract_version_ids(response)
|
||||||
|
if extracted is not None:
|
||||||
|
fetched_versions = extracted
|
||||||
|
refresh_succeeded = True
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
existing = self._get_record(model_type, model_id)
|
||||||
|
if existing and existing.should_ignore and not force_refresh:
|
||||||
|
# Ignore state could have flipped while awaiting provider
|
||||||
|
record = ModelUpdateRecord(
|
||||||
|
model_type=model_type,
|
||||||
|
model_id=model_id,
|
||||||
|
largest_version_id=existing.largest_version_id,
|
||||||
|
version_ids=list(existing.version_ids),
|
||||||
|
in_library_version_ids=normalized_local,
|
||||||
|
last_checked_at=existing.last_checked_at,
|
||||||
|
should_ignore=True,
|
||||||
|
)
|
||||||
|
self._upsert_record(record)
|
||||||
|
return record
|
||||||
|
|
||||||
|
version_ids = (
|
||||||
|
fetched_versions
|
||||||
|
if refresh_succeeded
|
||||||
|
else (list(existing.version_ids) if existing else [])
|
||||||
|
)
|
||||||
|
largest = max(version_ids) if version_ids else None
|
||||||
|
last_checked = now if refresh_succeeded else (
|
||||||
|
existing.last_checked_at if existing else None
|
||||||
|
)
|
||||||
|
record = ModelUpdateRecord(
|
||||||
|
model_type=model_type,
|
||||||
|
model_id=model_id,
|
||||||
|
largest_version_id=largest,
|
||||||
|
version_ids=version_ids,
|
||||||
|
in_library_version_ids=normalized_local,
|
||||||
|
last_checked_at=last_checked,
|
||||||
|
should_ignore=existing.should_ignore if existing else False,
|
||||||
|
)
|
||||||
|
self._upsert_record(record)
|
||||||
|
return record
|
||||||
|
|
||||||
|
async def _collect_local_versions(self, scanner) -> Dict[int, List[int]]:
|
||||||
|
cache = await scanner.get_cached_data()
|
||||||
|
mapping: Dict[int, set[int]] = {}
|
||||||
|
if not cache or not getattr(cache, "raw_data", None):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
for item in cache.raw_data:
|
||||||
|
civitai = item.get("civitai") if isinstance(item, dict) else None
|
||||||
|
if not isinstance(civitai, dict):
|
||||||
|
continue
|
||||||
|
model_id = self._normalize_int(civitai.get("modelId"))
|
||||||
|
version_id = self._normalize_int(civitai.get("id"))
|
||||||
|
if model_id is None or version_id is None:
|
||||||
|
continue
|
||||||
|
mapping.setdefault(model_id, set()).add(version_id)
|
||||||
|
|
||||||
|
return {model_id: sorted(ids) for model_id, ids in mapping.items()}
|
||||||
|
|
||||||
|
def _is_stale(self, record: ModelUpdateRecord, now: float) -> bool:
|
||||||
|
if record.last_checked_at is None:
|
||||||
|
return True
|
||||||
|
return (now - record.last_checked_at) >= self._ttl_seconds
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_int(value) -> Optional[int]:
|
||||||
|
try:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
return int(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _normalize_sequence(self, values: Sequence[int]) -> List[int]:
|
||||||
|
normalized = [
|
||||||
|
item
|
||||||
|
for item in (self._normalize_int(value) for value in values)
|
||||||
|
if item is not None
|
||||||
|
]
|
||||||
|
return sorted(dict.fromkeys(normalized))
|
||||||
|
|
||||||
|
def _extract_version_ids(self, response) -> Optional[List[int]]:
|
||||||
|
if not isinstance(response, Mapping):
|
||||||
|
return None
|
||||||
|
versions = response.get("modelVersions")
|
||||||
|
if versions is None:
|
||||||
|
return []
|
||||||
|
if not isinstance(versions, Iterable):
|
||||||
|
return None
|
||||||
|
normalized = []
|
||||||
|
for entry in versions:
|
||||||
|
if isinstance(entry, Mapping):
|
||||||
|
normalized_id = self._normalize_int(entry.get("id"))
|
||||||
|
else:
|
||||||
|
normalized_id = self._normalize_int(entry)
|
||||||
|
if normalized_id is not None:
|
||||||
|
normalized.append(normalized_id)
|
||||||
|
return sorted(dict.fromkeys(normalized))
|
||||||
|
|
||||||
|
def _get_record(self, model_type: str, model_id: int) -> Optional[ModelUpdateRecord]:
|
||||||
|
with self._connect() as conn:
|
||||||
|
row = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT model_type, model_id, largest_version_id, version_ids,
|
||||||
|
in_library_version_ids, last_checked_at, should_ignore
|
||||||
|
FROM model_update_status
|
||||||
|
WHERE model_type = ? AND model_id = ?
|
||||||
|
""",
|
||||||
|
(model_type, model_id),
|
||||||
|
).fetchone()
|
||||||
|
if not row:
|
||||||
|
return None
|
||||||
|
return ModelUpdateRecord(
|
||||||
|
model_type=row["model_type"],
|
||||||
|
model_id=int(row["model_id"]),
|
||||||
|
largest_version_id=self._normalize_int(row["largest_version_id"]),
|
||||||
|
version_ids=self._deserialize_json_array(row["version_ids"]),
|
||||||
|
in_library_version_ids=self._deserialize_json_array(
|
||||||
|
row["in_library_version_ids"]
|
||||||
|
),
|
||||||
|
last_checked_at=row["last_checked_at"],
|
||||||
|
should_ignore=bool(row["should_ignore"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _upsert_record(self, record: ModelUpdateRecord) -> None:
|
||||||
|
payload = (
|
||||||
|
record.model_type,
|
||||||
|
record.model_id,
|
||||||
|
record.largest_version_id,
|
||||||
|
json.dumps(record.version_ids),
|
||||||
|
json.dumps(record.in_library_version_ids),
|
||||||
|
record.last_checked_at,
|
||||||
|
1 if record.should_ignore else 0,
|
||||||
|
)
|
||||||
|
with self._connect() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO model_update_status (
|
||||||
|
model_type, model_id, largest_version_id, version_ids,
|
||||||
|
in_library_version_ids, last_checked_at, should_ignore
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT(model_type, model_id) DO UPDATE SET
|
||||||
|
largest_version_id = excluded.largest_version_id,
|
||||||
|
version_ids = excluded.version_ids,
|
||||||
|
in_library_version_ids = excluded.in_library_version_ids,
|
||||||
|
last_checked_at = excluded.last_checked_at,
|
||||||
|
should_ignore = excluded.should_ignore
|
||||||
|
""",
|
||||||
|
payload,
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _deserialize_json_array(value) -> List[int]:
|
||||||
|
if not value:
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
data = json.loads(value)
|
||||||
|
except (TypeError, json.JSONDecodeError):
|
||||||
|
return []
|
||||||
|
if isinstance(data, list):
|
||||||
|
normalized = []
|
||||||
|
for entry in data:
|
||||||
|
try:
|
||||||
|
normalized.append(int(entry))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
return sorted(dict.fromkeys(normalized))
|
||||||
|
return []
|
||||||
|
|
||||||
@@ -81,6 +81,11 @@ class PersistentModelCache:
|
|||||||
def is_enabled(self) -> bool:
|
def is_enabled(self) -> bool:
|
||||||
return os.environ.get("LORA_MANAGER_DISABLE_PERSISTENT_CACHE", "0") != "1"
|
return os.environ.get("LORA_MANAGER_DISABLE_PERSISTENT_CACHE", "0") != "1"
|
||||||
|
|
||||||
|
def get_database_path(self) -> str:
|
||||||
|
"""Expose the resolved SQLite database path."""
|
||||||
|
|
||||||
|
return self._db_path
|
||||||
|
|
||||||
def load_cache(self, model_type: str) -> Optional[PersistedCacheData]:
|
def load_cache(self, model_type: str) -> Optional[PersistedCacheData]:
|
||||||
if not self.is_enabled():
|
if not self.is_enabled():
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -145,6 +145,28 @@ class ServiceRegistry:
|
|||||||
logger.debug(f"Created and registered {service_name}")
|
logger.debug(f"Created and registered {service_name}")
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_model_update_service(cls):
|
||||||
|
"""Get or create the model update tracking service."""
|
||||||
|
|
||||||
|
service_name = "model_update_service"
|
||||||
|
|
||||||
|
if service_name in cls._services:
|
||||||
|
return cls._services[service_name]
|
||||||
|
|
||||||
|
async with cls._get_lock(service_name):
|
||||||
|
if service_name in cls._services:
|
||||||
|
return cls._services[service_name]
|
||||||
|
|
||||||
|
from .model_update_service import ModelUpdateService
|
||||||
|
from .persistent_model_cache import get_persistent_cache
|
||||||
|
|
||||||
|
cache = get_persistent_cache()
|
||||||
|
service = ModelUpdateService(cache.get_database_path())
|
||||||
|
cls._services[service_name] = service
|
||||||
|
logger.debug(f"Created and registered {service_name}")
|
||||||
|
return service
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_civarchive_client(cls):
|
async def get_civarchive_client(cls):
|
||||||
"""Get or create CivArchive client instance"""
|
"""Get or create CivArchive client instance"""
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import sys
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import types
|
import types
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
folder_paths_stub = types.SimpleNamespace(get_folder_paths=lambda *_: [])
|
folder_paths_stub = types.SimpleNamespace(get_folder_paths=lambda *_: [])
|
||||||
@@ -32,6 +33,41 @@ class DummyRoutes(BaseModelRoutes):
|
|||||||
def setup_specific_routes(self, registrar, prefix: str) -> None: # pragma: no cover - no extra routes in smoke tests
|
def setup_specific_routes(self, registrar, prefix: str) -> None: # pragma: no cover - no extra routes in smoke tests
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def __init__(self, service=None):
|
||||||
|
super().__init__(service)
|
||||||
|
self.set_model_update_service(NullModelUpdateService())
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NullUpdateRecord:
|
||||||
|
model_type: str
|
||||||
|
model_id: int
|
||||||
|
largest_version_id: int | None = None
|
||||||
|
version_ids: list[int] = field(default_factory=list)
|
||||||
|
in_library_version_ids: list[int] = field(default_factory=list)
|
||||||
|
last_checked_at: float | None = None
|
||||||
|
should_ignore: bool = False
|
||||||
|
|
||||||
|
def has_update(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class NullModelUpdateService:
|
||||||
|
async def refresh_for_model_type(self, *args, **kwargs):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def refresh_single_model(self, *args, **kwargs):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def update_in_library_versions(self, model_type, model_id, version_ids):
|
||||||
|
return NullUpdateRecord(model_type=model_type, model_id=model_id, in_library_version_ids=list(version_ids))
|
||||||
|
|
||||||
|
async def set_should_ignore(self, model_type, model_id, should_ignore):
|
||||||
|
return NullUpdateRecord(model_type=model_type, model_id=model_id, should_ignore=should_ignore)
|
||||||
|
|
||||||
|
async def get_record(self, *args, **kwargs):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def create_test_client(service) -> TestClient:
|
async def create_test_client(service) -> TestClient:
|
||||||
routes = DummyRoutes(service)
|
routes = DummyRoutes(service)
|
||||||
|
|||||||
85
tests/services/test_model_update_service.py
Normal file
85
tests/services/test_model_update_service.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
import asyncio
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from py.services.model_update_service import ModelUpdateService
|
||||||
|
|
||||||
|
|
||||||
|
class DummyScanner:
|
||||||
|
def __init__(self, raw_data):
|
||||||
|
self._cache = SimpleNamespace(raw_data=raw_data)
|
||||||
|
|
||||||
|
async def get_cached_data(self, *args, **kwargs):
|
||||||
|
return self._cache
|
||||||
|
|
||||||
|
|
||||||
|
class DummyProvider:
|
||||||
|
def __init__(self, response):
|
||||||
|
self.response = response
|
||||||
|
self.calls: int = 0
|
||||||
|
|
||||||
|
async def get_model_versions(self, model_id):
|
||||||
|
self.calls += 1
|
||||||
|
return self.response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_refresh_persists_versions_and_uses_cache(tmp_path):
|
||||||
|
db_path = tmp_path / "updates.sqlite"
|
||||||
|
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
|
||||||
|
raw_data = [
|
||||||
|
{"civitai": {"modelId": 1, "id": 11}},
|
||||||
|
{"civitai": {"modelId": 1, "id": 15}},
|
||||||
|
]
|
||||||
|
scanner = DummyScanner(raw_data)
|
||||||
|
provider = DummyProvider({"modelVersions": [{"id": 11}, {"id": 15}]})
|
||||||
|
|
||||||
|
await service.refresh_for_model_type("lora", scanner, provider)
|
||||||
|
record = await service.get_record("lora", 1)
|
||||||
|
|
||||||
|
assert provider.calls == 1
|
||||||
|
assert record is not None
|
||||||
|
assert record.version_ids == [11, 15]
|
||||||
|
assert record.in_library_version_ids == [11, 15]
|
||||||
|
assert record.has_update() is False
|
||||||
|
|
||||||
|
await service.refresh_for_model_type("lora", scanner, provider)
|
||||||
|
assert provider.calls == 1, "provider should not be called again within TTL"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_refresh_respects_ignore_flag(tmp_path):
|
||||||
|
db_path = tmp_path / "updates.sqlite"
|
||||||
|
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
|
||||||
|
raw_data = [{"civitai": {"modelId": 2, "id": 21}}]
|
||||||
|
scanner = DummyScanner(raw_data)
|
||||||
|
provider = DummyProvider({"modelVersions": [{"id": 21}, {"id": 22}]})
|
||||||
|
|
||||||
|
await service.refresh_for_model_type("lora", scanner, provider)
|
||||||
|
await service.set_should_ignore("lora", 2, True)
|
||||||
|
|
||||||
|
provider.calls = 0
|
||||||
|
await service.refresh_for_model_type("lora", scanner, provider)
|
||||||
|
assert provider.calls == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_in_library_versions_changes_update_state(tmp_path):
|
||||||
|
db_path = tmp_path / "updates.sqlite"
|
||||||
|
service = ModelUpdateService(str(db_path), ttl_seconds=1)
|
||||||
|
raw_data = [{"civitai": {"modelId": 3, "id": 31}}]
|
||||||
|
scanner = DummyScanner(raw_data)
|
||||||
|
provider = DummyProvider({"modelVersions": [{"id": 31}, {"id": 35}]})
|
||||||
|
|
||||||
|
await service.refresh_for_model_type("lora", scanner, provider)
|
||||||
|
await service.update_in_library_versions("lora", 3, [31])
|
||||||
|
record = await service.get_record("lora", 3)
|
||||||
|
|
||||||
|
assert record is not None
|
||||||
|
assert record.has_update() is True
|
||||||
|
|
||||||
|
await service.update_in_library_versions("lora", 3, [31, 35])
|
||||||
|
record = await service.get_record("lora", 3)
|
||||||
|
|
||||||
|
assert record.has_update() is False
|
||||||
Reference in New Issue
Block a user