mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
feat: add model version management endpoints
- Add set_version_update_ignore endpoint to toggle ignore status for specific versions - Add get_model_versions endpoint to retrieve version details with optional refresh - Update serialization to include version-specific data and preview overrides - Modify database schema to support version-level ignore tracking - Improve error handling for rate limiting and missing models These changes enable granular control over version updates and provide better visibility into model version status.
This commit is contained in:
@@ -1085,6 +1085,28 @@ class ModelUpdateHandler:
|
||||
)
|
||||
return web.json_response({"success": True, "record": self._serialize_record(record)})
|
||||
|
||||
async def set_version_update_ignore(self, request: web.Request) -> web.Response:
|
||||
payload = await self._read_json(request)
|
||||
model_id = self._normalize_model_id(payload.get("modelId"))
|
||||
version_id = self._normalize_model_id(payload.get("versionId"))
|
||||
if model_id is None or version_id is None:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "modelId and versionId are required"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
should_ignore = self._parse_bool(payload.get("shouldIgnore"))
|
||||
record = await self._update_service.set_version_should_ignore(
|
||||
self._service.model_type,
|
||||
model_id,
|
||||
version_id,
|
||||
should_ignore,
|
||||
)
|
||||
overrides = await self._build_preview_overrides(record)
|
||||
return web.json_response(
|
||||
{"success": True, "record": self._serialize_record(record, preview_overrides=overrides)}
|
||||
)
|
||||
|
||||
async def get_model_update_status(self, request: web.Request) -> web.Response:
|
||||
model_id = self._normalize_model_id(request.match_info.get("model_id"))
|
||||
if model_id is None:
|
||||
@@ -1107,6 +1129,33 @@ class ModelUpdateHandler:
|
||||
|
||||
return web.json_response({"success": True, "record": self._serialize_record(record)})
|
||||
|
||||
async def get_model_versions(self, request: web.Request) -> web.Response:
|
||||
model_id = self._normalize_model_id(request.match_info.get("model_id"))
|
||||
if model_id is None:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "model_id must be an integer"}, status=400
|
||||
)
|
||||
|
||||
refresh = self._parse_bool(request.query.get("refresh"))
|
||||
force = self._parse_bool(request.query.get("force"))
|
||||
|
||||
try:
|
||||
record = await self._get_or_refresh_record(model_id, refresh=refresh, force=force)
|
||||
except RateLimitError as exc:
|
||||
return web.json_response(
|
||||
{"success": False, "error": str(exc) or "Rate limited"}, status=429
|
||||
)
|
||||
|
||||
if record is None:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Model not tracked"}, status=404
|
||||
)
|
||||
|
||||
overrides = await self._build_preview_overrides(record)
|
||||
return web.json_response(
|
||||
{"success": True, "record": self._serialize_record(record, preview_overrides=overrides)}
|
||||
)
|
||||
|
||||
async def _get_or_refresh_record(
|
||||
self, model_id: int, *, refresh: bool, force: bool
|
||||
) -> Optional[object]:
|
||||
@@ -1160,8 +1209,13 @@ class ModelUpdateHandler:
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _serialize_record(record) -> Dict:
|
||||
def _serialize_record(
|
||||
self,
|
||||
record,
|
||||
*,
|
||||
preview_overrides: Optional[Dict[int, Optional[str]]] = None,
|
||||
) -> Dict:
|
||||
overrides = preview_overrides or {}
|
||||
return {
|
||||
"modelType": record.model_type,
|
||||
"modelId": record.model_id,
|
||||
@@ -1169,10 +1223,50 @@ class ModelUpdateHandler:
|
||||
"versionIds": record.version_ids,
|
||||
"inLibraryVersionIds": record.in_library_version_ids,
|
||||
"lastCheckedAt": record.last_checked_at,
|
||||
"shouldIgnore": record.should_ignore,
|
||||
"shouldIgnore": record.should_ignore_model,
|
||||
"hasUpdate": record.has_update(),
|
||||
"versions": [
|
||||
self._serialize_version(version, overrides.get(version.version_id))
|
||||
for version in record.versions
|
||||
],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _serialize_version(version, override_preview: Optional[str]) -> Dict:
|
||||
preview_url = override_preview if override_preview is not None else version.preview_url
|
||||
return {
|
||||
"versionId": version.version_id,
|
||||
"name": version.name,
|
||||
"baseModel": version.base_model,
|
||||
"releasedAt": version.released_at,
|
||||
"sizeBytes": version.size_bytes,
|
||||
"previewUrl": preview_url,
|
||||
"isInLibrary": version.is_in_library,
|
||||
"shouldIgnore": version.should_ignore,
|
||||
}
|
||||
|
||||
async def _build_preview_overrides(self, record) -> Dict[int, Optional[str]]:
|
||||
overrides: Dict[int, Optional[str]] = {}
|
||||
try:
|
||||
cache = await self._service.scanner.get_cached_data()
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._logger.debug("Failed to load cache while building preview overrides: %s", exc)
|
||||
return overrides
|
||||
|
||||
version_index = getattr(cache, "version_index", None)
|
||||
if not version_index:
|
||||
return overrides
|
||||
|
||||
for version in record.versions:
|
||||
if not version.is_in_library:
|
||||
continue
|
||||
cache_entry = version_index.get(version.version_id)
|
||||
if isinstance(cache_entry, Mapping):
|
||||
preview = cache_entry.get("preview_url")
|
||||
if isinstance(preview, str) and preview:
|
||||
overrides[version.version_id] = config.get_preview_static_url(preview)
|
||||
return overrides
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelHandlerSet:
|
||||
@@ -1233,6 +1327,8 @@ class ModelHandlerSet:
|
||||
"get_relative_paths": self.query.get_relative_paths,
|
||||
"refresh_model_updates": self.updates.refresh_model_updates,
|
||||
"set_model_update_ignore": self.updates.set_model_update_ignore,
|
||||
"set_version_update_ignore": self.updates.set_version_update_ignore,
|
||||
"get_model_update_status": self.updates.get_model_update_status,
|
||||
"get_model_versions": self.updates.get_model_versions,
|
||||
}
|
||||
|
||||
|
||||
@@ -57,7 +57,9 @@ COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/hash/{hash}", "get_civitai_model_by_hash"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/updates/refresh", "refresh_model_updates"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/updates/ignore", "set_model_update_ignore"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/updates/ignore-version", "set_version_update_ignore"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/updates/status/{model_id}", "get_model_update_status"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/updates/versions/{model_id}", "get_model_versions"),
|
||||
RouteDefinition("POST", "/api/lm/download-model", "download_model"),
|
||||
RouteDefinition("GET", "/api/lm/download-model-get", "download_model_get"),
|
||||
RouteDefinition("GET", "/api/lm/cancel-download-get", "cancel_download_get"),
|
||||
|
||||
Reference in New Issue
Block a user