From 451f77b99bdcafd14292b042b52828496259c712 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Fri, 14 Feb 2025 10:57:33 +0800 Subject: [PATCH] Add download lora --- lora_manager.py | 6 +- routes/api_routes.py | 59 +++++------- services/civitai_client.py | 54 +++-------- services/download_manager.py | 123 ++++++++++++++++++++++++++ services/file_monitor.py | 33 ++++++- static/css/style.css | 16 ++++ static/js/managers/DownloadManager.js | 42 ++++++++- templates/components/modals.html | 7 +- utils/model_utils.py | 1 + utils/models.py | 29 +++++- 10 files changed, 283 insertions(+), 87 deletions(-) create mode 100644 services/download_manager.py diff --git a/lora_manager.py b/lora_manager.py index 63ebd761..da80ec5b 100644 --- a/lora_manager.py +++ b/lora_manager.py @@ -26,13 +26,13 @@ class LoraManager: # Setup feature routes routes = LoraRoutes() - routes.setup_routes(app) - ApiRoutes.setup_routes(app) - # Setup file monitoring monitor = LoraFileMonitor(routes.scanner, config.loras_roots) monitor.start() + routes.setup_routes(app) + ApiRoutes.setup_routes(app, monitor) + # Store monitor in app for cleanup app['lora_monitor'] = monitor diff --git a/routes/api_routes.py b/routes/api_routes.py index 20e30e41..8f6f455c 100644 --- a/routes/api_routes.py +++ b/routes/api_routes.py @@ -3,26 +3,32 @@ import json import logging from aiohttp import web from typing import Dict, List + +from ..services.file_monitor import LoraFileMonitor +from ..services.download_manager import DownloadManager from ..services.civitai_client import CivitaiClient from ..config import config from ..services.lora_scanner import LoraScanner from operator import itemgetter from ..services.websocket_manager import ws_manager -from ..services.settings_manager import settings # 添加这行 +from ..services.settings_manager import settings +import asyncio logger = logging.getLogger(__name__) class ApiRoutes: """API route handlers for LoRA management""" - def __init__(self): + def __init__(self, file_monitor: LoraFileMonitor): self.scanner = LoraScanner() self.civitai_client = CivitaiClient() + self.download_manager = DownloadManager(file_monitor) + self._download_lock = asyncio.Lock() @classmethod - def setup_routes(cls, app: web.Application): + def setup_routes(cls, app: web.Application, monitor: LoraFileMonitor): """Register API routes""" - routes = cls() + routes = cls(monitor) app.router.add_post('/api/delete_model', routes.delete_model) app.router.add_post('/api/fetch-civitai', routes.fetch_civitai) app.router.add_post('/api/replace_preview', routes.replace_preview) @@ -467,41 +473,18 @@ class ApiRoutes: return web.Response(status=500, text=str(e)) async def download_lora(self, request: web.Request) -> web.Response: - """Handle LoRA download request""" - try: - data = await request.json() - download_url = data.get('download_url') - version_info = data.get('version_info') - lora_root = data.get('lora_root') - new_folder = data.get('new_folder', '').strip() - - if not download_url or not version_info or not lora_root: - return web.Response(status=400, text="Missing required parameters") - - if not os.path.isdir(lora_root): - return web.Response(status=400, text="Invalid LoRA root directory") - - # 构建保存路径 - save_dir = os.path.join(lora_root, new_folder) if new_folder else lora_root - os.makedirs(save_dir, exist_ok=True) - - # 使用提供的下载 URL 和版本信息 - result = await self.civitai_client.download_model_with_info( - download_url=download_url, - version_info=version_info, - save_dir=save_dir - ) - - if result.get('success'): - # 更新缓存 - 使用正确的扫描方法 - await self.scanner.scan_directory(save_dir) # Changed from rescan_directory to scan_directory + async with self._download_lock: + try: + data = await request.json() + result = await self.download_manager.download_from_civitai( + download_url=data.get('download_url'), + save_dir=data.get('lora_root'), + relative_path=data.get('relative_path') + ) return web.json_response(result) - else: - return web.Response(status=500, text=result.get('error', 'Download failed')) - - except Exception as e: - logger.error(f"Error downloading LoRA: {e}") - return web.Response(status=500, text=str(e)) + except Exception as e: + logger.error(f"Error downloading LoRA: {e}") + return web.Response(status=500, text=str(e)) async def update_settings(self, request: web.Request) -> web.Response: """Update application settings""" diff --git a/services/civitai_client.py b/services/civitai_client.py index c1e527d6..92702863 100644 --- a/services/civitai_client.py +++ b/services/civitai_client.py @@ -1,3 +1,4 @@ +from datetime import datetime import aiohttp import os import json @@ -5,6 +6,7 @@ import logging from email.parser import Parser from typing import Optional, Dict, Tuple from urllib.parse import unquote +from ..utils.models import LoraMetadata logger = logging.getLogger(__name__) @@ -127,50 +129,20 @@ class CivitaiClient: logger.error(f"Error fetching model versions: {e}") return None - async def download_model_with_info(self, download_url: str, version_info: dict, save_dir: str) -> Dict: - """Download model using provided version info and URL""" + async def get_model_version_info(self, version_id: str) -> Optional[Dict]: + """Fetch model version metadata from Civitai""" try: - # Generate default filename - default_filename = f"lora_{version_info['id']}.safetensors" - logger.info(f"Downloading model: {version_info.get('name', 'Unknown')}") + session = await self.session + url = f"{self.base_url}/model-versions/{version_id}" + headers = self._get_request_headers() - # Download the model file - success, result = await self._download_file(download_url, save_dir, default_filename) - if not success: - return {'success': False, 'error': result} - - save_path = result - - # Create metadata file - metadata_path = os.path.splitext(save_path)[0] + '.metadata.json' - metadata = { - 'model_name': version_info.get('name', os.path.basename(save_path)), - 'civitai': version_info, - 'preview_url': None, - 'from_civitai': True - } - - # Download preview image if available - images = version_info.get('images', []) - if images: - preview_ext = '.mp4' if images[0].get('type') == 'video' else '.png' - preview_path = os.path.splitext(save_path)[0] + '.preview' + preview_ext - await self.download_preview_image(images[0]['url'], preview_path) - metadata['preview_url'] = preview_path.replace(os.sep, '/') - - # Save metadata - with open(metadata_path, 'w', encoding='utf-8') as f: - json.dump(metadata, f, indent=2, ensure_ascii=False) - - return { - 'success': True, - 'file_path': save_path.replace(os.sep, '/'), - 'metadata': metadata - } - + async with session.get(url, headers=headers) as response: + if response.status == 200: + return await response.json() + return None except Exception as e: - logger.error(f"Error downloading model version: {e}") - return {'success': False, 'error': str(e)} + logger.error(f"Error fetching model version info: {e}") + return None async def close(self): """Close the session if it exists""" diff --git a/services/download_manager.py b/services/download_manager.py new file mode 100644 index 00000000..b7cfd6e7 --- /dev/null +++ b/services/download_manager.py @@ -0,0 +1,123 @@ +import logging +import os +import json +from typing import Optional, Dict +from .civitai_client import CivitaiClient +from .file_monitor import LoraFileMonitor +from ..utils.models import LoraMetadata + +logger = logging.getLogger(__name__) + +class DownloadManager: + def __init__(self, file_monitor: Optional[LoraFileMonitor] = None): + self.civitai_client = CivitaiClient() + self.file_monitor = file_monitor + + async def download_from_civitai(self, download_url: str, save_dir: str, relative_path: str = '') -> Dict: + try: + # Update save directory with relative path if provided + if relative_path: + save_dir = os.path.join(save_dir, relative_path) + # Create directory if it doesn't exist + os.makedirs(save_dir, exist_ok=True) + + # Get version info + version_id = download_url.split('/')[-1] + version_info = await self.civitai_client.get_model_version_info(version_id) + if not version_info: + return {'success': False, 'error': 'Failed to fetch model metadata'} + + # 2. 获取文件信息 + file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None) + if not file_info: + return {'success': False, 'error': 'No primary file found in metadata'} + + # 3. 准备下载 + file_name = file_info['name'] + save_path = os.path.join(save_dir, file_name) + file_size = file_info.get('sizeKB', 0) * 1024 + + # 4. 通知文件监控系统 + self.file_monitor.handler.add_ignore_path( + save_path.replace(os.sep, '/'), + file_size + ) + + # 5. 准备元数据 + metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path) + + # 6. 开始下载流程 + result = await self._execute_download( + download_url=download_url, + save_dir=save_dir, + metadata=metadata, + version_info=version_info, + relative_path=relative_path + ) + + return result + + except Exception as e: + logger.error(f"Error in download_from_civitai: {e}", exc_info=True) + return {'success': False, 'error': str(e)} + + async def _execute_download(self, download_url: str, save_dir: str, + metadata: LoraMetadata, version_info: Dict, relative_path: str) -> Dict: + """执行实际的下载流程,包括预览图和模型文件""" + try: + save_path = metadata.file_path + metadata_path = os.path.splitext(save_path)[0] + '.metadata.json' + + # 2. 下载预览图(如果有) + images = version_info.get('images', []) + if images: + preview_ext = '.mp4' if images[0].get('type') == 'video' else '.png' + preview_path = os.path.splitext(save_path)[0] + '.preview' + preview_ext + if await self.civitai_client.download_preview_image(images[0]['url'], preview_path): + metadata.preview_url = preview_path.replace(os.sep, '/') + # 更新元数据中的预览图URL + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False) + + # 3. 下载模型文件 + success, result = await self.civitai_client._download_file( + download_url, + save_dir, + os.path.basename(save_path) + ) + + if not success: + # 下载失败时清理文件 + for path in [save_path, metadata_path, metadata.preview_url]: + if path and os.path.exists(path): + os.remove(path) + return {'success': False, 'error': result} + + # 4. 更新文件信息(大小和修改时间) + metadata.update_file_info(save_path) + + # 5. 最终更新元数据 + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False) + + # 6. update lora cache + cache = await self.file_monitor.scanner.get_cached_data() + metadata_dict = metadata.to_dict() + metadata_dict['folder'] = relative_path + cache.raw_data.append(metadata_dict) + await cache.resort() + all_folders = set(cache.folders) + all_folders.add(relative_path) + cache.folders = sorted(list(all_folders)) + + return { + 'success': True + } + + except Exception as e: + logger.error(f"Error in _execute_download: {e}", exc_info=True) + # 确保清理任何部分下载的文件 + for path in [save_path, metadata_path]: + if path and os.path.exists(path): + os.remove(path) + return {'success': False, 'error': str(e)} \ No newline at end of file diff --git a/services/file_monitor.py b/services/file_monitor.py index 82437f39..f5045d4e 100644 --- a/services/file_monitor.py +++ b/services/file_monitor.py @@ -19,10 +19,41 @@ class LoraFileHandler(FileSystemEventHandler): self.pending_changes = set() # 待处理的变更 self.lock = Lock() # 线程安全锁 self.update_task = None # 异步更新任务 + self._ignore_paths = set() # Add ignore paths set + self._min_ignore_timeout = 5 # minimum timeout in seconds + self._download_speed = 1024 * 1024 # assume 1MB/s as base speed + + def _should_ignore(self, path: str) -> bool: + """Check if path should be ignored""" + logger.info(f"Checking ignore for {path}") + logger.info(f"Ignore paths: {self._ignore_paths}") + return path.replace(os.sep, '/') in self._ignore_paths + + def add_ignore_path(self, path: str, file_size: int = 0): + """Add path to ignore list with dynamic timeout based on file size""" + self._ignore_paths.add(path.replace(os.sep, '/')) + logger.info(f"Update ignore paths: {self._ignore_paths}") + + # Calculate timeout based on file size, with a minimum value + # Assuming average download speed of 1MB/s + timeout = max( + self._min_ignore_timeout, + (file_size / self._download_speed) * 1.5 # Add 50% buffer + ) + + logger.debug(f"Adding {path} to ignore list for {timeout:.1f} seconds") + + asyncio.get_event_loop().call_later( + timeout, + self._ignore_paths.discard, + path + ) def on_created(self, event): if event.is_directory or not event.src_path.endswith('.safetensors'): return + if self._should_ignore(event.src_path): + return logger.info(f"LoRA file created: {event.src_path}") self._schedule_update('add', event.src_path) @@ -123,4 +154,4 @@ class LoraFileMonitor: def stop(self): """Stop monitoring""" self.observer.stop() - self.observer.join() \ No newline at end of file + self.observer.join() \ No newline at end of file diff --git a/static/css/style.css b/static/css/style.css index cec3bce9..fd28b21c 100644 --- a/static/css/style.css +++ b/static/css/style.css @@ -1211,4 +1211,20 @@ body.modal-open { color: var(--text-color); opacity: 0.8; margin-top: 4px; +} + +.folder-item { + padding: 8px; + cursor: pointer; + border-radius: var(--border-radius-xs); + transition: background-color 0.2s; +} + +.folder-item:hover { + background: var(--lora-surface); +} + +.folder-item.selected { + background: oklch(var(--lora-accent) / 0.1); + border: 1px solid var(--lora-accent); } \ No newline at end of file diff --git a/static/js/managers/DownloadManager.js b/static/js/managers/DownloadManager.js index 32edca79..bdf6d395 100644 --- a/static/js/managers/DownloadManager.js +++ b/static/js/managers/DownloadManager.js @@ -10,6 +10,7 @@ export class DownloadManager { // Add initialization check this.initialized = false; + this.selectedFolder = ''; } showDownloadModal() { @@ -127,6 +128,9 @@ export class DownloadManager { loraRoot.innerHTML = data.roots.map(root => `` ).join(''); + + // Initialize folder browser after loading roots + this.initializeFolderBrowser(); } catch (error) { showToast(error.message, 'error'); } @@ -151,20 +155,33 @@ export class DownloadManager { return; } + console.log('Selected folder:', this.selectedFolder); // Log selected folder + console.log('New folder:', newFolder); // Log new folder + + // Construct relative path + let relativePath = ''; + if (this.selectedFolder) { + relativePath = this.selectedFolder; + } + if (newFolder) { + relativePath = relativePath ? + `${relativePath}/${newFolder}` : newFolder; + } + try { const downloadUrl = this.currentVersion.downloadUrl; if (!downloadUrl) { throw new Error('No download URL available'); } + // 只传递必要参数 const response = await fetch('/api/download-lora', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ download_url: downloadUrl, - version_info: this.currentVersion, lora_root: loraRoot, - new_folder: newFolder + relative_path: relativePath }) }); @@ -182,4 +199,25 @@ export class DownloadManager { showToast(error.message, 'error'); } } + + // Add new method to handle folder selection + initializeFolderBrowser() { + // Update folder selection handling + const folderBrowser = document.getElementById('folderBrowser'); + if (!folderBrowser) return; + + // Update folder selection event handling + folderBrowser.addEventListener('click', (event) => { + const folderItem = event.target.closest('.folder-item'); + if (!folderItem) return; + + // Remove previous selection + folderBrowser.querySelectorAll('.folder-item').forEach(f => + f.classList.remove('selected')); + + // Add selection to clicked folder + folderItem.classList.add('selected'); + this.selectedFolder = folderItem.dataset.folder; + }); + } } diff --git a/templates/components/modals.html b/templates/components/modals.html index 2ba52c5e..086556d4 100644 --- a/templates/components/modals.html +++ b/templates/components/modals.html @@ -53,7 +53,12 @@
- + + {% for folder in folders %} +
+ {{ folder }} +
+ {% endfor %}
diff --git a/utils/model_utils.py b/utils/model_utils.py index d046d0fc..893a0f5d 100644 --- a/utils/model_utils.py +++ b/utils/model_utils.py @@ -7,6 +7,7 @@ BASE_MODEL_MAPPING = { "sdxl": "SDXL", "sd-v2": "SD2.0", "flux1": "Flux.1 D", + "Illustrious": "IL" } def determine_base_model(version_string: Optional[str]) -> str: diff --git a/utils/models.py b/utils/models.py index 07d3d54c..0b4d0c67 100644 --- a/utils/models.py +++ b/utils/models.py @@ -1,6 +1,8 @@ from dataclasses import dataclass, asdict from typing import Dict, Optional from datetime import datetime +import os +from .model_utils import determine_base_model @dataclass class LoraMetadata: @@ -23,6 +25,25 @@ class LoraMetadata: data_copy = data.copy() return cls(**data_copy) + @classmethod + def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'LoraMetadata': + """Create LoraMetadata instance from Civitai version info""" + file_name = file_info['name'] + base_model = determine_base_model(version_info.get('baseModel', '')) + + return cls( + file_name=os.path.splitext(file_name)[0], + model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]), + file_path=save_path.replace(os.sep, '/'), + size=file_info.get('sizeKB', 0) * 1024, + modified=datetime.now().timestamp(), + sha256=file_info['hashes'].get('SHA256', ''), + base_model=base_model, + preview_url=None, # Will be updated after preview download + from_civitai=True, + civitai=version_info + ) + def to_dict(self) -> Dict: """Convert to dictionary for JSON serialization""" return asdict(self) @@ -36,4 +57,10 @@ class LoraMetadata: """Update Civitai information""" self.civitai = civitai_data - \ No newline at end of file + def update_file_info(self, file_path: str) -> None: + """Update metadata with actual file information""" + if os.path.exists(file_path): + self.size = os.path.getsize(file_path) + self.modified = os.path.getmtime(file_path) + self.file_path = file_path.replace(os.sep, '/') +