Add download lora

This commit is contained in:
Will Miao
2025-02-14 10:57:33 +08:00
parent b7aca9b6fc
commit 451f77b99b
10 changed files with 283 additions and 87 deletions

View File

@@ -1,3 +1,4 @@
from datetime import datetime
import aiohttp
import os
import json
@@ -5,6 +6,7 @@ import logging
from email.parser import Parser
from typing import Optional, Dict, Tuple
from urllib.parse import unquote
from ..utils.models import LoraMetadata
logger = logging.getLogger(__name__)
@@ -127,50 +129,20 @@ class CivitaiClient:
logger.error(f"Error fetching model versions: {e}")
return None
async def download_model_with_info(self, download_url: str, version_info: dict, save_dir: str) -> Dict:
"""Download model using provided version info and URL"""
async def get_model_version_info(self, version_id: str) -> Optional[Dict]:
"""Fetch model version metadata from Civitai"""
try:
# Generate default filename
default_filename = f"lora_{version_info['id']}.safetensors"
logger.info(f"Downloading model: {version_info.get('name', 'Unknown')}")
session = await self.session
url = f"{self.base_url}/model-versions/{version_id}"
headers = self._get_request_headers()
# Download the model file
success, result = await self._download_file(download_url, save_dir, default_filename)
if not success:
return {'success': False, 'error': result}
save_path = result
# Create metadata file
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
metadata = {
'model_name': version_info.get('name', os.path.basename(save_path)),
'civitai': version_info,
'preview_url': None,
'from_civitai': True
}
# Download preview image if available
images = version_info.get('images', [])
if images:
preview_ext = '.mp4' if images[0].get('type') == 'video' else '.png'
preview_path = os.path.splitext(save_path)[0] + '.preview' + preview_ext
await self.download_preview_image(images[0]['url'], preview_path)
metadata['preview_url'] = preview_path.replace(os.sep, '/')
# Save metadata
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
return {
'success': True,
'file_path': save_path.replace(os.sep, '/'),
'metadata': metadata
}
async with session.get(url, headers=headers) as response:
if response.status == 200:
return await response.json()
return None
except Exception as e:
logger.error(f"Error downloading model version: {e}")
return {'success': False, 'error': str(e)}
logger.error(f"Error fetching model version info: {e}")
return None
async def close(self):
"""Close the session if it exists"""

View File

@@ -0,0 +1,123 @@
import logging
import os
import json
from typing import Optional, Dict
from .civitai_client import CivitaiClient
from .file_monitor import LoraFileMonitor
from ..utils.models import LoraMetadata
logger = logging.getLogger(__name__)
class DownloadManager:
def __init__(self, file_monitor: Optional[LoraFileMonitor] = None):
self.civitai_client = CivitaiClient()
self.file_monitor = file_monitor
async def download_from_civitai(self, download_url: str, save_dir: str, relative_path: str = '') -> Dict:
try:
# Update save directory with relative path if provided
if relative_path:
save_dir = os.path.join(save_dir, relative_path)
# Create directory if it doesn't exist
os.makedirs(save_dir, exist_ok=True)
# Get version info
version_id = download_url.split('/')[-1]
version_info = await self.civitai_client.get_model_version_info(version_id)
if not version_info:
return {'success': False, 'error': 'Failed to fetch model metadata'}
# 2. 获取文件信息
file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None)
if not file_info:
return {'success': False, 'error': 'No primary file found in metadata'}
# 3. 准备下载
file_name = file_info['name']
save_path = os.path.join(save_dir, file_name)
file_size = file_info.get('sizeKB', 0) * 1024
# 4. 通知文件监控系统
self.file_monitor.handler.add_ignore_path(
save_path.replace(os.sep, '/'),
file_size
)
# 5. 准备元数据
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
# 6. 开始下载流程
result = await self._execute_download(
download_url=download_url,
save_dir=save_dir,
metadata=metadata,
version_info=version_info,
relative_path=relative_path
)
return result
except Exception as e:
logger.error(f"Error in download_from_civitai: {e}", exc_info=True)
return {'success': False, 'error': str(e)}
async def _execute_download(self, download_url: str, save_dir: str,
metadata: LoraMetadata, version_info: Dict, relative_path: str) -> Dict:
"""执行实际的下载流程,包括预览图和模型文件"""
try:
save_path = metadata.file_path
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
# 2. 下载预览图(如果有)
images = version_info.get('images', [])
if images:
preview_ext = '.mp4' if images[0].get('type') == 'video' else '.png'
preview_path = os.path.splitext(save_path)[0] + '.preview' + preview_ext
if await self.civitai_client.download_preview_image(images[0]['url'], preview_path):
metadata.preview_url = preview_path.replace(os.sep, '/')
# 更新元数据中的预览图URL
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
# 3. 下载模型文件
success, result = await self.civitai_client._download_file(
download_url,
save_dir,
os.path.basename(save_path)
)
if not success:
# 下载失败时清理文件
for path in [save_path, metadata_path, metadata.preview_url]:
if path and os.path.exists(path):
os.remove(path)
return {'success': False, 'error': result}
# 4. 更新文件信息(大小和修改时间)
metadata.update_file_info(save_path)
# 5. 最终更新元数据
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
# 6. update lora cache
cache = await self.file_monitor.scanner.get_cached_data()
metadata_dict = metadata.to_dict()
metadata_dict['folder'] = relative_path
cache.raw_data.append(metadata_dict)
await cache.resort()
all_folders = set(cache.folders)
all_folders.add(relative_path)
cache.folders = sorted(list(all_folders))
return {
'success': True
}
except Exception as e:
logger.error(f"Error in _execute_download: {e}", exc_info=True)
# 确保清理任何部分下载的文件
for path in [save_path, metadata_path]:
if path and os.path.exists(path):
os.remove(path)
return {'success': False, 'error': str(e)}

View File

@@ -19,10 +19,41 @@ class LoraFileHandler(FileSystemEventHandler):
self.pending_changes = set() # 待处理的变更
self.lock = Lock() # 线程安全锁
self.update_task = None # 异步更新任务
self._ignore_paths = set() # Add ignore paths set
self._min_ignore_timeout = 5 # minimum timeout in seconds
self._download_speed = 1024 * 1024 # assume 1MB/s as base speed
def _should_ignore(self, path: str) -> bool:
"""Check if path should be ignored"""
logger.info(f"Checking ignore for {path}")
logger.info(f"Ignore paths: {self._ignore_paths}")
return path.replace(os.sep, '/') in self._ignore_paths
def add_ignore_path(self, path: str, file_size: int = 0):
"""Add path to ignore list with dynamic timeout based on file size"""
self._ignore_paths.add(path.replace(os.sep, '/'))
logger.info(f"Update ignore paths: {self._ignore_paths}")
# Calculate timeout based on file size, with a minimum value
# Assuming average download speed of 1MB/s
timeout = max(
self._min_ignore_timeout,
(file_size / self._download_speed) * 1.5 # Add 50% buffer
)
logger.debug(f"Adding {path} to ignore list for {timeout:.1f} seconds")
asyncio.get_event_loop().call_later(
timeout,
self._ignore_paths.discard,
path
)
def on_created(self, event):
if event.is_directory or not event.src_path.endswith('.safetensors'):
return
if self._should_ignore(event.src_path):
return
logger.info(f"LoRA file created: {event.src_path}")
self._schedule_update('add', event.src_path)
@@ -123,4 +154,4 @@ class LoraFileMonitor:
def stop(self):
"""Stop monitoring"""
self.observer.stop()
self.observer.join()
self.observer.join()