diff --git a/routes/api_routes.py b/routes/api_routes.py index f10502b3..5e5adde5 100644 --- a/routes/api_routes.py +++ b/routes/api_routes.py @@ -8,6 +8,7 @@ from ..utils.file_utils import update_civitai_metadata, load_metadata from ..config import config from ..services.lora_scanner import LoraScanner from operator import itemgetter +from ..services.websocket_manager import ws_manager logger = logging.getLogger(__name__) @@ -26,6 +27,7 @@ class ApiRoutes: app.router.add_post('/api/replace_preview', routes.replace_preview) app.router.add_get('/api/loras', routes.get_loras) app.router.add_post('/api/fetch-all-civitai', routes.fetch_all_civitai) + app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection) async def delete_model(self, request: web.Request) -> web.Response: """Handle model deletion request""" @@ -319,36 +321,63 @@ class ApiRoutes: async def fetch_all_civitai(self, request: web.Request) -> web.Response: """Fetch CivitAI metadata for all loras in the background""" try: - # 获取所有 lora 数据(使用 scanner 的缓存) cache = await self.scanner.get_cached_data() total = len(cache.raw_data) processed = 0 success = 0 - needs_resort = False # 标记是否需要重新排序 + needs_resort = False - for lora in cache.raw_data: - if not lora.get('sha256') or lora.get('civitai') or not lora.get('from_civitai', True): - continue - + # 准备要处理的 loras + to_process = [ + lora for lora in cache.raw_data + if lora.get('sha256') and not lora.get('civitai') and lora.get('from_civitai') + ] + total_to_process = len(to_process) + + # 发送初始进度 + await ws_manager.broadcast({ + 'status': 'started', + 'total': total_to_process, + 'processed': 0, + 'success': 0 + }) + + for lora in to_process: try: original_name = lora.get('model_name') if await self._fetch_and_update_single_lora( sha256=lora['sha256'], file_path=lora['file_path'], - lora=lora # 直接传入缓存中的 lora 对象 + lora=lora ): success += 1 - # 检查 model_name 是否发生变化 if original_name != lora.get('model_name'): needs_resort = True + processed += 1 + # 每处理一个就发送进度更新 + await ws_manager.broadcast({ + 'status': 'processing', + 'total': total_to_process, + 'processed': processed, + 'success': success, + 'current_name': lora.get('model_name', 'Unknown') + }) + except Exception as e: logger.error(f"Error fetching CivitAI data for {lora['file_path']}: {e}") - # 只在需要时进行一次排序 if needs_resort: cache.sorted_by_name = sorted(cache.raw_data, key=itemgetter('model_name')) + + # 发送完成消息 + await ws_manager.broadcast({ + 'status': 'completed', + 'total': total_to_process, + 'processed': processed, + 'success': success + }) return web.json_response({ "success": True, @@ -356,11 +385,13 @@ class ApiRoutes: }) except Exception as e: + # 发送错误消息 + await ws_manager.broadcast({ + 'status': 'error', + 'error': str(e) + }) logger.error(f"Error in fetch_all_civitai: {e}") - return web.Response( - text=str(e), - status=500 - ) + return web.Response(text=str(e), status=500) async def _fetch_and_update_single_lora(self, sha256: str, file_path: str, lora: dict) -> bool: """Fetch and update metadata for a single lora without sorting diff --git a/services/websocket_manager.py b/services/websocket_manager.py new file mode 100644 index 00000000..fccd5e41 --- /dev/null +++ b/services/websocket_manager.py @@ -0,0 +1,43 @@ +import logging +from aiohttp import web +from typing import Set, Dict, Optional + +logger = logging.getLogger(__name__) + +class WebSocketManager: + """Manages WebSocket connections and broadcasts""" + + def __init__(self): + self._websockets: Set[web.WebSocketResponse] = set() + + async def handle_connection(self, request: web.Request) -> web.WebSocketResponse: + """Handle new WebSocket connection""" + ws = web.WebSocketResponse() + await ws.prepare(request) + self._websockets.add(ws) + + try: + async for msg in ws: + if msg.type == web.WSMsgType.ERROR: + logger.error(f'WebSocket error: {ws.exception()}') + finally: + self._websockets.discard(ws) + return ws + + async def broadcast(self, data: Dict): + """Broadcast message to all connected clients""" + if not self._websockets: + return + + for ws in self._websockets: + try: + await ws.send_json(data) + except Exception as e: + logger.error(f"Error sending progress: {e}") + + def get_connected_clients_count(self) -> int: + """Get number of connected clients""" + return len(self._websockets) + +# Global instance +ws_manager = WebSocketManager() \ No newline at end of file diff --git a/static/js/script.js b/static/js/script.js index 8c9d690b..069152f1 100644 --- a/static/js/script.js +++ b/static/js/script.js @@ -832,10 +832,57 @@ async function replacePreview(filePath) { // Fetch CivitAI metadata for all loras async function fetchCivitai() { + let ws = null; + await state.loadingManager.showWithProgress(async (loading) => { try { - loading.setStatus('Fetching metadata for all loras...'); + // 建立 WebSocket 连接 + ws = new WebSocket(`ws://${window.location.host}/ws/fetch-progress`); + // 等待操作完成的 Promise + const operationComplete = new Promise((resolve, reject) => { + ws.onmessage = (event) => { + const data = JSON.parse(event.data); + + switch(data.status) { + case 'started': + loading.setStatus('Starting metadata fetch...'); + break; + + case 'processing': + const percent = ((data.processed / data.total) * 100).toFixed(1); + loading.setProgress(percent); + loading.setStatus( + `Processing (${data.processed}/${data.total}) ${data.current_name}` + ); + break; + + case 'completed': + loading.setProgress(100); + loading.setStatus( + `Completed: Updated ${data.success} of ${data.processed} loras` + ); + resolve(); // 完成操作 + break; + + case 'error': + reject(new Error(data.error)); + break; + } + }; + + ws.onerror = (error) => { + reject(new Error('WebSocket error: ' + error.message)); + }; + }); + + // 等待 WebSocket 连接建立 + await new Promise((resolve, reject) => { + ws.onopen = resolve; + ws.onerror = reject; + }); + + // 发起获取请求 const response = await fetch('/api/fetch-all-civitai', { method: 'POST', headers: { 'Content-Type': 'application/json' } @@ -845,8 +892,8 @@ async function fetchCivitai() { throw new Error('Failed to fetch metadata'); } - const result = await response.json(); - showToast(result.message, 'success'); + // 等待操作完成 + await operationComplete; // 重置并重新加载当前视图 await resetAndReload(); @@ -854,9 +901,14 @@ async function fetchCivitai() { } catch (error) { console.error('Error fetching metadata:', error); showToast('Failed to fetch metadata: ' + error.message, 'error'); + } finally { + // 关闭 WebSocket 连接 + if (ws) { + ws.close(); + } } }, { - initialMessage: 'Starting metadata fetch...', + initialMessage: 'Connecting...', completionMessage: 'Metadata update complete' }); } \ No newline at end of file