mirror of
https://github.com/Azornes/Comfyui-LayerForge.git
synced 2026-03-21 20:52:12 -03:00
Added mask and image input
This commit is contained in:
@@ -179,6 +179,10 @@ class LayerForgeNode:
|
||||
"trigger": ("INT", {"default": 0, "min": 0, "max": 99999999, "step": 1}),
|
||||
"node_id": ("STRING", {"default": "0"}),
|
||||
},
|
||||
"optional": {
|
||||
"input_image": ("IMAGE",),
|
||||
"input_mask": ("MASK",),
|
||||
},
|
||||
"hidden": {
|
||||
"prompt": ("PROMPT",),
|
||||
"unique_id": ("UNIQUE_ID",),
|
||||
@@ -239,7 +243,7 @@ class LayerForgeNode:
|
||||
|
||||
_processing_lock = threading.Lock()
|
||||
|
||||
def process_canvas_image(self, fit_on_add, show_preview, auto_refresh_after_generation, trigger, node_id, prompt=None, unique_id=None):
|
||||
def process_canvas_image(self, fit_on_add, show_preview, auto_refresh_after_generation, trigger, node_id, input_image=None, input_mask=None, prompt=None, unique_id=None):
|
||||
|
||||
try:
|
||||
|
||||
@@ -250,6 +254,59 @@ class LayerForgeNode:
|
||||
|
||||
log_info(f"Lock acquired. Starting process_canvas_image for node_id: {node_id} (fallback unique_id: {unique_id})")
|
||||
|
||||
# Handle input image and mask if provided
|
||||
if input_image is not None or input_mask is not None:
|
||||
log_info(f"Input data detected for node {node_id} - Image: {input_image is not None}, Mask: {input_mask is not None}")
|
||||
|
||||
# Store input data for frontend to retrieve
|
||||
with self.__class__._storage_lock:
|
||||
input_data = {}
|
||||
|
||||
if input_image is not None:
|
||||
# Convert image tensor to base64
|
||||
if isinstance(input_image, torch.Tensor):
|
||||
# Ensure correct shape [B, H, W, C]
|
||||
if input_image.dim() == 3:
|
||||
input_image = input_image.unsqueeze(0)
|
||||
|
||||
# Convert to numpy and then to PIL
|
||||
img_np = (input_image.squeeze(0).cpu().numpy() * 255).astype(np.uint8)
|
||||
pil_img = Image.fromarray(img_np, 'RGB')
|
||||
|
||||
# Convert to base64
|
||||
buffered = io.BytesIO()
|
||||
pil_img.save(buffered, format="PNG")
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode()
|
||||
input_data['input_image'] = f"data:image/png;base64,{img_str}"
|
||||
input_data['input_image_width'] = pil_img.width
|
||||
input_data['input_image_height'] = pil_img.height
|
||||
log_debug(f"Stored input image: {pil_img.width}x{pil_img.height}")
|
||||
|
||||
if input_mask is not None:
|
||||
# Convert mask tensor to base64
|
||||
if isinstance(input_mask, torch.Tensor):
|
||||
# Ensure correct shape
|
||||
if input_mask.dim() == 2:
|
||||
input_mask = input_mask.unsqueeze(0)
|
||||
if input_mask.dim() == 3 and input_mask.shape[0] == 1:
|
||||
input_mask = input_mask.squeeze(0)
|
||||
|
||||
# Convert to numpy and then to PIL
|
||||
mask_np = (input_mask.cpu().numpy() * 255).astype(np.uint8)
|
||||
pil_mask = Image.fromarray(mask_np, 'L')
|
||||
|
||||
# Convert to base64
|
||||
mask_buffered = io.BytesIO()
|
||||
pil_mask.save(mask_buffered, format="PNG")
|
||||
mask_str = base64.b64encode(mask_buffered.getvalue()).decode()
|
||||
input_data['input_mask'] = f"data:image/png;base64,{mask_str}"
|
||||
log_debug(f"Stored input mask: {pil_mask.width}x{pil_mask.height}")
|
||||
|
||||
input_data['fit_on_add'] = fit_on_add
|
||||
|
||||
# Store in a special key for input data
|
||||
self.__class__._canvas_data_storage[f"{node_id}_input"] = input_data
|
||||
|
||||
storage_key = node_id
|
||||
|
||||
processed_image = None
|
||||
@@ -433,6 +490,37 @@ class LayerForgeNode:
|
||||
log_info("WebSocket connection closed")
|
||||
return ws
|
||||
|
||||
@PromptServer.instance.routes.get("/layerforge/get_input_data/{node_id}")
|
||||
async def get_input_data(request):
|
||||
try:
|
||||
node_id = request.match_info["node_id"]
|
||||
log_debug(f"Checking for input data for node: {node_id}")
|
||||
|
||||
with cls._storage_lock:
|
||||
input_key = f"{node_id}_input"
|
||||
input_data = cls._canvas_data_storage.pop(input_key, None)
|
||||
|
||||
if input_data:
|
||||
log_info(f"Input data found for node {node_id}, sending to frontend")
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'has_input': True,
|
||||
'data': input_data
|
||||
})
|
||||
else:
|
||||
log_debug(f"No input data found for node {node_id}")
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'has_input': False
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
log_error(f"Error in get_input_data: {str(e)}")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
@PromptServer.instance.routes.get("/ycnode/get_canvas_data/{node_id}")
|
||||
async def get_canvas_data(request):
|
||||
try:
|
||||
@@ -911,4 +999,3 @@ def convert_tensor_to_base64(tensor, alpha_mask=None, original_alpha=None):
|
||||
log_error(f"Error in convert_tensor_to_base64: {str(e)}")
|
||||
log_debug(f"Tensor shape: {tensor.shape}, dtype: {tensor.dtype}")
|
||||
raise
|
||||
|
||||
|
||||
Reference in New Issue
Block a user