mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat(context-menu): refresh missing license metadata
This commit is contained in:
@@ -6,7 +6,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Awaitable, Callable, Dict, Iterable, List, Mapping, Optional
|
||||
from typing import Any, Awaitable, Callable, Dict, Iterable, List, Mapping, Optional
|
||||
|
||||
from aiohttp import web
|
||||
import jinja2
|
||||
@@ -30,9 +30,17 @@ from ...services.use_cases import (
|
||||
from ...services.websocket_manager import WebSocketManager
|
||||
from ...services.websocket_progress_callback import WebSocketProgressCallback
|
||||
from ...services.errors import RateLimitError, ResourceNotFoundError
|
||||
from ...utils.civitai_utils import resolve_license_payload
|
||||
from ...utils.file_utils import calculate_sha256
|
||||
from ...utils.metadata_manager import MetadataManager
|
||||
|
||||
LICENSE_FIELDS = (
|
||||
"allowNoCredit",
|
||||
"allowCommercialUse",
|
||||
"allowDerivatives",
|
||||
"allowDifferentLicense",
|
||||
)
|
||||
|
||||
|
||||
class ModelPageView:
|
||||
"""Render the HTML view for model listings."""
|
||||
@@ -1083,6 +1091,77 @@ class ModelUpdateHandler:
|
||||
self._metadata_provider_selector = metadata_provider_selector
|
||||
self._logger = logger
|
||||
|
||||
async def fetch_missing_civitai_license_data(self, request: web.Request) -> web.Response:
|
||||
payload = await self._read_json(request)
|
||||
target_model_ids = self._extract_target_model_ids(payload)
|
||||
|
||||
provider = await self._get_civitai_provider()
|
||||
if provider is None:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Civitai provider not available"},
|
||||
status=503,
|
||||
)
|
||||
|
||||
try:
|
||||
cache = await self._service.scanner.get_cached_data()
|
||||
except Exception as exc:
|
||||
self._logger.error("Failed to load cache for license refresh: %s", exc, exc_info=True)
|
||||
cache = None
|
||||
|
||||
target_set = set(target_model_ids) if target_model_ids is not None else None
|
||||
candidates = await self._collect_models_missing_license(cache, target_set)
|
||||
if not candidates:
|
||||
return web.json_response({"success": True, "updated": []})
|
||||
|
||||
model_ids = sorted(candidates.keys())
|
||||
try:
|
||||
license_map = await self._fetch_license_info(provider, model_ids)
|
||||
except RateLimitError as exc:
|
||||
return web.json_response(
|
||||
{"success": False, "error": str(exc) or "Rate limited"},
|
||||
status=429,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive log
|
||||
self._logger.error("Failed to fetch license info: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
updated: List[Dict[str, str]] = []
|
||||
errors: List[Dict[str, str]] = []
|
||||
for model_id in model_ids:
|
||||
license_payload = license_map.get(model_id)
|
||||
if not license_payload:
|
||||
continue
|
||||
resolved_payload = resolve_license_payload(license_payload)
|
||||
for context in candidates.get(model_id, []):
|
||||
metadata_path = context["file_path"]
|
||||
metadata_payload = context["metadata"]
|
||||
civitai_section = metadata_payload.setdefault("civitai", {})
|
||||
model_section = civitai_section.get("model")
|
||||
if not isinstance(model_section, Mapping):
|
||||
model_section = {}
|
||||
model_section.update(resolved_payload)
|
||||
civitai_section["model"] = model_section
|
||||
metadata_payload["civitai"] = civitai_section
|
||||
try:
|
||||
await MetadataManager.save_metadata(metadata_path, metadata_payload)
|
||||
updated.append({"modelId": model_id, "filePath": metadata_path})
|
||||
except Exception as exc:
|
||||
self._logger.error(
|
||||
"Failed to save metadata for %s: %s",
|
||||
metadata_path,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
errors.append({"filePath": metadata_path, "error": str(exc)})
|
||||
|
||||
response_payload = {"success": True, "updated": updated}
|
||||
missing_model_ids = [mid for mid in model_ids if mid not in license_map]
|
||||
if missing_model_ids:
|
||||
response_payload["missingModelIds"] = missing_model_ids
|
||||
if errors:
|
||||
response_payload["errors"] = errors
|
||||
return web.json_response(response_payload)
|
||||
|
||||
async def refresh_model_updates(self, request: web.Request) -> web.Response:
|
||||
payload = await self._read_json(request)
|
||||
force_refresh = self._parse_bool(request.query.get("force")) or self._parse_bool(
|
||||
@@ -1247,6 +1326,128 @@ class ModelUpdateHandler:
|
||||
self._logger.error("Failed to acquire civitai provider: %s", exc, exc_info=True)
|
||||
return None
|
||||
|
||||
async def _collect_models_missing_license(
|
||||
self,
|
||||
cache,
|
||||
target_model_ids: Optional[set[int]],
|
||||
) -> Dict[int, List[Dict[str, Any]]]:
|
||||
entries: Dict[int, List[Dict[str, Any]]] = {}
|
||||
if cache is None:
|
||||
return entries
|
||||
|
||||
raw_data = getattr(cache, "raw_data", None) or []
|
||||
seen_paths: set[str] = set()
|
||||
target_set = target_model_ids
|
||||
|
||||
for item in raw_data:
|
||||
if not isinstance(item, Mapping):
|
||||
continue
|
||||
file_path = item.get("file_path")
|
||||
if not isinstance(file_path, str) or not file_path or file_path in seen_paths:
|
||||
continue
|
||||
seen_paths.add(file_path)
|
||||
|
||||
civitai_entry = item.get("civitai")
|
||||
if not isinstance(civitai_entry, Mapping):
|
||||
continue
|
||||
|
||||
model_id = self._normalize_model_id(civitai_entry.get("modelId"))
|
||||
if model_id is None:
|
||||
continue
|
||||
if target_set is not None and model_id not in target_set:
|
||||
continue
|
||||
|
||||
try:
|
||||
metadata_obj, should_skip = await MetadataManager.load_metadata(file_path)
|
||||
except Exception as exc:
|
||||
self._logger.debug("Failed to load metadata for %s: %s", file_path, exc)
|
||||
continue
|
||||
if metadata_obj is None or should_skip:
|
||||
continue
|
||||
|
||||
metadata_payload = self._convert_metadata_to_dict(metadata_obj)
|
||||
civitai_payload = metadata_payload.get("civitai")
|
||||
if not isinstance(civitai_payload, Mapping):
|
||||
civitai_payload = {}
|
||||
|
||||
model_payload = civitai_payload.get("model")
|
||||
if not isinstance(model_payload, Mapping):
|
||||
model_payload = {}
|
||||
|
||||
missing = [key for key in LICENSE_FIELDS if key not in model_payload]
|
||||
if not missing:
|
||||
continue
|
||||
|
||||
civitai_payload["model"] = model_payload
|
||||
metadata_payload["civitai"] = civitai_payload
|
||||
entries.setdefault(model_id, []).append(
|
||||
{"file_path": file_path, "metadata": metadata_payload}
|
||||
)
|
||||
|
||||
return entries
|
||||
|
||||
async def _fetch_license_info(
|
||||
self,
|
||||
provider,
|
||||
model_ids: List[int],
|
||||
) -> Dict[int, Dict[str, Any]]:
|
||||
if not model_ids:
|
||||
return {}
|
||||
|
||||
response = await provider.get_model_versions_bulk(model_ids)
|
||||
if not isinstance(response, Mapping):
|
||||
return {}
|
||||
|
||||
license_map: Dict[int, Dict[str, Any]] = {}
|
||||
for raw_id, payload in response.items():
|
||||
normalized_id = self._normalize_model_id(raw_id)
|
||||
if normalized_id is None or not isinstance(payload, Mapping):
|
||||
continue
|
||||
license_data: Dict[str, Any] = {}
|
||||
for field in LICENSE_FIELDS:
|
||||
license_data[field] = payload.get(field)
|
||||
license_map[normalized_id] = license_data
|
||||
return license_map
|
||||
|
||||
def _extract_target_model_ids(self, payload: Dict) -> Optional[List[int]]:
|
||||
if not isinstance(payload, Mapping):
|
||||
return None
|
||||
|
||||
raw_ids = payload.get("modelIds")
|
||||
if raw_ids is None:
|
||||
raw_ids = payload.get("model_ids")
|
||||
|
||||
if not isinstance(raw_ids, (list, tuple, set)):
|
||||
return None
|
||||
|
||||
normalized: List[int] = []
|
||||
for candidate in raw_ids:
|
||||
model_id = self._normalize_model_id(candidate)
|
||||
if model_id is not None:
|
||||
normalized.append(model_id)
|
||||
|
||||
if not normalized:
|
||||
return None
|
||||
|
||||
return sorted(set(normalized))
|
||||
|
||||
@staticmethod
|
||||
def _convert_metadata_to_dict(metadata: Any) -> Dict[str, Any]:
|
||||
if metadata is None:
|
||||
return {}
|
||||
|
||||
to_dict = getattr(metadata, "to_dict", None)
|
||||
if callable(to_dict):
|
||||
try:
|
||||
return to_dict()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if isinstance(metadata, Mapping):
|
||||
return dict(metadata)
|
||||
|
||||
return {}
|
||||
|
||||
async def _read_json(self, request: web.Request) -> Dict:
|
||||
if not request.can_read_body:
|
||||
return {}
|
||||
@@ -1401,6 +1602,7 @@ class ModelHandlerSet:
|
||||
"get_model_description": self.query.get_model_description,
|
||||
"get_relative_paths": self.query.get_relative_paths,
|
||||
"refresh_model_updates": self.updates.refresh_model_updates,
|
||||
"fetch_missing_civitai_license_data": self.updates.fetch_missing_civitai_license_data,
|
||||
"set_model_update_ignore": self.updates.set_model_update_ignore,
|
||||
"set_version_update_ignore": self.updates.set_version_update_ignore,
|
||||
"get_model_update_status": self.updates.get_model_update_status,
|
||||
|
||||
@@ -56,6 +56,7 @@ COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/version/{modelVersionId}", "get_civitai_model_by_version"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/hash/{hash}", "get_civitai_model_by_hash"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/updates/refresh", "refresh_model_updates"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/updates/fetch-missing-license", "fetch_missing_civitai_license_data"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/updates/ignore", "set_model_update_ignore"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/updates/ignore-version", "set_version_update_ignore"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/updates/status/{model_id}", "get_model_update_status"),
|
||||
@@ -103,4 +104,3 @@ class ModelRouteRegistrar:
|
||||
add_method_name = self._METHOD_MAP[method.upper()]
|
||||
add_method = getattr(self._app.router, add_method_name)
|
||||
add_method(path, handler)
|
||||
|
||||
|
||||
@@ -246,6 +246,10 @@ class CivitaiClient:
|
||||
'modelVersions': item.get('modelVersions', []),
|
||||
'type': item.get('type', ''),
|
||||
'name': item.get('name', ''),
|
||||
'allowNoCredit': item.get('allowNoCredit'),
|
||||
'allowCommercialUse': item.get('allowCommercialUse'),
|
||||
'allowDerivatives': item.get('allowDerivatives'),
|
||||
'allowDifferentLicense': item.get('allowDifferentLicense'),
|
||||
}
|
||||
return payload
|
||||
except RateLimitError:
|
||||
|
||||
Reference in New Issue
Block a user