mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-26 07:35:44 -03:00
Add WebSocket progress tracking for CivitAI metadata fetching
This commit is contained in:
@@ -8,6 +8,7 @@ from ..utils.file_utils import update_civitai_metadata, load_metadata
|
|||||||
from ..config import config
|
from ..config import config
|
||||||
from ..services.lora_scanner import LoraScanner
|
from ..services.lora_scanner import LoraScanner
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
|
from ..services.websocket_manager import ws_manager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -26,6 +27,7 @@ class ApiRoutes:
|
|||||||
app.router.add_post('/api/replace_preview', routes.replace_preview)
|
app.router.add_post('/api/replace_preview', routes.replace_preview)
|
||||||
app.router.add_get('/api/loras', routes.get_loras)
|
app.router.add_get('/api/loras', routes.get_loras)
|
||||||
app.router.add_post('/api/fetch-all-civitai', routes.fetch_all_civitai)
|
app.router.add_post('/api/fetch-all-civitai', routes.fetch_all_civitai)
|
||||||
|
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
||||||
|
|
||||||
async def delete_model(self, request: web.Request) -> web.Response:
|
async def delete_model(self, request: web.Request) -> web.Response:
|
||||||
"""Handle model deletion request"""
|
"""Handle model deletion request"""
|
||||||
@@ -319,36 +321,63 @@ class ApiRoutes:
|
|||||||
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
|
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
|
||||||
"""Fetch CivitAI metadata for all loras in the background"""
|
"""Fetch CivitAI metadata for all loras in the background"""
|
||||||
try:
|
try:
|
||||||
# 获取所有 lora 数据(使用 scanner 的缓存)
|
|
||||||
cache = await self.scanner.get_cached_data()
|
cache = await self.scanner.get_cached_data()
|
||||||
total = len(cache.raw_data)
|
total = len(cache.raw_data)
|
||||||
processed = 0
|
processed = 0
|
||||||
success = 0
|
success = 0
|
||||||
needs_resort = False # 标记是否需要重新排序
|
needs_resort = False
|
||||||
|
|
||||||
for lora in cache.raw_data:
|
# 准备要处理的 loras
|
||||||
if not lora.get('sha256') or lora.get('civitai') or not lora.get('from_civitai', True):
|
to_process = [
|
||||||
continue
|
lora for lora in cache.raw_data
|
||||||
|
if lora.get('sha256') and not lora.get('civitai') and lora.get('from_civitai')
|
||||||
|
]
|
||||||
|
total_to_process = len(to_process)
|
||||||
|
|
||||||
|
# 发送初始进度
|
||||||
|
await ws_manager.broadcast({
|
||||||
|
'status': 'started',
|
||||||
|
'total': total_to_process,
|
||||||
|
'processed': 0,
|
||||||
|
'success': 0
|
||||||
|
})
|
||||||
|
|
||||||
|
for lora in to_process:
|
||||||
try:
|
try:
|
||||||
original_name = lora.get('model_name')
|
original_name = lora.get('model_name')
|
||||||
if await self._fetch_and_update_single_lora(
|
if await self._fetch_and_update_single_lora(
|
||||||
sha256=lora['sha256'],
|
sha256=lora['sha256'],
|
||||||
file_path=lora['file_path'],
|
file_path=lora['file_path'],
|
||||||
lora=lora # 直接传入缓存中的 lora 对象
|
lora=lora
|
||||||
):
|
):
|
||||||
success += 1
|
success += 1
|
||||||
# 检查 model_name 是否发生变化
|
|
||||||
if original_name != lora.get('model_name'):
|
if original_name != lora.get('model_name'):
|
||||||
needs_resort = True
|
needs_resort = True
|
||||||
|
|
||||||
processed += 1
|
processed += 1
|
||||||
|
|
||||||
|
# 每处理一个就发送进度更新
|
||||||
|
await ws_manager.broadcast({
|
||||||
|
'status': 'processing',
|
||||||
|
'total': total_to_process,
|
||||||
|
'processed': processed,
|
||||||
|
'success': success,
|
||||||
|
'current_name': lora.get('model_name', 'Unknown')
|
||||||
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching CivitAI data for {lora['file_path']}: {e}")
|
logger.error(f"Error fetching CivitAI data for {lora['file_path']}: {e}")
|
||||||
|
|
||||||
# 只在需要时进行一次排序
|
|
||||||
if needs_resort:
|
if needs_resort:
|
||||||
cache.sorted_by_name = sorted(cache.raw_data, key=itemgetter('model_name'))
|
cache.sorted_by_name = sorted(cache.raw_data, key=itemgetter('model_name'))
|
||||||
|
|
||||||
|
# 发送完成消息
|
||||||
|
await ws_manager.broadcast({
|
||||||
|
'status': 'completed',
|
||||||
|
'total': total_to_process,
|
||||||
|
'processed': processed,
|
||||||
|
'success': success
|
||||||
|
})
|
||||||
|
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
"success": True,
|
"success": True,
|
||||||
@@ -356,11 +385,13 @@ class ApiRoutes:
|
|||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# 发送错误消息
|
||||||
|
await ws_manager.broadcast({
|
||||||
|
'status': 'error',
|
||||||
|
'error': str(e)
|
||||||
|
})
|
||||||
logger.error(f"Error in fetch_all_civitai: {e}")
|
logger.error(f"Error in fetch_all_civitai: {e}")
|
||||||
return web.Response(
|
return web.Response(text=str(e), status=500)
|
||||||
text=str(e),
|
|
||||||
status=500
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _fetch_and_update_single_lora(self, sha256: str, file_path: str, lora: dict) -> bool:
|
async def _fetch_and_update_single_lora(self, sha256: str, file_path: str, lora: dict) -> bool:
|
||||||
"""Fetch and update metadata for a single lora without sorting
|
"""Fetch and update metadata for a single lora without sorting
|
||||||
|
|||||||
43
services/websocket_manager.py
Normal file
43
services/websocket_manager.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
import logging
|
||||||
|
from aiohttp import web
|
||||||
|
from typing import Set, Dict, Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class WebSocketManager:
|
||||||
|
"""Manages WebSocket connections and broadcasts"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._websockets: Set[web.WebSocketResponse] = set()
|
||||||
|
|
||||||
|
async def handle_connection(self, request: web.Request) -> web.WebSocketResponse:
|
||||||
|
"""Handle new WebSocket connection"""
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
self._websockets.add(ws)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for msg in ws:
|
||||||
|
if msg.type == web.WSMsgType.ERROR:
|
||||||
|
logger.error(f'WebSocket error: {ws.exception()}')
|
||||||
|
finally:
|
||||||
|
self._websockets.discard(ws)
|
||||||
|
return ws
|
||||||
|
|
||||||
|
async def broadcast(self, data: Dict):
|
||||||
|
"""Broadcast message to all connected clients"""
|
||||||
|
if not self._websockets:
|
||||||
|
return
|
||||||
|
|
||||||
|
for ws in self._websockets:
|
||||||
|
try:
|
||||||
|
await ws.send_json(data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error sending progress: {e}")
|
||||||
|
|
||||||
|
def get_connected_clients_count(self) -> int:
|
||||||
|
"""Get number of connected clients"""
|
||||||
|
return len(self._websockets)
|
||||||
|
|
||||||
|
# Global instance
|
||||||
|
ws_manager = WebSocketManager()
|
||||||
@@ -832,10 +832,57 @@ async function replacePreview(filePath) {
|
|||||||
|
|
||||||
// Fetch CivitAI metadata for all loras
|
// Fetch CivitAI metadata for all loras
|
||||||
async function fetchCivitai() {
|
async function fetchCivitai() {
|
||||||
|
let ws = null;
|
||||||
|
|
||||||
await state.loadingManager.showWithProgress(async (loading) => {
|
await state.loadingManager.showWithProgress(async (loading) => {
|
||||||
try {
|
try {
|
||||||
loading.setStatus('Fetching metadata for all loras...');
|
// 建立 WebSocket 连接
|
||||||
|
ws = new WebSocket(`ws://${window.location.host}/ws/fetch-progress`);
|
||||||
|
|
||||||
|
// 等待操作完成的 Promise
|
||||||
|
const operationComplete = new Promise((resolve, reject) => {
|
||||||
|
ws.onmessage = (event) => {
|
||||||
|
const data = JSON.parse(event.data);
|
||||||
|
|
||||||
|
switch(data.status) {
|
||||||
|
case 'started':
|
||||||
|
loading.setStatus('Starting metadata fetch...');
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 'processing':
|
||||||
|
const percent = ((data.processed / data.total) * 100).toFixed(1);
|
||||||
|
loading.setProgress(percent);
|
||||||
|
loading.setStatus(
|
||||||
|
`Processing (${data.processed}/${data.total}) ${data.current_name}`
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 'completed':
|
||||||
|
loading.setProgress(100);
|
||||||
|
loading.setStatus(
|
||||||
|
`Completed: Updated ${data.success} of ${data.processed} loras`
|
||||||
|
);
|
||||||
|
resolve(); // 完成操作
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 'error':
|
||||||
|
reject(new Error(data.error));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.onerror = (error) => {
|
||||||
|
reject(new Error('WebSocket error: ' + error.message));
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
// 等待 WebSocket 连接建立
|
||||||
|
await new Promise((resolve, reject) => {
|
||||||
|
ws.onopen = resolve;
|
||||||
|
ws.onerror = reject;
|
||||||
|
});
|
||||||
|
|
||||||
|
// 发起获取请求
|
||||||
const response = await fetch('/api/fetch-all-civitai', {
|
const response = await fetch('/api/fetch-all-civitai', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: { 'Content-Type': 'application/json' }
|
headers: { 'Content-Type': 'application/json' }
|
||||||
@@ -845,8 +892,8 @@ async function fetchCivitai() {
|
|||||||
throw new Error('Failed to fetch metadata');
|
throw new Error('Failed to fetch metadata');
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = await response.json();
|
// 等待操作完成
|
||||||
showToast(result.message, 'success');
|
await operationComplete;
|
||||||
|
|
||||||
// 重置并重新加载当前视图
|
// 重置并重新加载当前视图
|
||||||
await resetAndReload();
|
await resetAndReload();
|
||||||
@@ -854,9 +901,14 @@ async function fetchCivitai() {
|
|||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Error fetching metadata:', error);
|
console.error('Error fetching metadata:', error);
|
||||||
showToast('Failed to fetch metadata: ' + error.message, 'error');
|
showToast('Failed to fetch metadata: ' + error.message, 'error');
|
||||||
|
} finally {
|
||||||
|
// 关闭 WebSocket 连接
|
||||||
|
if (ws) {
|
||||||
|
ws.close();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}, {
|
}, {
|
||||||
initialMessage: 'Starting metadata fetch...',
|
initialMessage: 'Connecting...',
|
||||||
completionMessage: 'Metadata update complete'
|
completionMessage: 'Metadata update complete'
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user