mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-04-02 02:38:52 -03:00
fix(recipe): show checkpoint-linked recipes in model modal (#851)
This commit is contained in:
@@ -81,6 +81,7 @@ class RecipeHandlerSet:
|
||||
"bulk_delete": self.management.bulk_delete,
|
||||
"save_recipe_from_widget": self.management.save_recipe_from_widget,
|
||||
"get_recipes_for_lora": self.query.get_recipes_for_lora,
|
||||
"get_recipes_for_checkpoint": self.query.get_recipes_for_checkpoint,
|
||||
"scan_recipes": self.query.scan_recipes,
|
||||
"move_recipe": self.management.move_recipe,
|
||||
"repair_recipes": self.management.repair_recipes,
|
||||
@@ -218,6 +219,7 @@ class RecipeListingHandler:
|
||||
filters["tags"] = tag_filters
|
||||
|
||||
lora_hash = request.query.get("lora_hash")
|
||||
checkpoint_hash = request.query.get("checkpoint_hash")
|
||||
|
||||
result = await recipe_scanner.get_paginated_data(
|
||||
page=page,
|
||||
@@ -227,6 +229,7 @@ class RecipeListingHandler:
|
||||
filters=filters,
|
||||
search_options=search_options,
|
||||
lora_hash=lora_hash,
|
||||
checkpoint_hash=checkpoint_hash,
|
||||
folder=folder,
|
||||
recursive=recursive,
|
||||
)
|
||||
@@ -423,6 +426,28 @@ class RecipeQueryHandler:
|
||||
self._logger.error("Error getting recipes for Lora: %s", exc)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_recipes_for_checkpoint(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
checkpoint_hash = request.query.get("hash")
|
||||
if not checkpoint_hash:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Checkpoint hash is required"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
matching_recipes = await recipe_scanner.get_recipes_for_checkpoint(
|
||||
checkpoint_hash
|
||||
)
|
||||
return web.json_response({"success": True, "recipes": matching_recipes})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error getting recipes for checkpoint: %s", exc)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def scan_recipes(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
|
||||
@@ -51,6 +51,9 @@ ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
"POST", "/api/lm/recipes/save-from-widget", "save_recipe_from_widget"
|
||||
),
|
||||
RouteDefinition("GET", "/api/lm/recipes/for-lora", "get_recipes_for_lora"),
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/recipes/for-checkpoint", "get_recipes_for_checkpoint"
|
||||
),
|
||||
RouteDefinition("GET", "/api/lm/recipes/scan", "scan_recipes"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/repair", "repair_recipes"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/cancel-repair", "cancel_repair"),
|
||||
|
||||
@@ -1615,6 +1615,9 @@ class RecipeScanner:
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Coerce legacy or malformed checkpoint entries into a dict."""
|
||||
|
||||
if checkpoint_raw is None:
|
||||
return None
|
||||
|
||||
if isinstance(checkpoint_raw, dict):
|
||||
return dict(checkpoint_raw)
|
||||
|
||||
@@ -1632,9 +1635,6 @@ class RecipeScanner:
|
||||
"file_name": file_name,
|
||||
}
|
||||
|
||||
logger.warning(
|
||||
"Unexpected checkpoint payload type %s", type(checkpoint_raw).__name__
|
||||
)
|
||||
return None
|
||||
|
||||
def _enrich_checkpoint_entry(self, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@@ -1790,6 +1790,7 @@ class RecipeScanner:
|
||||
filters: dict = None,
|
||||
search_options: dict = None,
|
||||
lora_hash: str = None,
|
||||
checkpoint_hash: str = None,
|
||||
bypass_filters: bool = True,
|
||||
folder: str | None = None,
|
||||
recursive: bool = True,
|
||||
@@ -1804,7 +1805,8 @@ class RecipeScanner:
|
||||
filters: Dictionary of filters to apply
|
||||
search_options: Dictionary of search options to apply
|
||||
lora_hash: Optional SHA256 hash of a LoRA to filter recipes by
|
||||
bypass_filters: If True, ignore other filters when a lora_hash is provided
|
||||
checkpoint_hash: Optional SHA256 hash of a checkpoint to filter recipes by
|
||||
bypass_filters: If True, ignore other filters when a hash filter is provided
|
||||
folder: Optional folder filter relative to recipes directory
|
||||
recursive: Whether to include recipes in subfolders of the selected folder
|
||||
"""
|
||||
@@ -1852,9 +1854,23 @@ class RecipeScanner:
|
||||
# Skip other filters if bypass_filters is True
|
||||
pass
|
||||
# Otherwise continue with normal filtering after applying LoRA hash filter
|
||||
elif checkpoint_hash:
|
||||
normalized_checkpoint_hash = checkpoint_hash.lower()
|
||||
filtered_data = [
|
||||
item
|
||||
for item in filtered_data
|
||||
if isinstance(item.get("checkpoint"), dict)
|
||||
and (item["checkpoint"].get("hash", "") or "").lower()
|
||||
== normalized_checkpoint_hash
|
||||
]
|
||||
|
||||
# Skip further filtering if we're only filtering by LoRA hash with bypass enabled
|
||||
if not (lora_hash and bypass_filters):
|
||||
if bypass_filters:
|
||||
pass
|
||||
|
||||
has_hash_filter = bool(lora_hash or checkpoint_hash)
|
||||
|
||||
# Skip further filtering if we're only filtering by model hash with bypass enabled
|
||||
if not (has_hash_filter and bypass_filters):
|
||||
# Apply folder filter before other criteria
|
||||
if folder is not None:
|
||||
normalized_folder = folder.strip("/")
|
||||
@@ -2334,6 +2350,38 @@ class RecipeScanner:
|
||||
|
||||
return matching_recipes
|
||||
|
||||
async def get_recipes_for_checkpoint(
|
||||
self, checkpoint_hash: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Return recipes that reference a given checkpoint hash."""
|
||||
|
||||
if not checkpoint_hash:
|
||||
return []
|
||||
|
||||
normalized_hash = checkpoint_hash.lower()
|
||||
cache = await self.get_cached_data()
|
||||
matching_recipes: List[Dict[str, Any]] = []
|
||||
|
||||
for recipe in cache.raw_data:
|
||||
checkpoint = self._normalize_checkpoint_entry(recipe.get("checkpoint"))
|
||||
if not checkpoint:
|
||||
continue
|
||||
|
||||
enriched_checkpoint = self._enrich_checkpoint_entry(dict(checkpoint))
|
||||
if (enriched_checkpoint.get("hash") or "").lower() != normalized_hash:
|
||||
continue
|
||||
|
||||
recipe_copy = {**recipe}
|
||||
recipe_copy["checkpoint"] = enriched_checkpoint
|
||||
recipe_copy["loras"] = [
|
||||
self._enrich_lora_entry(dict(entry))
|
||||
for entry in recipe.get("loras", [])
|
||||
]
|
||||
recipe_copy["file_url"] = self._format_file_url(recipe.get("file_path"))
|
||||
matching_recipes.append(recipe_copy)
|
||||
|
||||
return matching_recipes
|
||||
|
||||
async def get_recipe_syntax_tokens(self, recipe_id: str) -> List[str]:
|
||||
"""Build LoRA syntax tokens for a recipe."""
|
||||
|
||||
|
||||
@@ -83,6 +83,9 @@ export async function fetchRecipesPage(page = 1, pageSize = 100) {
|
||||
if (pageState.customFilter?.active && pageState.customFilter?.loraHash) {
|
||||
params.append('lora_hash', pageState.customFilter.loraHash);
|
||||
params.append('bypass_filters', 'true');
|
||||
} else if (pageState.customFilter?.active && pageState.customFilter?.checkpointHash) {
|
||||
params.append('checkpoint_hash', pageState.customFilter.checkpointHash);
|
||||
params.append('bypass_filters', 'true');
|
||||
} else {
|
||||
// Normal filtering logic
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ import { renderCompactTags, setupTagTooltip, formatFileSize, escapeAttribute, es
|
||||
import { renderTriggerWords, setupTriggerWordsEditMode } from './TriggerWords.js';
|
||||
import { parsePresets, renderPresetTags } from './PresetTags.js';
|
||||
import { initVersionsTab } from './ModelVersionsTab.js';
|
||||
import { loadRecipesForLora } from './RecipeTab.js';
|
||||
import { loadRecipesForModel } from './RecipeTab.js';
|
||||
import { translate } from '../../utils/i18nHelpers.js';
|
||||
import { state } from '../../state/index.js';
|
||||
|
||||
@@ -355,7 +355,9 @@ export async function showModelModal(model, modelType) {
|
||||
${versionsTabBadge}
|
||||
</button>`.trim();
|
||||
|
||||
const tabsContent = modelType === 'loras' ?
|
||||
const supportsRecipesTab = modelType === 'loras' || modelType === 'checkpoints';
|
||||
|
||||
const tabsContent = supportsRecipesTab ?
|
||||
`<button class="tab-btn active" data-tab="showcase">${examplesText}</button>
|
||||
<button class="tab-btn" data-tab="description">${descriptionText}</button>
|
||||
${versionsTabButton}
|
||||
@@ -385,7 +387,7 @@ export async function showModelModal(model, modelType) {
|
||||
</button>
|
||||
</div>`.trim();
|
||||
|
||||
const tabPanesContent = modelType === 'loras' ?
|
||||
const tabPanesContent = supportsRecipesTab ?
|
||||
`<div id="showcase-tab" class="tab-pane active">
|
||||
<div class="example-images-loading">
|
||||
<i class="fas fa-spinner fa-spin"></i> ${loadingExampleImagesText}
|
||||
@@ -664,14 +666,23 @@ export async function showModelModal(model, modelType) {
|
||||
setupNavigationShortcuts(modelType);
|
||||
updateNavigationControls();
|
||||
|
||||
// LoRA specific setup
|
||||
// Model-specific setup
|
||||
if (modelType === 'loras' || modelType === 'embeddings') {
|
||||
setupTriggerWordsEditMode();
|
||||
}
|
||||
|
||||
if (modelType == 'loras') {
|
||||
// Load recipes for this LoRA
|
||||
loadRecipesForLora(modelWithFullData.model_name, modelWithFullData.sha256);
|
||||
}
|
||||
if (modelType === 'loras') {
|
||||
loadRecipesForModel({
|
||||
modelKind: 'lora',
|
||||
displayName: modelWithFullData.model_name,
|
||||
sha256: modelWithFullData.sha256,
|
||||
});
|
||||
} else if (modelType === 'checkpoints') {
|
||||
loadRecipesForModel({
|
||||
modelKind: 'checkpoint',
|
||||
displayName: modelWithFullData.model_name,
|
||||
sha256: modelWithFullData.sha256,
|
||||
});
|
||||
}
|
||||
|
||||
// Load example images asynchronously - merge regular and custom images
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
/**
|
||||
* RecipeTab - Handles the recipes tab in model modals (LoRA specific functionality)
|
||||
* Moved to shared directory for consistency
|
||||
* RecipeTab - Handles the recipes tab in model modals.
|
||||
*/
|
||||
import { showToast, copyToClipboard } from '../../utils/uiHelpers.js';
|
||||
import { setSessionItem, removeSessionItem } from '../../utils/storageHelpers.js';
|
||||
|
||||
/**
|
||||
* Loads recipes that use the specified Lora and renders them in the tab
|
||||
* @param {string} loraName - The display name of the Lora
|
||||
* @param {string} sha256 - The SHA256 hash of the Lora
|
||||
* Loads recipes that use the specified model and renders them in the tab.
|
||||
* @param {Object} options
|
||||
* @param {'lora'|'checkpoint'} options.modelKind - Model kind for copy and endpoint selection
|
||||
* @param {string} options.displayName - The display name of the model
|
||||
* @param {string} options.sha256 - The SHA256 hash of the model
|
||||
*/
|
||||
export function loadRecipesForLora(loraName, sha256) {
|
||||
export function loadRecipesForModel({ modelKind, displayName, sha256 }) {
|
||||
const recipeTab = document.getElementById('recipes-tab');
|
||||
if (!recipeTab) return;
|
||||
|
||||
const normalizedHash = sha256?.toLowerCase?.() || '';
|
||||
const modelLabel = getModelLabel(modelKind);
|
||||
|
||||
// Show loading state
|
||||
recipeTab.innerHTML = `
|
||||
<div class="recipes-loading">
|
||||
@@ -21,18 +25,23 @@ export function loadRecipesForLora(loraName, sha256) {
|
||||
</div>
|
||||
`;
|
||||
|
||||
// Fetch recipes that use this Lora by hash
|
||||
fetch(`/api/lm/recipes/for-lora?hash=${encodeURIComponent(sha256.toLowerCase())}`)
|
||||
// Fetch recipes that use this model by hash
|
||||
fetch(`${getRecipesEndpoint(modelKind)}?hash=${encodeURIComponent(normalizedHash)}`)
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
if (!data.success) {
|
||||
throw new Error(data.error || 'Failed to load recipes');
|
||||
}
|
||||
|
||||
renderRecipes(recipeTab, data.recipes, loraName, sha256);
|
||||
renderRecipes(recipeTab, data.recipes, {
|
||||
modelKind,
|
||||
displayName,
|
||||
modelHash: normalizedHash,
|
||||
modelLabel,
|
||||
});
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Error loading recipes for Lora:', error);
|
||||
console.error(`Error loading recipes for ${modelLabel}:`, error);
|
||||
recipeTab.innerHTML = `
|
||||
<div class="recipes-error">
|
||||
<i class="fas fa-exclamation-circle"></i>
|
||||
@@ -46,15 +55,21 @@ export function loadRecipesForLora(loraName, sha256) {
|
||||
* Renders the recipe cards in the tab
|
||||
* @param {HTMLElement} tabElement - The tab element to render into
|
||||
* @param {Array} recipes - Array of recipe objects
|
||||
* @param {string} loraName - The display name of the Lora
|
||||
* @param {string} loraHash - The hash of the Lora
|
||||
* @param {Object} options - Render options
|
||||
*/
|
||||
function renderRecipes(tabElement, recipes, loraName, loraHash) {
|
||||
function renderRecipes(tabElement, recipes, options) {
|
||||
const {
|
||||
modelKind,
|
||||
displayName,
|
||||
modelHash,
|
||||
modelLabel,
|
||||
} = options;
|
||||
|
||||
if (!recipes || recipes.length === 0) {
|
||||
tabElement.innerHTML = `
|
||||
<div class="recipes-empty">
|
||||
<i class="fas fa-book-open"></i>
|
||||
<p>No recipes found that use this Lora.</p>
|
||||
<p>No recipes found that use this ${modelLabel}.</p>
|
||||
</div>
|
||||
`;
|
||||
|
||||
@@ -73,13 +88,13 @@ function renderRecipes(tabElement, recipes, loraName, loraHash) {
|
||||
headerText.appendChild(eyebrow);
|
||||
|
||||
const title = document.createElement('h3');
|
||||
title.textContent = `${recipes.length} recipe${recipes.length > 1 ? 's' : ''} using this Lora`;
|
||||
title.textContent = `${recipes.length} recipe${recipes.length > 1 ? 's' : ''} using this ${modelLabel}`;
|
||||
headerText.appendChild(title);
|
||||
|
||||
const description = document.createElement('p');
|
||||
description.className = 'recipes-header__description';
|
||||
description.textContent = loraName ?
|
||||
`Discover workflows crafted for ${loraName}.` :
|
||||
description.textContent = displayName ?
|
||||
`Discover workflows crafted for ${displayName}.` :
|
||||
'Discover workflows crafted for this model.';
|
||||
headerText.appendChild(description);
|
||||
|
||||
@@ -101,7 +116,11 @@ function renderRecipes(tabElement, recipes, loraName, loraHash) {
|
||||
headerElement.appendChild(viewAllButton);
|
||||
|
||||
viewAllButton.addEventListener('click', () => {
|
||||
navigateToRecipesPage(loraName, loraHash);
|
||||
navigateToRecipesPage({
|
||||
modelKind,
|
||||
displayName,
|
||||
modelHash,
|
||||
});
|
||||
});
|
||||
|
||||
const cardGrid = document.createElement('div');
|
||||
@@ -280,12 +299,10 @@ function copyRecipeSyntax(recipeId) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Navigates to the recipes page with filter for the current Lora
|
||||
* @param {string} loraName - The Lora display name to filter by
|
||||
* @param {string} loraHash - The hash of the Lora to filter by
|
||||
* @param {boolean} createNew - Whether to open the create recipe dialog
|
||||
* Navigates to the recipes page with filter for the current model
|
||||
* @param {Object} options - Navigation options
|
||||
*/
|
||||
function navigateToRecipesPage(loraName, loraHash) {
|
||||
function navigateToRecipesPage({ modelKind, displayName, modelHash }) {
|
||||
// Close the current modal
|
||||
if (window.modalManager) {
|
||||
modalManager.closeModal('modelModal');
|
||||
@@ -294,11 +311,19 @@ function navigateToRecipesPage(loraName, loraHash) {
|
||||
// Clear any previous filters first
|
||||
removeSessionItem('lora_to_recipe_filterLoraName');
|
||||
removeSessionItem('lora_to_recipe_filterLoraHash');
|
||||
removeSessionItem('checkpoint_to_recipe_filterCheckpointName');
|
||||
removeSessionItem('checkpoint_to_recipe_filterCheckpointHash');
|
||||
removeSessionItem('viewRecipeId');
|
||||
|
||||
// Store the LoRA name and hash filter in sessionStorage
|
||||
setSessionItem('lora_to_recipe_filterLoraName', loraName);
|
||||
setSessionItem('lora_to_recipe_filterLoraHash', loraHash);
|
||||
if (modelKind === 'checkpoint') {
|
||||
// Store the checkpoint name and hash filter in sessionStorage
|
||||
setSessionItem('checkpoint_to_recipe_filterCheckpointName', displayName);
|
||||
setSessionItem('checkpoint_to_recipe_filterCheckpointHash', modelHash);
|
||||
} else {
|
||||
// Store the LoRA name and hash filter in sessionStorage
|
||||
setSessionItem('lora_to_recipe_filterLoraName', displayName);
|
||||
setSessionItem('lora_to_recipe_filterLoraHash', modelHash);
|
||||
}
|
||||
|
||||
// Directly navigate to recipes page
|
||||
window.location.href = '/loras/recipes';
|
||||
@@ -325,3 +350,14 @@ function navigateToRecipeDetails(recipeId) {
|
||||
// Directly navigate to recipes page
|
||||
window.location.href = '/loras/recipes';
|
||||
}
|
||||
|
||||
function getRecipesEndpoint(modelKind) {
|
||||
if (modelKind === 'checkpoint') {
|
||||
return '/api/lm/recipes/for-checkpoint';
|
||||
}
|
||||
return '/api/lm/recipes/for-lora';
|
||||
}
|
||||
|
||||
function getModelLabel(modelKind) {
|
||||
return modelKind === 'checkpoint' ? 'checkpoint' : 'LoRA';
|
||||
}
|
||||
|
||||
@@ -66,6 +66,8 @@ class RecipeManager {
|
||||
active: false,
|
||||
loraName: null,
|
||||
loraHash: null,
|
||||
checkpointName: null,
|
||||
checkpointHash: null,
|
||||
recipeId: null
|
||||
};
|
||||
}
|
||||
@@ -127,16 +129,20 @@ class RecipeManager {
|
||||
// Check for Lora filter
|
||||
const filterLoraName = getSessionItem('lora_to_recipe_filterLoraName');
|
||||
const filterLoraHash = getSessionItem('lora_to_recipe_filterLoraHash');
|
||||
const filterCheckpointName = getSessionItem('checkpoint_to_recipe_filterCheckpointName');
|
||||
const filterCheckpointHash = getSessionItem('checkpoint_to_recipe_filterCheckpointHash');
|
||||
|
||||
// Check for specific recipe ID
|
||||
const viewRecipeId = getSessionItem('viewRecipeId');
|
||||
|
||||
// Set custom filter if any parameter is present
|
||||
if (filterLoraName || filterLoraHash || viewRecipeId) {
|
||||
if (filterLoraName || filterLoraHash || filterCheckpointName || filterCheckpointHash || viewRecipeId) {
|
||||
this.pageState.customFilter = {
|
||||
active: true,
|
||||
loraName: filterLoraName,
|
||||
loraHash: filterLoraHash,
|
||||
checkpointName: filterCheckpointName,
|
||||
checkpointHash: filterCheckpointHash,
|
||||
recipeId: viewRecipeId
|
||||
};
|
||||
|
||||
@@ -164,6 +170,13 @@ class RecipeManager {
|
||||
loraName;
|
||||
|
||||
filterText = `<span>Recipes using: <span class="lora-name">${displayName}</span></span>`;
|
||||
} else if (this.pageState.customFilter.checkpointName) {
|
||||
const checkpointName = this.pageState.customFilter.checkpointName;
|
||||
const displayName = checkpointName.length > 25 ?
|
||||
checkpointName.substring(0, 22) + '...' :
|
||||
checkpointName;
|
||||
|
||||
filterText = `<span>Recipes using checkpoint: <span class="lora-name">${displayName}</span></span>`;
|
||||
} else {
|
||||
filterText = 'Filtered recipes';
|
||||
}
|
||||
@@ -173,6 +186,10 @@ class RecipeManager {
|
||||
// Add title attribute to show the lora name as a tooltip
|
||||
if (this.pageState.customFilter.loraName) {
|
||||
textElement.setAttribute('title', this.pageState.customFilter.loraName);
|
||||
} else if (this.pageState.customFilter.checkpointName) {
|
||||
textElement.setAttribute('title', this.pageState.customFilter.checkpointName);
|
||||
} else {
|
||||
textElement.removeAttribute('title');
|
||||
}
|
||||
indicator.classList.remove('hidden');
|
||||
|
||||
@@ -199,6 +216,8 @@ class RecipeManager {
|
||||
active: false,
|
||||
loraName: null,
|
||||
loraHash: null,
|
||||
checkpointName: null,
|
||||
checkpointHash: null,
|
||||
recipeId: null
|
||||
};
|
||||
|
||||
@@ -211,6 +230,8 @@ class RecipeManager {
|
||||
// Clear any session storage items
|
||||
removeSessionItem('lora_to_recipe_filterLoraName');
|
||||
removeSessionItem('lora_to_recipe_filterLoraHash');
|
||||
removeSessionItem('checkpoint_to_recipe_filterCheckpointName');
|
||||
removeSessionItem('checkpoint_to_recipe_filterCheckpointHash');
|
||||
removeSessionItem('viewRecipeId');
|
||||
|
||||
// Reset and refresh the virtual scroller
|
||||
|
||||
@@ -82,7 +82,7 @@ vi.mock(MODEL_VERSIONS_MODULE, () => ({
|
||||
}));
|
||||
|
||||
vi.mock(RECIPE_TAB_MODULE, () => ({
|
||||
loadRecipesForLora: vi.fn(),
|
||||
loadRecipesForModel: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock(I18N_HELPERS_MODULE, () => ({
|
||||
@@ -103,11 +103,14 @@ vi.mock(API_FACTORY, () => ({
|
||||
|
||||
describe('Model metadata interactions keep file path in sync', () => {
|
||||
let getModelApiClient;
|
||||
let loadRecipesForModel;
|
||||
|
||||
beforeEach(async () => {
|
||||
document.body.innerHTML = '';
|
||||
({ getModelApiClient } = await import(API_FACTORY));
|
||||
({ loadRecipesForModel } = await import(RECIPE_TAB_MODULE));
|
||||
getModelApiClient.mockReset();
|
||||
loadRecipesForModel.mockReset();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -206,4 +209,33 @@ describe('Model metadata interactions keep file path in sync', () => {
|
||||
expect(saveModelMetadata).toHaveBeenCalledWith('models/Qwen.testing.safetensors', { notes: 'Updated notes' });
|
||||
});
|
||||
});
|
||||
|
||||
it('shows recipes tab for checkpoint modals and loads linked recipes by hash', async () => {
|
||||
const fetchModelMetadata = vi.fn().mockResolvedValue(null);
|
||||
|
||||
getModelApiClient.mockReturnValue({
|
||||
fetchModelMetadata,
|
||||
saveModelMetadata: vi.fn(),
|
||||
});
|
||||
|
||||
const { showModelModal } = await import(MODAL_MODULE);
|
||||
|
||||
await showModelModal(
|
||||
{
|
||||
model_name: 'Flux Base',
|
||||
file_path: 'models/checkpoints/flux-base.safetensors',
|
||||
file_name: 'flux-base.safetensors',
|
||||
sha256: 'ABC123',
|
||||
civitai: {},
|
||||
},
|
||||
'checkpoints',
|
||||
);
|
||||
|
||||
expect(document.querySelector('.tab-btn[data-tab="recipes"]')).not.toBeNull();
|
||||
expect(loadRecipesForModel).toHaveBeenCalledWith({
|
||||
modelKind: 'checkpoint',
|
||||
displayName: 'Flux Base',
|
||||
sha256: 'ABC123',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -80,7 +80,7 @@ vi.mock(MODEL_VERSIONS_MODULE, () => ({
|
||||
}));
|
||||
|
||||
vi.mock(RECIPE_TAB_MODULE, () => ({
|
||||
loadRecipesForLora: vi.fn(),
|
||||
loadRecipesForModel: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock(I18N_HELPERS_MODULE, () => ({
|
||||
|
||||
@@ -6,6 +6,7 @@ const initializePageFeaturesMock = vi.fn();
|
||||
const getCurrentPageStateMock = vi.fn();
|
||||
const getSessionItemMock = vi.fn();
|
||||
const removeSessionItemMock = vi.fn();
|
||||
const getStorageItemMock = vi.fn();
|
||||
const RecipeContextMenuMock = vi.fn();
|
||||
const refreshVirtualScrollMock = vi.fn();
|
||||
const refreshRecipesMock = vi.fn();
|
||||
@@ -51,6 +52,7 @@ vi.mock('../../../static/js/state/index.js', () => ({
|
||||
vi.mock('../../../static/js/utils/storageHelpers.js', () => ({
|
||||
getSessionItem: getSessionItemMock,
|
||||
removeSessionItem: removeSessionItemMock,
|
||||
getStorageItem: getStorageItemMock,
|
||||
}));
|
||||
|
||||
vi.mock('../../../static/js/components/ContextMenu/index.js', () => ({
|
||||
@@ -117,11 +119,14 @@ describe('RecipeManager', () => {
|
||||
const map = {
|
||||
lora_to_recipe_filterLoraName: 'Flux Dream',
|
||||
lora_to_recipe_filterLoraHash: 'abc123',
|
||||
checkpoint_to_recipe_filterCheckpointName: null,
|
||||
checkpoint_to_recipe_filterCheckpointHash: null,
|
||||
viewRecipeId: '42',
|
||||
};
|
||||
return map[key] ?? null;
|
||||
});
|
||||
removeSessionItemMock.mockImplementation(() => { });
|
||||
getStorageItemMock.mockImplementation((_, defaultValue = null) => defaultValue);
|
||||
|
||||
renderRecipesPage();
|
||||
|
||||
@@ -166,6 +171,8 @@ describe('RecipeManager', () => {
|
||||
active: true,
|
||||
loraName: 'Flux Dream',
|
||||
loraHash: 'abc123',
|
||||
checkpointName: null,
|
||||
checkpointHash: null,
|
||||
recipeId: '42',
|
||||
});
|
||||
|
||||
@@ -177,6 +184,8 @@ describe('RecipeManager', () => {
|
||||
|
||||
expect(removeSessionItemMock).toHaveBeenCalledWith('lora_to_recipe_filterLoraName');
|
||||
expect(removeSessionItemMock).toHaveBeenCalledWith('lora_to_recipe_filterLoraHash');
|
||||
expect(removeSessionItemMock).toHaveBeenCalledWith('checkpoint_to_recipe_filterCheckpointName');
|
||||
expect(removeSessionItemMock).toHaveBeenCalledWith('checkpoint_to_recipe_filterCheckpointHash');
|
||||
expect(removeSessionItemMock).toHaveBeenCalledWith('viewRecipeId');
|
||||
expect(pageState.customFilter.active).toBe(false);
|
||||
expect(indicator.classList.contains('hidden')).toBe(true);
|
||||
@@ -227,4 +236,36 @@ describe('RecipeManager', () => {
|
||||
await manager.refreshRecipes();
|
||||
expect(refreshRecipesMock).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('restores checkpoint recipe filter state and indicator text', async () => {
|
||||
getSessionItemMock.mockImplementation((key) => {
|
||||
const map = {
|
||||
lora_to_recipe_filterLoraName: null,
|
||||
lora_to_recipe_filterLoraHash: null,
|
||||
checkpoint_to_recipe_filterCheckpointName: 'Flux Base',
|
||||
checkpoint_to_recipe_filterCheckpointHash: 'ckpt123',
|
||||
viewRecipeId: null,
|
||||
};
|
||||
return map[key] ?? null;
|
||||
});
|
||||
|
||||
const manager = new RecipeManager();
|
||||
await manager.initialize();
|
||||
|
||||
expect(pageState.customFilter).toEqual({
|
||||
active: true,
|
||||
loraName: null,
|
||||
loraHash: null,
|
||||
checkpointName: 'Flux Base',
|
||||
checkpointHash: 'ckpt123',
|
||||
recipeId: null,
|
||||
});
|
||||
|
||||
const indicator = document.getElementById('customFilterIndicator');
|
||||
const filterText = indicator.querySelector('#customFilterText');
|
||||
|
||||
expect(filterText.innerHTML).toContain('Recipes using checkpoint:');
|
||||
expect(filterText.innerHTML).toContain('Flux Base');
|
||||
expect(filterText.getAttribute('title')).toBe('Flux Base');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -43,6 +43,9 @@ class StubRecipeScanner:
|
||||
self.cached_raw: List[Dict[str, Any]] = []
|
||||
self.recipes: Dict[str, Dict[str, Any]] = {}
|
||||
self.removed: List[str] = []
|
||||
self.last_paginated_params: Dict[str, Any] | None = None
|
||||
self.lora_lookup: Dict[str, List[Dict[str, Any]]] = {}
|
||||
self.checkpoint_lookup: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
async def _noop_get_cached_data(force_refresh: bool = False) -> None: # noqa: ARG001 - signature mirrors real scanner
|
||||
return None
|
||||
@@ -56,6 +59,7 @@ class StubRecipeScanner:
|
||||
return SimpleNamespace(raw_data=list(self.cached_raw))
|
||||
|
||||
async def get_paginated_data(self, **params: Any) -> Dict[str, Any]:
|
||||
self.last_paginated_params = params
|
||||
items = [dict(item) for item in self.listing_items]
|
||||
page = int(params.get("page", 1))
|
||||
page_size = int(params.get("page_size", 20))
|
||||
@@ -70,6 +74,14 @@ class StubRecipeScanner:
|
||||
async def get_recipe_by_id(self, recipe_id: str) -> Optional[Dict[str, Any]]:
|
||||
return self.recipes.get(recipe_id)
|
||||
|
||||
async def get_recipes_for_lora(self, lora_hash: str) -> List[Dict[str, Any]]:
|
||||
return list(self.lora_lookup.get(lora_hash.lower(), []))
|
||||
|
||||
async def get_recipes_for_checkpoint(
|
||||
self, checkpoint_hash: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
return list(self.checkpoint_lookup.get(checkpoint_hash.lower(), []))
|
||||
|
||||
async def get_recipe_json_path(self, recipe_id: str) -> Optional[str]:
|
||||
candidate = Path(self.recipes_dir) / f"{recipe_id}.recipe.json"
|
||||
return str(candidate) if candidate.exists() else None
|
||||
@@ -350,6 +362,47 @@ async def test_list_recipes_provides_file_urls(monkeypatch, tmp_path: Path) -> N
|
||||
assert payload["items"][0]["loras"] == []
|
||||
|
||||
|
||||
async def test_list_recipes_passes_checkpoint_hash_filter(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
response = await harness.client.get("/api/lm/recipes?checkpoint_hash=ckpt123")
|
||||
payload = await response.json()
|
||||
|
||||
assert response.status == 200
|
||||
assert payload["items"] == []
|
||||
assert harness.scanner.last_paginated_params["checkpoint_hash"] == "ckpt123"
|
||||
|
||||
|
||||
async def test_get_recipes_for_checkpoint(monkeypatch, tmp_path: Path) -> None:
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
harness.scanner.checkpoint_lookup["abc123"] = [
|
||||
{"id": "recipe-1", "title": "Linked recipe"}
|
||||
]
|
||||
|
||||
response = await harness.client.get(
|
||||
"/api/lm/recipes/for-checkpoint?hash=ABC123"
|
||||
)
|
||||
payload = await response.json()
|
||||
|
||||
assert response.status == 200
|
||||
assert payload == {
|
||||
"success": True,
|
||||
"recipes": [{"id": "recipe-1", "title": "Linked recipe"}],
|
||||
}
|
||||
|
||||
|
||||
async def test_get_recipes_for_checkpoint_requires_hash(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
response = await harness.client.get("/api/lm/recipes/for-checkpoint")
|
||||
payload = await response.json()
|
||||
|
||||
assert response.status == 400
|
||||
assert payload["success"] is False
|
||||
|
||||
|
||||
async def test_save_and_delete_recipe_round_trip(monkeypatch, tmp_path: Path) -> None:
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
form = FormData()
|
||||
|
||||
@@ -313,6 +313,75 @@ async def test_get_recipe_by_id_handles_non_dict_checkpoint(recipe_scanner):
|
||||
assert recipe["checkpoint"]["file_name"] == "by-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_paginated_data_filters_by_checkpoint_hash(recipe_scanner):
|
||||
scanner, _ = recipe_scanner
|
||||
image_path = Path(config.loras_roots[0]) / "checkpoint-filter.webp"
|
||||
await scanner.add_recipe(
|
||||
{
|
||||
"id": "checkpoint-match",
|
||||
"file_path": str(image_path),
|
||||
"title": "Checkpoint Match",
|
||||
"modified": 0.0,
|
||||
"created_date": 0.0,
|
||||
"loras": [],
|
||||
"checkpoint": {
|
||||
"name": "flux-base.safetensors",
|
||||
"hash": "ABC123",
|
||||
},
|
||||
}
|
||||
)
|
||||
await scanner.add_recipe(
|
||||
{
|
||||
"id": "checkpoint-miss",
|
||||
"file_path": str(Path(config.loras_roots[0]) / "checkpoint-miss.webp"),
|
||||
"title": "Checkpoint Miss",
|
||||
"modified": 1.0,
|
||||
"created_date": 1.0,
|
||||
"loras": [],
|
||||
"checkpoint": {
|
||||
"name": "other.safetensors",
|
||||
"hash": "zzz999",
|
||||
},
|
||||
}
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
result = await scanner.get_paginated_data(
|
||||
page=1,
|
||||
page_size=10,
|
||||
checkpoint_hash="abc123",
|
||||
)
|
||||
|
||||
assert [item["id"] for item in result["items"]] == ["checkpoint-match"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_recipes_for_checkpoint_matches_hash_case_insensitively(recipe_scanner):
|
||||
scanner, _ = recipe_scanner
|
||||
image_path = Path(config.loras_roots[0]) / "checkpoint-linked.webp"
|
||||
await scanner.add_recipe(
|
||||
{
|
||||
"id": "checkpoint-linked",
|
||||
"file_path": str(image_path),
|
||||
"title": "Checkpoint Linked",
|
||||
"modified": 0.0,
|
||||
"created_date": 0.0,
|
||||
"loras": [],
|
||||
"checkpoint": {
|
||||
"name": "flux-base.safetensors",
|
||||
"hash": "ABC123",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
recipes = await scanner.get_recipes_for_checkpoint("abc123")
|
||||
|
||||
assert len(recipes) == 1
|
||||
assert recipes[0]["id"] == "checkpoint-linked"
|
||||
assert recipes[0]["checkpoint"]["hash"] == "ABC123"
|
||||
|
||||
|
||||
def test_enrich_uses_version_index_when_hash_missing(recipe_scanner):
|
||||
scanner, stub = recipe_scanner
|
||||
version_id = 77
|
||||
|
||||
Reference in New Issue
Block a user