mirror of
https://github.com/Azornes/Comfyui-LayerForge.git
synced 2026-03-21 20:52:12 -03:00
Add WebSocket-based RAM output for CanvasNode
Introduces a WebSocket-based mechanism for CanvasNode to send and receive canvas image and mask data in RAM, enabling fast, diskless data transfer between frontend and backend. Adds a new WebSocketManager utility, updates CanvasIO to support RAM output mode, and modifies CanvasView to send canvas data via WebSocket before prompt execution. The backend (canvas_node.py) is updated to handle WebSocket data storage and retrieval, with improved locking and cleanup logic. This change improves workflow speed and reliability by avoiding unnecessary disk I/O and ensuring up-to-date canvas data is available during node execution.
This commit is contained in:
301
canvas_node.py
301
canvas_node.py
@@ -5,6 +5,8 @@ import numpy as np
|
||||
import folder_paths
|
||||
from server import PromptServer
|
||||
from aiohttp import web
|
||||
import asyncio
|
||||
import threading
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
from torchvision import transforms
|
||||
@@ -91,6 +93,9 @@ class BiRefNet(torch.nn.Module):
|
||||
|
||||
|
||||
class CanvasNode:
|
||||
_canvas_data_storage = {}
|
||||
_storage_lock = threading.Lock()
|
||||
|
||||
_canvas_cache = {
|
||||
'image': None,
|
||||
'mask': None,
|
||||
@@ -99,10 +104,16 @@ class CanvasNode:
|
||||
'persistent_cache': {},
|
||||
'last_execution_id': None
|
||||
}
|
||||
|
||||
# Simple in-memory storage for canvas data, keyed by prompt_id
|
||||
# WebSocket-based storage for canvas data per node
|
||||
_websocket_data = {}
|
||||
_websocket_listeners = {}
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.flow_id = str(uuid.uuid4())
|
||||
self.node_id = None # Will be set when node is created
|
||||
|
||||
if self.__class__._canvas_cache['persistent_cache']:
|
||||
self.restore_cache()
|
||||
@@ -166,14 +177,18 @@ class CanvasNode:
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"canvas_image": ("STRING", {"default": "canvas_image.png"}),
|
||||
"trigger": ("INT", {"default": 0, "min": 0, "max": 99999999, "step": 1, "hidden": True}),
|
||||
"output_switch": ("BOOLEAN", {"default": True}),
|
||||
"cache_enabled": ("BOOLEAN", {"default": True, "label": "Enable Cache"})
|
||||
"cache_enabled": ("BOOLEAN", {"default": True, "label": "Enable Cache"}),
|
||||
"node_id": ("STRING", {"default": "0", "hidden": True}),
|
||||
},
|
||||
"optional": {
|
||||
"input_image": ("IMAGE",),
|
||||
"input_mask": ("MASK",)
|
||||
},
|
||||
"hidden": {
|
||||
"prompt": ("PROMPT",),
|
||||
"unique_id": ("UNIQUE_ID",),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -230,161 +245,72 @@ class CanvasNode:
|
||||
return None
|
||||
|
||||
# Zmienna blokująca równoczesne wykonania
|
||||
_processing_lock = None
|
||||
|
||||
def process_canvas_image(self, canvas_image, trigger, output_switch, cache_enabled, input_image=None,
|
||||
_processing_lock = threading.Lock()
|
||||
|
||||
def process_canvas_image(self, trigger, output_switch, cache_enabled, node_id, prompt=None, unique_id=None, input_image=None,
|
||||
input_mask=None):
|
||||
|
||||
log_info(f"[CanvasNode] 🔍 process_canvas_image wejście – node_id={node_id!r}, unique_id={unique_id!r}, trigger={trigger}, output_switch={output_switch}")
|
||||
|
||||
try:
|
||||
# Sprawdź czy już trwa przetwarzanie
|
||||
if self.__class__._processing_lock is not None:
|
||||
log_warn(f"Process already in progress, waiting for completion...")
|
||||
return () # Zwróć pusty wynik, aby uniknąć równoczesnych przetworzeń
|
||||
|
||||
# Ustaw blokadę
|
||||
self.__class__._processing_lock = True
|
||||
if not self.__class__._processing_lock.acquire(blocking=False):
|
||||
log_warn(f"Process already in progress for node {node_id}, skipping...")
|
||||
# Return cached data if available to avoid breaking the flow
|
||||
return self.get_cached_data()
|
||||
|
||||
log_info(f"Lock acquired. Starting process_canvas_image for node_id: {node_id} (fallback unique_id: {unique_id})")
|
||||
|
||||
current_execution = self.get_execution_id()
|
||||
log_info(f"Starting process_canvas_image - execution ID: {current_execution}, trigger: {trigger}")
|
||||
log_debug(f"Canvas image filename: {canvas_image}")
|
||||
log_debug(f"Output switch: {output_switch}, Cache enabled: {cache_enabled}")
|
||||
log_debug(f"Input image provided: {input_image is not None}")
|
||||
log_debug(f"Input mask provided: {input_mask is not None}")
|
||||
# Use node_id as the primary key, as unique_id is proving unreliable
|
||||
storage_key = node_id
|
||||
|
||||
processed_image = None
|
||||
processed_mask = None
|
||||
|
||||
if current_execution != self.__class__._canvas_cache['last_execution_id']:
|
||||
log_info(f"New execution detected: {current_execution} (previous: {self.__class__._canvas_cache['last_execution_id']})")
|
||||
with self.__class__._storage_lock:
|
||||
canvas_data = self.__class__._canvas_data_storage.pop(storage_key, None)
|
||||
|
||||
self.__class__._canvas_cache['image'] = None
|
||||
self.__class__._canvas_cache['mask'] = None
|
||||
self.__class__._canvas_cache['last_execution_id'] = current_execution
|
||||
if canvas_data:
|
||||
log_info(f"Canvas data found for node {storage_key} from WebSocket")
|
||||
if canvas_data.get('image'):
|
||||
image_data = canvas_data['image'].split(',')[1]
|
||||
image_bytes = base64.b64decode(image_data)
|
||||
pil_image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
|
||||
image_array = np.array(pil_image).astype(np.float32) / 255.0
|
||||
processed_image = torch.from_numpy(image_array)[None,]
|
||||
log_debug(f"Image loaded from WebSocket, shape: {processed_image.shape}")
|
||||
|
||||
if canvas_data.get('mask'):
|
||||
mask_data = canvas_data['mask'].split(',')[1]
|
||||
mask_bytes = base64.b64decode(mask_data)
|
||||
pil_mask = Image.open(io.BytesIO(mask_bytes)).convert('L')
|
||||
mask_array = np.array(pil_mask).astype(np.float32) / 255.0
|
||||
processed_mask = torch.from_numpy(mask_array)[None,]
|
||||
log_debug(f"Mask loaded from WebSocket, shape: {processed_mask.shape}")
|
||||
else:
|
||||
log_debug(f"Same execution ID, using cached data")
|
||||
log_warn(f"No canvas data found for node {storage_key} in WebSocket cache, using fallbacks.")
|
||||
if input_image is not None:
|
||||
log_info("Using provided input_image as fallback")
|
||||
processed_image = input_image
|
||||
if input_mask is not None:
|
||||
log_info("Using provided input_mask as fallback")
|
||||
processed_mask = input_mask
|
||||
|
||||
if input_image is not None:
|
||||
log_info("Input image received, converting to PIL Image...")
|
||||
|
||||
if isinstance(input_image, torch.Tensor):
|
||||
if input_image.dim() == 4:
|
||||
input_image = input_image.squeeze(0) # 移除batch维度
|
||||
# Fallback to default tensors if nothing is loaded
|
||||
if processed_image is None:
|
||||
log_warn(f"Processed image is still None, creating default blank image.")
|
||||
processed_image = torch.zeros((1, 512, 512, 3), dtype=torch.float32)
|
||||
if processed_mask is None:
|
||||
log_warn(f"Processed mask is still None, creating default blank mask.")
|
||||
processed_mask = torch.zeros((1, 512, 512), dtype=torch.float32)
|
||||
|
||||
if input_image.shape[0] == 3: # 如果是[C, H, W]格式
|
||||
input_image = input_image.permute(1, 2, 0)
|
||||
|
||||
image_array = (input_image.cpu().numpy() * 255).astype(np.uint8)
|
||||
|
||||
if len(image_array.shape) == 2: # 如果是灰度图
|
||||
image_array = np.stack([image_array] * 3, axis=-1)
|
||||
elif len(image_array.shape) == 3 and image_array.shape[-1] != 3:
|
||||
image_array = np.transpose(image_array, (1, 2, 0))
|
||||
|
||||
try:
|
||||
|
||||
pil_image = Image.fromarray(image_array, 'RGB')
|
||||
log_debug("Successfully converted to PIL Image")
|
||||
|
||||
self.__class__._canvas_cache['image'] = pil_image
|
||||
log_debug(f"Image stored in cache with size: {pil_image.size}")
|
||||
except Exception as e:
|
||||
log_error(f"Error converting to PIL Image: {str(e)}")
|
||||
log_debug(f"Array shape: {image_array.shape}, dtype: {image_array.dtype}")
|
||||
raise
|
||||
|
||||
if input_mask is not None:
|
||||
log_info("Input mask received, converting to PIL Image...")
|
||||
if isinstance(input_mask, torch.Tensor):
|
||||
if input_mask.dim() == 4:
|
||||
input_mask = input_mask.squeeze(0)
|
||||
if input_mask.dim() == 3 and input_mask.shape[0] == 1:
|
||||
input_mask = input_mask.squeeze(0)
|
||||
|
||||
mask_array = (input_mask.cpu().numpy() * 255).astype(np.uint8)
|
||||
pil_mask = Image.fromarray(mask_array, 'L')
|
||||
log_debug("Successfully converted mask to PIL Image")
|
||||
|
||||
self.__class__._canvas_cache['mask'] = pil_mask
|
||||
log_debug(f"Mask stored in cache with size: {pil_mask.size}")
|
||||
|
||||
self.__class__._canvas_cache['cache_enabled'] = cache_enabled
|
||||
|
||||
try:
|
||||
# Wczytaj obraz bez maski
|
||||
image_without_mask_name = canvas_image.replace('.png', '_without_mask.png')
|
||||
path_image_without_mask = folder_paths.get_annotated_filepath(image_without_mask_name)
|
||||
log_debug(f"Canvas image name: {canvas_image}")
|
||||
log_debug(f"Looking for image without mask: {image_without_mask_name}")
|
||||
log_debug(f"Full path: {path_image_without_mask}")
|
||||
|
||||
# Sprawdź czy plik istnieje
|
||||
if not os.path.exists(path_image_without_mask):
|
||||
log_warn(f"Image without mask not found at: {path_image_without_mask}")
|
||||
# Spróbuj znaleźć plik w katalogu input
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
alternative_path = os.path.join(input_dir, image_without_mask_name)
|
||||
log_debug(f"Trying alternative path: {alternative_path}")
|
||||
if os.path.exists(alternative_path):
|
||||
path_image_without_mask = alternative_path
|
||||
log_info(f"Found image at alternative path: {alternative_path}")
|
||||
else:
|
||||
raise FileNotFoundError(f"Image file not found: {image_without_mask_name}")
|
||||
|
||||
i = Image.open(path_image_without_mask)
|
||||
i = ImageOps.exif_transpose(i)
|
||||
if i.mode not in ['RGB', 'RGBA']:
|
||||
i = i.convert('RGB')
|
||||
image = np.array(i).astype(np.float32) / 255.0
|
||||
if i.mode == 'RGBA':
|
||||
rgb = image[..., :3]
|
||||
alpha = image[..., 3:]
|
||||
image = rgb * alpha + (1 - alpha) * 0.5
|
||||
processed_image = torch.from_numpy(image)[None,]
|
||||
log_debug(f"Successfully loaded image without mask, shape: {processed_image.shape}")
|
||||
except Exception as e:
|
||||
log_error(f"Error loading image without mask: {str(e)}")
|
||||
processed_image = torch.ones((1, 512, 512, 3), dtype=torch.float32)
|
||||
log_debug(f"Using default image, shape: {processed_image.shape}")
|
||||
|
||||
try:
|
||||
# Wczytaj maskę
|
||||
path_image = folder_paths.get_annotated_filepath(canvas_image)
|
||||
path_mask = path_image.replace('.png', '_mask.png')
|
||||
log_debug(f"Canvas image path: {path_image}")
|
||||
log_debug(f"Looking for mask at: {path_mask}")
|
||||
|
||||
# Sprawdź czy plik maski istnieje
|
||||
if not os.path.exists(path_mask):
|
||||
log_warn(f"Mask not found at: {path_mask}")
|
||||
# Spróbuj znaleźć plik w katalogu input
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
mask_name = canvas_image.replace('.png', '_mask.png')
|
||||
alternative_mask_path = os.path.join(input_dir, mask_name)
|
||||
log_debug(f"Trying alternative mask path: {alternative_mask_path}")
|
||||
if os.path.exists(alternative_mask_path):
|
||||
path_mask = alternative_mask_path
|
||||
log_info(f"Found mask at alternative path: {alternative_mask_path}")
|
||||
|
||||
if os.path.exists(path_mask):
|
||||
log_debug(f"Mask file exists, loading...")
|
||||
mask = Image.open(path_mask).convert('L')
|
||||
mask = np.array(mask).astype(np.float32) / 255.0
|
||||
processed_mask = torch.from_numpy(mask)[None,]
|
||||
log_debug(f"Successfully loaded mask, shape: {processed_mask.shape}")
|
||||
else:
|
||||
log_debug(f"Mask file does not exist, creating default mask")
|
||||
processed_mask = torch.ones((1, processed_image.shape[1], processed_image.shape[2]),
|
||||
dtype=torch.float32)
|
||||
log_debug(f"Default mask created, shape: {processed_mask.shape}")
|
||||
except Exception as e:
|
||||
log_error(f"Error loading mask: {str(e)}")
|
||||
processed_mask = torch.ones((1, processed_image.shape[1], processed_image.shape[2]),
|
||||
dtype=torch.float32)
|
||||
log_debug(f"Fallback mask created, shape: {processed_mask.shape}")
|
||||
|
||||
if not output_switch:
|
||||
log_debug(f"Output switch is OFF, returning empty tuple")
|
||||
return ()
|
||||
return (None, None)
|
||||
|
||||
log_debug(f"About to return output - Image shape: {processed_image.shape}, Mask shape: {processed_mask.shape}")
|
||||
log_debug(f"Image tensor info - dtype: {processed_image.dtype}, device: {processed_image.device}")
|
||||
log_debug(f"Mask tensor info - dtype: {processed_mask.dtype}, device: {processed_mask.device}")
|
||||
|
||||
self.update_persistent_cache()
|
||||
|
||||
@@ -393,12 +319,13 @@ class CanvasNode:
|
||||
|
||||
except Exception as e:
|
||||
log_exception(f"Error in process_canvas_image: {str(e)}")
|
||||
return ()
|
||||
return (None, None)
|
||||
|
||||
finally:
|
||||
# Zwolnij blokadę
|
||||
self.__class__._processing_lock = None
|
||||
log_debug(f"Process completed, lock released")
|
||||
if self.__class__._processing_lock.locked():
|
||||
self.__class__._processing_lock.release()
|
||||
log_debug(f"Process completed for node {node_id}, lock released")
|
||||
|
||||
def get_cached_data(self):
|
||||
return {
|
||||
@@ -440,8 +367,80 @@ class CanvasNode:
|
||||
return cls._canvas_cache['data_flow_status'].get(flow_id)
|
||||
return cls._canvas_cache['data_flow_status']
|
||||
|
||||
@classmethod
|
||||
def _cleanup_old_websocket_data(cls):
|
||||
"""Clean up old WebSocket data from invalid nodes or data older than 5 minutes"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
cleanup_threshold = 300 # 5 minutes
|
||||
|
||||
nodes_to_remove = []
|
||||
for node_id, data in cls._websocket_data.items():
|
||||
# Remove invalid node IDs
|
||||
if node_id < 0:
|
||||
nodes_to_remove.append(node_id)
|
||||
continue
|
||||
|
||||
# Remove old data
|
||||
if current_time - data.get('timestamp', 0) > cleanup_threshold:
|
||||
nodes_to_remove.append(node_id)
|
||||
continue
|
||||
|
||||
for node_id in nodes_to_remove:
|
||||
del cls._websocket_data[node_id]
|
||||
log_debug(f"Cleaned up old WebSocket data for node {node_id}")
|
||||
|
||||
if nodes_to_remove:
|
||||
log_info(f"Cleaned up {len(nodes_to_remove)} old WebSocket entries")
|
||||
|
||||
except Exception as e:
|
||||
log_error(f"Error during WebSocket cleanup: {str(e)}")
|
||||
|
||||
@classmethod
|
||||
def setup_routes(cls):
|
||||
@PromptServer.instance.routes.get("/layerforge/canvas_ws")
|
||||
async def handle_canvas_websocket(request):
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
async for msg in ws:
|
||||
if msg.type == web.WSMsgType.TEXT:
|
||||
try:
|
||||
data = msg.json()
|
||||
node_id = data.get('nodeId')
|
||||
if not node_id:
|
||||
await ws.send_json({'status': 'error', 'message': 'nodeId is required'})
|
||||
continue
|
||||
|
||||
image_data = data.get('image')
|
||||
mask_data = data.get('mask')
|
||||
|
||||
with cls._storage_lock:
|
||||
cls._canvas_data_storage[node_id] = {
|
||||
'image': image_data,
|
||||
'mask': mask_data,
|
||||
'timestamp': time.time()
|
||||
}
|
||||
|
||||
log_info(f"Received canvas data for node {node_id} via WebSocket")
|
||||
# Send acknowledgment back to the client
|
||||
ack_payload = {
|
||||
'type': 'ack',
|
||||
'nodeId': node_id,
|
||||
'status': 'success'
|
||||
}
|
||||
await ws.send_json(ack_payload)
|
||||
log_debug(f"Sent ACK for node {node_id}")
|
||||
|
||||
except Exception as e:
|
||||
log_error(f"Error processing WebSocket message: {e}")
|
||||
await ws.send_json({'status': 'error', 'message': str(e)})
|
||||
elif msg.type == web.WSMsgType.ERROR:
|
||||
log_error(f"WebSocket connection closed with exception {ws.exception()}")
|
||||
|
||||
log_info("WebSocket connection closed")
|
||||
return ws
|
||||
|
||||
@PromptServer.instance.routes.get("/ycnode/get_canvas_data/{node_id}")
|
||||
async def get_canvas_data(request):
|
||||
try:
|
||||
@@ -811,3 +810,15 @@ def convert_tensor_to_base64(tensor, alpha_mask=None, original_alpha=None):
|
||||
log_error(f"Error in convert_tensor_to_base64: {str(e)}")
|
||||
log_debug(f"Tensor shape: {tensor.shape}, dtype: {tensor.dtype}")
|
||||
raise
|
||||
|
||||
|
||||
# Setup original API routes when module is loaded
|
||||
CanvasNode.setup_routes()
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CanvasNode": CanvasNode
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CanvasNode": "LayerForge"
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user