mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 06:32:12 -03:00
feat: Add initialization progress WebSocket and UI components
- Implement WebSocket route for initialization progress updates - Create initialization component with progress bar and stages - Add styles for initialization UI - Update base template to include initialization component - Enhance model scanner to broadcast progress during initialization
This commit is contained in:
@@ -13,6 +13,7 @@ from .model_cache import ModelCache
|
||||
from .model_hash_index import ModelHashIndex
|
||||
from ..utils.constants import PREVIEW_EXTENSIONS
|
||||
from .service_registry import ServiceRegistry
|
||||
from .websocket_manager import ws_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -61,21 +62,99 @@ class ModelScanner:
|
||||
# Set initializing flag to true
|
||||
self._is_initializing = True
|
||||
|
||||
start_time = time.time()
|
||||
# Use thread pool to execute CPU-intensive operations
|
||||
# First, count all model files to track progress
|
||||
await ws_manager.broadcast_init_progress({
|
||||
'stage': 'scan_folders',
|
||||
'progress': 0,
|
||||
'details': f"Scanning {self.model_type} folders..."
|
||||
})
|
||||
|
||||
# Count files in a separate thread to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
total_files = await loop.run_in_executor(
|
||||
None, # Use default thread pool
|
||||
self._count_model_files # Run file counting in thread
|
||||
)
|
||||
|
||||
await ws_manager.broadcast_init_progress({
|
||||
'stage': 'count_models',
|
||||
'progress': 1, # Changed from 10 to 1
|
||||
'details': f"Found {total_files} {self.model_type} files"
|
||||
})
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Use thread pool to execute CPU-intensive operations with progress reporting
|
||||
await loop.run_in_executor(
|
||||
None, # Use default thread pool
|
||||
self._initialize_cache_sync # Run synchronous version in thread
|
||||
self._initialize_cache_sync, # Run synchronous version in thread
|
||||
total_files # Pass the total file count for progress reporting
|
||||
)
|
||||
|
||||
# Send final progress update
|
||||
await ws_manager.broadcast_init_progress({
|
||||
'stage': 'finalizing',
|
||||
'progress': 99, # Changed from 95 to 99
|
||||
'details': f"Finalizing {self.model_type} cache..."
|
||||
})
|
||||
|
||||
logger.info(f"{self.model_type.capitalize()} cache initialized in {time.time() - start_time:.2f} seconds. Found {len(self._cache.raw_data)} models")
|
||||
|
||||
# Send completion message
|
||||
await asyncio.sleep(0.5) # Small delay to ensure final progress message is sent
|
||||
await ws_manager.broadcast_init_progress({
|
||||
'stage': 'finalizing',
|
||||
'progress': 100,
|
||||
'status': 'complete',
|
||||
'details': f"Completed! Found {len(self._cache.raw_data)} {self.model_type} files."
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.model_type.capitalize()} Scanner: Error initializing cache in background: {e}")
|
||||
finally:
|
||||
# Always clear the initializing flag when done
|
||||
self._is_initializing = False
|
||||
|
||||
def _initialize_cache_sync(self):
|
||||
def _count_model_files(self) -> int:
|
||||
"""Count all model files with supported extensions in all roots
|
||||
|
||||
Returns:
|
||||
int: Total number of model files found
|
||||
"""
|
||||
total_files = 0
|
||||
visited_real_paths = set()
|
||||
|
||||
for root_path in self.get_model_roots():
|
||||
if not os.path.exists(root_path):
|
||||
continue
|
||||
|
||||
def count_recursive(path):
|
||||
nonlocal total_files
|
||||
try:
|
||||
real_path = os.path.realpath(path)
|
||||
if real_path in visited_real_paths:
|
||||
return
|
||||
visited_real_paths.add(real_path)
|
||||
|
||||
with os.scandir(path) as it:
|
||||
for entry in it:
|
||||
try:
|
||||
if entry.is_file(follow_symlinks=True):
|
||||
ext = os.path.splitext(entry.name)[1].lower()
|
||||
if ext in self.file_extensions:
|
||||
total_files += 1
|
||||
elif entry.is_dir(follow_symlinks=True):
|
||||
count_recursive(entry.path)
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting files in entry {entry.path}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting files in {path}: {e}")
|
||||
|
||||
count_recursive(root_path)
|
||||
|
||||
return total_files
|
||||
|
||||
def _initialize_cache_sync(self, total_files=0):
|
||||
"""Synchronous version of cache initialization for thread pool execution"""
|
||||
try:
|
||||
# Create a new event loop for this thread
|
||||
@@ -84,8 +163,83 @@ class ModelScanner:
|
||||
|
||||
# Create a synchronous method to bypass the async lock
|
||||
def sync_initialize_cache():
|
||||
# Directly call the scan method to avoid lock issues
|
||||
raw_data = loop.run_until_complete(self.scan_all_models())
|
||||
# Track progress
|
||||
processed_files = 0
|
||||
last_progress_time = time.time()
|
||||
last_progress_percent = 0
|
||||
|
||||
# We need a wrapper around scan_all_models to track progress
|
||||
# This is a local function that will run in our thread's event loop
|
||||
async def scan_with_progress():
|
||||
nonlocal processed_files, last_progress_time, last_progress_percent
|
||||
|
||||
# For storing raw model data
|
||||
all_models = []
|
||||
|
||||
# Process each model root
|
||||
for root_path in self.get_model_roots():
|
||||
if not os.path.exists(root_path):
|
||||
continue
|
||||
|
||||
# Track visited paths to avoid symlink loops
|
||||
visited_paths = set()
|
||||
|
||||
# Recursively process directory
|
||||
async def scan_dir_with_progress(path):
|
||||
nonlocal processed_files, last_progress_time, last_progress_percent
|
||||
|
||||
try:
|
||||
real_path = os.path.realpath(path)
|
||||
if real_path in visited_paths:
|
||||
return
|
||||
visited_paths.add(real_path)
|
||||
|
||||
with os.scandir(path) as it:
|
||||
entries = list(it)
|
||||
for entry in entries:
|
||||
try:
|
||||
if entry.is_file(follow_symlinks=True):
|
||||
ext = os.path.splitext(entry.name)[1].lower()
|
||||
if ext in self.file_extensions:
|
||||
file_path = entry.path.replace(os.sep, "/")
|
||||
result = await self._process_model_file(file_path, root_path)
|
||||
if result:
|
||||
all_models.append(result)
|
||||
|
||||
# Update progress counter
|
||||
processed_files += 1
|
||||
|
||||
# Update progress periodically (not every file to avoid excessive updates)
|
||||
current_time = time.time()
|
||||
if total_files > 0 and (current_time - last_progress_time > 0.5 or processed_files == total_files):
|
||||
# Adjusted progress calculation
|
||||
progress_percent = min(99, int(1 + (processed_files / total_files) * 98))
|
||||
if progress_percent > last_progress_percent:
|
||||
last_progress_percent = progress_percent
|
||||
last_progress_time = current_time
|
||||
|
||||
# Send progress update through websocket
|
||||
await ws_manager.broadcast_init_progress({
|
||||
'stage': 'process_models',
|
||||
'progress': progress_percent,
|
||||
'details': f"Processing {self.model_type} files: {processed_files}/{total_files}"
|
||||
})
|
||||
|
||||
elif entry.is_dir(follow_symlinks=True):
|
||||
await scan_dir_with_progress(entry.path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing entry {entry.path}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning {path}: {e}")
|
||||
|
||||
# Process the root path
|
||||
await scan_dir_with_progress(root_path)
|
||||
|
||||
return all_models
|
||||
|
||||
# Run the progress-tracking scan function
|
||||
raw_data = loop.run_until_complete(scan_with_progress())
|
||||
|
||||
# Update hash index and tags count
|
||||
for model_data in raw_data:
|
||||
@@ -136,6 +290,7 @@ class ModelScanner:
|
||||
|
||||
async def _initialize_cache(self) -> None:
|
||||
"""Initialize or refresh the cache"""
|
||||
self._is_initializing = True # Set flag
|
||||
try:
|
||||
start_time = time.time()
|
||||
# Clear existing hash index
|
||||
@@ -171,15 +326,20 @@ class ModelScanner:
|
||||
logger.info(f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, found {len(raw_data)} models")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.model_type.capitalize()} Scanner: Error initializing cache: {e}")
|
||||
self._cache = ModelCache(
|
||||
raw_data=[],
|
||||
sorted_by_name=[],
|
||||
sorted_by_date=[],
|
||||
folders=[]
|
||||
)
|
||||
# Ensure cache is at least an empty structure on error
|
||||
if self._cache is None:
|
||||
self._cache = ModelCache(
|
||||
raw_data=[],
|
||||
sorted_by_name=[],
|
||||
sorted_by_date=[],
|
||||
folders=[]
|
||||
)
|
||||
finally:
|
||||
self._is_initializing = False # Unset flag
|
||||
|
||||
async def _reconcile_cache(self) -> None:
|
||||
"""Fast cache reconciliation - only process differences between cache and filesystem"""
|
||||
self._is_initializing = True # Set flag for reconciliation duration
|
||||
try:
|
||||
start_time = time.time()
|
||||
logger.info(f"{self.model_type.capitalize()} Scanner: Starting fast cache reconciliation...")
|
||||
@@ -306,6 +466,8 @@ class ModelScanner:
|
||||
logger.info(f"{self.model_type.capitalize()} Scanner: Cache reconciliation completed in {time.time() - start_time:.2f} seconds. Added {total_added}, removed {total_removed} models.")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.model_type.capitalize()} Scanner: Error reconciling cache: {e}", exc_info=True)
|
||||
finally:
|
||||
self._is_initializing = False # Unset flag
|
||||
|
||||
# These methods should be implemented in child classes
|
||||
async def scan_all_models(self) -> List[Dict]:
|
||||
|
||||
@@ -9,6 +9,7 @@ class WebSocketManager:
|
||||
|
||||
def __init__(self):
|
||||
self._websockets: Set[web.WebSocketResponse] = set()
|
||||
self._init_websockets: Set[web.WebSocketResponse] = set() # New set for initialization progress clients
|
||||
|
||||
async def handle_connection(self, request: web.Request) -> web.WebSocketResponse:
|
||||
"""Handle new WebSocket connection"""
|
||||
@@ -23,6 +24,20 @@ class WebSocketManager:
|
||||
finally:
|
||||
self._websockets.discard(ws)
|
||||
return ws
|
||||
|
||||
async def handle_init_connection(self, request: web.Request) -> web.WebSocketResponse:
|
||||
"""Handle new WebSocket connection for initialization progress"""
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
self._init_websockets.add(ws)
|
||||
|
||||
try:
|
||||
async for msg in ws:
|
||||
if msg.type == web.WSMsgType.ERROR:
|
||||
logger.error(f'Init WebSocket error: {ws.exception()}')
|
||||
finally:
|
||||
self._init_websockets.discard(ws)
|
||||
return ws
|
||||
|
||||
async def broadcast(self, data: Dict):
|
||||
"""Broadcast message to all connected clients"""
|
||||
@@ -34,10 +49,25 @@ class WebSocketManager:
|
||||
await ws.send_json(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending progress: {e}")
|
||||
|
||||
async def broadcast_init_progress(self, data: Dict):
|
||||
"""Broadcast initialization progress to connected clients"""
|
||||
if not self._init_websockets:
|
||||
return
|
||||
|
||||
for ws in self._init_websockets:
|
||||
try:
|
||||
await ws.send_json(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending initialization progress: {e}")
|
||||
|
||||
def get_connected_clients_count(self) -> int:
|
||||
"""Get number of connected clients"""
|
||||
return len(self._websockets)
|
||||
|
||||
def get_init_clients_count(self) -> int:
|
||||
"""Get number of initialization progress clients"""
|
||||
return len(self._init_websockets)
|
||||
|
||||
# Global instance
|
||||
ws_manager = WebSocketManager()
|
||||
ws_manager = WebSocketManager()
|
||||
Reference in New Issue
Block a user