Merge pull request #658 from willmiao/feature/global-license-refresh

Feature/global license refresh
This commit is contained in:
pixelpaws
2025-11-11 14:54:37 +08:00
committed by GitHub
21 changed files with 638 additions and 11 deletions

View File

@@ -152,6 +152,13 @@
"none": "Keine Beispielbild-Ordner mussten bereinigt werden",
"partial": "Bereinigung abgeschlossen, {failures} Ordner übersprungen",
"error": "Fehler beim Bereinigen der Beispielbild-Ordner: {message}"
},
"fetchMissingLicenses": {
"label": "Refresh license metadata",
"loading": "Refreshing license metadata for {typePlural}...",
"success": "Updated license metadata for {count} {typePlural}",
"none": "All {typePlural} already have license metadata",
"error": "Failed to refresh license metadata for {typePlural}: {message}"
}
},
"header": {

View File

@@ -152,6 +152,13 @@
"none": "No example image folders needed cleanup",
"partial": "Cleanup completed with {failures} folder(s) skipped",
"error": "Failed to clean example image folders: {message}"
},
"fetchMissingLicenses": {
"label": "Refresh license metadata",
"loading": "Refreshing license metadata for {typePlural}...",
"success": "Updated license metadata for {count} {typePlural}",
"none": "All {typePlural} already have license metadata",
"error": "Failed to refresh license metadata for {typePlural}: {message}"
}
},
"header": {

View File

@@ -152,6 +152,13 @@
"none": "No hay carpetas de imágenes de ejemplo que necesiten limpieza",
"partial": "Limpieza completada con {failures} carpeta(s) omitidas",
"error": "No se pudieron limpiar las carpetas de imágenes de ejemplo: {message}"
},
"fetchMissingLicenses": {
"label": "Refresh license metadata",
"loading": "Refreshing license metadata for {typePlural}...",
"success": "Updated license metadata for {count} {typePlural}",
"none": "All {typePlural} already have license metadata",
"error": "Failed to refresh license metadata for {typePlural}: {message}"
}
},
"header": {

View File

@@ -152,6 +152,13 @@
"none": "Aucun dossier d'images d'exemple à nettoyer",
"partial": "Nettoyage terminé avec {failures} dossier(s) ignoré(s)",
"error": "Échec du nettoyage des dossiers d'images d'exemple : {message}"
},
"fetchMissingLicenses": {
"label": "Refresh license metadata",
"loading": "Refreshing license metadata for {typePlural}...",
"success": "Updated license metadata for {count} {typePlural}",
"none": "All {typePlural} already have license metadata",
"error": "Failed to refresh license metadata for {typePlural}: {message}"
}
},
"header": {

View File

@@ -152,6 +152,13 @@
"none": "אין תיקיות תמונות דוגמה שזקוקות לניקוי",
"partial": "הניקוי הושלם עם דילוג על {failures} תיקיות",
"error": "ניקוי תיקיות תמונות הדוגמה נכשל: {message}"
},
"fetchMissingLicenses": {
"label": "Refresh license metadata",
"loading": "Refreshing license metadata for {typePlural}...",
"success": "Updated license metadata for {count} {typePlural}",
"none": "All {typePlural} already have license metadata",
"error": "Failed to refresh license metadata for {typePlural}: {message}"
}
},
"header": {

View File

@@ -152,6 +152,13 @@
"none": "クリーンアップが必要な例画像フォルダはありません",
"partial": "クリーンアップが完了しましたが、{failures} 個のフォルダはスキップされました",
"error": "例画像フォルダのクリーンアップに失敗しました:{message}"
},
"fetchMissingLicenses": {
"label": "Refresh license metadata",
"loading": "Refreshing license metadata for {typePlural}...",
"success": "Updated license metadata for {count} {typePlural}",
"none": "All {typePlural} already have license metadata",
"error": "Failed to refresh license metadata for {typePlural}: {message}"
}
},
"header": {

View File

@@ -152,6 +152,13 @@
"none": "정리가 필요한 예시 이미지 폴더가 없습니다",
"partial": "정리가 완료되었으나 {failures}개의 폴더가 건너뛰어졌습니다",
"error": "예시 이미지 폴더 정리에 실패했습니다: {message}"
},
"fetchMissingLicenses": {
"label": "Refresh license metadata",
"loading": "Refreshing license metadata for {typePlural}...",
"success": "Updated license metadata for {count} {typePlural}",
"none": "All {typePlural} already have license metadata",
"error": "Failed to refresh license metadata for {typePlural}: {message}"
}
},
"header": {

View File

@@ -152,6 +152,13 @@
"none": "Нет папок с примерами изображений, требующих очистки",
"partial": "Очистка завершена, пропущено {failures} папок",
"error": "Не удалось очистить папки с примерами изображений: {message}"
},
"fetchMissingLicenses": {
"label": "Refresh license metadata",
"loading": "Refreshing license metadata for {typePlural}...",
"success": "Updated license metadata for {count} {typePlural}",
"none": "All {typePlural} already have license metadata",
"error": "Failed to refresh license metadata for {typePlural}: {message}"
}
},
"header": {

View File

@@ -152,6 +152,13 @@
"none": "没有需要清理的示例图片文件夹",
"partial": "清理完成,有 {failures} 个文件夹跳过",
"error": "清理示例图片文件夹失败:{message}"
},
"fetchMissingLicenses": {
"label": "Refresh license metadata",
"loading": "Refreshing license metadata for {typePlural}...",
"success": "Updated license metadata for {count} {typePlural}",
"none": "All {typePlural} already have license metadata",
"error": "Failed to refresh license metadata for {typePlural}: {message}"
}
},
"header": {

View File

@@ -152,6 +152,13 @@
"none": "沒有需要清理的範例圖片資料夾",
"partial": "清理完成,有 {failures} 個資料夾略過",
"error": "清理範例圖片資料夾失敗:{message}"
},
"fetchMissingLicenses": {
"label": "Refresh license metadata",
"loading": "Refreshing license metadata for {typePlural}...",
"success": "Updated license metadata for {count} {typePlural}",
"none": "All {typePlural} already have license metadata",
"error": "Failed to refresh license metadata for {typePlural}: {message}"
}
},
"header": {

3
package-lock.json generated
View File

@@ -114,6 +114,7 @@
}
],
"license": "MIT",
"peer": true,
"engines": {
"node": ">=18"
},
@@ -137,6 +138,7 @@
}
],
"license": "MIT",
"peer": true,
"engines": {
"node": ">=18"
}
@@ -1611,6 +1613,7 @@
"integrity": "sha512-MyL55p3Ut3cXbeBEG7Hcv0mVM8pp8PBNWxRqchZnSfAiES1v1mRnMeFfaHWIPULpwsYfvO+ZmMZz5tGCnjzDUQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"cssstyle": "^4.0.1",
"data-urls": "^5.0.0",

