Add download lora

This commit is contained in:
Will Miao
2025-02-14 10:57:33 +08:00
parent b7aca9b6fc
commit 451f77b99b
10 changed files with 283 additions and 87 deletions

View File

@@ -26,13 +26,13 @@ class LoraManager:
# Setup feature routes # Setup feature routes
routes = LoraRoutes() routes = LoraRoutes()
routes.setup_routes(app)
ApiRoutes.setup_routes(app)
# Setup file monitoring # Setup file monitoring
monitor = LoraFileMonitor(routes.scanner, config.loras_roots) monitor = LoraFileMonitor(routes.scanner, config.loras_roots)
monitor.start() monitor.start()
routes.setup_routes(app)
ApiRoutes.setup_routes(app, monitor)
# Store monitor in app for cleanup # Store monitor in app for cleanup
app['lora_monitor'] = monitor app['lora_monitor'] = monitor

View File

@@ -3,26 +3,32 @@ import json
import logging import logging
from aiohttp import web from aiohttp import web
from typing import Dict, List from typing import Dict, List
from ..services.file_monitor import LoraFileMonitor
from ..services.download_manager import DownloadManager
from ..services.civitai_client import CivitaiClient from ..services.civitai_client import CivitaiClient
from ..config import config from ..config import config
from ..services.lora_scanner import LoraScanner from ..services.lora_scanner import LoraScanner
from operator import itemgetter from operator import itemgetter
from ..services.websocket_manager import ws_manager from ..services.websocket_manager import ws_manager
from ..services.settings_manager import settings # 添加这行 from ..services.settings_manager import settings
import asyncio
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ApiRoutes: class ApiRoutes:
"""API route handlers for LoRA management""" """API route handlers for LoRA management"""
def __init__(self): def __init__(self, file_monitor: LoraFileMonitor):
self.scanner = LoraScanner() self.scanner = LoraScanner()
self.civitai_client = CivitaiClient() self.civitai_client = CivitaiClient()
self.download_manager = DownloadManager(file_monitor)
self._download_lock = asyncio.Lock()
@classmethod @classmethod
def setup_routes(cls, app: web.Application): def setup_routes(cls, app: web.Application, monitor: LoraFileMonitor):
"""Register API routes""" """Register API routes"""
routes = cls() routes = cls(monitor)
app.router.add_post('/api/delete_model', routes.delete_model) app.router.add_post('/api/delete_model', routes.delete_model)
app.router.add_post('/api/fetch-civitai', routes.fetch_civitai) app.router.add_post('/api/fetch-civitai', routes.fetch_civitai)
app.router.add_post('/api/replace_preview', routes.replace_preview) app.router.add_post('/api/replace_preview', routes.replace_preview)
@@ -467,41 +473,18 @@ class ApiRoutes:
return web.Response(status=500, text=str(e)) return web.Response(status=500, text=str(e))
async def download_lora(self, request: web.Request) -> web.Response: async def download_lora(self, request: web.Request) -> web.Response:
"""Handle LoRA download request""" async with self._download_lock:
try: try:
data = await request.json() data = await request.json()
download_url = data.get('download_url') result = await self.download_manager.download_from_civitai(
version_info = data.get('version_info') download_url=data.get('download_url'),
lora_root = data.get('lora_root') save_dir=data.get('lora_root'),
new_folder = data.get('new_folder', '').strip() relative_path=data.get('relative_path')
)
if not download_url or not version_info or not lora_root:
return web.Response(status=400, text="Missing required parameters")
if not os.path.isdir(lora_root):
return web.Response(status=400, text="Invalid LoRA root directory")
# 构建保存路径
save_dir = os.path.join(lora_root, new_folder) if new_folder else lora_root
os.makedirs(save_dir, exist_ok=True)
# 使用提供的下载 URL 和版本信息
result = await self.civitai_client.download_model_with_info(
download_url=download_url,
version_info=version_info,
save_dir=save_dir
)
if result.get('success'):
# 更新缓存 - 使用正确的扫描方法
await self.scanner.scan_directory(save_dir) # Changed from rescan_directory to scan_directory
return web.json_response(result) return web.json_response(result)
else: except Exception as e:
return web.Response(status=500, text=result.get('error', 'Download failed')) logger.error(f"Error downloading LoRA: {e}")
return web.Response(status=500, text=str(e))
except Exception as e:
logger.error(f"Error downloading LoRA: {e}")
return web.Response(status=500, text=str(e))
async def update_settings(self, request: web.Request) -> web.Response: async def update_settings(self, request: web.Request) -> web.Response:
"""Update application settings""" """Update application settings"""

