mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-26 07:35:44 -03:00
Add WebSocket progress tracking for CivitAI metadata fetching
This commit is contained in:
@@ -8,6 +8,7 @@ from ..utils.file_utils import update_civitai_metadata, load_metadata
|
||||
from ..config import config
|
||||
from ..services.lora_scanner import LoraScanner
|
||||
from operator import itemgetter
|
||||
from ..services.websocket_manager import ws_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -26,6 +27,7 @@ class ApiRoutes:
|
||||
app.router.add_post('/api/replace_preview', routes.replace_preview)
|
||||
app.router.add_get('/api/loras', routes.get_loras)
|
||||
app.router.add_post('/api/fetch-all-civitai', routes.fetch_all_civitai)
|
||||
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
||||
|
||||
async def delete_model(self, request: web.Request) -> web.Response:
|
||||
"""Handle model deletion request"""
|
||||
@@ -319,36 +321,63 @@ class ApiRoutes:
|
||||
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
|
||||
"""Fetch CivitAI metadata for all loras in the background"""
|
||||
try:
|
||||
# 获取所有 lora 数据(使用 scanner 的缓存)
|
||||
cache = await self.scanner.get_cached_data()
|
||||
total = len(cache.raw_data)
|
||||
processed = 0
|
||||
success = 0
|
||||
needs_resort = False # 标记是否需要重新排序
|
||||
needs_resort = False
|
||||
|
||||
for lora in cache.raw_data:
|
||||
if not lora.get('sha256') or lora.get('civitai') or not lora.get('from_civitai', True):
|
||||
continue
|
||||
|
||||
# 准备要处理的 loras
|
||||
to_process = [
|
||||
lora for lora in cache.raw_data
|
||||
if lora.get('sha256') and not lora.get('civitai') and lora.get('from_civitai')
|
||||
]
|
||||
total_to_process = len(to_process)
|
||||
|
||||
# 发送初始进度
|
||||
await ws_manager.broadcast({
|
||||
'status': 'started',
|
||||
'total': total_to_process,
|
||||
'processed': 0,
|
||||
'success': 0
|
||||
})
|
||||
|
||||
for lora in to_process:
|
||||
try:
|
||||
original_name = lora.get('model_name')
|
||||
if await self._fetch_and_update_single_lora(
|
||||
sha256=lora['sha256'],
|
||||
file_path=lora['file_path'],
|
||||
lora=lora # 直接传入缓存中的 lora 对象
|
||||
lora=lora
|
||||
):
|
||||
success += 1
|
||||
# 检查 model_name 是否发生变化
|
||||
if original_name != lora.get('model_name'):
|
||||
needs_resort = True
|
||||
|
||||
processed += 1
|
||||
|
||||
# 每处理一个就发送进度更新
|
||||
await ws_manager.broadcast({
|
||||
'status': 'processing',
|
||||
'total': total_to_process,
|
||||
'processed': processed,
|
||||
'success': success,
|
||||
'current_name': lora.get('model_name', 'Unknown')
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching CivitAI data for {lora['file_path']}: {e}")
|
||||
|
||||
# 只在需要时进行一次排序
|
||||
if needs_resort:
|
||||
cache.sorted_by_name = sorted(cache.raw_data, key=itemgetter('model_name'))
|
||||
|
||||
# 发送完成消息
|
||||
await ws_manager.broadcast({
|
||||
'status': 'completed',
|
||||
'total': total_to_process,
|
||||
'processed': processed,
|
||||
'success': success
|
||||
})
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
@@ -356,11 +385,13 @@ class ApiRoutes:
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
# 发送错误消息
|
||||
await ws_manager.broadcast({
|
||||
'status': 'error',
|
||||
'error': str(e)
|
||||
})
|
||||
logger.error(f"Error in fetch_all_civitai: {e}")
|
||||
return web.Response(
|
||||
text=str(e),
|
||||
status=500
|
||||
)
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def _fetch_and_update_single_lora(self, sha256: str, file_path: str, lora: dict) -> bool:
|
||||
"""Fetch and update metadata for a single lora without sorting
|
||||
|
||||
Reference in New Issue
Block a user