mirror of
https://github.com/Azornes/Comfyui-LayerForge.git
synced 2026-03-21 20:52:12 -03:00
Improve mask loading logic on node connection
Updated mask loading to immediately use available data from connected nodes and preserve existing masks if none is provided. Backend mask data is only fetched after workflow execution, ensuring no stale data is loaded during connection.
This commit is contained in:
132
canvas_node.py
132
canvas_node.py
@@ -254,82 +254,80 @@ class LayerForgeNode:
|
||||
|
||||
log_info(f"Lock acquired. Starting process_canvas_image for node_id: {node_id} (fallback unique_id: {unique_id})")
|
||||
|
||||
# Handle input image and mask if provided
|
||||
if input_image is not None or input_mask is not None:
|
||||
log_info(f"Input data detected for node {node_id} - Image: {input_image is not None}, Mask: {input_mask is not None}")
|
||||
# Always store fresh input data, even if None, to clear stale data
|
||||
log_info(f"Storing input data for node {node_id} - Image: {input_image is not None}, Mask: {input_mask is not None}")
|
||||
|
||||
with self.__class__._storage_lock:
|
||||
input_data = {}
|
||||
|
||||
# Store input data for frontend to retrieve
|
||||
with self.__class__._storage_lock:
|
||||
input_data = {}
|
||||
|
||||
if input_image is not None:
|
||||
# Convert image tensor(s) to base64 - handle batch
|
||||
if isinstance(input_image, torch.Tensor):
|
||||
# Ensure correct shape [B, H, W, C]
|
||||
if input_image.dim() == 3:
|
||||
input_image = input_image.unsqueeze(0)
|
||||
if input_image is not None:
|
||||
# Convert image tensor(s) to base64 - handle batch
|
||||
if isinstance(input_image, torch.Tensor):
|
||||
# Ensure correct shape [B, H, W, C]
|
||||
if input_image.dim() == 3:
|
||||
input_image = input_image.unsqueeze(0)
|
||||
|
||||
batch_size = input_image.shape[0]
|
||||
log_info(f"Processing batch of {batch_size} image(s)")
|
||||
|
||||
if batch_size == 1:
|
||||
# Single image - keep backward compatibility
|
||||
img_np = (input_image.squeeze(0).cpu().numpy() * 255).astype(np.uint8)
|
||||
pil_img = Image.fromarray(img_np, 'RGB')
|
||||
|
||||
batch_size = input_image.shape[0]
|
||||
log_info(f"Processing batch of {batch_size} image(s)")
|
||||
|
||||
if batch_size == 1:
|
||||
# Single image - keep backward compatibility
|
||||
img_np = (input_image.squeeze(0).cpu().numpy() * 255).astype(np.uint8)
|
||||
# Convert to base64
|
||||
buffered = io.BytesIO()
|
||||
pil_img.save(buffered, format="PNG")
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode()
|
||||
input_data['input_image'] = f"data:image/png;base64,{img_str}"
|
||||
input_data['input_image_width'] = pil_img.width
|
||||
input_data['input_image_height'] = pil_img.height
|
||||
log_debug(f"Stored single input image: {pil_img.width}x{pil_img.height}")
|
||||
else:
|
||||
# Multiple images - store as array
|
||||
images_array = []
|
||||
for i in range(batch_size):
|
||||
img_np = (input_image[i].cpu().numpy() * 255).astype(np.uint8)
|
||||
pil_img = Image.fromarray(img_np, 'RGB')
|
||||
|
||||
# Convert to base64
|
||||
buffered = io.BytesIO()
|
||||
pil_img.save(buffered, format="PNG")
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode()
|
||||
input_data['input_image'] = f"data:image/png;base64,{img_str}"
|
||||
input_data['input_image_width'] = pil_img.width
|
||||
input_data['input_image_height'] = pil_img.height
|
||||
log_debug(f"Stored single input image: {pil_img.width}x{pil_img.height}")
|
||||
else:
|
||||
# Multiple images - store as array
|
||||
images_array = []
|
||||
for i in range(batch_size):
|
||||
img_np = (input_image[i].cpu().numpy() * 255).astype(np.uint8)
|
||||
pil_img = Image.fromarray(img_np, 'RGB')
|
||||
|
||||
# Convert to base64
|
||||
buffered = io.BytesIO()
|
||||
pil_img.save(buffered, format="PNG")
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode()
|
||||
images_array.append({
|
||||
'data': f"data:image/png;base64,{img_str}",
|
||||
'width': pil_img.width,
|
||||
'height': pil_img.height
|
||||
})
|
||||
log_debug(f"Stored batch image {i+1}/{batch_size}: {pil_img.width}x{pil_img.height}")
|
||||
|
||||
input_data['input_images_batch'] = images_array
|
||||
log_info(f"Stored batch of {batch_size} images")
|
||||
|
||||
if input_mask is not None:
|
||||
# Convert mask tensor to base64
|
||||
if isinstance(input_mask, torch.Tensor):
|
||||
# Ensure correct shape
|
||||
if input_mask.dim() == 2:
|
||||
input_mask = input_mask.unsqueeze(0)
|
||||
if input_mask.dim() == 3 and input_mask.shape[0] == 1:
|
||||
input_mask = input_mask.squeeze(0)
|
||||
images_array.append({
|
||||
'data': f"data:image/png;base64,{img_str}",
|
||||
'width': pil_img.width,
|
||||
'height': pil_img.height
|
||||
})
|
||||
log_debug(f"Stored batch image {i+1}/{batch_size}: {pil_img.width}x{pil_img.height}")
|
||||
|
||||
# Convert to numpy and then to PIL
|
||||
mask_np = (input_mask.cpu().numpy() * 255).astype(np.uint8)
|
||||
pil_mask = Image.fromarray(mask_np, 'L')
|
||||
|
||||
# Convert to base64
|
||||
mask_buffered = io.BytesIO()
|
||||
pil_mask.save(mask_buffered, format="PNG")
|
||||
mask_str = base64.b64encode(mask_buffered.getvalue()).decode()
|
||||
input_data['input_mask'] = f"data:image/png;base64,{mask_str}"
|
||||
log_debug(f"Stored input mask: {pil_mask.width}x{pil_mask.height}")
|
||||
|
||||
input_data['fit_on_add'] = fit_on_add
|
||||
|
||||
# Store in a special key for input data
|
||||
self.__class__._canvas_data_storage[f"{node_id}_input"] = input_data
|
||||
input_data['input_images_batch'] = images_array
|
||||
log_info(f"Stored batch of {batch_size} images")
|
||||
|
||||
if input_mask is not None:
|
||||
# Convert mask tensor to base64
|
||||
if isinstance(input_mask, torch.Tensor):
|
||||
# Ensure correct shape
|
||||
if input_mask.dim() == 2:
|
||||
input_mask = input_mask.unsqueeze(0)
|
||||
if input_mask.dim() == 3 and input_mask.shape[0] == 1:
|
||||
input_mask = input_mask.squeeze(0)
|
||||
|
||||
# Convert to numpy and then to PIL
|
||||
mask_np = (input_mask.cpu().numpy() * 255).astype(np.uint8)
|
||||
pil_mask = Image.fromarray(mask_np, 'L')
|
||||
|
||||
# Convert to base64
|
||||
mask_buffered = io.BytesIO()
|
||||
pil_mask.save(mask_buffered, format="PNG")
|
||||
mask_str = base64.b64encode(mask_buffered.getvalue()).decode()
|
||||
input_data['input_mask'] = f"data:image/png;base64,{mask_str}"
|
||||
log_debug(f"Stored input mask: {pil_mask.width}x{pil_mask.height}")
|
||||
|
||||
input_data['fit_on_add'] = fit_on_add
|
||||
|
||||
# Store in a special key for input data (overwrites any previous data)
|
||||
self.__class__._canvas_data_storage[f"{node_id}_input"] = input_data
|
||||
|
||||
storage_key = node_id
|
||||
|
||||
|
||||
Reference in New Issue
Block a user