View File

@@ -1,3 +1,4 @@
from datetime import datetime
import aiohttp import aiohttp
import os import os
import json import json
@@ -5,6 +6,7 @@ import logging
from email.parser import Parser from email.parser import Parser
from typing import Optional, Dict, Tuple from typing import Optional, Dict, Tuple
from urllib.parse import unquote from urllib.parse import unquote
from ..utils.models import LoraMetadata
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -127,50 +129,20 @@ class CivitaiClient:
logger.error(f"Error fetching model versions: {e}") logger.error(f"Error fetching model versions: {e}")
return None return None
async def download_model_with_info(self, download_url: str, version_info: dict, save_dir: str) -> Dict: async def get_model_version_info(self, version_id: str) -> Optional[Dict]:
"""Download model using provided version info and URL""" """Fetch model version metadata from Civitai"""
try: try:
# Generate default filename session = await self.session
default_filename = f"lora_{version_info['id']}.safetensors" url = f"{self.base_url}/model-versions/{version_id}"
logger.info(f"Downloading model: {version_info.get('name', 'Unknown')}") headers = self._get_request_headers()
# Download the model file
success, result = await self._download_file(download_url, save_dir, default_filename)
if not success:
return {'success': False, 'error': result}
save_path = result
# Create metadata file
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
metadata = {
'model_name': version_info.get('name', os.path.basename(save_path)),
'civitai': version_info,
'preview_url': None,
'from_civitai': True
}
# Download preview image if available
images = version_info.get('images', [])
if images:
preview_ext = '.mp4' if images[0].get('type') == 'video' else '.png'
preview_path = os.path.splitext(save_path)[0] + '.preview' + preview_ext
await self.download_preview_image(images[0]['url'], preview_path)
metadata['preview_url'] = preview_path.replace(os.sep, '/')
# Save metadata
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
return {
'success': True,
'file_path': save_path.replace(os.sep, '/'),
'metadata': metadata
}
async with session.get(url, headers=headers) as response:
if response.status == 200:
return await response.json()
return None
except Exception as e: except Exception as e:
logger.error(f"Error downloading model version: {e}") logger.error(f"Error fetching model version info: {e}")
return {'success': False, 'error': str(e)} return None
async def close(self): async def close(self):
"""Close the session if it exists""" """Close the session if it exists"""

View File

@@ -0,0 +1,123 @@
import logging
import os
import json
from typing import Optional, Dict
from .civitai_client import CivitaiClient
from .file_monitor import LoraFileMonitor
from ..utils.models import LoraMetadata
logger = logging.getLogger(__name__)
class DownloadManager:
def __init__(self, file_monitor: Optional[LoraFileMonitor] = None):
self.civitai_client = CivitaiClient()
self.file_monitor = file_monitor
async def download_from_civitai(self, download_url: str, save_dir: str, relative_path: str = '') -> Dict:
try:
# Update save directory with relative path if provided
if relative_path:
save_dir = os.path.join(save_dir, relative_path)
# Create directory if it doesn't exist
os.makedirs(save_dir, exist_ok=True)
# Get version info
version_id = download_url.split('/')[-1]
version_info = await self.civitai_client.get_model_version_info(version_id)
if not version_info:
return {'success': False, 'error': 'Failed to fetch model metadata'}
# 2. 获取文件信息
file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None)
if not file_info:
return {'success': False, 'error': 'No primary file found in metadata'}
# 3. 准备下载
file_name = file_info['name']
save_path = os.path.join(save_dir, file_name)
file_size = file_info.get('sizeKB', 0) * 1024
# 4. 通知文件监控系统
self.file_monitor.handler.add_ignore_path(
save_path.replace(os.sep, '/'),
file_size
)
# 5. 准备元数据
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
# 6. 开始下载流程
result = await self._execute_download(
download_url=download_url,
save_dir=save_dir,
metadata=metadata,
version_info=version_info,
relative_path=relative_path
)
return result
except Exception as e:
logger.error(f"Error in download_from_civitai: {e}", exc_info=True)
return {'success': False, 'error': str(e)}
async def _execute_download(self, download_url: str, save_dir: str,
metadata: LoraMetadata, version_info: Dict, relative_path: str) -> Dict:
"""执行实际的下载流程,包括预览图和模型文件"""
try:
save_path = metadata.file_path
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
# 2. 下载预览图(如果有)
images = version_info.get('images', [])
if images:
preview_ext = '.mp4' if images[0].get('type') == 'video' else '.png'
preview_path = os.path.splitext(save_path)[0] + '.preview' + preview_ext
if await self.civitai_client.download_preview_image(images[0]['url'], preview_path):
metadata.preview_url = preview_path.replace(os.sep, '/')
# 更新元数据中的预览图URL
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
# 3. 下载模型文件
success, result = await self.civitai_client._download_file(
download_url,
save_dir,
os.path.basename(save_path)
)
if not success:
# 下载失败时清理文件
for path in [save_path, metadata_path, metadata.preview_url]:
if path and os.path.exists(path):
os.remove(path)
return {'success': False, 'error': result}
# 4. 更新文件信息(大小和修改时间)
metadata.update_file_info(save_path)
# 5. 最终更新元数据
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
# 6. update lora cache
cache = await self.file_monitor.scanner.get_cached_data()
metadata_dict = metadata.to_dict()
metadata_dict['folder'] = relative_path
cache.raw_data.append(metadata_dict)
await cache.resort()
all_folders = set(cache.folders)
all_folders.add(relative_path)
cache.folders = sorted(list(all_folders))
return {
'success': True
}
except Exception as e:
logger.error(f"Error in _execute_download: {e}", exc_info=True)
# 确保清理任何部分下载的文件
for path in [save_path, metadata_path]:
if path and os.path.exists(path):
os.remove(path)
return {'success': False, 'error': str(e)}

