feat: Add WebSocket support for checkpoint download progress and update related components

This commit is contained in:
Will Miao
2025-04-13 21:31:01 +08:00
parent 9822f2c614
commit 76fc9e5a3d
3 changed files with 95 additions and 8 deletions

View File

@@ -53,6 +53,9 @@ class CheckpointsRoutes:
app.router.add_post('/api/checkpoints/replace-preview', self.replace_preview)
app.router.add_post('/api/checkpoints/download', self.download_checkpoint)
app.router.add_post('/api/checkpoints/save-metadata', self.save_metadata) # Add new route
# Add new WebSocket endpoint for checkpoint progress
app.router.add_get('/ws/checkpoint-progress', ws_manager.handle_checkpoint_connection)
async def get_checkpoints(self, request):
"""Get paginated checkpoint data"""
@@ -501,12 +504,66 @@ class CheckpointsRoutes:
if self.download_manager is None:
self.download_manager = await ServiceRegistry.get_download_manager()
# Use the common download handler with model_type="checkpoint"
return await ModelRouteUtils.handle_download_model(
request=request,
download_manager=self.download_manager,
model_type="checkpoint"
)
try:
data = await request.json()
# Create progress callback that uses checkpoint-specific WebSocket
async def progress_callback(progress):
await ws_manager.broadcast_checkpoint_progress({
'status': 'progress',
'progress': progress
})
# Check which identifier is provided
download_url = data.get('download_url')
model_hash = data.get('model_hash')
model_version_id = data.get('model_version_id')
# Validate that at least one identifier is provided
if not any([download_url, model_hash, model_version_id]):
return web.Response(
status=400,
text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'"
)
result = await self.download_manager.download_from_civitai(
download_url=download_url,
model_hash=model_hash,
model_version_id=model_version_id,
save_dir=data.get('checkpoint_root'),
relative_path=data.get('relative_path', ''),
progress_callback=progress_callback,
model_type="checkpoint"
)
if not result.get('success', False):
error_message = result.get('error', 'Unknown error')
# Return 401 for early access errors
if 'early access' in error_message.lower():
logger.warning(f"Early access download failed: {error_message}")
return web.Response(
status=401,
text=f"Early Access Restriction: {error_message}"
)
return web.Response(status=500, text=error_message)
return web.json_response(result)
except Exception as e:
error_message = str(e)
# Check if this might be an early access error
if '401' in error_message:
logger.warning(f"Early access error (401): {error_message}")
return web.Response(
status=401,
text="Early Access Restriction: This model requires purchase. Please ensure you have purchased early access and are logged in to Civitai."
)
logger.error(f"Error downloading checkpoint: {error_message}")
return web.Response(status=500, text=error_message)
async def get_checkpoint_roots(self, request):
"""Return the checkpoint root directories"""

View File

@@ -10,6 +10,7 @@ class WebSocketManager:
def __init__(self):
self._websockets: Set[web.WebSocketResponse] = set()
self._init_websockets: Set[web.WebSocketResponse] = set() # New set for initialization progress clients
self._checkpoint_websockets: Set[web.WebSocketResponse] = set() # New set for checkpoint download progress
async def handle_connection(self, request: web.Request) -> web.WebSocketResponse:
"""Handle new WebSocket connection"""
@@ -39,6 +40,20 @@ class WebSocketManager:
self._init_websockets.discard(ws)
return ws
async def handle_checkpoint_connection(self, request: web.Request) -> web.WebSocketResponse:
"""Handle new WebSocket connection for checkpoint download progress"""
ws = web.WebSocketResponse()
await ws.prepare(request)
self._checkpoint_websockets.add(ws)
try:
async for msg in ws:
if msg.type == web.WSMsgType.ERROR:
logger.error(f'Checkpoint WebSocket error: {ws.exception()}')
finally:
self._checkpoint_websockets.discard(ws)
return ws
async def broadcast(self, data: Dict):
"""Broadcast message to all connected clients"""
if not self._websockets:
@@ -68,6 +83,17 @@ class WebSocketManager:
await ws.send_json(data)
except Exception as e:
logger.error(f"Error sending initialization progress: {e}")
async def broadcast_checkpoint_progress(self, data: Dict):
"""Broadcast checkpoint download progress to connected clients"""
if not self._checkpoint_websockets:
return
for ws in self._checkpoint_websockets:
try:
await ws.send_json(data)
except Exception as e:
logger.error(f"Error sending checkpoint progress: {e}")
def get_connected_clients_count(self) -> int:
"""Get number of connected clients"""
@@ -77,5 +103,9 @@ class WebSocketManager:
"""Get number of initialization progress clients"""
return len(self._init_websockets)
def get_checkpoint_clients_count(self) -> int:
"""Get number of checkpoint progress clients"""
return len(self._checkpoint_websockets)
# Global instance
ws_manager = WebSocketManager()

View File

@@ -292,9 +292,9 @@ export class CheckpointDownloadManager {
const updateProgress = this.loadingManager.showDownloadProgress(1);
updateProgress(0, 0, this.currentVersion.name);
// Setup WebSocket for progress updates
// Setup WebSocket for progress updates using checkpoint-specific endpoint
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`);
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/checkpoint-progress`);
ws.onmessage = (event) => {
const data = JSON.parse(event.data);