View File

@@ -6,7 +6,7 @@ import json
import logging
import os
from dataclasses import dataclass
from typing import Awaitable, Callable, Dict, Iterable, List, Mapping, Optional
from typing import Any, Awaitable, Callable, Dict, Iterable, List, Mapping, Optional
from aiohttp import web
import jinja2
@@ -30,9 +30,17 @@ from ...services.use_cases import (
from ...services.websocket_manager import WebSocketManager
from ...services.websocket_progress_callback import WebSocketProgressCallback
from ...services.errors import RateLimitError, ResourceNotFoundError
from ...utils.civitai_utils import resolve_license_payload
from ...utils.file_utils import calculate_sha256
from ...utils.metadata_manager import MetadataManager
LICENSE_FIELDS = (
"allowNoCredit",
"allowCommercialUse",
"allowDerivatives",
"allowDifferentLicense",
)
class ModelPageView:
"""Render the HTML view for model listings."""
@@ -1083,6 +1091,77 @@ class ModelUpdateHandler:
self._metadata_provider_selector = metadata_provider_selector
self._logger = logger
async def fetch_missing_civitai_license_data(self, request: web.Request) -> web.Response:
payload = await self._read_json(request)
target_model_ids = self._extract_target_model_ids(payload)
provider = await self._get_civitai_provider()
if provider is None:
return web.json_response(
{"success": False, "error": "Civitai provider not available"},
status=503,
)
try:
cache = await self._service.scanner.get_cached_data()
except Exception as exc:
self._logger.error("Failed to load cache for license refresh: %s", exc, exc_info=True)
cache = None
target_set = set(target_model_ids) if target_model_ids is not None else None
candidates = await self._collect_models_missing_license(cache, target_set)
if not candidates:
return web.json_response({"success": True, "updated": []})
model_ids = sorted(candidates.keys())
try:
license_map = await self._fetch_license_info(provider, model_ids)
except RateLimitError as exc:
return web.json_response(
{"success": False, "error": str(exc) or "Rate limited"},
status=429,
)
except Exception as exc: # pragma: no cover - defensive log
self._logger.error("Failed to fetch license info: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
updated: List[Dict[str, str]] = []
errors: List[Dict[str, str]] = []
for model_id in model_ids:
license_payload = license_map.get(model_id)
if not license_payload:
continue
resolved_payload = resolve_license_payload(license_payload)
for context in candidates.get(model_id, []):
metadata_path = context["file_path"]
metadata_payload = context["metadata"]
civitai_section = metadata_payload.setdefault("civitai", {})
model_section = civitai_section.get("model")
if not isinstance(model_section, Mapping):
model_section = {}
model_section.update(resolved_payload)
civitai_section["model"] = model_section
metadata_payload["civitai"] = civitai_section
try:
await MetadataManager.save_metadata(metadata_path, metadata_payload)
updated.append({"modelId": model_id, "filePath": metadata_path})
except Exception as exc:
self._logger.error(
"Failed to save metadata for %s: %s",
metadata_path,
exc,
exc_info=True,
)
errors.append({"filePath": metadata_path, "error": str(exc)})
response_payload = {"success": True, "updated": updated}
missing_model_ids = [mid for mid in model_ids if mid not in license_map]
if missing_model_ids:
response_payload["missingModelIds"] = missing_model_ids
if errors:
response_payload["errors"] = errors
return web.json_response(response_payload)
async def refresh_model_updates(self, request: web.Request) -> web.Response:
payload = await self._read_json(request)
force_refresh = self._parse_bool(request.query.get("force")) or self._parse_bool(
@@ -1247,6 +1326,132 @@ class ModelUpdateHandler:
self._logger.error("Failed to acquire civitai provider: %s", exc, exc_info=True)
return None
async def _collect_models_missing_license(
self,
cache,
target_model_ids: Optional[set[int]],
) -> Dict[int, List[Dict[str, Any]]]:
entries: Dict[int, List[Dict[str, Any]]] = {}
if cache is None:
return entries
raw_data = getattr(cache, "raw_data", None) or []
seen_paths: set[str] = set()
target_set = target_model_ids
for item in raw_data:
if not isinstance(item, Mapping):
continue
file_path = item.get("file_path")
if not isinstance(file_path, str) or not file_path or file_path in seen_paths:
continue
seen_paths.add(file_path)
civitai_entry = item.get("civitai")
if not isinstance(civitai_entry, Mapping):
continue
model_id = self._normalize_model_id(civitai_entry.get("modelId"))
if model_id is None:
continue
if target_set is not None and model_id not in target_set:
continue
try:
metadata_obj, should_skip = await MetadataManager.load_metadata(file_path)
except Exception as exc:
self._logger.debug("Failed to load metadata for %s: %s", file_path, exc)
continue
if metadata_obj is None or should_skip:
continue
metadata_payload = self._convert_metadata_to_dict(metadata_obj)
civitai_payload = metadata_payload.get("civitai")
if not isinstance(civitai_payload, Mapping):
civitai_payload = {}
model_payload = civitai_payload.get("model")
if not isinstance(model_payload, Mapping):
model_payload = {}
missing = [key for key in LICENSE_FIELDS if key not in model_payload]
if not missing:
continue
civitai_payload["model"] = model_payload
metadata_payload["civitai"] = civitai_payload
entries.setdefault(model_id, []).append(
{"file_path": file_path, "metadata": metadata_payload}
)
return entries
async def _fetch_license_info(
self,
provider,
model_ids: List[int],
) -> Dict[int, Dict[str, Any]]:
if not model_ids:
return {}
BATCH_SIZE = 100
aggregated: Dict[int, Dict[str, Any]] = {}
for start in range(0, len(model_ids), BATCH_SIZE):
chunk = model_ids[start : start + BATCH_SIZE]
response = await provider.get_model_versions_bulk(chunk)
if not isinstance(response, Mapping):
continue
for raw_id, payload in response.items():
normalized_id = self._normalize_model_id(raw_id)
if normalized_id is None or not isinstance(payload, Mapping):
continue
license_data: Dict[str, Any] = {}
for field in LICENSE_FIELDS:
license_data[field] = payload.get(field)
aggregated[normalized_id] = license_data
return aggregated
def _extract_target_model_ids(self, payload: Dict) -> Optional[List[int]]:
if not isinstance(payload, Mapping):
return None
raw_ids = payload.get("modelIds")
if raw_ids is None:
raw_ids = payload.get("model_ids")
if not isinstance(raw_ids, (list, tuple, set)):
return None
normalized: List[int] = []
for candidate in raw_ids:
model_id = self._normalize_model_id(candidate)
if model_id is not None:
normalized.append(model_id)
if not normalized:
return None
return sorted(set(normalized))
@staticmethod
def _convert_metadata_to_dict(metadata: Any) -> Dict[str, Any]:
if metadata is None:
return {}
to_dict = getattr(metadata, "to_dict", None)
if callable(to_dict):
try:
return to_dict()
except Exception:
pass
if isinstance(metadata, Mapping):
return dict(metadata)
return {}
async def _read_json(self, request: web.Request) -> Dict:
if not request.can_read_body:
return {}
@@ -1401,6 +1606,7 @@ class ModelHandlerSet:
"get_model_description": self.query.get_model_description,
"get_relative_paths": self.query.get_relative_paths,
"refresh_model_updates": self.updates.refresh_model_updates,
"fetch_missing_civitai_license_data": self.updates.fetch_missing_civitai_license_data,
"set_model_update_ignore": self.updates.set_model_update_ignore,
"set_version_update_ignore": self.updates.set_version_update_ignore,
"get_model_update_status": self.updates.get_model_update_status,

View File

@@ -56,6 +56,7 @@ COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/version/{modelVersionId}", "get_civitai_model_by_version"),
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/hash/{hash}", "get_civitai_model_by_hash"),
RouteDefinition("POST", "/api/lm/{prefix}/updates/refresh", "refresh_model_updates"),
RouteDefinition("POST", "/api/lm/{prefix}/updates/fetch-missing-license", "fetch_missing_civitai_license_data"),
RouteDefinition("POST", "/api/lm/{prefix}/updates/ignore", "set_model_update_ignore"),
RouteDefinition("POST", "/api/lm/{prefix}/updates/ignore-version", "set_version_update_ignore"),
RouteDefinition("GET", "/api/lm/{prefix}/updates/status/{model_id}", "get_model_update_status"),
@@ -103,4 +104,3 @@ class ModelRouteRegistrar:
add_method_name = self._METHOD_MAP[method.upper()]
add_method = getattr(self._app.router, add_method_name)
add_method(path, handler)

View File

@@ -246,6 +246,10 @@ class CivitaiClient:
'modelVersions': item.get('modelVersions', []),
'type': item.get('type', ''),
'name': item.get('name', ''),
'allowNoCredit': item.get('allowNoCredit'),
'allowCommercialUse': item.get('allowCommercialUse'),
'allowDerivatives': item.get('allowDerivatives'),
'allowDifferentLicense': item.get('allowDifferentLicense'),
}
return payload
except RateLimitError:

View File

@@ -22,7 +22,7 @@ class MetadataManager:
"""
@staticmethod
async def load_metadata(file_path: str, model_class: Type[BaseModelMetadata] = LoraMetadata) -> Optional[BaseModelMetadata]:
async def load_metadata(file_path: str, model_class: Type[BaseModelMetadata] = LoraMetadata) -> tuple[Optional[BaseModelMetadata], bool]:
"""
Load metadata safely.

View File

@@ -77,6 +77,7 @@ export function getApiEndpoints(modelType) {
relinkCivitai: `/api/lm/${modelType}/relink-civitai`,
civitaiVersions: `/api/lm/${modelType}/civitai/versions`,
refreshUpdates: `/api/lm/${modelType}/updates/refresh`,
fetchMissingLicenses: `/api/lm/${modelType}/updates/fetch-missing-license`,
modelUpdateStatus: `/api/lm/${modelType}/updates/status`,
modelUpdateVersions: `/api/lm/${modelType}/updates/versions`,
ignoreModelUpdate: `/api/lm/${modelType}/updates/ignore`,

View File

@@ -1,6 +1,8 @@
import { BaseContextMenu } from './BaseContextMenu.js';
import { showToast } from '../../utils/uiHelpers.js';
import { translate } from '../../utils/i18nHelpers.js';
import { state } from '../../state/index.js';
import { getCompleteApiConfig, getCurrentModelType } from '../../api/apiConfig.js';
import { performModelUpdateCheck } from '../../utils/updateCheckHelpers.js';
export class GlobalContextMenu extends BaseContextMenu {
@@ -8,6 +10,7 @@ export class GlobalContextMenu extends BaseContextMenu {
super('globalContextMenu');
this._cleanupInProgress = false;
this._updateCheckInProgress = false;
this._licenseRefreshInProgress = false;
}
showMenu(x, y, origin = null) {
@@ -32,6 +35,11 @@ export class GlobalContextMenu extends BaseContextMenu {
console.error('Failed to check model updates:', error);
});
break;
case 'fetch-missing-licenses':
this.fetchMissingLicenses(menuItem).catch((error) => {
console.error('Failed to refresh missing license metadata:', error);
});
break;
default:
console.warn(`Unhandled global context menu action: ${action}`);
break;
@@ -133,4 +141,98 @@ export class GlobalContextMenu extends BaseContextMenu {
}
}
}
async fetchMissingLicenses(menuItem) {
if (this._licenseRefreshInProgress) {
return;
}
const modelType = getCurrentModelType();
const apiConfig = getCompleteApiConfig(modelType);
const displayName = apiConfig?.config?.displayName ?? 'Model';
const typePlural = this._buildTypePlural(displayName);
const loadingMessage = translate(
'globalContextMenu.fetchMissingLicenses.loading',
{ type: displayName, typePlural },
`Refreshing license metadata for ${typePlural}...`
);
const endpoint = apiConfig?.endpoints?.fetchMissingLicenses;
if (!endpoint) {
console.warn('Fetch missing license endpoint not configured for model type:', modelType);
showToast(
'globalContextMenu.fetchMissingLicenses.error',
{ message: 'Endpoint unavailable', type: displayName, typePlural },
'warning'
);
return;
}
this._licenseRefreshInProgress = true;
menuItem?.classList?.add('disabled');
state.loadingManager?.showSimpleLoading?.(loadingMessage);
try {
const response = await fetch(endpoint, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({}),
});
let payload = {};
try {
payload = await response.json();
} catch {
payload = {};
}
if (!response.ok || payload.success !== true) {
const errorMessage = payload?.error || response.statusText || 'Unknown error';
throw new Error(errorMessage);
}
const updated = Array.isArray(payload.updated) ? payload.updated : [];
if (updated.length > 0) {
showToast(
'globalContextMenu.fetchMissingLicenses.success',
{ count: updated.length, type: displayName, typePlural },
'success'
);
} else {
showToast(
'globalContextMenu.fetchMissingLicenses.none',
{ type: displayName, typePlural },
'info'
);
}
} catch (error) {
console.error('Failed to refresh missing license metadata:', error);
showToast(
'globalContextMenu.fetchMissingLicenses.error',
{ message: error?.message ?? 'Unknown error', type: displayName, typePlural },
'error'
);
} finally {
state.loadingManager?.hide?.();
if (typeof state.loadingManager?.restoreProgressBar === 'function') {
state.loadingManager.restoreProgressBar();
}
this._licenseRefreshInProgress = false;
menuItem?.classList?.remove('disabled');
}
}
_buildTypePlural(displayName) {
if (!displayName) {
return 'models';
}
const lower = displayName.toLowerCase();
if (lower.endsWith('s')) {
return displayName;
}
return `${displayName}s`;
}
}