View File

@@ -19,10 +19,41 @@ class LoraFileHandler(FileSystemEventHandler):
self.pending_changes = set() # 待处理的变更 self.pending_changes = set() # 待处理的变更
self.lock = Lock() # 线程安全锁 self.lock = Lock() # 线程安全锁
self.update_task = None # 异步更新任务 self.update_task = None # 异步更新任务
self._ignore_paths = set() # Add ignore paths set
self._min_ignore_timeout = 5 # minimum timeout in seconds
self._download_speed = 1024 * 1024 # assume 1MB/s as base speed
def _should_ignore(self, path: str) -> bool:
"""Check if path should be ignored"""
logger.info(f"Checking ignore for {path}")
logger.info(f"Ignore paths: {self._ignore_paths}")
return path.replace(os.sep, '/') in self._ignore_paths
def add_ignore_path(self, path: str, file_size: int = 0):
"""Add path to ignore list with dynamic timeout based on file size"""
self._ignore_paths.add(path.replace(os.sep, '/'))
logger.info(f"Update ignore paths: {self._ignore_paths}")
# Calculate timeout based on file size, with a minimum value
# Assuming average download speed of 1MB/s
timeout = max(
self._min_ignore_timeout,
(file_size / self._download_speed) * 1.5 # Add 50% buffer
)
logger.debug(f"Adding {path} to ignore list for {timeout:.1f} seconds")
asyncio.get_event_loop().call_later(
timeout,
self._ignore_paths.discard,
path
)
def on_created(self, event): def on_created(self, event):
if event.is_directory or not event.src_path.endswith('.safetensors'): if event.is_directory or not event.src_path.endswith('.safetensors'):
return return
if self._should_ignore(event.src_path):
return
logger.info(f"LoRA file created: {event.src_path}") logger.info(f"LoRA file created: {event.src_path}")
self._schedule_update('add', event.src_path) self._schedule_update('add', event.src_path)

View File

@@ -1212,3 +1212,19 @@ body.modal-open {
opacity: 0.8; opacity: 0.8;
margin-top: 4px; margin-top: 4px;
} }
.folder-item {
padding: 8px;
cursor: pointer;
border-radius: var(--border-radius-xs);
transition: background-color 0.2s;
}
.folder-item:hover {
background: var(--lora-surface);
}
.folder-item.selected {
background: oklch(var(--lora-accent) / 0.1);
border: 1px solid var(--lora-accent);
}

View File

