mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
feat(websocket-manager): implement caching for initialization progress and enhance broadcast functionality
This commit is contained in:
@@ -16,6 +16,8 @@ class WebSocketManager:
|
|||||||
self._download_websockets: Dict[str, web.WebSocketResponse] = {} # New dict for download-specific clients
|
self._download_websockets: Dict[str, web.WebSocketResponse] = {} # New dict for download-specific clients
|
||||||
# Add progress tracking dictionary
|
# Add progress tracking dictionary
|
||||||
self._download_progress: Dict[str, Dict] = {}
|
self._download_progress: Dict[str, Dict] = {}
|
||||||
|
# Cache last initialization progress payloads
|
||||||
|
self._last_init_progress: Dict[str, Dict] = {}
|
||||||
# Add auto-organize progress tracking
|
# Add auto-organize progress tracking
|
||||||
self._auto_organize_progress: Optional[Dict] = None
|
self._auto_organize_progress: Optional[Dict] = None
|
||||||
self._auto_organize_lock = asyncio.Lock()
|
self._auto_organize_lock = asyncio.Lock()
|
||||||
@@ -41,6 +43,8 @@ class WebSocketManager:
|
|||||||
self._init_websockets.add(ws)
|
self._init_websockets.add(ws)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
await self._send_cached_init_progress(ws)
|
||||||
|
|
||||||
async for msg in ws:
|
async for msg in ws:
|
||||||
if msg.type == web.WSMsgType.ERROR:
|
if msg.type == web.WSMsgType.ERROR:
|
||||||
logger.error(f'Init WebSocket error: {ws.exception()}')
|
logger.error(f'Init WebSocket error: {ws.exception()}')
|
||||||
@@ -102,22 +106,52 @@ class WebSocketManager:
|
|||||||
|
|
||||||
async def broadcast_init_progress(self, data: Dict):
|
async def broadcast_init_progress(self, data: Dict):
|
||||||
"""Broadcast initialization progress to connected clients"""
|
"""Broadcast initialization progress to connected clients"""
|
||||||
|
payload = dict(data) if data else {}
|
||||||
|
|
||||||
|
if 'stage' not in payload:
|
||||||
|
payload['stage'] = 'processing'
|
||||||
|
if 'progress' not in payload:
|
||||||
|
payload['progress'] = 0
|
||||||
|
if 'details' not in payload:
|
||||||
|
payload['details'] = 'Processing...'
|
||||||
|
|
||||||
|
key = self._get_init_progress_key(payload)
|
||||||
|
self._last_init_progress[key] = dict(payload)
|
||||||
|
|
||||||
if not self._init_websockets:
|
if not self._init_websockets:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Ensure data has all required fields
|
stale_clients = []
|
||||||
if 'stage' not in data:
|
for ws in list(self._init_websockets):
|
||||||
data['stage'] = 'processing'
|
|
||||||
if 'progress' not in data:
|
|
||||||
data['progress'] = 0
|
|
||||||
if 'details' not in data:
|
|
||||||
data['details'] = 'Processing...'
|
|
||||||
|
|
||||||
for ws in self._init_websockets:
|
|
||||||
try:
|
try:
|
||||||
await ws.send_json(data)
|
await ws.send_json(payload)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error sending initialization progress: {e}")
|
logger.error(f"Error sending initialization progress: {e}")
|
||||||
|
stale_clients.append(ws)
|
||||||
|
|
||||||
|
for ws in stale_clients:
|
||||||
|
self._init_websockets.discard(ws)
|
||||||
|
|
||||||
|
async def _send_cached_init_progress(self, ws: web.WebSocketResponse) -> None:
|
||||||
|
"""Send cached initialization progress payloads to a new client"""
|
||||||
|
if not self._last_init_progress:
|
||||||
|
return
|
||||||
|
|
||||||
|
for payload in list(self._last_init_progress.values()):
|
||||||
|
try:
|
||||||
|
await ws.send_json(payload)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f'Error sending cached initialization progress: {e}')
|
||||||
|
|
||||||
|
def _get_init_progress_key(self, data: Dict) -> str:
|
||||||
|
"""Return a stable key for caching initialization progress payloads"""
|
||||||
|
page_type = data.get('pageType')
|
||||||
|
if page_type:
|
||||||
|
return f'page:{page_type}'
|
||||||
|
scanner_type = data.get('scanner_type')
|
||||||
|
if scanner_type:
|
||||||
|
return f'scanner:{scanner_type}'
|
||||||
|
return 'global'
|
||||||
|
|
||||||
async def broadcast_download_progress(self, download_id: str, data: Dict):
|
async def broadcast_download_progress(self, download_id: str, data: Dict):
|
||||||
"""Send progress update to specific download client"""
|
"""Send progress update to specific download client"""
|
||||||
@@ -203,3 +237,4 @@ class WebSocketManager:
|
|||||||
|
|
||||||
# Global instance
|
# Global instance
|
||||||
ws_manager = WebSocketManager()
|
ws_manager = WebSocketManager()
|
||||||
|
|
||||||
|
|||||||
@@ -482,13 +482,17 @@ export class SettingsManager {
|
|||||||
librarySelect.disabled = true;
|
librarySelect.disabled = true;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
state.loadingManager.showSimpleLoading('Activating library...');
|
||||||
await this.activateLibrary(selectedLibrary);
|
await this.activateLibrary(selectedLibrary);
|
||||||
|
// Add a short delay before reloading the page
|
||||||
|
await new Promise(resolve => setTimeout(resolve, 300));
|
||||||
window.location.reload();
|
window.location.reload();
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Failed to activate library:', error);
|
console.error('Failed to activate library:', error);
|
||||||
showToast('toast.settings.libraryActivateFailed', { message: error.message }, 'error');
|
showToast('toast.settings.libraryActivateFailed', { message: error.message }, 'error');
|
||||||
await this.loadLibraries();
|
await this.loadLibraries();
|
||||||
} finally {
|
} finally {
|
||||||
|
state.loadingManager.hide();
|
||||||
if (!document.hidden) {
|
if (!document.hidden) {
|
||||||
librarySelect.disabled = librarySelect.options.length <= 1;
|
librarySelect.disabled = librarySelect.options.length <= 1;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -36,6 +36,25 @@ async def test_broadcast_init_progress_adds_defaults(manager):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_broadcast_init_progress_caches_payload(manager):
|
||||||
|
await manager.broadcast_init_progress({'pageType': 'loras', 'progress': 42})
|
||||||
|
|
||||||
|
cached = manager._last_init_progress.get('page:loras')
|
||||||
|
assert cached is not None
|
||||||
|
assert cached['progress'] == 42
|
||||||
|
assert cached['stage'] == 'processing'
|
||||||
|
assert cached['details'] == 'Processing...'
|
||||||
|
|
||||||
|
|
||||||
|
async def test_send_cached_progress_to_new_client(manager):
|
||||||
|
await manager.broadcast_init_progress({'pageType': 'loras', 'progress': 87})
|
||||||
|
|
||||||
|
ws = DummyWebSocket()
|
||||||
|
await manager._send_cached_init_progress(ws)
|
||||||
|
|
||||||
|
assert ws.messages[-1]['progress'] == 87
|
||||||
|
assert ws.messages[-1]['pageType'] == 'loras'
|
||||||
|
|
||||||
async def test_broadcast_download_progress_tracks_state(manager):
|
async def test_broadcast_download_progress_tracks_state(manager):
|
||||||
ws = DummyWebSocket()
|
ws = DummyWebSocket()
|
||||||
download_id = "abc"
|
download_id = "abc"
|
||||||
|
|||||||
Reference in New Issue
Block a user