mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
refactor(routes): limit update endpoints to essentials
This commit is contained in:
@@ -29,6 +29,7 @@ from ...services.use_cases import (
|
||||
)
|
||||
from ...services.websocket_manager import WebSocketManager
|
||||
from ...services.websocket_progress_callback import WebSocketProgressCallback
|
||||
from ...services.errors import RateLimitError
|
||||
from ...utils.file_utils import calculate_sha256
|
||||
from ...utils.metadata_manager import MetadataManager
|
||||
|
||||
@@ -1017,6 +1018,156 @@ class ModelAutoOrganizeHandler:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
|
||||
class ModelUpdateHandler:
|
||||
"""Handle update tracking requests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
service,
|
||||
update_service,
|
||||
metadata_provider_selector,
|
||||
logger: logging.Logger,
|
||||
) -> None:
|
||||
self._service = service
|
||||
self._update_service = update_service
|
||||
self._metadata_provider_selector = metadata_provider_selector
|
||||
self._logger = logger
|
||||
|
||||
async def refresh_model_updates(self, request: web.Request) -> web.Response:
|
||||
payload = await self._read_json(request)
|
||||
force_refresh = self._parse_bool(request.query.get("force")) or self._parse_bool(
|
||||
payload.get("force")
|
||||
)
|
||||
provider = await self._get_civitai_provider()
|
||||
if provider is None:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Civitai provider not available"}, status=503
|
||||
)
|
||||
|
||||
try:
|
||||
records = await self._update_service.refresh_for_model_type(
|
||||
self._service.model_type,
|
||||
self._service.scanner,
|
||||
provider,
|
||||
force_refresh=force_refresh,
|
||||
)
|
||||
except RateLimitError as exc:
|
||||
return web.json_response(
|
||||
{"success": False, "error": str(exc) or "Rate limited"}, status=429
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._logger.error("Failed to refresh model updates: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"success": True,
|
||||
"records": [self._serialize_record(record) for record in records.values()],
|
||||
}
|
||||
)
|
||||
|
||||
async def set_model_update_ignore(self, request: web.Request) -> web.Response:
|
||||
payload = await self._read_json(request)
|
||||
model_id = self._normalize_model_id(payload.get("modelId"))
|
||||
if model_id is None:
|
||||
return web.json_response({"success": False, "error": "modelId is required"}, status=400)
|
||||
|
||||
should_ignore = self._parse_bool(payload.get("shouldIgnore"))
|
||||
record = await self._update_service.set_should_ignore(
|
||||
self._service.model_type, model_id, should_ignore
|
||||
)
|
||||
return web.json_response({"success": True, "record": self._serialize_record(record)})
|
||||
|
||||
async def get_model_update_status(self, request: web.Request) -> web.Response:
|
||||
model_id = self._normalize_model_id(request.match_info.get("model_id"))
|
||||
if model_id is None:
|
||||
return web.json_response({"success": False, "error": "model_id must be an integer"}, status=400)
|
||||
|
||||
refresh = self._parse_bool(request.query.get("refresh"))
|
||||
force = self._parse_bool(request.query.get("force"))
|
||||
|
||||
try:
|
||||
record = await self._get_or_refresh_record(model_id, refresh=refresh, force=force)
|
||||
except RateLimitError as exc:
|
||||
return web.json_response(
|
||||
{"success": False, "error": str(exc) or "Rate limited"}, status=429
|
||||
)
|
||||
|
||||
if record is None:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Model not tracked"}, status=404
|
||||
)
|
||||
|
||||
return web.json_response({"success": True, "record": self._serialize_record(record)})
|
||||
|
||||
async def _get_or_refresh_record(
|
||||
self, model_id: int, *, refresh: bool, force: bool
|
||||
) -> Optional[object]:
|
||||
record = await self._update_service.get_record(self._service.model_type, model_id)
|
||||
if record and not refresh and not force:
|
||||
return record
|
||||
|
||||
provider = await self._get_civitai_provider()
|
||||
if provider is None:
|
||||
return record
|
||||
|
||||
return await self._update_service.refresh_single_model(
|
||||
self._service.model_type,
|
||||
model_id,
|
||||
self._service.scanner,
|
||||
provider,
|
||||
force_refresh=force or refresh,
|
||||
)
|
||||
|
||||
async def _get_civitai_provider(self):
|
||||
try:
|
||||
return await self._metadata_provider_selector("civitai_api")
|
||||
except Exception as exc: # pragma: no cover - defensive log
|
||||
self._logger.error("Failed to acquire civitai provider: %s", exc, exc_info=True)
|
||||
return None
|
||||
|
||||
async def _read_json(self, request: web.Request) -> Dict:
|
||||
if not request.can_read_body:
|
||||
return {}
|
||||
try:
|
||||
return await request.json()
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def _parse_bool(value) -> bool:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value.lower() in {"1", "true", "yes"}
|
||||
if isinstance(value, (int, float)):
|
||||
return bool(value)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _normalize_model_id(value) -> Optional[int]:
|
||||
try:
|
||||
if value is None:
|
||||
return None
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _serialize_record(record) -> Dict:
|
||||
return {
|
||||
"modelType": record.model_type,
|
||||
"modelId": record.model_id,
|
||||
"largestVersionId": record.largest_version_id,
|
||||
"versionIds": record.version_ids,
|
||||
"inLibraryVersionIds": record.in_library_version_ids,
|
||||
"lastCheckedAt": record.last_checked_at,
|
||||
"shouldIgnore": record.should_ignore,
|
||||
"hasUpdate": record.has_update(),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelHandlerSet:
|
||||
"""Aggregate concrete handlers into a flat mapping."""
|
||||
@@ -1029,6 +1180,7 @@ class ModelHandlerSet:
|
||||
civitai: ModelCivitaiHandler
|
||||
move: ModelMoveHandler
|
||||
auto_organize: ModelAutoOrganizeHandler
|
||||
updates: ModelUpdateHandler
|
||||
|
||||
def to_route_mapping(self) -> Dict[str, Callable[[web.Request], Awaitable[web.Response]]]:
|
||||
return {
|
||||
@@ -1073,5 +1225,8 @@ class ModelHandlerSet:
|
||||
"get_model_metadata": self.query.get_model_metadata,
|
||||
"get_model_description": self.query.get_model_description,
|
||||
"get_relative_paths": self.query.get_relative_paths,
|
||||
"refresh_model_updates": self.updates.refresh_model_updates,
|
||||
"set_model_update_ignore": self.updates.set_model_update_ignore,
|
||||
"get_model_update_status": self.updates.get_model_update_status,
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user