@@ -10,6 +10,7 @@ export class DownloadManager {
// Add initialization check // Add initialization check
this.initialized = false; this.initialized = false;
this.selectedFolder = '';
} }
showDownloadModal() { showDownloadModal() {
@@ -127,6 +128,9 @@ export class DownloadManager {
loraRoot.innerHTML = data.roots.map(root => loraRoot.innerHTML = data.roots.map(root =>
`<option value="${root}">${root}</option>` `<option value="${root}">${root}</option>`
).join(''); ).join('');
// Initialize folder browser after loading roots
this.initializeFolderBrowser();
} catch (error) { } catch (error) {
showToast(error.message, 'error'); showToast(error.message, 'error');
} }
@@ -151,20 +155,33 @@ export class DownloadManager {
return; return;
} }
console.log('Selected folder:', this.selectedFolder); // Log selected folder
console.log('New folder:', newFolder); // Log new folder
// Construct relative path
let relativePath = '';
if (this.selectedFolder) {
relativePath = this.selectedFolder;
}
if (newFolder) {
relativePath = relativePath ?
`${relativePath}/${newFolder}` : newFolder;
}
try { try {
const downloadUrl = this.currentVersion.downloadUrl; const downloadUrl = this.currentVersion.downloadUrl;
if (!downloadUrl) { if (!downloadUrl) {
throw new Error('No download URL available'); throw new Error('No download URL available');
} }
// 只传递必要参数
const response = await fetch('/api/download-lora', { const response = await fetch('/api/download-lora', {
method: 'POST', method: 'POST',
headers: { 'Content-Type': 'application/json' }, headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ body: JSON.stringify({
download_url: downloadUrl, download_url: downloadUrl,
version_info: this.currentVersion,
lora_root: loraRoot, lora_root: loraRoot,
new_folder: newFolder relative_path: relativePath
}) })
}); });
@@ -182,4 +199,25 @@ export class DownloadManager {
showToast(error.message, 'error'); showToast(error.message, 'error');
} }
} }
// Add new method to handle folder selection
initializeFolderBrowser() {
// Update folder selection handling
const folderBrowser = document.getElementById('folderBrowser');
if (!folderBrowser) return;
// Update folder selection event handling
folderBrowser.addEventListener('click', (event) => {
const folderItem = event.target.closest('.folder-item');
if (!folderItem) return;
// Remove previous selection
folderBrowser.querySelectorAll('.folder-item').forEach(f =>
f.classList.remove('selected'));
// Add selection to clicked folder
folderItem.classList.add('selected');
this.selectedFolder = folderItem.dataset.folder;
});
}
} }

View File

@@ -53,7 +53,12 @@
<div class="input-group"> <div class="input-group">
<label>Target Folder:</label> <label>Target Folder:</label>
<div class="folder-browser" id="folderBrowser"> <div class="folder-browser" id="folderBrowser">
<!-- Folder structure will be inserted here --> <!-- Folders will be dynamically inserted here -->
{% for folder in folders %}
<div class="folder-item" data-folder="{{ folder }}">
{{ folder }}
</div>
{% endfor %}
</div> </div>
</div> </div>
<div class="input-group"> <div class="input-group">

View File

@@ -7,6 +7,7 @@ BASE_MODEL_MAPPING = {
"sdxl": "SDXL", "sdxl": "SDXL",
"sd-v2": "SD2.0", "sd-v2": "SD2.0",
"flux1": "Flux.1 D", "flux1": "Flux.1 D",
"Illustrious": "IL"
} }
def determine_base_model(version_string: Optional[str]) -> str: def determine_base_model(version_string: Optional[str]) -> str:

View File

@@ -1,6 +1,8 @@
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict
from typing import Dict, Optional from typing import Dict, Optional
from datetime import datetime from datetime import datetime
import os
from .model_utils import determine_base_model
@dataclass @dataclass
class LoraMetadata: class LoraMetadata:
@@ -23,6 +25,25 @@ class LoraMetadata:
data_copy = data.copy() data_copy = data.copy()
return cls(**data_copy) return cls(**data_copy)
@classmethod
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'LoraMetadata':
"""Create LoraMetadata instance from Civitai version info"""
file_name = file_info['name']
base_model = determine_base_model(version_info.get('baseModel', ''))
return cls(
file_name=os.path.splitext(file_name)[0],
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
file_path=save_path.replace(os.sep, '/'),
size=file_info.get('sizeKB', 0) * 1024,
modified=datetime.now().timestamp(),
sha256=file_info['hashes'].get('SHA256', ''),
base_model=base_model,
preview_url=None, # Will be updated after preview download
from_civitai=True,
civitai=version_info
)
def to_dict(self) -> Dict: def to_dict(self) -> Dict:
"""Convert to dictionary for JSON serialization""" """Convert to dictionary for JSON serialization"""
return asdict(self) return asdict(self)
@@ -36,4 +57,10 @@ class LoraMetadata:
"""Update Civitai information""" """Update Civitai information"""
self.civitai = civitai_data self.civitai = civitai_data
def update_file_info(self, file_path: str) -> None:
"""Update metadata with actual file information"""
if os.path.exists(file_path):
self.size = os.path.getsize(file_path)
self.modified = os.path.getmtime(file_path)
self.file_path = file_path.replace(os.sep, '/')