Add download lora

This commit is contained in:
Will Miao
2025-02-14 10:57:33 +08:00
parent b7aca9b6fc
commit 451f77b99b
10 changed files with 283 additions and 87 deletions

View File

@@ -3,26 +3,32 @@ import json
import logging
from aiohttp import web
from typing import Dict, List
from ..services.file_monitor import LoraFileMonitor
from ..services.download_manager import DownloadManager
from ..services.civitai_client import CivitaiClient
from ..config import config
from ..services.lora_scanner import LoraScanner
from operator import itemgetter
from ..services.websocket_manager import ws_manager
from ..services.settings_manager import settings # 添加这行
from ..services.settings_manager import settings
import asyncio
logger = logging.getLogger(__name__)
class ApiRoutes:
"""API route handlers for LoRA management"""
def __init__(self):
def __init__(self, file_monitor: LoraFileMonitor):
self.scanner = LoraScanner()
self.civitai_client = CivitaiClient()
self.download_manager = DownloadManager(file_monitor)
self._download_lock = asyncio.Lock()
@classmethod
def setup_routes(cls, app: web.Application):
def setup_routes(cls, app: web.Application, monitor: LoraFileMonitor):
"""Register API routes"""
routes = cls()
routes = cls(monitor)
app.router.add_post('/api/delete_model', routes.delete_model)
app.router.add_post('/api/fetch-civitai', routes.fetch_civitai)
app.router.add_post('/api/replace_preview', routes.replace_preview)
@@ -467,41 +473,18 @@ class ApiRoutes:
return web.Response(status=500, text=str(e))
async def download_lora(self, request: web.Request) -> web.Response:
"""Handle LoRA download request"""
try:
data = await request.json()
download_url = data.get('download_url')
version_info = data.get('version_info')
lora_root = data.get('lora_root')
new_folder = data.get('new_folder', '').strip()
if not download_url or not version_info or not lora_root:
return web.Response(status=400, text="Missing required parameters")
if not os.path.isdir(lora_root):
return web.Response(status=400, text="Invalid LoRA root directory")
# 构建保存路径
save_dir = os.path.join(lora_root, new_folder) if new_folder else lora_root
os.makedirs(save_dir, exist_ok=True)
# 使用提供的下载 URL 和版本信息
result = await self.civitai_client.download_model_with_info(
download_url=download_url,
version_info=version_info,
save_dir=save_dir
)
if result.get('success'):
# 更新缓存 - 使用正确的扫描方法
await self.scanner.scan_directory(save_dir) # Changed from rescan_directory to scan_directory
async with self._download_lock:
try:
data = await request.json()
result = await self.download_manager.download_from_civitai(
download_url=data.get('download_url'),
save_dir=data.get('lora_root'),
relative_path=data.get('relative_path')
)
return web.json_response(result)
else:
return web.Response(status=500, text=result.get('error', 'Download failed'))
except Exception as e:
logger.error(f"Error downloading LoRA: {e}")
return web.Response(status=500, text=str(e))
except Exception as e:
logger.error(f"Error downloading LoRA: {e}")
return web.Response(status=500, text=str(e))
async def update_settings(self, request: web.Request) -> web.Response:
"""Update application settings"""