mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-04-02 10:48:51 -03:00
fix(recipe): show checkpoint-linked recipes in model modal (#851)
This commit is contained in:
@@ -82,7 +82,7 @@ vi.mock(MODEL_VERSIONS_MODULE, () => ({
|
||||
}));
|
||||
|
||||
vi.mock(RECIPE_TAB_MODULE, () => ({
|
||||
loadRecipesForLora: vi.fn(),
|
||||
loadRecipesForModel: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock(I18N_HELPERS_MODULE, () => ({
|
||||
@@ -103,11 +103,14 @@ vi.mock(API_FACTORY, () => ({
|
||||
|
||||
describe('Model metadata interactions keep file path in sync', () => {
|
||||
let getModelApiClient;
|
||||
let loadRecipesForModel;
|
||||
|
||||
beforeEach(async () => {
|
||||
document.body.innerHTML = '';
|
||||
({ getModelApiClient } = await import(API_FACTORY));
|
||||
({ loadRecipesForModel } = await import(RECIPE_TAB_MODULE));
|
||||
getModelApiClient.mockReset();
|
||||
loadRecipesForModel.mockReset();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -206,4 +209,33 @@ describe('Model metadata interactions keep file path in sync', () => {
|
||||
expect(saveModelMetadata).toHaveBeenCalledWith('models/Qwen.testing.safetensors', { notes: 'Updated notes' });
|
||||
});
|
||||
});
|
||||
|
||||
it('shows recipes tab for checkpoint modals and loads linked recipes by hash', async () => {
|
||||
const fetchModelMetadata = vi.fn().mockResolvedValue(null);
|
||||
|
||||
getModelApiClient.mockReturnValue({
|
||||
fetchModelMetadata,
|
||||
saveModelMetadata: vi.fn(),
|
||||
});
|
||||
|
||||
const { showModelModal } = await import(MODAL_MODULE);
|
||||
|
||||
await showModelModal(
|
||||
{
|
||||
model_name: 'Flux Base',
|
||||
file_path: 'models/checkpoints/flux-base.safetensors',
|
||||
file_name: 'flux-base.safetensors',
|
||||
sha256: 'ABC123',
|
||||
civitai: {},
|
||||
},
|
||||
'checkpoints',
|
||||
);
|
||||
|
||||
expect(document.querySelector('.tab-btn[data-tab="recipes"]')).not.toBeNull();
|
||||
expect(loadRecipesForModel).toHaveBeenCalledWith({
|
||||
modelKind: 'checkpoint',
|
||||
displayName: 'Flux Base',
|
||||
sha256: 'ABC123',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -80,7 +80,7 @@ vi.mock(MODEL_VERSIONS_MODULE, () => ({
|
||||
}));
|
||||
|
||||
vi.mock(RECIPE_TAB_MODULE, () => ({
|
||||
loadRecipesForLora: vi.fn(),
|
||||
loadRecipesForModel: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock(I18N_HELPERS_MODULE, () => ({
|
||||
|
||||
@@ -6,6 +6,7 @@ const initializePageFeaturesMock = vi.fn();
|
||||
const getCurrentPageStateMock = vi.fn();
|
||||
const getSessionItemMock = vi.fn();
|
||||
const removeSessionItemMock = vi.fn();
|
||||
const getStorageItemMock = vi.fn();
|
||||
const RecipeContextMenuMock = vi.fn();
|
||||
const refreshVirtualScrollMock = vi.fn();
|
||||
const refreshRecipesMock = vi.fn();
|
||||
@@ -51,6 +52,7 @@ vi.mock('../../../static/js/state/index.js', () => ({
|
||||
vi.mock('../../../static/js/utils/storageHelpers.js', () => ({
|
||||
getSessionItem: getSessionItemMock,
|
||||
removeSessionItem: removeSessionItemMock,
|
||||
getStorageItem: getStorageItemMock,
|
||||
}));
|
||||
|
||||
vi.mock('../../../static/js/components/ContextMenu/index.js', () => ({
|
||||
@@ -117,11 +119,14 @@ describe('RecipeManager', () => {
|
||||
const map = {
|
||||
lora_to_recipe_filterLoraName: 'Flux Dream',
|
||||
lora_to_recipe_filterLoraHash: 'abc123',
|
||||
checkpoint_to_recipe_filterCheckpointName: null,
|
||||
checkpoint_to_recipe_filterCheckpointHash: null,
|
||||
viewRecipeId: '42',
|
||||
};
|
||||
return map[key] ?? null;
|
||||
});
|
||||
removeSessionItemMock.mockImplementation(() => { });
|
||||
getStorageItemMock.mockImplementation((_, defaultValue = null) => defaultValue);
|
||||
|
||||
renderRecipesPage();
|
||||
|
||||
@@ -166,6 +171,8 @@ describe('RecipeManager', () => {
|
||||
active: true,
|
||||
loraName: 'Flux Dream',
|
||||
loraHash: 'abc123',
|
||||
checkpointName: null,
|
||||
checkpointHash: null,
|
||||
recipeId: '42',
|
||||
});
|
||||
|
||||
@@ -177,6 +184,8 @@ describe('RecipeManager', () => {
|
||||
|
||||
expect(removeSessionItemMock).toHaveBeenCalledWith('lora_to_recipe_filterLoraName');
|
||||
expect(removeSessionItemMock).toHaveBeenCalledWith('lora_to_recipe_filterLoraHash');
|
||||
expect(removeSessionItemMock).toHaveBeenCalledWith('checkpoint_to_recipe_filterCheckpointName');
|
||||
expect(removeSessionItemMock).toHaveBeenCalledWith('checkpoint_to_recipe_filterCheckpointHash');
|
||||
expect(removeSessionItemMock).toHaveBeenCalledWith('viewRecipeId');
|
||||
expect(pageState.customFilter.active).toBe(false);
|
||||
expect(indicator.classList.contains('hidden')).toBe(true);
|
||||
@@ -227,4 +236,36 @@ describe('RecipeManager', () => {
|
||||
await manager.refreshRecipes();
|
||||
expect(refreshRecipesMock).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('restores checkpoint recipe filter state and indicator text', async () => {
|
||||
getSessionItemMock.mockImplementation((key) => {
|
||||
const map = {
|
||||
lora_to_recipe_filterLoraName: null,
|
||||
lora_to_recipe_filterLoraHash: null,
|
||||
checkpoint_to_recipe_filterCheckpointName: 'Flux Base',
|
||||
checkpoint_to_recipe_filterCheckpointHash: 'ckpt123',
|
||||
viewRecipeId: null,
|
||||
};
|
||||
return map[key] ?? null;
|
||||
});
|
||||
|
||||
const manager = new RecipeManager();
|
||||
await manager.initialize();
|
||||
|
||||
expect(pageState.customFilter).toEqual({
|
||||
active: true,
|
||||
loraName: null,
|
||||
loraHash: null,
|
||||
checkpointName: 'Flux Base',
|
||||
checkpointHash: 'ckpt123',
|
||||
recipeId: null,
|
||||
});
|
||||
|
||||
const indicator = document.getElementById('customFilterIndicator');
|
||||
const filterText = indicator.querySelector('#customFilterText');
|
||||
|
||||
expect(filterText.innerHTML).toContain('Recipes using checkpoint:');
|
||||
expect(filterText.innerHTML).toContain('Flux Base');
|
||||
expect(filterText.getAttribute('title')).toBe('Flux Base');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -43,6 +43,9 @@ class StubRecipeScanner:
|
||||
self.cached_raw: List[Dict[str, Any]] = []
|
||||
self.recipes: Dict[str, Dict[str, Any]] = {}
|
||||
self.removed: List[str] = []
|
||||
self.last_paginated_params: Dict[str, Any] | None = None
|
||||
self.lora_lookup: Dict[str, List[Dict[str, Any]]] = {}
|
||||
self.checkpoint_lookup: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
async def _noop_get_cached_data(force_refresh: bool = False) -> None: # noqa: ARG001 - signature mirrors real scanner
|
||||
return None
|
||||
@@ -56,6 +59,7 @@ class StubRecipeScanner:
|
||||
return SimpleNamespace(raw_data=list(self.cached_raw))
|
||||
|
||||
async def get_paginated_data(self, **params: Any) -> Dict[str, Any]:
|
||||
self.last_paginated_params = params
|
||||
items = [dict(item) for item in self.listing_items]
|
||||
page = int(params.get("page", 1))
|
||||
page_size = int(params.get("page_size", 20))
|
||||
@@ -70,6 +74,14 @@ class StubRecipeScanner:
|
||||
async def get_recipe_by_id(self, recipe_id: str) -> Optional[Dict[str, Any]]:
|
||||
return self.recipes.get(recipe_id)
|
||||
|
||||
async def get_recipes_for_lora(self, lora_hash: str) -> List[Dict[str, Any]]:
|
||||
return list(self.lora_lookup.get(lora_hash.lower(), []))
|
||||
|
||||
async def get_recipes_for_checkpoint(
|
||||
self, checkpoint_hash: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
return list(self.checkpoint_lookup.get(checkpoint_hash.lower(), []))
|
||||
|
||||
async def get_recipe_json_path(self, recipe_id: str) -> Optional[str]:
|
||||
candidate = Path(self.recipes_dir) / f"{recipe_id}.recipe.json"
|
||||
return str(candidate) if candidate.exists() else None
|
||||
@@ -350,6 +362,47 @@ async def test_list_recipes_provides_file_urls(monkeypatch, tmp_path: Path) -> N
|
||||
assert payload["items"][0]["loras"] == []
|
||||
|
||||
|
||||
async def test_list_recipes_passes_checkpoint_hash_filter(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
response = await harness.client.get("/api/lm/recipes?checkpoint_hash=ckpt123")
|
||||
payload = await response.json()
|
||||
|
||||
assert response.status == 200
|
||||
assert payload["items"] == []
|
||||
assert harness.scanner.last_paginated_params["checkpoint_hash"] == "ckpt123"
|
||||
|
||||
|
||||
async def test_get_recipes_for_checkpoint(monkeypatch, tmp_path: Path) -> None:
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
harness.scanner.checkpoint_lookup["abc123"] = [
|
||||
{"id": "recipe-1", "title": "Linked recipe"}
|
||||
]
|
||||
|
||||
response = await harness.client.get(
|
||||
"/api/lm/recipes/for-checkpoint?hash=ABC123"
|
||||
)
|
||||
payload = await response.json()
|
||||
|
||||
assert response.status == 200
|
||||
assert payload == {
|
||||
"success": True,
|
||||
"recipes": [{"id": "recipe-1", "title": "Linked recipe"}],
|
||||
}
|
||||
|
||||
|
||||
async def test_get_recipes_for_checkpoint_requires_hash(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
response = await harness.client.get("/api/lm/recipes/for-checkpoint")
|
||||
payload = await response.json()
|
||||
|
||||
assert response.status == 400
|
||||
assert payload["success"] is False
|
||||
|
||||
|
||||
async def test_save_and_delete_recipe_round_trip(monkeypatch, tmp_path: Path) -> None:
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
form = FormData()
|
||||
|
||||
@@ -313,6 +313,75 @@ async def test_get_recipe_by_id_handles_non_dict_checkpoint(recipe_scanner):
|
||||
assert recipe["checkpoint"]["file_name"] == "by-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_paginated_data_filters_by_checkpoint_hash(recipe_scanner):
|
||||
scanner, _ = recipe_scanner
|
||||
image_path = Path(config.loras_roots[0]) / "checkpoint-filter.webp"
|
||||
await scanner.add_recipe(
|
||||
{
|
||||
"id": "checkpoint-match",
|
||||
"file_path": str(image_path),
|
||||
"title": "Checkpoint Match",
|
||||
"modified": 0.0,
|
||||
"created_date": 0.0,
|
||||
"loras": [],
|
||||
"checkpoint": {
|
||||
"name": "flux-base.safetensors",
|
||||
"hash": "ABC123",
|
||||
},
|
||||
}
|
||||
)
|
||||
await scanner.add_recipe(
|
||||
{
|
||||
"id": "checkpoint-miss",
|
||||
"file_path": str(Path(config.loras_roots[0]) / "checkpoint-miss.webp"),
|
||||
"title": "Checkpoint Miss",
|
||||
"modified": 1.0,
|
||||
"created_date": 1.0,
|
||||
"loras": [],
|
||||
"checkpoint": {
|
||||
"name": "other.safetensors",
|
||||
"hash": "zzz999",
|
||||
},
|
||||
}
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
result = await scanner.get_paginated_data(
|
||||
page=1,
|
||||
page_size=10,
|
||||
checkpoint_hash="abc123",
|
||||
)
|
||||
|
||||
assert [item["id"] for item in result["items"]] == ["checkpoint-match"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_recipes_for_checkpoint_matches_hash_case_insensitively(recipe_scanner):
|
||||
scanner, _ = recipe_scanner
|
||||
image_path = Path(config.loras_roots[0]) / "checkpoint-linked.webp"
|
||||
await scanner.add_recipe(
|
||||
{
|
||||
"id": "checkpoint-linked",
|
||||
"file_path": str(image_path),
|
||||
"title": "Checkpoint Linked",
|
||||
"modified": 0.0,
|
||||
"created_date": 0.0,
|
||||
"loras": [],
|
||||
"checkpoint": {
|
||||
"name": "flux-base.safetensors",
|
||||
"hash": "ABC123",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
recipes = await scanner.get_recipes_for_checkpoint("abc123")
|
||||
|
||||
assert len(recipes) == 1
|
||||
assert recipes[0]["id"] == "checkpoint-linked"
|
||||
assert recipes[0]["checkpoint"]["hash"] == "ABC123"
|
||||
|
||||
|
||||
def test_enrich_uses_version_index_when_hash_missing(recipe_scanner):
|
||||
scanner, stub = recipe_scanner
|
||||
version_id = 77
|
||||
|
||||
Reference in New Issue
Block a user