Improve mask loading logic on node connection

Updated mask loading to immediately use available data from connected nodes and preserve existing masks if none is provided. Backend mask data is only fetched after workflow execution, ensuring no stale data is loaded during connection.
This commit is contained in:
Dariusz L
2025-08-09 02:33:28 +02:00
parent b21d6e3502
commit 06d94f6a63
7 changed files with 92 additions and 81 deletions

View File

@@ -254,82 +254,80 @@ class LayerForgeNode:
log_info(f"Lock acquired. Starting process_canvas_image for node_id: {node_id} (fallback unique_id: {unique_id})") log_info(f"Lock acquired. Starting process_canvas_image for node_id: {node_id} (fallback unique_id: {unique_id})")
# Handle input image and mask if provided # Always store fresh input data, even if None, to clear stale data
if input_image is not None or input_mask is not None: log_info(f"Storing input data for node {node_id} - Image: {input_image is not None}, Mask: {input_mask is not None}")
log_info(f"Input data detected for node {node_id} - Image: {input_image is not None}, Mask: {input_mask is not None}")
# Store input data for frontend to retrieve with self.__class__._storage_lock:
with self.__class__._storage_lock: input_data = {}
input_data = {}
if input_image is not None: if input_image is not None:
# Convert image tensor(s) to base64 - handle batch # Convert image tensor(s) to base64 - handle batch
if isinstance(input_image, torch.Tensor): if isinstance(input_image, torch.Tensor):
# Ensure correct shape [B, H, W, C] # Ensure correct shape [B, H, W, C]
if input_image.dim() == 3: if input_image.dim() == 3:
input_image = input_image.unsqueeze(0) input_image = input_image.unsqueeze(0)
batch_size = input_image.shape[0] batch_size = input_image.shape[0]
log_info(f"Processing batch of {batch_size} image(s)") log_info(f"Processing batch of {batch_size} image(s)")
if batch_size == 1: if batch_size == 1:
# Single image - keep backward compatibility # Single image - keep backward compatibility
img_np = (input_image.squeeze(0).cpu().numpy() * 255).astype(np.uint8) img_np = (input_image.squeeze(0).cpu().numpy() * 255).astype(np.uint8)
pil_img = Image.fromarray(img_np, 'RGB')
# Convert to base64
buffered = io.BytesIO()
pil_img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
input_data['input_image'] = f"data:image/png;base64,{img_str}"
input_data['input_image_width'] = pil_img.width
input_data['input_image_height'] = pil_img.height
log_debug(f"Stored single input image: {pil_img.width}x{pil_img.height}")
else:
# Multiple images - store as array
images_array = []
for i in range(batch_size):
img_np = (input_image[i].cpu().numpy() * 255).astype(np.uint8)
pil_img = Image.fromarray(img_np, 'RGB') pil_img = Image.fromarray(img_np, 'RGB')
# Convert to base64 # Convert to base64
buffered = io.BytesIO() buffered = io.BytesIO()
pil_img.save(buffered, format="PNG") pil_img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode() img_str = base64.b64encode(buffered.getvalue()).decode()
input_data['input_image'] = f"data:image/png;base64,{img_str}" images_array.append({
input_data['input_image_width'] = pil_img.width 'data': f"data:image/png;base64,{img_str}",
input_data['input_image_height'] = pil_img.height 'width': pil_img.width,
log_debug(f"Stored single input image: {pil_img.width}x{pil_img.height}") 'height': pil_img.height
else: })
# Multiple images - store as array log_debug(f"Stored batch image {i+1}/{batch_size}: {pil_img.width}x{pil_img.height}")
images_array = []
for i in range(batch_size):
img_np = (input_image[i].cpu().numpy() * 255).astype(np.uint8)
pil_img = Image.fromarray(img_np, 'RGB')
# Convert to base64 input_data['input_images_batch'] = images_array
buffered = io.BytesIO() log_info(f"Stored batch of {batch_size} images")
pil_img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
images_array.append({
'data': f"data:image/png;base64,{img_str}",
'width': pil_img.width,
'height': pil_img.height
})
log_debug(f"Stored batch image {i+1}/{batch_size}: {pil_img.width}x{pil_img.height}")
input_data['input_images_batch'] = images_array if input_mask is not None:
log_info(f"Stored batch of {batch_size} images") # Convert mask tensor to base64
if isinstance(input_mask, torch.Tensor):
# Ensure correct shape
if input_mask.dim() == 2:
input_mask = input_mask.unsqueeze(0)
if input_mask.dim() == 3 and input_mask.shape[0] == 1:
input_mask = input_mask.squeeze(0)
if input_mask is not None: # Convert to numpy and then to PIL
# Convert mask tensor to base64 mask_np = (input_mask.cpu().numpy() * 255).astype(np.uint8)
if isinstance(input_mask, torch.Tensor): pil_mask = Image.fromarray(mask_np, 'L')
# Ensure correct shape
if input_mask.dim() == 2:
input_mask = input_mask.unsqueeze(0)
if input_mask.dim() == 3 and input_mask.shape[0] == 1:
input_mask = input_mask.squeeze(0)
# Convert to numpy and then to PIL # Convert to base64
mask_np = (input_mask.cpu().numpy() * 255).astype(np.uint8) mask_buffered = io.BytesIO()
pil_mask = Image.fromarray(mask_np, 'L') pil_mask.save(mask_buffered, format="PNG")
mask_str = base64.b64encode(mask_buffered.getvalue()).decode()
input_data['input_mask'] = f"data:image/png;base64,{mask_str}"
log_debug(f"Stored input mask: {pil_mask.width}x{pil_mask.height}")
# Convert to base64 input_data['fit_on_add'] = fit_on_add
mask_buffered = io.BytesIO()
pil_mask.save(mask_buffered, format="PNG")
mask_str = base64.b64encode(mask_buffered.getvalue()).decode()
input_data['input_mask'] = f"data:image/png;base64,{mask_str}"
log_debug(f"Stored input mask: {pil_mask.width}x{pil_mask.height}")
input_data['fit_on_add'] = fit_on_add # Store in a special key for input data (overwrites any previous data)
self.__class__._canvas_data_storage[f"{node_id}_input"] = input_data
# Store in a special key for input data
self.__class__._canvas_data_storage[f"{node_id}_input"] = input_data
storage_key = node_id storage_key = node_id

