diff --git a/py/routes/checkpoints_routes.py b/py/routes/checkpoints_routes.py index e65bea5b..1d3627a5 100644 --- a/py/routes/checkpoints_routes.py +++ b/py/routes/checkpoints_routes.py @@ -3,12 +3,14 @@ import json import jinja2 from aiohttp import web import logging +import asyncio from ..utils.routes_common import ModelRouteUtils from ..utils.constants import NSFW_LEVELS from ..services.civitai_client import CivitaiClient from ..services.websocket_manager import ws_manager from ..services.checkpoint_scanner import CheckpointScanner +from ..services.download_manager import DownloadManager from ..config import config from ..services.settings_manager import settings from ..utils.utils import fuzzy_match @@ -24,6 +26,8 @@ class CheckpointsRoutes: loader=jinja2.FileSystemLoader(config.templates_path), autoescape=True ) + self.download_manager = DownloadManager() + self._download_lock = asyncio.Lock() def setup_routes(self, app): """Register routes with the aiohttp app""" @@ -34,11 +38,13 @@ class CheckpointsRoutes: app.router.add_get('/api/checkpoints/top-tags', self.get_top_tags) app.router.add_get('/api/checkpoints/scan', self.scan_checkpoints) app.router.add_get('/api/checkpoints/info/{name}', self.get_checkpoint_info) + app.router.add_get('/api/checkpoints/roots', self.get_checkpoint_roots) # Add new routes for model management similar to LoRA routes app.router.add_post('/api/checkpoints/delete', self.delete_model) app.router.add_post('/api/checkpoints/fetch-civitai', self.fetch_civitai) app.router.add_post('/api/checkpoints/replace-preview', self.replace_preview) + app.router.add_post('/api/checkpoints/download', self.download_checkpoint) async def get_checkpoints(self, request): """Get paginated checkpoint data""" @@ -478,3 +484,33 @@ class CheckpointsRoutes: async def replace_preview(self, request: web.Request) -> web.Response: """Handle preview image replacement for checkpoints""" return await ModelRouteUtils.handle_replace_preview(request, self.scanner) + + async def download_checkpoint(self, request: web.Request) -> web.Response: + """Handle checkpoint download request""" + async with self._download_lock: + # Initialize DownloadManager with the file monitor if the scanner has one + if not hasattr(self, 'download_manager') or self.download_manager is None: + file_monitor = getattr(self.scanner, 'file_monitor', None) + self.download_manager = DownloadManager(file_monitor) + + # Use the common download handler with model_type="checkpoint" + return await ModelRouteUtils.handle_download_model( + request=request, + download_manager=self.download_manager, + model_type="checkpoint" + ) + + async def get_checkpoint_roots(self, request): + """Return the checkpoint root directories""" + try: + roots = self.scanner.get_model_roots() + return web.json_response({ + "success": True, + "roots": roots + }) + except Exception as e: + logger.error(f"Error getting checkpoint roots: {e}", exc_info=True) + return web.json_response({ + "success": False, + "error": str(e) + }, status=500) diff --git a/py/services/download_manager.py b/py/services/download_manager.py index 1dc2a945..91786336 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -4,7 +4,7 @@ import json from typing import Optional, Dict from .civitai_client import CivitaiClient from .file_monitor import LoraFileMonitor -from ..utils.models import LoraMetadata +from ..utils.models import LoraMetadata, CheckpointMetadata from ..utils.constants import CARD_PREVIEW_WIDTH from ..utils.exif_utils import ExifUtils @@ -20,7 +20,22 @@ class DownloadManager: async def download_from_civitai(self, download_url: str = None, model_hash: str = None, model_version_id: str = None, save_dir: str = None, - relative_path: str = '', progress_callback=None) -> Dict: + relative_path: str = '', progress_callback=None, + model_type: str = "lora") -> Dict: + """Download model from Civitai + + Args: + download_url: Direct download URL for the model + model_hash: SHA256 hash of the model + model_version_id: Civitai model version ID + save_dir: Directory to save the model to + relative_path: Relative path within save_dir + progress_callback: Callback function for progress updates + model_type: Type of model ('lora' or 'checkpoint') + + Returns: + Dict with download result + """ try: # Update save directory with relative path if provided if relative_path: @@ -46,7 +61,7 @@ class DownloadManager: if not version_info: return {'success': False, 'error': 'Failed to fetch model metadata'} - # Check if this is an early access LoRA + # Check if this is an early access model if version_info.get('earlyAccessEndsAt'): early_access_date = version_info.get('earlyAccessEndsAt', '') # Convert to a readable date if possible @@ -54,12 +69,12 @@ class DownloadManager: from datetime import datetime date_obj = datetime.fromisoformat(early_access_date.replace('Z', '+00:00')) formatted_date = date_obj.strftime('%Y-%m-%d') - early_access_msg = f"This LoRA requires early access payment (until {formatted_date}). " + early_access_msg = f"This model requires early access payment (until {formatted_date}). " except: - early_access_msg = "This LoRA requires early access payment. " + early_access_msg = "This model requires early access payment. " early_access_msg += "Please ensure you have purchased early access and are logged in to Civitai." - logger.warning(f"Early access LoRA detected: {version_info.get('name', 'Unknown')}") + logger.warning(f"Early access model detected: {version_info.get('name', 'Unknown')}") # We'll still try to download, but log a warning and prepare for potential failure if progress_callback: @@ -69,26 +84,32 @@ class DownloadManager: if progress_callback: await progress_callback(0) - # 2. 获取文件信息 + # 2. Get file information file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None) if not file_info: return {'success': False, 'error': 'No primary file found in metadata'} - # 3. 准备下载 + # 3. Prepare download file_name = file_info['name'] save_path = os.path.join(save_dir, file_name) file_size = file_info.get('sizeKB', 0) * 1024 - # 4. 通知文件监控系统 - 使用规范化路径和文件大小 - self.file_monitor.handler.add_ignore_path( - save_path.replace(os.sep, '/'), - file_size - ) + # 4. Notify file monitor - use normalized path and file size + if self.file_monitor and self.file_monitor.handler: + self.file_monitor.handler.add_ignore_path( + save_path.replace(os.sep, '/'), + file_size + ) - # 5. 准备元数据 - metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path) + # 5. Prepare metadata based on model type + if model_type == "checkpoint": + metadata = CheckpointMetadata.from_civitai_info(version_info, file_info, save_path) + logger.info(f"Creating CheckpointMetadata for {file_name}") + else: + metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path) + logger.info(f"Creating LoraMetadata for {file_name}") - # 5.1 获取并更新模型标签和描述信息 + # 5.1 Get and update model tags and description model_id = version_info.get('modelId') if model_id: model_metadata, _ = await self.civitai_client.get_model_metadata(str(model_id)) @@ -98,14 +119,15 @@ class DownloadManager: if model_metadata.get("description"): metadata.modelDescription = model_metadata.get("description", "") - # 6. 开始下载流程 + # 6. Start download process result = await self._execute_download( download_url=file_info.get('downloadUrl', ''), save_dir=save_dir, metadata=metadata, version_info=version_info, relative_path=relative_path, - progress_callback=progress_callback + progress_callback=progress_callback, + model_type=model_type ) return result @@ -119,8 +141,9 @@ class DownloadManager: return {'success': False, 'error': str(e)} async def _execute_download(self, download_url: str, save_dir: str, - metadata: LoraMetadata, version_info: Dict, - relative_path: str, progress_callback=None) -> Dict: + metadata, version_info: Dict, + relative_path: str, progress_callback=None, + model_type: str = "lora") -> Dict: """Execute the actual download process including preview images and model files""" try: save_path = metadata.file_path @@ -201,15 +224,21 @@ class DownloadManager: os.remove(path) return {'success': False, 'error': result} - # 4. 更新文件信息(大小和修改时间) + # 4. Update file information (size and modified time) metadata.update_file_info(save_path) - # 5. 最终更新元数据 + # 5. Final metadata update with open(metadata_path, 'w', encoding='utf-8') as f: json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False) - # 6. update lora cache - cache = await self.file_monitor.scanner.get_cached_data() + # 6. Update cache based on model type + if model_type == "checkpoint" and hasattr(self.file_monitor, "checkpoint_scanner"): + cache = await self.file_monitor.checkpoint_scanner.get_cached_data() + logger.info(f"Updating checkpoint cache for {save_path}") + else: + cache = await self.file_monitor.scanner.get_cached_data() + logger.info(f"Updating lora cache for {save_path}") + metadata_dict = metadata.to_dict() metadata_dict['folder'] = relative_path cache.raw_data.append(metadata_dict) @@ -218,11 +247,11 @@ class DownloadManager: all_folders.add(relative_path) cache.folders = sorted(list(all_folders), key=lambda x: x.lower()) - # Update the hash index with the new LoRA entry - self.file_monitor.scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path']) - - # Update the hash index with the new LoRA entry - self.file_monitor.scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path']) + # Update the hash index with the new model entry + if model_type == "checkpoint" and hasattr(self.file_monitor, "checkpoint_scanner"): + self.file_monitor.checkpoint_scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path']) + else: + self.file_monitor.scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path']) # Report 100% completion if progress_callback: diff --git a/py/utils/routes_common.py b/py/utils/routes_common.py index 5b68b368..6c0dc8d7 100644 --- a/py/utils/routes_common.py +++ b/py/utils/routes_common.py @@ -9,6 +9,7 @@ from .constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH from ..config import config from ..services.civitai_client import CivitaiClient from ..utils.exif_utils import ExifUtils +from ..services.download_manager import DownloadManager logger = logging.getLogger(__name__) @@ -421,4 +422,82 @@ class ModelRouteUtils: except Exception as e: logger.error(f"Error replacing preview: {e}", exc_info=True) - return web.Response(text=str(e), status=500) \ No newline at end of file + return web.Response(text=str(e), status=500) + + @staticmethod + async def handle_download_model(request: web.Request, download_manager: DownloadManager, model_type="lora") -> web.Response: + """Handle model download request + + Args: + request: The aiohttp request + download_manager: Instance of DownloadManager + model_type: Type of model ('lora' or 'checkpoint') + + Returns: + web.Response: The HTTP response + """ + try: + data = await request.json() + + # Create progress callback + async def progress_callback(progress): + from ..services.websocket_manager import ws_manager + await ws_manager.broadcast({ + 'status': 'progress', + 'progress': progress + }) + + # Check which identifier is provided + download_url = data.get('download_url') + model_hash = data.get('model_hash') + model_version_id = data.get('model_version_id') + + # Validate that at least one identifier is provided + if not any([download_url, model_hash, model_version_id]): + return web.Response( + status=400, + text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'" + ) + + # Use the correct root directory based on model type + root_key = 'checkpoint_root' if model_type == 'checkpoint' else 'lora_root' + save_dir = data.get(root_key) + + result = await download_manager.download_from_civitai( + download_url=download_url, + model_hash=model_hash, + model_version_id=model_version_id, + save_dir=save_dir, + relative_path=data.get('relative_path', ''), + progress_callback=progress_callback, + model_type=model_type + ) + + if not result.get('success', False): + error_message = result.get('error', 'Unknown error') + + # Return 401 for early access errors + if 'early access' in error_message.lower(): + logger.warning(f"Early access download failed: {error_message}") + return web.Response( + status=401, # Use 401 status code to match Civitai's response + text=f"Early Access Restriction: {error_message}" + ) + + return web.Response(status=500, text=error_message) + + return web.json_response(result) + + except Exception as e: + error_message = str(e) + + # Check if this might be an early access error + if '401' in error_message: + logger.warning(f"Early access error (401): {error_message}") + return web.Response( + status=401, + text="Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com." + ) + + logger.error(f"Error downloading {model_type}: {error_message}") + return web.Response(status=500, text=error_message) \ No newline at end of file diff --git a/static/js/checkpoints.js b/static/js/checkpoints.js index 72342ed7..2f1d316f 100644 --- a/static/js/checkpoints.js +++ b/static/js/checkpoints.js @@ -3,6 +3,7 @@ import { initializeInfiniteScroll } from './utils/infiniteScroll.js'; import { confirmDelete, closeDeleteModal } from './utils/modalUtils.js'; import { createPageControls } from './components/controls/index.js'; import { loadMoreCheckpoints } from './api/checkpointApi.js'; +import { CheckpointDownloadManager } from './managers/CheckpointDownloadManager.js'; // Initialize the Checkpoints page class CheckpointsPageManager { @@ -10,6 +11,9 @@ class CheckpointsPageManager { // Initialize page controls this.pageControls = createPageControls('checkpoints'); + // Initialize checkpoint download manager + window.checkpointDownloadManager = new CheckpointDownloadManager(); + // Expose only necessary functions to global scope this._exposeRequiredGlobalFunctions(); } diff --git a/static/js/components/controls/CheckpointsControls.js b/static/js/components/controls/CheckpointsControls.js index 44c6104a..8cc323f1 100644 --- a/static/js/components/controls/CheckpointsControls.js +++ b/static/js/components/controls/CheckpointsControls.js @@ -2,6 +2,7 @@ import { PageControls } from './PageControls.js'; import { loadMoreCheckpoints, resetAndReload, refreshCheckpoints, fetchCivitai } from '../../api/checkpointApi.js'; import { showToast } from '../../utils/uiHelpers.js'; +import { CheckpointDownloadManager } from '../../managers/CheckpointDownloadManager.js'; /** * CheckpointsControls class - Extends PageControls for Checkpoint-specific functionality @@ -11,6 +12,9 @@ export class CheckpointsControls extends PageControls { // Initialize with 'checkpoints' page type super('checkpoints'); + // Initialize checkpoint download manager + this.downloadManager = new CheckpointDownloadManager(); + // Register API methods specific to the Checkpoints page this.registerCheckpointsAPI(); } @@ -38,6 +42,11 @@ export class CheckpointsControls extends PageControls { return await fetchCivitai(); }, + // Add show download modal functionality + showDownloadModal: () => { + this.downloadManager.showDownloadModal(); + }, + // No clearCustomFilter implementation is needed for checkpoints // as custom filters are currently only used for LoRAs clearCustomFilter: async () => { diff --git a/static/js/components/controls/PageControls.js b/static/js/components/controls/PageControls.js index 0bc4f64e..77599fca 100644 --- a/static/js/components/controls/PageControls.js +++ b/static/js/components/controls/PageControls.js @@ -103,13 +103,12 @@ export class PageControls { fetchButton.addEventListener('click', () => this.fetchFromCivitai()); } + const downloadButton = document.querySelector('[data-action="download"]'); + if (downloadButton) { + downloadButton.addEventListener('click', () => this.showDownloadModal()); + } + if (this.pageType === 'loras') { - // Download button - LoRAs only - const downloadButton = document.querySelector('[data-action="download"]'); - if (downloadButton) { - downloadButton.addEventListener('click', () => this.showDownloadModal()); - } - // Bulk operations button - LoRAs only const bulkButton = document.querySelector('[data-action="bulk"]'); if (bulkButton) { @@ -349,14 +348,9 @@ export class PageControls { } /** - * Show download modal (LoRAs only) + * Show download modal */ showDownloadModal() { - if (this.pageType !== 'loras' || !this.api) { - console.error('Download modal is only available for LoRAs'); - return; - } - this.api.showDownloadModal(); } diff --git a/static/js/managers/CheckpointDownloadManager.js b/static/js/managers/CheckpointDownloadManager.js new file mode 100644 index 00000000..5dfb235c --- /dev/null +++ b/static/js/managers/CheckpointDownloadManager.js @@ -0,0 +1,423 @@ +import { modalManager } from './ModalManager.js'; +import { showToast } from '../utils/uiHelpers.js'; +import { LoadingManager } from './LoadingManager.js'; +import { state } from '../state/index.js'; +import { resetAndReload } from '../api/checkpointApi.js'; +import { getStorageItem } from '../utils/storageHelpers.js'; + +export class CheckpointDownloadManager { + constructor() { + this.currentVersion = null; + this.versions = []; + this.modelInfo = null; + this.modelVersionId = null; + + this.initialized = false; + this.selectedFolder = ''; + + this.loadingManager = new LoadingManager(); + this.folderClickHandler = null; + this.updateTargetPath = this.updateTargetPath.bind(this); + } + + showDownloadModal() { + console.log('Showing checkpoint download modal...'); + if (!this.initialized) { + const modal = document.getElementById('checkpointDownloadModal'); + if (!modal) { + console.error('Checkpoint download modal element not found'); + return; + } + this.initialized = true; + } + + modalManager.showModal('checkpointDownloadModal', null, () => { + // Cleanup handler when modal closes + this.cleanupFolderBrowser(); + }); + this.resetSteps(); + } + + resetSteps() { + document.querySelectorAll('#checkpointDownloadModal .download-step').forEach(step => step.style.display = 'none'); + document.getElementById('cpUrlStep').style.display = 'block'; + document.getElementById('checkpointUrl').value = ''; + document.getElementById('cpUrlError').textContent = ''; + + // Clear new folder input + const newFolderInput = document.getElementById('cpNewFolder'); + if (newFolderInput) { + newFolderInput.value = ''; + } + + this.currentVersion = null; + this.versions = []; + this.modelInfo = null; + this.modelVersionId = null; + + // Clear selected folder and remove selection from UI + this.selectedFolder = ''; + const folderBrowser = document.getElementById('cpFolderBrowser'); + if (folderBrowser) { + folderBrowser.querySelectorAll('.folder-item').forEach(f => + f.classList.remove('selected')); + } + } + + async validateAndFetchVersions() { + const url = document.getElementById('checkpointUrl').value.trim(); + const errorElement = document.getElementById('cpUrlError'); + + try { + this.loadingManager.showSimpleLoading('Fetching model versions...'); + + const modelId = this.extractModelId(url); + if (!modelId) { + throw new Error('Invalid Civitai URL format'); + } + + const response = await fetch(`/api/civitai/versions/${modelId}`); + if (!response.ok) { + throw new Error('Failed to fetch model versions'); + } + + this.versions = await response.json(); + if (!this.versions.length) { + throw new Error('No versions available for this model'); + } + + // If we have a version ID from URL, pre-select it + if (this.modelVersionId) { + this.currentVersion = this.versions.find(v => v.id.toString() === this.modelVersionId); + } + + this.showVersionStep(); + } catch (error) { + errorElement.textContent = error.message; + } finally { + this.loadingManager.hide(); + } + } + + extractModelId(url) { + const modelMatch = url.match(/civitai\.com\/models\/(\d+)/); + const versionMatch = url.match(/modelVersionId=(\d+)/); + + if (modelMatch) { + this.modelVersionId = versionMatch ? versionMatch[1] : null; + return modelMatch[1]; + } + return null; + } + + showVersionStep() { + document.getElementById('cpUrlStep').style.display = 'none'; + document.getElementById('cpVersionStep').style.display = 'block'; + + const versionList = document.getElementById('cpVersionList'); + versionList.innerHTML = this.versions.map(version => { + const firstImage = version.images?.find(img => !img.url.endsWith('.mp4')); + const thumbnailUrl = firstImage ? firstImage.url : '/loras_static/images/no-preview.png'; + + // Use version-level size or fallback to first file + const fileSize = version.modelSizeKB ? + (version.modelSizeKB / 1024).toFixed(2) : + (version.files[0]?.sizeKB / 1024).toFixed(2); + + // Use version-level existsLocally flag + const existsLocally = version.existsLocally; + const localPath = version.localPath; + + // Check if this is an early access version + const isEarlyAccess = version.availability === 'EarlyAccess'; + + // Create early access badge if needed + let earlyAccessBadge = ''; + if (isEarlyAccess) { + earlyAccessBadge = ` +