From ec9b37eb532d9c4338a8a1a795b157dad8941549 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Sat, 11 Oct 2025 20:36:38 +0800 Subject: [PATCH] feat: add model type context to tag suggestions - Pass modelType parameter to setupTagEditMode function - Implement model type aware priority tag suggestions - Add model type normalization and resolution logic - Handle suggestion state reset when model type changes - Maintain backward compatibility with existing functionality The changes enable context-aware tag suggestions based on model type, improving tag relevance and user experience when editing tags for different model types. --- static/js/components/shared/ModelModal.js | 4 +- static/js/components/shared/ModelTags.js | 115 +++++++++++++++--- static/js/managers/BulkManager.js | 33 +++-- static/js/utils/priorityTagHelpers.js | 68 ++++++++++- .../managers/settingsManager.library.test.js | 5 + .../frontend/utils/priorityTagHelpers.test.js | 100 +++++++++++++++ 6 files changed, 293 insertions(+), 32 deletions(-) create mode 100644 tests/frontend/utils/priorityTagHelpers.test.js diff --git a/static/js/components/shared/ModelModal.js b/static/js/components/shared/ModelModal.js index b0032f3b..7fc366ce 100644 --- a/static/js/components/shared/ModelModal.js +++ b/static/js/components/shared/ModelModal.js @@ -236,7 +236,7 @@ export async function showModelModal(model, modelType) { setupShowcaseScroll(modalId); setupTabSwitching(); setupTagTooltip(); - setupTagEditMode(); + setupTagEditMode(modelType); setupModelNameEditing(modelWithFullData.file_path); setupBaseModelEditing(modelWithFullData.file_path); setupFileNameEditing(modelWithFullData.file_path); @@ -480,4 +480,4 @@ const modelModal = { scrollToTop }; -export { modelModal }; \ No newline at end of file +export { modelModal }; diff --git a/static/js/components/shared/ModelTags.js b/static/js/components/shared/ModelTags.js index fb270884..e1bf19c0 100644 --- a/static/js/components/shared/ModelTags.js +++ b/static/js/components/shared/ModelTags.js @@ -6,38 +6,120 @@ import { showToast } from '../../utils/uiHelpers.js'; import { getModelApiClient } from '../../api/modelApiFactory.js'; import { translate } from '../../utils/i18nHelpers.js'; import { getPriorityTagSuggestions } from '../../utils/priorityTagHelpers.js'; +import { state } from '../../state/index.js'; +const MODEL_TYPE_SUGGESTION_KEY_MAP = { + loras: 'lora', + lora: 'lora', + checkpoints: 'checkpoint', + checkpoint: 'checkpoint', + embeddings: 'embedding', + embedding: 'embedding', +}; + +let activeModelTypeKey = ''; let priorityTagSuggestions = []; let priorityTagSuggestionsLoaded = false; let priorityTagSuggestionsPromise = null; -function ensurePriorityTagSuggestions() { +function normalizeModelTypeKey(modelType) { + if (!modelType) { + return ''; + } + const lower = String(modelType).toLowerCase(); + if (MODEL_TYPE_SUGGESTION_KEY_MAP[lower]) { + return MODEL_TYPE_SUGGESTION_KEY_MAP[lower]; + } + if (lower.endsWith('s')) { + return lower.slice(0, -1); + } + return lower; +} + +function resolveModelTypeKey(modelType = null) { + if (modelType) { + return normalizeModelTypeKey(modelType); + } + if (activeModelTypeKey) { + return activeModelTypeKey; + } + if (state?.currentPageType) { + return normalizeModelTypeKey(state.currentPageType); + } + return ''; +} + +function resetSuggestionState() { + priorityTagSuggestions = []; + priorityTagSuggestionsLoaded = false; + priorityTagSuggestionsPromise = null; +} + +function setActiveModelTypeKey(modelType = null) { + const resolvedKey = resolveModelTypeKey(modelType); + if (resolvedKey === activeModelTypeKey) { + return activeModelTypeKey; + } + activeModelTypeKey = resolvedKey; + resetSuggestionState(); + return activeModelTypeKey; +} + +function ensurePriorityTagSuggestions(modelType = null) { + if (modelType !== null && modelType !== undefined) { + setActiveModelTypeKey(modelType); + } else if (!activeModelTypeKey) { + setActiveModelTypeKey(); + } + + if (!activeModelTypeKey) { + resetSuggestionState(); + priorityTagSuggestionsLoaded = true; + return Promise.resolve([]); + } + + if (priorityTagSuggestionsLoaded && !priorityTagSuggestionsPromise) { + return Promise.resolve(priorityTagSuggestions); + } + if (!priorityTagSuggestionsPromise) { - priorityTagSuggestionsPromise = getPriorityTagSuggestions() + const requestKey = activeModelTypeKey; + priorityTagSuggestionsPromise = getPriorityTagSuggestions(requestKey) .then((tags) => { - priorityTagSuggestions = tags; - priorityTagSuggestionsLoaded = true; + if (activeModelTypeKey === requestKey) { + priorityTagSuggestions = tags; + priorityTagSuggestionsLoaded = true; + } return tags; }) .catch(() => { - priorityTagSuggestions = []; - priorityTagSuggestionsLoaded = true; - return priorityTagSuggestions; + if (activeModelTypeKey === requestKey) { + priorityTagSuggestions = []; + priorityTagSuggestionsLoaded = true; + } + return []; }) .finally(() => { - priorityTagSuggestionsPromise = null; + if (activeModelTypeKey === requestKey) { + priorityTagSuggestionsPromise = null; + } }); } - return priorityTagSuggestionsLoaded && !priorityTagSuggestionsPromise - ? Promise.resolve(priorityTagSuggestions) - : priorityTagSuggestionsPromise; + return priorityTagSuggestionsPromise; } -ensurePriorityTagSuggestions(); +activeModelTypeKey = resolveModelTypeKey(); + +if (activeModelTypeKey) { + ensurePriorityTagSuggestions(); +} window.addEventListener('lm:priority-tags-updated', () => { - priorityTagSuggestionsLoaded = false; + if (!activeModelTypeKey) { + return; + } + resetSuggestionState(); ensurePriorityTagSuggestions().then(() => { document.querySelectorAll('.metadata-edit-container .metadata-suggestions-container').forEach((container) => { renderPriorityTagSuggestions(container, getCurrentEditTags()); @@ -52,9 +134,12 @@ let saveTagsHandler = null; /** * Set up tag editing mode */ -export function setupTagEditMode() { +export function setupTagEditMode(modelType = null) { const editBtn = document.querySelector('.edit-tags-btn'); if (!editBtn) return; + + setActiveModelTypeKey(modelType); + ensurePriorityTagSuggestions(); // Store original tags for restoring on cancel let originalTags = []; @@ -523,4 +608,4 @@ function getCurrentEditTags() { function restoreOriginalTags(section, originalTags) { // Nothing to do here as we're just hiding the edit UI // and showing the original compact tags which weren't modified -} \ No newline at end of file +} diff --git a/static/js/managers/BulkManager.js b/static/js/managers/BulkManager.js index fd0e5449..53de725f 100644 --- a/static/js/managers/BulkManager.js +++ b/static/js/managers/BulkManager.js @@ -66,7 +66,11 @@ export class BulkManager { if (!container) { return; } - getPriorityTagSuggestions().then((tags) => { + const currentType = state.currentPageType; + if (!currentType || currentType === 'recipes') { + return; + } + getPriorityTagSuggestions(currentType).then((tags) => { if (!container.isConnected) { return; } @@ -619,17 +623,22 @@ export class BulkManager { container.className = 'metadata-suggestions-container'; container.innerHTML = `
`; - getPriorityTagSuggestions().then((tags) => { - if (!container.isConnected) { - return; - } - this.renderBulkSuggestionItems(container, tags); - this.updateBulkSuggestionsDropdown(); - }).catch(() => { - if (container.isConnected) { - container.innerHTML = ''; - } - }); + const currentType = state.currentPageType; + if (!currentType || currentType === 'recipes') { + container.innerHTML = ''; + } else { + getPriorityTagSuggestions(currentType).then((tags) => { + if (!container.isConnected) { + return; + } + this.renderBulkSuggestionItems(container, tags); + this.updateBulkSuggestionsDropdown(); + }).catch(() => { + if (container.isConnected) { + container.innerHTML = ''; + } + }); + } dropdown.appendChild(container); return dropdown; diff --git a/static/js/utils/priorityTagHelpers.js b/static/js/utils/priorityTagHelpers.js index 67b4f8dd..cfd1f9c4 100644 --- a/static/js/utils/priorityTagHelpers.js +++ b/static/js/utils/priorityTagHelpers.js @@ -1,5 +1,28 @@ import { DEFAULT_PRIORITY_TAG_CONFIG } from './constants.js'; +const MODEL_TYPE_ALIAS_MAP = { + loras: 'lora', + lora: 'lora', + checkpoints: 'checkpoint', + checkpoint: 'checkpoint', + embeddings: 'embedding', + embedding: 'embedding', +}; + +function normalizeModelTypeKey(modelType) { + if (typeof modelType !== 'string') { + return ''; + } + const lower = modelType.toLowerCase(); + if (MODEL_TYPE_ALIAS_MAP[lower]) { + return MODEL_TYPE_ALIAS_MAP[lower]; + } + if (lower.endsWith('s')) { + return lower.slice(0, -1); + } + return lower; +} + function splitPriorityEntries(raw = '') { const segments = []; raw.split('\n').forEach(line => { @@ -152,7 +175,18 @@ export async function getPriorityTagSuggestionsMap() { if (!Array.isArray(tags)) { return; } - normalized[modelType] = tags.filter(tag => typeof tag === 'string' && tag.trim()); + const key = normalizeModelTypeKey(modelType) || (typeof modelType === 'string' ? modelType.toLowerCase() : ''); + if (!key) { + return; + } + const filtered = tags + .filter((tag) => typeof tag === 'string') + .map((tag) => tag.trim()) + .filter(Boolean); + if (!normalized[key]) { + normalized[key] = []; + } + normalized[key].push(...filtered); }); const withDefaults = applyDefaultPriorityTagFallback(normalized); @@ -172,8 +206,35 @@ export async function getPriorityTagSuggestionsMap() { return fetchPromise; } -export async function getPriorityTagSuggestions() { +export async function getPriorityTagSuggestions(modelType = null) { const map = await getPriorityTagSuggestionsMap(); + + if (modelType) { + const lower = typeof modelType === 'string' ? modelType.toLowerCase() : ''; + const normalizedKey = normalizeModelTypeKey(modelType); + const candidates = []; + if (lower) { + candidates.push(lower); + } + if (normalizedKey && !candidates.includes(normalizedKey)) { + candidates.push(normalizedKey); + } + Object.entries(MODEL_TYPE_ALIAS_MAP).forEach(([alias, target]) => { + if (alias === lower || target === normalizedKey) { + if (!candidates.includes(target)) { + candidates.push(target); + } + } + }); + + for (const key of candidates) { + if (Array.isArray(map[key])) { + return [...map[key]]; + } + } + return []; + } + const unique = new Set(); Object.values(map).forEach((tags) => { tags.forEach((tag) => { @@ -195,7 +256,8 @@ function buildDefaultPriorityTagMap() { const map = {}; Object.entries(DEFAULT_PRIORITY_TAG_CONFIG).forEach(([modelType, configString]) => { const entries = parsePriorityTagString(configString); - map[modelType] = entries.map((entry) => entry.canonical); + const key = normalizeModelTypeKey(modelType) || modelType; + map[key] = entries.map((entry) => entry.canonical); }); return map; } diff --git a/tests/frontend/managers/settingsManager.library.test.js b/tests/frontend/managers/settingsManager.library.test.js index 9b5d6dd5..98fcbf21 100644 --- a/tests/frontend/managers/settingsManager.library.test.js +++ b/tests/frontend/managers/settingsManager.library.test.js @@ -37,6 +37,11 @@ vi.mock('../../../static/js/utils/constants.js', () => ({ DEFAULT_PATH_TEMPLATES: {}, MAPPABLE_BASE_MODELS: [], PATH_TEMPLATE_PLACEHOLDERS: {}, + DEFAULT_PRIORITY_TAG_CONFIG: { + lora: 'character, style', + checkpoint: 'base, guide', + embedding: 'hint', + }, })); vi.mock('../../../static/js/utils/i18nHelpers.js', () => ({ diff --git a/tests/frontend/utils/priorityTagHelpers.test.js b/tests/frontend/utils/priorityTagHelpers.test.js new file mode 100644 index 00000000..6d7d0fb3 --- /dev/null +++ b/tests/frontend/utils/priorityTagHelpers.test.js @@ -0,0 +1,100 @@ +import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest'; +import { DEFAULT_PRIORITY_TAG_CONFIG } from '../../../static/js/utils/constants.js'; + +const MODULE_PATH = '../../../static/js/utils/priorityTagHelpers.js'; + +let originalFetch; +let invalidateCacheFn; + +beforeEach(() => { + originalFetch = global.fetch; + invalidateCacheFn = null; + vi.resetModules(); +}); + +afterEach(() => { + if (invalidateCacheFn) { + invalidateCacheFn(); + invalidateCacheFn = null; + } + + if (originalFetch === undefined) { + delete global.fetch; + } else { + global.fetch = originalFetch; + } + + vi.restoreAllMocks(); +}); + +describe('priorityTagHelpers suggestion handling', () => { + it('returns trimmed, deduplicated suggestions scoped to the requested model type', async () => { + const fetchMock = vi.fn().mockResolvedValue({ + ok: true, + json: async () => ({ + success: true, + tags: { + loras: ['character', 'style ', 'style'], + checkpoints: ['Base ', 'Primary'], + }, + }), + }); + vi.stubGlobal('fetch', fetchMock); + + const module = await import(MODULE_PATH); + invalidateCacheFn = module.invalidatePriorityTagSuggestionsCache; + + const loraTags = await module.getPriorityTagSuggestions('loras'); + expect(loraTags).toEqual(['character', 'style']); + + const checkpointTags = await module.getPriorityTagSuggestions('CHECKPOINT'); + expect(checkpointTags).toEqual(['Base', 'Primary']); + + const aliasTags = await module.getPriorityTagSuggestions('lora'); + expect(aliasTags).toEqual(['character', 'style']); + + const defaultEmbedding = module + .parsePriorityTagString(DEFAULT_PRIORITY_TAG_CONFIG.embedding) + .map((entry) => entry.canonical); + const embeddingTags = await module.getPriorityTagSuggestions('embeddings'); + expect(embeddingTags).toEqual(defaultEmbedding); + + expect(fetchMock).toHaveBeenCalledTimes(1); + }); + + it('returns a unique union of suggestions when no model type is provided', async () => { + const fetchMock = vi.fn().mockResolvedValue({ + ok: true, + json: async () => ({ + success: true, + tags: { + lora: ['primary', 'support'], + checkpoint: ['guide', 'primary'], + embeddings: ['hint'], + }, + }), + }); + vi.stubGlobal('fetch', fetchMock); + + const module = await import(MODULE_PATH); + invalidateCacheFn = module.invalidatePriorityTagSuggestionsCache; + + const suggestions = await module.getPriorityTagSuggestions(); + expect(suggestions).toEqual(['primary', 'support', 'guide', 'hint']); + }); + + it('falls back to default configuration when fetching suggestions fails', async () => { + const fetchMock = vi.fn().mockRejectedValue(new Error('network error')); + vi.stubGlobal('fetch', fetchMock); + + const module = await import(MODULE_PATH); + invalidateCacheFn = module.invalidatePriorityTagSuggestionsCache; + + const expected = module + .parsePriorityTagString(DEFAULT_PRIORITY_TAG_CONFIG.lora) + .map((entry) => entry.canonical); + + const result = await module.getPriorityTagSuggestions('loras'); + expect(result).toEqual(expected); + }); +});