mirror of
https://github.com/Azornes/Comfyui-LayerForge.git
synced 2026-03-21 20:52:12 -03:00
Improve mask loading logic on node connection
Updated mask loading to immediately use available data from connected nodes and preserve existing masks if none is provided. Backend mask data is only fetched after workflow execution, ensuring no stale data is loaded during connection.
This commit is contained in:
118
canvas_node.py
118
canvas_node.py
@@ -254,82 +254,80 @@ class LayerForgeNode:
|
|||||||
|
|
||||||
log_info(f"Lock acquired. Starting process_canvas_image for node_id: {node_id} (fallback unique_id: {unique_id})")
|
log_info(f"Lock acquired. Starting process_canvas_image for node_id: {node_id} (fallback unique_id: {unique_id})")
|
||||||
|
|
||||||
# Handle input image and mask if provided
|
# Always store fresh input data, even if None, to clear stale data
|
||||||
if input_image is not None or input_mask is not None:
|
log_info(f"Storing input data for node {node_id} - Image: {input_image is not None}, Mask: {input_mask is not None}")
|
||||||
log_info(f"Input data detected for node {node_id} - Image: {input_image is not None}, Mask: {input_mask is not None}")
|
|
||||||
|
|
||||||
# Store input data for frontend to retrieve
|
with self.__class__._storage_lock:
|
||||||
with self.__class__._storage_lock:
|
input_data = {}
|
||||||
input_data = {}
|
|
||||||
|
|
||||||
if input_image is not None:
|
if input_image is not None:
|
||||||
# Convert image tensor(s) to base64 - handle batch
|
# Convert image tensor(s) to base64 - handle batch
|
||||||
if isinstance(input_image, torch.Tensor):
|
if isinstance(input_image, torch.Tensor):
|
||||||
# Ensure correct shape [B, H, W, C]
|
# Ensure correct shape [B, H, W, C]
|
||||||
if input_image.dim() == 3:
|
if input_image.dim() == 3:
|
||||||
input_image = input_image.unsqueeze(0)
|
input_image = input_image.unsqueeze(0)
|
||||||
|
|
||||||
batch_size = input_image.shape[0]
|
batch_size = input_image.shape[0]
|
||||||
log_info(f"Processing batch of {batch_size} image(s)")
|
log_info(f"Processing batch of {batch_size} image(s)")
|
||||||
|
|
||||||
if batch_size == 1:
|
if batch_size == 1:
|
||||||
# Single image - keep backward compatibility
|
# Single image - keep backward compatibility
|
||||||
img_np = (input_image.squeeze(0).cpu().numpy() * 255).astype(np.uint8)
|
img_np = (input_image.squeeze(0).cpu().numpy() * 255).astype(np.uint8)
|
||||||
|
pil_img = Image.fromarray(img_np, 'RGB')
|
||||||
|
|
||||||
|
# Convert to base64
|
||||||
|
buffered = io.BytesIO()
|
||||||
|
pil_img.save(buffered, format="PNG")
|
||||||
|
img_str = base64.b64encode(buffered.getvalue()).decode()
|
||||||
|
input_data['input_image'] = f"data:image/png;base64,{img_str}"
|
||||||
|
input_data['input_image_width'] = pil_img.width
|
||||||
|
input_data['input_image_height'] = pil_img.height
|
||||||
|
log_debug(f"Stored single input image: {pil_img.width}x{pil_img.height}")
|
||||||
|
else:
|
||||||
|
# Multiple images - store as array
|
||||||
|
images_array = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
img_np = (input_image[i].cpu().numpy() * 255).astype(np.uint8)
|
||||||
pil_img = Image.fromarray(img_np, 'RGB')
|
pil_img = Image.fromarray(img_np, 'RGB')
|
||||||
|
|
||||||
# Convert to base64
|
# Convert to base64
|
||||||
buffered = io.BytesIO()
|
buffered = io.BytesIO()
|
||||||
pil_img.save(buffered, format="PNG")
|
pil_img.save(buffered, format="PNG")
|
||||||
img_str = base64.b64encode(buffered.getvalue()).decode()
|
img_str = base64.b64encode(buffered.getvalue()).decode()
|
||||||
input_data['input_image'] = f"data:image/png;base64,{img_str}"
|
images_array.append({
|
||||||
input_data['input_image_width'] = pil_img.width
|
'data': f"data:image/png;base64,{img_str}",
|
||||||
input_data['input_image_height'] = pil_img.height
|
'width': pil_img.width,
|
||||||
log_debug(f"Stored single input image: {pil_img.width}x{pil_img.height}")
|
'height': pil_img.height
|
||||||
else:
|
})
|
||||||
# Multiple images - store as array
|
log_debug(f"Stored batch image {i+1}/{batch_size}: {pil_img.width}x{pil_img.height}")
|
||||||
images_array = []
|
|
||||||
for i in range(batch_size):
|
|
||||||
img_np = (input_image[i].cpu().numpy() * 255).astype(np.uint8)
|
|
||||||
pil_img = Image.fromarray(img_np, 'RGB')
|
|
||||||
|
|
||||||
# Convert to base64
|
input_data['input_images_batch'] = images_array
|
||||||
buffered = io.BytesIO()
|
log_info(f"Stored batch of {batch_size} images")
|
||||||
pil_img.save(buffered, format="PNG")
|
|
||||||
img_str = base64.b64encode(buffered.getvalue()).decode()
|
|
||||||
images_array.append({
|
|
||||||
'data': f"data:image/png;base64,{img_str}",
|
|
||||||
'width': pil_img.width,
|
|
||||||
'height': pil_img.height
|
|
||||||
})
|
|
||||||
log_debug(f"Stored batch image {i+1}/{batch_size}: {pil_img.width}x{pil_img.height}")
|
|
||||||
|
|
||||||
input_data['input_images_batch'] = images_array
|
if input_mask is not None:
|
||||||
log_info(f"Stored batch of {batch_size} images")
|
# Convert mask tensor to base64
|
||||||
|
if isinstance(input_mask, torch.Tensor):
|
||||||
|
# Ensure correct shape
|
||||||
|
if input_mask.dim() == 2:
|
||||||
|
input_mask = input_mask.unsqueeze(0)
|
||||||
|
if input_mask.dim() == 3 and input_mask.shape[0] == 1:
|
||||||
|
input_mask = input_mask.squeeze(0)
|
||||||
|
|
||||||
if input_mask is not None:
|
# Convert to numpy and then to PIL
|
||||||
# Convert mask tensor to base64
|
mask_np = (input_mask.cpu().numpy() * 255).astype(np.uint8)
|
||||||
if isinstance(input_mask, torch.Tensor):
|
pil_mask = Image.fromarray(mask_np, 'L')
|
||||||
# Ensure correct shape
|
|
||||||
if input_mask.dim() == 2:
|
|
||||||
input_mask = input_mask.unsqueeze(0)
|
|
||||||
if input_mask.dim() == 3 and input_mask.shape[0] == 1:
|
|
||||||
input_mask = input_mask.squeeze(0)
|
|
||||||
|
|
||||||
# Convert to numpy and then to PIL
|
# Convert to base64
|
||||||
mask_np = (input_mask.cpu().numpy() * 255).astype(np.uint8)
|
mask_buffered = io.BytesIO()
|
||||||
pil_mask = Image.fromarray(mask_np, 'L')
|
pil_mask.save(mask_buffered, format="PNG")
|
||||||
|
mask_str = base64.b64encode(mask_buffered.getvalue()).decode()
|
||||||
|
input_data['input_mask'] = f"data:image/png;base64,{mask_str}"
|
||||||
|
log_debug(f"Stored input mask: {pil_mask.width}x{pil_mask.height}")
|
||||||
|
|
||||||
# Convert to base64
|
input_data['fit_on_add'] = fit_on_add
|
||||||
mask_buffered = io.BytesIO()
|
|
||||||
pil_mask.save(mask_buffered, format="PNG")
|
|
||||||
mask_str = base64.b64encode(mask_buffered.getvalue()).decode()
|
|
||||||
input_data['input_mask'] = f"data:image/png;base64,{mask_str}"
|
|
||||||
log_debug(f"Stored input mask: {pil_mask.width}x{pil_mask.height}")
|
|
||||||
|
|
||||||
input_data['fit_on_add'] = fit_on_add
|
# Store in a special key for input data (overwrites any previous data)
|
||||||
|
self.__class__._canvas_data_storage[f"{node_id}_input"] = input_data
|
||||||
# Store in a special key for input data
|
|
||||||
self.__class__._canvas_data_storage[f"{node_id}_input"] = input_data
|
|
||||||
|
|
||||||
storage_key = node_id
|
storage_key = node_id
|
||||||
|
|
||||||
|
|||||||
@@ -620,10 +620,6 @@ export class CanvasIO {
|
|||||||
const hasMaskInput = this.canvas.node.inputs && this.canvas.node.inputs[1] && this.canvas.node.inputs[1].link;
|
const hasMaskInput = this.canvas.node.inputs && this.canvas.node.inputs[1] && this.canvas.node.inputs[1].link;
|
||||||
// If mask input is disconnected, clear any currently applied mask to ensure full separation
|
// If mask input is disconnected, clear any currently applied mask to ensure full separation
|
||||||
if (!hasMaskInput) {
|
if (!hasMaskInput) {
|
||||||
if (this.canvas.maskTool) {
|
|
||||||
this.canvas.maskTool.clear();
|
|
||||||
this.canvas.render();
|
|
||||||
}
|
|
||||||
this.canvas.maskAppliedFromInput = false;
|
this.canvas.maskAppliedFromInput = false;
|
||||||
this.canvas.lastLoadedMaskLinkId = undefined;
|
this.canvas.lastLoadedMaskLinkId = undefined;
|
||||||
log.info("Mask input disconnected - cleared mask to enforce separation from input_image");
|
log.info("Mask input disconnected - cleared mask to enforce separation from input_image");
|
||||||
@@ -633,6 +629,11 @@ export class CanvasIO {
|
|||||||
this.canvas.inputDataLoaded = true;
|
this.canvas.inputDataLoaded = true;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
// Skip backend check during mask connection if we didn't get immediate data
|
||||||
|
if (reason === "mask_connect" && !maskLoaded) {
|
||||||
|
log.info("No immediate mask data available during connection, skipping backend check to avoid stale data. Will check after execution.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
// Check backend for input data only if we have connected inputs
|
// Check backend for input data only if we have connected inputs
|
||||||
const response = await fetch(`/layerforge/get_input_data/${nodeId}`);
|
const response = await fetch(`/layerforge/get_input_data/${nodeId}`);
|
||||||
const result = await response.json();
|
const result = await response.json();
|
||||||
|
|||||||
@@ -652,8 +652,8 @@ export class CanvasRenderer {
|
|||||||
this.updateStrokeOverlaySize();
|
this.updateStrokeOverlaySize();
|
||||||
// Position above main canvas but below cursor overlay
|
// Position above main canvas but below cursor overlay
|
||||||
this.strokeOverlayCanvas.style.position = 'absolute';
|
this.strokeOverlayCanvas.style.position = 'absolute';
|
||||||
this.strokeOverlayCanvas.style.left = '1px';
|
this.strokeOverlayCanvas.style.left = '0px';
|
||||||
this.strokeOverlayCanvas.style.top = '1px';
|
this.strokeOverlayCanvas.style.top = '0px';
|
||||||
this.strokeOverlayCanvas.style.pointerEvents = 'none';
|
this.strokeOverlayCanvas.style.pointerEvents = 'none';
|
||||||
this.strokeOverlayCanvas.style.zIndex = '19'; // Below cursor overlay (20)
|
this.strokeOverlayCanvas.style.zIndex = '19'; // Below cursor overlay (20)
|
||||||
// Opacity is now controlled by MaskTool.previewOpacity
|
// Opacity is now controlled by MaskTool.previewOpacity
|
||||||
|
|||||||
@@ -1200,12 +1200,16 @@ app.registerExtension({
|
|||||||
if (index === 1) {
|
if (index === 1) {
|
||||||
if (connected && link_info) {
|
if (connected && link_info) {
|
||||||
log.info("Input mask connected");
|
log.info("Input mask connected");
|
||||||
|
// DON'T clear existing mask when connecting a new input
|
||||||
|
// Reset the loaded mask link ID to allow loading from the new connection
|
||||||
|
canvas.lastLoadedMaskLinkId = undefined;
|
||||||
// Mark that we have a pending mask connection
|
// Mark that we have a pending mask connection
|
||||||
canvas.hasPendingMaskConnection = true;
|
canvas.hasPendingMaskConnection = true;
|
||||||
// Check for data immediately when connected
|
// Check for data immediately when connected
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
log.info("Checking for input data after mask connection...");
|
log.info("Checking for input data after mask connection...");
|
||||||
// Only load mask here; images are handled by image connect or execution
|
// Only load mask here if it's immediately available from the connected node
|
||||||
|
// Don't load stale masks from backend storage
|
||||||
canvas.canvasIO.checkForInputData({ allowImage: false, allowMask: true, reason: "mask_connect" });
|
canvas.canvasIO.checkForInputData({ allowImage: false, allowMask: true, reason: "mask_connect" });
|
||||||
}, 500);
|
}, 500);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -708,10 +708,6 @@ export class CanvasIO {
|
|||||||
|
|
||||||
// If mask input is disconnected, clear any currently applied mask to ensure full separation
|
// If mask input is disconnected, clear any currently applied mask to ensure full separation
|
||||||
if (!hasMaskInput) {
|
if (!hasMaskInput) {
|
||||||
if (this.canvas.maskTool) {
|
|
||||||
this.canvas.maskTool.clear();
|
|
||||||
this.canvas.render();
|
|
||||||
}
|
|
||||||
(this.canvas as any).maskAppliedFromInput = false;
|
(this.canvas as any).maskAppliedFromInput = false;
|
||||||
this.canvas.lastLoadedMaskLinkId = undefined;
|
this.canvas.lastLoadedMaskLinkId = undefined;
|
||||||
log.info("Mask input disconnected - cleared mask to enforce separation from input_image");
|
log.info("Mask input disconnected - cleared mask to enforce separation from input_image");
|
||||||
@@ -723,6 +719,12 @@ export class CanvasIO {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Skip backend check during mask connection if we didn't get immediate data
|
||||||
|
if (reason === "mask_connect" && !maskLoaded) {
|
||||||
|
log.info("No immediate mask data available during connection, skipping backend check to avoid stale data. Will check after execution.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Check backend for input data only if we have connected inputs
|
// Check backend for input data only if we have connected inputs
|
||||||
const response = await fetch(`/layerforge/get_input_data/${nodeId}`);
|
const response = await fetch(`/layerforge/get_input_data/${nodeId}`);
|
||||||
const result = await response.json();
|
const result = await response.json();
|
||||||
|
|||||||
@@ -796,8 +796,8 @@ export class CanvasRenderer {
|
|||||||
|
|
||||||
// Position above main canvas but below cursor overlay
|
// Position above main canvas but below cursor overlay
|
||||||
this.strokeOverlayCanvas.style.position = 'absolute';
|
this.strokeOverlayCanvas.style.position = 'absolute';
|
||||||
this.strokeOverlayCanvas.style.left = '1px';
|
this.strokeOverlayCanvas.style.left = '0px';
|
||||||
this.strokeOverlayCanvas.style.top = '1px';
|
this.strokeOverlayCanvas.style.top = '0px';
|
||||||
this.strokeOverlayCanvas.style.pointerEvents = 'none';
|
this.strokeOverlayCanvas.style.pointerEvents = 'none';
|
||||||
this.strokeOverlayCanvas.style.zIndex = '19'; // Below cursor overlay (20)
|
this.strokeOverlayCanvas.style.zIndex = '19'; // Below cursor overlay (20)
|
||||||
// Opacity is now controlled by MaskTool.previewOpacity
|
// Opacity is now controlled by MaskTool.previewOpacity
|
||||||
|
|||||||
@@ -1368,12 +1368,18 @@ app.registerExtension({
|
|||||||
if (index === 1) {
|
if (index === 1) {
|
||||||
if (connected && link_info) {
|
if (connected && link_info) {
|
||||||
log.info("Input mask connected");
|
log.info("Input mask connected");
|
||||||
|
|
||||||
|
// DON'T clear existing mask when connecting a new input
|
||||||
|
// Reset the loaded mask link ID to allow loading from the new connection
|
||||||
|
canvas.lastLoadedMaskLinkId = undefined;
|
||||||
|
|
||||||
// Mark that we have a pending mask connection
|
// Mark that we have a pending mask connection
|
||||||
canvas.hasPendingMaskConnection = true;
|
canvas.hasPendingMaskConnection = true;
|
||||||
// Check for data immediately when connected
|
// Check for data immediately when connected
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
log.info("Checking for input data after mask connection...");
|
log.info("Checking for input data after mask connection...");
|
||||||
// Only load mask here; images are handled by image connect or execution
|
// Only load mask here if it's immediately available from the connected node
|
||||||
|
// Don't load stale masks from backend storage
|
||||||
canvas.canvasIO.checkForInputData({ allowImage: false, allowMask: true, reason: "mask_connect" });
|
canvas.canvasIO.checkForInputData({ allowImage: false, allowMask: true, reason: "mask_connect" });
|
||||||
}, 500);
|
}, 500);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
Reference in New Issue
Block a user