refactor: Simplify model download handling by consolidating download logic and updating parameter usage

This commit is contained in:
Will Miao
2025-07-02 18:25:42 +08:00
parent 9d8b7344cd
commit d7cb546c5f
7 changed files with 25 additions and 186 deletions

View File

@@ -437,68 +437,7 @@ class ApiRoutes:
}, status=500)
async def download_lora(self, request: web.Request) -> web.Response:
async with self._download_lock:
try:
if self.download_manager is None:
self.download_manager = await ServiceRegistry.get_download_manager()
data = await request.json()
# Create progress callback
async def progress_callback(progress):
await ws_manager.broadcast({
'status': 'progress',
'progress': progress
})
# Check which identifier is provided
download_url = data.get('download_url')
model_hash = data.get('model_hash')
model_version_id = data.get('model_version_id')
# Validate that at least one identifier is provided
if not any([download_url, model_hash, model_version_id]):
return web.Response(
status=400,
text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'"
)
result = await self.download_manager.download_from_civitai(
download_url=download_url,
model_hash=model_hash,
model_version_id=model_version_id,
save_dir=data.get('lora_root'),
relative_path=data.get('relative_path'),
progress_callback=progress_callback
)
if not result.get('success', False):
error_message = result.get('error', 'Unknown error')
# Return 401 for early access errors
if 'early access' in error_message.lower():
logger.warning(f"Early access download failed: {error_message}")
return web.Response(
status=401, # Use 401 status code to match Civitai's response
text=error_message
)
return web.Response(status=500, text=error_message)
return web.json_response(result)
except Exception as e:
error_message = str(e)
# Check if this might be an early access error
if '401' in error_message:
logger.warning(f"Early access error (401): {error_message}")
return web.Response(
status=401,
text="Early Access Restriction: This LoRA requires purchase. Please buy early access on Civitai.com."
)
logger.error(f"Error downloading LoRA: {error_message}")
return web.Response(status=500, text=error_message)
return await ModelRouteUtils.handle_download_model(request, self.download_manager, model_type="lora")
async def move_model(self, request: web.Request) -> web.Response:

View File

@@ -544,71 +544,7 @@ class CheckpointsRoutes:
async def download_checkpoint(self, request: web.Request) -> web.Response:
"""Handle checkpoint download request"""
async with self._download_lock:
# Get the download manager from service registry if not already initialized
if self.download_manager is None:
self.download_manager = await ServiceRegistry.get_download_manager()
try:
data = await request.json()
# Create progress callback that uses checkpoint-specific WebSocket
async def progress_callback(progress):
await ws_manager.broadcast_checkpoint_progress({
'status': 'progress',
'progress': progress
})
# Check which identifier is provided
download_url = data.get('download_url')
model_hash = data.get('model_hash')
model_version_id = data.get('model_version_id')
# Validate that at least one identifier is provided
if not any([download_url, model_hash, model_version_id]):
return web.Response(
status=400,
text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'"
)
result = await self.download_manager.download_from_civitai(
download_url=download_url,
model_hash=model_hash,
model_version_id=model_version_id,
save_dir=data.get('checkpoint_root'),
relative_path=data.get('relative_path', ''),
progress_callback=progress_callback,
model_type="checkpoint"
)
if not result.get('success', False):
error_message = result.get('error', 'Unknown error')
# Return 401 for early access errors
if 'early access' in error_message.lower():
logger.warning(f"Early access download failed: {error_message}")
return web.Response(
status=401,
text=f"Early Access Restriction: {error_message}"
)
return web.Response(status=500, text=error_message)
return web.json_response(result)
except Exception as e:
error_message = str(e)
# Check if this might be an early access error
if '401' in error_message:
logger.warning(f"Early access error (401): {error_message}")
return web.Response(
status=401,
text="Early Access Restriction: This model requires purchase. Please ensure you have purchased early access and are logged in to Civitai."
)
logger.error(f"Error downloading checkpoint: {error_message}")
return web.Response(status=500, text=error_message)
return await ModelRouteUtils.handle_download_model(request, self.download_manager, model_type="checkpoint")
async def get_checkpoint_roots(self, request):
"""Return the checkpoint root directories"""

View File

