diff --git a/py/routes/handlers/recipe_handlers.py b/py/routes/handlers/recipe_handlers.py index 339fdf46..59e354f7 100644 --- a/py/routes/handlers/recipe_handlers.py +++ b/py/routes/handlers/recipe_handlers.py @@ -81,6 +81,7 @@ class RecipeHandlerSet: "bulk_delete": self.management.bulk_delete, "save_recipe_from_widget": self.management.save_recipe_from_widget, "get_recipes_for_lora": self.query.get_recipes_for_lora, + "get_recipes_for_checkpoint": self.query.get_recipes_for_checkpoint, "scan_recipes": self.query.scan_recipes, "move_recipe": self.management.move_recipe, "repair_recipes": self.management.repair_recipes, @@ -218,6 +219,7 @@ class RecipeListingHandler: filters["tags"] = tag_filters lora_hash = request.query.get("lora_hash") + checkpoint_hash = request.query.get("checkpoint_hash") result = await recipe_scanner.get_paginated_data( page=page, @@ -227,6 +229,7 @@ class RecipeListingHandler: filters=filters, search_options=search_options, lora_hash=lora_hash, + checkpoint_hash=checkpoint_hash, folder=folder, recursive=recursive, ) @@ -423,6 +426,28 @@ class RecipeQueryHandler: self._logger.error("Error getting recipes for Lora: %s", exc) return web.json_response({"success": False, "error": str(exc)}, status=500) + async def get_recipes_for_checkpoint(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + checkpoint_hash = request.query.get("hash") + if not checkpoint_hash: + return web.json_response( + {"success": False, "error": "Checkpoint hash is required"}, + status=400, + ) + + matching_recipes = await recipe_scanner.get_recipes_for_checkpoint( + checkpoint_hash + ) + return web.json_response({"success": True, "recipes": matching_recipes}) + except Exception as exc: + self._logger.error("Error getting recipes for checkpoint: %s", exc) + return web.json_response({"success": False, "error": str(exc)}, status=500) + async def scan_recipes(self, request: web.Request) -> web.Response: try: await self._ensure_dependencies_ready() diff --git a/py/routes/recipe_route_registrar.py b/py/routes/recipe_route_registrar.py index 3fa30834..95aedee5 100644 --- a/py/routes/recipe_route_registrar.py +++ b/py/routes/recipe_route_registrar.py @@ -51,6 +51,9 @@ ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( "POST", "/api/lm/recipes/save-from-widget", "save_recipe_from_widget" ), RouteDefinition("GET", "/api/lm/recipes/for-lora", "get_recipes_for_lora"), + RouteDefinition( + "GET", "/api/lm/recipes/for-checkpoint", "get_recipes_for_checkpoint" + ), RouteDefinition("GET", "/api/lm/recipes/scan", "scan_recipes"), RouteDefinition("POST", "/api/lm/recipes/repair", "repair_recipes"), RouteDefinition("POST", "/api/lm/recipes/cancel-repair", "cancel_repair"), diff --git a/py/services/recipe_scanner.py b/py/services/recipe_scanner.py index 629babf5..0f7b7446 100644 --- a/py/services/recipe_scanner.py +++ b/py/services/recipe_scanner.py @@ -1615,6 +1615,9 @@ class RecipeScanner: ) -> Optional[Dict[str, Any]]: """Coerce legacy or malformed checkpoint entries into a dict.""" + if checkpoint_raw is None: + return None + if isinstance(checkpoint_raw, dict): return dict(checkpoint_raw) @@ -1632,9 +1635,6 @@ class RecipeScanner: "file_name": file_name, } - logger.warning( - "Unexpected checkpoint payload type %s", type(checkpoint_raw).__name__ - ) return None def _enrich_checkpoint_entry(self, checkpoint: Dict[str, Any]) -> Dict[str, Any]: @@ -1790,6 +1790,7 @@ class RecipeScanner: filters: dict = None, search_options: dict = None, lora_hash: str = None, + checkpoint_hash: str = None, bypass_filters: bool = True, folder: str | None = None, recursive: bool = True, @@ -1804,7 +1805,8 @@ class RecipeScanner: filters: Dictionary of filters to apply search_options: Dictionary of search options to apply lora_hash: Optional SHA256 hash of a LoRA to filter recipes by - bypass_filters: If True, ignore other filters when a lora_hash is provided + checkpoint_hash: Optional SHA256 hash of a checkpoint to filter recipes by + bypass_filters: If True, ignore other filters when a hash filter is provided folder: Optional folder filter relative to recipes directory recursive: Whether to include recipes in subfolders of the selected folder """ @@ -1852,9 +1854,23 @@ class RecipeScanner: # Skip other filters if bypass_filters is True pass # Otherwise continue with normal filtering after applying LoRA hash filter + elif checkpoint_hash: + normalized_checkpoint_hash = checkpoint_hash.lower() + filtered_data = [ + item + for item in filtered_data + if isinstance(item.get("checkpoint"), dict) + and (item["checkpoint"].get("hash", "") or "").lower() + == normalized_checkpoint_hash + ] - # Skip further filtering if we're only filtering by LoRA hash with bypass enabled - if not (lora_hash and bypass_filters): + if bypass_filters: + pass + + has_hash_filter = bool(lora_hash or checkpoint_hash) + + # Skip further filtering if we're only filtering by model hash with bypass enabled + if not (has_hash_filter and bypass_filters): # Apply folder filter before other criteria if folder is not None: normalized_folder = folder.strip("/") @@ -2334,6 +2350,38 @@ class RecipeScanner: return matching_recipes + async def get_recipes_for_checkpoint( + self, checkpoint_hash: str + ) -> List[Dict[str, Any]]: + """Return recipes that reference a given checkpoint hash.""" + + if not checkpoint_hash: + return [] + + normalized_hash = checkpoint_hash.lower() + cache = await self.get_cached_data() + matching_recipes: List[Dict[str, Any]] = [] + + for recipe in cache.raw_data: + checkpoint = self._normalize_checkpoint_entry(recipe.get("checkpoint")) + if not checkpoint: + continue + + enriched_checkpoint = self._enrich_checkpoint_entry(dict(checkpoint)) + if (enriched_checkpoint.get("hash") or "").lower() != normalized_hash: + continue + + recipe_copy = {**recipe} + recipe_copy["checkpoint"] = enriched_checkpoint + recipe_copy["loras"] = [ + self._enrich_lora_entry(dict(entry)) + for entry in recipe.get("loras", []) + ] + recipe_copy["file_url"] = self._format_file_url(recipe.get("file_path")) + matching_recipes.append(recipe_copy) + + return matching_recipes + async def get_recipe_syntax_tokens(self, recipe_id: str) -> List[str]: """Build LoRA syntax tokens for a recipe.""" diff --git a/static/js/api/recipeApi.js b/static/js/api/recipeApi.js index 6718bc17..cc579e7c 100644 --- a/static/js/api/recipeApi.js +++ b/static/js/api/recipeApi.js @@ -83,6 +83,9 @@ export async function fetchRecipesPage(page = 1, pageSize = 100) { if (pageState.customFilter?.active && pageState.customFilter?.loraHash) { params.append('lora_hash', pageState.customFilter.loraHash); params.append('bypass_filters', 'true'); + } else if (pageState.customFilter?.active && pageState.customFilter?.checkpointHash) { + params.append('checkpoint_hash', pageState.customFilter.checkpointHash); + params.append('bypass_filters', 'true'); } else { // Normal filtering logic diff --git a/static/js/components/shared/ModelModal.js b/static/js/components/shared/ModelModal.js index 8314ce7e..a4f252a6 100644 --- a/static/js/components/shared/ModelModal.js +++ b/static/js/components/shared/ModelModal.js @@ -19,7 +19,7 @@ import { renderCompactTags, setupTagTooltip, formatFileSize, escapeAttribute, es import { renderTriggerWords, setupTriggerWordsEditMode } from './TriggerWords.js'; import { parsePresets, renderPresetTags } from './PresetTags.js'; import { initVersionsTab } from './ModelVersionsTab.js'; -import { loadRecipesForLora } from './RecipeTab.js'; +import { loadRecipesForModel } from './RecipeTab.js'; import { translate } from '../../utils/i18nHelpers.js'; import { state } from '../../state/index.js'; @@ -355,7 +355,9 @@ export async function showModelModal(model, modelType) { ${versionsTabBadge} `.trim(); - const tabsContent = modelType === 'loras' ? + const supportsRecipesTab = modelType === 'loras' || modelType === 'checkpoints'; + + const tabsContent = supportsRecipesTab ? ` ${versionsTabButton} @@ -385,7 +387,7 @@ export async function showModelModal(model, modelType) { `.trim(); - const tabPanesContent = modelType === 'loras' ? + const tabPanesContent = supportsRecipesTab ? `
${loadingExampleImagesText} @@ -664,14 +666,23 @@ export async function showModelModal(model, modelType) { setupNavigationShortcuts(modelType); updateNavigationControls(); - // LoRA specific setup + // Model-specific setup if (modelType === 'loras' || modelType === 'embeddings') { setupTriggerWordsEditMode(); + } - if (modelType == 'loras') { - // Load recipes for this LoRA - loadRecipesForLora(modelWithFullData.model_name, modelWithFullData.sha256); - } + if (modelType === 'loras') { + loadRecipesForModel({ + modelKind: 'lora', + displayName: modelWithFullData.model_name, + sha256: modelWithFullData.sha256, + }); + } else if (modelType === 'checkpoints') { + loadRecipesForModel({ + modelKind: 'checkpoint', + displayName: modelWithFullData.model_name, + sha256: modelWithFullData.sha256, + }); } // Load example images asynchronously - merge regular and custom images diff --git a/static/js/components/shared/RecipeTab.js b/static/js/components/shared/RecipeTab.js index dc4e56ab..1384de13 100644 --- a/static/js/components/shared/RecipeTab.js +++ b/static/js/components/shared/RecipeTab.js @@ -1,38 +1,47 @@ /** - * RecipeTab - Handles the recipes tab in model modals (LoRA specific functionality) - * Moved to shared directory for consistency + * RecipeTab - Handles the recipes tab in model modals. */ import { showToast, copyToClipboard } from '../../utils/uiHelpers.js'; import { setSessionItem, removeSessionItem } from '../../utils/storageHelpers.js'; /** - * Loads recipes that use the specified Lora and renders them in the tab - * @param {string} loraName - The display name of the Lora - * @param {string} sha256 - The SHA256 hash of the Lora + * Loads recipes that use the specified model and renders them in the tab. + * @param {Object} options + * @param {'lora'|'checkpoint'} options.modelKind - Model kind for copy and endpoint selection + * @param {string} options.displayName - The display name of the model + * @param {string} options.sha256 - The SHA256 hash of the model */ -export function loadRecipesForLora(loraName, sha256) { +export function loadRecipesForModel({ modelKind, displayName, sha256 }) { const recipeTab = document.getElementById('recipes-tab'); if (!recipeTab) return; - + + const normalizedHash = sha256?.toLowerCase?.() || ''; + const modelLabel = getModelLabel(modelKind); + // Show loading state recipeTab.innerHTML = `
Loading recipes...
`; - - // Fetch recipes that use this Lora by hash - fetch(`/api/lm/recipes/for-lora?hash=${encodeURIComponent(sha256.toLowerCase())}`) + + // Fetch recipes that use this model by hash + fetch(`${getRecipesEndpoint(modelKind)}?hash=${encodeURIComponent(normalizedHash)}`) .then(response => response.json()) .then(data => { if (!data.success) { throw new Error(data.error || 'Failed to load recipes'); } - - renderRecipes(recipeTab, data.recipes, loraName, sha256); + + renderRecipes(recipeTab, data.recipes, { + modelKind, + displayName, + modelHash: normalizedHash, + modelLabel, + }); }) .catch(error => { - console.error('Error loading recipes for Lora:', error); + console.error(`Error loading recipes for ${modelLabel}:`, error); recipeTab.innerHTML = `
@@ -46,18 +55,24 @@ export function loadRecipesForLora(loraName, sha256) { * Renders the recipe cards in the tab * @param {HTMLElement} tabElement - The tab element to render into * @param {Array} recipes - Array of recipe objects - * @param {string} loraName - The display name of the Lora - * @param {string} loraHash - The hash of the Lora + * @param {Object} options - Render options */ -function renderRecipes(tabElement, recipes, loraName, loraHash) { +function renderRecipes(tabElement, recipes, options) { + const { + modelKind, + displayName, + modelHash, + modelLabel, + } = options; + if (!recipes || recipes.length === 0) { tabElement.innerHTML = `
-

No recipes found that use this Lora.

+

No recipes found that use this ${modelLabel}.

`; - + return; } @@ -73,13 +88,13 @@ function renderRecipes(tabElement, recipes, loraName, loraHash) { headerText.appendChild(eyebrow); const title = document.createElement('h3'); - title.textContent = `${recipes.length} recipe${recipes.length > 1 ? 's' : ''} using this Lora`; + title.textContent = `${recipes.length} recipe${recipes.length > 1 ? 's' : ''} using this ${modelLabel}`; headerText.appendChild(title); const description = document.createElement('p'); description.className = 'recipes-header__description'; - description.textContent = loraName ? - `Discover workflows crafted for ${loraName}.` : + description.textContent = displayName ? + `Discover workflows crafted for ${displayName}.` : 'Discover workflows crafted for this model.'; headerText.appendChild(description); @@ -101,7 +116,11 @@ function renderRecipes(tabElement, recipes, loraName, loraHash) { headerElement.appendChild(viewAllButton); viewAllButton.addEventListener('click', () => { - navigateToRecipesPage(loraName, loraHash); + navigateToRecipesPage({ + modelKind, + displayName, + modelHash, + }); }); const cardGrid = document.createElement('div'); @@ -280,26 +299,32 @@ function copyRecipeSyntax(recipeId) { } /** - * Navigates to the recipes page with filter for the current Lora - * @param {string} loraName - The Lora display name to filter by - * @param {string} loraHash - The hash of the Lora to filter by - * @param {boolean} createNew - Whether to open the create recipe dialog + * Navigates to the recipes page with filter for the current model + * @param {Object} options - Navigation options */ -function navigateToRecipesPage(loraName, loraHash) { +function navigateToRecipesPage({ modelKind, displayName, modelHash }) { // Close the current modal if (window.modalManager) { modalManager.closeModal('modelModal'); } - + // Clear any previous filters first removeSessionItem('lora_to_recipe_filterLoraName'); removeSessionItem('lora_to_recipe_filterLoraHash'); + removeSessionItem('checkpoint_to_recipe_filterCheckpointName'); + removeSessionItem('checkpoint_to_recipe_filterCheckpointHash'); removeSessionItem('viewRecipeId'); - - // Store the LoRA name and hash filter in sessionStorage - setSessionItem('lora_to_recipe_filterLoraName', loraName); - setSessionItem('lora_to_recipe_filterLoraHash', loraHash); - + + if (modelKind === 'checkpoint') { + // Store the checkpoint name and hash filter in sessionStorage + setSessionItem('checkpoint_to_recipe_filterCheckpointName', displayName); + setSessionItem('checkpoint_to_recipe_filterCheckpointHash', modelHash); + } else { + // Store the LoRA name and hash filter in sessionStorage + setSessionItem('lora_to_recipe_filterLoraName', displayName); + setSessionItem('lora_to_recipe_filterLoraHash', modelHash); + } + // Directly navigate to recipes page window.location.href = '/loras/recipes'; } @@ -321,7 +346,18 @@ function navigateToRecipeDetails(recipeId) { // Store the recipe ID in sessionStorage to load on recipes page setSessionItem('viewRecipeId', recipeId); - + // Directly navigate to recipes page window.location.href = '/loras/recipes'; } + +function getRecipesEndpoint(modelKind) { + if (modelKind === 'checkpoint') { + return '/api/lm/recipes/for-checkpoint'; + } + return '/api/lm/recipes/for-lora'; +} + +function getModelLabel(modelKind) { + return modelKind === 'checkpoint' ? 'checkpoint' : 'LoRA'; +} diff --git a/static/js/recipes.js b/static/js/recipes.js index 6e0cd2c6..7b12e971 100644 --- a/static/js/recipes.js +++ b/static/js/recipes.js @@ -66,6 +66,8 @@ class RecipeManager { active: false, loraName: null, loraHash: null, + checkpointName: null, + checkpointHash: null, recipeId: null }; } @@ -127,16 +129,20 @@ class RecipeManager { // Check for Lora filter const filterLoraName = getSessionItem('lora_to_recipe_filterLoraName'); const filterLoraHash = getSessionItem('lora_to_recipe_filterLoraHash'); + const filterCheckpointName = getSessionItem('checkpoint_to_recipe_filterCheckpointName'); + const filterCheckpointHash = getSessionItem('checkpoint_to_recipe_filterCheckpointHash'); // Check for specific recipe ID const viewRecipeId = getSessionItem('viewRecipeId'); // Set custom filter if any parameter is present - if (filterLoraName || filterLoraHash || viewRecipeId) { + if (filterLoraName || filterLoraHash || filterCheckpointName || filterCheckpointHash || viewRecipeId) { this.pageState.customFilter = { active: true, loraName: filterLoraName, loraHash: filterLoraHash, + checkpointName: filterCheckpointName, + checkpointHash: filterCheckpointHash, recipeId: viewRecipeId }; @@ -164,6 +170,13 @@ class RecipeManager { loraName; filterText = `Recipes using: ${displayName}`; + } else if (this.pageState.customFilter.checkpointName) { + const checkpointName = this.pageState.customFilter.checkpointName; + const displayName = checkpointName.length > 25 ? + checkpointName.substring(0, 22) + '...' : + checkpointName; + + filterText = `Recipes using checkpoint: ${displayName}`; } else { filterText = 'Filtered recipes'; } @@ -173,6 +186,10 @@ class RecipeManager { // Add title attribute to show the lora name as a tooltip if (this.pageState.customFilter.loraName) { textElement.setAttribute('title', this.pageState.customFilter.loraName); + } else if (this.pageState.customFilter.checkpointName) { + textElement.setAttribute('title', this.pageState.customFilter.checkpointName); + } else { + textElement.removeAttribute('title'); } indicator.classList.remove('hidden'); @@ -199,6 +216,8 @@ class RecipeManager { active: false, loraName: null, loraHash: null, + checkpointName: null, + checkpointHash: null, recipeId: null }; @@ -211,6 +230,8 @@ class RecipeManager { // Clear any session storage items removeSessionItem('lora_to_recipe_filterLoraName'); removeSessionItem('lora_to_recipe_filterLoraHash'); + removeSessionItem('checkpoint_to_recipe_filterCheckpointName'); + removeSessionItem('checkpoint_to_recipe_filterCheckpointHash'); removeSessionItem('viewRecipeId'); // Reset and refresh the virtual scroller diff --git a/tests/frontend/components/modelMetadata.renamePath.test.js b/tests/frontend/components/modelMetadata.renamePath.test.js index ef3b6b41..9fac235a 100644 --- a/tests/frontend/components/modelMetadata.renamePath.test.js +++ b/tests/frontend/components/modelMetadata.renamePath.test.js @@ -82,7 +82,7 @@ vi.mock(MODEL_VERSIONS_MODULE, () => ({ })); vi.mock(RECIPE_TAB_MODULE, () => ({ - loadRecipesForLora: vi.fn(), + loadRecipesForModel: vi.fn(), })); vi.mock(I18N_HELPERS_MODULE, () => ({ @@ -103,11 +103,14 @@ vi.mock(API_FACTORY, () => ({ describe('Model metadata interactions keep file path in sync', () => { let getModelApiClient; + let loadRecipesForModel; beforeEach(async () => { document.body.innerHTML = ''; ({ getModelApiClient } = await import(API_FACTORY)); + ({ loadRecipesForModel } = await import(RECIPE_TAB_MODULE)); getModelApiClient.mockReset(); + loadRecipesForModel.mockReset(); }); afterEach(() => { @@ -206,4 +209,33 @@ describe('Model metadata interactions keep file path in sync', () => { expect(saveModelMetadata).toHaveBeenCalledWith('models/Qwen.testing.safetensors', { notes: 'Updated notes' }); }); }); + + it('shows recipes tab for checkpoint modals and loads linked recipes by hash', async () => { + const fetchModelMetadata = vi.fn().mockResolvedValue(null); + + getModelApiClient.mockReturnValue({ + fetchModelMetadata, + saveModelMetadata: vi.fn(), + }); + + const { showModelModal } = await import(MODAL_MODULE); + + await showModelModal( + { + model_name: 'Flux Base', + file_path: 'models/checkpoints/flux-base.safetensors', + file_name: 'flux-base.safetensors', + sha256: 'ABC123', + civitai: {}, + }, + 'checkpoints', + ); + + expect(document.querySelector('.tab-btn[data-tab="recipes"]')).not.toBeNull(); + expect(loadRecipesForModel).toHaveBeenCalledWith({ + modelKind: 'checkpoint', + displayName: 'Flux Base', + sha256: 'ABC123', + }); + }); }); diff --git a/tests/frontend/components/modelModal.licenseIcons.test.js b/tests/frontend/components/modelModal.licenseIcons.test.js index 7a9ad1d7..f3f94b5d 100644 --- a/tests/frontend/components/modelModal.licenseIcons.test.js +++ b/tests/frontend/components/modelModal.licenseIcons.test.js @@ -80,7 +80,7 @@ vi.mock(MODEL_VERSIONS_MODULE, () => ({ })); vi.mock(RECIPE_TAB_MODULE, () => ({ - loadRecipesForLora: vi.fn(), + loadRecipesForModel: vi.fn(), })); vi.mock(I18N_HELPERS_MODULE, () => ({ diff --git a/tests/frontend/pages/recipesPage.test.js b/tests/frontend/pages/recipesPage.test.js index 07b4706c..cb514d5e 100644 --- a/tests/frontend/pages/recipesPage.test.js +++ b/tests/frontend/pages/recipesPage.test.js @@ -6,6 +6,7 @@ const initializePageFeaturesMock = vi.fn(); const getCurrentPageStateMock = vi.fn(); const getSessionItemMock = vi.fn(); const removeSessionItemMock = vi.fn(); +const getStorageItemMock = vi.fn(); const RecipeContextMenuMock = vi.fn(); const refreshVirtualScrollMock = vi.fn(); const refreshRecipesMock = vi.fn(); @@ -51,6 +52,7 @@ vi.mock('../../../static/js/state/index.js', () => ({ vi.mock('../../../static/js/utils/storageHelpers.js', () => ({ getSessionItem: getSessionItemMock, removeSessionItem: removeSessionItemMock, + getStorageItem: getStorageItemMock, })); vi.mock('../../../static/js/components/ContextMenu/index.js', () => ({ @@ -117,11 +119,14 @@ describe('RecipeManager', () => { const map = { lora_to_recipe_filterLoraName: 'Flux Dream', lora_to_recipe_filterLoraHash: 'abc123', + checkpoint_to_recipe_filterCheckpointName: null, + checkpoint_to_recipe_filterCheckpointHash: null, viewRecipeId: '42', }; return map[key] ?? null; }); removeSessionItemMock.mockImplementation(() => { }); + getStorageItemMock.mockImplementation((_, defaultValue = null) => defaultValue); renderRecipesPage(); @@ -166,6 +171,8 @@ describe('RecipeManager', () => { active: true, loraName: 'Flux Dream', loraHash: 'abc123', + checkpointName: null, + checkpointHash: null, recipeId: '42', }); @@ -177,6 +184,8 @@ describe('RecipeManager', () => { expect(removeSessionItemMock).toHaveBeenCalledWith('lora_to_recipe_filterLoraName'); expect(removeSessionItemMock).toHaveBeenCalledWith('lora_to_recipe_filterLoraHash'); + expect(removeSessionItemMock).toHaveBeenCalledWith('checkpoint_to_recipe_filterCheckpointName'); + expect(removeSessionItemMock).toHaveBeenCalledWith('checkpoint_to_recipe_filterCheckpointHash'); expect(removeSessionItemMock).toHaveBeenCalledWith('viewRecipeId'); expect(pageState.customFilter.active).toBe(false); expect(indicator.classList.contains('hidden')).toBe(true); @@ -227,4 +236,36 @@ describe('RecipeManager', () => { await manager.refreshRecipes(); expect(refreshRecipesMock).toHaveBeenCalledTimes(1); }); + + it('restores checkpoint recipe filter state and indicator text', async () => { + getSessionItemMock.mockImplementation((key) => { + const map = { + lora_to_recipe_filterLoraName: null, + lora_to_recipe_filterLoraHash: null, + checkpoint_to_recipe_filterCheckpointName: 'Flux Base', + checkpoint_to_recipe_filterCheckpointHash: 'ckpt123', + viewRecipeId: null, + }; + return map[key] ?? null; + }); + + const manager = new RecipeManager(); + await manager.initialize(); + + expect(pageState.customFilter).toEqual({ + active: true, + loraName: null, + loraHash: null, + checkpointName: 'Flux Base', + checkpointHash: 'ckpt123', + recipeId: null, + }); + + const indicator = document.getElementById('customFilterIndicator'); + const filterText = indicator.querySelector('#customFilterText'); + + expect(filterText.innerHTML).toContain('Recipes using checkpoint:'); + expect(filterText.innerHTML).toContain('Flux Base'); + expect(filterText.getAttribute('title')).toBe('Flux Base'); + }); }); diff --git a/tests/routes/test_recipe_routes.py b/tests/routes/test_recipe_routes.py index fba9882e..1fd09ae1 100644 --- a/tests/routes/test_recipe_routes.py +++ b/tests/routes/test_recipe_routes.py @@ -43,6 +43,9 @@ class StubRecipeScanner: self.cached_raw: List[Dict[str, Any]] = [] self.recipes: Dict[str, Dict[str, Any]] = {} self.removed: List[str] = [] + self.last_paginated_params: Dict[str, Any] | None = None + self.lora_lookup: Dict[str, List[Dict[str, Any]]] = {} + self.checkpoint_lookup: Dict[str, List[Dict[str, Any]]] = {} async def _noop_get_cached_data(force_refresh: bool = False) -> None: # noqa: ARG001 - signature mirrors real scanner return None @@ -56,6 +59,7 @@ class StubRecipeScanner: return SimpleNamespace(raw_data=list(self.cached_raw)) async def get_paginated_data(self, **params: Any) -> Dict[str, Any]: + self.last_paginated_params = params items = [dict(item) for item in self.listing_items] page = int(params.get("page", 1)) page_size = int(params.get("page_size", 20)) @@ -70,6 +74,14 @@ class StubRecipeScanner: async def get_recipe_by_id(self, recipe_id: str) -> Optional[Dict[str, Any]]: return self.recipes.get(recipe_id) + async def get_recipes_for_lora(self, lora_hash: str) -> List[Dict[str, Any]]: + return list(self.lora_lookup.get(lora_hash.lower(), [])) + + async def get_recipes_for_checkpoint( + self, checkpoint_hash: str + ) -> List[Dict[str, Any]]: + return list(self.checkpoint_lookup.get(checkpoint_hash.lower(), [])) + async def get_recipe_json_path(self, recipe_id: str) -> Optional[str]: candidate = Path(self.recipes_dir) / f"{recipe_id}.recipe.json" return str(candidate) if candidate.exists() else None @@ -350,6 +362,47 @@ async def test_list_recipes_provides_file_urls(monkeypatch, tmp_path: Path) -> N assert payload["items"][0]["loras"] == [] +async def test_list_recipes_passes_checkpoint_hash_filter( + monkeypatch, tmp_path: Path +) -> None: + async with recipe_harness(monkeypatch, tmp_path) as harness: + response = await harness.client.get("/api/lm/recipes?checkpoint_hash=ckpt123") + payload = await response.json() + + assert response.status == 200 + assert payload["items"] == [] + assert harness.scanner.last_paginated_params["checkpoint_hash"] == "ckpt123" + + +async def test_get_recipes_for_checkpoint(monkeypatch, tmp_path: Path) -> None: + async with recipe_harness(monkeypatch, tmp_path) as harness: + harness.scanner.checkpoint_lookup["abc123"] = [ + {"id": "recipe-1", "title": "Linked recipe"} + ] + + response = await harness.client.get( + "/api/lm/recipes/for-checkpoint?hash=ABC123" + ) + payload = await response.json() + + assert response.status == 200 + assert payload == { + "success": True, + "recipes": [{"id": "recipe-1", "title": "Linked recipe"}], + } + + +async def test_get_recipes_for_checkpoint_requires_hash( + monkeypatch, tmp_path: Path +) -> None: + async with recipe_harness(monkeypatch, tmp_path) as harness: + response = await harness.client.get("/api/lm/recipes/for-checkpoint") + payload = await response.json() + + assert response.status == 400 + assert payload["success"] is False + + async def test_save_and_delete_recipe_round_trip(monkeypatch, tmp_path: Path) -> None: async with recipe_harness(monkeypatch, tmp_path) as harness: form = FormData() diff --git a/tests/services/test_recipe_scanner.py b/tests/services/test_recipe_scanner.py index 692f3dae..01102f40 100644 --- a/tests/services/test_recipe_scanner.py +++ b/tests/services/test_recipe_scanner.py @@ -313,6 +313,75 @@ async def test_get_recipe_by_id_handles_non_dict_checkpoint(recipe_scanner): assert recipe["checkpoint"]["file_name"] == "by-id" +@pytest.mark.asyncio +async def test_get_paginated_data_filters_by_checkpoint_hash(recipe_scanner): + scanner, _ = recipe_scanner + image_path = Path(config.loras_roots[0]) / "checkpoint-filter.webp" + await scanner.add_recipe( + { + "id": "checkpoint-match", + "file_path": str(image_path), + "title": "Checkpoint Match", + "modified": 0.0, + "created_date": 0.0, + "loras": [], + "checkpoint": { + "name": "flux-base.safetensors", + "hash": "ABC123", + }, + } + ) + await scanner.add_recipe( + { + "id": "checkpoint-miss", + "file_path": str(Path(config.loras_roots[0]) / "checkpoint-miss.webp"), + "title": "Checkpoint Miss", + "modified": 1.0, + "created_date": 1.0, + "loras": [], + "checkpoint": { + "name": "other.safetensors", + "hash": "zzz999", + }, + } + ) + await asyncio.sleep(0) + + result = await scanner.get_paginated_data( + page=1, + page_size=10, + checkpoint_hash="abc123", + ) + + assert [item["id"] for item in result["items"]] == ["checkpoint-match"] + + +@pytest.mark.asyncio +async def test_get_recipes_for_checkpoint_matches_hash_case_insensitively(recipe_scanner): + scanner, _ = recipe_scanner + image_path = Path(config.loras_roots[0]) / "checkpoint-linked.webp" + await scanner.add_recipe( + { + "id": "checkpoint-linked", + "file_path": str(image_path), + "title": "Checkpoint Linked", + "modified": 0.0, + "created_date": 0.0, + "loras": [], + "checkpoint": { + "name": "flux-base.safetensors", + "hash": "ABC123", + }, + } + ) + + recipes = await scanner.get_recipes_for_checkpoint("abc123") + + assert len(recipes) == 1 + assert recipes[0]["id"] == "checkpoint-linked" + assert recipes[0]["checkpoint"]["hash"] == "ABC123" + + def test_enrich_uses_version_index_when_hash_missing(recipe_scanner): scanner, stub = recipe_scanner version_id = 77