Add download lora

This commit is contained in:
Will Miao
2025-02-14 10:57:33 +08:00
parent b7aca9b6fc
commit 451f77b99b
10 changed files with 283 additions and 87 deletions

View File

@@ -1,3 +1,4 @@
from datetime import datetime
import aiohttp
import os
import json
@@ -5,6 +6,7 @@ import logging
from email.parser import Parser
from typing import Optional, Dict, Tuple
from urllib.parse import unquote
from ..utils.models import LoraMetadata
logger = logging.getLogger(__name__)
@@ -127,50 +129,20 @@ class CivitaiClient:
logger.error(f"Error fetching model versions: {e}")
return None
async def download_model_with_info(self, download_url: str, version_info: dict, save_dir: str) -> Dict:
"""Download model using provided version info and URL"""
async def get_model_version_info(self, version_id: str) -> Optional[Dict]:
"""Fetch model version metadata from Civitai"""
try:
# Generate default filename
default_filename = f"lora_{version_info['id']}.safetensors"
logger.info(f"Downloading model: {version_info.get('name', 'Unknown')}")
session = await self.session
url = f"{self.base_url}/model-versions/{version_id}"
headers = self._get_request_headers()
# Download the model file
success, result = await self._download_file(download_url, save_dir, default_filename)
if not success:
return {'success': False, 'error': result}
save_path = result
# Create metadata file
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
metadata = {
'model_name': version_info.get('name', os.path.basename(save_path)),
'civitai': version_info,
'preview_url': None,
'from_civitai': True
}
# Download preview image if available
images = version_info.get('images', [])
if images:
preview_ext = '.mp4' if images[0].get('type') == 'video' else '.png'
preview_path = os.path.splitext(save_path)[0] + '.preview' + preview_ext
await self.download_preview_image(images[0]['url'], preview_path)
metadata['preview_url'] = preview_path.replace(os.sep, '/')
# Save metadata
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
return {
'success': True,
'file_path': save_path.replace(os.sep, '/'),
'metadata': metadata
}
async with session.get(url, headers=headers) as response:
if response.status == 200:
return await response.json()
return None
except Exception as e:
logger.error(f"Error downloading model version: {e}")
return {'success': False, 'error': str(e)}
logger.error(f"Error fetching model version info: {e}")
return None
async def close(self):
"""Close the session if it exists"""