View File

@@ -93,6 +93,9 @@
<div class="context-menu-item" data-action="check-model-updates">
<i class="fas fa-sync-alt"></i> <span>{{ t('globalContextMenu.checkModelUpdates.label') }}</span>
</div>
<div class="context-menu-item" data-action="fetch-missing-licenses">
<i class="fas fa-shield-alt"></i> <span>{{ t('globalContextMenu.fetchMissingLicenses.label') }}</span>
</div>
<div class="context-menu-item" data-action="cleanup-example-images-folders">
<i class="fas fa-trash-restore"></i> <span>{{ t('globalContextMenu.cleanupExampleImages.label') }}</span>
</div>

View File

@@ -44,7 +44,10 @@ const refreshSingleModelMetadataMock = vi.fn();
const resetAndReloadMock = vi.fn();
const getCompleteApiConfigMock = vi.fn(() => ({
config: { displayName: 'LoRA' },
endpoints: { refreshUpdates: '/api/lm/loras/updates/refresh' },
endpoints: {
refreshUpdates: '/api/lm/loras/updates/refresh',
fetchMissingLicenses: '/api/lm/loras/updates/fetch-missing-license',
},
}));
const getCurrentModelTypeMock = vi.fn(() => 'loras');
@@ -150,7 +153,10 @@ describe('Interaction-level regression coverage', () => {
resetAndReloadMock.mockResolvedValue(undefined);
getCompleteApiConfigMock.mockReturnValue({
config: { displayName: 'LoRA' },
endpoints: { refreshUpdates: '/api/lm/loras/updates/refresh' },
endpoints: {
refreshUpdates: '/api/lm/loras/updates/refresh',
fetchMissingLicenses: '/api/lm/loras/updates/fetch-missing-license',
},
});
getCurrentModelTypeMock.mockReturnValue('loras');
translateMock.mockImplementation((key, params, fallback) => (typeof fallback === 'string' ? fallback : key));
@@ -322,8 +328,9 @@ describe('Interaction-level regression coverage', () => {
document.body.innerHTML = `
<div id="globalContextMenu" class="context-menu">
<div class="context-menu-item" data-action="download-example-images"></div>
<div class="context-menu-item" data-action="cleanup-example-images-folders"></div>
<div class="context-menu-item" data-action="check-model-updates"></div>
<div class="context-menu-item" data-action="fetch-missing-licenses"></div>
<div class="context-menu-item" data-action="cleanup-example-images-folders"></div>
</div>
`;
@@ -354,6 +361,10 @@ describe('Interaction-level regression coverage', () => {
.mockResolvedValueOnce({
ok: true,
json: async () => ({ success: true, records: [{ id: 1 }] }),
})
.mockResolvedValueOnce({
ok: true,
json: async () => ({ success: true, updated: [{ modelId: 42 }] }),
});
menu.showMenu(240, 320);
@@ -379,7 +390,7 @@ describe('Interaction-level regression coverage', () => {
await flushAsyncTasks();
expect(global.fetch).toHaveBeenLastCalledWith('/api/lm/loras/updates/refresh', {
expect(global.fetch).toHaveBeenNthCalledWith(2, '/api/lm/loras/updates/refresh', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ force: false }),
@@ -398,5 +409,28 @@ describe('Interaction-level regression coverage', () => {
expect(loadingManagerStub.hide).toHaveBeenCalled();
expect(resetAndReloadMock).toHaveBeenCalledWith(false);
expect(checkUpdatesItem.classList.contains('disabled')).toBe(false);
menu.showMenu(480, 520);
const fetchMissingItem = document.querySelector('[data-action="fetch-missing-licenses"]');
fetchMissingItem.dispatchEvent(new Event('click', { bubbles: true }));
expect(fetchMissingItem.classList.contains('disabled')).toBe(true);
const fetchMissingResponse = await global.fetch.mock.results[2].value;
await fetchMissingResponse.json();
await flushAsyncTasks();
expect(global.fetch).toHaveBeenNthCalledWith(3, '/api/lm/loras/updates/fetch-missing-license', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({}),
});
expect(showToastMock).toHaveBeenCalledWith(
'globalContextMenu.fetchMissingLicenses.success',
{ count: 1, type: 'LoRA', typePlural: 'LoRAs' },
'success'
);
expect(loadingManagerStub.showSimpleLoading).toHaveBeenNthCalledWith(2, 'Refreshing license metadata for LoRAs...');
expect(fetchMissingItem.classList.contains('disabled')).toBe(false);
});
});

View File

@@ -1,3 +1,4 @@
import copy
import json
import logging
from types import SimpleNamespace
@@ -6,6 +7,7 @@ import pytest
from py.config import config
from py.routes.handlers.model_handlers import ModelUpdateHandler
from py.utils.metadata_manager import MetadataManager
from py.services.model_update_service import ModelUpdateRecord, ModelVersionRecord
@@ -264,3 +266,171 @@ async def test_refresh_model_updates_accepts_snake_case_ids():
call = update_service.calls[0]
assert call["target_model_ids"] == [3, 4]
@pytest.mark.asyncio
async def test_fetch_missing_license_data_updates_metadata(monkeypatch):
cache = SimpleNamespace(
raw_data=[
{"file_path": "/tmp/model1.safetensors", "civitai": {"modelId": 10}},
{"file_path": "/tmp/model2.safetensors", "civitai": {"modelId": 10}},
{"file_path": "/tmp/model3.safetensors", "civitai": {"modelId": 20}},
],
version_index={},
)
metadata_store = {
"/tmp/model1.safetensors": {"civitai": {"model": {}}},
"/tmp/model2.safetensors": {"civitai": {"model": {}}},
"/tmp/model3.safetensors": {"civitai": {"model": {}}},
}
async def fake_load(path: str):
data = metadata_store.get(path)
if data is None:
return None, False
return SimpleNamespace(to_dict=lambda: copy.deepcopy(data)), False
saved: list[tuple[str, dict]] = []
async def fake_save(path: str, metadata: dict):
saved.append((path, copy.deepcopy(metadata)))
return True
monkeypatch.setattr(MetadataManager, "load_metadata", staticmethod(fake_load))
monkeypatch.setattr(MetadataManager, "save_metadata", staticmethod(fake_save))
provider_calls: list[list[int]] = []
async def fake_bulk(model_ids):
provider_calls.append(list(model_ids))
return {
10: {
"allowNoCredit": True,
"allowCommercialUse": ["Sell"],
"allowDerivatives": True,
"allowDifferentLicense": True,
},
20: {
"allowNoCredit": False,
"allowCommercialUse": ["Image"],
"allowDerivatives": False,
"allowDifferentLicense": False,
},
}
provider = SimpleNamespace()
provider.get_model_versions_bulk = fake_bulk
async def metadata_selector(name):
assert name == "civitai_api"
return provider
handler = ModelUpdateHandler(
service=DummyService(cache),
update_service=SimpleNamespace(),
metadata_provider_selector=metadata_selector,
logger=logging.getLogger(__name__),
)
class DummyRequest:
can_read_body = True
query = {}
async def json(self):
return {}
response = await handler.fetch_missing_civitai_license_data(DummyRequest())
assert response.status == 200
payload = json.loads(response.text)
assert payload["success"] is True
assert len(payload["updated"]) == 3
assert provider_calls == [[10, 20]]
assert len(saved) == 3
first_metadata = saved[0][1]
assert first_metadata["civitai"]["model"]["allowNoCredit"] is True
assert first_metadata["civitai"]["model"]["allowCommercialUse"] == ["Sell"]
assert "missingModelIds" not in payload
assert "errors" not in payload
@pytest.mark.asyncio
async def test_fetch_missing_license_data_filters_model_ids(monkeypatch):
cache = SimpleNamespace(
raw_data=[
{"file_path": "/tmp/model1.safetensors", "civitai": {"modelId": 10}},
{"file_path": "/tmp/model2.safetensors", "civitai": {"modelId": 20}},
],
version_index={},
)
metadata_store = {
"/tmp/model1.safetensors": {"civitai": {"model": {}}},
"/tmp/model2.safetensors": {"civitai": {"model": {}}},
}
async def fake_load(path: str):
data = metadata_store.get(path)
if data is None:
return None, False
return SimpleNamespace(to_dict=lambda: copy.deepcopy(data)), False
saved: list[tuple[str, dict]] = []
async def fake_save(path: str, metadata: dict):
saved.append((path, copy.deepcopy(metadata)))
return True
monkeypatch.setattr(MetadataManager, "load_metadata", staticmethod(fake_load))
monkeypatch.setattr(MetadataManager, "save_metadata", staticmethod(fake_save))
provider_calls: list[list[int]] = []
async def fake_bulk(model_ids):
provider_calls.append(list(model_ids))
return {
10: {
"allowNoCredit": True,
"allowCommercialUse": ["Sell"],
"allowDerivatives": True,
"allowDifferentLicense": True,
},
20: {
"allowNoCredit": False,
"allowCommercialUse": ["Image"],
"allowDerivatives": False,
"allowDifferentLicense": False,
},
}
provider = SimpleNamespace()
provider.get_model_versions_bulk = fake_bulk
async def metadata_selector(name):
assert name == "civitai_api"
return provider
handler = ModelUpdateHandler(
service=DummyService(cache),
update_service=SimpleNamespace(),
metadata_provider_selector=metadata_selector,
logger=logging.getLogger(__name__),
)
class DummyRequest:
can_read_body = True
query = {}
async def json(self):
return {"modelIds": [20]}
response = await handler.fetch_missing_civitai_license_data(DummyRequest())
assert response.status == 200
payload = json.loads(response.text)
assert payload["success"] is True
assert len(payload["updated"]) == 1
assert provider_calls == [[20]]
assert len(saved) == 1

View File

@@ -204,8 +204,26 @@ async def test_get_model_versions_bulk_success(monkeypatch, downloader):
assert kwargs.get("params") == {"ids": "1,2"}
return True, {
"items": [
{"id": 1, "modelVersions": [{"id": 11}], "type": "LORA", "name": "One"},
{"id": 2, "modelVersions": [], "type": "Checkpoint", "name": "Two"},
{
"id": 1,
"modelVersions": [{"id": 11}],
"type": "LORA",
"name": "One",
"allowNoCredit": True,
"allowCommercialUse": ["Sell"],
"allowDerivatives": True,
"allowDifferentLicense": True,
},
{
"id": 2,
"modelVersions": [],
"type": "Checkpoint",
"name": "Two",
"allowNoCredit": False,
"allowCommercialUse": ["Image"],
"allowDerivatives": False,
"allowDifferentLicense": False,
},
]
}
@@ -216,8 +234,24 @@ async def test_get_model_versions_bulk_success(monkeypatch, downloader):
result = await client.get_model_versions_bulk([1, "2", 2])
assert result == {
1: {"modelVersions": [{"id": 11}], "type": "LORA", "name": "One"},
2: {"modelVersions": [], "type": "Checkpoint", "name": "Two"},
1: {
"modelVersions": [{"id": 11}],
"type": "LORA",
"name": "One",
"allowNoCredit": True,
"allowCommercialUse": ["Sell"],
"allowDerivatives": True,
"allowDifferentLicense": True,
},
2: {
"modelVersions": [],
"type": "Checkpoint",
"name": "Two",
"allowNoCredit": False,
"allowCommercialUse": ["Image"],
"allowDerivatives": False,
"allowDifferentLicense": False,
},
}