feat: Rename download endpoint from /api/download-lora to /api/download-model and update related logic

This commit is contained in:
Will Miao
2025-07-02 19:21:25 +08:00
parent 971cd56a4a
commit ab22d16bad
8 changed files with 25 additions and 25 deletions

View File

@@ -56,7 +56,7 @@ class ApiRoutes:
app.router.add_get('/api/civitai/versions/{model_id}', routes.get_civitai_versions) app.router.add_get('/api/civitai/versions/{model_id}', routes.get_civitai_versions)
app.router.add_get('/api/civitai/model/version/{modelVersionId}', routes.get_civitai_model_by_version) app.router.add_get('/api/civitai/model/version/{modelVersionId}', routes.get_civitai_model_by_version)
app.router.add_get('/api/civitai/model/hash/{hash}', routes.get_civitai_model_by_hash) app.router.add_get('/api/civitai/model/hash/{hash}', routes.get_civitai_model_by_hash)
app.router.add_post('/api/download-lora', routes.download_lora) app.router.add_post('/api/download-model', routes.download_model)
app.router.add_post('/api/move_model', routes.move_model) app.router.add_post('/api/move_model', routes.move_model)
app.router.add_get('/api/lora-model-description', routes.get_lora_model_description) # Add new route app.router.add_get('/api/lora-model-description', routes.get_lora_model_description) # Add new route
app.router.add_post('/api/loras/save-metadata', routes.save_metadata) app.router.add_post('/api/loras/save-metadata', routes.save_metadata)
@@ -436,8 +436,8 @@ class ApiRoutes:
"error": str(e) "error": str(e)
}, status=500) }, status=500)
async def download_lora(self, request: web.Request) -> web.Response: async def download_model(self, request: web.Request) -> web.Response:
return await ModelRouteUtils.handle_download_model(request, self.download_manager, model_type="lora") return await ModelRouteUtils.handle_download_model(request, self.download_manager)
async def move_model(self, request: web.Request) -> web.Response: async def move_model(self, request: web.Request) -> web.Response:

View File

@@ -54,7 +54,6 @@ class CheckpointsRoutes:
app.router.add_post('/api/checkpoints/fetch-civitai', self.fetch_civitai) app.router.add_post('/api/checkpoints/fetch-civitai', self.fetch_civitai)
app.router.add_post('/api/checkpoints/relink-civitai', self.relink_civitai) # Add new relink endpoint app.router.add_post('/api/checkpoints/relink-civitai', self.relink_civitai) # Add new relink endpoint
app.router.add_post('/api/checkpoints/replace-preview', self.replace_preview) app.router.add_post('/api/checkpoints/replace-preview', self.replace_preview)
app.router.add_post('/api/checkpoints/download', self.download_checkpoint)
app.router.add_post('/api/checkpoints/save-metadata', self.save_metadata) # Add new route app.router.add_post('/api/checkpoints/save-metadata', self.save_metadata) # Add new route
app.router.add_post('/api/checkpoints/rename', self.rename_checkpoint) # Add new rename endpoint app.router.add_post('/api/checkpoints/rename', self.rename_checkpoint) # Add new rename endpoint
@@ -539,10 +538,6 @@ class CheckpointsRoutes:
"""Handle preview image replacement for checkpoints""" """Handle preview image replacement for checkpoints"""
return await ModelRouteUtils.handle_replace_preview(request, self.scanner) return await ModelRouteUtils.handle_replace_preview(request, self.scanner)
async def download_checkpoint(self, request: web.Request) -> web.Response:
"""Handle checkpoint download request"""
return await ModelRouteUtils.handle_download_model(request, self.download_manager, model_type="checkpoint")
async def get_checkpoint_roots(self, request): async def get_checkpoint_roots(self, request):
"""Return the checkpoint root directories""" """Return the checkpoint root directories"""
try: try:

View File

@@ -1,10 +1,9 @@
import logging import logging
import os import os
import json
import asyncio import asyncio
from typing import Dict from typing import Dict
from ..utils.models import LoraMetadata, CheckpointMetadata from ..utils.models import LoraMetadata, CheckpointMetadata
from ..utils.constants import CARD_PREVIEW_WIDTH from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES
from ..utils.exif_utils import ExifUtils from ..utils.exif_utils import ExifUtils
from ..utils.metadata_manager import MetadataManager from ..utils.metadata_manager import MetadataManager
from .service_registry import ServiceRegistry from .service_registry import ServiceRegistry
@@ -51,7 +50,7 @@ class DownloadManager:
async def download_from_civitai(self, model_id: str = None, async def download_from_civitai(self, model_id: str = None,
model_version_id: str = None, save_dir: str = None, model_version_id: str = None, save_dir: str = None,
relative_path: str = '', progress_callback=None, relative_path: str = '', progress_callback=None,
model_type: str = "lora") -> Dict: model_type: str = None) -> Dict:
"""Download model from Civitai """Download model from Civitai
Args: Args:
@@ -81,6 +80,16 @@ class DownloadManager:
if not version_info: if not version_info:
return {'success': False, 'error': 'Failed to fetch model metadata'} return {'success': False, 'error': 'Failed to fetch model metadata'}
# Infer model_type if not provided
if model_type is None:
model_type_from_info = version_info.get('model', {}).get('type', '').lower()
if model_type_from_info == 'checkpoint':
model_type = 'checkpoint'
elif model_type_from_info in VALID_LORA_TYPES:
model_type = 'lora'
else:
return {'success': False, 'error': f'Model type "{model_type_from_info}" is not supported for download'}
# Check if this is an early access model # Check if this is an early access model
if version_info.get('earlyAccessEndsAt'): if version_info.get('earlyAccessEndsAt'):
early_access_date = version_info.get('earlyAccessEndsAt', '') early_access_date = version_info.get('earlyAccessEndsAt', '')

