From 06d94f6a636491f9b7a3ac15c698b02ec749ef25 Mon Sep 17 00:00:00 2001 From: Dariusz L Date: Sat, 9 Aug 2025 02:33:28 +0200 Subject: [PATCH] Improve mask loading logic on node connection Updated mask loading to immediately use available data from connected nodes and preserve existing masks if none is provided. Backend mask data is only fetched after workflow execution, ensuring no stale data is loaded during connection. --- canvas_node.py | 132 +++++++++++++++++++++--------------------- js/CanvasIO.js | 9 +-- js/CanvasRenderer.js | 4 +- js/CanvasView.js | 6 +- src/CanvasIO.ts | 10 ++-- src/CanvasRenderer.ts | 4 +- src/CanvasView.ts | 8 ++- 7 files changed, 92 insertions(+), 81 deletions(-) diff --git a/canvas_node.py b/canvas_node.py index c51a2fa..eb83859 100644 --- a/canvas_node.py +++ b/canvas_node.py @@ -254,82 +254,80 @@ class LayerForgeNode: log_info(f"Lock acquired. Starting process_canvas_image for node_id: {node_id} (fallback unique_id: {unique_id})") - # Handle input image and mask if provided - if input_image is not None or input_mask is not None: - log_info(f"Input data detected for node {node_id} - Image: {input_image is not None}, Mask: {input_mask is not None}") + # Always store fresh input data, even if None, to clear stale data + log_info(f"Storing input data for node {node_id} - Image: {input_image is not None}, Mask: {input_mask is not None}") + + with self.__class__._storage_lock: + input_data = {} - # Store input data for frontend to retrieve - with self.__class__._storage_lock: - input_data = {} - - if input_image is not None: - # Convert image tensor(s) to base64 - handle batch - if isinstance(input_image, torch.Tensor): - # Ensure correct shape [B, H, W, C] - if input_image.dim() == 3: - input_image = input_image.unsqueeze(0) + if input_image is not None: + # Convert image tensor(s) to base64 - handle batch + if isinstance(input_image, torch.Tensor): + # Ensure correct shape [B, H, W, C] + if input_image.dim() == 3: + input_image = input_image.unsqueeze(0) + + batch_size = input_image.shape[0] + log_info(f"Processing batch of {batch_size} image(s)") + + if batch_size == 1: + # Single image - keep backward compatibility + img_np = (input_image.squeeze(0).cpu().numpy() * 255).astype(np.uint8) + pil_img = Image.fromarray(img_np, 'RGB') - batch_size = input_image.shape[0] - log_info(f"Processing batch of {batch_size} image(s)") - - if batch_size == 1: - # Single image - keep backward compatibility - img_np = (input_image.squeeze(0).cpu().numpy() * 255).astype(np.uint8) + # Convert to base64 + buffered = io.BytesIO() + pil_img.save(buffered, format="PNG") + img_str = base64.b64encode(buffered.getvalue()).decode() + input_data['input_image'] = f"data:image/png;base64,{img_str}" + input_data['input_image_width'] = pil_img.width + input_data['input_image_height'] = pil_img.height + log_debug(f"Stored single input image: {pil_img.width}x{pil_img.height}") + else: + # Multiple images - store as array + images_array = [] + for i in range(batch_size): + img_np = (input_image[i].cpu().numpy() * 255).astype(np.uint8) pil_img = Image.fromarray(img_np, 'RGB') # Convert to base64 buffered = io.BytesIO() pil_img.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() - input_data['input_image'] = f"data:image/png;base64,{img_str}" - input_data['input_image_width'] = pil_img.width - input_data['input_image_height'] = pil_img.height - log_debug(f"Stored single input image: {pil_img.width}x{pil_img.height}") - else: - # Multiple images - store as array - images_array = [] - for i in range(batch_size): - img_np = (input_image[i].cpu().numpy() * 255).astype(np.uint8) - pil_img = Image.fromarray(img_np, 'RGB') - - # Convert to base64 - buffered = io.BytesIO() - pil_img.save(buffered, format="PNG") - img_str = base64.b64encode(buffered.getvalue()).decode() - images_array.append({ - 'data': f"data:image/png;base64,{img_str}", - 'width': pil_img.width, - 'height': pil_img.height - }) - log_debug(f"Stored batch image {i+1}/{batch_size}: {pil_img.width}x{pil_img.height}") - - input_data['input_images_batch'] = images_array - log_info(f"Stored batch of {batch_size} images") - - if input_mask is not None: - # Convert mask tensor to base64 - if isinstance(input_mask, torch.Tensor): - # Ensure correct shape - if input_mask.dim() == 2: - input_mask = input_mask.unsqueeze(0) - if input_mask.dim() == 3 and input_mask.shape[0] == 1: - input_mask = input_mask.squeeze(0) + images_array.append({ + 'data': f"data:image/png;base64,{img_str}", + 'width': pil_img.width, + 'height': pil_img.height + }) + log_debug(f"Stored batch image {i+1}/{batch_size}: {pil_img.width}x{pil_img.height}") - # Convert to numpy and then to PIL - mask_np = (input_mask.cpu().numpy() * 255).astype(np.uint8) - pil_mask = Image.fromarray(mask_np, 'L') - - # Convert to base64 - mask_buffered = io.BytesIO() - pil_mask.save(mask_buffered, format="PNG") - mask_str = base64.b64encode(mask_buffered.getvalue()).decode() - input_data['input_mask'] = f"data:image/png;base64,{mask_str}" - log_debug(f"Stored input mask: {pil_mask.width}x{pil_mask.height}") - - input_data['fit_on_add'] = fit_on_add - - # Store in a special key for input data - self.__class__._canvas_data_storage[f"{node_id}_input"] = input_data + input_data['input_images_batch'] = images_array + log_info(f"Stored batch of {batch_size} images") + + if input_mask is not None: + # Convert mask tensor to base64 + if isinstance(input_mask, torch.Tensor): + # Ensure correct shape + if input_mask.dim() == 2: + input_mask = input_mask.unsqueeze(0) + if input_mask.dim() == 3 and input_mask.shape[0] == 1: + input_mask = input_mask.squeeze(0) + + # Convert to numpy and then to PIL + mask_np = (input_mask.cpu().numpy() * 255).astype(np.uint8) + pil_mask = Image.fromarray(mask_np, 'L') + + # Convert to base64 + mask_buffered = io.BytesIO() + pil_mask.save(mask_buffered, format="PNG") + mask_str = base64.b64encode(mask_buffered.getvalue()).decode() + input_data['input_mask'] = f"data:image/png;base64,{mask_str}" + log_debug(f"Stored input mask: {pil_mask.width}x{pil_mask.height}") + + input_data['fit_on_add'] = fit_on_add + + # Store in a special key for input data (overwrites any previous data) + self.__class__._canvas_data_storage[f"{node_id}_input"] = input_data storage_key = node_id diff --git a/js/CanvasIO.js b/js/CanvasIO.js index 2918d23..011e7cb 100644 --- a/js/CanvasIO.js +++ b/js/CanvasIO.js @@ -620,10 +620,6 @@ export class CanvasIO { const hasMaskInput = this.canvas.node.inputs && this.canvas.node.inputs[1] && this.canvas.node.inputs[1].link; // If mask input is disconnected, clear any currently applied mask to ensure full separation if (!hasMaskInput) { - if (this.canvas.maskTool) { - this.canvas.maskTool.clear(); - this.canvas.render(); - } this.canvas.maskAppliedFromInput = false; this.canvas.lastLoadedMaskLinkId = undefined; log.info("Mask input disconnected - cleared mask to enforce separation from input_image"); @@ -633,6 +629,11 @@ export class CanvasIO { this.canvas.inputDataLoaded = true; return; } + // Skip backend check during mask connection if we didn't get immediate data + if (reason === "mask_connect" && !maskLoaded) { + log.info("No immediate mask data available during connection, skipping backend check to avoid stale data. Will check after execution."); + return; + } // Check backend for input data only if we have connected inputs const response = await fetch(`/layerforge/get_input_data/${nodeId}`); const result = await response.json(); diff --git a/js/CanvasRenderer.js b/js/CanvasRenderer.js index 38938e4..1bebc6f 100644 --- a/js/CanvasRenderer.js +++ b/js/CanvasRenderer.js @@ -652,8 +652,8 @@ export class CanvasRenderer { this.updateStrokeOverlaySize(); // Position above main canvas but below cursor overlay this.strokeOverlayCanvas.style.position = 'absolute'; - this.strokeOverlayCanvas.style.left = '1px'; - this.strokeOverlayCanvas.style.top = '1px'; + this.strokeOverlayCanvas.style.left = '0px'; + this.strokeOverlayCanvas.style.top = '0px'; this.strokeOverlayCanvas.style.pointerEvents = 'none'; this.strokeOverlayCanvas.style.zIndex = '19'; // Below cursor overlay (20) // Opacity is now controlled by MaskTool.previewOpacity diff --git a/js/CanvasView.js b/js/CanvasView.js index cac9f30..947cf2b 100644 --- a/js/CanvasView.js +++ b/js/CanvasView.js @@ -1200,12 +1200,16 @@ app.registerExtension({ if (index === 1) { if (connected && link_info) { log.info("Input mask connected"); + // DON'T clear existing mask when connecting a new input + // Reset the loaded mask link ID to allow loading from the new connection + canvas.lastLoadedMaskLinkId = undefined; // Mark that we have a pending mask connection canvas.hasPendingMaskConnection = true; // Check for data immediately when connected setTimeout(() => { log.info("Checking for input data after mask connection..."); - // Only load mask here; images are handled by image connect or execution + // Only load mask here if it's immediately available from the connected node + // Don't load stale masks from backend storage canvas.canvasIO.checkForInputData({ allowImage: false, allowMask: true, reason: "mask_connect" }); }, 500); } diff --git a/src/CanvasIO.ts b/src/CanvasIO.ts index dd1365f..afb9dee 100644 --- a/src/CanvasIO.ts +++ b/src/CanvasIO.ts @@ -708,10 +708,6 @@ export class CanvasIO { // If mask input is disconnected, clear any currently applied mask to ensure full separation if (!hasMaskInput) { - if (this.canvas.maskTool) { - this.canvas.maskTool.clear(); - this.canvas.render(); - } (this.canvas as any).maskAppliedFromInput = false; this.canvas.lastLoadedMaskLinkId = undefined; log.info("Mask input disconnected - cleared mask to enforce separation from input_image"); @@ -723,6 +719,12 @@ export class CanvasIO { return; } + // Skip backend check during mask connection if we didn't get immediate data + if (reason === "mask_connect" && !maskLoaded) { + log.info("No immediate mask data available during connection, skipping backend check to avoid stale data. Will check after execution."); + return; + } + // Check backend for input data only if we have connected inputs const response = await fetch(`/layerforge/get_input_data/${nodeId}`); const result = await response.json(); diff --git a/src/CanvasRenderer.ts b/src/CanvasRenderer.ts index 0c6a368..59b2af2 100644 --- a/src/CanvasRenderer.ts +++ b/src/CanvasRenderer.ts @@ -796,8 +796,8 @@ export class CanvasRenderer { // Position above main canvas but below cursor overlay this.strokeOverlayCanvas.style.position = 'absolute'; - this.strokeOverlayCanvas.style.left = '1px'; - this.strokeOverlayCanvas.style.top = '1px'; + this.strokeOverlayCanvas.style.left = '0px'; + this.strokeOverlayCanvas.style.top = '0px'; this.strokeOverlayCanvas.style.pointerEvents = 'none'; this.strokeOverlayCanvas.style.zIndex = '19'; // Below cursor overlay (20) // Opacity is now controlled by MaskTool.previewOpacity diff --git a/src/CanvasView.ts b/src/CanvasView.ts index 67fb020..7c4a665 100644 --- a/src/CanvasView.ts +++ b/src/CanvasView.ts @@ -1368,12 +1368,18 @@ app.registerExtension({ if (index === 1) { if (connected && link_info) { log.info("Input mask connected"); + + // DON'T clear existing mask when connecting a new input + // Reset the loaded mask link ID to allow loading from the new connection + canvas.lastLoadedMaskLinkId = undefined; + // Mark that we have a pending mask connection canvas.hasPendingMaskConnection = true; // Check for data immediately when connected setTimeout(() => { log.info("Checking for input data after mask connection..."); - // Only load mask here; images are handled by image connect or execution + // Only load mask here if it's immediately available from the connected node + // Don't load stale masks from backend storage canvas.canvasIO.checkForInputData({ allowImage: false, allowMask: true, reason: "mask_connect" }); }, 500); } else {