refactor: Enhance checkpoint download functionality with new modal and manager integration

This commit is contained in:
Will Miao
2025-04-11 18:25:37 +08:00
parent 3df96034a1
commit 1db49a4dd4
10 changed files with 699 additions and 43 deletions

View File

@@ -4,7 +4,7 @@ import json
from typing import Optional, Dict
from .civitai_client import CivitaiClient
from .file_monitor import LoraFileMonitor
from ..utils.models import LoraMetadata
from ..utils.models import LoraMetadata, CheckpointMetadata
from ..utils.constants import CARD_PREVIEW_WIDTH
from ..utils.exif_utils import ExifUtils
@@ -20,7 +20,22 @@ class DownloadManager:
async def download_from_civitai(self, download_url: str = None, model_hash: str = None,
model_version_id: str = None, save_dir: str = None,
relative_path: str = '', progress_callback=None) -> Dict:
relative_path: str = '', progress_callback=None,
model_type: str = "lora") -> Dict:
"""Download model from Civitai
Args:
download_url: Direct download URL for the model
model_hash: SHA256 hash of the model
model_version_id: Civitai model version ID
save_dir: Directory to save the model to
relative_path: Relative path within save_dir
progress_callback: Callback function for progress updates
model_type: Type of model ('lora' or 'checkpoint')
Returns:
Dict with download result
"""
try:
# Update save directory with relative path if provided
if relative_path:
@@ -46,7 +61,7 @@ class DownloadManager:
if not version_info:
return {'success': False, 'error': 'Failed to fetch model metadata'}
# Check if this is an early access LoRA
# Check if this is an early access model
if version_info.get('earlyAccessEndsAt'):
early_access_date = version_info.get('earlyAccessEndsAt', '')
# Convert to a readable date if possible
@@ -54,12 +69,12 @@ class DownloadManager:
from datetime import datetime
date_obj = datetime.fromisoformat(early_access_date.replace('Z', '+00:00'))
formatted_date = date_obj.strftime('%Y-%m-%d')
early_access_msg = f"This LoRA requires early access payment (until {formatted_date}). "
early_access_msg = f"This model requires early access payment (until {formatted_date}). "
except:
early_access_msg = "This LoRA requires early access payment. "
early_access_msg = "This model requires early access payment. "
early_access_msg += "Please ensure you have purchased early access and are logged in to Civitai."
logger.warning(f"Early access LoRA detected: {version_info.get('name', 'Unknown')}")
logger.warning(f"Early access model detected: {version_info.get('name', 'Unknown')}")
# We'll still try to download, but log a warning and prepare for potential failure
if progress_callback:
@@ -69,26 +84,32 @@ class DownloadManager:
if progress_callback:
await progress_callback(0)
# 2. 获取文件信息
# 2. Get file information
file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None)
if not file_info:
return {'success': False, 'error': 'No primary file found in metadata'}
# 3. 准备下载
# 3. Prepare download
file_name = file_info['name']
save_path = os.path.join(save_dir, file_name)
file_size = file_info.get('sizeKB', 0) * 1024
# 4. 通知文件监控系统 - 使用规范化路径和文件大小
self.file_monitor.handler.add_ignore_path(
save_path.replace(os.sep, '/'),
file_size
)
# 4. Notify file monitor - use normalized path and file size
if self.file_monitor and self.file_monitor.handler:
self.file_monitor.handler.add_ignore_path(
save_path.replace(os.sep, '/'),
file_size
)
# 5. 准备元数据
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
# 5. Prepare metadata based on model type
if model_type == "checkpoint":
metadata = CheckpointMetadata.from_civitai_info(version_info, file_info, save_path)
logger.info(f"Creating CheckpointMetadata for {file_name}")
else:
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
logger.info(f"Creating LoraMetadata for {file_name}")
# 5.1 获取并更新模型标签和描述信息
# 5.1 Get and update model tags and description
model_id = version_info.get('modelId')
if model_id:
model_metadata, _ = await self.civitai_client.get_model_metadata(str(model_id))
@@ -98,14 +119,15 @@ class DownloadManager:
if model_metadata.get("description"):
metadata.modelDescription = model_metadata.get("description", "")
# 6. 开始下载流程
# 6. Start download process
result = await self._execute_download(
download_url=file_info.get('downloadUrl', ''),
save_dir=save_dir,
metadata=metadata,
version_info=version_info,
relative_path=relative_path,
progress_callback=progress_callback
progress_callback=progress_callback,
model_type=model_type
)
return result
@@ -119,8 +141,9 @@ class DownloadManager:
return {'success': False, 'error': str(e)}
async def _execute_download(self, download_url: str, save_dir: str,
metadata: LoraMetadata, version_info: Dict,
relative_path: str, progress_callback=None) -> Dict:
metadata, version_info: Dict,
relative_path: str, progress_callback=None,
model_type: str = "lora") -> Dict:
"""Execute the actual download process including preview images and model files"""
try:
save_path = metadata.file_path
@@ -201,15 +224,21 @@ class DownloadManager:
os.remove(path)
return {'success': False, 'error': result}
# 4. 更新文件信息(大小和修改时间)
# 4. Update file information (size and modified time)
metadata.update_file_info(save_path)
# 5. 最终更新元数据
# 5. Final metadata update
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
# 6. update lora cache
cache = await self.file_monitor.scanner.get_cached_data()
# 6. Update cache based on model type
if model_type == "checkpoint" and hasattr(self.file_monitor, "checkpoint_scanner"):
cache = await self.file_monitor.checkpoint_scanner.get_cached_data()
logger.info(f"Updating checkpoint cache for {save_path}")
else:
cache = await self.file_monitor.scanner.get_cached_data()
logger.info(f"Updating lora cache for {save_path}")
metadata_dict = metadata.to_dict()
metadata_dict['folder'] = relative_path
cache.raw_data.append(metadata_dict)
@@ -218,11 +247,11 @@ class DownloadManager:
all_folders.add(relative_path)
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
# Update the hash index with the new LoRA entry
self.file_monitor.scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
# Update the hash index with the new LoRA entry
self.file_monitor.scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
# Update the hash index with the new model entry
if model_type == "checkpoint" and hasattr(self.file_monitor, "checkpoint_scanner"):
self.file_monitor.checkpoint_scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
else:
self.file_monitor.scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
# Report 100% completion
if progress_callback: