mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 22:52:12 -03:00
refactor(routes): limit update endpoints to essentials
This commit is contained in:
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Dict, Mapping
|
||||
from typing import TYPE_CHECKING, Callable, Dict, Mapping
|
||||
|
||||
import jinja2
|
||||
from aiohttp import web
|
||||
@@ -42,8 +42,12 @@ from .handlers.model_handlers import (
|
||||
ModelMoveHandler,
|
||||
ModelPageView,
|
||||
ModelQueryHandler,
|
||||
ModelUpdateHandler,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..services.model_update_service import ModelUpdateService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -99,10 +103,18 @@ class BaseModelRoutes(ABC):
|
||||
ws_manager=self._ws_manager,
|
||||
download_manager_factory=ServiceRegistry.get_download_manager,
|
||||
)
|
||||
self._model_update_service: ModelUpdateService | None = None
|
||||
|
||||
if service is not None:
|
||||
self.attach_service(service)
|
||||
|
||||
def set_model_update_service(self, service: "ModelUpdateService") -> None:
|
||||
"""Attach the model update tracking service."""
|
||||
|
||||
self._model_update_service = service
|
||||
self._handler_set = None
|
||||
self._handler_mapping = None
|
||||
|
||||
def attach_service(self, service) -> None:
|
||||
"""Attach a model service and rebuild handler dependencies."""
|
||||
self.service = service
|
||||
@@ -127,6 +139,7 @@ class BaseModelRoutes(ABC):
|
||||
|
||||
def _create_handler_set(self) -> ModelHandlerSet:
|
||||
service = self._ensure_service()
|
||||
update_service = self._ensure_model_update_service()
|
||||
page_view = ModelPageView(
|
||||
template_env=self.template_env,
|
||||
template_name=self.template_name or "",
|
||||
@@ -186,6 +199,12 @@ class BaseModelRoutes(ABC):
|
||||
ws_manager=self._ws_manager,
|
||||
logger=logger,
|
||||
)
|
||||
updates = ModelUpdateHandler(
|
||||
service=service,
|
||||
update_service=update_service,
|
||||
metadata_provider_selector=get_metadata_provider,
|
||||
logger=logger,
|
||||
)
|
||||
return ModelHandlerSet(
|
||||
page_view=page_view,
|
||||
listing=listing,
|
||||
@@ -195,6 +214,7 @@ class BaseModelRoutes(ABC):
|
||||
civitai=civitai,
|
||||
move=move,
|
||||
auto_organize=auto_organize,
|
||||
updates=updates,
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -273,3 +293,8 @@ class BaseModelRoutes(ABC):
|
||||
|
||||
return proxy
|
||||
|
||||
def _ensure_model_update_service(self) -> "ModelUpdateService":
|
||||
if self._model_update_service is None:
|
||||
raise RuntimeError("Model update service has not been attached")
|
||||
return self._model_update_service
|
||||
|
||||
|
||||
@@ -21,7 +21,9 @@ class CheckpointRoutes(BaseModelRoutes):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
self.service = CheckpointService(checkpoint_scanner)
|
||||
|
||||
update_service = await ServiceRegistry.get_model_update_service()
|
||||
self.set_model_update_service(update_service)
|
||||
|
||||
# Attach service dependencies
|
||||
self.attach_service(self.service)
|
||||
|
||||
|
||||
@@ -20,7 +20,9 @@ class EmbeddingRoutes(BaseModelRoutes):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
self.service = EmbeddingService(embedding_scanner)
|
||||
|
||||
update_service = await ServiceRegistry.get_model_update_service()
|
||||
self.set_model_update_service(update_service)
|
||||
|
||||
# Attach service dependencies
|
||||
self.attach_service(self.service)
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ from ...services.use_cases import (
|
||||
)
|
||||
from ...services.websocket_manager import WebSocketManager
|
||||
from ...services.websocket_progress_callback import WebSocketProgressCallback
|
||||
from ...services.errors import RateLimitError
|
||||
from ...utils.file_utils import calculate_sha256
|
||||
from ...utils.metadata_manager import MetadataManager
|
||||
|
||||
@@ -1017,6 +1018,156 @@ class ModelAutoOrganizeHandler:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
|
||||
class ModelUpdateHandler:
|
||||
"""Handle update tracking requests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
service,
|
||||
update_service,
|
||||
metadata_provider_selector,
|
||||
logger: logging.Logger,
|
||||
) -> None:
|
||||
self._service = service
|
||||
self._update_service = update_service
|
||||
self._metadata_provider_selector = metadata_provider_selector
|
||||
self._logger = logger
|
||||
|
||||
async def refresh_model_updates(self, request: web.Request) -> web.Response:
|
||||
payload = await self._read_json(request)
|
||||
force_refresh = self._parse_bool(request.query.get("force")) or self._parse_bool(
|
||||
payload.get("force")
|
||||
)
|
||||
provider = await self._get_civitai_provider()
|
||||
if provider is None:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Civitai provider not available"}, status=503
|
||||
)
|
||||
|
||||
try:
|
||||
records = await self._update_service.refresh_for_model_type(
|
||||
self._service.model_type,
|
||||
self._service.scanner,
|
||||
provider,
|
||||
force_refresh=force_refresh,
|
||||
)
|
||||
except RateLimitError as exc:
|
||||
return web.json_response(
|
||||
{"success": False, "error": str(exc) or "Rate limited"}, status=429
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._logger.error("Failed to refresh model updates: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"success": True,
|
||||
"records": [self._serialize_record(record) for record in records.values()],
|
||||
}
|
||||
)
|
||||
|
||||
async def set_model_update_ignore(self, request: web.Request) -> web.Response:
|
||||
payload = await self._read_json(request)
|
||||
model_id = self._normalize_model_id(payload.get("modelId"))
|
||||
if model_id is None:
|
||||
return web.json_response({"success": False, "error": "modelId is required"}, status=400)
|
||||
|
||||
should_ignore = self._parse_bool(payload.get("shouldIgnore"))
|
||||
record = await self._update_service.set_should_ignore(
|
||||
self._service.model_type, model_id, should_ignore
|
||||
)
|
||||
return web.json_response({"success": True, "record": self._serialize_record(record)})
|
||||
|
||||
async def get_model_update_status(self, request: web.Request) -> web.Response:
|
||||
model_id = self._normalize_model_id(request.match_info.get("model_id"))
|
||||
if model_id is None:
|
||||
return web.json_response({"success": False, "error": "model_id must be an integer"}, status=400)
|
||||
|
||||
refresh = self._parse_bool(request.query.get("refresh"))
|
||||
force = self._parse_bool(request.query.get("force"))
|
||||
|
||||
try:
|
||||
record = await self._get_or_refresh_record(model_id, refresh=refresh, force=force)
|
||||
except RateLimitError as exc:
|
||||
return web.json_response(
|
||||
{"success": False, "error": str(exc) or "Rate limited"}, status=429
|
||||
)
|
||||
|
||||
if record is None:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Model not tracked"}, status=404
|
||||
)
|
||||
|
||||
return web.json_response({"success": True, "record": self._serialize_record(record)})
|
||||
|
||||
async def _get_or_refresh_record(
|
||||
self, model_id: int, *, refresh: bool, force: bool
|
||||
) -> Optional[object]:
|
||||
record = await self._update_service.get_record(self._service.model_type, model_id)
|
||||
if record and not refresh and not force:
|
||||
return record
|
||||
|
||||
provider = await self._get_civitai_provider()
|
||||
if provider is None:
|
||||
return record
|
||||
|
||||
return await self._update_service.refresh_single_model(
|
||||
self._service.model_type,
|
||||
model_id,
|
||||
self._service.scanner,
|
||||
provider,
|
||||
force_refresh=force or refresh,
|
||||
)
|
||||
|
||||
async def _get_civitai_provider(self):
|
||||
try:
|
||||
return await self._metadata_provider_selector("civitai_api")
|
||||
except Exception as exc: # pragma: no cover - defensive log
|
||||
self._logger.error("Failed to acquire civitai provider: %s", exc, exc_info=True)
|
||||
return None
|
||||
|
||||
async def _read_json(self, request: web.Request) -> Dict:
|
||||
if not request.can_read_body:
|
||||
return {}
|
||||
try:
|
||||
return await request.json()
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def _parse_bool(value) -> bool:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value.lower() in {"1", "true", "yes"}
|
||||
if isinstance(value, (int, float)):
|
||||
return bool(value)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _normalize_model_id(value) -> Optional[int]:
|
||||
try:
|
||||
if value is None:
|
||||
return None
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _serialize_record(record) -> Dict:
|
||||
return {
|
||||
"modelType": record.model_type,
|
||||
"modelId": record.model_id,
|
||||
"largestVersionId": record.largest_version_id,
|
||||
"versionIds": record.version_ids,
|
||||
"inLibraryVersionIds": record.in_library_version_ids,
|
||||
"lastCheckedAt": record.last_checked_at,
|
||||
"shouldIgnore": record.should_ignore,
|
||||
"hasUpdate": record.has_update(),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelHandlerSet:
|
||||
"""Aggregate concrete handlers into a flat mapping."""
|
||||
@@ -1029,6 +1180,7 @@ class ModelHandlerSet:
|
||||
civitai: ModelCivitaiHandler
|
||||
move: ModelMoveHandler
|
||||
auto_organize: ModelAutoOrganizeHandler
|
||||
updates: ModelUpdateHandler
|
||||
|
||||
def to_route_mapping(self) -> Dict[str, Callable[[web.Request], Awaitable[web.Response]]]:
|
||||
return {
|
||||
@@ -1073,5 +1225,8 @@ class ModelHandlerSet:
|
||||
"get_model_metadata": self.query.get_model_metadata,
|
||||
"get_model_description": self.query.get_model_description,
|
||||
"get_relative_paths": self.query.get_relative_paths,
|
||||
"refresh_model_updates": self.updates.refresh_model_updates,
|
||||
"set_model_update_ignore": self.updates.set_model_update_ignore,
|
||||
"get_model_update_status": self.updates.get_model_update_status,
|
||||
}
|
||||
|
||||
|
||||
@@ -24,7 +24,9 @@ class LoraRoutes(BaseModelRoutes):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
self.service = LoraService(lora_scanner)
|
||||
|
||||
update_service = await ServiceRegistry.get_model_update_service()
|
||||
self.set_model_update_service(update_service)
|
||||
|
||||
# Attach service dependencies
|
||||
self.attach_service(self.service)
|
||||
|
||||
|
||||
@@ -55,6 +55,9 @@ COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/civitai/versions/{model_id}", "get_civitai_versions"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/version/{modelVersionId}", "get_civitai_model_by_version"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/hash/{hash}", "get_civitai_model_by_hash"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/updates/refresh", "refresh_model_updates"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/updates/ignore", "set_model_update_ignore"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/updates/status/{model_id}", "get_model_update_status"),
|
||||
RouteDefinition("POST", "/api/lm/download-model", "download_model"),
|
||||
RouteDefinition("GET", "/api/lm/download-model-get", "download_model_get"),
|
||||
RouteDefinition("GET", "/api/lm/cancel-download-get", "cancel_download_get"),
|
||||
|
||||
Reference in New Issue
Block a user