diff --git a/locales/de.json b/locales/de.json index 9e683df9..e2b0bcf5 100644 --- a/locales/de.json +++ b/locales/de.json @@ -152,6 +152,13 @@ "none": "Keine Beispielbild-Ordner mussten bereinigt werden", "partial": "Bereinigung abgeschlossen, {failures} Ordner übersprungen", "error": "Fehler beim Bereinigen der Beispielbild-Ordner: {message}" + }, + "fetchMissingLicenses": { + "label": "Refresh license metadata", + "loading": "Refreshing license metadata for {typePlural}...", + "success": "Updated license metadata for {count} {typePlural}", + "none": "All {typePlural} already have license metadata", + "error": "Failed to refresh license metadata for {typePlural}: {message}" } }, "header": { diff --git a/locales/en.json b/locales/en.json index c53e6b3d..5298e2e8 100644 --- a/locales/en.json +++ b/locales/en.json @@ -152,6 +152,13 @@ "none": "No example image folders needed cleanup", "partial": "Cleanup completed with {failures} folder(s) skipped", "error": "Failed to clean example image folders: {message}" + }, + "fetchMissingLicenses": { + "label": "Refresh license metadata", + "loading": "Refreshing license metadata for {typePlural}...", + "success": "Updated license metadata for {count} {typePlural}", + "none": "All {typePlural} already have license metadata", + "error": "Failed to refresh license metadata for {typePlural}: {message}" } }, "header": { diff --git a/locales/es.json b/locales/es.json index 29f28544..9387ff48 100644 --- a/locales/es.json +++ b/locales/es.json @@ -152,6 +152,13 @@ "none": "No hay carpetas de imágenes de ejemplo que necesiten limpieza", "partial": "Limpieza completada con {failures} carpeta(s) omitidas", "error": "No se pudieron limpiar las carpetas de imágenes de ejemplo: {message}" + }, + "fetchMissingLicenses": { + "label": "Refresh license metadata", + "loading": "Refreshing license metadata for {typePlural}...", + "success": "Updated license metadata for {count} {typePlural}", + "none": "All {typePlural} already have license metadata", + "error": "Failed to refresh license metadata for {typePlural}: {message}" } }, "header": { diff --git a/locales/fr.json b/locales/fr.json index 7af2d2a2..d8448828 100644 --- a/locales/fr.json +++ b/locales/fr.json @@ -152,6 +152,13 @@ "none": "Aucun dossier d'images d'exemple à nettoyer", "partial": "Nettoyage terminé avec {failures} dossier(s) ignoré(s)", "error": "Échec du nettoyage des dossiers d'images d'exemple : {message}" + }, + "fetchMissingLicenses": { + "label": "Refresh license metadata", + "loading": "Refreshing license metadata for {typePlural}...", + "success": "Updated license metadata for {count} {typePlural}", + "none": "All {typePlural} already have license metadata", + "error": "Failed to refresh license metadata for {typePlural}: {message}" } }, "header": { diff --git a/locales/he.json b/locales/he.json index e9a6f43d..9afc8b76 100644 --- a/locales/he.json +++ b/locales/he.json @@ -152,6 +152,13 @@ "none": "אין תיקיות תמונות דוגמה שזקוקות לניקוי", "partial": "הניקוי הושלם עם דילוג על {failures} תיקיות", "error": "ניקוי תיקיות תמונות הדוגמה נכשל: {message}" + }, + "fetchMissingLicenses": { + "label": "Refresh license metadata", + "loading": "Refreshing license metadata for {typePlural}...", + "success": "Updated license metadata for {count} {typePlural}", + "none": "All {typePlural} already have license metadata", + "error": "Failed to refresh license metadata for {typePlural}: {message}" } }, "header": { diff --git a/locales/ja.json b/locales/ja.json index 182ec174..d2fdc676 100644 --- a/locales/ja.json +++ b/locales/ja.json @@ -152,6 +152,13 @@ "none": "クリーンアップが必要な例画像フォルダはありません", "partial": "クリーンアップが完了しましたが、{failures} 個のフォルダはスキップされました", "error": "例画像フォルダのクリーンアップに失敗しました:{message}" + }, + "fetchMissingLicenses": { + "label": "Refresh license metadata", + "loading": "Refreshing license metadata for {typePlural}...", + "success": "Updated license metadata for {count} {typePlural}", + "none": "All {typePlural} already have license metadata", + "error": "Failed to refresh license metadata for {typePlural}: {message}" } }, "header": { diff --git a/locales/ko.json b/locales/ko.json index 24b8ba45..0fa01db3 100644 --- a/locales/ko.json +++ b/locales/ko.json @@ -152,6 +152,13 @@ "none": "정리가 필요한 예시 이미지 폴더가 없습니다", "partial": "정리가 완료되었으나 {failures}개의 폴더가 건너뛰어졌습니다", "error": "예시 이미지 폴더 정리에 실패했습니다: {message}" + }, + "fetchMissingLicenses": { + "label": "Refresh license metadata", + "loading": "Refreshing license metadata for {typePlural}...", + "success": "Updated license metadata for {count} {typePlural}", + "none": "All {typePlural} already have license metadata", + "error": "Failed to refresh license metadata for {typePlural}: {message}" } }, "header": { diff --git a/locales/ru.json b/locales/ru.json index 9630fc2b..6e6e2615 100644 --- a/locales/ru.json +++ b/locales/ru.json @@ -152,6 +152,13 @@ "none": "Нет папок с примерами изображений, требующих очистки", "partial": "Очистка завершена, пропущено {failures} папок", "error": "Не удалось очистить папки с примерами изображений: {message}" + }, + "fetchMissingLicenses": { + "label": "Refresh license metadata", + "loading": "Refreshing license metadata for {typePlural}...", + "success": "Updated license metadata for {count} {typePlural}", + "none": "All {typePlural} already have license metadata", + "error": "Failed to refresh license metadata for {typePlural}: {message}" } }, "header": { diff --git a/locales/zh-CN.json b/locales/zh-CN.json index cea7096f..29e56f35 100644 --- a/locales/zh-CN.json +++ b/locales/zh-CN.json @@ -152,6 +152,13 @@ "none": "没有需要清理的示例图片文件夹", "partial": "清理完成,有 {failures} 个文件夹跳过", "error": "清理示例图片文件夹失败:{message}" + }, + "fetchMissingLicenses": { + "label": "Refresh license metadata", + "loading": "Refreshing license metadata for {typePlural}...", + "success": "Updated license metadata for {count} {typePlural}", + "none": "All {typePlural} already have license metadata", + "error": "Failed to refresh license metadata for {typePlural}: {message}" } }, "header": { diff --git a/locales/zh-TW.json b/locales/zh-TW.json index 97820dcd..cc0b67a7 100644 --- a/locales/zh-TW.json +++ b/locales/zh-TW.json @@ -152,6 +152,13 @@ "none": "沒有需要清理的範例圖片資料夾", "partial": "清理完成,有 {failures} 個資料夾略過", "error": "清理範例圖片資料夾失敗:{message}" + }, + "fetchMissingLicenses": { + "label": "Refresh license metadata", + "loading": "Refreshing license metadata for {typePlural}...", + "success": "Updated license metadata for {count} {typePlural}", + "none": "All {typePlural} already have license metadata", + "error": "Failed to refresh license metadata for {typePlural}: {message}" } }, "header": { diff --git a/package-lock.json b/package-lock.json index 875f7d6a..62fa3fe9 100644 --- a/package-lock.json +++ b/package-lock.json @@ -114,6 +114,7 @@ } ], "license": "MIT", + "peer": true, "engines": { "node": ">=18" }, @@ -137,6 +138,7 @@ } ], "license": "MIT", + "peer": true, "engines": { "node": ">=18" } @@ -1611,6 +1613,7 @@ "integrity": "sha512-MyL55p3Ut3cXbeBEG7Hcv0mVM8pp8PBNWxRqchZnSfAiES1v1mRnMeFfaHWIPULpwsYfvO+ZmMZz5tGCnjzDUQ==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "cssstyle": "^4.0.1", "data-urls": "^5.0.0", diff --git a/py/routes/handlers/model_handlers.py b/py/routes/handlers/model_handlers.py index dd79e099..fee783dd 100644 --- a/py/routes/handlers/model_handlers.py +++ b/py/routes/handlers/model_handlers.py @@ -6,7 +6,7 @@ import json import logging import os from dataclasses import dataclass -from typing import Awaitable, Callable, Dict, Iterable, List, Mapping, Optional +from typing import Any, Awaitable, Callable, Dict, Iterable, List, Mapping, Optional from aiohttp import web import jinja2 @@ -30,9 +30,17 @@ from ...services.use_cases import ( from ...services.websocket_manager import WebSocketManager from ...services.websocket_progress_callback import WebSocketProgressCallback from ...services.errors import RateLimitError, ResourceNotFoundError +from ...utils.civitai_utils import resolve_license_payload from ...utils.file_utils import calculate_sha256 from ...utils.metadata_manager import MetadataManager +LICENSE_FIELDS = ( + "allowNoCredit", + "allowCommercialUse", + "allowDerivatives", + "allowDifferentLicense", +) + class ModelPageView: """Render the HTML view for model listings.""" @@ -1083,6 +1091,77 @@ class ModelUpdateHandler: self._metadata_provider_selector = metadata_provider_selector self._logger = logger + async def fetch_missing_civitai_license_data(self, request: web.Request) -> web.Response: + payload = await self._read_json(request) + target_model_ids = self._extract_target_model_ids(payload) + + provider = await self._get_civitai_provider() + if provider is None: + return web.json_response( + {"success": False, "error": "Civitai provider not available"}, + status=503, + ) + + try: + cache = await self._service.scanner.get_cached_data() + except Exception as exc: + self._logger.error("Failed to load cache for license refresh: %s", exc, exc_info=True) + cache = None + + target_set = set(target_model_ids) if target_model_ids is not None else None + candidates = await self._collect_models_missing_license(cache, target_set) + if not candidates: + return web.json_response({"success": True, "updated": []}) + + model_ids = sorted(candidates.keys()) + try: + license_map = await self._fetch_license_info(provider, model_ids) + except RateLimitError as exc: + return web.json_response( + {"success": False, "error": str(exc) or "Rate limited"}, + status=429, + ) + except Exception as exc: # pragma: no cover - defensive log + self._logger.error("Failed to fetch license info: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + updated: List[Dict[str, str]] = [] + errors: List[Dict[str, str]] = [] + for model_id in model_ids: + license_payload = license_map.get(model_id) + if not license_payload: + continue + resolved_payload = resolve_license_payload(license_payload) + for context in candidates.get(model_id, []): + metadata_path = context["file_path"] + metadata_payload = context["metadata"] + civitai_section = metadata_payload.setdefault("civitai", {}) + model_section = civitai_section.get("model") + if not isinstance(model_section, Mapping): + model_section = {} + model_section.update(resolved_payload) + civitai_section["model"] = model_section + metadata_payload["civitai"] = civitai_section + try: + await MetadataManager.save_metadata(metadata_path, metadata_payload) + updated.append({"modelId": model_id, "filePath": metadata_path}) + except Exception as exc: + self._logger.error( + "Failed to save metadata for %s: %s", + metadata_path, + exc, + exc_info=True, + ) + errors.append({"filePath": metadata_path, "error": str(exc)}) + + response_payload = {"success": True, "updated": updated} + missing_model_ids = [mid for mid in model_ids if mid not in license_map] + if missing_model_ids: + response_payload["missingModelIds"] = missing_model_ids + if errors: + response_payload["errors"] = errors + return web.json_response(response_payload) + async def refresh_model_updates(self, request: web.Request) -> web.Response: payload = await self._read_json(request) force_refresh = self._parse_bool(request.query.get("force")) or self._parse_bool( @@ -1247,6 +1326,132 @@ class ModelUpdateHandler: self._logger.error("Failed to acquire civitai provider: %s", exc, exc_info=True) return None + async def _collect_models_missing_license( + self, + cache, + target_model_ids: Optional[set[int]], + ) -> Dict[int, List[Dict[str, Any]]]: + entries: Dict[int, List[Dict[str, Any]]] = {} + if cache is None: + return entries + + raw_data = getattr(cache, "raw_data", None) or [] + seen_paths: set[str] = set() + target_set = target_model_ids + + for item in raw_data: + if not isinstance(item, Mapping): + continue + file_path = item.get("file_path") + if not isinstance(file_path, str) or not file_path or file_path in seen_paths: + continue + seen_paths.add(file_path) + + civitai_entry = item.get("civitai") + if not isinstance(civitai_entry, Mapping): + continue + + model_id = self._normalize_model_id(civitai_entry.get("modelId")) + if model_id is None: + continue + if target_set is not None and model_id not in target_set: + continue + + try: + metadata_obj, should_skip = await MetadataManager.load_metadata(file_path) + except Exception as exc: + self._logger.debug("Failed to load metadata for %s: %s", file_path, exc) + continue + if metadata_obj is None or should_skip: + continue + + metadata_payload = self._convert_metadata_to_dict(metadata_obj) + civitai_payload = metadata_payload.get("civitai") + if not isinstance(civitai_payload, Mapping): + civitai_payload = {} + + model_payload = civitai_payload.get("model") + if not isinstance(model_payload, Mapping): + model_payload = {} + + missing = [key for key in LICENSE_FIELDS if key not in model_payload] + if not missing: + continue + + civitai_payload["model"] = model_payload + metadata_payload["civitai"] = civitai_payload + entries.setdefault(model_id, []).append( + {"file_path": file_path, "metadata": metadata_payload} + ) + + return entries + + async def _fetch_license_info( + self, + provider, + model_ids: List[int], + ) -> Dict[int, Dict[str, Any]]: + if not model_ids: + return {} + + BATCH_SIZE = 100 + aggregated: Dict[int, Dict[str, Any]] = {} + for start in range(0, len(model_ids), BATCH_SIZE): + chunk = model_ids[start : start + BATCH_SIZE] + response = await provider.get_model_versions_bulk(chunk) + if not isinstance(response, Mapping): + continue + + for raw_id, payload in response.items(): + normalized_id = self._normalize_model_id(raw_id) + if normalized_id is None or not isinstance(payload, Mapping): + continue + license_data: Dict[str, Any] = {} + for field in LICENSE_FIELDS: + license_data[field] = payload.get(field) + aggregated[normalized_id] = license_data + + return aggregated + + def _extract_target_model_ids(self, payload: Dict) -> Optional[List[int]]: + if not isinstance(payload, Mapping): + return None + + raw_ids = payload.get("modelIds") + if raw_ids is None: + raw_ids = payload.get("model_ids") + + if not isinstance(raw_ids, (list, tuple, set)): + return None + + normalized: List[int] = [] + for candidate in raw_ids: + model_id = self._normalize_model_id(candidate) + if model_id is not None: + normalized.append(model_id) + + if not normalized: + return None + + return sorted(set(normalized)) + + @staticmethod + def _convert_metadata_to_dict(metadata: Any) -> Dict[str, Any]: + if metadata is None: + return {} + + to_dict = getattr(metadata, "to_dict", None) + if callable(to_dict): + try: + return to_dict() + except Exception: + pass + + if isinstance(metadata, Mapping): + return dict(metadata) + + return {} + async def _read_json(self, request: web.Request) -> Dict: if not request.can_read_body: return {} @@ -1401,6 +1606,7 @@ class ModelHandlerSet: "get_model_description": self.query.get_model_description, "get_relative_paths": self.query.get_relative_paths, "refresh_model_updates": self.updates.refresh_model_updates, + "fetch_missing_civitai_license_data": self.updates.fetch_missing_civitai_license_data, "set_model_update_ignore": self.updates.set_model_update_ignore, "set_version_update_ignore": self.updates.set_version_update_ignore, "get_model_update_status": self.updates.get_model_update_status, diff --git a/py/routes/model_route_registrar.py b/py/routes/model_route_registrar.py index 12b36850..ce7a75ba 100644 --- a/py/routes/model_route_registrar.py +++ b/py/routes/model_route_registrar.py @@ -56,6 +56,7 @@ COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/version/{modelVersionId}", "get_civitai_model_by_version"), RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/hash/{hash}", "get_civitai_model_by_hash"), RouteDefinition("POST", "/api/lm/{prefix}/updates/refresh", "refresh_model_updates"), + RouteDefinition("POST", "/api/lm/{prefix}/updates/fetch-missing-license", "fetch_missing_civitai_license_data"), RouteDefinition("POST", "/api/lm/{prefix}/updates/ignore", "set_model_update_ignore"), RouteDefinition("POST", "/api/lm/{prefix}/updates/ignore-version", "set_version_update_ignore"), RouteDefinition("GET", "/api/lm/{prefix}/updates/status/{model_id}", "get_model_update_status"), @@ -103,4 +104,3 @@ class ModelRouteRegistrar: add_method_name = self._METHOD_MAP[method.upper()] add_method = getattr(self._app.router, add_method_name) add_method(path, handler) - diff --git a/py/services/civitai_client.py b/py/services/civitai_client.py index a2ec8ed9..503988b7 100644 --- a/py/services/civitai_client.py +++ b/py/services/civitai_client.py @@ -246,6 +246,10 @@ class CivitaiClient: 'modelVersions': item.get('modelVersions', []), 'type': item.get('type', ''), 'name': item.get('name', ''), + 'allowNoCredit': item.get('allowNoCredit'), + 'allowCommercialUse': item.get('allowCommercialUse'), + 'allowDerivatives': item.get('allowDerivatives'), + 'allowDifferentLicense': item.get('allowDifferentLicense'), } return payload except RateLimitError: diff --git a/py/utils/metadata_manager.py b/py/utils/metadata_manager.py index e0b3d3c1..841284ac 100644 --- a/py/utils/metadata_manager.py +++ b/py/utils/metadata_manager.py @@ -22,7 +22,7 @@ class MetadataManager: """ @staticmethod - async def load_metadata(file_path: str, model_class: Type[BaseModelMetadata] = LoraMetadata) -> Optional[BaseModelMetadata]: + async def load_metadata(file_path: str, model_class: Type[BaseModelMetadata] = LoraMetadata) -> tuple[Optional[BaseModelMetadata], bool]: """ Load metadata safely. diff --git a/static/js/api/apiConfig.js b/static/js/api/apiConfig.js index bac64710..04210615 100644 --- a/static/js/api/apiConfig.js +++ b/static/js/api/apiConfig.js @@ -77,6 +77,7 @@ export function getApiEndpoints(modelType) { relinkCivitai: `/api/lm/${modelType}/relink-civitai`, civitaiVersions: `/api/lm/${modelType}/civitai/versions`, refreshUpdates: `/api/lm/${modelType}/updates/refresh`, + fetchMissingLicenses: `/api/lm/${modelType}/updates/fetch-missing-license`, modelUpdateStatus: `/api/lm/${modelType}/updates/status`, modelUpdateVersions: `/api/lm/${modelType}/updates/versions`, ignoreModelUpdate: `/api/lm/${modelType}/updates/ignore`, diff --git a/static/js/components/ContextMenu/GlobalContextMenu.js b/static/js/components/ContextMenu/GlobalContextMenu.js index 2ae3bb5d..ec4f794a 100644 --- a/static/js/components/ContextMenu/GlobalContextMenu.js +++ b/static/js/components/ContextMenu/GlobalContextMenu.js @@ -1,6 +1,8 @@ import { BaseContextMenu } from './BaseContextMenu.js'; import { showToast } from '../../utils/uiHelpers.js'; +import { translate } from '../../utils/i18nHelpers.js'; import { state } from '../../state/index.js'; +import { getCompleteApiConfig, getCurrentModelType } from '../../api/apiConfig.js'; import { performModelUpdateCheck } from '../../utils/updateCheckHelpers.js'; export class GlobalContextMenu extends BaseContextMenu { @@ -8,6 +10,7 @@ export class GlobalContextMenu extends BaseContextMenu { super('globalContextMenu'); this._cleanupInProgress = false; this._updateCheckInProgress = false; + this._licenseRefreshInProgress = false; } showMenu(x, y, origin = null) { @@ -32,6 +35,11 @@ export class GlobalContextMenu extends BaseContextMenu { console.error('Failed to check model updates:', error); }); break; + case 'fetch-missing-licenses': + this.fetchMissingLicenses(menuItem).catch((error) => { + console.error('Failed to refresh missing license metadata:', error); + }); + break; default: console.warn(`Unhandled global context menu action: ${action}`); break; @@ -133,4 +141,98 @@ export class GlobalContextMenu extends BaseContextMenu { } } } + + async fetchMissingLicenses(menuItem) { + if (this._licenseRefreshInProgress) { + return; + } + + const modelType = getCurrentModelType(); + const apiConfig = getCompleteApiConfig(modelType); + const displayName = apiConfig?.config?.displayName ?? 'Model'; + const typePlural = this._buildTypePlural(displayName); + const loadingMessage = translate( + 'globalContextMenu.fetchMissingLicenses.loading', + { type: displayName, typePlural }, + `Refreshing license metadata for ${typePlural}...` + ); + + const endpoint = apiConfig?.endpoints?.fetchMissingLicenses; + if (!endpoint) { + console.warn('Fetch missing license endpoint not configured for model type:', modelType); + showToast( + 'globalContextMenu.fetchMissingLicenses.error', + { message: 'Endpoint unavailable', type: displayName, typePlural }, + 'warning' + ); + return; + } + + this._licenseRefreshInProgress = true; + menuItem?.classList?.add('disabled'); + state.loadingManager?.showSimpleLoading?.(loadingMessage); + + try { + const response = await fetch(endpoint, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({}), + }); + + let payload = {}; + try { + payload = await response.json(); + } catch { + payload = {}; + } + + if (!response.ok || payload.success !== true) { + const errorMessage = payload?.error || response.statusText || 'Unknown error'; + throw new Error(errorMessage); + } + + const updated = Array.isArray(payload.updated) ? payload.updated : []; + if (updated.length > 0) { + showToast( + 'globalContextMenu.fetchMissingLicenses.success', + { count: updated.length, type: displayName, typePlural }, + 'success' + ); + } else { + showToast( + 'globalContextMenu.fetchMissingLicenses.none', + { type: displayName, typePlural }, + 'info' + ); + } + } catch (error) { + console.error('Failed to refresh missing license metadata:', error); + showToast( + 'globalContextMenu.fetchMissingLicenses.error', + { message: error?.message ?? 'Unknown error', type: displayName, typePlural }, + 'error' + ); + } finally { + state.loadingManager?.hide?.(); + if (typeof state.loadingManager?.restoreProgressBar === 'function') { + state.loadingManager.restoreProgressBar(); + } + + this._licenseRefreshInProgress = false; + menuItem?.classList?.remove('disabled'); + } + } + + _buildTypePlural(displayName) { + if (!displayName) { + return 'models'; + } + + const lower = displayName.toLowerCase(); + if (lower.endsWith('s')) { + return displayName; + } + + return `${displayName}s`; + } } diff --git a/templates/components/context_menu.html b/templates/components/context_menu.html index d0f967e5..fd5103fc 100644 --- a/templates/components/context_menu.html +++ b/templates/components/context_menu.html @@ -93,6 +93,9 @@
{{ t('globalContextMenu.checkModelUpdates.label') }}
+
+ {{ t('globalContextMenu.fetchMissingLicenses.label') }} +
{{ t('globalContextMenu.cleanupExampleImages.label') }}
diff --git a/tests/frontend/components/contextMenu.interactions.test.js b/tests/frontend/components/contextMenu.interactions.test.js index f6230b27..a7d6af9f 100644 --- a/tests/frontend/components/contextMenu.interactions.test.js +++ b/tests/frontend/components/contextMenu.interactions.test.js @@ -44,7 +44,10 @@ const refreshSingleModelMetadataMock = vi.fn(); const resetAndReloadMock = vi.fn(); const getCompleteApiConfigMock = vi.fn(() => ({ config: { displayName: 'LoRA' }, - endpoints: { refreshUpdates: '/api/lm/loras/updates/refresh' }, + endpoints: { + refreshUpdates: '/api/lm/loras/updates/refresh', + fetchMissingLicenses: '/api/lm/loras/updates/fetch-missing-license', + }, })); const getCurrentModelTypeMock = vi.fn(() => 'loras'); @@ -150,7 +153,10 @@ describe('Interaction-level regression coverage', () => { resetAndReloadMock.mockResolvedValue(undefined); getCompleteApiConfigMock.mockReturnValue({ config: { displayName: 'LoRA' }, - endpoints: { refreshUpdates: '/api/lm/loras/updates/refresh' }, + endpoints: { + refreshUpdates: '/api/lm/loras/updates/refresh', + fetchMissingLicenses: '/api/lm/loras/updates/fetch-missing-license', + }, }); getCurrentModelTypeMock.mockReturnValue('loras'); translateMock.mockImplementation((key, params, fallback) => (typeof fallback === 'string' ? fallback : key)); @@ -322,8 +328,9 @@ describe('Interaction-level regression coverage', () => { document.body.innerHTML = `
-
+
+
`; @@ -354,6 +361,10 @@ describe('Interaction-level regression coverage', () => { .mockResolvedValueOnce({ ok: true, json: async () => ({ success: true, records: [{ id: 1 }] }), + }) + .mockResolvedValueOnce({ + ok: true, + json: async () => ({ success: true, updated: [{ modelId: 42 }] }), }); menu.showMenu(240, 320); @@ -379,7 +390,7 @@ describe('Interaction-level regression coverage', () => { await flushAsyncTasks(); - expect(global.fetch).toHaveBeenLastCalledWith('/api/lm/loras/updates/refresh', { + expect(global.fetch).toHaveBeenNthCalledWith(2, '/api/lm/loras/updates/refresh', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ force: false }), @@ -398,5 +409,28 @@ describe('Interaction-level regression coverage', () => { expect(loadingManagerStub.hide).toHaveBeenCalled(); expect(resetAndReloadMock).toHaveBeenCalledWith(false); expect(checkUpdatesItem.classList.contains('disabled')).toBe(false); + + menu.showMenu(480, 520); + const fetchMissingItem = document.querySelector('[data-action="fetch-missing-licenses"]'); + fetchMissingItem.dispatchEvent(new Event('click', { bubbles: true })); + expect(fetchMissingItem.classList.contains('disabled')).toBe(true); + + const fetchMissingResponse = await global.fetch.mock.results[2].value; + await fetchMissingResponse.json(); + await flushAsyncTasks(); + + expect(global.fetch).toHaveBeenNthCalledWith(3, '/api/lm/loras/updates/fetch-missing-license', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({}), + }); + + expect(showToastMock).toHaveBeenCalledWith( + 'globalContextMenu.fetchMissingLicenses.success', + { count: 1, type: 'LoRA', typePlural: 'LoRAs' }, + 'success' + ); + expect(loadingManagerStub.showSimpleLoading).toHaveBeenNthCalledWith(2, 'Refreshing license metadata for LoRAs...'); + expect(fetchMissingItem.classList.contains('disabled')).toBe(false); }); }); diff --git a/tests/routes/test_model_update_handler.py b/tests/routes/test_model_update_handler.py index 79278ef7..e0285081 100644 --- a/tests/routes/test_model_update_handler.py +++ b/tests/routes/test_model_update_handler.py @@ -1,3 +1,4 @@ +import copy import json import logging from types import SimpleNamespace @@ -6,6 +7,7 @@ import pytest from py.config import config from py.routes.handlers.model_handlers import ModelUpdateHandler +from py.utils.metadata_manager import MetadataManager from py.services.model_update_service import ModelUpdateRecord, ModelVersionRecord @@ -264,3 +266,171 @@ async def test_refresh_model_updates_accepts_snake_case_ids(): call = update_service.calls[0] assert call["target_model_ids"] == [3, 4] + + +@pytest.mark.asyncio +async def test_fetch_missing_license_data_updates_metadata(monkeypatch): + cache = SimpleNamespace( + raw_data=[ + {"file_path": "/tmp/model1.safetensors", "civitai": {"modelId": 10}}, + {"file_path": "/tmp/model2.safetensors", "civitai": {"modelId": 10}}, + {"file_path": "/tmp/model3.safetensors", "civitai": {"modelId": 20}}, + ], + version_index={}, + ) + + metadata_store = { + "/tmp/model1.safetensors": {"civitai": {"model": {}}}, + "/tmp/model2.safetensors": {"civitai": {"model": {}}}, + "/tmp/model3.safetensors": {"civitai": {"model": {}}}, + } + + async def fake_load(path: str): + data = metadata_store.get(path) + if data is None: + return None, False + return SimpleNamespace(to_dict=lambda: copy.deepcopy(data)), False + + saved: list[tuple[str, dict]] = [] + + async def fake_save(path: str, metadata: dict): + saved.append((path, copy.deepcopy(metadata))) + return True + + monkeypatch.setattr(MetadataManager, "load_metadata", staticmethod(fake_load)) + monkeypatch.setattr(MetadataManager, "save_metadata", staticmethod(fake_save)) + + provider_calls: list[list[int]] = [] + + async def fake_bulk(model_ids): + provider_calls.append(list(model_ids)) + return { + 10: { + "allowNoCredit": True, + "allowCommercialUse": ["Sell"], + "allowDerivatives": True, + "allowDifferentLicense": True, + }, + 20: { + "allowNoCredit": False, + "allowCommercialUse": ["Image"], + "allowDerivatives": False, + "allowDifferentLicense": False, + }, + } + + provider = SimpleNamespace() + provider.get_model_versions_bulk = fake_bulk + + async def metadata_selector(name): + assert name == "civitai_api" + return provider + + handler = ModelUpdateHandler( + service=DummyService(cache), + update_service=SimpleNamespace(), + metadata_provider_selector=metadata_selector, + logger=logging.getLogger(__name__), + ) + + class DummyRequest: + can_read_body = True + query = {} + + async def json(self): + return {} + + response = await handler.fetch_missing_civitai_license_data(DummyRequest()) + assert response.status == 200 + + payload = json.loads(response.text) + assert payload["success"] is True + assert len(payload["updated"]) == 3 + assert provider_calls == [[10, 20]] + assert len(saved) == 3 + + first_metadata = saved[0][1] + assert first_metadata["civitai"]["model"]["allowNoCredit"] is True + assert first_metadata["civitai"]["model"]["allowCommercialUse"] == ["Sell"] + assert "missingModelIds" not in payload + assert "errors" not in payload + + +@pytest.mark.asyncio +async def test_fetch_missing_license_data_filters_model_ids(monkeypatch): + cache = SimpleNamespace( + raw_data=[ + {"file_path": "/tmp/model1.safetensors", "civitai": {"modelId": 10}}, + {"file_path": "/tmp/model2.safetensors", "civitai": {"modelId": 20}}, + ], + version_index={}, + ) + + metadata_store = { + "/tmp/model1.safetensors": {"civitai": {"model": {}}}, + "/tmp/model2.safetensors": {"civitai": {"model": {}}}, + } + + async def fake_load(path: str): + data = metadata_store.get(path) + if data is None: + return None, False + return SimpleNamespace(to_dict=lambda: copy.deepcopy(data)), False + + saved: list[tuple[str, dict]] = [] + + async def fake_save(path: str, metadata: dict): + saved.append((path, copy.deepcopy(metadata))) + return True + + monkeypatch.setattr(MetadataManager, "load_metadata", staticmethod(fake_load)) + monkeypatch.setattr(MetadataManager, "save_metadata", staticmethod(fake_save)) + + provider_calls: list[list[int]] = [] + + async def fake_bulk(model_ids): + provider_calls.append(list(model_ids)) + return { + 10: { + "allowNoCredit": True, + "allowCommercialUse": ["Sell"], + "allowDerivatives": True, + "allowDifferentLicense": True, + }, + 20: { + "allowNoCredit": False, + "allowCommercialUse": ["Image"], + "allowDerivatives": False, + "allowDifferentLicense": False, + }, + } + + provider = SimpleNamespace() + provider.get_model_versions_bulk = fake_bulk + + async def metadata_selector(name): + assert name == "civitai_api" + return provider + + handler = ModelUpdateHandler( + service=DummyService(cache), + update_service=SimpleNamespace(), + metadata_provider_selector=metadata_selector, + logger=logging.getLogger(__name__), + ) + + class DummyRequest: + can_read_body = True + query = {} + + async def json(self): + return {"modelIds": [20]} + + response = await handler.fetch_missing_civitai_license_data(DummyRequest()) + assert response.status == 200 + + payload = json.loads(response.text) + assert payload["success"] is True + assert len(payload["updated"]) == 1 + assert provider_calls == [[20]] + assert len(saved) == 1 diff --git a/tests/services/test_civitai_client.py b/tests/services/test_civitai_client.py index 8edb6fd5..ab3a5ef9 100644 --- a/tests/services/test_civitai_client.py +++ b/tests/services/test_civitai_client.py @@ -204,8 +204,26 @@ async def test_get_model_versions_bulk_success(monkeypatch, downloader): assert kwargs.get("params") == {"ids": "1,2"} return True, { "items": [ - {"id": 1, "modelVersions": [{"id": 11}], "type": "LORA", "name": "One"}, - {"id": 2, "modelVersions": [], "type": "Checkpoint", "name": "Two"}, + { + "id": 1, + "modelVersions": [{"id": 11}], + "type": "LORA", + "name": "One", + "allowNoCredit": True, + "allowCommercialUse": ["Sell"], + "allowDerivatives": True, + "allowDifferentLicense": True, + }, + { + "id": 2, + "modelVersions": [], + "type": "Checkpoint", + "name": "Two", + "allowNoCredit": False, + "allowCommercialUse": ["Image"], + "allowDerivatives": False, + "allowDifferentLicense": False, + }, ] } @@ -216,8 +234,24 @@ async def test_get_model_versions_bulk_success(monkeypatch, downloader): result = await client.get_model_versions_bulk([1, "2", 2]) assert result == { - 1: {"modelVersions": [{"id": 11}], "type": "LORA", "name": "One"}, - 2: {"modelVersions": [], "type": "Checkpoint", "name": "Two"}, + 1: { + "modelVersions": [{"id": 11}], + "type": "LORA", + "name": "One", + "allowNoCredit": True, + "allowCommercialUse": ["Sell"], + "allowDerivatives": True, + "allowDifferentLicense": True, + }, + 2: { + "modelVersions": [], + "type": "Checkpoint", + "name": "Two", + "allowNoCredit": False, + "allowCommercialUse": ["Image"], + "allowDerivatives": False, + "allowDifferentLicense": False, + }, }