mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
refactor: Simplify model download handling by consolidating download logic and updating parameter usage
This commit is contained in:
@@ -437,68 +437,7 @@ class ApiRoutes:
|
|||||||
}, status=500)
|
}, status=500)
|
||||||
|
|
||||||
async def download_lora(self, request: web.Request) -> web.Response:
|
async def download_lora(self, request: web.Request) -> web.Response:
|
||||||
async with self._download_lock:
|
return await ModelRouteUtils.handle_download_model(request, self.download_manager, model_type="lora")
|
||||||
try:
|
|
||||||
if self.download_manager is None:
|
|
||||||
self.download_manager = await ServiceRegistry.get_download_manager()
|
|
||||||
|
|
||||||
data = await request.json()
|
|
||||||
|
|
||||||
# Create progress callback
|
|
||||||
async def progress_callback(progress):
|
|
||||||
await ws_manager.broadcast({
|
|
||||||
'status': 'progress',
|
|
||||||
'progress': progress
|
|
||||||
})
|
|
||||||
|
|
||||||
# Check which identifier is provided
|
|
||||||
download_url = data.get('download_url')
|
|
||||||
model_hash = data.get('model_hash')
|
|
||||||
model_version_id = data.get('model_version_id')
|
|
||||||
|
|
||||||
# Validate that at least one identifier is provided
|
|
||||||
if not any([download_url, model_hash, model_version_id]):
|
|
||||||
return web.Response(
|
|
||||||
status=400,
|
|
||||||
text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'"
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await self.download_manager.download_from_civitai(
|
|
||||||
download_url=download_url,
|
|
||||||
model_hash=model_hash,
|
|
||||||
model_version_id=model_version_id,
|
|
||||||
save_dir=data.get('lora_root'),
|
|
||||||
relative_path=data.get('relative_path'),
|
|
||||||
progress_callback=progress_callback
|
|
||||||
)
|
|
||||||
|
|
||||||
if not result.get('success', False):
|
|
||||||
error_message = result.get('error', 'Unknown error')
|
|
||||||
|
|
||||||
# Return 401 for early access errors
|
|
||||||
if 'early access' in error_message.lower():
|
|
||||||
logger.warning(f"Early access download failed: {error_message}")
|
|
||||||
return web.Response(
|
|
||||||
status=401, # Use 401 status code to match Civitai's response
|
|
||||||
text=error_message
|
|
||||||
)
|
|
||||||
|
|
||||||
return web.Response(status=500, text=error_message)
|
|
||||||
|
|
||||||
return web.json_response(result)
|
|
||||||
except Exception as e:
|
|
||||||
error_message = str(e)
|
|
||||||
|
|
||||||
# Check if this might be an early access error
|
|
||||||
if '401' in error_message:
|
|
||||||
logger.warning(f"Early access error (401): {error_message}")
|
|
||||||
return web.Response(
|
|
||||||
status=401,
|
|
||||||
text="Early Access Restriction: This LoRA requires purchase. Please buy early access on Civitai.com."
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.error(f"Error downloading LoRA: {error_message}")
|
|
||||||
return web.Response(status=500, text=error_message)
|
|
||||||
|
|
||||||
|
|
||||||
async def move_model(self, request: web.Request) -> web.Response:
|
async def move_model(self, request: web.Request) -> web.Response:
|
||||||
|
|||||||
@@ -544,71 +544,7 @@ class CheckpointsRoutes:
|
|||||||
|
|
||||||
async def download_checkpoint(self, request: web.Request) -> web.Response:
|
async def download_checkpoint(self, request: web.Request) -> web.Response:
|
||||||
"""Handle checkpoint download request"""
|
"""Handle checkpoint download request"""
|
||||||
async with self._download_lock:
|
return await ModelRouteUtils.handle_download_model(request, self.download_manager, model_type="checkpoint")
|
||||||
# Get the download manager from service registry if not already initialized
|
|
||||||
if self.download_manager is None:
|
|
||||||
self.download_manager = await ServiceRegistry.get_download_manager()
|
|
||||||
|
|
||||||
try:
|
|
||||||
data = await request.json()
|
|
||||||
|
|
||||||
# Create progress callback that uses checkpoint-specific WebSocket
|
|
||||||
async def progress_callback(progress):
|
|
||||||
await ws_manager.broadcast_checkpoint_progress({
|
|
||||||
'status': 'progress',
|
|
||||||
'progress': progress
|
|
||||||
})
|
|
||||||
|
|
||||||
# Check which identifier is provided
|
|
||||||
download_url = data.get('download_url')
|
|
||||||
model_hash = data.get('model_hash')
|
|
||||||
model_version_id = data.get('model_version_id')
|
|
||||||
|
|
||||||
# Validate that at least one identifier is provided
|
|
||||||
if not any([download_url, model_hash, model_version_id]):
|
|
||||||
return web.Response(
|
|
||||||
status=400,
|
|
||||||
text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'"
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await self.download_manager.download_from_civitai(
|
|
||||||
download_url=download_url,
|
|
||||||
model_hash=model_hash,
|
|
||||||
model_version_id=model_version_id,
|
|
||||||
save_dir=data.get('checkpoint_root'),
|
|
||||||
relative_path=data.get('relative_path', ''),
|
|
||||||
progress_callback=progress_callback,
|
|
||||||
model_type="checkpoint"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not result.get('success', False):
|
|
||||||
error_message = result.get('error', 'Unknown error')
|
|
||||||
|
|
||||||
# Return 401 for early access errors
|
|
||||||
if 'early access' in error_message.lower():
|
|
||||||
logger.warning(f"Early access download failed: {error_message}")
|
|
||||||
return web.Response(
|
|
||||||
status=401,
|
|
||||||
text=f"Early Access Restriction: {error_message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return web.Response(status=500, text=error_message)
|
|
||||||
|
|
||||||
return web.json_response(result)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error_message = str(e)
|
|
||||||
|
|
||||||
# Check if this might be an early access error
|
|
||||||
if '401' in error_message:
|
|
||||||
logger.warning(f"Early access error (401): {error_message}")
|
|
||||||
return web.Response(
|
|
||||||
status=401,
|
|
||||||
text="Early Access Restriction: This model requires purchase. Please ensure you have purchased early access and are logged in to Civitai."
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.error(f"Error downloading checkpoint: {error_message}")
|
|
||||||
return web.Response(status=500, text=error_message)
|
|
||||||
|
|
||||||
async def get_checkpoint_roots(self, request):
|
async def get_checkpoint_roots(self, request):
|
||||||
"""Return the checkpoint root directories"""
|
"""Return the checkpoint root directories"""
|
||||||
|
|||||||
@@ -267,7 +267,7 @@ class CivitaiClient:
|
|||||||
# Replace index with modelId
|
# Replace index with modelId
|
||||||
if 'index' in result:
|
if 'index' in result:
|
||||||
del result['index']
|
del result['index']
|
||||||
result['modelId'] = model_id
|
result['modelId'] = int(model_id)
|
||||||
|
|
||||||
# Add model field with metadata from top level
|
# Add model field with metadata from top level
|
||||||
result['model'] = {
|
result['model'] = {
|
||||||
|
|||||||
@@ -48,16 +48,15 @@ class DownloadManager:
|
|||||||
"""Get the checkpoint scanner from registry"""
|
"""Get the checkpoint scanner from registry"""
|
||||||
return await ServiceRegistry.get_checkpoint_scanner()
|
return await ServiceRegistry.get_checkpoint_scanner()
|
||||||
|
|
||||||
async def download_from_civitai(self, download_url: str = None, model_hash: str = None,
|
async def download_from_civitai(self, model_id: str = None,
|
||||||
model_version_id: str = None, save_dir: str = None,
|
model_version_id: str = None, save_dir: str = None,
|
||||||
relative_path: str = '', progress_callback=None,
|
relative_path: str = '', progress_callback=None,
|
||||||
model_type: str = "lora") -> Dict:
|
model_type: str = "lora") -> Dict:
|
||||||
"""Download model from Civitai
|
"""Download model from Civitai
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
download_url: Direct download URL for the model
|
model_id: Civitai model ID
|
||||||
model_hash: SHA256 hash of the model
|
model_version_id: Civitai model version ID (optional, if not provided, will download the latest version)
|
||||||
model_version_id: Civitai model version ID
|
|
||||||
save_dir: Directory to save the model to
|
save_dir: Directory to save the model to
|
||||||
relative_path: Relative path within save_dir
|
relative_path: Relative path within save_dir
|
||||||
progress_callback: Callback function for progress updates
|
progress_callback: Callback function for progress updates
|
||||||
@@ -77,25 +76,10 @@ class DownloadManager:
|
|||||||
civitai_client = await self._get_civitai_client()
|
civitai_client = await self._get_civitai_client()
|
||||||
|
|
||||||
# Get version info based on the provided identifier
|
# Get version info based on the provided identifier
|
||||||
version_info = None
|
version_info = await civitai_client.get_model_version(model_id, model_version_id)
|
||||||
error_msg = None
|
|
||||||
|
|
||||||
if model_hash:
|
|
||||||
# Get model by hash
|
|
||||||
version_info = await civitai_client.get_model_by_hash(model_hash)
|
|
||||||
elif model_version_id:
|
|
||||||
# Use model version ID directly
|
|
||||||
version_info, error_msg = await civitai_client.get_model_version_info(model_version_id)
|
|
||||||
elif download_url:
|
|
||||||
# Extract version ID from download URL
|
|
||||||
version_id = download_url.split('/')[-1]
|
|
||||||
version_info, error_msg = await civitai_client.get_model_version_info(version_id)
|
|
||||||
|
|
||||||
|
|
||||||
if not version_info:
|
if not version_info:
|
||||||
if error_msg and "model not found" in error_msg.lower():
|
return {'success': False, 'error': 'Failed to fetch model metadata'}
|
||||||
return {'success': False, 'error': f'Model not found on Civitai: {error_msg}'}
|
|
||||||
return {'success': False, 'error': error_msg or 'Failed to fetch model metadata'}
|
|
||||||
|
|
||||||
# Check if this is an early access model
|
# Check if this is an early access model
|
||||||
if version_info.get('earlyAccessEndsAt'):
|
if version_info.get('earlyAccessEndsAt'):
|
||||||
@@ -137,18 +121,6 @@ class DownloadManager:
|
|||||||
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
|
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
|
||||||
logger.info(f"Creating LoraMetadata for {file_name}")
|
logger.info(f"Creating LoraMetadata for {file_name}")
|
||||||
|
|
||||||
# 5.1 Get and update model tags, description and creator info
|
|
||||||
model_id = version_info.get('modelId')
|
|
||||||
if model_id:
|
|
||||||
model_metadata, _ = await civitai_client.get_model_metadata(str(model_id))
|
|
||||||
if model_metadata:
|
|
||||||
if model_metadata.get("tags"):
|
|
||||||
metadata.tags = model_metadata.get("tags", [])
|
|
||||||
if model_metadata.get("description"):
|
|
||||||
metadata.modelDescription = model_metadata.get("description", "")
|
|
||||||
if model_metadata.get("creator"):
|
|
||||||
metadata.civitai["creator"] = model_metadata.get("creator")
|
|
||||||
|
|
||||||
# 6. Start download process
|
# 6. Start download process
|
||||||
result = await self._execute_download(
|
result = await self._execute_download(
|
||||||
download_url=file_info.get('downloadUrl', ''),
|
download_url=file_info.get('downloadUrl', ''),
|
||||||
|
|||||||
@@ -587,15 +587,14 @@ class ModelRouteUtils:
|
|||||||
})
|
})
|
||||||
|
|
||||||
# Check which identifier is provided
|
# Check which identifier is provided
|
||||||
download_url = data.get('download_url')
|
model_id = data.get('model_id')
|
||||||
model_hash = data.get('model_hash')
|
|
||||||
model_version_id = data.get('model_version_id')
|
model_version_id = data.get('model_version_id')
|
||||||
|
|
||||||
# Validate that at least one identifier is provided
|
# Only model_id is required, model_version_id is optional
|
||||||
if not any([download_url, model_hash, model_version_id]):
|
if not model_id:
|
||||||
return web.Response(
|
return web.Response(
|
||||||
status=400,
|
status=400,
|
||||||
text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'"
|
text="Missing required parameter: Please provide 'model_id'"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use the correct root directory based on model type
|
# Use the correct root directory based on model type
|
||||||
@@ -603,8 +602,7 @@ class ModelRouteUtils:
|
|||||||
save_dir = data.get(root_key)
|
save_dir = data.get(root_key)
|
||||||
|
|
||||||
result = await download_manager.download_from_civitai(
|
result = await download_manager.download_from_civitai(
|
||||||
download_url=download_url,
|
model_id=model_id,
|
||||||
model_hash=model_hash,
|
|
||||||
model_version_id=model_version_id,
|
model_version_id=model_version_id,
|
||||||
save_dir=save_dir,
|
save_dir=save_dir,
|
||||||
relative_path=data.get('relative_path', ''),
|
relative_path=data.get('relative_path', ''),
|
||||||
|
|||||||
@@ -61,6 +61,7 @@ export class CheckpointDownloadManager {
|
|||||||
this.currentVersion = null;
|
this.currentVersion = null;
|
||||||
this.versions = [];
|
this.versions = [];
|
||||||
this.modelInfo = null;
|
this.modelInfo = null;
|
||||||
|
this.modelId = null;
|
||||||
this.modelVersionId = null;
|
this.modelVersionId = null;
|
||||||
|
|
||||||
// Clear selected folder and remove selection from UI
|
// Clear selected folder and remove selection from UI
|
||||||
@@ -79,12 +80,12 @@ export class CheckpointDownloadManager {
|
|||||||
try {
|
try {
|
||||||
this.loadingManager.showSimpleLoading('Fetching model versions...');
|
this.loadingManager.showSimpleLoading('Fetching model versions...');
|
||||||
|
|
||||||
const modelId = this.extractModelId(url);
|
this.modelId = this.extractModelId(url);
|
||||||
if (!modelId) {
|
if (!this.modelId) {
|
||||||
throw new Error('Invalid Civitai URL format');
|
throw new Error('Invalid Civitai URL format');
|
||||||
}
|
}
|
||||||
|
|
||||||
const response = await fetch(`/api/checkpoints/civitai/versions/${modelId}`);
|
const response = await fetch(`/api/checkpoints/civitai/versions/${this.modelId}`);
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
const errorData = await response.json().catch(() => ({}));
|
const errorData = await response.json().catch(() => ({}));
|
||||||
if (errorData && errorData.error && errorData.error.includes('Model type mismatch')) {
|
if (errorData && errorData.error && errorData.error.includes('Model type mismatch')) {
|
||||||
@@ -296,11 +297,6 @@ export class CheckpointDownloadManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const downloadUrl = this.currentVersion.downloadUrl;
|
|
||||||
if (!downloadUrl) {
|
|
||||||
throw new Error('No download URL available');
|
|
||||||
}
|
|
||||||
|
|
||||||
// Show enhanced loading with progress details
|
// Show enhanced loading with progress details
|
||||||
const updateProgress = this.loadingManager.showDownloadProgress(1);
|
const updateProgress = this.loadingManager.showDownloadProgress(1);
|
||||||
updateProgress(0, 0, this.currentVersion.name);
|
updateProgress(0, 0, this.currentVersion.name);
|
||||||
@@ -338,7 +334,8 @@ export class CheckpointDownloadManager {
|
|||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: { 'Content-Type': 'application/json' },
|
headers: { 'Content-Type': 'application/json' },
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
download_url: downloadUrl,
|
model_id: this.modelId,
|
||||||
|
model_version_id: this.currentVersion.id,
|
||||||
checkpoint_root: checkpointRoot,
|
checkpoint_root: checkpointRoot,
|
||||||
relative_path: targetFolder
|
relative_path: targetFolder
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ export class DownloadManager {
|
|||||||
this.currentVersion = null;
|
this.currentVersion = null;
|
||||||
this.versions = [];
|
this.versions = [];
|
||||||
this.modelInfo = null;
|
this.modelInfo = null;
|
||||||
|
this.modelId = null;
|
||||||
this.modelVersionId = null;
|
this.modelVersionId = null;
|
||||||
|
|
||||||
// Clear selected folder and remove selection from UI
|
// Clear selected folder and remove selection from UI
|
||||||
@@ -81,12 +82,12 @@ export class DownloadManager {
|
|||||||
try {
|
try {
|
||||||
this.loadingManager.showSimpleLoading('Fetching model versions...');
|
this.loadingManager.showSimpleLoading('Fetching model versions...');
|
||||||
|
|
||||||
const modelId = this.extractModelId(url);
|
this.modelId = this.extractModelId(url);
|
||||||
if (!modelId) {
|
if (!this.modelId) {
|
||||||
throw new Error('Invalid Civitai URL format');
|
throw new Error('Invalid Civitai URL format');
|
||||||
}
|
}
|
||||||
|
|
||||||
const response = await fetch(`/api/civitai/versions/${modelId}`);
|
const response = await fetch(`/api/civitai/versions/${this.modelId}`);
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
const errorData = await response.json().catch(() => ({}));
|
const errorData = await response.json().catch(() => ({}));
|
||||||
if (errorData && errorData.error && errorData.error.includes('Model type mismatch')) {
|
if (errorData && errorData.error && errorData.error.includes('Model type mismatch')) {
|
||||||
@@ -306,11 +307,6 @@ export class DownloadManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const downloadUrl = this.currentVersion.downloadUrl;
|
|
||||||
if (!downloadUrl) {
|
|
||||||
throw new Error('No download URL available');
|
|
||||||
}
|
|
||||||
|
|
||||||
// Show enhanced loading with progress details
|
// Show enhanced loading with progress details
|
||||||
const updateProgress = this.loadingManager.showDownloadProgress(1);
|
const updateProgress = this.loadingManager.showDownloadProgress(1);
|
||||||
updateProgress(0, 0, this.currentVersion.name);
|
updateProgress(0, 0, this.currentVersion.name);
|
||||||
@@ -348,7 +344,8 @@ export class DownloadManager {
|
|||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: { 'Content-Type': 'application/json' },
|
headers: { 'Content-Type': 'application/json' },
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
download_url: downloadUrl,
|
model_id: this.modelId,
|
||||||
|
model_version_id: this.currentVersion.id,
|
||||||
lora_root: loraRoot,
|
lora_root: loraRoot,
|
||||||
relative_path: targetFolder
|
relative_path: targetFolder
|
||||||
})
|
})
|
||||||
|
|||||||
Reference in New Issue
Block a user