fix(recipe): show checkpoint-linked recipes in model modal (#851)

This commit is contained in:
Will Miao
2026-03-31 16:45:01 +08:00
parent 316f17dd46
commit 8dc2a2f76b
12 changed files with 393 additions and 51 deletions

View File

@@ -82,7 +82,7 @@ vi.mock(MODEL_VERSIONS_MODULE, () => ({
}));
vi.mock(RECIPE_TAB_MODULE, () => ({
loadRecipesForLora: vi.fn(),
loadRecipesForModel: vi.fn(),
}));
vi.mock(I18N_HELPERS_MODULE, () => ({
@@ -103,11 +103,14 @@ vi.mock(API_FACTORY, () => ({
describe('Model metadata interactions keep file path in sync', () => {
let getModelApiClient;
let loadRecipesForModel;
beforeEach(async () => {
document.body.innerHTML = '';
({ getModelApiClient } = await import(API_FACTORY));
({ loadRecipesForModel } = await import(RECIPE_TAB_MODULE));
getModelApiClient.mockReset();
loadRecipesForModel.mockReset();
});
afterEach(() => {
@@ -206,4 +209,33 @@ describe('Model metadata interactions keep file path in sync', () => {
expect(saveModelMetadata).toHaveBeenCalledWith('models/Qwen.testing.safetensors', { notes: 'Updated notes' });
});
});
it('shows recipes tab for checkpoint modals and loads linked recipes by hash', async () => {
const fetchModelMetadata = vi.fn().mockResolvedValue(null);
getModelApiClient.mockReturnValue({
fetchModelMetadata,
saveModelMetadata: vi.fn(),
});
const { showModelModal } = await import(MODAL_MODULE);
await showModelModal(
{
model_name: 'Flux Base',
file_path: 'models/checkpoints/flux-base.safetensors',
file_name: 'flux-base.safetensors',
sha256: 'ABC123',
civitai: {},
},
'checkpoints',
);
expect(document.querySelector('.tab-btn[data-tab="recipes"]')).not.toBeNull();
expect(loadRecipesForModel).toHaveBeenCalledWith({
modelKind: 'checkpoint',
displayName: 'Flux Base',
sha256: 'ABC123',
});
});
});

View File

@@ -80,7 +80,7 @@ vi.mock(MODEL_VERSIONS_MODULE, () => ({
}));
vi.mock(RECIPE_TAB_MODULE, () => ({
loadRecipesForLora: vi.fn(),
loadRecipesForModel: vi.fn(),
}));
vi.mock(I18N_HELPERS_MODULE, () => ({

View File

@@ -6,6 +6,7 @@ const initializePageFeaturesMock = vi.fn();
const getCurrentPageStateMock = vi.fn();
const getSessionItemMock = vi.fn();
const removeSessionItemMock = vi.fn();
const getStorageItemMock = vi.fn();
const RecipeContextMenuMock = vi.fn();
const refreshVirtualScrollMock = vi.fn();
const refreshRecipesMock = vi.fn();
@@ -51,6 +52,7 @@ vi.mock('../../../static/js/state/index.js', () => ({
vi.mock('../../../static/js/utils/storageHelpers.js', () => ({
getSessionItem: getSessionItemMock,
removeSessionItem: removeSessionItemMock,
getStorageItem: getStorageItemMock,
}));
vi.mock('../../../static/js/components/ContextMenu/index.js', () => ({
@@ -117,11 +119,14 @@ describe('RecipeManager', () => {
const map = {
lora_to_recipe_filterLoraName: 'Flux Dream',
lora_to_recipe_filterLoraHash: 'abc123',
checkpoint_to_recipe_filterCheckpointName: null,
checkpoint_to_recipe_filterCheckpointHash: null,
viewRecipeId: '42',
};
return map[key] ?? null;
});
removeSessionItemMock.mockImplementation(() => { });
getStorageItemMock.mockImplementation((_, defaultValue = null) => defaultValue);
renderRecipesPage();
@@ -166,6 +171,8 @@ describe('RecipeManager', () => {
active: true,
loraName: 'Flux Dream',
loraHash: 'abc123',
checkpointName: null,
checkpointHash: null,
recipeId: '42',
});
@@ -177,6 +184,8 @@ describe('RecipeManager', () => {
expect(removeSessionItemMock).toHaveBeenCalledWith('lora_to_recipe_filterLoraName');
expect(removeSessionItemMock).toHaveBeenCalledWith('lora_to_recipe_filterLoraHash');
expect(removeSessionItemMock).toHaveBeenCalledWith('checkpoint_to_recipe_filterCheckpointName');
expect(removeSessionItemMock).toHaveBeenCalledWith('checkpoint_to_recipe_filterCheckpointHash');
expect(removeSessionItemMock).toHaveBeenCalledWith('viewRecipeId');
expect(pageState.customFilter.active).toBe(false);
expect(indicator.classList.contains('hidden')).toBe(true);
@@ -227,4 +236,36 @@ describe('RecipeManager', () => {
await manager.refreshRecipes();
expect(refreshRecipesMock).toHaveBeenCalledTimes(1);
});
it('restores checkpoint recipe filter state and indicator text', async () => {
getSessionItemMock.mockImplementation((key) => {
const map = {
lora_to_recipe_filterLoraName: null,
lora_to_recipe_filterLoraHash: null,
checkpoint_to_recipe_filterCheckpointName: 'Flux Base',
checkpoint_to_recipe_filterCheckpointHash: 'ckpt123',
viewRecipeId: null,
};
return map[key] ?? null;
});
const manager = new RecipeManager();
await manager.initialize();
expect(pageState.customFilter).toEqual({
active: true,
loraName: null,
loraHash: null,
checkpointName: 'Flux Base',
checkpointHash: 'ckpt123',
recipeId: null,
});
const indicator = document.getElementById('customFilterIndicator');
const filterText = indicator.querySelector('#customFilterText');
expect(filterText.innerHTML).toContain('Recipes using checkpoint:');
expect(filterText.innerHTML).toContain('Flux Base');
expect(filterText.getAttribute('title')).toBe('Flux Base');
});
});

View File

@@ -43,6 +43,9 @@ class StubRecipeScanner:
self.cached_raw: List[Dict[str, Any]] = []
self.recipes: Dict[str, Dict[str, Any]] = {}
self.removed: List[str] = []
self.last_paginated_params: Dict[str, Any] | None = None
self.lora_lookup: Dict[str, List[Dict[str, Any]]] = {}
self.checkpoint_lookup: Dict[str, List[Dict[str, Any]]] = {}
async def _noop_get_cached_data(force_refresh: bool = False) -> None: # noqa: ARG001 - signature mirrors real scanner
return None
@@ -56,6 +59,7 @@ class StubRecipeScanner:
return SimpleNamespace(raw_data=list(self.cached_raw))
async def get_paginated_data(self, **params: Any) -> Dict[str, Any]:
self.last_paginated_params = params
items = [dict(item) for item in self.listing_items]
page = int(params.get("page", 1))
page_size = int(params.get("page_size", 20))
@@ -70,6 +74,14 @@ class StubRecipeScanner:
async def get_recipe_by_id(self, recipe_id: str) -> Optional[Dict[str, Any]]:
return self.recipes.get(recipe_id)
async def get_recipes_for_lora(self, lora_hash: str) -> List[Dict[str, Any]]:
return list(self.lora_lookup.get(lora_hash.lower(), []))
async def get_recipes_for_checkpoint(
self, checkpoint_hash: str
) -> List[Dict[str, Any]]:
return list(self.checkpoint_lookup.get(checkpoint_hash.lower(), []))
async def get_recipe_json_path(self, recipe_id: str) -> Optional[str]:
candidate = Path(self.recipes_dir) / f"{recipe_id}.recipe.json"
return str(candidate) if candidate.exists() else None
@@ -350,6 +362,47 @@ async def test_list_recipes_provides_file_urls(monkeypatch, tmp_path: Path) -> N
assert payload["items"][0]["loras"] == []
async def test_list_recipes_passes_checkpoint_hash_filter(
monkeypatch, tmp_path: Path
) -> None:
async with recipe_harness(monkeypatch, tmp_path) as harness:
response = await harness.client.get("/api/lm/recipes?checkpoint_hash=ckpt123")
payload = await response.json()
assert response.status == 200
assert payload["items"] == []
assert harness.scanner.last_paginated_params["checkpoint_hash"] == "ckpt123"
async def test_get_recipes_for_checkpoint(monkeypatch, tmp_path: Path) -> None:
async with recipe_harness(monkeypatch, tmp_path) as harness:
harness.scanner.checkpoint_lookup["abc123"] = [
{"id": "recipe-1", "title": "Linked recipe"}
]
response = await harness.client.get(
"/api/lm/recipes/for-checkpoint?hash=ABC123"
)
payload = await response.json()
assert response.status == 200
assert payload == {
"success": True,
"recipes": [{"id": "recipe-1", "title": "Linked recipe"}],
}
async def test_get_recipes_for_checkpoint_requires_hash(
monkeypatch, tmp_path: Path
) -> None:
async with recipe_harness(monkeypatch, tmp_path) as harness:
response = await harness.client.get("/api/lm/recipes/for-checkpoint")
payload = await response.json()
assert response.status == 400
assert payload["success"] is False
async def test_save_and_delete_recipe_round_trip(monkeypatch, tmp_path: Path) -> None:
async with recipe_harness(monkeypatch, tmp_path) as harness:
form = FormData()

View File

@@ -313,6 +313,75 @@ async def test_get_recipe_by_id_handles_non_dict_checkpoint(recipe_scanner):
assert recipe["checkpoint"]["file_name"] == "by-id"
@pytest.mark.asyncio
async def test_get_paginated_data_filters_by_checkpoint_hash(recipe_scanner):
scanner, _ = recipe_scanner
image_path = Path(config.loras_roots[0]) / "checkpoint-filter.webp"
await scanner.add_recipe(
{
"id": "checkpoint-match",
"file_path": str(image_path),
"title": "Checkpoint Match",
"modified": 0.0,
"created_date": 0.0,
"loras": [],
"checkpoint": {
"name": "flux-base.safetensors",
"hash": "ABC123",
},
}
)
await scanner.add_recipe(
{
"id": "checkpoint-miss",
"file_path": str(Path(config.loras_roots[0]) / "checkpoint-miss.webp"),
"title": "Checkpoint Miss",
"modified": 1.0,
"created_date": 1.0,
"loras": [],
"checkpoint": {
"name": "other.safetensors",
"hash": "zzz999",
},
}
)
await asyncio.sleep(0)
result = await scanner.get_paginated_data(
page=1,
page_size=10,
checkpoint_hash="abc123",
)
assert [item["id"] for item in result["items"]] == ["checkpoint-match"]
@pytest.mark.asyncio
async def test_get_recipes_for_checkpoint_matches_hash_case_insensitively(recipe_scanner):
scanner, _ = recipe_scanner
image_path = Path(config.loras_roots[0]) / "checkpoint-linked.webp"
await scanner.add_recipe(
{
"id": "checkpoint-linked",
"file_path": str(image_path),
"title": "Checkpoint Linked",
"modified": 0.0,
"created_date": 0.0,
"loras": [],
"checkpoint": {
"name": "flux-base.safetensors",
"hash": "ABC123",
},
}
)
recipes = await scanner.get_recipes_for_checkpoint("abc123")
assert len(recipes) == 1
assert recipes[0]["id"] == "checkpoint-linked"
assert recipes[0]["checkpoint"]["hash"] == "ABC123"
def test_enrich_uses_version_index_when_hash_missing(recipe_scanner):
scanner, stub = recipe_scanner
version_id = 77