mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
feat: Add WebSocket support for checkpoint download progress and update related components
This commit is contained in:
@@ -54,6 +54,9 @@ class CheckpointsRoutes:
|
|||||||
app.router.add_post('/api/checkpoints/download', self.download_checkpoint)
|
app.router.add_post('/api/checkpoints/download', self.download_checkpoint)
|
||||||
app.router.add_post('/api/checkpoints/save-metadata', self.save_metadata) # Add new route
|
app.router.add_post('/api/checkpoints/save-metadata', self.save_metadata) # Add new route
|
||||||
|
|
||||||
|
# Add new WebSocket endpoint for checkpoint progress
|
||||||
|
app.router.add_get('/ws/checkpoint-progress', ws_manager.handle_checkpoint_connection)
|
||||||
|
|
||||||
async def get_checkpoints(self, request):
|
async def get_checkpoints(self, request):
|
||||||
"""Get paginated checkpoint data"""
|
"""Get paginated checkpoint data"""
|
||||||
try:
|
try:
|
||||||
@@ -501,13 +504,67 @@ class CheckpointsRoutes:
|
|||||||
if self.download_manager is None:
|
if self.download_manager is None:
|
||||||
self.download_manager = await ServiceRegistry.get_download_manager()
|
self.download_manager = await ServiceRegistry.get_download_manager()
|
||||||
|
|
||||||
# Use the common download handler with model_type="checkpoint"
|
try:
|
||||||
return await ModelRouteUtils.handle_download_model(
|
data = await request.json()
|
||||||
request=request,
|
|
||||||
download_manager=self.download_manager,
|
# Create progress callback that uses checkpoint-specific WebSocket
|
||||||
|
async def progress_callback(progress):
|
||||||
|
await ws_manager.broadcast_checkpoint_progress({
|
||||||
|
'status': 'progress',
|
||||||
|
'progress': progress
|
||||||
|
})
|
||||||
|
|
||||||
|
# Check which identifier is provided
|
||||||
|
download_url = data.get('download_url')
|
||||||
|
model_hash = data.get('model_hash')
|
||||||
|
model_version_id = data.get('model_version_id')
|
||||||
|
|
||||||
|
# Validate that at least one identifier is provided
|
||||||
|
if not any([download_url, model_hash, model_version_id]):
|
||||||
|
return web.Response(
|
||||||
|
status=400,
|
||||||
|
text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await self.download_manager.download_from_civitai(
|
||||||
|
download_url=download_url,
|
||||||
|
model_hash=model_hash,
|
||||||
|
model_version_id=model_version_id,
|
||||||
|
save_dir=data.get('checkpoint_root'),
|
||||||
|
relative_path=data.get('relative_path', ''),
|
||||||
|
progress_callback=progress_callback,
|
||||||
model_type="checkpoint"
|
model_type="checkpoint"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not result.get('success', False):
|
||||||
|
error_message = result.get('error', 'Unknown error')
|
||||||
|
|
||||||
|
# Return 401 for early access errors
|
||||||
|
if 'early access' in error_message.lower():
|
||||||
|
logger.warning(f"Early access download failed: {error_message}")
|
||||||
|
return web.Response(
|
||||||
|
status=401,
|
||||||
|
text=f"Early Access Restriction: {error_message}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return web.Response(status=500, text=error_message)
|
||||||
|
|
||||||
|
return web.json_response(result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_message = str(e)
|
||||||
|
|
||||||
|
# Check if this might be an early access error
|
||||||
|
if '401' in error_message:
|
||||||
|
logger.warning(f"Early access error (401): {error_message}")
|
||||||
|
return web.Response(
|
||||||
|
status=401,
|
||||||
|
text="Early Access Restriction: This model requires purchase. Please ensure you have purchased early access and are logged in to Civitai."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.error(f"Error downloading checkpoint: {error_message}")
|
||||||
|
return web.Response(status=500, text=error_message)
|
||||||
|
|
||||||
async def get_checkpoint_roots(self, request):
|
async def get_checkpoint_roots(self, request):
|
||||||
"""Return the checkpoint root directories"""
|
"""Return the checkpoint root directories"""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ class WebSocketManager:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._websockets: Set[web.WebSocketResponse] = set()
|
self._websockets: Set[web.WebSocketResponse] = set()
|
||||||
self._init_websockets: Set[web.WebSocketResponse] = set() # New set for initialization progress clients
|
self._init_websockets: Set[web.WebSocketResponse] = set() # New set for initialization progress clients
|
||||||
|
self._checkpoint_websockets: Set[web.WebSocketResponse] = set() # New set for checkpoint download progress
|
||||||
|
|
||||||
async def handle_connection(self, request: web.Request) -> web.WebSocketResponse:
|
async def handle_connection(self, request: web.Request) -> web.WebSocketResponse:
|
||||||
"""Handle new WebSocket connection"""
|
"""Handle new WebSocket connection"""
|
||||||
@@ -39,6 +40,20 @@ class WebSocketManager:
|
|||||||
self._init_websockets.discard(ws)
|
self._init_websockets.discard(ws)
|
||||||
return ws
|
return ws
|
||||||
|
|
||||||
|
async def handle_checkpoint_connection(self, request: web.Request) -> web.WebSocketResponse:
|
||||||
|
"""Handle new WebSocket connection for checkpoint download progress"""
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
self._checkpoint_websockets.add(ws)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for msg in ws:
|
||||||
|
if msg.type == web.WSMsgType.ERROR:
|
||||||
|
logger.error(f'Checkpoint WebSocket error: {ws.exception()}')
|
||||||
|
finally:
|
||||||
|
self._checkpoint_websockets.discard(ws)
|
||||||
|
return ws
|
||||||
|
|
||||||
async def broadcast(self, data: Dict):
|
async def broadcast(self, data: Dict):
|
||||||
"""Broadcast message to all connected clients"""
|
"""Broadcast message to all connected clients"""
|
||||||
if not self._websockets:
|
if not self._websockets:
|
||||||
@@ -69,6 +84,17 @@ class WebSocketManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error sending initialization progress: {e}")
|
logger.error(f"Error sending initialization progress: {e}")
|
||||||
|
|
||||||
|
async def broadcast_checkpoint_progress(self, data: Dict):
|
||||||
|
"""Broadcast checkpoint download progress to connected clients"""
|
||||||
|
if not self._checkpoint_websockets:
|
||||||
|
return
|
||||||
|
|
||||||
|
for ws in self._checkpoint_websockets:
|
||||||
|
try:
|
||||||
|
await ws.send_json(data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error sending checkpoint progress: {e}")
|
||||||
|
|
||||||
def get_connected_clients_count(self) -> int:
|
def get_connected_clients_count(self) -> int:
|
||||||
"""Get number of connected clients"""
|
"""Get number of connected clients"""
|
||||||
return len(self._websockets)
|
return len(self._websockets)
|
||||||
@@ -77,5 +103,9 @@ class WebSocketManager:
|
|||||||
"""Get number of initialization progress clients"""
|
"""Get number of initialization progress clients"""
|
||||||
return len(self._init_websockets)
|
return len(self._init_websockets)
|
||||||
|
|
||||||
|
def get_checkpoint_clients_count(self) -> int:
|
||||||
|
"""Get number of checkpoint progress clients"""
|
||||||
|
return len(self._checkpoint_websockets)
|
||||||
|
|
||||||
# Global instance
|
# Global instance
|
||||||
ws_manager = WebSocketManager()
|
ws_manager = WebSocketManager()
|
||||||
@@ -292,9 +292,9 @@ export class CheckpointDownloadManager {
|
|||||||
const updateProgress = this.loadingManager.showDownloadProgress(1);
|
const updateProgress = this.loadingManager.showDownloadProgress(1);
|
||||||
updateProgress(0, 0, this.currentVersion.name);
|
updateProgress(0, 0, this.currentVersion.name);
|
||||||
|
|
||||||
// Setup WebSocket for progress updates
|
// Setup WebSocket for progress updates using checkpoint-specific endpoint
|
||||||
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
|
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
|
||||||
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`);
|
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/checkpoint-progress`);
|
||||||
|
|
||||||
ws.onmessage = (event) => {
|
ws.onmessage = (event) => {
|
||||||
const data = JSON.parse(event.data);
|
const data = JSON.parse(event.data);
|
||||||
|
|||||||
Reference in New Issue
Block a user