fix(recipe): show checkpoint-linked recipes in model modal (#851)

This commit is contained in:
Will Miao
2026-03-31 16:45:01 +08:00
parent 316f17dd46
commit 8dc2a2f76b
12 changed files with 393 additions and 51 deletions

View File

@@ -81,6 +81,7 @@ class RecipeHandlerSet:
"bulk_delete": self.management.bulk_delete, "bulk_delete": self.management.bulk_delete,
"save_recipe_from_widget": self.management.save_recipe_from_widget, "save_recipe_from_widget": self.management.save_recipe_from_widget,
"get_recipes_for_lora": self.query.get_recipes_for_lora, "get_recipes_for_lora": self.query.get_recipes_for_lora,
"get_recipes_for_checkpoint": self.query.get_recipes_for_checkpoint,
"scan_recipes": self.query.scan_recipes, "scan_recipes": self.query.scan_recipes,
"move_recipe": self.management.move_recipe, "move_recipe": self.management.move_recipe,
"repair_recipes": self.management.repair_recipes, "repair_recipes": self.management.repair_recipes,
@@ -218,6 +219,7 @@ class RecipeListingHandler:
filters["tags"] = tag_filters filters["tags"] = tag_filters
lora_hash = request.query.get("lora_hash") lora_hash = request.query.get("lora_hash")
checkpoint_hash = request.query.get("checkpoint_hash")
result = await recipe_scanner.get_paginated_data( result = await recipe_scanner.get_paginated_data(
page=page, page=page,
@@ -227,6 +229,7 @@ class RecipeListingHandler:
filters=filters, filters=filters,
search_options=search_options, search_options=search_options,
lora_hash=lora_hash, lora_hash=lora_hash,
checkpoint_hash=checkpoint_hash,
folder=folder, folder=folder,
recursive=recursive, recursive=recursive,
) )
@@ -423,6 +426,28 @@ class RecipeQueryHandler:
self._logger.error("Error getting recipes for Lora: %s", exc) self._logger.error("Error getting recipes for Lora: %s", exc)
return web.json_response({"success": False, "error": str(exc)}, status=500) return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_recipes_for_checkpoint(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
raise RuntimeError("Recipe scanner unavailable")
checkpoint_hash = request.query.get("hash")
if not checkpoint_hash:
return web.json_response(
{"success": False, "error": "Checkpoint hash is required"},
status=400,
)
matching_recipes = await recipe_scanner.get_recipes_for_checkpoint(
checkpoint_hash
)
return web.json_response({"success": True, "recipes": matching_recipes})
except Exception as exc:
self._logger.error("Error getting recipes for checkpoint: %s", exc)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def scan_recipes(self, request: web.Request) -> web.Response: async def scan_recipes(self, request: web.Request) -> web.Response:
try: try:
await self._ensure_dependencies_ready() await self._ensure_dependencies_ready()

View File

@@ -51,6 +51,9 @@ ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
"POST", "/api/lm/recipes/save-from-widget", "save_recipe_from_widget" "POST", "/api/lm/recipes/save-from-widget", "save_recipe_from_widget"
), ),
RouteDefinition("GET", "/api/lm/recipes/for-lora", "get_recipes_for_lora"), RouteDefinition("GET", "/api/lm/recipes/for-lora", "get_recipes_for_lora"),
RouteDefinition(
"GET", "/api/lm/recipes/for-checkpoint", "get_recipes_for_checkpoint"
),
RouteDefinition("GET", "/api/lm/recipes/scan", "scan_recipes"), RouteDefinition("GET", "/api/lm/recipes/scan", "scan_recipes"),
RouteDefinition("POST", "/api/lm/recipes/repair", "repair_recipes"), RouteDefinition("POST", "/api/lm/recipes/repair", "repair_recipes"),
RouteDefinition("POST", "/api/lm/recipes/cancel-repair", "cancel_repair"), RouteDefinition("POST", "/api/lm/recipes/cancel-repair", "cancel_repair"),

View File