View File

@@ -564,7 +564,7 @@ class ModelRouteUtils:
return web.Response(text=str(e), status=500) return web.Response(text=str(e), status=500)
@staticmethod @staticmethod
async def handle_download_model(request: web.Request, download_manager: DownloadManager, model_type="lora") -> web.Response: async def handle_download_model(request: web.Request, download_manager: DownloadManager, model_type=None) -> web.Response:
"""Handle model download request """Handle model download request
Args: Args:
@@ -597,14 +597,10 @@ class ModelRouteUtils:
text="Missing required parameter: Please provide 'model_id'" text="Missing required parameter: Please provide 'model_id'"
) )
# Use the correct root directory based on model type
root_key = 'checkpoint_root' if model_type == 'checkpoint' else 'lora_root'
save_dir = data.get(root_key)
result = await download_manager.download_from_civitai( result = await download_manager.download_from_civitai(
model_id=model_id, model_id=model_id,
model_version_id=model_version_id, model_version_id=model_version_id,
save_dir=save_dir, save_dir=data.get('model_root'),
relative_path=data.get('relative_path', ''), relative_path=data.get('relative_path', ''),
progress_callback=progress_callback, progress_callback=progress_callback,
model_type=model_type model_type=model_type
@@ -636,7 +632,7 @@ class ModelRouteUtils:
text="Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com." text="Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com."
) )
logger.error(f"Error downloading {model_type}: {error_message}") logger.error(f"Error downloading model: {error_message}")
return web.Response(status=500, text=error_message) return web.Response(status=500, text=error_message)
@staticmethod @staticmethod

View File

@@ -330,13 +330,13 @@ export class CheckpointDownloadManager {
}; };
// Start download using checkpoint download endpoint // Start download using checkpoint download endpoint
const response = await fetch('/api/checkpoints/download', { const response = await fetch('/api/download-model', {
method: 'POST', method: 'POST',
headers: { 'Content-Type': 'application/json' }, headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ body: JSON.stringify({
model_id: this.modelId, model_id: this.modelId,
model_version_id: this.currentVersion.id, model_version_id: this.currentVersion.id,
checkpoint_root: checkpointRoot, model_root: checkpointRoot,
relative_path: targetFolder relative_path: targetFolder
}) })
}); });

View File

@@ -340,13 +340,13 @@ export class DownloadManager {
}; };
// Start download // Start download
const response = await fetch('/api/download-lora', { const response = await fetch('/api/download-model', {
method: 'POST', method: 'POST',
headers: { 'Content-Type': 'application/json' }, headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ body: JSON.stringify({
model_id: this.modelId, model_id: this.modelId,
model_version_id: this.currentVersion.id, model_version_id: this.currentVersion.id,
lora_root: loraRoot, model_root: loraRoot,
relative_path: targetFolder relative_path: targetFolder
}) })
}); });

View File

@@ -189,14 +189,14 @@ export class DownloadManager {
try { try {
// Download the LoRA // Download the LoRA
const response = await fetch('/api/download-lora', { const response = await fetch('/api/download-model', {
method: 'POST', method: 'POST',
headers: { 'Content-Type': 'application/json' }, headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ body: JSON.stringify({
download_url: lora.downloadUrl, download_url: lora.downloadUrl,
model_version_id: lora.modelVersionId, model_version_id: lora.modelVersionId,
model_hash: lora.hash, model_hash: lora.hash,
lora_root: loraRoot, model_root: loraRoot,
relative_path: targetPath.replace(loraRoot + '/', '') relative_path: targetPath.replace(loraRoot + '/', '')
}) })
}); });

View File

@@ -7,7 +7,7 @@ export const apiRoutes = {
delete: (id) => `/api/loras/${id}`, delete: (id) => `/api/loras/${id}`,
update: (id) => `/api/loras/${id}`, update: (id) => `/api/loras/${id}`,
civitai: (id) => `/api/loras/${id}/civitai`, civitai: (id) => `/api/loras/${id}/civitai`,
download: '/api/download-lora', download: '/api/download-model',
move: '/api/move-lora', move: '/api/move-lora',
scan: '/api/scan-loras' scan: '/api/scan-loras'
}, },