diff --git a/py/routes/checkpoints_routes.py b/py/routes/checkpoints_routes.py index e65bea5b..1d3627a5 100644 --- a/py/routes/checkpoints_routes.py +++ b/py/routes/checkpoints_routes.py @@ -3,12 +3,14 @@ import json import jinja2 from aiohttp import web import logging +import asyncio from ..utils.routes_common import ModelRouteUtils from ..utils.constants import NSFW_LEVELS from ..services.civitai_client import CivitaiClient from ..services.websocket_manager import ws_manager from ..services.checkpoint_scanner import CheckpointScanner +from ..services.download_manager import DownloadManager from ..config import config from ..services.settings_manager import settings from ..utils.utils import fuzzy_match @@ -24,6 +26,8 @@ class CheckpointsRoutes: loader=jinja2.FileSystemLoader(config.templates_path), autoescape=True ) + self.download_manager = DownloadManager() + self._download_lock = asyncio.Lock() def setup_routes(self, app): """Register routes with the aiohttp app""" @@ -34,11 +38,13 @@ class CheckpointsRoutes: app.router.add_get('/api/checkpoints/top-tags', self.get_top_tags) app.router.add_get('/api/checkpoints/scan', self.scan_checkpoints) app.router.add_get('/api/checkpoints/info/{name}', self.get_checkpoint_info) + app.router.add_get('/api/checkpoints/roots', self.get_checkpoint_roots) # Add new routes for model management similar to LoRA routes app.router.add_post('/api/checkpoints/delete', self.delete_model) app.router.add_post('/api/checkpoints/fetch-civitai', self.fetch_civitai) app.router.add_post('/api/checkpoints/replace-preview', self.replace_preview) + app.router.add_post('/api/checkpoints/download', self.download_checkpoint) async def get_checkpoints(self, request): """Get paginated checkpoint data""" @@ -478,3 +484,33 @@ class CheckpointsRoutes: async def replace_preview(self, request: web.Request) -> web.Response: """Handle preview image replacement for checkpoints""" return await ModelRouteUtils.handle_replace_preview(request, self.scanner) + + async def download_checkpoint(self, request: web.Request) -> web.Response: + """Handle checkpoint download request""" + async with self._download_lock: + # Initialize DownloadManager with the file monitor if the scanner has one + if not hasattr(self, 'download_manager') or self.download_manager is None: + file_monitor = getattr(self.scanner, 'file_monitor', None) + self.download_manager = DownloadManager(file_monitor) + + # Use the common download handler with model_type="checkpoint" + return await ModelRouteUtils.handle_download_model( + request=request, + download_manager=self.download_manager, + model_type="checkpoint" + ) + + async def get_checkpoint_roots(self, request): + """Return the checkpoint root directories""" + try: + roots = self.scanner.get_model_roots() + return web.json_response({ + "success": True, + "roots": roots + }) + except Exception as e: + logger.error(f"Error getting checkpoint roots: {e}", exc_info=True) + return web.json_response({ + "success": False, + "error": str(e) + }, status=500) diff --git a/py/services/download_manager.py b/py/services/download_manager.py index 1dc2a945..91786336 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -4,7 +4,7 @@ import json from typing import Optional, Dict from .civitai_client import CivitaiClient from .file_monitor import LoraFileMonitor -from ..utils.models import LoraMetadata +from ..utils.models import LoraMetadata, CheckpointMetadata from ..utils.constants import CARD_PREVIEW_WIDTH from ..utils.exif_utils import ExifUtils @@ -20,7 +20,22 @@ class DownloadManager: async def download_from_civitai(self, download_url: str = None, model_hash: str = None, model_version_id: str = None, save_dir: str = None, - relative_path: str = '', progress_callback=None) -> Dict: + relative_path: str = '', progress_callback=None, + model_type: str = "lora") -> Dict: + """Download model from Civitai + + Args: + download_url: Direct download URL for the model + model_hash: SHA256 hash of the model + model_version_id: Civitai model version ID + save_dir: Directory to save the model to + relative_path: Relative path within save_dir + progress_callback: Callback function for progress updates + model_type: Type of model ('lora' or 'checkpoint') + + Returns: + Dict with download result + """ try: # Update save directory with relative path if provided if relative_path: @@ -46,7 +61,7 @@ class DownloadManager: if not version_info: return {'success': False, 'error': 'Failed to fetch model metadata'} - # Check if this is an early access LoRA + # Check if this is an early access model if version_info.get('earlyAccessEndsAt'): early_access_date = version_info.get('earlyAccessEndsAt', '') # Convert to a readable date if possible @@ -54,12 +69,12 @@ class DownloadManager: from datetime import datetime date_obj = datetime.fromisoformat(early_access_date.replace('Z', '+00:00')) formatted_date = date_obj.strftime('%Y-%m-%d') - early_access_msg = f"This LoRA requires early access payment (until {formatted_date}). " + early_access_msg = f"This model requires early access payment (until {formatted_date}). " except: - early_access_msg = "This LoRA requires early access payment. " + early_access_msg = "This model requires early access payment. " early_access_msg += "Please ensure you have purchased early access and are logged in to Civitai." - logger.warning(f"Early access LoRA detected: {version_info.get('name', 'Unknown')}") + logger.warning(f"Early access model detected: {version_info.get('name', 'Unknown')}") # We'll still try to download, but log a warning and prepare for potential failure if progress_callback: @@ -69,26 +84,32 @@ class DownloadManager: if progress_callback: await progress_callback(0) - # 2. 获取文件信息 + # 2. Get file information file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None) if not file_info: return {'success': False, 'error': 'No primary file found in metadata'} - # 3. 准备下载 + # 3. Prepare download file_name = file_info['name'] save_path = os.path.join(save_dir, file_name) file_size = file_info.get('sizeKB', 0) * 1024 - # 4. 通知文件监控系统 - 使用规范化路径和文件大小 - self.file_monitor.handler.add_ignore_path( - save_path.replace(os.sep, '/'), - file_size - ) + # 4. Notify file monitor - use normalized path and file size + if self.file_monitor and self.file_monitor.handler: + self.file_monitor.handler.add_ignore_path( + save_path.replace(os.sep, '/'), + file_size + ) - # 5. 准备元数据 - metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path) + # 5. Prepare metadata based on model type + if model_type == "checkpoint": + metadata = CheckpointMetadata.from_civitai_info(version_info, file_info, save_path) + logger.info(f"Creating CheckpointMetadata for {file_name}") + else: + metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path) + logger.info(f"Creating LoraMetadata for {file_name}") - # 5.1 获取并更新模型标签和描述信息 + # 5.1 Get and update model tags and description model_id = version_info.get('modelId') if model_id: model_metadata, _ = await self.civitai_client.get_model_metadata(str(model_id)) @@ -98,14 +119,15 @@ class DownloadManager: if model_metadata.get("description"): metadata.modelDescription = model_metadata.get("description", "") - # 6. 开始下载流程 + # 6. Start download process result = await self._execute_download( download_url=file_info.get('downloadUrl', ''), save_dir=save_dir, metadata=metadata, version_info=version_info, relative_path=relative_path, - progress_callback=progress_callback + progress_callback=progress_callback, + model_type=model_type ) return result @@ -119,8 +141,9 @@ class DownloadManager: return {'success': False, 'error': str(e)} async def _execute_download(self, download_url: str, save_dir: str, - metadata: LoraMetadata, version_info: Dict, - relative_path: str, progress_callback=None) -> Dict: + metadata, version_info: Dict, + relative_path: str, progress_callback=None, + model_type: str = "lora") -> Dict: """Execute the actual download process including preview images and model files""" try: save_path = metadata.file_path @@ -201,15 +224,21 @@ class DownloadManager: os.remove(path) return {'success': False, 'error': result} - # 4. 更新文件信息(大小和修改时间) + # 4. Update file information (size and modified time) metadata.update_file_info(save_path) - # 5. 最终更新元数据 + # 5. Final metadata update with open(metadata_path, 'w', encoding='utf-8') as f: json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False) - # 6. update lora cache - cache = await self.file_monitor.scanner.get_cached_data() + # 6. Update cache based on model type + if model_type == "checkpoint" and hasattr(self.file_monitor, "checkpoint_scanner"): + cache = await self.file_monitor.checkpoint_scanner.get_cached_data() + logger.info(f"Updating checkpoint cache for {save_path}") + else: + cache = await self.file_monitor.scanner.get_cached_data() + logger.info(f"Updating lora cache for {save_path}") + metadata_dict = metadata.to_dict() metadata_dict['folder'] = relative_path cache.raw_data.append(metadata_dict) @@ -218,11 +247,11 @@ class DownloadManager: all_folders.add(relative_path) cache.folders = sorted(list(all_folders), key=lambda x: x.lower()) - # Update the hash index with the new LoRA entry - self.file_monitor.scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path']) - - # Update the hash index with the new LoRA entry - self.file_monitor.scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path']) + # Update the hash index with the new model entry + if model_type == "checkpoint" and hasattr(self.file_monitor, "checkpoint_scanner"): + self.file_monitor.checkpoint_scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path']) + else: + self.file_monitor.scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path']) # Report 100% completion if progress_callback: diff --git a/py/utils/routes_common.py b/py/utils/routes_common.py index 5b68b368..6c0dc8d7 100644 --- a/py/utils/routes_common.py +++ b/py/utils/routes_common.py @@ -9,6 +9,7 @@ from .constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH from ..config import config from ..services.civitai_client import CivitaiClient from ..utils.exif_utils import ExifUtils +from ..services.download_manager import DownloadManager logger = logging.getLogger(__name__) @@ -421,4 +422,82 @@ class ModelRouteUtils: except Exception as e: logger.error(f"Error replacing preview: {e}", exc_info=True) - return web.Response(text=str(e), status=500) \ No newline at end of file + return web.Response(text=str(e), status=500) + + @staticmethod + async def handle_download_model(request: web.Request, download_manager: DownloadManager, model_type="lora") -> web.Response: + """Handle model download request + + Args: + request: The aiohttp request + download_manager: Instance of DownloadManager + model_type: Type of model ('lora' or 'checkpoint') + + Returns: + web.Response: The HTTP response + """ + try: + data = await request.json() + + # Create progress callback + async def progress_callback(progress): + from ..services.websocket_manager import ws_manager + await ws_manager.broadcast({ + 'status': 'progress', + 'progress': progress + }) + + # Check which identifier is provided + download_url = data.get('download_url') + model_hash = data.get('model_hash') + model_version_id = data.get('model_version_id') + + # Validate that at least one identifier is provided + if not any([download_url, model_hash, model_version_id]): + return web.Response( + status=400, + text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'" + ) + + # Use the correct root directory based on model type + root_key = 'checkpoint_root' if model_type == 'checkpoint' else 'lora_root' + save_dir = data.get(root_key) + + result = await download_manager.download_from_civitai( + download_url=download_url, + model_hash=model_hash, + model_version_id=model_version_id, + save_dir=save_dir, + relative_path=data.get('relative_path', ''), + progress_callback=progress_callback, + model_type=model_type + ) + + if not result.get('success', False): + error_message = result.get('error', 'Unknown error') + + # Return 401 for early access errors + if 'early access' in error_message.lower(): + logger.warning(f"Early access download failed: {error_message}") + return web.Response( + status=401, # Use 401 status code to match Civitai's response + text=f"Early Access Restriction: {error_message}" + ) + + return web.Response(status=500, text=error_message) + + return web.json_response(result) + + except Exception as e: + error_message = str(e) + + # Check if this might be an early access error + if '401' in error_message: + logger.warning(f"Early access error (401): {error_message}") + return web.Response( + status=401, + text="Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com." + ) + + logger.error(f"Error downloading {model_type}: {error_message}") + return web.Response(status=500, text=error_message) \ No newline at end of file diff --git a/static/js/checkpoints.js b/static/js/checkpoints.js index 72342ed7..2f1d316f 100644 --- a/static/js/checkpoints.js +++ b/static/js/checkpoints.js @@ -3,6 +3,7 @@ import { initializeInfiniteScroll } from './utils/infiniteScroll.js'; import { confirmDelete, closeDeleteModal } from './utils/modalUtils.js'; import { createPageControls } from './components/controls/index.js'; import { loadMoreCheckpoints } from './api/checkpointApi.js'; +import { CheckpointDownloadManager } from './managers/CheckpointDownloadManager.js'; // Initialize the Checkpoints page class CheckpointsPageManager { @@ -10,6 +11,9 @@ class CheckpointsPageManager { // Initialize page controls this.pageControls = createPageControls('checkpoints'); + // Initialize checkpoint download manager + window.checkpointDownloadManager = new CheckpointDownloadManager(); + // Expose only necessary functions to global scope this._exposeRequiredGlobalFunctions(); } diff --git a/static/js/components/controls/CheckpointsControls.js b/static/js/components/controls/CheckpointsControls.js index 44c6104a..8cc323f1 100644 --- a/static/js/components/controls/CheckpointsControls.js +++ b/static/js/components/controls/CheckpointsControls.js @@ -2,6 +2,7 @@ import { PageControls } from './PageControls.js'; import { loadMoreCheckpoints, resetAndReload, refreshCheckpoints, fetchCivitai } from '../../api/checkpointApi.js'; import { showToast } from '../../utils/uiHelpers.js'; +import { CheckpointDownloadManager } from '../../managers/CheckpointDownloadManager.js'; /** * CheckpointsControls class - Extends PageControls for Checkpoint-specific functionality @@ -11,6 +12,9 @@ export class CheckpointsControls extends PageControls { // Initialize with 'checkpoints' page type super('checkpoints'); + // Initialize checkpoint download manager + this.downloadManager = new CheckpointDownloadManager(); + // Register API methods specific to the Checkpoints page this.registerCheckpointsAPI(); } @@ -38,6 +42,11 @@ export class CheckpointsControls extends PageControls { return await fetchCivitai(); }, + // Add show download modal functionality + showDownloadModal: () => { + this.downloadManager.showDownloadModal(); + }, + // No clearCustomFilter implementation is needed for checkpoints // as custom filters are currently only used for LoRAs clearCustomFilter: async () => { diff --git a/static/js/components/controls/PageControls.js b/static/js/components/controls/PageControls.js index 0bc4f64e..77599fca 100644 --- a/static/js/components/controls/PageControls.js +++ b/static/js/components/controls/PageControls.js @@ -103,13 +103,12 @@ export class PageControls { fetchButton.addEventListener('click', () => this.fetchFromCivitai()); } + const downloadButton = document.querySelector('[data-action="download"]'); + if (downloadButton) { + downloadButton.addEventListener('click', () => this.showDownloadModal()); + } + if (this.pageType === 'loras') { - // Download button - LoRAs only - const downloadButton = document.querySelector('[data-action="download"]'); - if (downloadButton) { - downloadButton.addEventListener('click', () => this.showDownloadModal()); - } - // Bulk operations button - LoRAs only const bulkButton = document.querySelector('[data-action="bulk"]'); if (bulkButton) { @@ -349,14 +348,9 @@ export class PageControls { } /** - * Show download modal (LoRAs only) + * Show download modal */ showDownloadModal() { - if (this.pageType !== 'loras' || !this.api) { - console.error('Download modal is only available for LoRAs'); - return; - } - this.api.showDownloadModal(); } diff --git a/static/js/managers/CheckpointDownloadManager.js b/static/js/managers/CheckpointDownloadManager.js new file mode 100644 index 00000000..5dfb235c --- /dev/null +++ b/static/js/managers/CheckpointDownloadManager.js @@ -0,0 +1,423 @@ +import { modalManager } from './ModalManager.js'; +import { showToast } from '../utils/uiHelpers.js'; +import { LoadingManager } from './LoadingManager.js'; +import { state } from '../state/index.js'; +import { resetAndReload } from '../api/checkpointApi.js'; +import { getStorageItem } from '../utils/storageHelpers.js'; + +export class CheckpointDownloadManager { + constructor() { + this.currentVersion = null; + this.versions = []; + this.modelInfo = null; + this.modelVersionId = null; + + this.initialized = false; + this.selectedFolder = ''; + + this.loadingManager = new LoadingManager(); + this.folderClickHandler = null; + this.updateTargetPath = this.updateTargetPath.bind(this); + } + + showDownloadModal() { + console.log('Showing checkpoint download modal...'); + if (!this.initialized) { + const modal = document.getElementById('checkpointDownloadModal'); + if (!modal) { + console.error('Checkpoint download modal element not found'); + return; + } + this.initialized = true; + } + + modalManager.showModal('checkpointDownloadModal', null, () => { + // Cleanup handler when modal closes + this.cleanupFolderBrowser(); + }); + this.resetSteps(); + } + + resetSteps() { + document.querySelectorAll('#checkpointDownloadModal .download-step').forEach(step => step.style.display = 'none'); + document.getElementById('cpUrlStep').style.display = 'block'; + document.getElementById('checkpointUrl').value = ''; + document.getElementById('cpUrlError').textContent = ''; + + // Clear new folder input + const newFolderInput = document.getElementById('cpNewFolder'); + if (newFolderInput) { + newFolderInput.value = ''; + } + + this.currentVersion = null; + this.versions = []; + this.modelInfo = null; + this.modelVersionId = null; + + // Clear selected folder and remove selection from UI + this.selectedFolder = ''; + const folderBrowser = document.getElementById('cpFolderBrowser'); + if (folderBrowser) { + folderBrowser.querySelectorAll('.folder-item').forEach(f => + f.classList.remove('selected')); + } + } + + async validateAndFetchVersions() { + const url = document.getElementById('checkpointUrl').value.trim(); + const errorElement = document.getElementById('cpUrlError'); + + try { + this.loadingManager.showSimpleLoading('Fetching model versions...'); + + const modelId = this.extractModelId(url); + if (!modelId) { + throw new Error('Invalid Civitai URL format'); + } + + const response = await fetch(`/api/civitai/versions/${modelId}`); + if (!response.ok) { + throw new Error('Failed to fetch model versions'); + } + + this.versions = await response.json(); + if (!this.versions.length) { + throw new Error('No versions available for this model'); + } + + // If we have a version ID from URL, pre-select it + if (this.modelVersionId) { + this.currentVersion = this.versions.find(v => v.id.toString() === this.modelVersionId); + } + + this.showVersionStep(); + } catch (error) { + errorElement.textContent = error.message; + } finally { + this.loadingManager.hide(); + } + } + + extractModelId(url) { + const modelMatch = url.match(/civitai\.com\/models\/(\d+)/); + const versionMatch = url.match(/modelVersionId=(\d+)/); + + if (modelMatch) { + this.modelVersionId = versionMatch ? versionMatch[1] : null; + return modelMatch[1]; + } + return null; + } + + showVersionStep() { + document.getElementById('cpUrlStep').style.display = 'none'; + document.getElementById('cpVersionStep').style.display = 'block'; + + const versionList = document.getElementById('cpVersionList'); + versionList.innerHTML = this.versions.map(version => { + const firstImage = version.images?.find(img => !img.url.endsWith('.mp4')); + const thumbnailUrl = firstImage ? firstImage.url : '/loras_static/images/no-preview.png'; + + // Use version-level size or fallback to first file + const fileSize = version.modelSizeKB ? + (version.modelSizeKB / 1024).toFixed(2) : + (version.files[0]?.sizeKB / 1024).toFixed(2); + + // Use version-level existsLocally flag + const existsLocally = version.existsLocally; + const localPath = version.localPath; + + // Check if this is an early access version + const isEarlyAccess = version.availability === 'EarlyAccess'; + + // Create early access badge if needed + let earlyAccessBadge = ''; + if (isEarlyAccess) { + earlyAccessBadge = ` +
+ Early Access +
+ `; + } + + // Status badge for local models + const localStatus = existsLocally ? + `
+ In Library +
${localPath || ''}
+
` : ''; + + return ` +
+
+ Version preview +
+
+
+

