feat(context-menu): refresh missing license metadata

This commit is contained in:
Will Miao
2025-11-11 14:24:59 +08:00
parent 4557da8b63
commit 29bb85359e
20 changed files with 633 additions and 10 deletions

View File

@@ -6,7 +6,7 @@ import json
import logging
import os
from dataclasses import dataclass
from typing import Awaitable, Callable, Dict, Iterable, List, Mapping, Optional
from typing import Any, Awaitable, Callable, Dict, Iterable, List, Mapping, Optional
from aiohttp import web
import jinja2
@@ -30,9 +30,17 @@ from ...services.use_cases import (
from ...services.websocket_manager import WebSocketManager
from ...services.websocket_progress_callback import WebSocketProgressCallback
from ...services.errors import RateLimitError, ResourceNotFoundError
from ...utils.civitai_utils import resolve_license_payload
from ...utils.file_utils import calculate_sha256
from ...utils.metadata_manager import MetadataManager
LICENSE_FIELDS = (
"allowNoCredit",
"allowCommercialUse",
"allowDerivatives",
"allowDifferentLicense",
)
class ModelPageView:
"""Render the HTML view for model listings."""
@@ -1083,6 +1091,77 @@ class ModelUpdateHandler:
self._metadata_provider_selector = metadata_provider_selector
self._logger = logger
async def fetch_missing_civitai_license_data(self, request: web.Request) -> web.Response:
payload = await self._read_json(request)
target_model_ids = self._extract_target_model_ids(payload)
provider = await self._get_civitai_provider()
if provider is None:
return web.json_response(
{"success": False, "error": "Civitai provider not available"},
status=503,
)
try:
cache = await self._service.scanner.get_cached_data()
except Exception as exc:
self._logger.error("Failed to load cache for license refresh: %s", exc, exc_info=True)
cache = None
target_set = set(target_model_ids) if target_model_ids is not None else None
candidates = await self._collect_models_missing_license(cache, target_set)
if not candidates:
return web.json_response({"success": True, "updated": []})
model_ids = sorted(candidates.keys())
try:
license_map = await self._fetch_license_info(provider, model_ids)
except RateLimitError as exc:
return web.json_response(
{"success": False, "error": str(exc) or "Rate limited"},
status=429,
)
except Exception as exc: # pragma: no cover - defensive log
self._logger.error("Failed to fetch license info: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
updated: List[Dict[str, str]] = []
errors: List[Dict[str, str]] = []
for model_id in model_ids:
license_payload = license_map.get(model_id)
if not license_payload:
continue
resolved_payload = resolve_license_payload(license_payload)
for context in candidates.get(model_id, []):
metadata_path = context["file_path"]
metadata_payload = context["metadata"]
civitai_section = metadata_payload.setdefault("civitai", {})
model_section = civitai_section.get("model")
if not isinstance(model_section, Mapping):
model_section = {}
model_section.update(resolved_payload)
civitai_section["model"] = model_section
metadata_payload["civitai"] = civitai_section
try:
await MetadataManager.save_metadata(metadata_path, metadata_payload)
updated.append({"modelId": model_id, "filePath": metadata_path})
except Exception as exc:
self._logger.error(
"Failed to save metadata for %s: %s",
metadata_path,
exc,
exc_info=True,
)
errors.append({"filePath": metadata_path, "error": str(exc)})
response_payload = {"success": True, "updated": updated}
missing_model_ids = [mid for mid in model_ids if mid not in license_map]
if missing_model_ids:
response_payload["missingModelIds"] = missing_model_ids
if errors:
response_payload["errors"] = errors
return web.json_response(response_payload)
async def refresh_model_updates(self, request: web.Request) -> web.Response:
payload = await self._read_json(request)
force_refresh = self._parse_bool(request.query.get("force")) or self._parse_bool(
@@ -1247,6 +1326,128 @@ class ModelUpdateHandler:
self._logger.error("Failed to acquire civitai provider: %s", exc, exc_info=True)
return None
async def _collect_models_missing_license(
self,
cache,
target_model_ids: Optional[set[int]],
) -> Dict[int, List[Dict[str, Any]]]:
entries: Dict[int, List[Dict[str, Any]]] = {}
if cache is None:
return entries
raw_data = getattr(cache, "raw_data", None) or []
seen_paths: set[str] = set()
target_set = target_model_ids
for item in raw_data:
if not isinstance(item, Mapping):
continue
file_path = item.get("file_path")
if not isinstance(file_path, str) or not file_path or file_path in seen_paths:
continue
seen_paths.add(file_path)
civitai_entry = item.get("civitai")
if not isinstance(civitai_entry, Mapping):
continue
model_id = self._normalize_model_id(civitai_entry.get("modelId"))
if model_id is None:
continue
if target_set is not None and model_id not in target_set:
continue
try:
metadata_obj, should_skip = await MetadataManager.load_metadata(file_path)
except Exception as exc:
self._logger.debug("Failed to load metadata for %s: %s", file_path, exc)
continue
if metadata_obj is None or should_skip:
continue
metadata_payload = self._convert_metadata_to_dict(metadata_obj)
civitai_payload = metadata_payload.get("civitai")
if not isinstance(civitai_payload, Mapping):
civitai_payload = {}
model_payload = civitai_payload.get("model")
if not isinstance(model_payload, Mapping):
model_payload = {}
missing = [key for key in LICENSE_FIELDS if key not in model_payload]
if not missing:
continue
civitai_payload["model"] = model_payload
metadata_payload["civitai"] = civitai_payload
entries.setdefault(model_id, []).append(
{"file_path": file_path, "metadata": metadata_payload}
)
return entries
async def _fetch_license_info(
self,
provider,
model_ids: List[int],
) -> Dict[int, Dict[str, Any]]:
if not model_ids:
return {}
response = await provider.get_model_versions_bulk(model_ids)
if not isinstance(response, Mapping):
return {}
license_map: Dict[int, Dict[str, Any]] = {}
for raw_id, payload in response.items():
normalized_id = self._normalize_model_id(raw_id)
if normalized_id is None or not isinstance(payload, Mapping):
continue
license_data: Dict[str, Any] = {}
for field in LICENSE_FIELDS:
license_data[field] = payload.get(field)
license_map[normalized_id] = license_data
return license_map
def _extract_target_model_ids(self, payload: Dict) -> Optional[List[int]]:
if not isinstance(payload, Mapping):
return None
raw_ids = payload.get("modelIds")
if raw_ids is None:
raw_ids = payload.get("model_ids")
if not isinstance(raw_ids, (list, tuple, set)):
return None
normalized: List[int] = []
for candidate in raw_ids:
model_id = self._normalize_model_id(candidate)
if model_id is not None:
normalized.append(model_id)
if not normalized:
return None
return sorted(set(normalized))
@staticmethod
def _convert_metadata_to_dict(metadata: Any) -> Dict[str, Any]:
if metadata is None:
return {}
to_dict = getattr(metadata, "to_dict", None)
if callable(to_dict):
try:
return to_dict()
except Exception:
pass
if isinstance(metadata, Mapping):
return dict(metadata)
return {}
async def _read_json(self, request: web.Request) -> Dict:
if not request.can_read_body:
return {}
@@ -1401,6 +1602,7 @@ class ModelHandlerSet:
"get_model_description": self.query.get_model_description,
"get_relative_paths": self.query.get_relative_paths,
"refresh_model_updates": self.updates.refresh_model_updates,
"fetch_missing_civitai_license_data": self.updates.fetch_missing_civitai_license_data,
"set_model_update_ignore": self.updates.set_model_update_ignore,
"set_version_update_ignore": self.updates.set_version_update_ignore,
"get_model_update_status": self.updates.get_model_update_status,

View File

@@ -56,6 +56,7 @@ COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/version/{modelVersionId}", "get_civitai_model_by_version"),
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/hash/{hash}", "get_civitai_model_by_hash"),
RouteDefinition("POST", "/api/lm/{prefix}/updates/refresh", "refresh_model_updates"),
RouteDefinition("POST", "/api/lm/{prefix}/updates/fetch-missing-license", "fetch_missing_civitai_license_data"),
RouteDefinition("POST", "/api/lm/{prefix}/updates/ignore", "set_model_update_ignore"),
RouteDefinition("POST", "/api/lm/{prefix}/updates/ignore-version", "set_version_update_ignore"),
RouteDefinition("GET", "/api/lm/{prefix}/updates/status/{model_id}", "get_model_update_status"),
@@ -103,4 +104,3 @@ class ModelRouteRegistrar:
add_method_name = self._METHOD_MAP[method.upper()]
add_method = getattr(self._app.router, add_method_name)
add_method(path, handler)

View File

@@ -246,6 +246,10 @@ class CivitaiClient:
'modelVersions': item.get('modelVersions', []),
'type': item.get('type', ''),
'name': item.get('name', ''),
'allowNoCredit': item.get('allowNoCredit'),
'allowCommercialUse': item.get('allowCommercialUse'),
'allowDerivatives': item.get('allowDerivatives'),
'allowDifferentLicense': item.get('allowDifferentLicense'),
}
return payload
except RateLimitError: