Add WebSocket-based RAM output for CanvasNode

Introduces a WebSocket-based mechanism for CanvasNode to send and receive canvas image and mask data in RAM, enabling fast, diskless data transfer between frontend and backend. Adds a new WebSocketManager utility, updates CanvasIO to support RAM output mode, and modifies CanvasView to send canvas data via WebSocket before prompt execution. The backend (canvas_node.py) is updated to handle WebSocket data storage and retrieval, with improved locking and cleanup logic. This change improves workflow speed and reliability by avoiding unnecessary disk I/O and ensuring up-to-date canvas data is available during node execution.
This commit is contained in:
Dariusz L
2025-06-27 05:28:13 +02:00
parent daf3abeea7
commit be4fae2964
6 changed files with 602 additions and 254 deletions

View File

@@ -5,6 +5,8 @@ import numpy as np
import folder_paths import folder_paths
from server import PromptServer from server import PromptServer
from aiohttp import web from aiohttp import web
import asyncio
import threading
import os import os
from tqdm import tqdm from tqdm import tqdm
from torchvision import transforms from torchvision import transforms
@@ -91,6 +93,9 @@ class BiRefNet(torch.nn.Module):
class CanvasNode: class CanvasNode:
_canvas_data_storage = {}
_storage_lock = threading.Lock()
_canvas_cache = { _canvas_cache = {
'image': None, 'image': None,
'mask': None, 'mask': None,
@@ -100,9 +105,15 @@ class CanvasNode:
'last_execution_id': None 'last_execution_id': None
} }
# Simple in-memory storage for canvas data, keyed by prompt_id
# WebSocket-based storage for canvas data per node
_websocket_data = {}
_websocket_listeners = {}
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.flow_id = str(uuid.uuid4()) self.flow_id = str(uuid.uuid4())
self.node_id = None # Will be set when node is created
if self.__class__._canvas_cache['persistent_cache']: if self.__class__._canvas_cache['persistent_cache']:
self.restore_cache() self.restore_cache()
@@ -166,14 +177,18 @@ class CanvasNode:
def INPUT_TYPES(cls): def INPUT_TYPES(cls):
return { return {
"required": { "required": {
"canvas_image": ("STRING", {"default": "canvas_image.png"}),
"trigger": ("INT", {"default": 0, "min": 0, "max": 99999999, "step": 1, "hidden": True}), "trigger": ("INT", {"default": 0, "min": 0, "max": 99999999, "step": 1, "hidden": True}),
"output_switch": ("BOOLEAN", {"default": True}), "output_switch": ("BOOLEAN", {"default": True}),
"cache_enabled": ("BOOLEAN", {"default": True, "label": "Enable Cache"}) "cache_enabled": ("BOOLEAN", {"default": True, "label": "Enable Cache"}),
"node_id": ("STRING", {"default": "0", "hidden": True}),
}, },
"optional": { "optional": {
"input_image": ("IMAGE",), "input_image": ("IMAGE",),
"input_mask": ("MASK",) "input_mask": ("MASK",)
},
"hidden": {
"prompt": ("PROMPT",),
"unique_id": ("UNIQUE_ID",),
} }
} }
@@ -230,161 +245,72 @@ class CanvasNode:
return None return None
# Zmienna blokująca równoczesne wykonania # Zmienna blokująca równoczesne wykonania
_processing_lock = None _processing_lock = threading.Lock()
def process_canvas_image(self, canvas_image, trigger, output_switch, cache_enabled, input_image=None, def process_canvas_image(self, trigger, output_switch, cache_enabled, node_id, prompt=None, unique_id=None, input_image=None,
input_mask=None): input_mask=None):
log_info(f"[CanvasNode] 🔍 process_canvas_image wejście node_id={node_id!r}, unique_id={unique_id!r}, trigger={trigger}, output_switch={output_switch}")
try: try:
# Sprawdź czy już trwa przetwarzanie # Sprawdź czy już trwa przetwarzanie
if self.__class__._processing_lock is not None: if not self.__class__._processing_lock.acquire(blocking=False):
log_warn(f"Process already in progress, waiting for completion...") log_warn(f"Process already in progress for node {node_id}, skipping...")
return () # Zwróć pusty wynik, aby uniknąć równoczesnych przetworzeń # Return cached data if available to avoid breaking the flow
return self.get_cached_data()
# Ustaw blokadę log_info(f"Lock acquired. Starting process_canvas_image for node_id: {node_id} (fallback unique_id: {unique_id})")
self.__class__._processing_lock = True
current_execution = self.get_execution_id() # Use node_id as the primary key, as unique_id is proving unreliable
log_info(f"Starting process_canvas_image - execution ID: {current_execution}, trigger: {trigger}") storage_key = node_id
log_debug(f"Canvas image filename: {canvas_image}")
log_debug(f"Output switch: {output_switch}, Cache enabled: {cache_enabled}")
log_debug(f"Input image provided: {input_image is not None}")
log_debug(f"Input mask provided: {input_mask is not None}")
if current_execution != self.__class__._canvas_cache['last_execution_id']: processed_image = None
log_info(f"New execution detected: {current_execution} (previous: {self.__class__._canvas_cache['last_execution_id']})") processed_mask = None
self.__class__._canvas_cache['image'] = None with self.__class__._storage_lock:
self.__class__._canvas_cache['mask'] = None canvas_data = self.__class__._canvas_data_storage.pop(storage_key, None)
self.__class__._canvas_cache['last_execution_id'] = current_execution
if canvas_data:
log_info(f"Canvas data found for node {storage_key} from WebSocket")
if canvas_data.get('image'):
image_data = canvas_data['image'].split(',')[1]
image_bytes = base64.b64decode(image_data)
pil_image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
image_array = np.array(pil_image).astype(np.float32) / 255.0
processed_image = torch.from_numpy(image_array)[None,]
log_debug(f"Image loaded from WebSocket, shape: {processed_image.shape}")
if canvas_data.get('mask'):
mask_data = canvas_data['mask'].split(',')[1]
mask_bytes = base64.b64decode(mask_data)
pil_mask = Image.open(io.BytesIO(mask_bytes)).convert('L')
mask_array = np.array(pil_mask).astype(np.float32) / 255.0
processed_mask = torch.from_numpy(mask_array)[None,]
log_debug(f"Mask loaded from WebSocket, shape: {processed_mask.shape}")
else: else:
log_debug(f"Same execution ID, using cached data") log_warn(f"No canvas data found for node {storage_key} in WebSocket cache, using fallbacks.")
if input_image is not None: if input_image is not None:
log_info("Input image received, converting to PIL Image...") log_info("Using provided input_image as fallback")
processed_image = input_image
if isinstance(input_image, torch.Tensor):
if input_image.dim() == 4:
input_image = input_image.squeeze(0) # 移除batch维度
if input_image.shape[0] == 3: # 如果是[C, H, W]格式
input_image = input_image.permute(1, 2, 0)
image_array = (input_image.cpu().numpy() * 255).astype(np.uint8)
if len(image_array.shape) == 2: # 如果是灰度图
image_array = np.stack([image_array] * 3, axis=-1)
elif len(image_array.shape) == 3 and image_array.shape[-1] != 3:
image_array = np.transpose(image_array, (1, 2, 0))
try:
pil_image = Image.fromarray(image_array, 'RGB')
log_debug("Successfully converted to PIL Image")
self.__class__._canvas_cache['image'] = pil_image
log_debug(f"Image stored in cache with size: {pil_image.size}")
except Exception as e:
log_error(f"Error converting to PIL Image: {str(e)}")
log_debug(f"Array shape: {image_array.shape}, dtype: {image_array.dtype}")
raise
if input_mask is not None: if input_mask is not None:
log_info("Input mask received, converting to PIL Image...") log_info("Using provided input_mask as fallback")
if isinstance(input_mask, torch.Tensor): processed_mask = input_mask
if input_mask.dim() == 4:
input_mask = input_mask.squeeze(0)
if input_mask.dim() == 3 and input_mask.shape[0] == 1:
input_mask = input_mask.squeeze(0)
mask_array = (input_mask.cpu().numpy() * 255).astype(np.uint8)
pil_mask = Image.fromarray(mask_array, 'L')
log_debug("Successfully converted mask to PIL Image")
self.__class__._canvas_cache['mask'] = pil_mask # Fallback to default tensors if nothing is loaded
log_debug(f"Mask stored in cache with size: {pil_mask.size}") if processed_image is None:
log_warn(f"Processed image is still None, creating default blank image.")
processed_image = torch.zeros((1, 512, 512, 3), dtype=torch.float32)
if processed_mask is None:
log_warn(f"Processed mask is still None, creating default blank mask.")
processed_mask = torch.zeros((1, 512, 512), dtype=torch.float32)
self.__class__._canvas_cache['cache_enabled'] = cache_enabled
try:
# Wczytaj obraz bez maski
image_without_mask_name = canvas_image.replace('.png', '_without_mask.png')
path_image_without_mask = folder_paths.get_annotated_filepath(image_without_mask_name)
log_debug(f"Canvas image name: {canvas_image}")
log_debug(f"Looking for image without mask: {image_without_mask_name}")
log_debug(f"Full path: {path_image_without_mask}")
# Sprawdź czy plik istnieje
if not os.path.exists(path_image_without_mask):
log_warn(f"Image without mask not found at: {path_image_without_mask}")
# Spróbuj znaleźć plik w katalogu input
input_dir = folder_paths.get_input_directory()
alternative_path = os.path.join(input_dir, image_without_mask_name)
log_debug(f"Trying alternative path: {alternative_path}")
if os.path.exists(alternative_path):
path_image_without_mask = alternative_path
log_info(f"Found image at alternative path: {alternative_path}")
else:
raise FileNotFoundError(f"Image file not found: {image_without_mask_name}")
i = Image.open(path_image_without_mask)
i = ImageOps.exif_transpose(i)
if i.mode not in ['RGB', 'RGBA']:
i = i.convert('RGB')
image = np.array(i).astype(np.float32) / 255.0
if i.mode == 'RGBA':
rgb = image[..., :3]
alpha = image[..., 3:]
image = rgb * alpha + (1 - alpha) * 0.5
processed_image = torch.from_numpy(image)[None,]
log_debug(f"Successfully loaded image without mask, shape: {processed_image.shape}")
except Exception as e:
log_error(f"Error loading image without mask: {str(e)}")
processed_image = torch.ones((1, 512, 512, 3), dtype=torch.float32)
log_debug(f"Using default image, shape: {processed_image.shape}")
try:
# Wczytaj maskę
path_image = folder_paths.get_annotated_filepath(canvas_image)
path_mask = path_image.replace('.png', '_mask.png')
log_debug(f"Canvas image path: {path_image}")
log_debug(f"Looking for mask at: {path_mask}")
# Sprawdź czy plik maski istnieje
if not os.path.exists(path_mask):
log_warn(f"Mask not found at: {path_mask}")
# Spróbuj znaleźć plik w katalogu input
input_dir = folder_paths.get_input_directory()
mask_name = canvas_image.replace('.png', '_mask.png')
alternative_mask_path = os.path.join(input_dir, mask_name)
log_debug(f"Trying alternative mask path: {alternative_mask_path}")
if os.path.exists(alternative_mask_path):
path_mask = alternative_mask_path
log_info(f"Found mask at alternative path: {alternative_mask_path}")
if os.path.exists(path_mask):
log_debug(f"Mask file exists, loading...")
mask = Image.open(path_mask).convert('L')
mask = np.array(mask).astype(np.float32) / 255.0
processed_mask = torch.from_numpy(mask)[None,]
log_debug(f"Successfully loaded mask, shape: {processed_mask.shape}")
else:
log_debug(f"Mask file does not exist, creating default mask")
processed_mask = torch.ones((1, processed_image.shape[1], processed_image.shape[2]),
dtype=torch.float32)
log_debug(f"Default mask created, shape: {processed_mask.shape}")
except Exception as e:
log_error(f"Error loading mask: {str(e)}")
processed_mask = torch.ones((1, processed_image.shape[1], processed_image.shape[2]),
dtype=torch.float32)
log_debug(f"Fallback mask created, shape: {processed_mask.shape}")
if not output_switch: if not output_switch:
log_debug(f"Output switch is OFF, returning empty tuple") log_debug(f"Output switch is OFF, returning empty tuple")
return () return (None, None)
log_debug(f"About to return output - Image shape: {processed_image.shape}, Mask shape: {processed_mask.shape}") log_debug(f"About to return output - Image shape: {processed_image.shape}, Mask shape: {processed_mask.shape}")
log_debug(f"Image tensor info - dtype: {processed_image.dtype}, device: {processed_image.device}")
log_debug(f"Mask tensor info - dtype: {processed_mask.dtype}, device: {processed_mask.device}")
self.update_persistent_cache() self.update_persistent_cache()
@@ -393,12 +319,13 @@ class CanvasNode:
except Exception as e: except Exception as e:
log_exception(f"Error in process_canvas_image: {str(e)}") log_exception(f"Error in process_canvas_image: {str(e)}")
return () return (None, None)
finally: finally:
# Zwolnij blokadę # Zwolnij blokadę
self.__class__._processing_lock = None if self.__class__._processing_lock.locked():
log_debug(f"Process completed, lock released") self.__class__._processing_lock.release()
log_debug(f"Process completed for node {node_id}, lock released")
def get_cached_data(self): def get_cached_data(self):
return { return {
@@ -440,8 +367,80 @@ class CanvasNode:
return cls._canvas_cache['data_flow_status'].get(flow_id) return cls._canvas_cache['data_flow_status'].get(flow_id)
return cls._canvas_cache['data_flow_status'] return cls._canvas_cache['data_flow_status']
@classmethod
def _cleanup_old_websocket_data(cls):
"""Clean up old WebSocket data from invalid nodes or data older than 5 minutes"""
try:
current_time = time.time()
cleanup_threshold = 300 # 5 minutes
nodes_to_remove = []
for node_id, data in cls._websocket_data.items():
# Remove invalid node IDs
if node_id < 0:
nodes_to_remove.append(node_id)
continue
# Remove old data
if current_time - data.get('timestamp', 0) > cleanup_threshold:
nodes_to_remove.append(node_id)
continue
for node_id in nodes_to_remove:
del cls._websocket_data[node_id]
log_debug(f"Cleaned up old WebSocket data for node {node_id}")
if nodes_to_remove:
log_info(f"Cleaned up {len(nodes_to_remove)} old WebSocket entries")
except Exception as e:
log_error(f"Error during WebSocket cleanup: {str(e)}")
@classmethod @classmethod
def setup_routes(cls): def setup_routes(cls):
@PromptServer.instance.routes.get("/layerforge/canvas_ws")
async def handle_canvas_websocket(request):
ws = web.WebSocketResponse()
await ws.prepare(request)
async for msg in ws:
if msg.type == web.WSMsgType.TEXT:
try:
data = msg.json()
node_id = data.get('nodeId')
if not node_id:
await ws.send_json({'status': 'error', 'message': 'nodeId is required'})
continue
image_data = data.get('image')
mask_data = data.get('mask')
with cls._storage_lock:
cls._canvas_data_storage[node_id] = {
'image': image_data,
'mask': mask_data,
'timestamp': time.time()
}
log_info(f"Received canvas data for node {node_id} via WebSocket")
# Send acknowledgment back to the client
ack_payload = {
'type': 'ack',
'nodeId': node_id,
'status': 'success'
}
await ws.send_json(ack_payload)
log_debug(f"Sent ACK for node {node_id}")
except Exception as e:
log_error(f"Error processing WebSocket message: {e}")
await ws.send_json({'status': 'error', 'message': str(e)})
elif msg.type == web.WSMsgType.ERROR:
log_error(f"WebSocket connection closed with exception {ws.exception()}")
log_info("WebSocket connection closed")
return ws
@PromptServer.instance.routes.get("/ycnode/get_canvas_data/{node_id}") @PromptServer.instance.routes.get("/ycnode/get_canvas_data/{node_id}")
async def get_canvas_data(request): async def get_canvas_data(request):
try: try:
@@ -811,3 +810,15 @@ def convert_tensor_to_base64(tensor, alpha_mask=None, original_alpha=None):
log_error(f"Error in convert_tensor_to_base64: {str(e)}") log_error(f"Error in convert_tensor_to_base64: {str(e)}")
log_debug(f"Tensor shape: {tensor.shape}, dtype: {tensor.dtype}") log_debug(f"Tensor shape: {tensor.shape}, dtype: {tensor.dtype}")
raise raise
# Setup original API routes when module is loaded
CanvasNode.setup_routes()
NODE_CLASS_MAPPINGS = {
"CanvasNode": CanvasNode
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CanvasNode": "LayerForge"
}

View File

@@ -266,9 +266,6 @@ export class Canvas {
async saveToServer(fileName) {
return this.canvasIO.saveToServer(fileName);
}
async getFlattenedCanvasAsBlob() { async getFlattenedCanvasAsBlob() {
return this.canvasLayers.getFlattenedCanvasAsBlob(); return this.canvasLayers.getFlattenedCanvasAsBlob();

View File

@@ -1,5 +1,6 @@
import {createCanvas} from "./utils/CommonUtils.js"; import {createCanvas} from "./utils/CommonUtils.js";
import {createModuleLogger} from "./utils/LoggerUtils.js"; import {createModuleLogger} from "./utils/LoggerUtils.js";
import {webSocketManager} from "./utils/WebSocketManager.js";
const log = createModuleLogger('CanvasIO'); const log = createModuleLogger('CanvasIO');
@@ -9,7 +10,8 @@ export class CanvasIO {
this._saveInProgress = null; this._saveInProgress = null;
} }
async saveToServer(fileName) { async saveToServer(fileName, outputMode = 'disk') {
if (outputMode === 'disk') {
if (!window.canvasSaveStates) { if (!window.canvasSaveStates) {
window.canvasSaveStates = new Map(); window.canvasSaveStates = new Map();
} }
@@ -21,23 +23,25 @@ export class CanvasIO {
return this._saveInProgress || window.canvasSaveStates.get(saveKey); return this._saveInProgress || window.canvasSaveStates.get(saveKey);
} }
log.info(`Starting saveToServer with fileName: ${fileName} for node: ${nodeId}`); log.info(`Starting saveToServer (disk) with fileName: ${fileName} for node: ${nodeId}`);
log.debug(`Canvas dimensions: ${this.canvas.width}x${this.canvas.height}`); this._saveInProgress = this._performSave(fileName, outputMode);
log.debug(`Number of layers: ${this.canvas.layers.length}`);
this._saveInProgress = this._performSave(fileName);
window.canvasSaveStates.set(saveKey, this._saveInProgress); window.canvasSaveStates.set(saveKey, this._saveInProgress);
try { try {
const result = await this._saveInProgress; return await this._saveInProgress;
return result;
} finally { } finally {
this._saveInProgress = null; this._saveInProgress = null;
window.canvasSaveStates.delete(saveKey); window.canvasSaveStates.delete(saveKey);
log.debug(`Save completed for node ${nodeId}, lock released`); log.debug(`Save completed for node ${nodeId}, lock released`);
} }
} else {
// For RAM mode, we don't need the lock/state management as it's synchronous
log.info(`Starting saveToServer (RAM) for node: ${this.canvas.node.id}`);
return this._performSave(fileName, outputMode);
}
} }
async _performSave(fileName) { async _performSave(fileName, outputMode) {
if (this.canvas.layers.length === 0) { if (this.canvas.layers.length === 0) {
log.warn(`Node ${this.canvas.node.id} has no layers, creating empty canvas`); log.warn(`Node ${this.canvas.node.id} has no layers, creating empty canvas`);
return Promise.resolve(true); return Promise.resolve(true);
@@ -152,6 +156,15 @@ export class CanvasIO {
maskCtx.globalCompositeOperation = 'source-over'; maskCtx.globalCompositeOperation = 'source-over';
maskCtx.drawImage(tempMaskCanvas, 0, 0); maskCtx.drawImage(tempMaskCanvas, 0, 0);
} }
if (outputMode === 'ram') {
const imageData = tempCanvas.toDataURL('image/png');
const maskData = maskCanvas.toDataURL('image/png');
log.info("Returning image and mask data as base64 for RAM mode.");
resolve({ image: imageData, mask: maskData });
return;
}
// --- Disk Mode (original logic) ---
const fileNameWithoutMask = fileName.replace('.png', '_without_mask.png'); const fileNameWithoutMask = fileName.replace('.png', '_without_mask.png');
log.info(`Saving image without mask as: ${fileNameWithoutMask}`); log.info(`Saving image without mask as: ${fileNameWithoutMask}`);
@@ -204,7 +217,9 @@ export class CanvasIO {
if (maskResp.status === 200) { if (maskResp.status === 200) {
const data = await resp.json(); const data = await resp.json();
if (this.canvas.widget) {
this.canvas.widget.value = fileName; this.canvas.widget.value = fileName;
}
log.info(`All files saved successfully, widget value set to: ${fileName}`); log.info(`All files saved successfully, widget value set to: ${fileName}`);
resolve(true); resolve(true);
} else { } else {
@@ -228,6 +243,132 @@ export class CanvasIO {
}); });
} }
async _renderOutputData() {
return new Promise((resolve) => {
const { canvas: tempCanvas, ctx: tempCtx } = createCanvas(this.canvas.width, this.canvas.height);
const { canvas: maskCanvas, ctx: maskCtx } = createCanvas(this.canvas.width, this.canvas.height);
// This logic is mostly mirrored from _performSave to ensure consistency
tempCtx.fillStyle = '#ffffff';
tempCtx.fillRect(0, 0, this.canvas.width, this.canvas.height);
const visibilityCanvas = document.createElement('canvas');
visibilityCanvas.width = this.canvas.width;
visibilityCanvas.height = this.canvas.height;
const visibilityCtx = visibilityCanvas.getContext('2d', { alpha: true });
maskCtx.fillStyle = '#ffffff'; // Start with a white mask (nothing masked)
maskCtx.fillRect(0, 0, this.canvas.width, this.canvas.height);
const sortedLayers = this.canvas.layers.sort((a, b) => a.zIndex - b.zIndex);
sortedLayers.forEach((layer) => {
// Render layer to main canvas
tempCtx.save();
tempCtx.globalCompositeOperation = layer.blendMode || 'normal';
tempCtx.globalAlpha = layer.opacity !== undefined ? layer.opacity : 1;
tempCtx.translate(layer.x + layer.width / 2, layer.y + layer.height / 2);
tempCtx.rotate(layer.rotation * Math.PI / 180);
tempCtx.drawImage(layer.image, -layer.width / 2, -layer.height / 2, layer.width, layer.height);
tempCtx.restore();
// Render layer to visibility canvas for the mask
visibilityCtx.save();
visibilityCtx.translate(layer.x + layer.width / 2, layer.y + layer.height / 2);
visibilityCtx.rotate(layer.rotation * Math.PI / 180);
visibilityCtx.drawImage(layer.image, -layer.width / 2, -layer.height / 2, layer.width, layer.height);
visibilityCtx.restore();
});
// Create layer visibility mask
const visibilityData = visibilityCtx.getImageData(0, 0, this.canvas.width, this.canvas.height);
const maskData = maskCtx.getImageData(0, 0, this.canvas.width, this.canvas.height);
for (let i = 0; i < visibilityData.data.length; i += 4) {
const alpha = visibilityData.data[i + 3];
const maskValue = 255 - alpha; // Invert alpha to create the mask
maskData.data[i] = maskData.data[i + 1] = maskData.data[i + 2] = maskValue;
maskData.data[i + 3] = 255; // Solid mask
}
maskCtx.putImageData(maskData, 0, 0);
// Composite the tool mask on top
const toolMaskCanvas = this.canvas.maskTool.getMask();
if (toolMaskCanvas) {
// Create a temp canvas for processing the mask
const tempMaskCanvas = document.createElement('canvas');
tempMaskCanvas.width = this.canvas.width;
tempMaskCanvas.height = this.canvas.height;
const tempMaskCtx = tempMaskCanvas.getContext('2d');
// Clear the canvas
tempMaskCtx.clearRect(0, 0, tempMaskCanvas.width, tempMaskCanvas.height);
// Calculate the correct position to extract the mask
const maskX = this.canvas.maskTool.x;
const maskY = this.canvas.maskTool.y;
log.debug(`[renderOutputData] Extracting mask from world position (${maskX}, ${maskY})`);
const sourceX = Math.max(0, -maskX);
const sourceY = Math.max(0, -maskY);
const destX = Math.max(0, maskX);
const destY = Math.max(0, maskY);
const copyWidth = Math.min(toolMaskCanvas.width - sourceX, this.canvas.width - destX);
const copyHeight = Math.min(toolMaskCanvas.height - sourceY, this.canvas.height - destY);
if (copyWidth > 0 && copyHeight > 0) {
tempMaskCtx.drawImage(
toolMaskCanvas,
sourceX, sourceY, copyWidth, copyHeight,
destX, destY, copyWidth, copyHeight
);
}
// Convert the brush mask (white with alpha) to a solid white mask on black background.
const tempMaskData = tempMaskCtx.getImageData(0, 0, this.canvas.width, this.canvas.height);
for (let i = 0; i < tempMaskData.data.length; i += 4) {
const alpha = tempMaskData.data[i + 3];
// The painted area (alpha > 0) should become white (255).
tempMaskData.data[i] = tempMaskData.data[i+1] = tempMaskData.data[i+2] = alpha;
tempMaskData.data[i + 3] = 255; // Solid alpha
}
tempMaskCtx.putImageData(tempMaskData, 0, 0);
// Use 'screen' blending mode. This correctly adds the white brush mask
// to the existing layer visibility mask. (white + anything = white)
maskCtx.globalCompositeOperation = 'screen';
maskCtx.drawImage(tempMaskCanvas, 0, 0);
}
const imageDataUrl = tempCanvas.toDataURL('image/png');
const maskDataUrl = maskCanvas.toDataURL('image/png');
resolve({ image: imageDataUrl, mask: maskDataUrl });
});
}
async sendDataViaWebSocket(nodeId) {
log.info(`Preparing to send data for node ${nodeId} via WebSocket.`);
const { image, mask } = await this._renderOutputData();
try {
log.info(`Sending data for node ${nodeId}...`);
await webSocketManager.sendMessage({
type: 'canvas_data',
nodeId: String(nodeId),
image: image,
mask: mask,
}, true); // `true` requires an acknowledgment
log.info(`Data for node ${nodeId} has been sent and acknowledged by the server.`);
return true;
} catch (error) {
log.error(`Failed to send data for node ${nodeId}:`, error);
// We can alert the user here or handle it silently.
// For now, let's throw to make it clear the process failed.
throw new Error(`Failed to get confirmation from server for node ${nodeId}. The workflow might not have the latest canvas data.`);
}
}
async addInputToCanvas(inputImage, inputMask) { async addInputToCanvas(inputImage, inputMask) {
try { try {
log.debug("Adding input to canvas:", {inputImage}); log.debug("Adding input to canvas:", {inputImage});

View File

@@ -377,8 +377,7 @@ async function createCanvasWidget(node, widget, app) {
const img = new Image(); const img = new Image();
img.onload = async () => { img.onload = async () => {
canvas.addLayer(img); canvas.addLayer(img);
await saveWithFallback(widget.value); await updateOutput();
app.graph.runStep();
}; };
img.src = event.target.result; img.src = event.target.result;
}; };
@@ -392,8 +391,7 @@ async function createCanvasWidget(node, widget, app) {
textContent: "Import Input", textContent: "Import Input",
onclick: async () => { onclick: async () => {
if (await canvas.importLatestImage()) { if (await canvas.importLatestImage()) {
await saveWithFallback(widget.value); await updateOutput();
app.graph.runStep();
} }
} }
}), }),
@@ -574,8 +572,7 @@ async function createCanvasWidget(node, widget, app) {
canvas.updateSelection([newLayer]); canvas.updateSelection([newLayer]);
canvas.render(); canvas.render();
canvas.saveState(); canvas.saveState();
await saveWithFallback(widget.value); await updateOutput();
app.graph.runStep();
} catch (error) { } catch (error) {
log.error("Matting error:", error); log.error("Matting error:", error);
alert(`Error during matting process: ${error.message}`); alert(`Error during matting process: ${error.message}`);
@@ -745,7 +742,8 @@ async function createCanvasWidget(node, widget, app) {
const triggerWidget = node.widgets.find(w => w.name === "trigger"); const triggerWidget = node.widgets.find(w => w.name === "trigger");
const updateOutput = async () => { const updateOutput = async () => {
await saveWithFallback(widget.value); // Only increment trigger and run step - don't save to disk here
// Saving to disk will happen during execution_start event
triggerWidget.value = (triggerWidget.value + 1) % 99999999; triggerWidget.value = (triggerWidget.value + 1) % 99999999;
app.graph.runStep(); app.graph.runStep();
}; };
@@ -790,8 +788,9 @@ async function createCanvasWidget(node, widget, app) {
canvas.render(); canvas.render();
}; };
canvas.canvas.addEventListener('mouseup', updateOutput); // Remove automatic saving on mouse events - only save during execution
canvas.canvas.addEventListener('mouseleave', updateOutput); // canvas.canvas.addEventListener('mouseup', updateOutput);
// canvas.canvas.addEventListener('mouseleave', updateOutput);
const mainContainer = $el("div.painterMainContainer", { const mainContainer = $el("div.painterMainContainer", {
@@ -922,66 +921,8 @@ async function createCanvasWidget(node, widget, app) {
if (!window.canvasExecutionStates) { if (!window.canvasExecutionStates) {
window.canvasExecutionStates = new Map(); window.canvasExecutionStates = new Map();
} }
const saveWithFallback = async (fileName) => {
try {
const uniqueFileName = generateUniqueFileName(fileName, node.id);
log.debug(`Attempting to save with unique name: ${uniqueFileName}`);
return await canvas.saveToServer(uniqueFileName);
} catch (error) {
log.warn(`Failed to save with unique name, falling back to original: ${fileName}`, error);
return await canvas.saveToServer(fileName);
}
};
api.addEventListener("execution_start", async (event) => {
const executionData = event.detail || {};
const currentPromptId = executionData.prompt_id;
log.info(`Execution start event for node ${node.id}, prompt_id: ${currentPromptId}`);
log.debug(`Widget value: ${widget.value}`);
log.debug(`Node inputs: ${node.inputs?.length || 0}`);
log.debug(`Canvas layers count: ${canvas.layers.length}`);
if (window.canvasExecutionStates.get(node.id)) {
log.warn(`Execution already in progress for node ${node.id}, skipping...`);
return;
}
window.canvasExecutionStates.set(node.id, true);
try {
if (canvas.layers.length === 0) {
log.warn(`Node ${node.id} has no layers, skipping save to server`);
} else {
await saveWithFallback(widget.value);
log.info(`Canvas saved to server for node ${node.id}`);
}
if (node.inputs[0]?.link) {
const linkId = node.inputs[0].link;
const inputData = app.nodeOutputs[linkId];
log.debug(`Input link ${linkId} has data: ${!!inputData}`);
if (inputData) {
imageCache.set(linkId, inputData);
log.debug(`Input data cached for link ${linkId}`);
}
} else {
log.debug(`No input link found`);
}
} catch (error) {
log.error(`Error during execution for node ${node.id}:`, error);
} finally {
window.canvasExecutionStates.set(node.id, false);
log.debug(`Execution completed for node ${node.id}, flag released`);
}
});
const originalSaveToServer = canvas.saveToServer;
canvas.saveToServer = async function (fileName) {
log.debug(`saveToServer called with fileName: ${fileName}`);
log.debug(`Current execution context - node ID: ${node.id}`);
const result = await originalSaveToServer.call(this, fileName);
log.debug(`saveToServer completed, result: ${result}`);
return result;
};
node.canvasWidget = canvas; node.canvasWidget = canvas;
@@ -996,30 +937,111 @@ async function createCanvasWidget(node, widget, app) {
} }
const canvasNodeInstances = new Map();
app.registerExtension({ app.registerExtension({
name: "Comfy.CanvasNode", name: "Comfy.CanvasNode",
init() {
// Monkey-patch the queuePrompt function to send canvas data via WebSocket before sending the prompt
const originalQueuePrompt = app.queuePrompt;
app.queuePrompt = async function(number, prompt) {
log.info("Preparing to queue prompt...");
if (canvasNodeInstances.size > 0) {
log.info(`Found ${canvasNodeInstances.size} CanvasNode(s). Sending data via WebSocket...`);
const sendPromises = [];
for (const [nodeId, canvasWidget] of canvasNodeInstances.entries()) {
// Ensure the node still exists on the graph before sending data
if (app.graph.getNodeById(nodeId) && canvasWidget.canvas && canvasWidget.canvas.canvasIO) {
log.debug(`Sending data for canvas node ${nodeId}`);
// This now returns a promise that resolves upon server ACK
sendPromises.push(canvasWidget.canvas.canvasIO.sendDataViaWebSocket(nodeId));
} else {
// If node doesn't exist, it might have been deleted, so we can clean up the map
log.warn(`Node ${nodeId} not found in graph, removing from instances map.`);
canvasNodeInstances.delete(nodeId);
}
}
try {
// Wait for all WebSocket messages to be acknowledged
await Promise.all(sendPromises);
log.info("All canvas data has been sent and acknowledged by the server.");
} catch (error) {
log.error("Failed to send canvas data for one or more nodes. Aborting prompt.", error);
// IMPORTANT: Stop the prompt from queueing if data transfer fails.
// You might want to show a user-facing error here.
alert(`CanvasNode Error: ${error.message}`);
return; // Stop execution
}
}
log.info("All pre-prompt tasks complete. Proceeding with original queuePrompt.");
// Proceed with the original queuePrompt logic
return originalQueuePrompt.apply(this, arguments);
};
},
async beforeRegisterNodeDef(nodeType, nodeData, app) { async beforeRegisterNodeDef(nodeType, nodeData, app) {
if (nodeType.comfyClass === "CanvasNode") { if (nodeType.comfyClass === "CanvasNode") {
const onNodeCreated = nodeType.prototype.onNodeCreated; const onNodeCreated = nodeType.prototype.onNodeCreated;
nodeType.prototype.onNodeCreated = async function () { nodeType.prototype.onNodeCreated = function () {
log.info("CanvasNode created, ID:", this.id); log.debug("CanvasNode onNodeCreated: Base widget setup.");
// Call original onNodeCreated to ensure widgets are created
const r = onNodeCreated?.apply(this, arguments); const r = onNodeCreated?.apply(this, arguments);
// The main initialization is moved to onAdded
const widget = this.widgets.find(w => w.name === "canvas_image");
log.debug("Found canvas_image widget:", widget);
await createCanvasWidget(this, widget, app);
return r; return r;
}; };
// onAdded is the most reliable callback for when a node is fully added to the graph and has an ID
nodeType.prototype.onAdded = async function() {
log.info(`CanvasNode onAdded, ID: ${this.id}`);
log.debug(`Available widgets in onAdded:`, this.widgets.map(w => w.name));
// Prevent re-initialization if the widget already exists
if (this.canvasWidget) {
log.warn(`CanvasNode ${this.id} already initialized. Skipping onAdded setup.`);
return;
}
// Now that we are in onAdded, this.id is guaranteed to be correct.
// Set the hidden node_id widget's value for backend communication.
const nodeIdWidget = this.widgets.find(w => w.name === "node_id");
if (nodeIdWidget) {
nodeIdWidget.value = String(this.id);
log.debug(`Set hidden node_id widget to: ${nodeIdWidget.value}`);
} else {
log.error("Could not find the hidden node_id widget!");
}
// Create the main canvas widget and register it in our global map
// We pass `null` for the widget parameter as we are not using a pre-defined widget.
const canvasWidget = await createCanvasWidget(this, null, app);
canvasNodeInstances.set(this.id, canvasWidget);
log.info(`Registered CanvasNode instance for ID: ${this.id}`);
};
const onRemoved = nodeType.prototype.onRemoved; const onRemoved = nodeType.prototype.onRemoved;
nodeType.prototype.onRemoved = function () { nodeType.prototype.onRemoved = function () {
log.info(`Cleaning up canvas node ${this.id}`);
// Clean up from our instance map
canvasNodeInstances.delete(this.id);
log.info(`Deregistered CanvasNode instance for ID: ${this.id}`);
// Clean up execution state
if (window.canvasExecutionStates) {
window.canvasExecutionStates.delete(this.id);
}
const tooltip = document.getElementById(`painter-help-tooltip-${this.id}`); const tooltip = document.getElementById(`painter-help-tooltip-${this.id}`);
if (tooltip) { if (tooltip) {
tooltip.remove(); tooltip.remove();
} }
const backdrop = document.querySelector('.painter-modal-backdrop'); const backdrop = document.querySelector('.painter-modal-backdrop');
if (backdrop && backdrop.contains(this.canvasWidget.canvas)) { if (backdrop && backdrop.contains(this.canvasWidget?.canvas)) {
document.body.removeChild(backdrop); document.body.removeChild(backdrop);
} }

View File

@@ -125,12 +125,29 @@ export function cloneLayers(layers) {
* @returns {string} Sygnatura JSON * @returns {string} Sygnatura JSON
*/ */
export function getStateSignature(layers) { export function getStateSignature(layers) {
return JSON.stringify(layers.map(layer => { return JSON.stringify(layers.map((layer, index) => {
const sig = {...layer}; const sig = {
if (sig.imageId) { index: index,
sig.imageId = sig.imageId; x: Math.round(layer.x * 100) / 100, // Round to avoid floating point precision issues
y: Math.round(layer.y * 100) / 100,
width: Math.round(layer.width * 100) / 100,
height: Math.round(layer.height * 100) / 100,
rotation: Math.round((layer.rotation || 0) * 100) / 100,
zIndex: layer.zIndex,
blendMode: layer.blendMode || 'normal',
opacity: layer.opacity !== undefined ? Math.round(layer.opacity * 100) / 100 : 1
};
// Include imageId if available
if (layer.imageId) {
sig.imageId = layer.imageId;
} }
delete sig.image;
// Include image src as fallback identifier
if (layer.image && layer.image.src) {
sig.imageSrc = layer.image.src.substring(0, 100); // First 100 chars to avoid huge signatures
}
return sig; return sig;
})); }));
} }

View File

@@ -0,0 +1,160 @@
import {createModuleLogger} from "./LoggerUtils.js";
const log = createModuleLogger('WebSocketManager');
class WebSocketManager {
constructor(url) {
this.url = url;
this.socket = null;
this.messageQueue = [];
this.isConnecting = false;
this.reconnectAttempts = 0;
this.maxReconnectAttempts = 10;
this.reconnectInterval = 5000; // 5 seconds
this.ackCallbacks = new Map(); // Store callbacks for messages awaiting ACK
this.messageIdCounter = 0;
this.connect();
}
connect() {
if (this.socket && this.socket.readyState === WebSocket.OPEN) {
log.debug("WebSocket is already open.");
return;
}
if (this.isConnecting) {
log.debug("Connection attempt already in progress.");
return;
}
this.isConnecting = true;
log.info(`Connecting to WebSocket at ${this.url}...`);
try {
this.socket = new WebSocket(this.url);
this.socket.onopen = () => {
this.isConnecting = false;
this.reconnectAttempts = 0;
log.info("WebSocket connection established.");
this.flushMessageQueue();
};
this.socket.onmessage = (event) => {
try {
const data = JSON.parse(event.data);
log.debug("Received message:", data);
if (data.type === 'ack' && data.nodeId) {
const callback = this.ackCallbacks.get(data.nodeId);
if (callback) {
log.debug(`ACK received for nodeId: ${data.nodeId}, resolving promise.`);
callback.resolve(data);
this.ackCallbacks.delete(data.nodeId);
}
}
// Handle other incoming messages if needed
} catch (error) {
log.error("Error parsing incoming WebSocket message:", error);
}
};
this.socket.onclose = (event) => {
this.isConnecting = false;
if (event.wasClean) {
log.info(`WebSocket closed cleanly, code=${event.code}, reason=${event.reason}`);
} else {
log.warn("WebSocket connection died. Attempting to reconnect...");
this.handleReconnect();
}
};
this.socket.onerror = (error) => {
this.isConnecting = false;
log.error("WebSocket error:", error);
// The onclose event will be fired next, which will handle reconnection.
};
} catch (error) {
this.isConnecting = false;
log.error("Failed to create WebSocket connection:", error);
this.handleReconnect();
}
}
handleReconnect() {
if (this.reconnectAttempts < this.maxReconnectAttempts) {
this.reconnectAttempts++;
log.info(`Reconnect attempt ${this.reconnectAttempts}/${this.maxReconnectAttempts}...`);
setTimeout(() => this.connect(), this.reconnectInterval);
} else {
log.error("Max reconnect attempts reached. Giving up.");
}
}
sendMessage(data, requiresAck = false) {
return new Promise((resolve, reject) => {
const nodeId = data.nodeId;
if (requiresAck && !nodeId) {
return reject(new Error("A nodeId is required for messages that need acknowledgment."));
}
const message = JSON.stringify(data);
if (this.socket && this.socket.readyState === WebSocket.OPEN) {
this.socket.send(message);
log.debug("Sent message:", data);
if (requiresAck) {
log.debug(`Message for nodeId ${nodeId} requires ACK. Setting up callback.`);
// Set a timeout for the ACK
const timeout = setTimeout(() => {
this.ackCallbacks.delete(nodeId);
reject(new Error(`ACK timeout for nodeId ${nodeId}`));
log.warn(`ACK timeout for nodeId ${nodeId}.`);
}, 10000); // 10-second timeout
this.ackCallbacks.set(nodeId, {
resolve: (responseData) => {
clearTimeout(timeout);
resolve(responseData);
},
reject: (error) => {
clearTimeout(timeout);
reject(error);
}
});
} else {
resolve(); // Resolve immediately if no ACK is needed
}
} else {
log.warn("WebSocket not open. Queuing message.");
// Note: The current queueing doesn't support ACK promises well.
// For simplicity, we'll focus on the connected case.
// A more robust implementation would wrap the queued message in a function.
this.messageQueue.push(message);
if (!this.isConnecting) {
this.connect();
}
// For now, we reject if not connected and ACK is required.
if (requiresAck) {
reject(new Error("Cannot send message with ACK required while disconnected."));
}
}
});
}
flushMessageQueue() {
log.debug(`Flushing ${this.messageQueue.length} queued messages.`);
// Note: This simple flush doesn't handle ACKs for queued messages.
// This should be acceptable as data is sent right before queueing a prompt,
// at which point the socket should ideally be connected.
while (this.messageQueue.length > 0) {
const message = this.messageQueue.shift();
this.socket.send(message);
}
}
}
// Create a singleton instance of the WebSocketManager
const wsUrl = `ws://${window.location.host}/layerforge/canvas_ws`;
export const webSocketManager = new WebSocketManager(wsUrl);