mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat(context-menu): refresh missing license metadata
This commit is contained in:
@@ -44,7 +44,10 @@ const refreshSingleModelMetadataMock = vi.fn();
|
||||
const resetAndReloadMock = vi.fn();
|
||||
const getCompleteApiConfigMock = vi.fn(() => ({
|
||||
config: { displayName: 'LoRA' },
|
||||
endpoints: { refreshUpdates: '/api/lm/loras/updates/refresh' },
|
||||
endpoints: {
|
||||
refreshUpdates: '/api/lm/loras/updates/refresh',
|
||||
fetchMissingLicenses: '/api/lm/loras/updates/fetch-missing-license',
|
||||
},
|
||||
}));
|
||||
const getCurrentModelTypeMock = vi.fn(() => 'loras');
|
||||
|
||||
@@ -150,7 +153,10 @@ describe('Interaction-level regression coverage', () => {
|
||||
resetAndReloadMock.mockResolvedValue(undefined);
|
||||
getCompleteApiConfigMock.mockReturnValue({
|
||||
config: { displayName: 'LoRA' },
|
||||
endpoints: { refreshUpdates: '/api/lm/loras/updates/refresh' },
|
||||
endpoints: {
|
||||
refreshUpdates: '/api/lm/loras/updates/refresh',
|
||||
fetchMissingLicenses: '/api/lm/loras/updates/fetch-missing-license',
|
||||
},
|
||||
});
|
||||
getCurrentModelTypeMock.mockReturnValue('loras');
|
||||
translateMock.mockImplementation((key, params, fallback) => (typeof fallback === 'string' ? fallback : key));
|
||||
@@ -322,8 +328,9 @@ describe('Interaction-level regression coverage', () => {
|
||||
document.body.innerHTML = `
|
||||
<div id="globalContextMenu" class="context-menu">
|
||||
<div class="context-menu-item" data-action="download-example-images"></div>
|
||||
<div class="context-menu-item" data-action="cleanup-example-images-folders"></div>
|
||||
<div class="context-menu-item" data-action="check-model-updates"></div>
|
||||
<div class="context-menu-item" data-action="fetch-missing-licenses"></div>
|
||||
<div class="context-menu-item" data-action="cleanup-example-images-folders"></div>
|
||||
</div>
|
||||
`;
|
||||
|
||||
@@ -354,6 +361,10 @@ describe('Interaction-level regression coverage', () => {
|
||||
.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({ success: true, records: [{ id: 1 }] }),
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({ success: true, updated: [{ modelId: 42 }] }),
|
||||
});
|
||||
|
||||
menu.showMenu(240, 320);
|
||||
@@ -379,7 +390,7 @@ describe('Interaction-level regression coverage', () => {
|
||||
|
||||
await flushAsyncTasks();
|
||||
|
||||
expect(global.fetch).toHaveBeenLastCalledWith('/api/lm/loras/updates/refresh', {
|
||||
expect(global.fetch).toHaveBeenNthCalledWith(2, '/api/lm/loras/updates/refresh', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ force: false }),
|
||||
@@ -398,5 +409,28 @@ describe('Interaction-level regression coverage', () => {
|
||||
expect(loadingManagerStub.hide).toHaveBeenCalled();
|
||||
expect(resetAndReloadMock).toHaveBeenCalledWith(false);
|
||||
expect(checkUpdatesItem.classList.contains('disabled')).toBe(false);
|
||||
|
||||
menu.showMenu(480, 520);
|
||||
const fetchMissingItem = document.querySelector('[data-action="fetch-missing-licenses"]');
|
||||
fetchMissingItem.dispatchEvent(new Event('click', { bubbles: true }));
|
||||
expect(fetchMissingItem.classList.contains('disabled')).toBe(true);
|
||||
|
||||
const fetchMissingResponse = await global.fetch.mock.results[2].value;
|
||||
await fetchMissingResponse.json();
|
||||
await flushAsyncTasks();
|
||||
|
||||
expect(global.fetch).toHaveBeenNthCalledWith(3, '/api/lm/loras/updates/fetch-missing-license', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({}),
|
||||
});
|
||||
|
||||
expect(showToastMock).toHaveBeenCalledWith(
|
||||
'globalContextMenu.fetchMissingLicenses.success',
|
||||
{ count: 1, type: 'LoRA', typePlural: 'LoRAs' },
|
||||
'success'
|
||||
);
|
||||
expect(loadingManagerStub.showSimpleLoading).toHaveBeenNthCalledWith(2, 'Refreshing license metadata for LoRAs...');
|
||||
expect(fetchMissingItem.classList.contains('disabled')).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
from types import SimpleNamespace
|
||||
@@ -6,6 +7,7 @@ import pytest
|
||||
|
||||
from py.config import config
|
||||
from py.routes.handlers.model_handlers import ModelUpdateHandler
|
||||
from py.utils.metadata_manager import MetadataManager
|
||||
from py.services.model_update_service import ModelUpdateRecord, ModelVersionRecord
|
||||
|
||||
|
||||
@@ -264,3 +266,171 @@ async def test_refresh_model_updates_accepts_snake_case_ids():
|
||||
|
||||
call = update_service.calls[0]
|
||||
assert call["target_model_ids"] == [3, 4]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_missing_license_data_updates_metadata(monkeypatch):
|
||||
cache = SimpleNamespace(
|
||||
raw_data=[
|
||||
{"file_path": "/tmp/model1.safetensors", "civitai": {"modelId": 10}},
|
||||
{"file_path": "/tmp/model2.safetensors", "civitai": {"modelId": 10}},
|
||||
{"file_path": "/tmp/model3.safetensors", "civitai": {"modelId": 20}},
|
||||
],
|
||||
version_index={},
|
||||
)
|
||||
|
||||
metadata_store = {
|
||||
"/tmp/model1.safetensors": {"civitai": {"model": {}}},
|
||||
"/tmp/model2.safetensors": {"civitai": {"model": {}}},
|
||||
"/tmp/model3.safetensors": {"civitai": {"model": {}}},
|
||||
}
|
||||
|
||||
async def fake_load(path: str):
|
||||
data = metadata_store.get(path)
|
||||
if data is None:
|
||||
return None, False
|
||||
return SimpleNamespace(to_dict=lambda: copy.deepcopy(data)), False
|
||||
|
||||
saved: list[tuple[str, dict]] = []
|
||||
|
||||
async def fake_save(path: str, metadata: dict):
|
||||
saved.append((path, copy.deepcopy(metadata)))
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(MetadataManager, "load_metadata", staticmethod(fake_load))
|
||||
monkeypatch.setattr(MetadataManager, "save_metadata", staticmethod(fake_save))
|
||||
|
||||
provider_calls: list[list[int]] = []
|
||||
|
||||
async def fake_bulk(model_ids):
|
||||
provider_calls.append(list(model_ids))
|
||||
return {
|
||||
10: {
|
||||
"allowNoCredit": True,
|
||||
"allowCommercialUse": ["Sell"],
|
||||
"allowDerivatives": True,
|
||||
"allowDifferentLicense": True,
|
||||
},
|
||||
20: {
|
||||
"allowNoCredit": False,
|
||||
"allowCommercialUse": ["Image"],
|
||||
"allowDerivatives": False,
|
||||
"allowDifferentLicense": False,
|
||||
},
|
||||
}
|
||||
|
||||
provider = SimpleNamespace()
|
||||
provider.get_model_versions_bulk = fake_bulk
|
||||
|
||||
async def metadata_selector(name):
|
||||
assert name == "civitai_api"
|
||||
return provider
|
||||
|
||||
handler = ModelUpdateHandler(
|
||||
service=DummyService(cache),
|
||||
update_service=SimpleNamespace(),
|
||||
metadata_provider_selector=metadata_selector,
|
||||
logger=logging.getLogger(__name__),
|
||||
)
|
||||
|
||||
class DummyRequest:
|
||||
can_read_body = True
|
||||
query = {}
|
||||
|
||||
async def json(self):
|
||||
return {}
|
||||
|
||||
response = await handler.fetch_missing_civitai_license_data(DummyRequest())
|
||||
assert response.status == 200
|
||||
|
||||
payload = json.loads(response.text)
|
||||
assert payload["success"] is True
|
||||
assert len(payload["updated"]) == 3
|
||||
assert provider_calls == [[10, 20]]
|
||||
assert len(saved) == 3
|
||||
|
||||
first_metadata = saved[0][1]
|
||||
assert first_metadata["civitai"]["model"]["allowNoCredit"] is True
|
||||
assert first_metadata["civitai"]["model"]["allowCommercialUse"] == ["Sell"]
|
||||
assert "missingModelIds" not in payload
|
||||
assert "errors" not in payload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_missing_license_data_filters_model_ids(monkeypatch):
|
||||
cache = SimpleNamespace(
|
||||
raw_data=[
|
||||
{"file_path": "/tmp/model1.safetensors", "civitai": {"modelId": 10}},
|
||||
{"file_path": "/tmp/model2.safetensors", "civitai": {"modelId": 20}},
|
||||
],
|
||||
version_index={},
|
||||
)
|
||||
|
||||
metadata_store = {
|
||||
"/tmp/model1.safetensors": {"civitai": {"model": {}}},
|
||||
"/tmp/model2.safetensors": {"civitai": {"model": {}}},
|
||||
}
|
||||
|
||||
async def fake_load(path: str):
|
||||
data = metadata_store.get(path)
|
||||
if data is None:
|
||||
return None, False
|
||||
return SimpleNamespace(to_dict=lambda: copy.deepcopy(data)), False
|
||||
|
||||
saved: list[tuple[str, dict]] = []
|
||||
|
||||
async def fake_save(path: str, metadata: dict):
|
||||
saved.append((path, copy.deepcopy(metadata)))
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(MetadataManager, "load_metadata", staticmethod(fake_load))
|
||||
monkeypatch.setattr(MetadataManager, "save_metadata", staticmethod(fake_save))
|
||||
|
||||
provider_calls: list[list[int]] = []
|
||||
|
||||
async def fake_bulk(model_ids):
|
||||
provider_calls.append(list(model_ids))
|
||||
return {
|
||||
10: {
|
||||
"allowNoCredit": True,
|
||||
"allowCommercialUse": ["Sell"],
|
||||
"allowDerivatives": True,
|
||||
"allowDifferentLicense": True,
|
||||
},
|
||||
20: {
|
||||
"allowNoCredit": False,
|
||||
"allowCommercialUse": ["Image"],
|
||||
"allowDerivatives": False,
|
||||
"allowDifferentLicense": False,
|
||||
},
|
||||
}
|
||||
|
||||
provider = SimpleNamespace()
|
||||
provider.get_model_versions_bulk = fake_bulk
|
||||
|
||||
async def metadata_selector(name):
|
||||
assert name == "civitai_api"
|
||||
return provider
|
||||
|
||||
handler = ModelUpdateHandler(
|
||||
service=DummyService(cache),
|
||||
update_service=SimpleNamespace(),
|
||||
metadata_provider_selector=metadata_selector,
|
||||
logger=logging.getLogger(__name__),
|
||||
)
|
||||
|
||||
class DummyRequest:
|
||||
can_read_body = True
|
||||
query = {}
|
||||
|
||||
async def json(self):
|
||||
return {"modelIds": [20]}
|
||||
|
||||
response = await handler.fetch_missing_civitai_license_data(DummyRequest())
|
||||
assert response.status == 200
|
||||
|
||||
payload = json.loads(response.text)
|
||||
assert payload["success"] is True
|
||||
assert len(payload["updated"]) == 1
|
||||
assert provider_calls == [[20]]
|
||||
assert len(saved) == 1
|
||||
|
||||
@@ -204,8 +204,26 @@ async def test_get_model_versions_bulk_success(monkeypatch, downloader):
|
||||
assert kwargs.get("params") == {"ids": "1,2"}
|
||||
return True, {
|
||||
"items": [
|
||||
{"id": 1, "modelVersions": [{"id": 11}], "type": "LORA", "name": "One"},
|
||||
{"id": 2, "modelVersions": [], "type": "Checkpoint", "name": "Two"},
|
||||
{
|
||||
"id": 1,
|
||||
"modelVersions": [{"id": 11}],
|
||||
"type": "LORA",
|
||||
"name": "One",
|
||||
"allowNoCredit": True,
|
||||
"allowCommercialUse": ["Sell"],
|
||||
"allowDerivatives": True,
|
||||
"allowDifferentLicense": True,
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"modelVersions": [],
|
||||
"type": "Checkpoint",
|
||||
"name": "Two",
|
||||
"allowNoCredit": False,
|
||||
"allowCommercialUse": ["Image"],
|
||||
"allowDerivatives": False,
|
||||
"allowDifferentLicense": False,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
@@ -216,8 +234,24 @@ async def test_get_model_versions_bulk_success(monkeypatch, downloader):
|
||||
result = await client.get_model_versions_bulk([1, "2", 2])
|
||||
|
||||
assert result == {
|
||||
1: {"modelVersions": [{"id": 11}], "type": "LORA", "name": "One"},
|
||||
2: {"modelVersions": [], "type": "Checkpoint", "name": "Two"},
|
||||
1: {
|
||||
"modelVersions": [{"id": 11}],
|
||||
"type": "LORA",
|
||||
"name": "One",
|
||||
"allowNoCredit": True,
|
||||
"allowCommercialUse": ["Sell"],
|
||||
"allowDerivatives": True,
|
||||
"allowDifferentLicense": True,
|
||||
},
|
||||
2: {
|
||||
"modelVersions": [],
|
||||
"type": "Checkpoint",
|
||||
"name": "Two",
|
||||
"allowNoCredit": False,
|
||||
"allowCommercialUse": ["Image"],
|
||||
"allowDerivatives": False,
|
||||
"allowDifferentLicense": False,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user