From 31223f0526b91a6ff67b99995e1ea62e5aa6ba6a Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Mon, 4 Aug 2025 23:37:27 +0800 Subject: [PATCH] feat: enhance model root fetching and moving functionality across various components --- py/routes/checkpoint_routes.py | 26 +++++- static/js/api/baseModelApi.js | 6 ++ static/js/api/checkpointApi.js | 26 ++++++ .../ContextMenu/CheckpointContextMenu.js | 5 +- .../ContextMenu/EmbeddingContextMenu.js | 5 +- .../components/ContextMenu/LoraContextMenu.js | 1 + static/js/components/shared/ModelCard.js | 18 ++-- static/js/loras.js | 2 - static/js/managers/MoveManager.js | 83 ++++++++++--------- static/js/state/index.js | 40 ++++++++- templates/components/modals/move_modal.html | 6 +- 11 files changed, 158 insertions(+), 60 deletions(-) diff --git a/py/routes/checkpoint_routes.py b/py/routes/checkpoint_routes.py index b8edabac..9b5d20a6 100644 --- a/py/routes/checkpoint_routes.py +++ b/py/routes/checkpoint_routes.py @@ -111,8 +111,30 @@ class CheckpointRoutes(BaseModelRoutes): async def get_checkpoints_roots(self, request: web.Request) -> web.Response: """Return the list of checkpoint roots from config""" - return web.json_response({"checkpoints_roots": config.checkpoints_roots}) + try: + roots = config.checkpoints_roots + return web.json_response({ + "success": True, + "roots": roots + }) + except Exception as e: + logger.error(f"Error getting checkpoint roots: {e}", exc_info=True) + return web.json_response({ + "success": False, + "error": str(e) + }, status=500) async def get_unet_roots(self, request: web.Request) -> web.Response: """Return the list of unet roots from config""" - return web.json_response({"unet_roots": config.unet_roots}) \ No newline at end of file + try: + roots = config.unet_roots + return web.json_response({ + "success": True, + "roots": roots + }) + except Exception as e: + logger.error(f"Error getting unet roots: {e}", exc_info=True) + return web.json_response({ + "success": False, + "error": str(e) + }, status=500) \ No newline at end of file diff --git a/static/js/api/baseModelApi.js b/static/js/api/baseModelApi.js index eb60c2eb..2efdd313 100644 --- a/static/js/api/baseModelApi.js +++ b/static/js/api/baseModelApi.js @@ -555,6 +555,12 @@ export class BaseModelApiClient { async fetchModelRoots() { try { + // For checkpoints, use the specific method that considers modelType + // if (this.modelType === 'checkpoints') { + // const pageState = this.getPageState(); + // return await this.fetchModelRoots(pageState.modelType || 'checkpoint'); + // } + const response = await fetch(this.apiConfig.endpoints.roots); if (!response.ok) { throw new Error(`Failed to fetch ${this.apiConfig.config.displayName} roots`); diff --git a/static/js/api/checkpointApi.js b/static/js/api/checkpointApi.js index 6fdf8b58..ce69e999 100644 --- a/static/js/api/checkpointApi.js +++ b/static/js/api/checkpointApi.js @@ -64,4 +64,30 @@ export class CheckpointApiClient extends BaseModelApiClient { throw error; } } + + /** + * Get appropriate roots based on model type + */ + async fetchModelRoots(modelType = 'checkpoint') { + try { + let response; + if (modelType === 'diffusion_model') { + response = await fetch(this.apiConfig.endpoints.specific.unet_roots, { + method: 'GET' + }); + } else { + response = await fetch(this.apiConfig.endpoints.specific.checkpoints_roots, { + method: 'GET' + }); + } + + if (!response.ok) { + throw new Error(`Failed to fetch ${modelType} roots`); + } + return await response.json(); + } catch (error) { + console.error(`Error fetching ${modelType} roots:`, error); + throw error; + } + } } diff --git a/static/js/components/ContextMenu/CheckpointContextMenu.js b/static/js/components/ContextMenu/CheckpointContextMenu.js index 9f8f6216..748f27aa 100644 --- a/static/js/components/ContextMenu/CheckpointContextMenu.js +++ b/static/js/components/ContextMenu/CheckpointContextMenu.js @@ -1,8 +1,8 @@ import { BaseContextMenu } from './BaseContextMenu.js'; import { ModelContextMenuMixin } from './ModelContextMenuMixin.js'; import { getModelApiClient, resetAndReload } from '../../api/modelApiFactory.js'; -import { showToast } from '../../utils/uiHelpers.js'; import { showDeleteModal, showExcludeModal } from '../../utils/modalUtils.js'; +import { moveManager } from '../../managers/MoveManager.js'; export class CheckpointContextMenu extends BaseContextMenu { constructor() { @@ -54,8 +54,7 @@ export class CheckpointContextMenu extends BaseContextMenu { apiClient.refreshSingleModelMetadata(this.currentCard.dataset.filepath); break; case 'move': - // Move to folder (placeholder) - showToast('Move to folder feature coming soon', 'info'); + moveManager.showMoveModal(this.currentCard.dataset.filepath, this.currentCard.dataset.model_type); break; case 'exclude': showExcludeModal(this.currentCard.dataset.filepath); diff --git a/static/js/components/ContextMenu/EmbeddingContextMenu.js b/static/js/components/ContextMenu/EmbeddingContextMenu.js index a820ea16..0629bff9 100644 --- a/static/js/components/ContextMenu/EmbeddingContextMenu.js +++ b/static/js/components/ContextMenu/EmbeddingContextMenu.js @@ -1,7 +1,7 @@ import { BaseContextMenu } from './BaseContextMenu.js'; import { ModelContextMenuMixin } from './ModelContextMenuMixin.js'; import { getModelApiClient, resetAndReload } from '../../api/modelApiFactory.js'; -import { showToast } from '../../utils/uiHelpers.js'; +import { moveManager } from '../../managers/MoveManager.js'; import { showDeleteModal, showExcludeModal } from '../../utils/modalUtils.js'; export class EmbeddingContextMenu extends BaseContextMenu { @@ -54,8 +54,7 @@ export class EmbeddingContextMenu extends BaseContextMenu { apiClient.refreshSingleModelMetadata(this.currentCard.dataset.filepath); break; case 'move': - // Move to folder (placeholder) - showToast('Move to folder feature coming soon', 'info'); + moveManager.showMoveModal(this.currentCard.dataset.filepath); break; case 'exclude': showExcludeModal(this.currentCard.dataset.filepath); diff --git a/static/js/components/ContextMenu/LoraContextMenu.js b/static/js/components/ContextMenu/LoraContextMenu.js index 83719e58..298f593d 100644 --- a/static/js/components/ContextMenu/LoraContextMenu.js +++ b/static/js/components/ContextMenu/LoraContextMenu.js @@ -3,6 +3,7 @@ import { ModelContextMenuMixin } from './ModelContextMenuMixin.js'; import { getModelApiClient, resetAndReload } from '../../api/modelApiFactory.js'; import { copyToClipboard, sendLoraToWorkflow } from '../../utils/uiHelpers.js'; import { showExcludeModal, showDeleteModal } from '../../utils/modalUtils.js'; +import { moveManager } from '../../managers/MoveManager.js'; export class LoraContextMenu extends BaseContextMenu { constructor() { diff --git a/static/js/components/shared/ModelCard.js b/static/js/components/shared/ModelCard.js index 69d3ea43..e8fb6bec 100644 --- a/static/js/components/shared/ModelCard.js +++ b/static/js/components/shared/ModelCard.js @@ -5,6 +5,7 @@ import { toggleShowcase } from './showcase/ShowcaseView.js'; import { bulkManager } from '../../managers/BulkManager.js'; import { modalManager } from '../../managers/ModalManager.js'; import { NSFW_LEVELS } from '../../utils/constants.js'; +import { MODEL_TYPES } from '../../api/apiConfig.js'; import { getModelApiClient } from '../../api/modelApiFactory.js'; import { showDeleteModal } from '../../utils/modalUtils.js'; @@ -152,7 +153,7 @@ async function toggleFavorite(card) { } function handleSendToWorkflow(card, replaceMode, modelType) { - if (modelType === 'loras') { + if (modelType === MODEL_TYPES.LORA) { const usageTips = JSON.parse(card.dataset.usage_tips || '{}'); const strength = usageTips.strength || 1; const loraSyntax = ``; @@ -164,16 +165,16 @@ function handleSendToWorkflow(card, replaceMode, modelType) { } function handleCopyAction(card, modelType) { - if (modelType === 'loras') { + if (modelType === MODEL_TYPES.LORA) { const usageTips = JSON.parse(card.dataset.usage_tips || '{}'); const strength = usageTips.strength || 1; const loraSyntax = ``; copyToClipboard(loraSyntax, 'LoRA syntax copied to clipboard'); - } else if (modelType === 'checkpoints') { + } else if (modelType === MODEL_TYPES.CHECKPOINT) { // Checkpoint copy functionality - copy checkpoint name const checkpointName = card.dataset.file_name; copyToClipboard(checkpointName, 'Checkpoint name copied'); - } else if (modelType === 'embeddings') { + } else if (modelType === MODEL_TYPES.EMBEDDING) { const embeddingName = card.dataset.file_name; copyToClipboard(embeddingName, 'Embedding name copied'); } @@ -377,10 +378,15 @@ export function createModelCard(model, modelType) { card.dataset.favorite = model.favorite ? 'true' : 'false'; // LoRA specific data - if (modelType === 'loras') { + if (modelType === MODEL_TYPES.LORA) { card.dataset.usage_tips = model.usage_tips; } + // checkpoint specific data + if (modelType === MODEL_TYPES.CHECKPOINT) { + card.dataset.model_type = model.model_type; // checkpoint or diffusion_model + } + // Store metadata if available if (model.civitai) { card.dataset.meta = JSON.stringify(model.civitai || {}); @@ -406,7 +412,7 @@ export function createModelCard(model, modelType) { } // Apply selection state if in bulk mode and this card is in the selected set (LoRA only) - if (modelType === 'loras' && state.bulkMode && state.selectedLoras.has(model.file_path)) { + if (modelType === MODEL_TYPES.LORA && state.bulkMode && state.selectedLoras.has(model.file_path)) { card.classList.add('selected'); } diff --git a/static/js/loras.js b/static/js/loras.js index c26a25c6..43608ca6 100644 --- a/static/js/loras.js +++ b/static/js/loras.js @@ -2,7 +2,6 @@ import { appCore } from './core.js'; import { state } from './state/index.js'; import { updateCardsForBulkMode } from './components/shared/ModelCard.js'; import { bulkManager } from './managers/BulkManager.js'; -import { moveManager } from './managers/MoveManager.js'; import { LoraContextMenu } from './components/ContextMenu/index.js'; import { createPageControls } from './components/controls/index.js'; import { confirmDelete, closeDeleteModal, confirmExclude, closeExcludeModal } from './utils/modalUtils.js'; @@ -33,7 +32,6 @@ class LoraPageManager { window.closeDeleteModal = closeDeleteModal; window.confirmExclude = confirmExclude; window.closeExcludeModal = closeExcludeModal; - window.moveManager = moveManager; // Bulk operations window.toggleBulkMode = () => bulkManager.toggleBulkMode(); diff --git a/static/js/managers/MoveManager.js b/static/js/managers/MoveManager.js index a0c3f5e4..0dc9a1e4 100644 --- a/static/js/managers/MoveManager.js +++ b/static/js/managers/MoveManager.js @@ -9,99 +9,107 @@ class MoveManager { this.currentFilePath = null; this.bulkFilePaths = null; this.modal = document.getElementById('moveModal'); - this.loraRootSelect = document.getElementById('moveLoraRoot'); + this.modelRootSelect = document.getElementById('moveModelRoot'); this.folderBrowser = document.getElementById('moveFolderBrowser'); this.newFolderInput = document.getElementById('moveNewFolder'); this.pathDisplay = document.getElementById('moveTargetPathDisplay'); this.modalTitle = document.getElementById('moveModalTitle'); + this.rootLabel = document.getElementById('moveRootLabel'); this.initializeEventListeners(); } initializeEventListeners() { - // 初始化LoRA根目录选择器 - this.loraRootSelect.addEventListener('change', () => this.updatePathPreview()); + // Initialize model root directory selector + this.modelRootSelect.addEventListener('change', () => this.updatePathPreview()); - // 文件夹选择事件 + // Folder selection event this.folderBrowser.addEventListener('click', (e) => { const folderItem = e.target.closest('.folder-item'); if (!folderItem) return; - // 如果点击已选中的文件夹,则取消选择 + // If clicking already selected folder, deselect it if (folderItem.classList.contains('selected')) { folderItem.classList.remove('selected'); } else { - // 取消其他选中状态 + // Deselect other folders this.folderBrowser.querySelectorAll('.folder-item').forEach(item => { item.classList.remove('selected'); }); - // 设置当前选中状态 + // Select current folder folderItem.classList.add('selected'); } this.updatePathPreview(); }); - // 新文件夹输入事件 + // New folder input event this.newFolderInput.addEventListener('input', () => this.updatePathPreview()); } - async showMoveModal(filePath) { + async showMoveModal(filePath, modelType = null) { // Reset state this.currentFilePath = null; this.bulkFilePaths = null; + const apiClient = getModelApiClient(); + const currentPageType = state.currentPageType; + const modelConfig = apiClient.apiConfig.config; + // Handle bulk mode if (filePath === 'bulk') { - const selectedPaths = Array.from(state.selectedLoras); + const selectedPaths = Array.from(state.selectedModels); if (selectedPaths.length === 0) { - showToast('No LoRAs selected', 'warning'); + showToast('No models selected', 'warning'); return; } this.bulkFilePaths = selectedPaths; - this.modalTitle.textContent = `Move ${selectedPaths.length} LoRAs`; + this.modalTitle.textContent = `Move ${selectedPaths.length} ${modelConfig.displayName}s`; } else { // Single file mode this.currentFilePath = filePath; - this.modalTitle.textContent = "Move Model"; + this.modalTitle.textContent = `Move ${modelConfig.displayName}`; } - // 清除之前的选择 + // Update UI labels based on model type + this.rootLabel.textContent = `Select ${modelConfig.displayName} Root:`; + this.pathDisplay.querySelector('.path-text').textContent = `Select a ${modelConfig.displayName.toLowerCase()} root directory`; + + // Clear previous selections this.folderBrowser.querySelectorAll('.folder-item').forEach(item => { item.classList.remove('selected'); }); this.newFolderInput.value = ''; try { - // Fetch LoRA roots - const rootsResponse = await fetch('/api/loras/roots'); - if (!rootsResponse.ok) { - throw new Error('Failed to fetch LoRA roots'); + // Fetch model roots + let rootsData; + if (modelType) { + // For checkpoints, use the specific API method that considers modelType + rootsData = await apiClient.fetchModelRoots(modelType); + } else { + // For other model types, use the generic method + rootsData = await apiClient.fetchModelRoots(); } - const rootsData = await rootsResponse.json(); if (!rootsData.roots || rootsData.roots.length === 0) { - throw new Error('No LoRA roots found'); + throw new Error(`No ${modelConfig.displayName.toLowerCase()} roots found`); } - // 填充LoRA根目录选择器 - this.loraRootSelect.innerHTML = rootsData.roots.map(root => + // Populate model root selector + this.modelRootSelect.innerHTML = rootsData.roots.map(root => `` ).join(''); - // Set default lora root if available - const defaultRoot = getStorageItem('settings', {}).default_lora_root; + // Set default root if available + const settingsKey = `default_${currentPageType.slice(0, -1)}_root`; // Remove 's' from plural + const defaultRoot = getStorageItem('settings', {})[settingsKey]; if (defaultRoot && rootsData.roots.includes(defaultRoot)) { - this.loraRootSelect.value = defaultRoot; + this.modelRootSelect.value = defaultRoot; } // Fetch folders dynamically - const foldersResponse = await fetch('/api/loras/folders'); - if (!foldersResponse.ok) { - throw new Error('Failed to fetch folders'); - } - - const foldersData = await foldersResponse.json(); + const foldersData = await apiClient.fetchModelFolders(); // Update folder browser with dynamic content this.folderBrowser.innerHTML = foldersData.folders.map(folder => @@ -112,13 +120,13 @@ class MoveManager { modalManager.showModal('moveModal'); } catch (error) { - console.error('Error fetching LoRA roots or folders:', error); + console.error(`Error fetching ${modelConfig.displayName.toLowerCase()} roots or folders:`, error); showToast(error.message, 'error'); } } updatePathPreview() { - const selectedRoot = this.loraRootSelect.value; + const selectedRoot = this.modelRootSelect.value; const selectedFolder = this.folderBrowser.querySelector('.folder-item.selected')?.dataset.folder || ''; const newFolder = this.newFolderInput.value.trim(); @@ -134,7 +142,7 @@ class MoveManager { } async moveModel() { - const selectedRoot = this.loraRootSelect.value; + const selectedRoot = this.modelRootSelect.value; const selectedFolder = this.folderBrowser.querySelector('.folder-item.selected')?.dataset.folder || ''; const newFolder = this.newFolderInput.value.trim(); @@ -191,11 +199,8 @@ class MoveManager { // Refresh folder tags after successful move try { - const foldersResponse = await fetch('/api/loras/folders'); - if (foldersResponse.ok) { - const foldersData = await foldersResponse.json(); - updateFolderTags(foldersData.folders); - } + const foldersData = await apiClient.fetchModelFolders(); + updateFolderTags(foldersData.folders); } catch (error) { console.error('Error refreshing folder tags:', error); } diff --git a/static/js/state/index.js b/static/js/state/index.js index ac77bbc2..e539d41b 100644 --- a/static/js/state/index.js +++ b/static/js/state/index.js @@ -89,6 +89,9 @@ export const state = { baseModel: [], tags: [] }, + modelType: 'checkpoint', // 'checkpoint' or 'diffusion_model' + bulkMode: false, + selectedModels: new Set(), showFavoritesOnly: false, duplicatesMode: false, }, @@ -112,6 +115,8 @@ export const state = { baseModel: [], tags: [] }, + bulkMode: false, + selectedModels: new Set(), showFavoritesOnly: false, duplicatesMode: false, } @@ -154,12 +159,43 @@ export const state = { get filters() { return this.pages[this.currentPageType].filters; }, set filters(value) { this.pages[this.currentPageType].filters = value; }, - get bulkMode() { return this.pages.loras.bulkMode; }, - set bulkMode(value) { this.pages.loras.bulkMode = value; }, + get bulkMode() { + const currentType = this.currentPageType; + if (currentType === MODEL_TYPES.LORA) { + return this.pages.loras.bulkMode; + } else { + return this.pages[currentType].bulkMode; + } + }, + set bulkMode(value) { + const currentType = this.currentPageType; + if (currentType === MODEL_TYPES.LORA) { + this.pages.loras.bulkMode = value; + } else { + this.pages[currentType].bulkMode = value; + } + }, get selectedLoras() { return this.pages.loras.selectedLoras; }, set selectedLoras(value) { this.pages.loras.selectedLoras = value; }, + get selectedModels() { + const currentType = this.currentPageType; + if (currentType === MODEL_TYPES.LORA) { + return this.pages.loras.selectedLoras; + } else { + return this.pages[currentType].selectedModels; + } + }, + set selectedModels(value) { + const currentType = this.currentPageType; + if (currentType === MODEL_TYPES.LORA) { + this.pages.loras.selectedLoras = value; + } else { + this.pages[currentType].selectedModels = value; + } + }, + get loraMetadataCache() { return this.pages.loras.loraMetadataCache; }, set loraMetadataCache(value) { this.pages.loras.loraMetadataCache = value; }, diff --git a/templates/components/modals/move_modal.html b/templates/components/modals/move_modal.html index dffeb46b..b978d740 100644 --- a/templates/components/modals/move_modal.html +++ b/templates/components/modals/move_modal.html @@ -9,13 +9,13 @@
- Select a LoRA root directory + Select a model root directory
- - + +