mirror of
https://github.com/Azornes/Comfyui-LayerForge.git
synced 2026-03-21 20:52:12 -03:00
Add WebSocket-based RAM output for CanvasNode
Introduces a WebSocket-based mechanism for CanvasNode to send and receive canvas image and mask data in RAM, enabling fast, diskless data transfer between frontend and backend. Adds a new WebSocketManager utility, updates CanvasIO to support RAM output mode, and modifies CanvasView to send canvas data via WebSocket before prompt execution. The backend (canvas_node.py) is updated to handle WebSocket data storage and retrieval, with improved locking and cleanup logic. This change improves workflow speed and reliability by avoiding unnecessary disk I/O and ensuring up-to-date canvas data is available during node execution.
This commit is contained in:
301
canvas_node.py
301
canvas_node.py
@@ -5,6 +5,8 @@ import numpy as np
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
from server import PromptServer
|
from server import PromptServer
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
import os
|
import os
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
@@ -91,6 +93,9 @@ class BiRefNet(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class CanvasNode:
|
class CanvasNode:
|
||||||
|
_canvas_data_storage = {}
|
||||||
|
_storage_lock = threading.Lock()
|
||||||
|
|
||||||
_canvas_cache = {
|
_canvas_cache = {
|
||||||
'image': None,
|
'image': None,
|
||||||
'mask': None,
|
'mask': None,
|
||||||
@@ -99,10 +104,16 @@ class CanvasNode:
|
|||||||
'persistent_cache': {},
|
'persistent_cache': {},
|
||||||
'last_execution_id': None
|
'last_execution_id': None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Simple in-memory storage for canvas data, keyed by prompt_id
|
||||||
|
# WebSocket-based storage for canvas data per node
|
||||||
|
_websocket_data = {}
|
||||||
|
_websocket_listeners = {}
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.flow_id = str(uuid.uuid4())
|
self.flow_id = str(uuid.uuid4())
|
||||||
|
self.node_id = None # Will be set when node is created
|
||||||
|
|
||||||
if self.__class__._canvas_cache['persistent_cache']:
|
if self.__class__._canvas_cache['persistent_cache']:
|
||||||
self.restore_cache()
|
self.restore_cache()
|
||||||
@@ -166,14 +177,18 @@ class CanvasNode:
|
|||||||
def INPUT_TYPES(cls):
|
def INPUT_TYPES(cls):
|
||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"canvas_image": ("STRING", {"default": "canvas_image.png"}),
|
|
||||||
"trigger": ("INT", {"default": 0, "min": 0, "max": 99999999, "step": 1, "hidden": True}),
|
"trigger": ("INT", {"default": 0, "min": 0, "max": 99999999, "step": 1, "hidden": True}),
|
||||||
"output_switch": ("BOOLEAN", {"default": True}),
|
"output_switch": ("BOOLEAN", {"default": True}),
|
||||||
"cache_enabled": ("BOOLEAN", {"default": True, "label": "Enable Cache"})
|
"cache_enabled": ("BOOLEAN", {"default": True, "label": "Enable Cache"}),
|
||||||
|
"node_id": ("STRING", {"default": "0", "hidden": True}),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"input_image": ("IMAGE",),
|
"input_image": ("IMAGE",),
|
||||||
"input_mask": ("MASK",)
|
"input_mask": ("MASK",)
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"prompt": ("PROMPT",),
|
||||||
|
"unique_id": ("UNIQUE_ID",),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -230,161 +245,72 @@ class CanvasNode:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# Zmienna blokująca równoczesne wykonania
|
# Zmienna blokująca równoczesne wykonania
|
||||||
_processing_lock = None
|
_processing_lock = threading.Lock()
|
||||||
|
|
||||||
def process_canvas_image(self, canvas_image, trigger, output_switch, cache_enabled, input_image=None,
|
def process_canvas_image(self, trigger, output_switch, cache_enabled, node_id, prompt=None, unique_id=None, input_image=None,
|
||||||
input_mask=None):
|
input_mask=None):
|
||||||
|
|
||||||
|
log_info(f"[CanvasNode] 🔍 process_canvas_image wejście – node_id={node_id!r}, unique_id={unique_id!r}, trigger={trigger}, output_switch={output_switch}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Sprawdź czy już trwa przetwarzanie
|
# Sprawdź czy już trwa przetwarzanie
|
||||||
if self.__class__._processing_lock is not None:
|
if not self.__class__._processing_lock.acquire(blocking=False):
|
||||||
log_warn(f"Process already in progress, waiting for completion...")
|
log_warn(f"Process already in progress for node {node_id}, skipping...")
|
||||||
return () # Zwróć pusty wynik, aby uniknąć równoczesnych przetworzeń
|
# Return cached data if available to avoid breaking the flow
|
||||||
|
return self.get_cached_data()
|
||||||
# Ustaw blokadę
|
|
||||||
self.__class__._processing_lock = True
|
log_info(f"Lock acquired. Starting process_canvas_image for node_id: {node_id} (fallback unique_id: {unique_id})")
|
||||||
|
|
||||||
current_execution = self.get_execution_id()
|
# Use node_id as the primary key, as unique_id is proving unreliable
|
||||||
log_info(f"Starting process_canvas_image - execution ID: {current_execution}, trigger: {trigger}")
|
storage_key = node_id
|
||||||
log_debug(f"Canvas image filename: {canvas_image}")
|
|
||||||
log_debug(f"Output switch: {output_switch}, Cache enabled: {cache_enabled}")
|
processed_image = None
|
||||||
log_debug(f"Input image provided: {input_image is not None}")
|
processed_mask = None
|
||||||
log_debug(f"Input mask provided: {input_mask is not None}")
|
|
||||||
|
|
||||||
if current_execution != self.__class__._canvas_cache['last_execution_id']:
|
with self.__class__._storage_lock:
|
||||||
log_info(f"New execution detected: {current_execution} (previous: {self.__class__._canvas_cache['last_execution_id']})")
|
canvas_data = self.__class__._canvas_data_storage.pop(storage_key, None)
|
||||||
|
|
||||||
self.__class__._canvas_cache['image'] = None
|
if canvas_data:
|
||||||
self.__class__._canvas_cache['mask'] = None
|
log_info(f"Canvas data found for node {storage_key} from WebSocket")
|
||||||
self.__class__._canvas_cache['last_execution_id'] = current_execution
|
if canvas_data.get('image'):
|
||||||
|
image_data = canvas_data['image'].split(',')[1]
|
||||||
|
image_bytes = base64.b64decode(image_data)
|
||||||
|
pil_image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
|
||||||
|
image_array = np.array(pil_image).astype(np.float32) / 255.0
|
||||||
|
processed_image = torch.from_numpy(image_array)[None,]
|
||||||
|
log_debug(f"Image loaded from WebSocket, shape: {processed_image.shape}")
|
||||||
|
|
||||||
|
if canvas_data.get('mask'):
|
||||||
|
mask_data = canvas_data['mask'].split(',')[1]
|
||||||
|
mask_bytes = base64.b64decode(mask_data)
|
||||||
|
pil_mask = Image.open(io.BytesIO(mask_bytes)).convert('L')
|
||||||
|
mask_array = np.array(pil_mask).astype(np.float32) / 255.0
|
||||||
|
processed_mask = torch.from_numpy(mask_array)[None,]
|
||||||
|
log_debug(f"Mask loaded from WebSocket, shape: {processed_mask.shape}")
|
||||||
else:
|
else:
|
||||||
log_debug(f"Same execution ID, using cached data")
|
log_warn(f"No canvas data found for node {storage_key} in WebSocket cache, using fallbacks.")
|
||||||
|
if input_image is not None:
|
||||||
|
log_info("Using provided input_image as fallback")
|
||||||
|
processed_image = input_image
|
||||||
|
if input_mask is not None:
|
||||||
|
log_info("Using provided input_mask as fallback")
|
||||||
|
processed_mask = input_mask
|
||||||
|
|
||||||
if input_image is not None:
|
|
||||||
log_info("Input image received, converting to PIL Image...")
|
|
||||||
|
|
||||||
if isinstance(input_image, torch.Tensor):
|
# Fallback to default tensors if nothing is loaded
|
||||||
if input_image.dim() == 4:
|
if processed_image is None:
|
||||||
input_image = input_image.squeeze(0) # 移除batch维度
|
log_warn(f"Processed image is still None, creating default blank image.")
|
||||||
|
processed_image = torch.zeros((1, 512, 512, 3), dtype=torch.float32)
|
||||||
|
if processed_mask is None:
|
||||||
|
log_warn(f"Processed mask is still None, creating default blank mask.")
|
||||||
|
processed_mask = torch.zeros((1, 512, 512), dtype=torch.float32)
|
||||||
|
|
||||||
if input_image.shape[0] == 3: # 如果是[C, H, W]格式
|
|
||||||
input_image = input_image.permute(1, 2, 0)
|
|
||||||
|
|
||||||
image_array = (input_image.cpu().numpy() * 255).astype(np.uint8)
|
|
||||||
|
|
||||||
if len(image_array.shape) == 2: # 如果是灰度图
|
|
||||||
image_array = np.stack([image_array] * 3, axis=-1)
|
|
||||||
elif len(image_array.shape) == 3 and image_array.shape[-1] != 3:
|
|
||||||
image_array = np.transpose(image_array, (1, 2, 0))
|
|
||||||
|
|
||||||
try:
|
|
||||||
|
|
||||||
pil_image = Image.fromarray(image_array, 'RGB')
|
|
||||||
log_debug("Successfully converted to PIL Image")
|
|
||||||
|
|
||||||
self.__class__._canvas_cache['image'] = pil_image
|
|
||||||
log_debug(f"Image stored in cache with size: {pil_image.size}")
|
|
||||||
except Exception as e:
|
|
||||||
log_error(f"Error converting to PIL Image: {str(e)}")
|
|
||||||
log_debug(f"Array shape: {image_array.shape}, dtype: {image_array.dtype}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
if input_mask is not None:
|
|
||||||
log_info("Input mask received, converting to PIL Image...")
|
|
||||||
if isinstance(input_mask, torch.Tensor):
|
|
||||||
if input_mask.dim() == 4:
|
|
||||||
input_mask = input_mask.squeeze(0)
|
|
||||||
if input_mask.dim() == 3 and input_mask.shape[0] == 1:
|
|
||||||
input_mask = input_mask.squeeze(0)
|
|
||||||
|
|
||||||
mask_array = (input_mask.cpu().numpy() * 255).astype(np.uint8)
|
|
||||||
pil_mask = Image.fromarray(mask_array, 'L')
|
|
||||||
log_debug("Successfully converted mask to PIL Image")
|
|
||||||
|
|
||||||
self.__class__._canvas_cache['mask'] = pil_mask
|
|
||||||
log_debug(f"Mask stored in cache with size: {pil_mask.size}")
|
|
||||||
|
|
||||||
self.__class__._canvas_cache['cache_enabled'] = cache_enabled
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Wczytaj obraz bez maski
|
|
||||||
image_without_mask_name = canvas_image.replace('.png', '_without_mask.png')
|
|
||||||
path_image_without_mask = folder_paths.get_annotated_filepath(image_without_mask_name)
|
|
||||||
log_debug(f"Canvas image name: {canvas_image}")
|
|
||||||
log_debug(f"Looking for image without mask: {image_without_mask_name}")
|
|
||||||
log_debug(f"Full path: {path_image_without_mask}")
|
|
||||||
|
|
||||||
# Sprawdź czy plik istnieje
|
|
||||||
if not os.path.exists(path_image_without_mask):
|
|
||||||
log_warn(f"Image without mask not found at: {path_image_without_mask}")
|
|
||||||
# Spróbuj znaleźć plik w katalogu input
|
|
||||||
input_dir = folder_paths.get_input_directory()
|
|
||||||
alternative_path = os.path.join(input_dir, image_without_mask_name)
|
|
||||||
log_debug(f"Trying alternative path: {alternative_path}")
|
|
||||||
if os.path.exists(alternative_path):
|
|
||||||
path_image_without_mask = alternative_path
|
|
||||||
log_info(f"Found image at alternative path: {alternative_path}")
|
|
||||||
else:
|
|
||||||
raise FileNotFoundError(f"Image file not found: {image_without_mask_name}")
|
|
||||||
|
|
||||||
i = Image.open(path_image_without_mask)
|
|
||||||
i = ImageOps.exif_transpose(i)
|
|
||||||
if i.mode not in ['RGB', 'RGBA']:
|
|
||||||
i = i.convert('RGB')
|
|
||||||
image = np.array(i).astype(np.float32) / 255.0
|
|
||||||
if i.mode == 'RGBA':
|
|
||||||
rgb = image[..., :3]
|
|
||||||
alpha = image[..., 3:]
|
|
||||||
image = rgb * alpha + (1 - alpha) * 0.5
|
|
||||||
processed_image = torch.from_numpy(image)[None,]
|
|
||||||
log_debug(f"Successfully loaded image without mask, shape: {processed_image.shape}")
|
|
||||||
except Exception as e:
|
|
||||||
log_error(f"Error loading image without mask: {str(e)}")
|
|
||||||
processed_image = torch.ones((1, 512, 512, 3), dtype=torch.float32)
|
|
||||||
log_debug(f"Using default image, shape: {processed_image.shape}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Wczytaj maskę
|
|
||||||
path_image = folder_paths.get_annotated_filepath(canvas_image)
|
|
||||||
path_mask = path_image.replace('.png', '_mask.png')
|
|
||||||
log_debug(f"Canvas image path: {path_image}")
|
|
||||||
log_debug(f"Looking for mask at: {path_mask}")
|
|
||||||
|
|
||||||
# Sprawdź czy plik maski istnieje
|
|
||||||
if not os.path.exists(path_mask):
|
|
||||||
log_warn(f"Mask not found at: {path_mask}")
|
|
||||||
# Spróbuj znaleźć plik w katalogu input
|
|
||||||
input_dir = folder_paths.get_input_directory()
|
|
||||||
mask_name = canvas_image.replace('.png', '_mask.png')
|
|
||||||
alternative_mask_path = os.path.join(input_dir, mask_name)
|
|
||||||
log_debug(f"Trying alternative mask path: {alternative_mask_path}")
|
|
||||||
if os.path.exists(alternative_mask_path):
|
|
||||||
path_mask = alternative_mask_path
|
|
||||||
log_info(f"Found mask at alternative path: {alternative_mask_path}")
|
|
||||||
|
|
||||||
if os.path.exists(path_mask):
|
|
||||||
log_debug(f"Mask file exists, loading...")
|
|
||||||
mask = Image.open(path_mask).convert('L')
|
|
||||||
mask = np.array(mask).astype(np.float32) / 255.0
|
|
||||||
processed_mask = torch.from_numpy(mask)[None,]
|
|
||||||
log_debug(f"Successfully loaded mask, shape: {processed_mask.shape}")
|
|
||||||
else:
|
|
||||||
log_debug(f"Mask file does not exist, creating default mask")
|
|
||||||
processed_mask = torch.ones((1, processed_image.shape[1], processed_image.shape[2]),
|
|
||||||
dtype=torch.float32)
|
|
||||||
log_debug(f"Default mask created, shape: {processed_mask.shape}")
|
|
||||||
except Exception as e:
|
|
||||||
log_error(f"Error loading mask: {str(e)}")
|
|
||||||
processed_mask = torch.ones((1, processed_image.shape[1], processed_image.shape[2]),
|
|
||||||
dtype=torch.float32)
|
|
||||||
log_debug(f"Fallback mask created, shape: {processed_mask.shape}")
|
|
||||||
|
|
||||||
if not output_switch:
|
if not output_switch:
|
||||||
log_debug(f"Output switch is OFF, returning empty tuple")
|
log_debug(f"Output switch is OFF, returning empty tuple")
|
||||||
return ()
|
return (None, None)
|
||||||
|
|
||||||
log_debug(f"About to return output - Image shape: {processed_image.shape}, Mask shape: {processed_mask.shape}")
|
log_debug(f"About to return output - Image shape: {processed_image.shape}, Mask shape: {processed_mask.shape}")
|
||||||
log_debug(f"Image tensor info - dtype: {processed_image.dtype}, device: {processed_image.device}")
|
|
||||||
log_debug(f"Mask tensor info - dtype: {processed_mask.dtype}, device: {processed_mask.device}")
|
|
||||||
|
|
||||||
self.update_persistent_cache()
|
self.update_persistent_cache()
|
||||||
|
|
||||||
@@ -393,12 +319,13 @@ class CanvasNode:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log_exception(f"Error in process_canvas_image: {str(e)}")
|
log_exception(f"Error in process_canvas_image: {str(e)}")
|
||||||
return ()
|
return (None, None)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Zwolnij blokadę
|
# Zwolnij blokadę
|
||||||
self.__class__._processing_lock = None
|
if self.__class__._processing_lock.locked():
|
||||||
log_debug(f"Process completed, lock released")
|
self.__class__._processing_lock.release()
|
||||||
|
log_debug(f"Process completed for node {node_id}, lock released")
|
||||||
|
|
||||||
def get_cached_data(self):
|
def get_cached_data(self):
|
||||||
return {
|
return {
|
||||||
@@ -440,8 +367,80 @@ class CanvasNode:
|
|||||||
return cls._canvas_cache['data_flow_status'].get(flow_id)
|
return cls._canvas_cache['data_flow_status'].get(flow_id)
|
||||||
return cls._canvas_cache['data_flow_status']
|
return cls._canvas_cache['data_flow_status']
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _cleanup_old_websocket_data(cls):
|
||||||
|
"""Clean up old WebSocket data from invalid nodes or data older than 5 minutes"""
|
||||||
|
try:
|
||||||
|
current_time = time.time()
|
||||||
|
cleanup_threshold = 300 # 5 minutes
|
||||||
|
|
||||||
|
nodes_to_remove = []
|
||||||
|
for node_id, data in cls._websocket_data.items():
|
||||||
|
# Remove invalid node IDs
|
||||||
|
if node_id < 0:
|
||||||
|
nodes_to_remove.append(node_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Remove old data
|
||||||
|
if current_time - data.get('timestamp', 0) > cleanup_threshold:
|
||||||
|
nodes_to_remove.append(node_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
for node_id in nodes_to_remove:
|
||||||
|
del cls._websocket_data[node_id]
|
||||||
|
log_debug(f"Cleaned up old WebSocket data for node {node_id}")
|
||||||
|
|
||||||
|
if nodes_to_remove:
|
||||||
|
log_info(f"Cleaned up {len(nodes_to_remove)} old WebSocket entries")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log_error(f"Error during WebSocket cleanup: {str(e)}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setup_routes(cls):
|
def setup_routes(cls):
|
||||||
|
@PromptServer.instance.routes.get("/layerforge/canvas_ws")
|
||||||
|
async def handle_canvas_websocket(request):
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
|
||||||
|
async for msg in ws:
|
||||||
|
if msg.type == web.WSMsgType.TEXT:
|
||||||
|
try:
|
||||||
|
data = msg.json()
|
||||||
|
node_id = data.get('nodeId')
|
||||||
|
if not node_id:
|
||||||
|
await ws.send_json({'status': 'error', 'message': 'nodeId is required'})
|
||||||
|
continue
|
||||||
|
|
||||||
|
image_data = data.get('image')
|
||||||
|
mask_data = data.get('mask')
|
||||||
|
|
||||||
|
with cls._storage_lock:
|
||||||
|
cls._canvas_data_storage[node_id] = {
|
||||||
|
'image': image_data,
|
||||||
|
'mask': mask_data,
|
||||||
|
'timestamp': time.time()
|
||||||
|
}
|
||||||
|
|
||||||
|
log_info(f"Received canvas data for node {node_id} via WebSocket")
|
||||||
|
# Send acknowledgment back to the client
|
||||||
|
ack_payload = {
|
||||||
|
'type': 'ack',
|
||||||
|
'nodeId': node_id,
|
||||||
|
'status': 'success'
|
||||||
|
}
|
||||||
|
await ws.send_json(ack_payload)
|
||||||
|
log_debug(f"Sent ACK for node {node_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log_error(f"Error processing WebSocket message: {e}")
|
||||||
|
await ws.send_json({'status': 'error', 'message': str(e)})
|
||||||
|
elif msg.type == web.WSMsgType.ERROR:
|
||||||
|
log_error(f"WebSocket connection closed with exception {ws.exception()}")
|
||||||
|
|
||||||
|
log_info("WebSocket connection closed")
|
||||||
|
return ws
|
||||||
|
|
||||||
@PromptServer.instance.routes.get("/ycnode/get_canvas_data/{node_id}")
|
@PromptServer.instance.routes.get("/ycnode/get_canvas_data/{node_id}")
|
||||||
async def get_canvas_data(request):
|
async def get_canvas_data(request):
|
||||||
try:
|
try:
|
||||||
@@ -811,3 +810,15 @@ def convert_tensor_to_base64(tensor, alpha_mask=None, original_alpha=None):
|
|||||||
log_error(f"Error in convert_tensor_to_base64: {str(e)}")
|
log_error(f"Error in convert_tensor_to_base64: {str(e)}")
|
||||||
log_debug(f"Tensor shape: {tensor.shape}, dtype: {tensor.dtype}")
|
log_debug(f"Tensor shape: {tensor.shape}, dtype: {tensor.dtype}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# Setup original API routes when module is loaded
|
||||||
|
CanvasNode.setup_routes()
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"CanvasNode": CanvasNode
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"CanvasNode": "LayerForge"
|
||||||
|
}
|
||||||
|
|||||||
@@ -266,9 +266,6 @@ export class Canvas {
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
async saveToServer(fileName) {
|
|
||||||
return this.canvasIO.saveToServer(fileName);
|
|
||||||
}
|
|
||||||
|
|
||||||
async getFlattenedCanvasAsBlob() {
|
async getFlattenedCanvasAsBlob() {
|
||||||
return this.canvasLayers.getFlattenedCanvasAsBlob();
|
return this.canvasLayers.getFlattenedCanvasAsBlob();
|
||||||
|
|||||||
193
js/CanvasIO.js
193
js/CanvasIO.js
@@ -1,5 +1,6 @@
|
|||||||
import {createCanvas} from "./utils/CommonUtils.js";
|
import {createCanvas} from "./utils/CommonUtils.js";
|
||||||
import {createModuleLogger} from "./utils/LoggerUtils.js";
|
import {createModuleLogger} from "./utils/LoggerUtils.js";
|
||||||
|
import {webSocketManager} from "./utils/WebSocketManager.js";
|
||||||
|
|
||||||
const log = createModuleLogger('CanvasIO');
|
const log = createModuleLogger('CanvasIO');
|
||||||
|
|
||||||
@@ -9,35 +10,38 @@ export class CanvasIO {
|
|||||||
this._saveInProgress = null;
|
this._saveInProgress = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
async saveToServer(fileName) {
|
async saveToServer(fileName, outputMode = 'disk') {
|
||||||
if (!window.canvasSaveStates) {
|
if (outputMode === 'disk') {
|
||||||
window.canvasSaveStates = new Map();
|
if (!window.canvasSaveStates) {
|
||||||
}
|
window.canvasSaveStates = new Map();
|
||||||
|
}
|
||||||
const nodeId = this.canvas.node.id;
|
|
||||||
const saveKey = `${nodeId}_${fileName}`;
|
|
||||||
if (this._saveInProgress || window.canvasSaveStates.get(saveKey)) {
|
|
||||||
log.warn(`Save already in progress for node ${nodeId}, waiting...`);
|
|
||||||
return this._saveInProgress || window.canvasSaveStates.get(saveKey);
|
|
||||||
}
|
|
||||||
|
|
||||||
log.info(`Starting saveToServer with fileName: ${fileName} for node: ${nodeId}`);
|
const nodeId = this.canvas.node.id;
|
||||||
log.debug(`Canvas dimensions: ${this.canvas.width}x${this.canvas.height}`);
|
const saveKey = `${nodeId}_${fileName}`;
|
||||||
log.debug(`Number of layers: ${this.canvas.layers.length}`);
|
if (this._saveInProgress || window.canvasSaveStates.get(saveKey)) {
|
||||||
this._saveInProgress = this._performSave(fileName);
|
log.warn(`Save already in progress for node ${nodeId}, waiting...`);
|
||||||
window.canvasSaveStates.set(saveKey, this._saveInProgress);
|
return this._saveInProgress || window.canvasSaveStates.get(saveKey);
|
||||||
|
}
|
||||||
try {
|
|
||||||
const result = await this._saveInProgress;
|
log.info(`Starting saveToServer (disk) with fileName: ${fileName} for node: ${nodeId}`);
|
||||||
return result;
|
this._saveInProgress = this._performSave(fileName, outputMode);
|
||||||
} finally {
|
window.canvasSaveStates.set(saveKey, this._saveInProgress);
|
||||||
this._saveInProgress = null;
|
|
||||||
window.canvasSaveStates.delete(saveKey);
|
try {
|
||||||
log.debug(`Save completed for node ${nodeId}, lock released`);
|
return await this._saveInProgress;
|
||||||
|
} finally {
|
||||||
|
this._saveInProgress = null;
|
||||||
|
window.canvasSaveStates.delete(saveKey);
|
||||||
|
log.debug(`Save completed for node ${nodeId}, lock released`);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// For RAM mode, we don't need the lock/state management as it's synchronous
|
||||||
|
log.info(`Starting saveToServer (RAM) for node: ${this.canvas.node.id}`);
|
||||||
|
return this._performSave(fileName, outputMode);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async _performSave(fileName) {
|
async _performSave(fileName, outputMode) {
|
||||||
if (this.canvas.layers.length === 0) {
|
if (this.canvas.layers.length === 0) {
|
||||||
log.warn(`Node ${this.canvas.node.id} has no layers, creating empty canvas`);
|
log.warn(`Node ${this.canvas.node.id} has no layers, creating empty canvas`);
|
||||||
return Promise.resolve(true);
|
return Promise.resolve(true);
|
||||||
@@ -152,6 +156,15 @@ export class CanvasIO {
|
|||||||
maskCtx.globalCompositeOperation = 'source-over';
|
maskCtx.globalCompositeOperation = 'source-over';
|
||||||
maskCtx.drawImage(tempMaskCanvas, 0, 0);
|
maskCtx.drawImage(tempMaskCanvas, 0, 0);
|
||||||
}
|
}
|
||||||
|
if (outputMode === 'ram') {
|
||||||
|
const imageData = tempCanvas.toDataURL('image/png');
|
||||||
|
const maskData = maskCanvas.toDataURL('image/png');
|
||||||
|
log.info("Returning image and mask data as base64 for RAM mode.");
|
||||||
|
resolve({ image: imageData, mask: maskData });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Disk Mode (original logic) ---
|
||||||
const fileNameWithoutMask = fileName.replace('.png', '_without_mask.png');
|
const fileNameWithoutMask = fileName.replace('.png', '_without_mask.png');
|
||||||
log.info(`Saving image without mask as: ${fileNameWithoutMask}`);
|
log.info(`Saving image without mask as: ${fileNameWithoutMask}`);
|
||||||
|
|
||||||
@@ -204,7 +217,9 @@ export class CanvasIO {
|
|||||||
|
|
||||||
if (maskResp.status === 200) {
|
if (maskResp.status === 200) {
|
||||||
const data = await resp.json();
|
const data = await resp.json();
|
||||||
this.canvas.widget.value = fileName;
|
if (this.canvas.widget) {
|
||||||
|
this.canvas.widget.value = fileName;
|
||||||
|
}
|
||||||
log.info(`All files saved successfully, widget value set to: ${fileName}`);
|
log.info(`All files saved successfully, widget value set to: ${fileName}`);
|
||||||
resolve(true);
|
resolve(true);
|
||||||
} else {
|
} else {
|
||||||
@@ -228,6 +243,132 @@ export class CanvasIO {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async _renderOutputData() {
|
||||||
|
return new Promise((resolve) => {
|
||||||
|
const { canvas: tempCanvas, ctx: tempCtx } = createCanvas(this.canvas.width, this.canvas.height);
|
||||||
|
const { canvas: maskCanvas, ctx: maskCtx } = createCanvas(this.canvas.width, this.canvas.height);
|
||||||
|
|
||||||
|
// This logic is mostly mirrored from _performSave to ensure consistency
|
||||||
|
tempCtx.fillStyle = '#ffffff';
|
||||||
|
tempCtx.fillRect(0, 0, this.canvas.width, this.canvas.height);
|
||||||
|
const visibilityCanvas = document.createElement('canvas');
|
||||||
|
visibilityCanvas.width = this.canvas.width;
|
||||||
|
visibilityCanvas.height = this.canvas.height;
|
||||||
|
const visibilityCtx = visibilityCanvas.getContext('2d', { alpha: true });
|
||||||
|
maskCtx.fillStyle = '#ffffff'; // Start with a white mask (nothing masked)
|
||||||
|
maskCtx.fillRect(0, 0, this.canvas.width, this.canvas.height);
|
||||||
|
|
||||||
|
const sortedLayers = this.canvas.layers.sort((a, b) => a.zIndex - b.zIndex);
|
||||||
|
sortedLayers.forEach((layer) => {
|
||||||
|
// Render layer to main canvas
|
||||||
|
tempCtx.save();
|
||||||
|
tempCtx.globalCompositeOperation = layer.blendMode || 'normal';
|
||||||
|
tempCtx.globalAlpha = layer.opacity !== undefined ? layer.opacity : 1;
|
||||||
|
tempCtx.translate(layer.x + layer.width / 2, layer.y + layer.height / 2);
|
||||||
|
tempCtx.rotate(layer.rotation * Math.PI / 180);
|
||||||
|
tempCtx.drawImage(layer.image, -layer.width / 2, -layer.height / 2, layer.width, layer.height);
|
||||||
|
tempCtx.restore();
|
||||||
|
|
||||||
|
// Render layer to visibility canvas for the mask
|
||||||
|
visibilityCtx.save();
|
||||||
|
visibilityCtx.translate(layer.x + layer.width / 2, layer.y + layer.height / 2);
|
||||||
|
visibilityCtx.rotate(layer.rotation * Math.PI / 180);
|
||||||
|
visibilityCtx.drawImage(layer.image, -layer.width / 2, -layer.height / 2, layer.width, layer.height);
|
||||||
|
visibilityCtx.restore();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Create layer visibility mask
|
||||||
|
const visibilityData = visibilityCtx.getImageData(0, 0, this.canvas.width, this.canvas.height);
|
||||||
|
const maskData = maskCtx.getImageData(0, 0, this.canvas.width, this.canvas.height);
|
||||||
|
for (let i = 0; i < visibilityData.data.length; i += 4) {
|
||||||
|
const alpha = visibilityData.data[i + 3];
|
||||||
|
const maskValue = 255 - alpha; // Invert alpha to create the mask
|
||||||
|
maskData.data[i] = maskData.data[i + 1] = maskData.data[i + 2] = maskValue;
|
||||||
|
maskData.data[i + 3] = 255; // Solid mask
|
||||||
|
}
|
||||||
|
maskCtx.putImageData(maskData, 0, 0);
|
||||||
|
|
||||||
|
// Composite the tool mask on top
|
||||||
|
const toolMaskCanvas = this.canvas.maskTool.getMask();
|
||||||
|
if (toolMaskCanvas) {
|
||||||
|
// Create a temp canvas for processing the mask
|
||||||
|
const tempMaskCanvas = document.createElement('canvas');
|
||||||
|
tempMaskCanvas.width = this.canvas.width;
|
||||||
|
tempMaskCanvas.height = this.canvas.height;
|
||||||
|
const tempMaskCtx = tempMaskCanvas.getContext('2d');
|
||||||
|
|
||||||
|
// Clear the canvas
|
||||||
|
tempMaskCtx.clearRect(0, 0, tempMaskCanvas.width, tempMaskCanvas.height);
|
||||||
|
|
||||||
|
// Calculate the correct position to extract the mask
|
||||||
|
const maskX = this.canvas.maskTool.x;
|
||||||
|
const maskY = this.canvas.maskTool.y;
|
||||||
|
|
||||||
|
log.debug(`[renderOutputData] Extracting mask from world position (${maskX}, ${maskY})`);
|
||||||
|
|
||||||
|
const sourceX = Math.max(0, -maskX);
|
||||||
|
const sourceY = Math.max(0, -maskY);
|
||||||
|
const destX = Math.max(0, maskX);
|
||||||
|
const destY = Math.max(0, maskY);
|
||||||
|
|
||||||
|
const copyWidth = Math.min(toolMaskCanvas.width - sourceX, this.canvas.width - destX);
|
||||||
|
const copyHeight = Math.min(toolMaskCanvas.height - sourceY, this.canvas.height - destY);
|
||||||
|
|
||||||
|
if (copyWidth > 0 && copyHeight > 0) {
|
||||||
|
tempMaskCtx.drawImage(
|
||||||
|
toolMaskCanvas,
|
||||||
|
sourceX, sourceY, copyWidth, copyHeight,
|
||||||
|
destX, destY, copyWidth, copyHeight
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert the brush mask (white with alpha) to a solid white mask on black background.
|
||||||
|
const tempMaskData = tempMaskCtx.getImageData(0, 0, this.canvas.width, this.canvas.height);
|
||||||
|
for (let i = 0; i < tempMaskData.data.length; i += 4) {
|
||||||
|
const alpha = tempMaskData.data[i + 3];
|
||||||
|
// The painted area (alpha > 0) should become white (255).
|
||||||
|
tempMaskData.data[i] = tempMaskData.data[i+1] = tempMaskData.data[i+2] = alpha;
|
||||||
|
tempMaskData.data[i + 3] = 255; // Solid alpha
|
||||||
|
}
|
||||||
|
tempMaskCtx.putImageData(tempMaskData, 0, 0);
|
||||||
|
|
||||||
|
// Use 'screen' blending mode. This correctly adds the white brush mask
|
||||||
|
// to the existing layer visibility mask. (white + anything = white)
|
||||||
|
maskCtx.globalCompositeOperation = 'screen';
|
||||||
|
maskCtx.drawImage(tempMaskCanvas, 0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
const imageDataUrl = tempCanvas.toDataURL('image/png');
|
||||||
|
const maskDataUrl = maskCanvas.toDataURL('image/png');
|
||||||
|
|
||||||
|
resolve({ image: imageDataUrl, mask: maskDataUrl });
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
async sendDataViaWebSocket(nodeId) {
|
||||||
|
log.info(`Preparing to send data for node ${nodeId} via WebSocket.`);
|
||||||
|
|
||||||
|
const { image, mask } = await this._renderOutputData();
|
||||||
|
|
||||||
|
try {
|
||||||
|
log.info(`Sending data for node ${nodeId}...`);
|
||||||
|
await webSocketManager.sendMessage({
|
||||||
|
type: 'canvas_data',
|
||||||
|
nodeId: String(nodeId),
|
||||||
|
image: image,
|
||||||
|
mask: mask,
|
||||||
|
}, true); // `true` requires an acknowledgment
|
||||||
|
|
||||||
|
log.info(`Data for node ${nodeId} has been sent and acknowledged by the server.`);
|
||||||
|
return true;
|
||||||
|
} catch (error) {
|
||||||
|
log.error(`Failed to send data for node ${nodeId}:`, error);
|
||||||
|
// We can alert the user here or handle it silently.
|
||||||
|
// For now, let's throw to make it clear the process failed.
|
||||||
|
throw new Error(`Failed to get confirmation from server for node ${nodeId}. The workflow might not have the latest canvas data.`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async addInputToCanvas(inputImage, inputMask) {
|
async addInputToCanvas(inputImage, inputMask) {
|
||||||
try {
|
try {
|
||||||
log.debug("Adding input to canvas:", {inputImage});
|
log.debug("Adding input to canvas:", {inputImage});
|
||||||
|
|||||||
172
js/CanvasView.js
172
js/CanvasView.js
@@ -377,8 +377,7 @@ async function createCanvasWidget(node, widget, app) {
|
|||||||
const img = new Image();
|
const img = new Image();
|
||||||
img.onload = async () => {
|
img.onload = async () => {
|
||||||
canvas.addLayer(img);
|
canvas.addLayer(img);
|
||||||
await saveWithFallback(widget.value);
|
await updateOutput();
|
||||||
app.graph.runStep();
|
|
||||||
};
|
};
|
||||||
img.src = event.target.result;
|
img.src = event.target.result;
|
||||||
};
|
};
|
||||||
@@ -392,8 +391,7 @@ async function createCanvasWidget(node, widget, app) {
|
|||||||
textContent: "Import Input",
|
textContent: "Import Input",
|
||||||
onclick: async () => {
|
onclick: async () => {
|
||||||
if (await canvas.importLatestImage()) {
|
if (await canvas.importLatestImage()) {
|
||||||
await saveWithFallback(widget.value);
|
await updateOutput();
|
||||||
app.graph.runStep();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
@@ -574,8 +572,7 @@ async function createCanvasWidget(node, widget, app) {
|
|||||||
canvas.updateSelection([newLayer]);
|
canvas.updateSelection([newLayer]);
|
||||||
canvas.render();
|
canvas.render();
|
||||||
canvas.saveState();
|
canvas.saveState();
|
||||||
await saveWithFallback(widget.value);
|
await updateOutput();
|
||||||
app.graph.runStep();
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
log.error("Matting error:", error);
|
log.error("Matting error:", error);
|
||||||
alert(`Error during matting process: ${error.message}`);
|
alert(`Error during matting process: ${error.message}`);
|
||||||
@@ -745,7 +742,8 @@ async function createCanvasWidget(node, widget, app) {
|
|||||||
const triggerWidget = node.widgets.find(w => w.name === "trigger");
|
const triggerWidget = node.widgets.find(w => w.name === "trigger");
|
||||||
|
|
||||||
const updateOutput = async () => {
|
const updateOutput = async () => {
|
||||||
await saveWithFallback(widget.value);
|
// Only increment trigger and run step - don't save to disk here
|
||||||
|
// Saving to disk will happen during execution_start event
|
||||||
triggerWidget.value = (triggerWidget.value + 1) % 99999999;
|
triggerWidget.value = (triggerWidget.value + 1) % 99999999;
|
||||||
app.graph.runStep();
|
app.graph.runStep();
|
||||||
};
|
};
|
||||||
@@ -790,8 +788,9 @@ async function createCanvasWidget(node, widget, app) {
|
|||||||
canvas.render();
|
canvas.render();
|
||||||
};
|
};
|
||||||
|
|
||||||
canvas.canvas.addEventListener('mouseup', updateOutput);
|
// Remove automatic saving on mouse events - only save during execution
|
||||||
canvas.canvas.addEventListener('mouseleave', updateOutput);
|
// canvas.canvas.addEventListener('mouseup', updateOutput);
|
||||||
|
// canvas.canvas.addEventListener('mouseleave', updateOutput);
|
||||||
|
|
||||||
|
|
||||||
const mainContainer = $el("div.painterMainContainer", {
|
const mainContainer = $el("div.painterMainContainer", {
|
||||||
@@ -922,66 +921,8 @@ async function createCanvasWidget(node, widget, app) {
|
|||||||
if (!window.canvasExecutionStates) {
|
if (!window.canvasExecutionStates) {
|
||||||
window.canvasExecutionStates = new Map();
|
window.canvasExecutionStates = new Map();
|
||||||
}
|
}
|
||||||
const saveWithFallback = async (fileName) => {
|
|
||||||
try {
|
|
||||||
const uniqueFileName = generateUniqueFileName(fileName, node.id);
|
|
||||||
log.debug(`Attempting to save with unique name: ${uniqueFileName}`);
|
|
||||||
return await canvas.saveToServer(uniqueFileName);
|
|
||||||
} catch (error) {
|
|
||||||
log.warn(`Failed to save with unique name, falling back to original: ${fileName}`, error);
|
|
||||||
return await canvas.saveToServer(fileName);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
api.addEventListener("execution_start", async (event) => {
|
|
||||||
const executionData = event.detail || {};
|
|
||||||
const currentPromptId = executionData.prompt_id;
|
|
||||||
|
|
||||||
log.info(`Execution start event for node ${node.id}, prompt_id: ${currentPromptId}`);
|
|
||||||
log.debug(`Widget value: ${widget.value}`);
|
|
||||||
log.debug(`Node inputs: ${node.inputs?.length || 0}`);
|
|
||||||
log.debug(`Canvas layers count: ${canvas.layers.length}`);
|
|
||||||
if (window.canvasExecutionStates.get(node.id)) {
|
|
||||||
log.warn(`Execution already in progress for node ${node.id}, skipping...`);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
window.canvasExecutionStates.set(node.id, true);
|
|
||||||
|
|
||||||
try {
|
|
||||||
if (canvas.layers.length === 0) {
|
|
||||||
log.warn(`Node ${node.id} has no layers, skipping save to server`);
|
|
||||||
} else {
|
|
||||||
await saveWithFallback(widget.value);
|
|
||||||
log.info(`Canvas saved to server for node ${node.id}`);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (node.inputs[0]?.link) {
|
|
||||||
const linkId = node.inputs[0].link;
|
|
||||||
const inputData = app.nodeOutputs[linkId];
|
|
||||||
log.debug(`Input link ${linkId} has data: ${!!inputData}`);
|
|
||||||
if (inputData) {
|
|
||||||
imageCache.set(linkId, inputData);
|
|
||||||
log.debug(`Input data cached for link ${linkId}`);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.debug(`No input link found`);
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
log.error(`Error during execution for node ${node.id}:`, error);
|
|
||||||
} finally {
|
|
||||||
window.canvasExecutionStates.set(node.id, false);
|
|
||||||
log.debug(`Execution completed for node ${node.id}, flag released`);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
const originalSaveToServer = canvas.saveToServer;
|
|
||||||
canvas.saveToServer = async function (fileName) {
|
|
||||||
log.debug(`saveToServer called with fileName: ${fileName}`);
|
|
||||||
log.debug(`Current execution context - node ID: ${node.id}`);
|
|
||||||
const result = await originalSaveToServer.call(this, fileName);
|
|
||||||
log.debug(`saveToServer completed, result: ${result}`);
|
|
||||||
return result;
|
|
||||||
};
|
|
||||||
|
|
||||||
node.canvasWidget = canvas;
|
node.canvasWidget = canvas;
|
||||||
|
|
||||||
@@ -996,30 +937,111 @@ async function createCanvasWidget(node, widget, app) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
const canvasNodeInstances = new Map();
|
||||||
|
|
||||||
app.registerExtension({
|
app.registerExtension({
|
||||||
name: "Comfy.CanvasNode",
|
name: "Comfy.CanvasNode",
|
||||||
|
|
||||||
|
init() {
|
||||||
|
// Monkey-patch the queuePrompt function to send canvas data via WebSocket before sending the prompt
|
||||||
|
const originalQueuePrompt = app.queuePrompt;
|
||||||
|
app.queuePrompt = async function(number, prompt) {
|
||||||
|
log.info("Preparing to queue prompt...");
|
||||||
|
|
||||||
|
if (canvasNodeInstances.size > 0) {
|
||||||
|
log.info(`Found ${canvasNodeInstances.size} CanvasNode(s). Sending data via WebSocket...`);
|
||||||
|
|
||||||
|
const sendPromises = [];
|
||||||
|
for (const [nodeId, canvasWidget] of canvasNodeInstances.entries()) {
|
||||||
|
// Ensure the node still exists on the graph before sending data
|
||||||
|
if (app.graph.getNodeById(nodeId) && canvasWidget.canvas && canvasWidget.canvas.canvasIO) {
|
||||||
|
log.debug(`Sending data for canvas node ${nodeId}`);
|
||||||
|
// This now returns a promise that resolves upon server ACK
|
||||||
|
sendPromises.push(canvasWidget.canvas.canvasIO.sendDataViaWebSocket(nodeId));
|
||||||
|
} else {
|
||||||
|
// If node doesn't exist, it might have been deleted, so we can clean up the map
|
||||||
|
log.warn(`Node ${nodeId} not found in graph, removing from instances map.`);
|
||||||
|
canvasNodeInstances.delete(nodeId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Wait for all WebSocket messages to be acknowledged
|
||||||
|
await Promise.all(sendPromises);
|
||||||
|
log.info("All canvas data has been sent and acknowledged by the server.");
|
||||||
|
} catch (error) {
|
||||||
|
log.error("Failed to send canvas data for one or more nodes. Aborting prompt.", error);
|
||||||
|
// IMPORTANT: Stop the prompt from queueing if data transfer fails.
|
||||||
|
// You might want to show a user-facing error here.
|
||||||
|
alert(`CanvasNode Error: ${error.message}`);
|
||||||
|
return; // Stop execution
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.info("All pre-prompt tasks complete. Proceeding with original queuePrompt.");
|
||||||
|
// Proceed with the original queuePrompt logic
|
||||||
|
return originalQueuePrompt.apply(this, arguments);
|
||||||
|
};
|
||||||
|
},
|
||||||
|
|
||||||
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
||||||
if (nodeType.comfyClass === "CanvasNode") {
|
if (nodeType.comfyClass === "CanvasNode") {
|
||||||
const onNodeCreated = nodeType.prototype.onNodeCreated;
|
const onNodeCreated = nodeType.prototype.onNodeCreated;
|
||||||
nodeType.prototype.onNodeCreated = async function () {
|
nodeType.prototype.onNodeCreated = function () {
|
||||||
log.info("CanvasNode created, ID:", this.id);
|
log.debug("CanvasNode onNodeCreated: Base widget setup.");
|
||||||
|
// Call original onNodeCreated to ensure widgets are created
|
||||||
const r = onNodeCreated?.apply(this, arguments);
|
const r = onNodeCreated?.apply(this, arguments);
|
||||||
|
// The main initialization is moved to onAdded
|
||||||
const widget = this.widgets.find(w => w.name === "canvas_image");
|
|
||||||
log.debug("Found canvas_image widget:", widget);
|
|
||||||
await createCanvasWidget(this, widget, app);
|
|
||||||
|
|
||||||
return r;
|
return r;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// onAdded is the most reliable callback for when a node is fully added to the graph and has an ID
|
||||||
|
nodeType.prototype.onAdded = async function() {
|
||||||
|
log.info(`CanvasNode onAdded, ID: ${this.id}`);
|
||||||
|
log.debug(`Available widgets in onAdded:`, this.widgets.map(w => w.name));
|
||||||
|
|
||||||
|
// Prevent re-initialization if the widget already exists
|
||||||
|
if (this.canvasWidget) {
|
||||||
|
log.warn(`CanvasNode ${this.id} already initialized. Skipping onAdded setup.`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now that we are in onAdded, this.id is guaranteed to be correct.
|
||||||
|
// Set the hidden node_id widget's value for backend communication.
|
||||||
|
const nodeIdWidget = this.widgets.find(w => w.name === "node_id");
|
||||||
|
if (nodeIdWidget) {
|
||||||
|
nodeIdWidget.value = String(this.id);
|
||||||
|
log.debug(`Set hidden node_id widget to: ${nodeIdWidget.value}`);
|
||||||
|
} else {
|
||||||
|
log.error("Could not find the hidden node_id widget!");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the main canvas widget and register it in our global map
|
||||||
|
// We pass `null` for the widget parameter as we are not using a pre-defined widget.
|
||||||
|
const canvasWidget = await createCanvasWidget(this, null, app);
|
||||||
|
canvasNodeInstances.set(this.id, canvasWidget);
|
||||||
|
log.info(`Registered CanvasNode instance for ID: ${this.id}`);
|
||||||
|
};
|
||||||
|
|
||||||
const onRemoved = nodeType.prototype.onRemoved;
|
const onRemoved = nodeType.prototype.onRemoved;
|
||||||
nodeType.prototype.onRemoved = function () {
|
nodeType.prototype.onRemoved = function () {
|
||||||
|
log.info(`Cleaning up canvas node ${this.id}`);
|
||||||
|
|
||||||
|
// Clean up from our instance map
|
||||||
|
canvasNodeInstances.delete(this.id);
|
||||||
|
log.info(`Deregistered CanvasNode instance for ID: ${this.id}`);
|
||||||
|
|
||||||
|
// Clean up execution state
|
||||||
|
if (window.canvasExecutionStates) {
|
||||||
|
window.canvasExecutionStates.delete(this.id);
|
||||||
|
}
|
||||||
|
|
||||||
const tooltip = document.getElementById(`painter-help-tooltip-${this.id}`);
|
const tooltip = document.getElementById(`painter-help-tooltip-${this.id}`);
|
||||||
if (tooltip) {
|
if (tooltip) {
|
||||||
tooltip.remove();
|
tooltip.remove();
|
||||||
}
|
}
|
||||||
const backdrop = document.querySelector('.painter-modal-backdrop');
|
const backdrop = document.querySelector('.painter-modal-backdrop');
|
||||||
if (backdrop && backdrop.contains(this.canvasWidget.canvas)) {
|
if (backdrop && backdrop.contains(this.canvasWidget?.canvas)) {
|
||||||
document.body.removeChild(backdrop);
|
document.body.removeChild(backdrop);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -125,12 +125,29 @@ export function cloneLayers(layers) {
|
|||||||
* @returns {string} Sygnatura JSON
|
* @returns {string} Sygnatura JSON
|
||||||
*/
|
*/
|
||||||
export function getStateSignature(layers) {
|
export function getStateSignature(layers) {
|
||||||
return JSON.stringify(layers.map(layer => {
|
return JSON.stringify(layers.map((layer, index) => {
|
||||||
const sig = {...layer};
|
const sig = {
|
||||||
if (sig.imageId) {
|
index: index,
|
||||||
sig.imageId = sig.imageId;
|
x: Math.round(layer.x * 100) / 100, // Round to avoid floating point precision issues
|
||||||
|
y: Math.round(layer.y * 100) / 100,
|
||||||
|
width: Math.round(layer.width * 100) / 100,
|
||||||
|
height: Math.round(layer.height * 100) / 100,
|
||||||
|
rotation: Math.round((layer.rotation || 0) * 100) / 100,
|
||||||
|
zIndex: layer.zIndex,
|
||||||
|
blendMode: layer.blendMode || 'normal',
|
||||||
|
opacity: layer.opacity !== undefined ? Math.round(layer.opacity * 100) / 100 : 1
|
||||||
|
};
|
||||||
|
|
||||||
|
// Include imageId if available
|
||||||
|
if (layer.imageId) {
|
||||||
|
sig.imageId = layer.imageId;
|
||||||
}
|
}
|
||||||
delete sig.image;
|
|
||||||
|
// Include image src as fallback identifier
|
||||||
|
if (layer.image && layer.image.src) {
|
||||||
|
sig.imageSrc = layer.image.src.substring(0, 100); // First 100 chars to avoid huge signatures
|
||||||
|
}
|
||||||
|
|
||||||
return sig;
|
return sig;
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|||||||
160
js/utils/WebSocketManager.js
Normal file
160
js/utils/WebSocketManager.js
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
import {createModuleLogger} from "./LoggerUtils.js";
|
||||||
|
|
||||||
|
const log = createModuleLogger('WebSocketManager');
|
||||||
|
|
||||||
|
class WebSocketManager {
|
||||||
|
constructor(url) {
|
||||||
|
this.url = url;
|
||||||
|
this.socket = null;
|
||||||
|
this.messageQueue = [];
|
||||||
|
this.isConnecting = false;
|
||||||
|
this.reconnectAttempts = 0;
|
||||||
|
this.maxReconnectAttempts = 10;
|
||||||
|
this.reconnectInterval = 5000; // 5 seconds
|
||||||
|
this.ackCallbacks = new Map(); // Store callbacks for messages awaiting ACK
|
||||||
|
this.messageIdCounter = 0;
|
||||||
|
|
||||||
|
this.connect();
|
||||||
|
}
|
||||||
|
|
||||||
|
connect() {
|
||||||
|
if (this.socket && this.socket.readyState === WebSocket.OPEN) {
|
||||||
|
log.debug("WebSocket is already open.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this.isConnecting) {
|
||||||
|
log.debug("Connection attempt already in progress.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
this.isConnecting = true;
|
||||||
|
log.info(`Connecting to WebSocket at ${this.url}...`);
|
||||||
|
|
||||||
|
try {
|
||||||
|
this.socket = new WebSocket(this.url);
|
||||||
|
|
||||||
|
this.socket.onopen = () => {
|
||||||
|
this.isConnecting = false;
|
||||||
|
this.reconnectAttempts = 0;
|
||||||
|
log.info("WebSocket connection established.");
|
||||||
|
this.flushMessageQueue();
|
||||||
|
};
|
||||||
|
|
||||||
|
this.socket.onmessage = (event) => {
|
||||||
|
try {
|
||||||
|
const data = JSON.parse(event.data);
|
||||||
|
log.debug("Received message:", data);
|
||||||
|
|
||||||
|
if (data.type === 'ack' && data.nodeId) {
|
||||||
|
const callback = this.ackCallbacks.get(data.nodeId);
|
||||||
|
if (callback) {
|
||||||
|
log.debug(`ACK received for nodeId: ${data.nodeId}, resolving promise.`);
|
||||||
|
callback.resolve(data);
|
||||||
|
this.ackCallbacks.delete(data.nodeId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Handle other incoming messages if needed
|
||||||
|
} catch (error) {
|
||||||
|
log.error("Error parsing incoming WebSocket message:", error);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
this.socket.onclose = (event) => {
|
||||||
|
this.isConnecting = false;
|
||||||
|
if (event.wasClean) {
|
||||||
|
log.info(`WebSocket closed cleanly, code=${event.code}, reason=${event.reason}`);
|
||||||
|
} else {
|
||||||
|
log.warn("WebSocket connection died. Attempting to reconnect...");
|
||||||
|
this.handleReconnect();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
this.socket.onerror = (error) => {
|
||||||
|
this.isConnecting = false;
|
||||||
|
log.error("WebSocket error:", error);
|
||||||
|
// The onclose event will be fired next, which will handle reconnection.
|
||||||
|
};
|
||||||
|
} catch (error) {
|
||||||
|
this.isConnecting = false;
|
||||||
|
log.error("Failed to create WebSocket connection:", error);
|
||||||
|
this.handleReconnect();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
handleReconnect() {
|
||||||
|
if (this.reconnectAttempts < this.maxReconnectAttempts) {
|
||||||
|
this.reconnectAttempts++;
|
||||||
|
log.info(`Reconnect attempt ${this.reconnectAttempts}/${this.maxReconnectAttempts}...`);
|
||||||
|
setTimeout(() => this.connect(), this.reconnectInterval);
|
||||||
|
} else {
|
||||||
|
log.error("Max reconnect attempts reached. Giving up.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sendMessage(data, requiresAck = false) {
|
||||||
|
return new Promise((resolve, reject) => {
|
||||||
|
const nodeId = data.nodeId;
|
||||||
|
if (requiresAck && !nodeId) {
|
||||||
|
return reject(new Error("A nodeId is required for messages that need acknowledgment."));
|
||||||
|
}
|
||||||
|
|
||||||
|
const message = JSON.stringify(data);
|
||||||
|
|
||||||
|
if (this.socket && this.socket.readyState === WebSocket.OPEN) {
|
||||||
|
this.socket.send(message);
|
||||||
|
log.debug("Sent message:", data);
|
||||||
|
if (requiresAck) {
|
||||||
|
log.debug(`Message for nodeId ${nodeId} requires ACK. Setting up callback.`);
|
||||||
|
// Set a timeout for the ACK
|
||||||
|
const timeout = setTimeout(() => {
|
||||||
|
this.ackCallbacks.delete(nodeId);
|
||||||
|
reject(new Error(`ACK timeout for nodeId ${nodeId}`));
|
||||||
|
log.warn(`ACK timeout for nodeId ${nodeId}.`);
|
||||||
|
}, 10000); // 10-second timeout
|
||||||
|
|
||||||
|
this.ackCallbacks.set(nodeId, {
|
||||||
|
resolve: (responseData) => {
|
||||||
|
clearTimeout(timeout);
|
||||||
|
resolve(responseData);
|
||||||
|
},
|
||||||
|
reject: (error) => {
|
||||||
|
clearTimeout(timeout);
|
||||||
|
reject(error);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
resolve(); // Resolve immediately if no ACK is needed
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.warn("WebSocket not open. Queuing message.");
|
||||||
|
// Note: The current queueing doesn't support ACK promises well.
|
||||||
|
// For simplicity, we'll focus on the connected case.
|
||||||
|
// A more robust implementation would wrap the queued message in a function.
|
||||||
|
this.messageQueue.push(message);
|
||||||
|
if (!this.isConnecting) {
|
||||||
|
this.connect();
|
||||||
|
}
|
||||||
|
// For now, we reject if not connected and ACK is required.
|
||||||
|
if (requiresAck) {
|
||||||
|
reject(new Error("Cannot send message with ACK required while disconnected."));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
flushMessageQueue() {
|
||||||
|
log.debug(`Flushing ${this.messageQueue.length} queued messages.`);
|
||||||
|
// Note: This simple flush doesn't handle ACKs for queued messages.
|
||||||
|
// This should be acceptable as data is sent right before queueing a prompt,
|
||||||
|
// at which point the socket should ideally be connected.
|
||||||
|
while (this.messageQueue.length > 0) {
|
||||||
|
const message = this.messageQueue.shift();
|
||||||
|
this.socket.send(message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a singleton instance of the WebSocketManager
|
||||||
|
const wsUrl = `ws://${window.location.host}/layerforge/canvas_ws`;
|
||||||
|
export const webSocketManager = new WebSocketManager(wsUrl);
|
||||||
Reference in New Issue
Block a user