refactor: Simplify model download handling by consolidating download logic and updating parameter usage

This commit is contained in:
Will Miao
2025-07-02 18:25:42 +08:00
parent 9d8b7344cd
commit d7cb546c5f
7 changed files with 25 additions and 186 deletions

View File

@@ -267,7 +267,7 @@ class CivitaiClient:
# Replace index with modelId
if 'index' in result:
del result['index']
result['modelId'] = model_id
result['modelId'] = int(model_id)
# Add model field with metadata from top level
result['model'] = {

View File

@@ -48,16 +48,15 @@ class DownloadManager:
"""Get the checkpoint scanner from registry"""
return await ServiceRegistry.get_checkpoint_scanner()
async def download_from_civitai(self, download_url: str = None, model_hash: str = None,
async def download_from_civitai(self, model_id: str = None,
model_version_id: str = None, save_dir: str = None,
relative_path: str = '', progress_callback=None,
model_type: str = "lora") -> Dict:
"""Download model from Civitai
Args:
download_url: Direct download URL for the model
model_hash: SHA256 hash of the model
model_version_id: Civitai model version ID
model_id: Civitai model ID
model_version_id: Civitai model version ID (optional, if not provided, will download the latest version)
save_dir: Directory to save the model to
relative_path: Relative path within save_dir
progress_callback: Callback function for progress updates
@@ -77,25 +76,10 @@ class DownloadManager:
civitai_client = await self._get_civitai_client()
# Get version info based on the provided identifier
version_info = None
error_msg = None
if model_hash:
# Get model by hash
version_info = await civitai_client.get_model_by_hash(model_hash)
elif model_version_id:
# Use model version ID directly
version_info, error_msg = await civitai_client.get_model_version_info(model_version_id)
elif download_url:
# Extract version ID from download URL
version_id = download_url.split('/')[-1]
version_info, error_msg = await civitai_client.get_model_version_info(version_id)
version_info = await civitai_client.get_model_version(model_id, model_version_id)
if not version_info:
if error_msg and "model not found" in error_msg.lower():
return {'success': False, 'error': f'Model not found on Civitai: {error_msg}'}
return {'success': False, 'error': error_msg or 'Failed to fetch model metadata'}
return {'success': False, 'error': 'Failed to fetch model metadata'}
# Check if this is an early access model
if version_info.get('earlyAccessEndsAt'):
@@ -137,18 +121,6 @@ class DownloadManager:
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
logger.info(f"Creating LoraMetadata for {file_name}")
# 5.1 Get and update model tags, description and creator info
model_id = version_info.get('modelId')
if model_id:
model_metadata, _ = await civitai_client.get_model_metadata(str(model_id))
if model_metadata:
if model_metadata.get("tags"):
metadata.tags = model_metadata.get("tags", [])
if model_metadata.get("description"):
metadata.modelDescription = model_metadata.get("description", "")
if model_metadata.get("creator"):
metadata.civitai["creator"] = model_metadata.get("creator")
# 6. Start download process
result = await self._execute_download(
download_url=file_info.get('downloadUrl', ''),