mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Add download lora
This commit is contained in:
@@ -26,13 +26,13 @@ class LoraManager:
|
||||
# Setup feature routes
|
||||
routes = LoraRoutes()
|
||||
|
||||
routes.setup_routes(app)
|
||||
ApiRoutes.setup_routes(app)
|
||||
|
||||
# Setup file monitoring
|
||||
monitor = LoraFileMonitor(routes.scanner, config.loras_roots)
|
||||
monitor.start()
|
||||
|
||||
routes.setup_routes(app)
|
||||
ApiRoutes.setup_routes(app, monitor)
|
||||
|
||||
# Store monitor in app for cleanup
|
||||
app['lora_monitor'] = monitor
|
||||
|
||||
|
||||
@@ -3,26 +3,32 @@ import json
|
||||
import logging
|
||||
from aiohttp import web
|
||||
from typing import Dict, List
|
||||
|
||||
from ..services.file_monitor import LoraFileMonitor
|
||||
from ..services.download_manager import DownloadManager
|
||||
from ..services.civitai_client import CivitaiClient
|
||||
from ..config import config
|
||||
from ..services.lora_scanner import LoraScanner
|
||||
from operator import itemgetter
|
||||
from ..services.websocket_manager import ws_manager
|
||||
from ..services.settings_manager import settings # 添加这行
|
||||
from ..services.settings_manager import settings
|
||||
import asyncio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ApiRoutes:
|
||||
"""API route handlers for LoRA management"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, file_monitor: LoraFileMonitor):
|
||||
self.scanner = LoraScanner()
|
||||
self.civitai_client = CivitaiClient()
|
||||
self.download_manager = DownloadManager(file_monitor)
|
||||
self._download_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
def setup_routes(cls, app: web.Application):
|
||||
def setup_routes(cls, app: web.Application, monitor: LoraFileMonitor):
|
||||
"""Register API routes"""
|
||||
routes = cls()
|
||||
routes = cls(monitor)
|
||||
app.router.add_post('/api/delete_model', routes.delete_model)
|
||||
app.router.add_post('/api/fetch-civitai', routes.fetch_civitai)
|
||||
app.router.add_post('/api/replace_preview', routes.replace_preview)
|
||||
@@ -467,41 +473,18 @@ class ApiRoutes:
|
||||
return web.Response(status=500, text=str(e))
|
||||
|
||||
async def download_lora(self, request: web.Request) -> web.Response:
|
||||
"""Handle LoRA download request"""
|
||||
try:
|
||||
data = await request.json()
|
||||
download_url = data.get('download_url')
|
||||
version_info = data.get('version_info')
|
||||
lora_root = data.get('lora_root')
|
||||
new_folder = data.get('new_folder', '').strip()
|
||||
|
||||
if not download_url or not version_info or not lora_root:
|
||||
return web.Response(status=400, text="Missing required parameters")
|
||||
|
||||
if not os.path.isdir(lora_root):
|
||||
return web.Response(status=400, text="Invalid LoRA root directory")
|
||||
|
||||
# 构建保存路径
|
||||
save_dir = os.path.join(lora_root, new_folder) if new_folder else lora_root
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# 使用提供的下载 URL 和版本信息
|
||||
result = await self.civitai_client.download_model_with_info(
|
||||
download_url=download_url,
|
||||
version_info=version_info,
|
||||
save_dir=save_dir
|
||||
)
|
||||
|
||||
if result.get('success'):
|
||||
# 更新缓存 - 使用正确的扫描方法
|
||||
await self.scanner.scan_directory(save_dir) # Changed from rescan_directory to scan_directory
|
||||
async with self._download_lock:
|
||||
try:
|
||||
data = await request.json()
|
||||
result = await self.download_manager.download_from_civitai(
|
||||
download_url=data.get('download_url'),
|
||||
save_dir=data.get('lora_root'),
|
||||
relative_path=data.get('relative_path')
|
||||
)
|
||||
return web.json_response(result)
|
||||
else:
|
||||
return web.Response(status=500, text=result.get('error', 'Download failed'))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading LoRA: {e}")
|
||||
return web.Response(status=500, text=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading LoRA: {e}")
|
||||
return web.Response(status=500, text=str(e))
|
||||
|
||||
async def update_settings(self, request: web.Request) -> web.Response:
|
||||
"""Update application settings"""
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import datetime
|
||||
import aiohttp
|
||||
import os
|
||||
import json
|
||||
@@ -5,6 +6,7 @@ import logging
|
||||
from email.parser import Parser
|
||||
from typing import Optional, Dict, Tuple
|
||||
from urllib.parse import unquote
|
||||
from ..utils.models import LoraMetadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -127,50 +129,20 @@ class CivitaiClient:
|
||||
logger.error(f"Error fetching model versions: {e}")
|
||||
return None
|
||||
|
||||
async def download_model_with_info(self, download_url: str, version_info: dict, save_dir: str) -> Dict:
|
||||
"""Download model using provided version info and URL"""
|
||||
async def get_model_version_info(self, version_id: str) -> Optional[Dict]:
|
||||
"""Fetch model version metadata from Civitai"""
|
||||
try:
|
||||
# Generate default filename
|
||||
default_filename = f"lora_{version_info['id']}.safetensors"
|
||||
logger.info(f"Downloading model: {version_info.get('name', 'Unknown')}")
|
||||
session = await self.session
|
||||
url = f"{self.base_url}/model-versions/{version_id}"
|
||||
headers = self._get_request_headers()
|
||||
|
||||
# Download the model file
|
||||
success, result = await self._download_file(download_url, save_dir, default_filename)
|
||||
if not success:
|
||||
return {'success': False, 'error': result}
|
||||
|
||||
save_path = result
|
||||
|
||||
# Create metadata file
|
||||
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
|
||||
metadata = {
|
||||
'model_name': version_info.get('name', os.path.basename(save_path)),
|
||||
'civitai': version_info,
|
||||
'preview_url': None,
|
||||
'from_civitai': True
|
||||
}
|
||||
|
||||
# Download preview image if available
|
||||
images = version_info.get('images', [])
|
||||
if images:
|
||||
preview_ext = '.mp4' if images[0].get('type') == 'video' else '.png'
|
||||
preview_path = os.path.splitext(save_path)[0] + '.preview' + preview_ext
|
||||
await self.download_preview_image(images[0]['url'], preview_path)
|
||||
metadata['preview_url'] = preview_path.replace(os.sep, '/')
|
||||
|
||||
# Save metadata
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'file_path': save_path.replace(os.sep, '/'),
|
||||
'metadata': metadata
|
||||
}
|
||||
|
||||
async with session.get(url, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading model version: {e}")
|
||||
return {'success': False, 'error': str(e)}
|
||||
logger.error(f"Error fetching model version info: {e}")
|
||||
return None
|
||||
|
||||
async def close(self):
|
||||
"""Close the session if it exists"""
|
||||
|
||||
123
services/download_manager.py
Normal file
123
services/download_manager.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
from typing import Optional, Dict
|
||||
from .civitai_client import CivitaiClient
|
||||
from .file_monitor import LoraFileMonitor
|
||||
from ..utils.models import LoraMetadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DownloadManager:
|
||||
def __init__(self, file_monitor: Optional[LoraFileMonitor] = None):
|
||||
self.civitai_client = CivitaiClient()
|
||||
self.file_monitor = file_monitor
|
||||
|
||||
async def download_from_civitai(self, download_url: str, save_dir: str, relative_path: str = '') -> Dict:
|
||||
try:
|
||||
# Update save directory with relative path if provided
|
||||
if relative_path:
|
||||
save_dir = os.path.join(save_dir, relative_path)
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# Get version info
|
||||
version_id = download_url.split('/')[-1]
|
||||
version_info = await self.civitai_client.get_model_version_info(version_id)
|
||||
if not version_info:
|
||||
return {'success': False, 'error': 'Failed to fetch model metadata'}
|
||||
|
||||
# 2. 获取文件信息
|
||||
file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None)
|
||||
if not file_info:
|
||||
return {'success': False, 'error': 'No primary file found in metadata'}
|
||||
|
||||
# 3. 准备下载
|
||||
file_name = file_info['name']
|
||||
save_path = os.path.join(save_dir, file_name)
|
||||
file_size = file_info.get('sizeKB', 0) * 1024
|
||||
|
||||
# 4. 通知文件监控系统
|
||||
self.file_monitor.handler.add_ignore_path(
|
||||
save_path.replace(os.sep, '/'),
|
||||
file_size
|
||||
)
|
||||
|
||||
# 5. 准备元数据
|
||||
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
|
||||
|
||||
# 6. 开始下载流程
|
||||
result = await self._execute_download(
|
||||
download_url=download_url,
|
||||
save_dir=save_dir,
|
||||
metadata=metadata,
|
||||
version_info=version_info,
|
||||
relative_path=relative_path
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in download_from_civitai: {e}", exc_info=True)
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
async def _execute_download(self, download_url: str, save_dir: str,
|
||||
metadata: LoraMetadata, version_info: Dict, relative_path: str) -> Dict:
|
||||
"""执行实际的下载流程,包括预览图和模型文件"""
|
||||
try:
|
||||
save_path = metadata.file_path
|
||||
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
|
||||
|
||||
# 2. 下载预览图(如果有)
|
||||
images = version_info.get('images', [])
|
||||
if images:
|
||||
preview_ext = '.mp4' if images[0].get('type') == 'video' else '.png'
|
||||
preview_path = os.path.splitext(save_path)[0] + '.preview' + preview_ext
|
||||
if await self.civitai_client.download_preview_image(images[0]['url'], preview_path):
|
||||
metadata.preview_url = preview_path.replace(os.sep, '/')
|
||||
# 更新元数据中的预览图URL
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
|
||||
|
||||
# 3. 下载模型文件
|
||||
success, result = await self.civitai_client._download_file(
|
||||
download_url,
|
||||
save_dir,
|
||||
os.path.basename(save_path)
|
||||
)
|
||||
|
||||
if not success:
|
||||
# 下载失败时清理文件
|
||||
for path in [save_path, metadata_path, metadata.preview_url]:
|
||||
if path and os.path.exists(path):
|
||||
os.remove(path)
|
||||
return {'success': False, 'error': result}
|
||||
|
||||
# 4. 更新文件信息(大小和修改时间)
|
||||
metadata.update_file_info(save_path)
|
||||
|
||||
# 5. 最终更新元数据
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
|
||||
|
||||
# 6. update lora cache
|
||||
cache = await self.file_monitor.scanner.get_cached_data()
|
||||
metadata_dict = metadata.to_dict()
|
||||
metadata_dict['folder'] = relative_path
|
||||
cache.raw_data.append(metadata_dict)
|
||||
await cache.resort()
|
||||
all_folders = set(cache.folders)
|
||||
all_folders.add(relative_path)
|
||||
cache.folders = sorted(list(all_folders))
|
||||
|
||||
return {
|
||||
'success': True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in _execute_download: {e}", exc_info=True)
|
||||
# 确保清理任何部分下载的文件
|
||||
for path in [save_path, metadata_path]:
|
||||
if path and os.path.exists(path):
|
||||
os.remove(path)
|
||||
return {'success': False, 'error': str(e)}
|
||||
@@ -19,10 +19,41 @@ class LoraFileHandler(FileSystemEventHandler):
|
||||
self.pending_changes = set() # 待处理的变更
|
||||
self.lock = Lock() # 线程安全锁
|
||||
self.update_task = None # 异步更新任务
|
||||
self._ignore_paths = set() # Add ignore paths set
|
||||
self._min_ignore_timeout = 5 # minimum timeout in seconds
|
||||
self._download_speed = 1024 * 1024 # assume 1MB/s as base speed
|
||||
|
||||
def _should_ignore(self, path: str) -> bool:
|
||||
"""Check if path should be ignored"""
|
||||
logger.info(f"Checking ignore for {path}")
|
||||
logger.info(f"Ignore paths: {self._ignore_paths}")
|
||||
return path.replace(os.sep, '/') in self._ignore_paths
|
||||
|
||||
def add_ignore_path(self, path: str, file_size: int = 0):
|
||||
"""Add path to ignore list with dynamic timeout based on file size"""
|
||||
self._ignore_paths.add(path.replace(os.sep, '/'))
|
||||
logger.info(f"Update ignore paths: {self._ignore_paths}")
|
||||
|
||||
# Calculate timeout based on file size, with a minimum value
|
||||
# Assuming average download speed of 1MB/s
|
||||
timeout = max(
|
||||
self._min_ignore_timeout,
|
||||
(file_size / self._download_speed) * 1.5 # Add 50% buffer
|
||||
)
|
||||
|
||||
logger.debug(f"Adding {path} to ignore list for {timeout:.1f} seconds")
|
||||
|
||||
asyncio.get_event_loop().call_later(
|
||||
timeout,
|
||||
self._ignore_paths.discard,
|
||||
path
|
||||
)
|
||||
|
||||
def on_created(self, event):
|
||||
if event.is_directory or not event.src_path.endswith('.safetensors'):
|
||||
return
|
||||
if self._should_ignore(event.src_path):
|
||||
return
|
||||
logger.info(f"LoRA file created: {event.src_path}")
|
||||
self._schedule_update('add', event.src_path)
|
||||
|
||||
@@ -123,4 +154,4 @@ class LoraFileMonitor:
|
||||
def stop(self):
|
||||
"""Stop monitoring"""
|
||||
self.observer.stop()
|
||||
self.observer.join()
|
||||
self.observer.join()
|
||||
@@ -1211,4 +1211,20 @@ body.modal-open {
|
||||
color: var(--text-color);
|
||||
opacity: 0.8;
|
||||
margin-top: 4px;
|
||||
}
|
||||
|
||||
.folder-item {
|
||||
padding: 8px;
|
||||
cursor: pointer;
|
||||
border-radius: var(--border-radius-xs);
|
||||
transition: background-color 0.2s;
|
||||
}
|
||||
|
||||
.folder-item:hover {
|
||||
background: var(--lora-surface);
|
||||
}
|
||||
|
||||
.folder-item.selected {
|
||||
background: oklch(var(--lora-accent) / 0.1);
|
||||
border: 1px solid var(--lora-accent);
|
||||
}
|
||||
@@ -10,6 +10,7 @@ export class DownloadManager {
|
||||
|
||||
// Add initialization check
|
||||
this.initialized = false;
|
||||
this.selectedFolder = '';
|
||||
}
|
||||
|
||||
showDownloadModal() {
|
||||
@@ -127,6 +128,9 @@ export class DownloadManager {
|
||||
loraRoot.innerHTML = data.roots.map(root =>
|
||||
`<option value="${root}">${root}</option>`
|
||||
).join('');
|
||||
|
||||
// Initialize folder browser after loading roots
|
||||
this.initializeFolderBrowser();
|
||||
} catch (error) {
|
||||
showToast(error.message, 'error');
|
||||
}
|
||||
@@ -151,20 +155,33 @@ export class DownloadManager {
|
||||
return;
|
||||
}
|
||||
|
||||
console.log('Selected folder:', this.selectedFolder); // Log selected folder
|
||||
console.log('New folder:', newFolder); // Log new folder
|
||||
|
||||
// Construct relative path
|
||||
let relativePath = '';
|
||||
if (this.selectedFolder) {
|
||||
relativePath = this.selectedFolder;
|
||||
}
|
||||
if (newFolder) {
|
||||
relativePath = relativePath ?
|
||||
`${relativePath}/${newFolder}` : newFolder;
|
||||
}
|
||||
|
||||
try {
|
||||
const downloadUrl = this.currentVersion.downloadUrl;
|
||||
if (!downloadUrl) {
|
||||
throw new Error('No download URL available');
|
||||
}
|
||||
|
||||
// 只传递必要参数
|
||||
const response = await fetch('/api/download-lora', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
download_url: downloadUrl,
|
||||
version_info: this.currentVersion,
|
||||
lora_root: loraRoot,
|
||||
new_folder: newFolder
|
||||
relative_path: relativePath
|
||||
})
|
||||
});
|
||||
|
||||
@@ -182,4 +199,25 @@ export class DownloadManager {
|
||||
showToast(error.message, 'error');
|
||||
}
|
||||
}
|
||||
|
||||
// Add new method to handle folder selection
|
||||
initializeFolderBrowser() {
|
||||
// Update folder selection handling
|
||||
const folderBrowser = document.getElementById('folderBrowser');
|
||||
if (!folderBrowser) return;
|
||||
|
||||
// Update folder selection event handling
|
||||
folderBrowser.addEventListener('click', (event) => {
|
||||
const folderItem = event.target.closest('.folder-item');
|
||||
if (!folderItem) return;
|
||||
|
||||
// Remove previous selection
|
||||
folderBrowser.querySelectorAll('.folder-item').forEach(f =>
|
||||
f.classList.remove('selected'));
|
||||
|
||||
// Add selection to clicked folder
|
||||
folderItem.classList.add('selected');
|
||||
this.selectedFolder = folderItem.dataset.folder;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,7 +53,12 @@
|
||||
<div class="input-group">
|
||||
<label>Target Folder:</label>
|
||||
<div class="folder-browser" id="folderBrowser">
|
||||
<!-- Folder structure will be inserted here -->
|
||||
<!-- Folders will be dynamically inserted here -->
|
||||
{% for folder in folders %}
|
||||
<div class="folder-item" data-folder="{{ folder }}">
|
||||
{{ folder }}
|
||||
</div>
|
||||
{% endfor %}
|
||||
</div>
|
||||
</div>
|
||||
<div class="input-group">
|
||||
|
||||
@@ -7,6 +7,7 @@ BASE_MODEL_MAPPING = {
|
||||
"sdxl": "SDXL",
|
||||
"sd-v2": "SD2.0",
|
||||
"flux1": "Flux.1 D",
|
||||
"Illustrious": "IL"
|
||||
}
|
||||
|
||||
def determine_base_model(version_string: Optional[str]) -> str:
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import Dict, Optional
|
||||
from datetime import datetime
|
||||
import os
|
||||
from .model_utils import determine_base_model
|
||||
|
||||
@dataclass
|
||||
class LoraMetadata:
|
||||
@@ -23,6 +25,25 @@ class LoraMetadata:
|
||||
data_copy = data.copy()
|
||||
return cls(**data_copy)
|
||||
|
||||
@classmethod
|
||||
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'LoraMetadata':
|
||||
"""Create LoraMetadata instance from Civitai version info"""
|
||||
file_name = file_info['name']
|
||||
base_model = determine_base_model(version_info.get('baseModel', ''))
|
||||
|
||||
return cls(
|
||||
file_name=os.path.splitext(file_name)[0],
|
||||
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
|
||||
file_path=save_path.replace(os.sep, '/'),
|
||||
size=file_info.get('sizeKB', 0) * 1024,
|
||||
modified=datetime.now().timestamp(),
|
||||
sha256=file_info['hashes'].get('SHA256', ''),
|
||||
base_model=base_model,
|
||||
preview_url=None, # Will be updated after preview download
|
||||
from_civitai=True,
|
||||
civitai=version_info
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert to dictionary for JSON serialization"""
|
||||
return asdict(self)
|
||||
@@ -36,4 +57,10 @@ class LoraMetadata:
|
||||
"""Update Civitai information"""
|
||||
self.civitai = civitai_data
|
||||
|
||||
|
||||
def update_file_info(self, file_path: str) -> None:
|
||||
"""Update metadata with actual file information"""
|
||||
if os.path.exists(file_path):
|
||||
self.size = os.path.getsize(file_path)
|
||||
self.modified = os.path.getmtime(file_path)
|
||||
self.file_path = file_path.replace(os.sep, '/')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user