mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
refactor: Implement download management routes and update API endpoints for LoRA
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
@@ -7,6 +8,7 @@ from typing import Dict
|
|||||||
import jinja2
|
import jinja2
|
||||||
|
|
||||||
from ..utils.routes_common import ModelRouteUtils
|
from ..utils.routes_common import ModelRouteUtils
|
||||||
|
from ..services.service_registry import ServiceRegistry
|
||||||
from ..services.websocket_manager import ws_manager
|
from ..services.websocket_manager import ws_manager
|
||||||
from ..services.settings_manager import settings
|
from ..services.settings_manager import settings
|
||||||
from ..config import config
|
from ..config import config
|
||||||
@@ -56,6 +58,12 @@ class BaseModelRoutes(ABC):
|
|||||||
app.router.add_get(f'/api/{prefix}/find-duplicates', self.find_duplicate_models)
|
app.router.add_get(f'/api/{prefix}/find-duplicates', self.find_duplicate_models)
|
||||||
app.router.add_get(f'/api/{prefix}/find-filename-conflicts', self.find_filename_conflicts)
|
app.router.add_get(f'/api/{prefix}/find-filename-conflicts', self.find_filename_conflicts)
|
||||||
|
|
||||||
|
# Common Download management
|
||||||
|
app.router.add_post(f'/api/download-model', self.download_model)
|
||||||
|
app.router.add_get(f'/api/download-model-get', self.download_model_get)
|
||||||
|
app.router.add_get(f'/api/cancel-download-get', self.cancel_download_get)
|
||||||
|
app.router.add_get(f'/api/download-progress/{{download_id}}', self.get_download_progress)
|
||||||
|
|
||||||
# CivitAI integration routes
|
# CivitAI integration routes
|
||||||
app.router.add_post(f'/api/{prefix}/fetch-all-civitai', self.fetch_all_civitai)
|
app.router.add_post(f'/api/{prefix}/fetch-all-civitai', self.fetch_all_civitai)
|
||||||
app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions)
|
app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions)
|
||||||
@@ -409,6 +417,111 @@ class BaseModelRoutes(ABC):
|
|||||||
"error": str(e)
|
"error": str(e)
|
||||||
}, status=500)
|
}, status=500)
|
||||||
|
|
||||||
|
# Download management methods
|
||||||
|
async def download_model(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle model download request"""
|
||||||
|
return await ModelRouteUtils.handle_download_model(request)
|
||||||
|
|
||||||
|
async def download_model_get(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle model download request via GET method"""
|
||||||
|
try:
|
||||||
|
# Extract query parameters
|
||||||
|
model_id = request.query.get('model_id')
|
||||||
|
if not model_id:
|
||||||
|
return web.Response(
|
||||||
|
status=400,
|
||||||
|
text="Missing required parameter: Please provide 'model_id'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get optional parameters
|
||||||
|
model_version_id = request.query.get('model_version_id')
|
||||||
|
download_id = request.query.get('download_id')
|
||||||
|
use_default_paths = request.query.get('use_default_paths', 'false').lower() == 'true'
|
||||||
|
|
||||||
|
# Create a data dictionary that mimics what would be received from a POST request
|
||||||
|
data = {
|
||||||
|
'model_id': model_id
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add optional parameters only if they are provided
|
||||||
|
if model_version_id:
|
||||||
|
data['model_version_id'] = model_version_id
|
||||||
|
|
||||||
|
if download_id:
|
||||||
|
data['download_id'] = download_id
|
||||||
|
|
||||||
|
data['use_default_paths'] = use_default_paths
|
||||||
|
|
||||||
|
# Create a mock request object with the data
|
||||||
|
future = asyncio.get_event_loop().create_future()
|
||||||
|
future.set_result(data)
|
||||||
|
|
||||||
|
mock_request = type('MockRequest', (), {
|
||||||
|
'json': lambda self=None: future
|
||||||
|
})()
|
||||||
|
|
||||||
|
# Call the existing download handler
|
||||||
|
return await ModelRouteUtils.handle_download_model(mock_request)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_message = str(e)
|
||||||
|
logger.error(f"Error downloading model via GET: {error_message}", exc_info=True)
|
||||||
|
return web.Response(status=500, text=error_message)
|
||||||
|
|
||||||
|
async def cancel_download_get(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle GET request for cancelling a download by download_id"""
|
||||||
|
try:
|
||||||
|
download_id = request.query.get('download_id')
|
||||||
|
if not download_id:
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': 'Download ID is required'
|
||||||
|
}, status=400)
|
||||||
|
|
||||||
|
# Create a mock request with match_info for compatibility
|
||||||
|
mock_request = type('MockRequest', (), {
|
||||||
|
'match_info': {'download_id': download_id}
|
||||||
|
})()
|
||||||
|
return await ModelRouteUtils.handle_cancel_download(mock_request)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error cancelling download via GET: {e}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
async def get_download_progress(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle request for download progress by download_id"""
|
||||||
|
try:
|
||||||
|
# Get download_id from URL path
|
||||||
|
download_id = request.match_info.get('download_id')
|
||||||
|
if not download_id:
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': 'Download ID is required'
|
||||||
|
}, status=400)
|
||||||
|
|
||||||
|
# Get progress information from websocket manager
|
||||||
|
from ..services.websocket_manager import ws_manager
|
||||||
|
progress_data = ws_manager.get_download_progress(download_id)
|
||||||
|
|
||||||
|
if progress_data is None:
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': 'Download ID not found'
|
||||||
|
}, status=404)
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
'success': True,
|
||||||
|
'progress': progress_data.get('progress', 0)
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting download progress: {e}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
|
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
|
||||||
"""Fetch CivitAI metadata for all models in the background"""
|
"""Fetch CivitAI metadata for all models in the background"""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -20,8 +20,6 @@ class LoraRoutes(BaseModelRoutes):
|
|||||||
# Service will be initialized later via setup_routes
|
# Service will be initialized later via setup_routes
|
||||||
self.service = None
|
self.service = None
|
||||||
self.civitai_client = None
|
self.civitai_client = None
|
||||||
self.download_manager = None
|
|
||||||
self._download_lock = asyncio.Lock()
|
|
||||||
self.template_name = "loras.html"
|
self.template_name = "loras.html"
|
||||||
|
|
||||||
async def initialize_services(self):
|
async def initialize_services(self):
|
||||||
@@ -29,7 +27,6 @@ class LoraRoutes(BaseModelRoutes):
|
|||||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||||
self.service = LoraService(lora_scanner)
|
self.service = LoraService(lora_scanner)
|
||||||
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||||
self.download_manager = await ServiceRegistry.get_download_manager()
|
|
||||||
|
|
||||||
# Initialize parent with the service
|
# Initialize parent with the service
|
||||||
super().__init__(self.service)
|
super().__init__(self.service)
|
||||||
@@ -63,22 +60,9 @@ class LoraRoutes(BaseModelRoutes):
|
|||||||
app.router.add_get(f'/api/civitai/model/version/{{modelVersionId}}', self.get_civitai_model_by_version)
|
app.router.add_get(f'/api/civitai/model/version/{{modelVersionId}}', self.get_civitai_model_by_version)
|
||||||
app.router.add_get(f'/api/civitai/model/hash/{{hash}}', self.get_civitai_model_by_hash)
|
app.router.add_get(f'/api/civitai/model/hash/{{hash}}', self.get_civitai_model_by_hash)
|
||||||
|
|
||||||
# Download management
|
|
||||||
app.router.add_post(f'/api/download-model', self.download_model)
|
|
||||||
app.router.add_get(f'/api/download-model-get', self.download_model_get)
|
|
||||||
app.router.add_get(f'/api/cancel-download-get', self.cancel_download_get)
|
|
||||||
app.router.add_get(f'/api/download-progress/{{download_id}}', self.get_download_progress)
|
|
||||||
|
|
||||||
# ComfyUI integration
|
# ComfyUI integration
|
||||||
app.router.add_post(f'/loramanager/get_trigger_words', self.get_trigger_words)
|
app.router.add_post(f'/loramanager/get_trigger_words', self.get_trigger_words)
|
||||||
|
|
||||||
# Legacy API compatibility
|
|
||||||
app.router.add_post(f'/api/delete_model', self.delete_model)
|
|
||||||
app.router.add_post(f'/api/fetch-civitai', self.fetch_civitai)
|
|
||||||
app.router.add_post(f'/api/relink-civitai', self.relink_civitai)
|
|
||||||
app.router.add_post(f'/api/replace_preview', self.replace_preview)
|
|
||||||
app.router.add_post(f'/api/fetch-all-civitai', self.fetch_all_civitai)
|
|
||||||
|
|
||||||
def _parse_specific_params(self, request: web.Request) -> Dict:
|
def _parse_specific_params(self, request: web.Request) -> Dict:
|
||||||
"""Parse LoRA-specific parameters"""
|
"""Parse LoRA-specific parameters"""
|
||||||
params = {}
|
params = {}
|
||||||
@@ -358,111 +342,6 @@ class LoraRoutes(BaseModelRoutes):
|
|||||||
"error": str(e)
|
"error": str(e)
|
||||||
}, status=500)
|
}, status=500)
|
||||||
|
|
||||||
# Download management methods
|
|
||||||
async def download_model(self, request: web.Request) -> web.Response:
|
|
||||||
"""Handle model download request"""
|
|
||||||
return await ModelRouteUtils.handle_download_model(request, self.download_manager)
|
|
||||||
|
|
||||||
async def download_model_get(self, request: web.Request) -> web.Response:
|
|
||||||
"""Handle model download request via GET method"""
|
|
||||||
try:
|
|
||||||
# Extract query parameters
|
|
||||||
model_id = request.query.get('model_id')
|
|
||||||
if not model_id:
|
|
||||||
return web.Response(
|
|
||||||
status=400,
|
|
||||||
text="Missing required parameter: Please provide 'model_id'"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get optional parameters
|
|
||||||
model_version_id = request.query.get('model_version_id')
|
|
||||||
download_id = request.query.get('download_id')
|
|
||||||
use_default_paths = request.query.get('use_default_paths', 'false').lower() == 'true'
|
|
||||||
|
|
||||||
# Create a data dictionary that mimics what would be received from a POST request
|
|
||||||
data = {
|
|
||||||
'model_id': model_id
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add optional parameters only if they are provided
|
|
||||||
if model_version_id:
|
|
||||||
data['model_version_id'] = model_version_id
|
|
||||||
|
|
||||||
if download_id:
|
|
||||||
data['download_id'] = download_id
|
|
||||||
|
|
||||||
data['use_default_paths'] = use_default_paths
|
|
||||||
|
|
||||||
# Create a mock request object with the data
|
|
||||||
future = asyncio.get_event_loop().create_future()
|
|
||||||
future.set_result(data)
|
|
||||||
|
|
||||||
mock_request = type('MockRequest', (), {
|
|
||||||
'json': lambda self=None: future
|
|
||||||
})()
|
|
||||||
|
|
||||||
# Call the existing download handler
|
|
||||||
return await ModelRouteUtils.handle_download_model(mock_request, self.download_manager)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error_message = str(e)
|
|
||||||
logger.error(f"Error downloading model via GET: {error_message}", exc_info=True)
|
|
||||||
return web.Response(status=500, text=error_message)
|
|
||||||
|
|
||||||
async def cancel_download_get(self, request: web.Request) -> web.Response:
|
|
||||||
"""Handle GET request for cancelling a download by download_id"""
|
|
||||||
try:
|
|
||||||
download_id = request.query.get('download_id')
|
|
||||||
if not download_id:
|
|
||||||
return web.json_response({
|
|
||||||
'success': False,
|
|
||||||
'error': 'Download ID is required'
|
|
||||||
}, status=400)
|
|
||||||
|
|
||||||
# Create a mock request with match_info for compatibility
|
|
||||||
mock_request = type('MockRequest', (), {
|
|
||||||
'match_info': {'download_id': download_id}
|
|
||||||
})()
|
|
||||||
return await ModelRouteUtils.handle_cancel_download(mock_request, self.download_manager)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error cancelling download via GET: {e}", exc_info=True)
|
|
||||||
return web.json_response({
|
|
||||||
'success': False,
|
|
||||||
'error': str(e)
|
|
||||||
}, status=500)
|
|
||||||
|
|
||||||
async def get_download_progress(self, request: web.Request) -> web.Response:
|
|
||||||
"""Handle request for download progress by download_id"""
|
|
||||||
try:
|
|
||||||
# Get download_id from URL path
|
|
||||||
download_id = request.match_info.get('download_id')
|
|
||||||
if not download_id:
|
|
||||||
return web.json_response({
|
|
||||||
'success': False,
|
|
||||||
'error': 'Download ID is required'
|
|
||||||
}, status=400)
|
|
||||||
|
|
||||||
# Get progress information from websocket manager
|
|
||||||
from ..services.websocket_manager import ws_manager
|
|
||||||
progress_data = ws_manager.get_download_progress(download_id)
|
|
||||||
|
|
||||||
if progress_data is None:
|
|
||||||
return web.json_response({
|
|
||||||
'success': False,
|
|
||||||
'error': 'Download ID not found'
|
|
||||||
}, status=404)
|
|
||||||
|
|
||||||
return web.json_response({
|
|
||||||
'success': True,
|
|
||||||
'progress': progress_data.get('progress', 0)
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting download progress: {e}", exc_info=True)
|
|
||||||
return web.json_response({
|
|
||||||
'success': False,
|
|
||||||
'error': str(e)
|
|
||||||
}, status=500)
|
|
||||||
|
|
||||||
# Model management methods
|
# Model management methods
|
||||||
async def move_model(self, request: web.Request) -> web.Response:
|
async def move_model(self, request: web.Request) -> web.Response:
|
||||||
"""Handle model move request"""
|
"""Handle model move request"""
|
||||||
|
|||||||
@@ -566,9 +566,10 @@ class ModelRouteUtils:
|
|||||||
return web.Response(text=str(e), status=500)
|
return web.Response(text=str(e), status=500)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def handle_download_model(request: web.Request, download_manager: DownloadManager) -> web.Response:
|
async def handle_download_model(request: web.Request) -> web.Response:
|
||||||
"""Handle model download request"""
|
"""Handle model download request"""
|
||||||
try:
|
try:
|
||||||
|
download_manager = await ServiceRegistry.get_download_manager()
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
|
|
||||||
# Get or generate a download ID
|
# Get or generate a download ID
|
||||||
@@ -663,17 +664,17 @@ class ModelRouteUtils:
|
|||||||
}, status=500)
|
}, status=500)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def handle_cancel_download(request: web.Request, download_manager: DownloadManager) -> web.Response:
|
async def handle_cancel_download(request: web.Request) -> web.Response:
|
||||||
"""Handle cancellation of a download task
|
"""Handle cancellation of a download task
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: The aiohttp request
|
request: The aiohttp request
|
||||||
download_manager: The download manager instance
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
web.Response: The HTTP response
|
web.Response: The HTTP response
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
download_manager = await ServiceRegistry.get_download_manager()
|
||||||
download_id = request.match_info.get('download_id')
|
download_id = request.match_info.get('download_id')
|
||||||
if not download_id:
|
if not download_id:
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
@@ -701,17 +702,17 @@ class ModelRouteUtils:
|
|||||||
}, status=500)
|
}, status=500)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def handle_list_downloads(request: web.Request, download_manager: DownloadManager) -> web.Response:
|
async def handle_list_downloads(request: web.Request) -> web.Response:
|
||||||
"""Get list of active downloads
|
"""Get list of active downloads
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: The aiohttp request
|
request: The aiohttp request
|
||||||
download_manager: The download manager instance
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
web.Response: The HTTP response with list of downloads
|
web.Response: The HTTP response with list of downloads
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
download_manager = await ServiceRegistry.get_download_manager()
|
||||||
result = await download_manager.get_active_downloads()
|
result = await download_manager.get_active_downloads()
|
||||||
return web.json_response(result)
|
return web.json_response(result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -277,7 +277,7 @@ export async function deleteModel(filePath, modelType = 'lora') {
|
|||||||
|
|
||||||
const endpoint = modelType === 'checkpoint'
|
const endpoint = modelType === 'checkpoint'
|
||||||
? '/api/checkpoints/delete'
|
? '/api/checkpoints/delete'
|
||||||
: '/api/delete_model';
|
: '/api/loras/delete';
|
||||||
|
|
||||||
const response = await fetch(endpoint, {
|
const response = await fetch(endpoint, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
@@ -454,7 +454,7 @@ export async function refreshSingleModelMetadata(filePath, modelType = 'lora') {
|
|||||||
|
|
||||||
const endpoint = modelType === 'checkpoint'
|
const endpoint = modelType === 'checkpoint'
|
||||||
? '/api/checkpoints/fetch-civitai'
|
? '/api/checkpoints/fetch-civitai'
|
||||||
: '/api/fetch-civitai';
|
: '/api/loras/fetch-civitai';
|
||||||
|
|
||||||
const response = await fetch(endpoint, {
|
const response = await fetch(endpoint, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
@@ -557,7 +557,7 @@ export async function uploadPreview(filePath, file, modelType = 'lora', nsfwLeve
|
|||||||
// Set endpoint based on model type
|
// Set endpoint based on model type
|
||||||
const endpoint = modelType === 'checkpoint'
|
const endpoint = modelType === 'checkpoint'
|
||||||
? '/api/checkpoints/replace-preview'
|
? '/api/checkpoints/replace-preview'
|
||||||
: '/api/replace_preview';
|
: '/api/loras/replace_preview';
|
||||||
|
|
||||||
const response = await fetch(endpoint, {
|
const response = await fetch(endpoint, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ export async function fetchLorasPage(page = 1, pageSize = 100) {
|
|||||||
export async function fetchCivitai() {
|
export async function fetchCivitai() {
|
||||||
return fetchCivitaiMetadata({
|
return fetchCivitaiMetadata({
|
||||||
modelType: 'lora',
|
modelType: 'lora',
|
||||||
fetchEndpoint: '/api/fetch-all-civitai',
|
fetchEndpoint: '/api/loras/fetch-all-civitai',
|
||||||
resetAndReloadFunction: resetAndReload
|
resetAndReloadFunction: resetAndReload
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -125,7 +125,7 @@ export const ModelContextMenuMixin = {
|
|||||||
|
|
||||||
const endpoint = this.modelType === 'checkpoint' ?
|
const endpoint = this.modelType === 'checkpoint' ?
|
||||||
'/api/checkpoints/relink-civitai' :
|
'/api/checkpoints/relink-civitai' :
|
||||||
'/api/relink-civitai';
|
'/api/loras/relink-civitai';
|
||||||
|
|
||||||
const response = await fetch(endpoint, {
|
const response = await fetch(endpoint, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
|
|||||||
Reference in New Issue
Block a user