View File

@@ -620,10 +620,6 @@ export class CanvasIO {
const hasMaskInput = this.canvas.node.inputs && this.canvas.node.inputs[1] && this.canvas.node.inputs[1].link; const hasMaskInput = this.canvas.node.inputs && this.canvas.node.inputs[1] && this.canvas.node.inputs[1].link;
// If mask input is disconnected, clear any currently applied mask to ensure full separation // If mask input is disconnected, clear any currently applied mask to ensure full separation
if (!hasMaskInput) { if (!hasMaskInput) {
if (this.canvas.maskTool) {
this.canvas.maskTool.clear();
this.canvas.render();
}
this.canvas.maskAppliedFromInput = false; this.canvas.maskAppliedFromInput = false;
this.canvas.lastLoadedMaskLinkId = undefined; this.canvas.lastLoadedMaskLinkId = undefined;
log.info("Mask input disconnected - cleared mask to enforce separation from input_image"); log.info("Mask input disconnected - cleared mask to enforce separation from input_image");
@@ -633,6 +629,11 @@ export class CanvasIO {
this.canvas.inputDataLoaded = true; this.canvas.inputDataLoaded = true;
return; return;
} }
// Skip backend check during mask connection if we didn't get immediate data
if (reason === "mask_connect" && !maskLoaded) {
log.info("No immediate mask data available during connection, skipping backend check to avoid stale data. Will check after execution.");
return;
}
// Check backend for input data only if we have connected inputs // Check backend for input data only if we have connected inputs
const response = await fetch(`/layerforge/get_input_data/${nodeId}`); const response = await fetch(`/layerforge/get_input_data/${nodeId}`);
const result = await response.json(); const result = await response.json();

View File

@@ -652,8 +652,8 @@ export class CanvasRenderer {
this.updateStrokeOverlaySize(); this.updateStrokeOverlaySize();
// Position above main canvas but below cursor overlay // Position above main canvas but below cursor overlay
this.strokeOverlayCanvas.style.position = 'absolute'; this.strokeOverlayCanvas.style.position = 'absolute';
this.strokeOverlayCanvas.style.left = '1px'; this.strokeOverlayCanvas.style.left = '0px';
this.strokeOverlayCanvas.style.top = '1px'; this.strokeOverlayCanvas.style.top = '0px';
this.strokeOverlayCanvas.style.pointerEvents = 'none'; this.strokeOverlayCanvas.style.pointerEvents = 'none';
this.strokeOverlayCanvas.style.zIndex = '19'; // Below cursor overlay (20) this.strokeOverlayCanvas.style.zIndex = '19'; // Below cursor overlay (20)
// Opacity is now controlled by MaskTool.previewOpacity // Opacity is now controlled by MaskTool.previewOpacity

View File

@@ -1200,12 +1200,16 @@ app.registerExtension({
if (index === 1) { if (index === 1) {
if (connected && link_info) { if (connected && link_info) {
log.info("Input mask connected"); log.info("Input mask connected");
// DON'T clear existing mask when connecting a new input
// Reset the loaded mask link ID to allow loading from the new connection
canvas.lastLoadedMaskLinkId = undefined;
// Mark that we have a pending mask connection // Mark that we have a pending mask connection
canvas.hasPendingMaskConnection = true; canvas.hasPendingMaskConnection = true;
// Check for data immediately when connected // Check for data immediately when connected
setTimeout(() => { setTimeout(() => {
log.info("Checking for input data after mask connection..."); log.info("Checking for input data after mask connection...");
// Only load mask here; images are handled by image connect or execution // Only load mask here if it's immediately available from the connected node
// Don't load stale masks from backend storage
canvas.canvasIO.checkForInputData({ allowImage: false, allowMask: true, reason: "mask_connect" }); canvas.canvasIO.checkForInputData({ allowImage: false, allowMask: true, reason: "mask_connect" });
}, 500); }, 500);
} }

View File

@@ -708,10 +708,6 @@ export class CanvasIO {
// If mask input is disconnected, clear any currently applied mask to ensure full separation // If mask input is disconnected, clear any currently applied mask to ensure full separation
if (!hasMaskInput) { if (!hasMaskInput) {
if (this.canvas.maskTool) {
this.canvas.maskTool.clear();
this.canvas.render();
}
(this.canvas as any).maskAppliedFromInput = false; (this.canvas as any).maskAppliedFromInput = false;
this.canvas.lastLoadedMaskLinkId = undefined; this.canvas.lastLoadedMaskLinkId = undefined;
log.info("Mask input disconnected - cleared mask to enforce separation from input_image"); log.info("Mask input disconnected - cleared mask to enforce separation from input_image");
@@ -723,6 +719,12 @@ export class CanvasIO {
return; return;
} }
// Skip backend check during mask connection if we didn't get immediate data
if (reason === "mask_connect" && !maskLoaded) {
log.info("No immediate mask data available during connection, skipping backend check to avoid stale data. Will check after execution.");
return;
}
// Check backend for input data only if we have connected inputs // Check backend for input data only if we have connected inputs
const response = await fetch(`/layerforge/get_input_data/${nodeId}`); const response = await fetch(`/layerforge/get_input_data/${nodeId}`);
const result = await response.json(); const result = await response.json();

View File

@@ -796,8 +796,8 @@ export class CanvasRenderer {
// Position above main canvas but below cursor overlay // Position above main canvas but below cursor overlay
this.strokeOverlayCanvas.style.position = 'absolute'; this.strokeOverlayCanvas.style.position = 'absolute';
this.strokeOverlayCanvas.style.left = '1px'; this.strokeOverlayCanvas.style.left = '0px';
this.strokeOverlayCanvas.style.top = '1px'; this.strokeOverlayCanvas.style.top = '0px';
this.strokeOverlayCanvas.style.pointerEvents = 'none'; this.strokeOverlayCanvas.style.pointerEvents = 'none';
this.strokeOverlayCanvas.style.zIndex = '19'; // Below cursor overlay (20) this.strokeOverlayCanvas.style.zIndex = '19'; // Below cursor overlay (20)
// Opacity is now controlled by MaskTool.previewOpacity // Opacity is now controlled by MaskTool.previewOpacity

View File

@@ -1368,12 +1368,18 @@ app.registerExtension({
if (index === 1) { if (index === 1) {
if (connected && link_info) { if (connected && link_info) {
log.info("Input mask connected"); log.info("Input mask connected");
// DON'T clear existing mask when connecting a new input
// Reset the loaded mask link ID to allow loading from the new connection
canvas.lastLoadedMaskLinkId = undefined;
// Mark that we have a pending mask connection // Mark that we have a pending mask connection
canvas.hasPendingMaskConnection = true; canvas.hasPendingMaskConnection = true;
// Check for data immediately when connected // Check for data immediately when connected
setTimeout(() => { setTimeout(() => {
log.info("Checking for input data after mask connection..."); log.info("Checking for input data after mask connection...");
// Only load mask here; images are handled by image connect or execution // Only load mask here if it's immediately available from the connected node
// Don't load stale masks from backend storage
canvas.canvasIO.checkForInputData({ allowImage: false, allowMask: true, reason: "mask_connect" }); canvas.canvasIO.checkForInputData({ allowImage: false, allowMask: true, reason: "mask_connect" });
}, 500); }, 500);
} else { } else {