refactor: Update model_id and model_version_id types to integers and add validation in routes

This commit is contained in:
Will Miao
2025-07-09 14:21:49 +08:00
parent c692713ffb
commit 79011bd257
4 changed files with 57 additions and 17 deletions

View File

@@ -48,8 +48,8 @@ class DownloadManager:
"""Get the checkpoint scanner from registry"""
return await ServiceRegistry.get_checkpoint_scanner()
async def download_from_civitai(self, model_id: str = None,
model_version_id: str = None, save_dir: str = None,
async def download_from_civitai(self, model_id: int,
model_version_id: int, save_dir: str = None,
relative_path: str = '', progress_callback=None, use_default_paths: bool = False) -> Dict:
"""Download model from Civitai
@@ -65,6 +65,20 @@ class DownloadManager:
Dict with download result
"""
try:
# Check if model version already exists in library
if model_version_id is not None:
# Case 1: model_version_id is provided, check both scanners
lora_scanner = await self._get_lora_scanner()
checkpoint_scanner = await self._get_checkpoint_scanner()
# Check lora scanner first
if await lora_scanner.check_model_version_exists(model_id, model_version_id):
return {'success': False, 'error': 'Model version already exists in lora library'}
# Check checkpoint scanner
if await checkpoint_scanner.check_model_version_exists(model_id, model_version_id):
return {'success': False, 'error': 'Model version already exists in checkpoint library'}
# Get civitai client
civitai_client = await self._get_civitai_client()
@@ -82,10 +96,21 @@ class DownloadManager:
else:
return {'success': False, 'error': f'Model type "{model_type_from_info}" is not supported for download'}
scanner = model_type == 'checkpoint' and await self._get_checkpoint_scanner() or await self._get_lora_scanner()
if scanner.check_model_version_exists(model_id, model_version_id):
return {'success': False, 'error': 'Model version already exists in library'}
# Case 2: model_version_id was None, check after getting version_info
if model_version_id is None:
version_model_id = version_info.get('modelId')
version_id = version_info.get('id')
if model_type == 'lora':
# Check lora scanner
lora_scanner = await self._get_lora_scanner()
if await lora_scanner.check_model_version_exists(version_model_id, version_id):
return {'success': False, 'error': 'Model version already exists in lora library'}
elif model_type == 'checkpoint':
# Check checkpoint scanner
checkpoint_scanner = await self._get_checkpoint_scanner()
if await checkpoint_scanner.check_model_version_exists(version_model_id, version_id):
return {'success': False, 'error': 'Model version already exists in checkpoint library'}
# Handle use_default_paths
if use_default_paths: