mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
refactor: Update model_id and model_version_id types to integers and add validation in routes
This commit is contained in:
@@ -225,7 +225,7 @@ class CivitaiClient:
|
|||||||
logger.error(f"Error fetching model versions: {e}")
|
logger.error(f"Error fetching model versions: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_model_version(self, model_id: str, version_id: str = "") -> Optional[Dict]:
|
async def get_model_version(self, model_id: int, version_id: int) -> Optional[Dict]:
|
||||||
"""Get specific model version with additional metadata
|
"""Get specific model version with additional metadata
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -250,7 +250,7 @@ class CivitaiClient:
|
|||||||
if version_id:
|
if version_id:
|
||||||
# If version_id provided, find exact match
|
# If version_id provided, find exact match
|
||||||
for version in model_versions:
|
for version in model_versions:
|
||||||
if str(version.get('id')) == str(version_id):
|
if version.get('id') == version_id:
|
||||||
matched_version = version
|
matched_version = version
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
@@ -267,7 +267,7 @@ class CivitaiClient:
|
|||||||
# Replace index with modelId
|
# Replace index with modelId
|
||||||
if 'index' in result:
|
if 'index' in result:
|
||||||
del result['index']
|
del result['index']
|
||||||
result['modelId'] = int(model_id)
|
result['modelId'] = model_id
|
||||||
|
|
||||||
# Add model field with metadata from top level
|
# Add model field with metadata from top level
|
||||||
result['model'] = {
|
result['model'] = {
|
||||||
|
|||||||
@@ -48,8 +48,8 @@ class DownloadManager:
|
|||||||
"""Get the checkpoint scanner from registry"""
|
"""Get the checkpoint scanner from registry"""
|
||||||
return await ServiceRegistry.get_checkpoint_scanner()
|
return await ServiceRegistry.get_checkpoint_scanner()
|
||||||
|
|
||||||
async def download_from_civitai(self, model_id: str = None,
|
async def download_from_civitai(self, model_id: int,
|
||||||
model_version_id: str = None, save_dir: str = None,
|
model_version_id: int, save_dir: str = None,
|
||||||
relative_path: str = '', progress_callback=None, use_default_paths: bool = False) -> Dict:
|
relative_path: str = '', progress_callback=None, use_default_paths: bool = False) -> Dict:
|
||||||
"""Download model from Civitai
|
"""Download model from Civitai
|
||||||
|
|
||||||
@@ -65,6 +65,20 @@ class DownloadManager:
|
|||||||
Dict with download result
|
Dict with download result
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Check if model version already exists in library
|
||||||
|
if model_version_id is not None:
|
||||||
|
# Case 1: model_version_id is provided, check both scanners
|
||||||
|
lora_scanner = await self._get_lora_scanner()
|
||||||
|
checkpoint_scanner = await self._get_checkpoint_scanner()
|
||||||
|
|
||||||
|
# Check lora scanner first
|
||||||
|
if await lora_scanner.check_model_version_exists(model_id, model_version_id):
|
||||||
|
return {'success': False, 'error': 'Model version already exists in lora library'}
|
||||||
|
|
||||||
|
# Check checkpoint scanner
|
||||||
|
if await checkpoint_scanner.check_model_version_exists(model_id, model_version_id):
|
||||||
|
return {'success': False, 'error': 'Model version already exists in checkpoint library'}
|
||||||
|
|
||||||
# Get civitai client
|
# Get civitai client
|
||||||
civitai_client = await self._get_civitai_client()
|
civitai_client = await self._get_civitai_client()
|
||||||
|
|
||||||
@@ -82,10 +96,21 @@ class DownloadManager:
|
|||||||
else:
|
else:
|
||||||
return {'success': False, 'error': f'Model type "{model_type_from_info}" is not supported for download'}
|
return {'success': False, 'error': f'Model type "{model_type_from_info}" is not supported for download'}
|
||||||
|
|
||||||
scanner = model_type == 'checkpoint' and await self._get_checkpoint_scanner() or await self._get_lora_scanner()
|
# Case 2: model_version_id was None, check after getting version_info
|
||||||
|
if model_version_id is None:
|
||||||
|
version_model_id = version_info.get('modelId')
|
||||||
|
version_id = version_info.get('id')
|
||||||
|
|
||||||
if scanner.check_model_version_exists(model_id, model_version_id):
|
if model_type == 'lora':
|
||||||
return {'success': False, 'error': 'Model version already exists in library'}
|
# Check lora scanner
|
||||||
|
lora_scanner = await self._get_lora_scanner()
|
||||||
|
if await lora_scanner.check_model_version_exists(version_model_id, version_id):
|
||||||
|
return {'success': False, 'error': 'Model version already exists in lora library'}
|
||||||
|
elif model_type == 'checkpoint':
|
||||||
|
# Check checkpoint scanner
|
||||||
|
checkpoint_scanner = await self._get_checkpoint_scanner()
|
||||||
|
if await checkpoint_scanner.check_model_version_exists(version_model_id, version_id):
|
||||||
|
return {'success': False, 'error': 'Model version already exists in checkpoint library'}
|
||||||
|
|
||||||
# Handle use_default_paths
|
# Handle use_default_paths
|
||||||
if use_default_paths:
|
if use_default_paths:
|
||||||
|
|||||||
@@ -52,9 +52,6 @@ class WebSocketManager:
|
|||||||
if not download_id:
|
if not download_id:
|
||||||
# Generate a new download ID if not provided
|
# Generate a new download ID if not provided
|
||||||
download_id = str(uuid4())
|
download_id = str(uuid4())
|
||||||
logger.info(f"Created new download ID: {download_id}")
|
|
||||||
else:
|
|
||||||
logger.info(f"Using provided download ID: {download_id}")
|
|
||||||
|
|
||||||
# Store the websocket with its download ID
|
# Store the websocket with its download ID
|
||||||
self._download_websockets[download_id] = ws
|
self._download_websockets[download_id] = ws
|
||||||
|
|||||||
@@ -590,9 +590,25 @@ class ModelRouteUtils:
|
|||||||
'download_id': download_id
|
'download_id': download_id
|
||||||
})
|
})
|
||||||
|
|
||||||
# Check which identifier is provided
|
# Check which identifier is provided and convert to int
|
||||||
model_id = data.get('model_id')
|
try:
|
||||||
model_version_id = data.get('model_version_id')
|
model_id = int(data.get('model_id'))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return web.Response(
|
||||||
|
status=400,
|
||||||
|
text="Invalid model_id: Must be an integer"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert model_version_id to int if provided
|
||||||
|
model_version_id = None
|
||||||
|
if data.get('model_version_id'):
|
||||||
|
try:
|
||||||
|
model_version_id = int(data.get('model_version_id'))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return web.Response(
|
||||||
|
status=400,
|
||||||
|
text="Invalid model_version_id: Must be an integer"
|
||||||
|
)
|
||||||
|
|
||||||
# Only model_id is required, model_version_id is optional
|
# Only model_id is required, model_version_id is optional
|
||||||
if not model_id:
|
if not model_id:
|
||||||
@@ -696,8 +712,10 @@ class ModelRouteUtils:
|
|||||||
try:
|
try:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
file_path = data.get('file_path')
|
file_path = data.get('file_path')
|
||||||
model_id = data.get('model_id')
|
model_id = int(data.get('model_id'))
|
||||||
model_version_id = data.get('model_version_id')
|
model_version_id = None
|
||||||
|
if data.get('model_version_id'):
|
||||||
|
model_version_id = int(data.get('model_version_id'))
|
||||||
|
|
||||||
if not file_path or not model_id:
|
if not file_path or not model_id:
|
||||||
return web.json_response({"success": False, "error": "Both file_path and model_id are required"}, status=400)
|
return web.json_response({"success": False, "error": "Both file_path and model_id are required"}, status=400)
|
||||||
|
|||||||
Reference in New Issue
Block a user