refactor(routes): limit update endpoints to essentials

This commit is contained in:
pixelpaws
2025-10-15 15:37:35 +08:00
parent 321ff72953
commit ee0d241c75
11 changed files with 749 additions and 8 deletions

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from typing import Callable, Dict, Mapping
from typing import TYPE_CHECKING, Callable, Dict, Mapping
import jinja2
from aiohttp import web
@@ -42,8 +42,12 @@ from .handlers.model_handlers import (
ModelMoveHandler,
ModelPageView,
ModelQueryHandler,
ModelUpdateHandler,
)
if TYPE_CHECKING:
from ..services.model_update_service import ModelUpdateService
logger = logging.getLogger(__name__)
@@ -99,10 +103,18 @@ class BaseModelRoutes(ABC):
ws_manager=self._ws_manager,
download_manager_factory=ServiceRegistry.get_download_manager,
)
self._model_update_service: ModelUpdateService | None = None
if service is not None:
self.attach_service(service)
def set_model_update_service(self, service: "ModelUpdateService") -> None:
"""Attach the model update tracking service."""
self._model_update_service = service
self._handler_set = None
self._handler_mapping = None
def attach_service(self, service) -> None:
"""Attach a model service and rebuild handler dependencies."""
self.service = service
@@ -127,6 +139,7 @@ class BaseModelRoutes(ABC):
def _create_handler_set(self) -> ModelHandlerSet:
service = self._ensure_service()
update_service = self._ensure_model_update_service()
page_view = ModelPageView(
template_env=self.template_env,
template_name=self.template_name or "",
@@ -186,6 +199,12 @@ class BaseModelRoutes(ABC):
ws_manager=self._ws_manager,
logger=logger,
)
updates = ModelUpdateHandler(
service=service,
update_service=update_service,
metadata_provider_selector=get_metadata_provider,
logger=logger,
)
return ModelHandlerSet(
page_view=page_view,
listing=listing,
@@ -195,6 +214,7 @@ class BaseModelRoutes(ABC):
civitai=civitai,
move=move,
auto_organize=auto_organize,
updates=updates,
)
@property
@@ -273,3 +293,8 @@ class BaseModelRoutes(ABC):
return proxy
def _ensure_model_update_service(self) -> "ModelUpdateService":
if self._model_update_service is None:
raise RuntimeError("Model update service has not been attached")
return self._model_update_service

View File

@@ -21,7 +21,9 @@ class CheckpointRoutes(BaseModelRoutes):
"""Initialize services from ServiceRegistry"""
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
self.service = CheckpointService(checkpoint_scanner)
update_service = await ServiceRegistry.get_model_update_service()
self.set_model_update_service(update_service)
# Attach service dependencies
self.attach_service(self.service)

View File

@@ -20,7 +20,9 @@ class EmbeddingRoutes(BaseModelRoutes):
"""Initialize services from ServiceRegistry"""
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
self.service = EmbeddingService(embedding_scanner)
update_service = await ServiceRegistry.get_model_update_service()
self.set_model_update_service(update_service)
# Attach service dependencies
self.attach_service(self.service)

View File

@@ -29,6 +29,7 @@ from ...services.use_cases import (
)
from ...services.websocket_manager import WebSocketManager
from ...services.websocket_progress_callback import WebSocketProgressCallback
from ...services.errors import RateLimitError
from ...utils.file_utils import calculate_sha256
from ...utils.metadata_manager import MetadataManager
@@ -1017,6 +1018,156 @@ class ModelAutoOrganizeHandler:
return web.json_response({"success": False, "error": str(exc)}, status=500)
class ModelUpdateHandler:
"""Handle update tracking requests."""
def __init__(
self,
*,
service,
update_service,
metadata_provider_selector,
logger: logging.Logger,
) -> None:
self._service = service
self._update_service = update_service
self._metadata_provider_selector = metadata_provider_selector
self._logger = logger
async def refresh_model_updates(self, request: web.Request) -> web.Response:
payload = await self._read_json(request)
force_refresh = self._parse_bool(request.query.get("force")) or self._parse_bool(
payload.get("force")
)
provider = await self._get_civitai_provider()
if provider is None:
return web.json_response(
{"success": False, "error": "Civitai provider not available"}, status=503
)
try:
records = await self._update_service.refresh_for_model_type(
self._service.model_type,
self._service.scanner,
provider,
force_refresh=force_refresh,
)
except RateLimitError as exc:
return web.json_response(
{"success": False, "error": str(exc) or "Rate limited"}, status=429
)
except Exception as exc: # pragma: no cover - defensive logging
self._logger.error("Failed to refresh model updates: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
return web.json_response(
{
"success": True,
"records": [self._serialize_record(record) for record in records.values()],
}
)
async def set_model_update_ignore(self, request: web.Request) -> web.Response:
payload = await self._read_json(request)
model_id = self._normalize_model_id(payload.get("modelId"))
if model_id is None:
return web.json_response({"success": False, "error": "modelId is required"}, status=400)
should_ignore = self._parse_bool(payload.get("shouldIgnore"))
record = await self._update_service.set_should_ignore(
self._service.model_type, model_id, should_ignore
)
return web.json_response({"success": True, "record": self._serialize_record(record)})
async def get_model_update_status(self, request: web.Request) -> web.Response:
model_id = self._normalize_model_id(request.match_info.get("model_id"))
if model_id is None:
return web.json_response({"success": False, "error": "model_id must be an integer"}, status=400)
refresh = self._parse_bool(request.query.get("refresh"))
force = self._parse_bool(request.query.get("force"))
try:
record = await self._get_or_refresh_record(model_id, refresh=refresh, force=force)
except RateLimitError as exc:
return web.json_response(
{"success": False, "error": str(exc) or "Rate limited"}, status=429
)
if record is None:
return web.json_response(
{"success": False, "error": "Model not tracked"}, status=404
)
return web.json_response({"success": True, "record": self._serialize_record(record)})
async def _get_or_refresh_record(
self, model_id: int, *, refresh: bool, force: bool
) -> Optional[object]:
record = await self._update_service.get_record(self._service.model_type, model_id)
if record and not refresh and not force:
return record
provider = await self._get_civitai_provider()
if provider is None:
return record
return await self._update_service.refresh_single_model(
self._service.model_type,
model_id,
self._service.scanner,
provider,
force_refresh=force or refresh,
)
async def _get_civitai_provider(self):
try:
return await self._metadata_provider_selector("civitai_api")
except Exception as exc: # pragma: no cover - defensive log
self._logger.error("Failed to acquire civitai provider: %s", exc, exc_info=True)
return None
async def _read_json(self, request: web.Request) -> Dict:
if not request.can_read_body:
return {}
try:
return await request.json()
except Exception:
return {}
@staticmethod
def _parse_bool(value) -> bool:
if isinstance(value, bool):
return value
if isinstance(value, str):
return value.lower() in {"1", "true", "yes"}
if isinstance(value, (int, float)):
return bool(value)
return False
@staticmethod
def _normalize_model_id(value) -> Optional[int]:
try:
if value is None:
return None
return int(value)
except (TypeError, ValueError):
return None
@staticmethod
def _serialize_record(record) -> Dict:
return {
"modelType": record.model_type,
"modelId": record.model_id,
"largestVersionId": record.largest_version_id,
"versionIds": record.version_ids,
"inLibraryVersionIds": record.in_library_version_ids,
"lastCheckedAt": record.last_checked_at,
"shouldIgnore": record.should_ignore,
"hasUpdate": record.has_update(),
}
@dataclass
class ModelHandlerSet:
"""Aggregate concrete handlers into a flat mapping."""
@@ -1029,6 +1180,7 @@ class ModelHandlerSet:
civitai: ModelCivitaiHandler
move: ModelMoveHandler
auto_organize: ModelAutoOrganizeHandler
updates: ModelUpdateHandler
def to_route_mapping(self) -> Dict[str, Callable[[web.Request], Awaitable[web.Response]]]:
return {
@@ -1073,5 +1225,8 @@ class ModelHandlerSet:
"get_model_metadata": self.query.get_model_metadata,
"get_model_description": self.query.get_model_description,
"get_relative_paths": self.query.get_relative_paths,
"refresh_model_updates": self.updates.refresh_model_updates,
"set_model_update_ignore": self.updates.set_model_update_ignore,
"get_model_update_status": self.updates.get_model_update_status,
}

View File

@@ -24,7 +24,9 @@ class LoraRoutes(BaseModelRoutes):
"""Initialize services from ServiceRegistry"""
lora_scanner = await ServiceRegistry.get_lora_scanner()
self.service = LoraService(lora_scanner)
update_service = await ServiceRegistry.get_model_update_service()
self.set_model_update_service(update_service)
# Attach service dependencies
self.attach_service(self.service)

View File

@@ -55,6 +55,9 @@ COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
RouteDefinition("GET", "/api/lm/{prefix}/civitai/versions/{model_id}", "get_civitai_versions"),
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/version/{modelVersionId}", "get_civitai_model_by_version"),
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/hash/{hash}", "get_civitai_model_by_hash"),
RouteDefinition("POST", "/api/lm/{prefix}/updates/refresh", "refresh_model_updates"),
RouteDefinition("POST", "/api/lm/{prefix}/updates/ignore", "set_model_update_ignore"),
RouteDefinition("GET", "/api/lm/{prefix}/updates/status/{model_id}", "get_model_update_status"),
RouteDefinition("POST", "/api/lm/download-model", "download_model"),
RouteDefinition("GET", "/api/lm/download-model-get", "download_model_get"),
RouteDefinition("GET", "/api/lm/cancel-download-get", "cancel_download_get"),