feat: add model type context to tag suggestions

- Pass modelType parameter to setupTagEditMode function
- Implement model type aware priority tag suggestions
- Add model type normalization and resolution logic
- Handle suggestion state reset when model type changes
- Maintain backward compatibility with existing functionality

The changes enable context-aware tag suggestions based on model type, improving tag relevance and user experience when editing tags for different model types.
This commit is contained in:
Will Miao
2025-10-11 20:36:38 +08:00
parent b0847f6b87
commit ec9b37eb53
6 changed files with 293 additions and 32 deletions

View File

@@ -236,7 +236,7 @@ export async function showModelModal(model, modelType) {
setupShowcaseScroll(modalId);
setupTabSwitching();
setupTagTooltip();
setupTagEditMode();
setupTagEditMode(modelType);
setupModelNameEditing(modelWithFullData.file_path);
setupBaseModelEditing(modelWithFullData.file_path);
setupFileNameEditing(modelWithFullData.file_path);
@@ -480,4 +480,4 @@ const modelModal = {
scrollToTop
};
export { modelModal };
export { modelModal };

View File

@@ -6,38 +6,120 @@ import { showToast } from '../../utils/uiHelpers.js';
import { getModelApiClient } from '../../api/modelApiFactory.js';
import { translate } from '../../utils/i18nHelpers.js';
import { getPriorityTagSuggestions } from '../../utils/priorityTagHelpers.js';
import { state } from '../../state/index.js';
const MODEL_TYPE_SUGGESTION_KEY_MAP = {
loras: 'lora',
lora: 'lora',
checkpoints: 'checkpoint',
checkpoint: 'checkpoint',
embeddings: 'embedding',
embedding: 'embedding',
};
let activeModelTypeKey = '';
let priorityTagSuggestions = [];
let priorityTagSuggestionsLoaded = false;
let priorityTagSuggestionsPromise = null;
function ensurePriorityTagSuggestions() {
function normalizeModelTypeKey(modelType) {
if (!modelType) {
return '';
}
const lower = String(modelType).toLowerCase();
if (MODEL_TYPE_SUGGESTION_KEY_MAP[lower]) {
return MODEL_TYPE_SUGGESTION_KEY_MAP[lower];
}
if (lower.endsWith('s')) {
return lower.slice(0, -1);
}
return lower;
}
function resolveModelTypeKey(modelType = null) {
if (modelType) {
return normalizeModelTypeKey(modelType);
}
if (activeModelTypeKey) {
return activeModelTypeKey;
}
if (state?.currentPageType) {
return normalizeModelTypeKey(state.currentPageType);
}
return '';
}
function resetSuggestionState() {
priorityTagSuggestions = [];
priorityTagSuggestionsLoaded = false;
priorityTagSuggestionsPromise = null;
}
function setActiveModelTypeKey(modelType = null) {
const resolvedKey = resolveModelTypeKey(modelType);
if (resolvedKey === activeModelTypeKey) {
return activeModelTypeKey;
}
activeModelTypeKey = resolvedKey;
resetSuggestionState();
return activeModelTypeKey;
}
function ensurePriorityTagSuggestions(modelType = null) {
if (modelType !== null && modelType !== undefined) {
setActiveModelTypeKey(modelType);
} else if (!activeModelTypeKey) {
setActiveModelTypeKey();
}
if (!activeModelTypeKey) {
resetSuggestionState();
priorityTagSuggestionsLoaded = true;
return Promise.resolve([]);
}
if (priorityTagSuggestionsLoaded && !priorityTagSuggestionsPromise) {
return Promise.resolve(priorityTagSuggestions);
}
if (!priorityTagSuggestionsPromise) {
priorityTagSuggestionsPromise = getPriorityTagSuggestions()
const requestKey = activeModelTypeKey;
priorityTagSuggestionsPromise = getPriorityTagSuggestions(requestKey)
.then((tags) => {
priorityTagSuggestions = tags;
priorityTagSuggestionsLoaded = true;
if (activeModelTypeKey === requestKey) {
priorityTagSuggestions = tags;
priorityTagSuggestionsLoaded = true;
}
return tags;
})
.catch(() => {
priorityTagSuggestions = [];
priorityTagSuggestionsLoaded = true;
return priorityTagSuggestions;
if (activeModelTypeKey === requestKey) {
priorityTagSuggestions = [];
priorityTagSuggestionsLoaded = true;
}
return [];
})
.finally(() => {
priorityTagSuggestionsPromise = null;
if (activeModelTypeKey === requestKey) {
priorityTagSuggestionsPromise = null;
}
});
}
return priorityTagSuggestionsLoaded && !priorityTagSuggestionsPromise
? Promise.resolve(priorityTagSuggestions)
: priorityTagSuggestionsPromise;
return priorityTagSuggestionsPromise;
}
ensurePriorityTagSuggestions();
activeModelTypeKey = resolveModelTypeKey();
if (activeModelTypeKey) {
ensurePriorityTagSuggestions();
}
window.addEventListener('lm:priority-tags-updated', () => {
priorityTagSuggestionsLoaded = false;
if (!activeModelTypeKey) {
return;
}
resetSuggestionState();
ensurePriorityTagSuggestions().then(() => {
document.querySelectorAll('.metadata-edit-container .metadata-suggestions-container').forEach((container) => {
renderPriorityTagSuggestions(container, getCurrentEditTags());
@@ -52,9 +134,12 @@ let saveTagsHandler = null;
/**
* Set up tag editing mode
*/
export function setupTagEditMode() {
export function setupTagEditMode(modelType = null) {
const editBtn = document.querySelector('.edit-tags-btn');
if (!editBtn) return;
setActiveModelTypeKey(modelType);
ensurePriorityTagSuggestions();
// Store original tags for restoring on cancel
let originalTags = [];
@@ -523,4 +608,4 @@ function getCurrentEditTags() {
function restoreOriginalTags(section, originalTags) {
// Nothing to do here as we're just hiding the edit UI
// and showing the original compact tags which weren't modified
}
}

View File

@@ -66,7 +66,11 @@ export class BulkManager {
if (!container) {
return;
}
getPriorityTagSuggestions().then((tags) => {
const currentType = state.currentPageType;
if (!currentType || currentType === 'recipes') {
return;
}
getPriorityTagSuggestions(currentType).then((tags) => {
if (!container.isConnected) {
return;
}
@@ -619,17 +623,22 @@ export class BulkManager {
container.className = 'metadata-suggestions-container';
container.innerHTML = `<div class="metadata-suggestions-loading">${translate('settings.priorityTags.loadingSuggestions', 'Loading suggestions…')}</div>`;
getPriorityTagSuggestions().then((tags) => {
if (!container.isConnected) {
return;
}
this.renderBulkSuggestionItems(container, tags);
this.updateBulkSuggestionsDropdown();
}).catch(() => {
if (container.isConnected) {
container.innerHTML = '';
}
});
const currentType = state.currentPageType;
if (!currentType || currentType === 'recipes') {
container.innerHTML = '';
} else {
getPriorityTagSuggestions(currentType).then((tags) => {
if (!container.isConnected) {
return;
}
this.renderBulkSuggestionItems(container, tags);
this.updateBulkSuggestionsDropdown();
}).catch(() => {
if (container.isConnected) {
container.innerHTML = '';
}
});
}
dropdown.appendChild(container);
return dropdown;

View File

@@ -1,5 +1,28 @@
import { DEFAULT_PRIORITY_TAG_CONFIG } from './constants.js';
const MODEL_TYPE_ALIAS_MAP = {
loras: 'lora',
lora: 'lora',
checkpoints: 'checkpoint',
checkpoint: 'checkpoint',
embeddings: 'embedding',
embedding: 'embedding',
};
function normalizeModelTypeKey(modelType) {
if (typeof modelType !== 'string') {
return '';
}
const lower = modelType.toLowerCase();
if (MODEL_TYPE_ALIAS_MAP[lower]) {
return MODEL_TYPE_ALIAS_MAP[lower];
}
if (lower.endsWith('s')) {
return lower.slice(0, -1);
}
return lower;
}
function splitPriorityEntries(raw = '') {
const segments = [];
raw.split('\n').forEach(line => {
@@ -152,7 +175,18 @@ export async function getPriorityTagSuggestionsMap() {
if (!Array.isArray(tags)) {
return;
}
normalized[modelType] = tags.filter(tag => typeof tag === 'string' && tag.trim());
const key = normalizeModelTypeKey(modelType) || (typeof modelType === 'string' ? modelType.toLowerCase() : '');
if (!key) {
return;
}
const filtered = tags
.filter((tag) => typeof tag === 'string')
.map((tag) => tag.trim())
.filter(Boolean);
if (!normalized[key]) {
normalized[key] = [];
}
normalized[key].push(...filtered);
});
const withDefaults = applyDefaultPriorityTagFallback(normalized);
@@ -172,8 +206,35 @@ export async function getPriorityTagSuggestionsMap() {
return fetchPromise;
}
export async function getPriorityTagSuggestions() {
export async function getPriorityTagSuggestions(modelType = null) {
const map = await getPriorityTagSuggestionsMap();
if (modelType) {
const lower = typeof modelType === 'string' ? modelType.toLowerCase() : '';
const normalizedKey = normalizeModelTypeKey(modelType);
const candidates = [];
if (lower) {
candidates.push(lower);
}
if (normalizedKey && !candidates.includes(normalizedKey)) {
candidates.push(normalizedKey);
}
Object.entries(MODEL_TYPE_ALIAS_MAP).forEach(([alias, target]) => {
if (alias === lower || target === normalizedKey) {
if (!candidates.includes(target)) {
candidates.push(target);
}
}
});
for (const key of candidates) {
if (Array.isArray(map[key])) {
return [...map[key]];
}
}
return [];
}
const unique = new Set();
Object.values(map).forEach((tags) => {
tags.forEach((tag) => {
@@ -195,7 +256,8 @@ function buildDefaultPriorityTagMap() {
const map = {};
Object.entries(DEFAULT_PRIORITY_TAG_CONFIG).forEach(([modelType, configString]) => {
const entries = parsePriorityTagString(configString);
map[modelType] = entries.map((entry) => entry.canonical);
const key = normalizeModelTypeKey(modelType) || modelType;
map[key] = entries.map((entry) => entry.canonical);
});
return map;
}

View File

@@ -37,6 +37,11 @@ vi.mock('../../../static/js/utils/constants.js', () => ({
DEFAULT_PATH_TEMPLATES: {},
MAPPABLE_BASE_MODELS: [],
PATH_TEMPLATE_PLACEHOLDERS: {},
DEFAULT_PRIORITY_TAG_CONFIG: {
lora: 'character, style',
checkpoint: 'base, guide',
embedding: 'hint',
},
}));
vi.mock('../../../static/js/utils/i18nHelpers.js', () => ({

View File

@@ -0,0 +1,100 @@
import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest';
import { DEFAULT_PRIORITY_TAG_CONFIG } from '../../../static/js/utils/constants.js';
const MODULE_PATH = '../../../static/js/utils/priorityTagHelpers.js';
let originalFetch;
let invalidateCacheFn;
beforeEach(() => {
originalFetch = global.fetch;
invalidateCacheFn = null;
vi.resetModules();
});
afterEach(() => {
if (invalidateCacheFn) {
invalidateCacheFn();
invalidateCacheFn = null;
}
if (originalFetch === undefined) {
delete global.fetch;
} else {
global.fetch = originalFetch;
}
vi.restoreAllMocks();
});
describe('priorityTagHelpers suggestion handling', () => {
it('returns trimmed, deduplicated suggestions scoped to the requested model type', async () => {
const fetchMock = vi.fn().mockResolvedValue({
ok: true,
json: async () => ({
success: true,
tags: {
loras: ['character', 'style ', 'style'],
checkpoints: ['Base ', 'Primary'],
},
}),
});
vi.stubGlobal('fetch', fetchMock);
const module = await import(MODULE_PATH);
invalidateCacheFn = module.invalidatePriorityTagSuggestionsCache;
const loraTags = await module.getPriorityTagSuggestions('loras');
expect(loraTags).toEqual(['character', 'style']);
const checkpointTags = await module.getPriorityTagSuggestions('CHECKPOINT');
expect(checkpointTags).toEqual(['Base', 'Primary']);
const aliasTags = await module.getPriorityTagSuggestions('lora');
expect(aliasTags).toEqual(['character', 'style']);
const defaultEmbedding = module
.parsePriorityTagString(DEFAULT_PRIORITY_TAG_CONFIG.embedding)
.map((entry) => entry.canonical);
const embeddingTags = await module.getPriorityTagSuggestions('embeddings');
expect(embeddingTags).toEqual(defaultEmbedding);
expect(fetchMock).toHaveBeenCalledTimes(1);
});
it('returns a unique union of suggestions when no model type is provided', async () => {
const fetchMock = vi.fn().mockResolvedValue({
ok: true,
json: async () => ({
success: true,
tags: {
lora: ['primary', 'support'],
checkpoint: ['guide', 'primary'],
embeddings: ['hint'],
},
}),
});
vi.stubGlobal('fetch', fetchMock);
const module = await import(MODULE_PATH);
invalidateCacheFn = module.invalidatePriorityTagSuggestionsCache;
const suggestions = await module.getPriorityTagSuggestions();
expect(suggestions).toEqual(['primary', 'support', 'guide', 'hint']);
});
it('falls back to default configuration when fetching suggestions fails', async () => {
const fetchMock = vi.fn().mockRejectedValue(new Error('network error'));
vi.stubGlobal('fetch', fetchMock);
const module = await import(MODULE_PATH);
invalidateCacheFn = module.invalidatePriorityTagSuggestionsCache;
const expected = module
.parsePriorityTagString(DEFAULT_PRIORITY_TAG_CONFIG.lora)
.map((entry) => entry.canonical);
const result = await module.getPriorityTagSuggestions('loras');
expect(result).toEqual(expected);
});
});