diff --git a/canvas_node.py b/canvas_node.py index 9e0bd51..f30becf 100644 --- a/canvas_node.py +++ b/canvas_node.py @@ -11,6 +11,11 @@ from torchvision import transforms from transformers import AutoModelForImageSegmentation, PretrainedConfig import torch.nn.functional as F import traceback +import uuid +import time +import base64 +from PIL import Image +import io # 设置高精度计算 torch.set_float32_matmul_precision('high') @@ -45,16 +50,92 @@ class BiRefNet(torch.nn.Module): output = self.decoder(features) return [output] -class CanvasView: +class CanvasNode: + _canvas_cache = { + 'image': None, + 'mask': None, + 'cache_enabled': True, + 'data_flow_status': {}, + 'persistent_cache': {}, + 'last_execution_id': None + } + + def __init__(self): + super().__init__() + self.flow_id = str(uuid.uuid4()) + # 从持久化缓存恢复数据 + if self.__class__._canvas_cache['persistent_cache']: + self.restore_cache() + + def restore_cache(self): + """从持久化缓存恢复数据,除非是新的执行""" + try: + persistent = self.__class__._canvas_cache['persistent_cache'] + current_execution = self.get_execution_id() + + # 只有在新的执行ID时才清除缓存 + if current_execution != self.__class__._canvas_cache['last_execution_id']: + print(f"New execution detected: {current_execution}") + self.__class__._canvas_cache['image'] = None + self.__class__._canvas_cache['mask'] = None + self.__class__._canvas_cache['last_execution_id'] = current_execution + else: + # 否则保留现有缓存 + if persistent.get('image') is not None: + self.__class__._canvas_cache['image'] = persistent['image'] + print("Restored image from persistent cache") + if persistent.get('mask') is not None: + self.__class__._canvas_cache['mask'] = persistent['mask'] + print("Restored mask from persistent cache") + except Exception as e: + print(f"Error restoring cache: {str(e)}") + + def get_execution_id(self): + """获取当前工作流执行ID""" + try: + # 可以使用时间戳或其他唯一标识 + return str(int(time.time() * 1000)) + except Exception as e: + print(f"Error getting execution ID: {str(e)}") + return None + + def update_persistent_cache(self): + """更新持久化缓存""" + try: + self.__class__._canvas_cache['persistent_cache'] = { + 'image': self.__class__._canvas_cache['image'], + 'mask': self.__class__._canvas_cache['mask'] + } + print("Updated persistent cache") + except Exception as e: + print(f"Error updating persistent cache: {str(e)}") + + def track_data_flow(self, stage, status, data_info=None): + """追踪数据流状态""" + flow_status = { + 'timestamp': time.time(), + 'stage': stage, + 'status': status, + 'data_info': data_info + } + print(f"Data Flow [{self.flow_id}] - Stage: {stage}, Status: {status}") + if data_info: + print(f"Data Info: {data_info}") + + self.__class__._canvas_cache['data_flow_status'][self.flow_id] = flow_status + @classmethod def INPUT_TYPES(cls): return { "required": { "canvas_image": ("STRING", {"default": "canvas_image.png"}), - "trigger": ("INT", {"default": 0, "min": 0, "max": 99999999, "step": 1}) + "trigger": ("INT", {"default": 0, "min": 0, "max": 99999999, "step": 1, "hidden": True}), + "output_switch": ("BOOLEAN", {"default": True}), + "cache_enabled": ("BOOLEAN", {"default": True, "label": "Enable Cache"}) }, - "hidden": { - "unique_id": "UNIQUE_ID" + "optional": { + "input_image": ("IMAGE",), + "input_mask": ("MASK",) } } @@ -63,42 +144,258 @@ class CanvasView: FUNCTION = "process_canvas_image" CATEGORY = "ycnode" - def process_canvas_image(self, canvas_image, trigger, unique_id): + def add_image_to_canvas(self, input_image): + """处理输入图像""" try: - # 读取保存的画布图像和遮 - path_image = folder_paths.get_annotated_filepath(canvas_image) - path_mask = folder_paths.get_annotated_filepath(canvas_image.replace('.png', '_mask.png')) + # 确保输入图像是正确的格式 + if not isinstance(input_image, torch.Tensor): + raise ValueError("Input image must be a torch.Tensor") - # 处理主图像 - i = Image.open(path_image) - i = ImageOps.exif_transpose(i) - if i.mode not in ['RGB', 'RGBA']: - i = i.convert('RGB') - image = np.array(i).astype(np.float32) / 255.0 - if i.mode == 'RGBA': - rgb = image[..., :3] - alpha = image[..., 3:] - image = rgb * alpha + (1 - alpha) * 0.5 + # 处理图像维度 + if input_image.dim() == 4: + input_image = input_image.squeeze(0) - # 处理遮罩图像 - try: - mask = Image.open(path_mask).convert('L') - mask = np.array(mask).astype(np.float32) / 255.0 - mask = torch.from_numpy(mask)[None,] - except: - # 如果没有遮罩文件,创建空白遮罩 - mask = torch.zeros((1, image.shape[0], image.shape[1]), dtype=torch.float32) + # 转换为标准格式 + if input_image.dim() == 3 and input_image.shape[0] in [1, 3]: + input_image = input_image.permute(1, 2, 0) - # 转换为tensor - image = torch.from_numpy(image)[None,] + return input_image - return (image, mask) except Exception as e: - print(f"Error processing canvas image: {str(e)}") - # 回白色图像和空白遮罩 - blank = np.ones((512, 512, 3), dtype=np.float32) - blank_mask = np.zeros((512, 512), dtype=np.float32) - return (torch.from_numpy(blank)[None,], torch.from_numpy(blank_mask)[None,]) + print(f"Error in add_image_to_canvas: {str(e)}") + return None + + def add_mask_to_canvas(self, input_mask, input_image): + """处理输入遮罩""" + try: + # 确保输入遮罩是正确的格式 + if not isinstance(input_mask, torch.Tensor): + raise ValueError("Input mask must be a torch.Tensor") + + # 处理遮罩维度 + if input_mask.dim() == 4: + input_mask = input_mask.squeeze(0) + if input_mask.dim() == 3 and input_mask.shape[0] == 1: + input_mask = input_mask.squeeze(0) + + # 确保遮罩尺寸与图像匹配 + if input_image is not None: + expected_shape = input_image.shape[:2] + if input_mask.shape != expected_shape: + input_mask = F.interpolate( + input_mask.unsqueeze(0).unsqueeze(0), + size=expected_shape, + mode='bilinear', + align_corners=False + ).squeeze() + + return input_mask + + except Exception as e: + print(f"Error in add_mask_to_canvas: {str(e)}") + return None + + def process_canvas_image(self, canvas_image, trigger, output_switch, cache_enabled, input_image=None, input_mask=None): + try: + current_execution = self.get_execution_id() + print(f"Processing canvas image, execution ID: {current_execution}") + + # 检查是否是新的执行 + if current_execution != self.__class__._canvas_cache['last_execution_id']: + print(f"New execution detected: {current_execution}") + # 清除旧的缓存 + self.__class__._canvas_cache['image'] = None + self.__class__._canvas_cache['mask'] = None + self.__class__._canvas_cache['last_execution_id'] = current_execution + + # 处理输入图像 + if input_image is not None: + print("Input image received, converting to PIL Image...") + # 将tensor转换为PIL Image并存储到缓存 + if isinstance(input_image, torch.Tensor): + if input_image.dim() == 4: + input_image = input_image.squeeze(0) # 移除batch维度 + + # 确保图像格式为[H, W, C] + if input_image.shape[0] == 3: # 如果是[C, H, W]格式 + input_image = input_image.permute(1, 2, 0) + + # 转换为numpy数组并确保值范围在0-255 + image_array = (input_image.cpu().numpy() * 255).astype(np.uint8) + + # 确保数组形状正确 + if len(image_array.shape) == 2: # 如果是灰度图 + image_array = np.stack([image_array] * 3, axis=-1) + elif len(image_array.shape) == 3 and image_array.shape[-1] != 3: + image_array = np.transpose(image_array, (1, 2, 0)) + + try: + # 转换为PIL Image + pil_image = Image.fromarray(image_array, 'RGB') + print("Successfully converted to PIL Image") + # 存储PIL Image到缓存 + self.__class__._canvas_cache['image'] = pil_image + print(f"Image stored in cache with size: {pil_image.size}") + except Exception as e: + print(f"Error converting to PIL Image: {str(e)}") + print(f"Array shape: {image_array.shape}, dtype: {image_array.dtype}") + raise + + # 处理输入遮罩 + if input_mask is not None: + print("Input mask received, converting to PIL Image...") + if isinstance(input_mask, torch.Tensor): + if input_mask.dim() == 4: + input_mask = input_mask.squeeze(0) + if input_mask.dim() == 3 and input_mask.shape[0] == 1: + input_mask = input_mask.squeeze(0) + + # 转换为PIL Image + mask_array = (input_mask.cpu().numpy() * 255).astype(np.uint8) + pil_mask = Image.fromarray(mask_array, 'L') + print("Successfully converted mask to PIL Image") + # 存储遮罩到缓存 + self.__class__._canvas_cache['mask'] = pil_mask + print(f"Mask stored in cache with size: {pil_mask.size}") + + # 更新缓存开关状态 + self.__class__._canvas_cache['cache_enabled'] = cache_enabled + + try: + # 尝试读取画布图像 + path_image = folder_paths.get_annotated_filepath(canvas_image) + i = Image.open(path_image) + i = ImageOps.exif_transpose(i) + if i.mode not in ['RGB', 'RGBA']: + i = i.convert('RGB') + image = np.array(i).astype(np.float32) / 255.0 + if i.mode == 'RGBA': + rgb = image[..., :3] + alpha = image[..., 3:] + image = rgb * alpha + (1 - alpha) * 0.5 + processed_image = torch.from_numpy(image)[None,] + except Exception as e: + # 如果读取失败,创建白色画布 + processed_image = torch.ones((1, 512, 512, 3), dtype=torch.float32) + + try: + # 尝试读取遮罩图像 + path_mask = path_image.replace('.png', '_mask.png') + if os.path.exists(path_mask): + mask = Image.open(path_mask).convert('L') + mask = np.array(mask).astype(np.float32) / 255.0 + processed_mask = torch.from_numpy(mask)[None,] + else: + # 如果没有遮罩文件,创建全白遮罩 + processed_mask = torch.ones((1, processed_image.shape[1], processed_image.shape[2]), dtype=torch.float32) + except Exception as e: + print(f"Error loading mask: {str(e)}") + # 创建默认遮罩 + processed_mask = torch.ones((1, processed_image.shape[1], processed_image.shape[2]), dtype=torch.float32) + + # 输出处理 + if not output_switch: + return () + + # 更新持久化缓存 + self.update_persistent_cache() + + # 返回处理后的图像和遮罩 + return (processed_image, processed_mask) + + except Exception as e: + print(f"Error in process_canvas_image: {str(e)}") + traceback.print_exc() + return () + + # 添加获取缓存数据的方法 + def get_cached_data(self): + return { + 'image': self.__class__._canvas_cache['image'], + 'mask': self.__class__._canvas_cache['mask'] + } + + # 添加API路由处理器 + @classmethod + def api_get_data(cls, node_id): + try: + return { + 'success': True, + 'data': cls._canvas_cache + } + except Exception as e: + return { + 'success': False, + 'error': str(e) + } + + @classmethod + def get_flow_status(cls, flow_id=None): + """获取数据流状态""" + if flow_id: + return cls._canvas_cache['data_flow_status'].get(flow_id) + return cls._canvas_cache['data_flow_status'] + + @classmethod + def setup_routes(cls): + @PromptServer.instance.routes.get("/ycnode/get_canvas_data/{node_id}") + async def get_canvas_data(request): + try: + node_id = request.match_info["node_id"] + print(f"Received request for node: {node_id}") + + cache_data = cls._canvas_cache + print(f"Cache content: {cache_data}") + print(f"Image in cache: {cache_data['image'] is not None}") + + response_data = { + 'success': True, + 'data': { + 'image': None, + 'mask': None + } + } + + if cache_data['image'] is not None: + pil_image = cache_data['image'] + buffered = io.BytesIO() + pil_image.save(buffered, format="PNG") + img_str = base64.b64encode(buffered.getvalue()).decode() + response_data['data']['image'] = f"data:image/png;base64,{img_str}" + + if cache_data['mask'] is not None: + pil_mask = cache_data['mask'] + mask_buffer = io.BytesIO() + pil_mask.save(mask_buffer, format="PNG") + mask_str = base64.b64encode(mask_buffer.getvalue()).decode() + response_data['data']['mask'] = f"data:image/png;base64,{mask_str}" + + return web.json_response(response_data) + + except Exception as e: + print(f"Error in get_canvas_data: {str(e)}") + return web.json_response({ + 'success': False, + 'error': str(e) + }) + + def store_image(self, image_data): + # 将base64数据转换为PIL Image并存储 + if isinstance(image_data, str) and image_data.startswith('data:image'): + image_data = image_data.split(',')[1] + image_bytes = base64.b64decode(image_data) + self.cached_image = Image.open(io.BytesIO(image_bytes)) + else: + self.cached_image = image_data + + def get_cached_image(self): + # 将PIL Image转换为base64 + if self.cached_image: + buffered = io.BytesIO() + self.cached_image.save(buffered, format="PNG") + img_str = base64.b64encode(buffered.getvalue()).decode() + return f"data:image/png;base64,{img_str}" + return None class BiRefNetMatting: def __init__(self): @@ -270,7 +567,7 @@ async def matting(request): print("Received matting request") data = await request.json() - # 获取BiRefNet实例 + # 取BiRefNet实例 matting = BiRefNetMatting() # 处理图像数据,现在返回图像tensor和alpha通道