${version.name}

+ ${localStatus} +
+
+ ${version.baseModel ? `
${version.baseModel}
` : ''} + ${earlyAccessBadge} +
+
+ ${new Date(version.createdAt).toLocaleDateString()} + ${fileSize} MB +
+
+
+ `; + }).join(''); + + // Update Next button state based on initial selection + this.updateNextButtonState(); + } + + selectVersion(versionId) { + this.currentVersion = this.versions.find(v => v.id.toString() === versionId.toString()); + if (!this.currentVersion) return; + + document.querySelectorAll('#cpVersionList .version-item').forEach(item => { + item.classList.toggle('selected', item.querySelector('h3').textContent === this.currentVersion.name); + }); + + // Update Next button state after selection + this.updateNextButtonState(); + } + + updateNextButtonState() { + const nextButton = document.querySelector('#cpVersionStep .primary-btn'); + if (!nextButton) return; + + const existsLocally = this.currentVersion?.existsLocally; + + if (existsLocally) { + nextButton.disabled = true; + nextButton.classList.add('disabled'); + nextButton.textContent = 'Already in Library'; + } else { + nextButton.disabled = false; + nextButton.classList.remove('disabled'); + nextButton.textContent = 'Next'; + } + } + + async proceedToLocation() { + if (!this.currentVersion) { + showToast('Please select a version', 'error'); + return; + } + + // Double-check if the version exists locally + const existsLocally = this.currentVersion.existsLocally; + if (existsLocally) { + showToast('This version already exists in your library', 'info'); + return; + } + + document.getElementById('cpVersionStep').style.display = 'none'; + document.getElementById('cpLocationStep').style.display = 'block'; + + try { + // Use checkpoint roots endpoint instead of lora roots + const response = await fetch('/api/checkpoints/roots'); + if (!response.ok) { + throw new Error('Failed to fetch checkpoint roots'); + } + + const data = await response.json(); + const checkpointRoot = document.getElementById('checkpointRoot'); + checkpointRoot.innerHTML = data.roots.map(root => + `` + ).join(''); + + // Set default checkpoint root if available + const defaultRoot = getStorageItem('settings', {}).default_checkpoints_root; + if (defaultRoot && data.roots.includes(defaultRoot)) { + checkpointRoot.value = defaultRoot; + } + + // Initialize folder browser after loading roots + this.initializeFolderBrowser(); + } catch (error) { + showToast(error.message, 'error'); + } + } + + backToUrl() { + document.getElementById('cpVersionStep').style.display = 'none'; + document.getElementById('cpUrlStep').style.display = 'block'; + } + + backToVersions() { + document.getElementById('cpLocationStep').style.display = 'none'; + document.getElementById('cpVersionStep').style.display = 'block'; + } + + async startDownload() { + const checkpointRoot = document.getElementById('checkpointRoot').value; + const newFolder = document.getElementById('cpNewFolder').value.trim(); + + if (!checkpointRoot) { + showToast('Please select a checkpoint root directory', 'error'); + return; + } + + // Construct relative path + let targetFolder = ''; + if (this.selectedFolder) { + targetFolder = this.selectedFolder; + } + if (newFolder) { + targetFolder = targetFolder ? + `${targetFolder}/${newFolder}` : newFolder; + } + + try { + const downloadUrl = this.currentVersion.downloadUrl; + if (!downloadUrl) { + throw new Error('No download URL available'); + } + + // Show enhanced loading with progress details + const updateProgress = this.loadingManager.showDownloadProgress(1); + updateProgress(0, 0, this.currentVersion.name); + + // Setup WebSocket for progress updates + const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://'; + const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`); + + ws.onmessage = (event) => { + const data = JSON.parse(event.data); + if (data.status === 'progress') { + // Update progress display with current progress + updateProgress(data.progress, 0, this.currentVersion.name); + + // Add more detailed status messages based on progress + if (data.progress < 3) { + this.loadingManager.setStatus(`Preparing download...`); + } else if (data.progress === 3) { + this.loadingManager.setStatus(`Downloaded preview image`); + } else if (data.progress > 3 && data.progress < 100) { + this.loadingManager.setStatus(`Downloading checkpoint file`); + } else { + this.loadingManager.setStatus(`Finalizing download...`); + } + } + }; + + ws.onerror = (error) => { + console.error('WebSocket error:', error); + // Continue with download even if WebSocket fails + }; + + // Start download using checkpoint download endpoint + const response = await fetch('/api/checkpoints/download', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + download_url: downloadUrl, + checkpoint_root: checkpointRoot, + relative_path: targetFolder + }) + }); + + if (!response.ok) { + throw new Error(await response.text()); + } + + showToast('Download completed successfully', 'success'); + modalManager.closeModal('checkpointDownloadModal'); + + // Update state and trigger reload with folder update + state.activeFolder = targetFolder; + await resetAndReload(true); // Pass true to update folders + + } catch (error) { + showToast(error.message, 'error'); + } finally { + this.loadingManager.hide(); + } + } + + initializeFolderBrowser() { + const folderBrowser = document.getElementById('cpFolderBrowser'); + if (!folderBrowser) return; + + // Cleanup existing handler if any + this.cleanupFolderBrowser(); + + // Create new handler + this.folderClickHandler = (event) => { + const folderItem = event.target.closest('.folder-item'); + if (!folderItem) return; + + if (folderItem.classList.contains('selected')) { + folderItem.classList.remove('selected'); + this.selectedFolder = ''; + } else { + folderBrowser.querySelectorAll('.folder-item').forEach(f => + f.classList.remove('selected')); + folderItem.classList.add('selected'); + this.selectedFolder = folderItem.dataset.folder; + } + + // Update path display after folder selection + this.updateTargetPath(); + }; + + // Add the new handler + folderBrowser.addEventListener('click', this.folderClickHandler); + + // Add event listeners for path updates + const checkpointRoot = document.getElementById('checkpointRoot'); + const newFolder = document.getElementById('cpNewFolder'); + + checkpointRoot.addEventListener('change', this.updateTargetPath); + newFolder.addEventListener('input', this.updateTargetPath); + + // Update initial path + this.updateTargetPath(); + } + + cleanupFolderBrowser() { + if (this.folderClickHandler) { + const folderBrowser = document.getElementById('cpFolderBrowser'); + if (folderBrowser) { + folderBrowser.removeEventListener('click', this.folderClickHandler); + this.folderClickHandler = null; + } + } + + // Remove path update listeners + const checkpointRoot = document.getElementById('checkpointRoot'); + const newFolder = document.getElementById('cpNewFolder'); + + if (checkpointRoot) checkpointRoot.removeEventListener('change', this.updateTargetPath); + if (newFolder) newFolder.removeEventListener('input', this.updateTargetPath); + } + + updateTargetPath() { + const pathDisplay = document.getElementById('cpTargetPathDisplay'); + const checkpointRoot = document.getElementById('checkpointRoot').value; + const newFolder = document.getElementById('cpNewFolder').value.trim(); + + let fullPath = checkpointRoot || 'Select a checkpoint root directory'; + + if (checkpointRoot) { + if (this.selectedFolder) { + fullPath += '/' + this.selectedFolder; + } + if (newFolder) { + fullPath += '/' + newFolder; + } + } + + pathDisplay.innerHTML = `${fullPath}`; + } +} \ No newline at end of file diff --git a/static/js/managers/ModalManager.js b/static/js/managers/ModalManager.js index 5c2a2700..989a2806 100644 --- a/static/js/managers/ModalManager.js +++ b/static/js/managers/ModalManager.js @@ -35,6 +35,19 @@ export class ModalManager { closeOnOutsideClick: true }); } + + // Add checkpointDownloadModal registration + const checkpointDownloadModal = document.getElementById('checkpointDownloadModal'); + if (checkpointDownloadModal) { + this.registerModal('checkpointDownloadModal', { + element: checkpointDownloadModal, + onClose: () => { + this.getModal('checkpointDownloadModal').element.style.display = 'none'; + document.body.classList.remove('modal-open'); + }, + closeOnOutsideClick: true + }); + } const deleteModal = document.getElementById('deleteModal'); if (deleteModal) { diff --git a/templates/components/checkpoint_modals.html b/templates/components/checkpoint_modals.html index af971475..0a0ea127 100644 --- a/templates/components/checkpoint_modals.html +++ b/templates/components/checkpoint_modals.html @@ -32,4 +32,73 @@ + + + + \ No newline at end of file diff --git a/templates/components/lora_modals.html b/templates/components/lora_modals.html index 3bc0ab30..8eeb438b 100644 --- a/templates/components/lora_modals.html +++ b/templates/components/lora_modals.html @@ -4,7 +4,7 @@ - +