From be4fae2964f3401cbbf23ae94584bfe383f8ebe5 Mon Sep 17 00:00:00 2001 From: Dariusz L Date: Fri, 27 Jun 2025 05:28:13 +0200 Subject: [PATCH] Add WebSocket-based RAM output for CanvasNode Introduces a WebSocket-based mechanism for CanvasNode to send and receive canvas image and mask data in RAM, enabling fast, diskless data transfer between frontend and backend. Adds a new WebSocketManager utility, updates CanvasIO to support RAM output mode, and modifies CanvasView to send canvas data via WebSocket before prompt execution. The backend (canvas_node.py) is updated to handle WebSocket data storage and retrieval, with improved locking and cleanup logic. This change improves workflow speed and reliability by avoiding unnecessary disk I/O and ensuring up-to-date canvas data is available during node execution. --- canvas_node.py | 301 ++++++++++++++++++----------------- js/Canvas.js | 3 - js/CanvasIO.js | 193 +++++++++++++++++++--- js/CanvasView.js | 172 +++++++++++--------- js/utils/CommonUtils.js | 27 +++- js/utils/WebSocketManager.js | 160 +++++++++++++++++++ 6 files changed, 602 insertions(+), 254 deletions(-) create mode 100644 js/utils/WebSocketManager.js diff --git a/canvas_node.py b/canvas_node.py index 72b31a0..c4fe199 100644 --- a/canvas_node.py +++ b/canvas_node.py @@ -5,6 +5,8 @@ import numpy as np import folder_paths from server import PromptServer from aiohttp import web +import asyncio +import threading import os from tqdm import tqdm from torchvision import transforms @@ -91,6 +93,9 @@ class BiRefNet(torch.nn.Module): class CanvasNode: + _canvas_data_storage = {} + _storage_lock = threading.Lock() + _canvas_cache = { 'image': None, 'mask': None, @@ -99,10 +104,16 @@ class CanvasNode: 'persistent_cache': {}, 'last_execution_id': None } + + # Simple in-memory storage for canvas data, keyed by prompt_id + # WebSocket-based storage for canvas data per node + _websocket_data = {} + _websocket_listeners = {} def __init__(self): super().__init__() self.flow_id = str(uuid.uuid4()) + self.node_id = None # Will be set when node is created if self.__class__._canvas_cache['persistent_cache']: self.restore_cache() @@ -166,14 +177,18 @@ class CanvasNode: def INPUT_TYPES(cls): return { "required": { - "canvas_image": ("STRING", {"default": "canvas_image.png"}), "trigger": ("INT", {"default": 0, "min": 0, "max": 99999999, "step": 1, "hidden": True}), "output_switch": ("BOOLEAN", {"default": True}), - "cache_enabled": ("BOOLEAN", {"default": True, "label": "Enable Cache"}) + "cache_enabled": ("BOOLEAN", {"default": True, "label": "Enable Cache"}), + "node_id": ("STRING", {"default": "0", "hidden": True}), }, "optional": { "input_image": ("IMAGE",), "input_mask": ("MASK",) + }, + "hidden": { + "prompt": ("PROMPT",), + "unique_id": ("UNIQUE_ID",), } } @@ -230,161 +245,72 @@ class CanvasNode: return None # Zmienna blokująca równoczesne wykonania - _processing_lock = None - - def process_canvas_image(self, canvas_image, trigger, output_switch, cache_enabled, input_image=None, + _processing_lock = threading.Lock() + + def process_canvas_image(self, trigger, output_switch, cache_enabled, node_id, prompt=None, unique_id=None, input_image=None, input_mask=None): + + log_info(f"[CanvasNode] 🔍 process_canvas_image wejście – node_id={node_id!r}, unique_id={unique_id!r}, trigger={trigger}, output_switch={output_switch}") + try: # Sprawdź czy już trwa przetwarzanie - if self.__class__._processing_lock is not None: - log_warn(f"Process already in progress, waiting for completion...") - return () # Zwróć pusty wynik, aby uniknąć równoczesnych przetworzeń - - # Ustaw blokadę - self.__class__._processing_lock = True + if not self.__class__._processing_lock.acquire(blocking=False): + log_warn(f"Process already in progress for node {node_id}, skipping...") + # Return cached data if available to avoid breaking the flow + return self.get_cached_data() + + log_info(f"Lock acquired. Starting process_canvas_image for node_id: {node_id} (fallback unique_id: {unique_id})") - current_execution = self.get_execution_id() - log_info(f"Starting process_canvas_image - execution ID: {current_execution}, trigger: {trigger}") - log_debug(f"Canvas image filename: {canvas_image}") - log_debug(f"Output switch: {output_switch}, Cache enabled: {cache_enabled}") - log_debug(f"Input image provided: {input_image is not None}") - log_debug(f"Input mask provided: {input_mask is not None}") + # Use node_id as the primary key, as unique_id is proving unreliable + storage_key = node_id + + processed_image = None + processed_mask = None - if current_execution != self.__class__._canvas_cache['last_execution_id']: - log_info(f"New execution detected: {current_execution} (previous: {self.__class__._canvas_cache['last_execution_id']})") + with self.__class__._storage_lock: + canvas_data = self.__class__._canvas_data_storage.pop(storage_key, None) - self.__class__._canvas_cache['image'] = None - self.__class__._canvas_cache['mask'] = None - self.__class__._canvas_cache['last_execution_id'] = current_execution + if canvas_data: + log_info(f"Canvas data found for node {storage_key} from WebSocket") + if canvas_data.get('image'): + image_data = canvas_data['image'].split(',')[1] + image_bytes = base64.b64decode(image_data) + pil_image = Image.open(io.BytesIO(image_bytes)).convert('RGB') + image_array = np.array(pil_image).astype(np.float32) / 255.0 + processed_image = torch.from_numpy(image_array)[None,] + log_debug(f"Image loaded from WebSocket, shape: {processed_image.shape}") + + if canvas_data.get('mask'): + mask_data = canvas_data['mask'].split(',')[1] + mask_bytes = base64.b64decode(mask_data) + pil_mask = Image.open(io.BytesIO(mask_bytes)).convert('L') + mask_array = np.array(pil_mask).astype(np.float32) / 255.0 + processed_mask = torch.from_numpy(mask_array)[None,] + log_debug(f"Mask loaded from WebSocket, shape: {processed_mask.shape}") else: - log_debug(f"Same execution ID, using cached data") + log_warn(f"No canvas data found for node {storage_key} in WebSocket cache, using fallbacks.") + if input_image is not None: + log_info("Using provided input_image as fallback") + processed_image = input_image + if input_mask is not None: + log_info("Using provided input_mask as fallback") + processed_mask = input_mask - if input_image is not None: - log_info("Input image received, converting to PIL Image...") - if isinstance(input_image, torch.Tensor): - if input_image.dim() == 4: - input_image = input_image.squeeze(0) # 移除batch维度 + # Fallback to default tensors if nothing is loaded + if processed_image is None: + log_warn(f"Processed image is still None, creating default blank image.") + processed_image = torch.zeros((1, 512, 512, 3), dtype=torch.float32) + if processed_mask is None: + log_warn(f"Processed mask is still None, creating default blank mask.") + processed_mask = torch.zeros((1, 512, 512), dtype=torch.float32) - if input_image.shape[0] == 3: # 如果是[C, H, W]格式 - input_image = input_image.permute(1, 2, 0) - - image_array = (input_image.cpu().numpy() * 255).astype(np.uint8) - - if len(image_array.shape) == 2: # 如果是灰度图 - image_array = np.stack([image_array] * 3, axis=-1) - elif len(image_array.shape) == 3 and image_array.shape[-1] != 3: - image_array = np.transpose(image_array, (1, 2, 0)) - - try: - - pil_image = Image.fromarray(image_array, 'RGB') - log_debug("Successfully converted to PIL Image") - - self.__class__._canvas_cache['image'] = pil_image - log_debug(f"Image stored in cache with size: {pil_image.size}") - except Exception as e: - log_error(f"Error converting to PIL Image: {str(e)}") - log_debug(f"Array shape: {image_array.shape}, dtype: {image_array.dtype}") - raise - - if input_mask is not None: - log_info("Input mask received, converting to PIL Image...") - if isinstance(input_mask, torch.Tensor): - if input_mask.dim() == 4: - input_mask = input_mask.squeeze(0) - if input_mask.dim() == 3 and input_mask.shape[0] == 1: - input_mask = input_mask.squeeze(0) - - mask_array = (input_mask.cpu().numpy() * 255).astype(np.uint8) - pil_mask = Image.fromarray(mask_array, 'L') - log_debug("Successfully converted mask to PIL Image") - - self.__class__._canvas_cache['mask'] = pil_mask - log_debug(f"Mask stored in cache with size: {pil_mask.size}") - - self.__class__._canvas_cache['cache_enabled'] = cache_enabled - - try: - # Wczytaj obraz bez maski - image_without_mask_name = canvas_image.replace('.png', '_without_mask.png') - path_image_without_mask = folder_paths.get_annotated_filepath(image_without_mask_name) - log_debug(f"Canvas image name: {canvas_image}") - log_debug(f"Looking for image without mask: {image_without_mask_name}") - log_debug(f"Full path: {path_image_without_mask}") - - # Sprawdź czy plik istnieje - if not os.path.exists(path_image_without_mask): - log_warn(f"Image without mask not found at: {path_image_without_mask}") - # Spróbuj znaleźć plik w katalogu input - input_dir = folder_paths.get_input_directory() - alternative_path = os.path.join(input_dir, image_without_mask_name) - log_debug(f"Trying alternative path: {alternative_path}") - if os.path.exists(alternative_path): - path_image_without_mask = alternative_path - log_info(f"Found image at alternative path: {alternative_path}") - else: - raise FileNotFoundError(f"Image file not found: {image_without_mask_name}") - - i = Image.open(path_image_without_mask) - i = ImageOps.exif_transpose(i) - if i.mode not in ['RGB', 'RGBA']: - i = i.convert('RGB') - image = np.array(i).astype(np.float32) / 255.0 - if i.mode == 'RGBA': - rgb = image[..., :3] - alpha = image[..., 3:] - image = rgb * alpha + (1 - alpha) * 0.5 - processed_image = torch.from_numpy(image)[None,] - log_debug(f"Successfully loaded image without mask, shape: {processed_image.shape}") - except Exception as e: - log_error(f"Error loading image without mask: {str(e)}") - processed_image = torch.ones((1, 512, 512, 3), dtype=torch.float32) - log_debug(f"Using default image, shape: {processed_image.shape}") - - try: - # Wczytaj maskę - path_image = folder_paths.get_annotated_filepath(canvas_image) - path_mask = path_image.replace('.png', '_mask.png') - log_debug(f"Canvas image path: {path_image}") - log_debug(f"Looking for mask at: {path_mask}") - - # Sprawdź czy plik maski istnieje - if not os.path.exists(path_mask): - log_warn(f"Mask not found at: {path_mask}") - # Spróbuj znaleźć plik w katalogu input - input_dir = folder_paths.get_input_directory() - mask_name = canvas_image.replace('.png', '_mask.png') - alternative_mask_path = os.path.join(input_dir, mask_name) - log_debug(f"Trying alternative mask path: {alternative_mask_path}") - if os.path.exists(alternative_mask_path): - path_mask = alternative_mask_path - log_info(f"Found mask at alternative path: {alternative_mask_path}") - - if os.path.exists(path_mask): - log_debug(f"Mask file exists, loading...") - mask = Image.open(path_mask).convert('L') - mask = np.array(mask).astype(np.float32) / 255.0 - processed_mask = torch.from_numpy(mask)[None,] - log_debug(f"Successfully loaded mask, shape: {processed_mask.shape}") - else: - log_debug(f"Mask file does not exist, creating default mask") - processed_mask = torch.ones((1, processed_image.shape[1], processed_image.shape[2]), - dtype=torch.float32) - log_debug(f"Default mask created, shape: {processed_mask.shape}") - except Exception as e: - log_error(f"Error loading mask: {str(e)}") - processed_mask = torch.ones((1, processed_image.shape[1], processed_image.shape[2]), - dtype=torch.float32) - log_debug(f"Fallback mask created, shape: {processed_mask.shape}") if not output_switch: log_debug(f"Output switch is OFF, returning empty tuple") - return () + return (None, None) log_debug(f"About to return output - Image shape: {processed_image.shape}, Mask shape: {processed_mask.shape}") - log_debug(f"Image tensor info - dtype: {processed_image.dtype}, device: {processed_image.device}") - log_debug(f"Mask tensor info - dtype: {processed_mask.dtype}, device: {processed_mask.device}") self.update_persistent_cache() @@ -393,12 +319,13 @@ class CanvasNode: except Exception as e: log_exception(f"Error in process_canvas_image: {str(e)}") - return () + return (None, None) finally: # Zwolnij blokadę - self.__class__._processing_lock = None - log_debug(f"Process completed, lock released") + if self.__class__._processing_lock.locked(): + self.__class__._processing_lock.release() + log_debug(f"Process completed for node {node_id}, lock released") def get_cached_data(self): return { @@ -440,8 +367,80 @@ class CanvasNode: return cls._canvas_cache['data_flow_status'].get(flow_id) return cls._canvas_cache['data_flow_status'] + @classmethod + def _cleanup_old_websocket_data(cls): + """Clean up old WebSocket data from invalid nodes or data older than 5 minutes""" + try: + current_time = time.time() + cleanup_threshold = 300 # 5 minutes + + nodes_to_remove = [] + for node_id, data in cls._websocket_data.items(): + # Remove invalid node IDs + if node_id < 0: + nodes_to_remove.append(node_id) + continue + + # Remove old data + if current_time - data.get('timestamp', 0) > cleanup_threshold: + nodes_to_remove.append(node_id) + continue + + for node_id in nodes_to_remove: + del cls._websocket_data[node_id] + log_debug(f"Cleaned up old WebSocket data for node {node_id}") + + if nodes_to_remove: + log_info(f"Cleaned up {len(nodes_to_remove)} old WebSocket entries") + + except Exception as e: + log_error(f"Error during WebSocket cleanup: {str(e)}") + @classmethod def setup_routes(cls): + @PromptServer.instance.routes.get("/layerforge/canvas_ws") + async def handle_canvas_websocket(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + + async for msg in ws: + if msg.type == web.WSMsgType.TEXT: + try: + data = msg.json() + node_id = data.get('nodeId') + if not node_id: + await ws.send_json({'status': 'error', 'message': 'nodeId is required'}) + continue + + image_data = data.get('image') + mask_data = data.get('mask') + + with cls._storage_lock: + cls._canvas_data_storage[node_id] = { + 'image': image_data, + 'mask': mask_data, + 'timestamp': time.time() + } + + log_info(f"Received canvas data for node {node_id} via WebSocket") + # Send acknowledgment back to the client + ack_payload = { + 'type': 'ack', + 'nodeId': node_id, + 'status': 'success' + } + await ws.send_json(ack_payload) + log_debug(f"Sent ACK for node {node_id}") + + except Exception as e: + log_error(f"Error processing WebSocket message: {e}") + await ws.send_json({'status': 'error', 'message': str(e)}) + elif msg.type == web.WSMsgType.ERROR: + log_error(f"WebSocket connection closed with exception {ws.exception()}") + + log_info("WebSocket connection closed") + return ws + @PromptServer.instance.routes.get("/ycnode/get_canvas_data/{node_id}") async def get_canvas_data(request): try: @@ -811,3 +810,15 @@ def convert_tensor_to_base64(tensor, alpha_mask=None, original_alpha=None): log_error(f"Error in convert_tensor_to_base64: {str(e)}") log_debug(f"Tensor shape: {tensor.shape}, dtype: {tensor.dtype}") raise + + +# Setup original API routes when module is loaded +CanvasNode.setup_routes() + +NODE_CLASS_MAPPINGS = { + "CanvasNode": CanvasNode +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "CanvasNode": "LayerForge" +} diff --git a/js/Canvas.js b/js/Canvas.js index 0a6b382..253b8f5 100644 --- a/js/Canvas.js +++ b/js/Canvas.js @@ -266,9 +266,6 @@ export class Canvas { - async saveToServer(fileName) { - return this.canvasIO.saveToServer(fileName); - } async getFlattenedCanvasAsBlob() { return this.canvasLayers.getFlattenedCanvasAsBlob(); diff --git a/js/CanvasIO.js b/js/CanvasIO.js index a2ff041..d5750ba 100644 --- a/js/CanvasIO.js +++ b/js/CanvasIO.js @@ -1,5 +1,6 @@ import {createCanvas} from "./utils/CommonUtils.js"; import {createModuleLogger} from "./utils/LoggerUtils.js"; +import {webSocketManager} from "./utils/WebSocketManager.js"; const log = createModuleLogger('CanvasIO'); @@ -9,35 +10,38 @@ export class CanvasIO { this._saveInProgress = null; } - async saveToServer(fileName) { - if (!window.canvasSaveStates) { - window.canvasSaveStates = new Map(); - } - - const nodeId = this.canvas.node.id; - const saveKey = `${nodeId}_${fileName}`; - if (this._saveInProgress || window.canvasSaveStates.get(saveKey)) { - log.warn(`Save already in progress for node ${nodeId}, waiting...`); - return this._saveInProgress || window.canvasSaveStates.get(saveKey); - } + async saveToServer(fileName, outputMode = 'disk') { + if (outputMode === 'disk') { + if (!window.canvasSaveStates) { + window.canvasSaveStates = new Map(); + } - log.info(`Starting saveToServer with fileName: ${fileName} for node: ${nodeId}`); - log.debug(`Canvas dimensions: ${this.canvas.width}x${this.canvas.height}`); - log.debug(`Number of layers: ${this.canvas.layers.length}`); - this._saveInProgress = this._performSave(fileName); - window.canvasSaveStates.set(saveKey, this._saveInProgress); - - try { - const result = await this._saveInProgress; - return result; - } finally { - this._saveInProgress = null; - window.canvasSaveStates.delete(saveKey); - log.debug(`Save completed for node ${nodeId}, lock released`); + const nodeId = this.canvas.node.id; + const saveKey = `${nodeId}_${fileName}`; + if (this._saveInProgress || window.canvasSaveStates.get(saveKey)) { + log.warn(`Save already in progress for node ${nodeId}, waiting...`); + return this._saveInProgress || window.canvasSaveStates.get(saveKey); + } + + log.info(`Starting saveToServer (disk) with fileName: ${fileName} for node: ${nodeId}`); + this._saveInProgress = this._performSave(fileName, outputMode); + window.canvasSaveStates.set(saveKey, this._saveInProgress); + + try { + return await this._saveInProgress; + } finally { + this._saveInProgress = null; + window.canvasSaveStates.delete(saveKey); + log.debug(`Save completed for node ${nodeId}, lock released`); + } + } else { + // For RAM mode, we don't need the lock/state management as it's synchronous + log.info(`Starting saveToServer (RAM) for node: ${this.canvas.node.id}`); + return this._performSave(fileName, outputMode); } } - async _performSave(fileName) { + async _performSave(fileName, outputMode) { if (this.canvas.layers.length === 0) { log.warn(`Node ${this.canvas.node.id} has no layers, creating empty canvas`); return Promise.resolve(true); @@ -152,6 +156,15 @@ export class CanvasIO { maskCtx.globalCompositeOperation = 'source-over'; maskCtx.drawImage(tempMaskCanvas, 0, 0); } + if (outputMode === 'ram') { + const imageData = tempCanvas.toDataURL('image/png'); + const maskData = maskCanvas.toDataURL('image/png'); + log.info("Returning image and mask data as base64 for RAM mode."); + resolve({ image: imageData, mask: maskData }); + return; + } + + // --- Disk Mode (original logic) --- const fileNameWithoutMask = fileName.replace('.png', '_without_mask.png'); log.info(`Saving image without mask as: ${fileNameWithoutMask}`); @@ -204,7 +217,9 @@ export class CanvasIO { if (maskResp.status === 200) { const data = await resp.json(); - this.canvas.widget.value = fileName; + if (this.canvas.widget) { + this.canvas.widget.value = fileName; + } log.info(`All files saved successfully, widget value set to: ${fileName}`); resolve(true); } else { @@ -228,6 +243,132 @@ export class CanvasIO { }); } + async _renderOutputData() { + return new Promise((resolve) => { + const { canvas: tempCanvas, ctx: tempCtx } = createCanvas(this.canvas.width, this.canvas.height); + const { canvas: maskCanvas, ctx: maskCtx } = createCanvas(this.canvas.width, this.canvas.height); + + // This logic is mostly mirrored from _performSave to ensure consistency + tempCtx.fillStyle = '#ffffff'; + tempCtx.fillRect(0, 0, this.canvas.width, this.canvas.height); + const visibilityCanvas = document.createElement('canvas'); + visibilityCanvas.width = this.canvas.width; + visibilityCanvas.height = this.canvas.height; + const visibilityCtx = visibilityCanvas.getContext('2d', { alpha: true }); + maskCtx.fillStyle = '#ffffff'; // Start with a white mask (nothing masked) + maskCtx.fillRect(0, 0, this.canvas.width, this.canvas.height); + + const sortedLayers = this.canvas.layers.sort((a, b) => a.zIndex - b.zIndex); + sortedLayers.forEach((layer) => { + // Render layer to main canvas + tempCtx.save(); + tempCtx.globalCompositeOperation = layer.blendMode || 'normal'; + tempCtx.globalAlpha = layer.opacity !== undefined ? layer.opacity : 1; + tempCtx.translate(layer.x + layer.width / 2, layer.y + layer.height / 2); + tempCtx.rotate(layer.rotation * Math.PI / 180); + tempCtx.drawImage(layer.image, -layer.width / 2, -layer.height / 2, layer.width, layer.height); + tempCtx.restore(); + + // Render layer to visibility canvas for the mask + visibilityCtx.save(); + visibilityCtx.translate(layer.x + layer.width / 2, layer.y + layer.height / 2); + visibilityCtx.rotate(layer.rotation * Math.PI / 180); + visibilityCtx.drawImage(layer.image, -layer.width / 2, -layer.height / 2, layer.width, layer.height); + visibilityCtx.restore(); + }); + + // Create layer visibility mask + const visibilityData = visibilityCtx.getImageData(0, 0, this.canvas.width, this.canvas.height); + const maskData = maskCtx.getImageData(0, 0, this.canvas.width, this.canvas.height); + for (let i = 0; i < visibilityData.data.length; i += 4) { + const alpha = visibilityData.data[i + 3]; + const maskValue = 255 - alpha; // Invert alpha to create the mask + maskData.data[i] = maskData.data[i + 1] = maskData.data[i + 2] = maskValue; + maskData.data[i + 3] = 255; // Solid mask + } + maskCtx.putImageData(maskData, 0, 0); + + // Composite the tool mask on top + const toolMaskCanvas = this.canvas.maskTool.getMask(); + if (toolMaskCanvas) { + // Create a temp canvas for processing the mask + const tempMaskCanvas = document.createElement('canvas'); + tempMaskCanvas.width = this.canvas.width; + tempMaskCanvas.height = this.canvas.height; + const tempMaskCtx = tempMaskCanvas.getContext('2d'); + + // Clear the canvas + tempMaskCtx.clearRect(0, 0, tempMaskCanvas.width, tempMaskCanvas.height); + + // Calculate the correct position to extract the mask + const maskX = this.canvas.maskTool.x; + const maskY = this.canvas.maskTool.y; + + log.debug(`[renderOutputData] Extracting mask from world position (${maskX}, ${maskY})`); + + const sourceX = Math.max(0, -maskX); + const sourceY = Math.max(0, -maskY); + const destX = Math.max(0, maskX); + const destY = Math.max(0, maskY); + + const copyWidth = Math.min(toolMaskCanvas.width - sourceX, this.canvas.width - destX); + const copyHeight = Math.min(toolMaskCanvas.height - sourceY, this.canvas.height - destY); + + if (copyWidth > 0 && copyHeight > 0) { + tempMaskCtx.drawImage( + toolMaskCanvas, + sourceX, sourceY, copyWidth, copyHeight, + destX, destY, copyWidth, copyHeight + ); + } + + // Convert the brush mask (white with alpha) to a solid white mask on black background. + const tempMaskData = tempMaskCtx.getImageData(0, 0, this.canvas.width, this.canvas.height); + for (let i = 0; i < tempMaskData.data.length; i += 4) { + const alpha = tempMaskData.data[i + 3]; + // The painted area (alpha > 0) should become white (255). + tempMaskData.data[i] = tempMaskData.data[i+1] = tempMaskData.data[i+2] = alpha; + tempMaskData.data[i + 3] = 255; // Solid alpha + } + tempMaskCtx.putImageData(tempMaskData, 0, 0); + + // Use 'screen' blending mode. This correctly adds the white brush mask + // to the existing layer visibility mask. (white + anything = white) + maskCtx.globalCompositeOperation = 'screen'; + maskCtx.drawImage(tempMaskCanvas, 0, 0); + } + + const imageDataUrl = tempCanvas.toDataURL('image/png'); + const maskDataUrl = maskCanvas.toDataURL('image/png'); + + resolve({ image: imageDataUrl, mask: maskDataUrl }); + }); + } + + async sendDataViaWebSocket(nodeId) { + log.info(`Preparing to send data for node ${nodeId} via WebSocket.`); + + const { image, mask } = await this._renderOutputData(); + + try { + log.info(`Sending data for node ${nodeId}...`); + await webSocketManager.sendMessage({ + type: 'canvas_data', + nodeId: String(nodeId), + image: image, + mask: mask, + }, true); // `true` requires an acknowledgment + + log.info(`Data for node ${nodeId} has been sent and acknowledged by the server.`); + return true; + } catch (error) { + log.error(`Failed to send data for node ${nodeId}:`, error); + // We can alert the user here or handle it silently. + // For now, let's throw to make it clear the process failed. + throw new Error(`Failed to get confirmation from server for node ${nodeId}. The workflow might not have the latest canvas data.`); + } + } + async addInputToCanvas(inputImage, inputMask) { try { log.debug("Adding input to canvas:", {inputImage}); diff --git a/js/CanvasView.js b/js/CanvasView.js index b18835e..f2044db 100644 --- a/js/CanvasView.js +++ b/js/CanvasView.js @@ -377,8 +377,7 @@ async function createCanvasWidget(node, widget, app) { const img = new Image(); img.onload = async () => { canvas.addLayer(img); - await saveWithFallback(widget.value); - app.graph.runStep(); + await updateOutput(); }; img.src = event.target.result; }; @@ -392,8 +391,7 @@ async function createCanvasWidget(node, widget, app) { textContent: "Import Input", onclick: async () => { if (await canvas.importLatestImage()) { - await saveWithFallback(widget.value); - app.graph.runStep(); + await updateOutput(); } } }), @@ -574,8 +572,7 @@ async function createCanvasWidget(node, widget, app) { canvas.updateSelection([newLayer]); canvas.render(); canvas.saveState(); - await saveWithFallback(widget.value); - app.graph.runStep(); + await updateOutput(); } catch (error) { log.error("Matting error:", error); alert(`Error during matting process: ${error.message}`); @@ -745,7 +742,8 @@ async function createCanvasWidget(node, widget, app) { const triggerWidget = node.widgets.find(w => w.name === "trigger"); const updateOutput = async () => { - await saveWithFallback(widget.value); + // Only increment trigger and run step - don't save to disk here + // Saving to disk will happen during execution_start event triggerWidget.value = (triggerWidget.value + 1) % 99999999; app.graph.runStep(); }; @@ -790,8 +788,9 @@ async function createCanvasWidget(node, widget, app) { canvas.render(); }; - canvas.canvas.addEventListener('mouseup', updateOutput); - canvas.canvas.addEventListener('mouseleave', updateOutput); + // Remove automatic saving on mouse events - only save during execution + // canvas.canvas.addEventListener('mouseup', updateOutput); + // canvas.canvas.addEventListener('mouseleave', updateOutput); const mainContainer = $el("div.painterMainContainer", { @@ -922,66 +921,8 @@ async function createCanvasWidget(node, widget, app) { if (!window.canvasExecutionStates) { window.canvasExecutionStates = new Map(); } - const saveWithFallback = async (fileName) => { - try { - const uniqueFileName = generateUniqueFileName(fileName, node.id); - log.debug(`Attempting to save with unique name: ${uniqueFileName}`); - return await canvas.saveToServer(uniqueFileName); - } catch (error) { - log.warn(`Failed to save with unique name, falling back to original: ${fileName}`, error); - return await canvas.saveToServer(fileName); - } - }; - api.addEventListener("execution_start", async (event) => { - const executionData = event.detail || {}; - const currentPromptId = executionData.prompt_id; - - log.info(`Execution start event for node ${node.id}, prompt_id: ${currentPromptId}`); - log.debug(`Widget value: ${widget.value}`); - log.debug(`Node inputs: ${node.inputs?.length || 0}`); - log.debug(`Canvas layers count: ${canvas.layers.length}`); - if (window.canvasExecutionStates.get(node.id)) { - log.warn(`Execution already in progress for node ${node.id}, skipping...`); - return; - } - window.canvasExecutionStates.set(node.id, true); - - try { - if (canvas.layers.length === 0) { - log.warn(`Node ${node.id} has no layers, skipping save to server`); - } else { - await saveWithFallback(widget.value); - log.info(`Canvas saved to server for node ${node.id}`); - } - if (node.inputs[0]?.link) { - const linkId = node.inputs[0].link; - const inputData = app.nodeOutputs[linkId]; - log.debug(`Input link ${linkId} has data: ${!!inputData}`); - if (inputData) { - imageCache.set(linkId, inputData); - log.debug(`Input data cached for link ${linkId}`); - } - } else { - log.debug(`No input link found`); - } - } catch (error) { - log.error(`Error during execution for node ${node.id}:`, error); - } finally { - window.canvasExecutionStates.set(node.id, false); - log.debug(`Execution completed for node ${node.id}, flag released`); - } - }); - - const originalSaveToServer = canvas.saveToServer; - canvas.saveToServer = async function (fileName) { - log.debug(`saveToServer called with fileName: ${fileName}`); - log.debug(`Current execution context - node ID: ${node.id}`); - const result = await originalSaveToServer.call(this, fileName); - log.debug(`saveToServer completed, result: ${result}`); - return result; - }; node.canvasWidget = canvas; @@ -996,30 +937,111 @@ async function createCanvasWidget(node, widget, app) { } +const canvasNodeInstances = new Map(); + app.registerExtension({ name: "Comfy.CanvasNode", + + init() { + // Monkey-patch the queuePrompt function to send canvas data via WebSocket before sending the prompt + const originalQueuePrompt = app.queuePrompt; + app.queuePrompt = async function(number, prompt) { + log.info("Preparing to queue prompt..."); + + if (canvasNodeInstances.size > 0) { + log.info(`Found ${canvasNodeInstances.size} CanvasNode(s). Sending data via WebSocket...`); + + const sendPromises = []; + for (const [nodeId, canvasWidget] of canvasNodeInstances.entries()) { + // Ensure the node still exists on the graph before sending data + if (app.graph.getNodeById(nodeId) && canvasWidget.canvas && canvasWidget.canvas.canvasIO) { + log.debug(`Sending data for canvas node ${nodeId}`); + // This now returns a promise that resolves upon server ACK + sendPromises.push(canvasWidget.canvas.canvasIO.sendDataViaWebSocket(nodeId)); + } else { + // If node doesn't exist, it might have been deleted, so we can clean up the map + log.warn(`Node ${nodeId} not found in graph, removing from instances map.`); + canvasNodeInstances.delete(nodeId); + } + } + + try { + // Wait for all WebSocket messages to be acknowledged + await Promise.all(sendPromises); + log.info("All canvas data has been sent and acknowledged by the server."); + } catch (error) { + log.error("Failed to send canvas data for one or more nodes. Aborting prompt.", error); + // IMPORTANT: Stop the prompt from queueing if data transfer fails. + // You might want to show a user-facing error here. + alert(`CanvasNode Error: ${error.message}`); + return; // Stop execution + } + } + + log.info("All pre-prompt tasks complete. Proceeding with original queuePrompt."); + // Proceed with the original queuePrompt logic + return originalQueuePrompt.apply(this, arguments); + }; + }, + async beforeRegisterNodeDef(nodeType, nodeData, app) { if (nodeType.comfyClass === "CanvasNode") { const onNodeCreated = nodeType.prototype.onNodeCreated; - nodeType.prototype.onNodeCreated = async function () { - log.info("CanvasNode created, ID:", this.id); + nodeType.prototype.onNodeCreated = function () { + log.debug("CanvasNode onNodeCreated: Base widget setup."); + // Call original onNodeCreated to ensure widgets are created const r = onNodeCreated?.apply(this, arguments); - - const widget = this.widgets.find(w => w.name === "canvas_image"); - log.debug("Found canvas_image widget:", widget); - await createCanvasWidget(this, widget, app); - + // The main initialization is moved to onAdded return r; }; + // onAdded is the most reliable callback for when a node is fully added to the graph and has an ID + nodeType.prototype.onAdded = async function() { + log.info(`CanvasNode onAdded, ID: ${this.id}`); + log.debug(`Available widgets in onAdded:`, this.widgets.map(w => w.name)); + + // Prevent re-initialization if the widget already exists + if (this.canvasWidget) { + log.warn(`CanvasNode ${this.id} already initialized. Skipping onAdded setup.`); + return; + } + + // Now that we are in onAdded, this.id is guaranteed to be correct. + // Set the hidden node_id widget's value for backend communication. + const nodeIdWidget = this.widgets.find(w => w.name === "node_id"); + if (nodeIdWidget) { + nodeIdWidget.value = String(this.id); + log.debug(`Set hidden node_id widget to: ${nodeIdWidget.value}`); + } else { + log.error("Could not find the hidden node_id widget!"); + } + + // Create the main canvas widget and register it in our global map + // We pass `null` for the widget parameter as we are not using a pre-defined widget. + const canvasWidget = await createCanvasWidget(this, null, app); + canvasNodeInstances.set(this.id, canvasWidget); + log.info(`Registered CanvasNode instance for ID: ${this.id}`); + }; + const onRemoved = nodeType.prototype.onRemoved; nodeType.prototype.onRemoved = function () { + log.info(`Cleaning up canvas node ${this.id}`); + + // Clean up from our instance map + canvasNodeInstances.delete(this.id); + log.info(`Deregistered CanvasNode instance for ID: ${this.id}`); + + // Clean up execution state + if (window.canvasExecutionStates) { + window.canvasExecutionStates.delete(this.id); + } + const tooltip = document.getElementById(`painter-help-tooltip-${this.id}`); if (tooltip) { tooltip.remove(); } const backdrop = document.querySelector('.painter-modal-backdrop'); - if (backdrop && backdrop.contains(this.canvasWidget.canvas)) { + if (backdrop && backdrop.contains(this.canvasWidget?.canvas)) { document.body.removeChild(backdrop); } diff --git a/js/utils/CommonUtils.js b/js/utils/CommonUtils.js index 3fffc89..88824f1 100644 --- a/js/utils/CommonUtils.js +++ b/js/utils/CommonUtils.js @@ -125,12 +125,29 @@ export function cloneLayers(layers) { * @returns {string} Sygnatura JSON */ export function getStateSignature(layers) { - return JSON.stringify(layers.map(layer => { - const sig = {...layer}; - if (sig.imageId) { - sig.imageId = sig.imageId; + return JSON.stringify(layers.map((layer, index) => { + const sig = { + index: index, + x: Math.round(layer.x * 100) / 100, // Round to avoid floating point precision issues + y: Math.round(layer.y * 100) / 100, + width: Math.round(layer.width * 100) / 100, + height: Math.round(layer.height * 100) / 100, + rotation: Math.round((layer.rotation || 0) * 100) / 100, + zIndex: layer.zIndex, + blendMode: layer.blendMode || 'normal', + opacity: layer.opacity !== undefined ? Math.round(layer.opacity * 100) / 100 : 1 + }; + + // Include imageId if available + if (layer.imageId) { + sig.imageId = layer.imageId; } - delete sig.image; + + // Include image src as fallback identifier + if (layer.image && layer.image.src) { + sig.imageSrc = layer.image.src.substring(0, 100); // First 100 chars to avoid huge signatures + } + return sig; })); } diff --git a/js/utils/WebSocketManager.js b/js/utils/WebSocketManager.js new file mode 100644 index 0000000..ae94684 --- /dev/null +++ b/js/utils/WebSocketManager.js @@ -0,0 +1,160 @@ +import {createModuleLogger} from "./LoggerUtils.js"; + +const log = createModuleLogger('WebSocketManager'); + +class WebSocketManager { + constructor(url) { + this.url = url; + this.socket = null; + this.messageQueue = []; + this.isConnecting = false; + this.reconnectAttempts = 0; + this.maxReconnectAttempts = 10; + this.reconnectInterval = 5000; // 5 seconds + this.ackCallbacks = new Map(); // Store callbacks for messages awaiting ACK + this.messageIdCounter = 0; + + this.connect(); + } + + connect() { + if (this.socket && this.socket.readyState === WebSocket.OPEN) { + log.debug("WebSocket is already open."); + return; + } + + if (this.isConnecting) { + log.debug("Connection attempt already in progress."); + return; + } + + this.isConnecting = true; + log.info(`Connecting to WebSocket at ${this.url}...`); + + try { + this.socket = new WebSocket(this.url); + + this.socket.onopen = () => { + this.isConnecting = false; + this.reconnectAttempts = 0; + log.info("WebSocket connection established."); + this.flushMessageQueue(); + }; + + this.socket.onmessage = (event) => { + try { + const data = JSON.parse(event.data); + log.debug("Received message:", data); + + if (data.type === 'ack' && data.nodeId) { + const callback = this.ackCallbacks.get(data.nodeId); + if (callback) { + log.debug(`ACK received for nodeId: ${data.nodeId}, resolving promise.`); + callback.resolve(data); + this.ackCallbacks.delete(data.nodeId); + } + } + // Handle other incoming messages if needed + } catch (error) { + log.error("Error parsing incoming WebSocket message:", error); + } + }; + + this.socket.onclose = (event) => { + this.isConnecting = false; + if (event.wasClean) { + log.info(`WebSocket closed cleanly, code=${event.code}, reason=${event.reason}`); + } else { + log.warn("WebSocket connection died. Attempting to reconnect..."); + this.handleReconnect(); + } + }; + + this.socket.onerror = (error) => { + this.isConnecting = false; + log.error("WebSocket error:", error); + // The onclose event will be fired next, which will handle reconnection. + }; + } catch (error) { + this.isConnecting = false; + log.error("Failed to create WebSocket connection:", error); + this.handleReconnect(); + } + } + + handleReconnect() { + if (this.reconnectAttempts < this.maxReconnectAttempts) { + this.reconnectAttempts++; + log.info(`Reconnect attempt ${this.reconnectAttempts}/${this.maxReconnectAttempts}...`); + setTimeout(() => this.connect(), this.reconnectInterval); + } else { + log.error("Max reconnect attempts reached. Giving up."); + } + } + + sendMessage(data, requiresAck = false) { + return new Promise((resolve, reject) => { + const nodeId = data.nodeId; + if (requiresAck && !nodeId) { + return reject(new Error("A nodeId is required for messages that need acknowledgment.")); + } + + const message = JSON.stringify(data); + + if (this.socket && this.socket.readyState === WebSocket.OPEN) { + this.socket.send(message); + log.debug("Sent message:", data); + if (requiresAck) { + log.debug(`Message for nodeId ${nodeId} requires ACK. Setting up callback.`); + // Set a timeout for the ACK + const timeout = setTimeout(() => { + this.ackCallbacks.delete(nodeId); + reject(new Error(`ACK timeout for nodeId ${nodeId}`)); + log.warn(`ACK timeout for nodeId ${nodeId}.`); + }, 10000); // 10-second timeout + + this.ackCallbacks.set(nodeId, { + resolve: (responseData) => { + clearTimeout(timeout); + resolve(responseData); + }, + reject: (error) => { + clearTimeout(timeout); + reject(error); + } + }); + } else { + resolve(); // Resolve immediately if no ACK is needed + } + } else { + log.warn("WebSocket not open. Queuing message."); + // Note: The current queueing doesn't support ACK promises well. + // For simplicity, we'll focus on the connected case. + // A more robust implementation would wrap the queued message in a function. + this.messageQueue.push(message); + if (!this.isConnecting) { + this.connect(); + } + // For now, we reject if not connected and ACK is required. + if (requiresAck) { + reject(new Error("Cannot send message with ACK required while disconnected.")); + } + } + }); + } + + flushMessageQueue() { + log.debug(`Flushing ${this.messageQueue.length} queued messages.`); + // Note: This simple flush doesn't handle ACKs for queued messages. + // This should be acceptable as data is sent right before queueing a prompt, + // at which point the socket should ideally be connected. + while (this.messageQueue.length > 0) { + const message = this.messageQueue.shift(); + this.socket.send(message); + } + } +} + +// Create a singleton instance of the WebSocketManager +const wsUrl = `ws://${window.location.host}/layerforge/canvas_ws`; +export const webSocketManager = new WebSocketManager(wsUrl);