@@ -1615,6 +1615,9 @@ class RecipeScanner:
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
"""Coerce legacy or malformed checkpoint entries into a dict.""" """Coerce legacy or malformed checkpoint entries into a dict."""
if checkpoint_raw is None:
return None
if isinstance(checkpoint_raw, dict): if isinstance(checkpoint_raw, dict):
return dict(checkpoint_raw) return dict(checkpoint_raw)
@@ -1632,9 +1635,6 @@ class RecipeScanner:
"file_name": file_name, "file_name": file_name,
} }
logger.warning(
"Unexpected checkpoint payload type %s", type(checkpoint_raw).__name__
)
return None return None
def _enrich_checkpoint_entry(self, checkpoint: Dict[str, Any]) -> Dict[str, Any]: def _enrich_checkpoint_entry(self, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
@@ -1790,6 +1790,7 @@ class RecipeScanner:
filters: dict = None, filters: dict = None,
search_options: dict = None, search_options: dict = None,
lora_hash: str = None, lora_hash: str = None,
checkpoint_hash: str = None,
bypass_filters: bool = True, bypass_filters: bool = True,
folder: str | None = None, folder: str | None = None,
recursive: bool = True, recursive: bool = True,
@@ -1804,7 +1805,8 @@ class RecipeScanner:
filters: Dictionary of filters to apply filters: Dictionary of filters to apply
search_options: Dictionary of search options to apply search_options: Dictionary of search options to apply
lora_hash: Optional SHA256 hash of a LoRA to filter recipes by lora_hash: Optional SHA256 hash of a LoRA to filter recipes by
bypass_filters: If True, ignore other filters when a lora_hash is provided checkpoint_hash: Optional SHA256 hash of a checkpoint to filter recipes by
bypass_filters: If True, ignore other filters when a hash filter is provided
folder: Optional folder filter relative to recipes directory folder: Optional folder filter relative to recipes directory
recursive: Whether to include recipes in subfolders of the selected folder recursive: Whether to include recipes in subfolders of the selected folder
""" """
@@ -1852,9 +1854,23 @@ class RecipeScanner:
# Skip other filters if bypass_filters is True # Skip other filters if bypass_filters is True
pass pass
# Otherwise continue with normal filtering after applying LoRA hash filter # Otherwise continue with normal filtering after applying LoRA hash filter
elif checkpoint_hash:
normalized_checkpoint_hash = checkpoint_hash.lower()
filtered_data = [
item
for item in filtered_data
if isinstance(item.get("checkpoint"), dict)
and (item["checkpoint"].get("hash", "") or "").lower()
== normalized_checkpoint_hash
]
# Skip further filtering if we're only filtering by LoRA hash with bypass enabled if bypass_filters:
if not (lora_hash and bypass_filters): pass
has_hash_filter = bool(lora_hash or checkpoint_hash)
# Skip further filtering if we're only filtering by model hash with bypass enabled
if not (has_hash_filter and bypass_filters):
# Apply folder filter before other criteria # Apply folder filter before other criteria
if folder is not None: if folder is not None:
normalized_folder = folder.strip("/") normalized_folder = folder.strip("/")
@@ -2334,6 +2350,38 @@ class RecipeScanner:
return matching_recipes return matching_recipes
async def get_recipes_for_checkpoint(
self, checkpoint_hash: str
) -> List[Dict[str, Any]]:
"""Return recipes that reference a given checkpoint hash."""
if not checkpoint_hash:
return []
normalized_hash = checkpoint_hash.lower()
cache = await self.get_cached_data()
matching_recipes: List[Dict[str, Any]] = []
for recipe in cache.raw_data:
checkpoint = self._normalize_checkpoint_entry(recipe.get("checkpoint"))
if not checkpoint:
continue
enriched_checkpoint = self._enrich_checkpoint_entry(dict(checkpoint))
if (enriched_checkpoint.get("hash") or "").lower() != normalized_hash:
continue
recipe_copy = {**recipe}
recipe_copy["checkpoint"] = enriched_checkpoint
recipe_copy["loras"] = [
self._enrich_lora_entry(dict(entry))
for entry in recipe.get("loras", [])
]
recipe_copy["file_url"] = self._format_file_url(recipe.get("file_path"))
matching_recipes.append(recipe_copy)
return matching_recipes
async def get_recipe_syntax_tokens(self, recipe_id: str) -> List[str]: async def get_recipe_syntax_tokens(self, recipe_id: str) -> List[str]:
"""Build LoRA syntax tokens for a recipe.""" """Build LoRA syntax tokens for a recipe."""

View File

@@ -83,6 +83,9 @@ export async function fetchRecipesPage(page = 1, pageSize = 100) {
if (pageState.customFilter?.active && pageState.customFilter?.loraHash) { if (pageState.customFilter?.active && pageState.customFilter?.loraHash) {
params.append('lora_hash', pageState.customFilter.loraHash); params.append('lora_hash', pageState.customFilter.loraHash);
params.append('bypass_filters', 'true'); params.append('bypass_filters', 'true');
} else if (pageState.customFilter?.active && pageState.customFilter?.checkpointHash) {
params.append('checkpoint_hash', pageState.customFilter.checkpointHash);
params.append('bypass_filters', 'true');
} else { } else {
// Normal filtering logic // Normal filtering logic

View File

@@ -19,7 +19,7 @@ import { renderCompactTags, setupTagTooltip, formatFileSize, escapeAttribute, es
import { renderTriggerWords, setupTriggerWordsEditMode } from './TriggerWords.js'; import { renderTriggerWords, setupTriggerWordsEditMode } from './TriggerWords.js';
import { parsePresets, renderPresetTags } from './PresetTags.js'; import { parsePresets, renderPresetTags } from './PresetTags.js';
import { initVersionsTab } from './ModelVersionsTab.js'; import { initVersionsTab } from './ModelVersionsTab.js';
import { loadRecipesForLora } from './RecipeTab.js'; import { loadRecipesForModel } from './RecipeTab.js';
import { translate } from '../../utils/i18nHelpers.js'; import { translate } from '../../utils/i18nHelpers.js';
import { state } from '../../state/index.js'; import { state } from '../../state/index.js';
@@ -355,7 +355,9 @@ export async function showModelModal(model, modelType) {
${versionsTabBadge} ${versionsTabBadge}
</button>`.trim(); </button>`.trim();
const tabsContent = modelType === 'loras' ? const supportsRecipesTab = modelType === 'loras' || modelType === 'checkpoints';
const tabsContent = supportsRecipesTab ?
`<button class="tab-btn active" data-tab="showcase">${examplesText}</button> `<button class="tab-btn active" data-tab="showcase">${examplesText}</button>
<button class="tab-btn" data-tab="description">${descriptionText}</button> <button class="tab-btn" data-tab="description">${descriptionText}</button>
${versionsTabButton} ${versionsTabButton}
@@ -385,7 +387,7 @@ export async function showModelModal(model, modelType) {
</button> </button>
</div>`.trim(); </div>`.trim();
const tabPanesContent = modelType === 'loras' ? const tabPanesContent = supportsRecipesTab ?
`<div id="showcase-tab" class="tab-pane active"> `<div id="showcase-tab" class="tab-pane active">
<div class="example-images-loading"> <div class="example-images-loading">
<i class="fas fa-spinner fa-spin"></i> ${loadingExampleImagesText} <i class="fas fa-spinner fa-spin"></i> ${loadingExampleImagesText}
@@ -664,14 +666,23 @@ export async function showModelModal(model, modelType) {
setupNavigationShortcuts(modelType); setupNavigationShortcuts(modelType);
updateNavigationControls(); updateNavigationControls();
// LoRA specific setup // Model-specific setup
if (modelType === 'loras' || modelType === 'embeddings') { if (modelType === 'loras' || modelType === 'embeddings') {
setupTriggerWordsEditMode(); setupTriggerWordsEditMode();
}
if (modelType == 'loras') { if (modelType === 'loras') {
// Load recipes for this LoRA loadRecipesForModel({
loadRecipesForLora(modelWithFullData.model_name, modelWithFullData.sha256); modelKind: 'lora',
} displayName: modelWithFullData.model_name,
sha256: modelWithFullData.sha256,
});
} else if (modelType === 'checkpoints') {
loadRecipesForModel({
modelKind: 'checkpoint',
displayName: modelWithFullData.model_name,
sha256: modelWithFullData.sha256,
});
} }
// Load example images asynchronously - merge regular and custom images // Load example images asynchronously - merge regular and custom images

View File

@@ -1,38 +1,47 @@
/** /**
* RecipeTab - Handles the recipes tab in model modals (LoRA specific functionality) * RecipeTab - Handles the recipes tab in model modals.
* Moved to shared directory for consistency
*/ */
import { showToast, copyToClipboard } from '../../utils/uiHelpers.js'; import { showToast, copyToClipboard } from '../../utils/uiHelpers.js';
import { setSessionItem, removeSessionItem } from '../../utils/storageHelpers.js'; import { setSessionItem, removeSessionItem } from '../../utils/storageHelpers.js';
/** /**
* Loads recipes that use the specified Lora and renders them in the tab * Loads recipes that use the specified model and renders them in the tab.
* @param {string} loraName - The display name of the Lora * @param {Object} options
* @param {string} sha256 - The SHA256 hash of the Lora * @param {'lora'|'checkpoint'} options.modelKind - Model kind for copy and endpoint selection
* @param {string} options.displayName - The display name of the model
* @param {string} options.sha256 - The SHA256 hash of the model
*/ */
export function loadRecipesForLora(loraName, sha256) { export function loadRecipesForModel({ modelKind, displayName, sha256 }) {
const recipeTab = document.getElementById('recipes-tab'); const recipeTab = document.getElementById('recipes-tab');
if (!recipeTab) return; if (!recipeTab) return;
const normalizedHash = sha256?.toLowerCase?.() || '';
const modelLabel = getModelLabel(modelKind);
// Show loading state // Show loading state
recipeTab.innerHTML = ` recipeTab.innerHTML = `
<div class="recipes-loading"> <div class="recipes-loading">
<i class="fas fa-spinner fa-spin"></i> Loading recipes... <i class="fas fa-spinner fa-spin"></i> Loading recipes...
</div> </div>
`; `;
// Fetch recipes that use this Lora by hash // Fetch recipes that use this model by hash
fetch(`/api/lm/recipes/for-lora?hash=${encodeURIComponent(sha256.toLowerCase())}`) fetch(`${getRecipesEndpoint(modelKind)}?hash=${encodeURIComponent(normalizedHash)}`)
.then(response => response.json()) .then(response => response.json())
.then(data => { .then(data => {
if (!data.success) { if (!data.success) {
throw new Error(data.error || 'Failed to load recipes'); throw new Error(data.error || 'Failed to load recipes');
} }
renderRecipes(recipeTab, data.recipes, loraName, sha256); renderRecipes(recipeTab, data.recipes, {
modelKind,
displayName,
modelHash: normalizedHash,
modelLabel,
});
}) })
.catch(error => { .catch(error => {
console.error('Error loading recipes for Lora:', error); console.error(`Error loading recipes for ${modelLabel}:`, error);
recipeTab.innerHTML = ` recipeTab.innerHTML = `
<div class="recipes-error"> <div class="recipes-error">
<i class="fas fa-exclamation-circle"></i> <i class="fas fa-exclamation-circle"></i>
@@ -46,18 +55,24 @@ export function loadRecipesForLora(loraName, sha256) {
* Renders the recipe cards in the tab * Renders the recipe cards in the tab
* @param {HTMLElement} tabElement - The tab element to render into * @param {HTMLElement} tabElement - The tab element to render into
* @param {Array} recipes - Array of recipe objects * @param {Array} recipes - Array of recipe objects
* @param {string} loraName - The display name of the Lora * @param {Object} options - Render options
* @param {string} loraHash - The hash of the Lora
*/ */
function renderRecipes(tabElement, recipes, loraName, loraHash) { function renderRecipes(tabElement, recipes, options) {
const {
modelKind,
displayName,
modelHash,
modelLabel,
} = options;
if (!recipes || recipes.length === 0) { if (!recipes || recipes.length === 0) {
tabElement.innerHTML = ` tabElement.innerHTML = `
<div class="recipes-empty"> <div class="recipes-empty">
<i class="fas fa-book-open"></i> <i class="fas fa-book-open"></i>
<p>No recipes found that use this Lora.</p> <p>No recipes found that use this ${modelLabel}.</p>
</div> </div>
`; `;
return; return;
} }
@@ -73,13 +88,13 @@ function renderRecipes(tabElement, recipes, loraName, loraHash) {
headerText.appendChild(eyebrow); headerText.appendChild(eyebrow);
const title = document.createElement('h3'); const title = document.createElement('h3');
title.textContent = `${recipes.length} recipe${recipes.length > 1 ? 's' : ''} using this Lora`; title.textContent = `${recipes.length} recipe${recipes.length > 1 ? 's' : ''} using this ${modelLabel}`;
headerText.appendChild(title); headerText.appendChild(title);
const description = document.createElement('p'); const description = document.createElement('p');
description.className = 'recipes-header__description'; description.className = 'recipes-header__description';
description.textContent = loraName ? description.textContent = displayName ?
`Discover workflows crafted for ${loraName}.` : `Discover workflows crafted for ${displayName}.` :
'Discover workflows crafted for this model.'; 'Discover workflows crafted for this model.';
headerText.appendChild(description); headerText.appendChild(description);
@@ -101,7 +116,11 @@ function renderRecipes(tabElement, recipes, loraName, loraHash) {
headerElement.appendChild(viewAllButton); headerElement.appendChild(viewAllButton);
viewAllButton.addEventListener('click', () => { viewAllButton.addEventListener('click', () => {
navigateToRecipesPage(loraName, loraHash); navigateToRecipesPage({
modelKind,
displayName,
modelHash,
});
}); });
const cardGrid = document.createElement('div'); const cardGrid = document.createElement('div');
@@ -280,26 +299,32 @@ function copyRecipeSyntax(recipeId) {
} }
/** /**
* Navigates to the recipes page with filter for the current Lora * Navigates to the recipes page with filter for the current model
* @param {string} loraName - The Lora display name to filter by * @param {Object} options - Navigation options
* @param {string} loraHash - The hash of the Lora to filter by
* @param {boolean} createNew - Whether to open the create recipe dialog
*/ */
function navigateToRecipesPage(loraName, loraHash) { function navigateToRecipesPage({ modelKind, displayName, modelHash }) {
// Close the current modal // Close the current modal
if (window.modalManager) { if (window.modalManager) {
modalManager.closeModal('modelModal'); modalManager.closeModal('modelModal');
} }
// Clear any previous filters first // Clear any previous filters first
removeSessionItem('lora_to_recipe_filterLoraName'); removeSessionItem('lora_to_recipe_filterLoraName');
removeSessionItem('lora_to_recipe_filterLoraHash'); removeSessionItem('lora_to_recipe_filterLoraHash');
removeSessionItem('checkpoint_to_recipe_filterCheckpointName');
removeSessionItem('checkpoint_to_recipe_filterCheckpointHash');
removeSessionItem('viewRecipeId'); removeSessionItem('viewRecipeId');
// Store the LoRA name and hash filter in sessionStorage if (modelKind === 'checkpoint') {
setSessionItem('lora_to_recipe_filterLoraName', loraName); // Store the checkpoint name and hash filter in sessionStorage
setSessionItem('lora_to_recipe_filterLoraHash', loraHash); setSessionItem('checkpoint_to_recipe_filterCheckpointName', displayName);
setSessionItem('checkpoint_to_recipe_filterCheckpointHash', modelHash);
} else {
// Store the LoRA name and hash filter in sessionStorage
setSessionItem('lora_to_recipe_filterLoraName', displayName);
setSessionItem('lora_to_recipe_filterLoraHash', modelHash);
}
// Directly navigate to recipes page // Directly navigate to recipes page
window.location.href = '/loras/recipes'; window.location.href = '/loras/recipes';
} }
@@ -321,7 +346,18 @@ function navigateToRecipeDetails(recipeId) {
// Store the recipe ID in sessionStorage to load on recipes page // Store the recipe ID in sessionStorage to load on recipes page
setSessionItem('viewRecipeId', recipeId); setSessionItem('viewRecipeId', recipeId);
// Directly navigate to recipes page // Directly navigate to recipes page
window.location.href = '/loras/recipes'; window.location.href = '/loras/recipes';
} }
function getRecipesEndpoint(modelKind) {
if (modelKind === 'checkpoint') {
return '/api/lm/recipes/for-checkpoint';
}
return '/api/lm/recipes/for-lora';
}
function getModelLabel(modelKind) {
return modelKind === 'checkpoint' ? 'checkpoint' : 'LoRA';
}

View File

@@ -66,6 +66,8 @@ class RecipeManager {
active: false, active: false,
loraName: null, loraName: null,
loraHash: null, loraHash: null,
checkpointName: null,
checkpointHash: null,
recipeId: null recipeId: null
}; };
} }
@@ -127,16 +129,20 @@ class RecipeManager {
// Check for Lora filter // Check for Lora filter
const filterLoraName = getSessionItem('lora_to_recipe_filterLoraName'); const filterLoraName = getSessionItem('lora_to_recipe_filterLoraName');
const filterLoraHash = getSessionItem('lora_to_recipe_filterLoraHash'); const filterLoraHash = getSessionItem('lora_to_recipe_filterLoraHash');
const filterCheckpointName = getSessionItem('checkpoint_to_recipe_filterCheckpointName');
const filterCheckpointHash = getSessionItem('checkpoint_to_recipe_filterCheckpointHash');
// Check for specific recipe ID // Check for specific recipe ID
const viewRecipeId = getSessionItem('viewRecipeId'); const viewRecipeId = getSessionItem('viewRecipeId');
// Set custom filter if any parameter is present // Set custom filter if any parameter is present
if (filterLoraName || filterLoraHash || viewRecipeId) { if (filterLoraName || filterLoraHash || filterCheckpointName || filterCheckpointHash || viewRecipeId) {
this.pageState.customFilter = { this.pageState.customFilter = {
active: true, active: true,
loraName: filterLoraName, loraName: filterLoraName,
loraHash: filterLoraHash, loraHash: filterLoraHash,
checkpointName: filterCheckpointName,
checkpointHash: filterCheckpointHash,
recipeId: viewRecipeId recipeId: viewRecipeId
}; };
@@ -164,6 +170,13 @@ class RecipeManager {
loraName; loraName;
filterText = `<span>Recipes using: <span class="lora-name">${displayName}</span></span>`; filterText = `<span>Recipes using: <span class="lora-name">${displayName}</span></span>`;
} else if (this.pageState.customFilter.checkpointName) {
const checkpointName = this.pageState.customFilter.checkpointName;
const displayName = checkpointName.length > 25 ?
checkpointName.substring(0, 22) + '...' :
checkpointName;
filterText = `<span>Recipes using checkpoint: <span class="lora-name">${displayName}</span></span>`;
} else { } else {
filterText = 'Filtered recipes'; filterText = 'Filtered recipes';
} }
@@ -173,6 +186,10 @@ class RecipeManager {
// Add title attribute to show the lora name as a tooltip // Add title attribute to show the lora name as a tooltip
if (this.pageState.customFilter.loraName) { if (this.pageState.customFilter.loraName) {
textElement.setAttribute('title', this.pageState.customFilter.loraName); textElement.setAttribute('title', this.pageState.customFilter.loraName);
} else if (this.pageState.customFilter.checkpointName) {
textElement.setAttribute('title', this.pageState.customFilter.checkpointName);
} else {
textElement.removeAttribute('title');
} }
indicator.classList.remove('hidden'); indicator.classList.remove('hidden');
@@ -199,6 +216,8 @@ class RecipeManager {
active: false, active: false,
loraName: null, loraName: null,
loraHash: null, loraHash: null,
checkpointName: null,
checkpointHash: null,
recipeId: null recipeId: null
}; };
@@ -211,6 +230,8 @@ class RecipeManager {
// Clear any session storage items // Clear any session storage items
removeSessionItem('lora_to_recipe_filterLoraName'); removeSessionItem('lora_to_recipe_filterLoraName');
removeSessionItem('lora_to_recipe_filterLoraHash'); removeSessionItem('lora_to_recipe_filterLoraHash');
removeSessionItem('checkpoint_to_recipe_filterCheckpointName');
removeSessionItem('checkpoint_to_recipe_filterCheckpointHash');
removeSessionItem('viewRecipeId'); removeSessionItem('viewRecipeId');
// Reset and refresh the virtual scroller // Reset and refresh the virtual scroller

View File

@@ -82,7 +82,7 @@ vi.mock(MODEL_VERSIONS_MODULE, () => ({
})); }));
vi.mock(RECIPE_TAB_MODULE, () => ({ vi.mock(RECIPE_TAB_MODULE, () => ({
loadRecipesForLora: vi.fn(), loadRecipesForModel: vi.fn(),
})); }));
vi.mock(I18N_HELPERS_MODULE, () => ({ vi.mock(I18N_HELPERS_MODULE, () => ({
@@ -103,11 +103,14 @@ vi.mock(API_FACTORY, () => ({
describe('Model metadata interactions keep file path in sync', () => { describe('Model metadata interactions keep file path in sync', () => {
let getModelApiClient; let getModelApiClient;
let loadRecipesForModel;
beforeEach(async () => { beforeEach(async () => {
document.body.innerHTML = ''; document.body.innerHTML = '';
({ getModelApiClient } = await import(API_FACTORY)); ({ getModelApiClient } = await import(API_FACTORY));
({ loadRecipesForModel } = await import(RECIPE_TAB_MODULE));
getModelApiClient.mockReset(); getModelApiClient.mockReset();
loadRecipesForModel.mockReset();
}); });
afterEach(() => { afterEach(() => {
@@ -206,4 +209,33 @@ describe('Model metadata interactions keep file path in sync', () => {
expect(saveModelMetadata).toHaveBeenCalledWith('models/Qwen.testing.safetensors', { notes: 'Updated notes' }); expect(saveModelMetadata).toHaveBeenCalledWith('models/Qwen.testing.safetensors', { notes: 'Updated notes' });
}); });
}); });
it('shows recipes tab for checkpoint modals and loads linked recipes by hash', async () => {
const fetchModelMetadata = vi.fn().mockResolvedValue(null);
getModelApiClient.mockReturnValue({
fetchModelMetadata,
saveModelMetadata: vi.fn(),
});
const { showModelModal } = await import(MODAL_MODULE);
await showModelModal(
{
model_name: 'Flux Base',
file_path: 'models/checkpoints/flux-base.safetensors',
file_name: 'flux-base.safetensors',
sha256: 'ABC123',
civitai: {},
},
'checkpoints',
);
expect(document.querySelector('.tab-btn[data-tab="recipes"]')).not.toBeNull();
expect(loadRecipesForModel).toHaveBeenCalledWith({
modelKind: 'checkpoint',
displayName: 'Flux Base',
sha256: 'ABC123',
});
});
}); });

View File

@@ -80,7 +80,7 @@ vi.mock(MODEL_VERSIONS_MODULE, () => ({
})); }));
vi.mock(RECIPE_TAB_MODULE, () => ({ vi.mock(RECIPE_TAB_MODULE, () => ({
loadRecipesForLora: vi.fn(), loadRecipesForModel: vi.fn(),
})); }));
vi.mock(I18N_HELPERS_MODULE, () => ({ vi.mock(I18N_HELPERS_MODULE, () => ({

View File

@@ -6,6 +6,7 @@ const initializePageFeaturesMock = vi.fn();
const getCurrentPageStateMock = vi.fn(); const getCurrentPageStateMock = vi.fn();
const getSessionItemMock = vi.fn(); const getSessionItemMock = vi.fn();
const removeSessionItemMock = vi.fn(); const removeSessionItemMock = vi.fn();
const getStorageItemMock = vi.fn();
const RecipeContextMenuMock = vi.fn(); const RecipeContextMenuMock = vi.fn();
const refreshVirtualScrollMock = vi.fn(); const refreshVirtualScrollMock = vi.fn();
const refreshRecipesMock = vi.fn(); const refreshRecipesMock = vi.fn();
@@ -51,6 +52,7 @@ vi.mock('../../../static/js/state/index.js', () => ({
vi.mock('../../../static/js/utils/storageHelpers.js', () => ({ vi.mock('../../../static/js/utils/storageHelpers.js', () => ({
getSessionItem: getSessionItemMock, getSessionItem: getSessionItemMock,
removeSessionItem: removeSessionItemMock, removeSessionItem: removeSessionItemMock,
getStorageItem: getStorageItemMock,
})); }));
vi.mock('../../../static/js/components/ContextMenu/index.js', () => ({ vi.mock('../../../static/js/components/ContextMenu/index.js', () => ({
@@ -117,11 +119,14 @@ describe('RecipeManager', () => {
const map = { const map = {
lora_to_recipe_filterLoraName: 'Flux Dream', lora_to_recipe_filterLoraName: 'Flux Dream',
lora_to_recipe_filterLoraHash: 'abc123', lora_to_recipe_filterLoraHash: 'abc123',
checkpoint_to_recipe_filterCheckpointName: null,
checkpoint_to_recipe_filterCheckpointHash: null,
viewRecipeId: '42', viewRecipeId: '42',
}; };
return map[key] ?? null; return map[key] ?? null;
}); });
removeSessionItemMock.mockImplementation(() => { }); removeSessionItemMock.mockImplementation(() => { });
getStorageItemMock.mockImplementation((_, defaultValue = null) => defaultValue);
renderRecipesPage(); renderRecipesPage();
@@ -166,6 +171,8 @@ describe('RecipeManager', () => {
active: true, active: true,
loraName: 'Flux Dream', loraName: 'Flux Dream',
loraHash: 'abc123', loraHash: 'abc123',
checkpointName: null,
checkpointHash: null,
recipeId: '42', recipeId: '42',
}); });
@@ -177,6 +184,8 @@ describe('RecipeManager', () => {
expect(removeSessionItemMock).toHaveBeenCalledWith('lora_to_recipe_filterLoraName'); expect(removeSessionItemMock).toHaveBeenCalledWith('lora_to_recipe_filterLoraName');
expect(removeSessionItemMock).toHaveBeenCalledWith('lora_to_recipe_filterLoraHash'); expect(removeSessionItemMock).toHaveBeenCalledWith('lora_to_recipe_filterLoraHash');
expect(removeSessionItemMock).toHaveBeenCalledWith('checkpoint_to_recipe_filterCheckpointName');
expect(removeSessionItemMock).toHaveBeenCalledWith('checkpoint_to_recipe_filterCheckpointHash');
expect(removeSessionItemMock).toHaveBeenCalledWith('viewRecipeId'); expect(removeSessionItemMock).toHaveBeenCalledWith('viewRecipeId');
expect(pageState.customFilter.active).toBe(false); expect(pageState.customFilter.active).toBe(false);
expect(indicator.classList.contains('hidden')).toBe(true); expect(indicator.classList.contains('hidden')).toBe(true);
@@ -227,4 +236,36 @@ describe('RecipeManager', () => {
await manager.refreshRecipes(); await manager.refreshRecipes();
expect(refreshRecipesMock).toHaveBeenCalledTimes(1); expect(refreshRecipesMock).toHaveBeenCalledTimes(1);
}); });
it('restores checkpoint recipe filter state and indicator text', async () => {
getSessionItemMock.mockImplementation((key) => {
const map = {
lora_to_recipe_filterLoraName: null,
lora_to_recipe_filterLoraHash: null,
checkpoint_to_recipe_filterCheckpointName: 'Flux Base',
checkpoint_to_recipe_filterCheckpointHash: 'ckpt123',
viewRecipeId: null,
};
return map[key] ?? null;
});
const manager = new RecipeManager();
await manager.initialize();
expect(pageState.customFilter).toEqual({
active: true,
loraName: null,
loraHash: null,
checkpointName: 'Flux Base',
checkpointHash: 'ckpt123',
recipeId: null,
});
const indicator = document.getElementById('customFilterIndicator');
const filterText = indicator.querySelector('#customFilterText');
expect(filterText.innerHTML).toContain('Recipes using checkpoint:');
expect(filterText.innerHTML).toContain('Flux Base');
expect(filterText.getAttribute('title')).toBe('Flux Base');
});
}); });

View File

@@ -43,6 +43,9 @@ class StubRecipeScanner:
self.cached_raw: List[Dict[str, Any]] = [] self.cached_raw: List[Dict[str, Any]] = []
self.recipes: Dict[str, Dict[str, Any]] = {} self.recipes: Dict[str, Dict[str, Any]] = {}
self.removed: List[str] = [] self.removed: List[str] = []
self.last_paginated_params: Dict[str, Any] | None = None
self.lora_lookup: Dict[str, List[Dict[str, Any]]] = {}
self.checkpoint_lookup: Dict[str, List[Dict[str, Any]]] = {}
async def _noop_get_cached_data(force_refresh: bool = False) -> None: # noqa: ARG001 - signature mirrors real scanner async def _noop_get_cached_data(force_refresh: bool = False) -> None: # noqa: ARG001 - signature mirrors real scanner
return None return None
@@ -56,6 +59,7 @@ class StubRecipeScanner:
return SimpleNamespace(raw_data=list(self.cached_raw)) return SimpleNamespace(raw_data=list(self.cached_raw))
async def get_paginated_data(self, **params: Any) -> Dict[str, Any]: async def get_paginated_data(self, **params: Any) -> Dict[str, Any]:
self.last_paginated_params = params
items = [dict(item) for item in self.listing_items] items = [dict(item) for item in self.listing_items]
page = int(params.get("page", 1)) page = int(params.get("page", 1))
page_size = int(params.get("page_size", 20)) page_size = int(params.get("page_size", 20))
@@ -70,6 +74,14 @@ class StubRecipeScanner:
async def get_recipe_by_id(self, recipe_id: str) -> Optional[Dict[str, Any]]: async def get_recipe_by_id(self, recipe_id: str) -> Optional[Dict[str, Any]]:
return self.recipes.get(recipe_id) return self.recipes.get(recipe_id)
async def get_recipes_for_lora(self, lora_hash: str) -> List[Dict[str, Any]]:
return list(self.lora_lookup.get(lora_hash.lower(), []))
async def get_recipes_for_checkpoint(
self, checkpoint_hash: str
) -> List[Dict[str, Any]]:
return list(self.checkpoint_lookup.get(checkpoint_hash.lower(), []))
async def get_recipe_json_path(self, recipe_id: str) -> Optional[str]: async def get_recipe_json_path(self, recipe_id: str) -> Optional[str]:
candidate = Path(self.recipes_dir) / f"{recipe_id}.recipe.json" candidate = Path(self.recipes_dir) / f"{recipe_id}.recipe.json"
return str(candidate) if candidate.exists() else None return str(candidate) if candidate.exists() else None
@@ -350,6 +362,47 @@ async def test_list_recipes_provides_file_urls(monkeypatch, tmp_path: Path) -> N
assert payload["items"][0]["loras"] == [] assert payload["items"][0]["loras"] == []
async def test_list_recipes_passes_checkpoint_hash_filter(
monkeypatch, tmp_path: Path
) -> None:
async with recipe_harness(monkeypatch, tmp_path) as harness:
response = await harness.client.get("/api/lm/recipes?checkpoint_hash=ckpt123")
payload = await response.json()
assert response.status == 200
assert payload["items"] == []
assert harness.scanner.last_paginated_params["checkpoint_hash"] == "ckpt123"
async def test_get_recipes_for_checkpoint(monkeypatch, tmp_path: Path) -> None:
async with recipe_harness(monkeypatch, tmp_path) as harness:
harness.scanner.checkpoint_lookup["abc123"] = [
{"id": "recipe-1", "title": "Linked recipe"}
]
response = await harness.client.get(
"/api/lm/recipes/for-checkpoint?hash=ABC123"
)
payload = await response.json()
assert response.status == 200
assert payload == {
"success": True,
"recipes": [{"id": "recipe-1", "title": "Linked recipe"}],
}
async def test_get_recipes_for_checkpoint_requires_hash(
monkeypatch, tmp_path: Path
) -> None:
async with recipe_harness(monkeypatch, tmp_path) as harness:
response = await harness.client.get("/api/lm/recipes/for-checkpoint")
payload = await response.json()
assert response.status == 400
assert payload["success"] is False
async def test_save_and_delete_recipe_round_trip(monkeypatch, tmp_path: Path) -> None: async def test_save_and_delete_recipe_round_trip(monkeypatch, tmp_path: Path) -> None:
async with recipe_harness(monkeypatch, tmp_path) as harness: async with recipe_harness(monkeypatch, tmp_path) as harness:
form = FormData() form = FormData()

View File

@@ -313,6 +313,75 @@ async def test_get_recipe_by_id_handles_non_dict_checkpoint(recipe_scanner):
assert recipe["checkpoint"]["file_name"] == "by-id" assert recipe["checkpoint"]["file_name"] == "by-id"
@pytest.mark.asyncio
async def test_get_paginated_data_filters_by_checkpoint_hash(recipe_scanner):
scanner, _ = recipe_scanner
image_path = Path(config.loras_roots[0]) / "checkpoint-filter.webp"
await scanner.add_recipe(
{
"id": "checkpoint-match",
"file_path": str(image_path),
"title": "Checkpoint Match",
"modified": 0.0,
"created_date": 0.0,
"loras": [],
"checkpoint": {
"name": "flux-base.safetensors",
"hash": "ABC123",
},
}
)
await scanner.add_recipe(
{
"id": "checkpoint-miss",
"file_path": str(Path(config.loras_roots[0]) / "checkpoint-miss.webp"),
"title": "Checkpoint Miss",
"modified": 1.0,
"created_date": 1.0,
"loras": [],
"checkpoint": {
"name": "other.safetensors",
"hash": "zzz999",
},
}
)
await asyncio.sleep(0)
result = await scanner.get_paginated_data(
page=1,
page_size=10,
checkpoint_hash="abc123",
)
assert [item["id"] for item in result["items"]] == ["checkpoint-match"]
@pytest.mark.asyncio
async def test_get_recipes_for_checkpoint_matches_hash_case_insensitively(recipe_scanner):
scanner, _ = recipe_scanner
image_path = Path(config.loras_roots[0]) / "checkpoint-linked.webp"
await scanner.add_recipe(
{
"id": "checkpoint-linked",
"file_path": str(image_path),
"title": "Checkpoint Linked",
"modified": 0.0,
"created_date": 0.0,
"loras": [],
"checkpoint": {
"name": "flux-base.safetensors",
"hash": "ABC123",
},
}
)
recipes = await scanner.get_recipes_for_checkpoint("abc123")
assert len(recipes) == 1
assert recipes[0]["id"] == "checkpoint-linked"
assert recipes[0]["checkpoint"]["hash"] == "ABC123"
def test_enrich_uses_version_index_when_hash_missing(recipe_scanner): def test_enrich_uses_version_index_when_hash_missing(recipe_scanner):
scanner, stub = recipe_scanner scanner, stub = recipe_scanner
version_id = 77 version_id = 77