@@ -267,7 +267,7 @@ class CivitaiClient:
# Replace index with modelId
if 'index' in result:
del result['index']
result['modelId'] = model_id
result['modelId'] = int(model_id)
# Add model field with metadata from top level
result['model'] = {

View File

@@ -48,16 +48,15 @@ class DownloadManager:
"""Get the checkpoint scanner from registry"""
return await ServiceRegistry.get_checkpoint_scanner()
async def download_from_civitai(self, download_url: str = None, model_hash: str = None,
async def download_from_civitai(self, model_id: str = None,
model_version_id: str = None, save_dir: str = None,
relative_path: str = '', progress_callback=None,
model_type: str = "lora") -> Dict:
"""Download model from Civitai
Args:
download_url: Direct download URL for the model
model_hash: SHA256 hash of the model
model_version_id: Civitai model version ID
model_id: Civitai model ID
model_version_id: Civitai model version ID (optional, if not provided, will download the latest version)
save_dir: Directory to save the model to
relative_path: Relative path within save_dir
progress_callback: Callback function for progress updates
@@ -77,25 +76,10 @@ class DownloadManager:
civitai_client = await self._get_civitai_client()
# Get version info based on the provided identifier
version_info = None
error_msg = None
if model_hash:
# Get model by hash
version_info = await civitai_client.get_model_by_hash(model_hash)
elif model_version_id:
# Use model version ID directly
version_info, error_msg = await civitai_client.get_model_version_info(model_version_id)
elif download_url:
# Extract version ID from download URL
version_id = download_url.split('/')[-1]
version_info, error_msg = await civitai_client.get_model_version_info(version_id)
version_info = await civitai_client.get_model_version(model_id, model_version_id)
if not version_info:
if error_msg and "model not found" in error_msg.lower():
return {'success': False, 'error': f'Model not found on Civitai: {error_msg}'}
return {'success': False, 'error': error_msg or 'Failed to fetch model metadata'}
return {'success': False, 'error': 'Failed to fetch model metadata'}
# Check if this is an early access model
if version_info.get('earlyAccessEndsAt'):
@@ -137,18 +121,6 @@ class DownloadManager:
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
logger.info(f"Creating LoraMetadata for {file_name}")
# 5.1 Get and update model tags, description and creator info
model_id = version_info.get('modelId')
if model_id:
model_metadata, _ = await civitai_client.get_model_metadata(str(model_id))
if model_metadata:
if model_metadata.get("tags"):
metadata.tags = model_metadata.get("tags", [])
if model_metadata.get("description"):
metadata.modelDescription = model_metadata.get("description", "")
if model_metadata.get("creator"):
metadata.civitai["creator"] = model_metadata.get("creator")
# 6. Start download process
result = await self._execute_download(
download_url=file_info.get('downloadUrl', ''),

View File

@@ -587,15 +587,14 @@ class ModelRouteUtils:
})
# Check which identifier is provided
download_url = data.get('download_url')
model_hash = data.get('model_hash')
model_id = data.get('model_id')
model_version_id = data.get('model_version_id')
# Validate that at least one identifier is provided
if not any([download_url, model_hash, model_version_id]):
# Only model_id is required, model_version_id is optional
if not model_id:
return web.Response(
status=400,
text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'"
text="Missing required parameter: Please provide 'model_id'"
)
# Use the correct root directory based on model type
@@ -603,8 +602,7 @@ class ModelRouteUtils:
save_dir = data.get(root_key)
result = await download_manager.download_from_civitai(
download_url=download_url,
model_hash=model_hash,
model_id=model_id,
model_version_id=model_version_id,
save_dir=save_dir,
relative_path=data.get('relative_path', ''),

View File

@@ -61,6 +61,7 @@ export class CheckpointDownloadManager {
this.currentVersion = null;
this.versions = [];
this.modelInfo = null;
this.modelId = null;
this.modelVersionId = null;
// Clear selected folder and remove selection from UI
@@ -79,12 +80,12 @@ export class CheckpointDownloadManager {
try {
this.loadingManager.showSimpleLoading('Fetching model versions...');
const modelId = this.extractModelId(url);
if (!modelId) {
this.modelId = this.extractModelId(url);
if (!this.modelId) {
throw new Error('Invalid Civitai URL format');
}
const response = await fetch(`/api/checkpoints/civitai/versions/${modelId}`);
const response = await fetch(`/api/checkpoints/civitai/versions/${this.modelId}`);
if (!response.ok) {
const errorData = await response.json().catch(() => ({}));
if (errorData && errorData.error && errorData.error.includes('Model type mismatch')) {
@@ -296,11 +297,6 @@ export class CheckpointDownloadManager {
}
try {
const downloadUrl = this.currentVersion.downloadUrl;
if (!downloadUrl) {
throw new Error('No download URL available');
}
// Show enhanced loading with progress details
const updateProgress = this.loadingManager.showDownloadProgress(1);
updateProgress(0, 0, this.currentVersion.name);
@@ -338,7 +334,8 @@ export class CheckpointDownloadManager {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
download_url: downloadUrl,
model_id: this.modelId,
model_version_id: this.currentVersion.id,
checkpoint_root: checkpointRoot,
relative_path: targetFolder
})

View File

@@ -63,6 +63,7 @@ export class DownloadManager {
this.currentVersion = null;
this.versions = [];
this.modelInfo = null;
this.modelId = null;
this.modelVersionId = null;
// Clear selected folder and remove selection from UI
@@ -81,12 +82,12 @@ export class DownloadManager {
try {
this.loadingManager.showSimpleLoading('Fetching model versions...');
const modelId = this.extractModelId(url);
if (!modelId) {
this.modelId = this.extractModelId(url);
if (!this.modelId) {
throw new Error('Invalid Civitai URL format');
}
const response = await fetch(`/api/civitai/versions/${modelId}`);
const response = await fetch(`/api/civitai/versions/${this.modelId}`);
if (!response.ok) {
const errorData = await response.json().catch(() => ({}));
if (errorData && errorData.error && errorData.error.includes('Model type mismatch')) {
@@ -306,11 +307,6 @@ export class DownloadManager {
}
try {
const downloadUrl = this.currentVersion.downloadUrl;
if (!downloadUrl) {
throw new Error('No download URL available');
}
// Show enhanced loading with progress details
const updateProgress = this.loadingManager.showDownloadProgress(1);
updateProgress(0, 0, this.currentVersion.name);
@@ -348,7 +344,8 @@ export class DownloadManager {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
download_url: downloadUrl,
model_id: this.modelId,
model_version_id: this.currentVersion.id,
lora_root: loraRoot,
relative_path: targetFolder
})