From 872814b6765a888b1b07ac1859f93fb9b6da8bfc Mon Sep 17 00:00:00 2001 From: tanglup <160642431+tanglup@users.noreply.github.com> Date: Wed, 20 Nov 2024 18:21:25 +0800 Subject: [PATCH] Update canvas_node.py --- canvas_node.py | 341 ++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 338 insertions(+), 3 deletions(-) diff --git a/canvas_node.py b/canvas_node.py index 45579a9..9e0bd51 100644 --- a/canvas_node.py +++ b/canvas_node.py @@ -4,6 +4,46 @@ import torch import numpy as np import folder_paths from server import PromptServer +from aiohttp import web +import os +from tqdm import tqdm +from torchvision import transforms +from transformers import AutoModelForImageSegmentation, PretrainedConfig +import torch.nn.functional as F +import traceback + +# 设置高精度计算 +torch.set_float32_matmul_precision('high') + +# 定义配置类 +class BiRefNetConfig(PretrainedConfig): + model_type = "BiRefNet" + def __init__(self, bb_pretrained=False, **kwargs): + self.bb_pretrained = bb_pretrained + super().__init__(**kwargs) + +# 定义模型类 +class BiRefNet(torch.nn.Module): + def __init__(self, config): + super().__init__() + # 基本网络结构 + self.encoder = torch.nn.Sequential( + torch.nn.Conv2d(3, 64, kernel_size=3, padding=1), + torch.nn.ReLU(inplace=True), + torch.nn.Conv2d(64, 64, kernel_size=3, padding=1), + torch.nn.ReLU(inplace=True) + ) + + self.decoder = torch.nn.Sequential( + torch.nn.Conv2d(64, 32, kernel_size=3, padding=1), + torch.nn.ReLU(inplace=True), + torch.nn.Conv2d(32, 1, kernel_size=1) + ) + + def forward(self, x): + features = self.encoder(x) + output = self.decoder(features) + return [output] class CanvasView: @classmethod @@ -25,7 +65,7 @@ class CanvasView: def process_canvas_image(self, canvas_image, trigger, unique_id): try: - # 读取保存的画布图像和遮罩 + # 读取保存的画布图像和遮 path_image = folder_paths.get_annotated_filepath(canvas_image) path_mask = folder_paths.get_annotated_filepath(canvas_image.replace('.png', '_mask.png')) @@ -55,7 +95,302 @@ class CanvasView: return (image, mask) except Exception as e: print(f"Error processing canvas image: {str(e)}") - # 返回白色图像和空白遮罩 + # 回白色图像和空白遮罩 blank = np.ones((512, 512, 3), dtype=np.float32) blank_mask = np.zeros((512, 512), dtype=np.float32) - return (torch.from_numpy(blank)[None,], torch.from_numpy(blank_mask)[None,]) \ No newline at end of file + return (torch.from_numpy(blank)[None,], torch.from_numpy(blank_mask)[None,]) + +class BiRefNetMatting: + def __init__(self): + self.model = None + self.model_path = None + self.model_cache = {} + # 使用 ComfyUI models 目录 + self.base_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "models") + + def load_model(self, model_path): + try: + if model_path not in self.model_cache: + # 使用 ComfyUI models 目录下的 BiRefNet 路径 + full_model_path = os.path.join(self.base_path, "BiRefNet") + + print(f"Loading BiRefNet model from {full_model_path}...") + + try: + # 直接从Hugging Face加载 + self.model = AutoModelForImageSegmentation.from_pretrained( + "ZhengPeng7/BiRefNet", + trust_remote_code=True, + cache_dir=full_model_path # 使用本地缓存目录 + ) + + # 设置为评估模式并移动到GPU + self.model.eval() + if torch.cuda.is_available(): + self.model = self.model.cuda() + + self.model_cache[model_path] = self.model + print("Model loaded successfully from Hugging Face") + print(f"Model type: {type(self.model)}") + print(f"Model device: {next(self.model.parameters()).device}") + + except Exception as e: + print(f"Failed to load model: {str(e)}") + raise + + else: + self.model = self.model_cache[model_path] + print("Using cached model") + + return True + + except Exception as e: + print(f"Error loading model: {str(e)}") + traceback.print_exc() + return False + + def preprocess_image(self, image): + """预处理输入图像""" + try: + # 转换为PIL图像 + if isinstance(image, torch.Tensor): + if image.dim() == 4: + image = image.squeeze(0) + if image.dim() == 3: + image = transforms.ToPILImage()(image) + + # 参考nodes.py的预处理 + transform_image = transforms.Compose([ + transforms.Resize((1024, 1024)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + + # 转换为tensor并添加batch维度 + image_tensor = transform_image(image).unsqueeze(0) + + if torch.cuda.is_available(): + image_tensor = image_tensor.cuda() + + return image_tensor + except Exception as e: + print(f"Error preprocessing image: {str(e)}") + return None + + def execute(self, image, model_path, threshold=0.5, refinement=1): + try: + # 发送开始状态 + PromptServer.instance.send_sync("matting_status", {"status": "processing"}) + + # 加载模型 + if not self.load_model(model_path): + raise RuntimeError("Failed to load model") + + # 获取原始尺寸 + if isinstance(image, torch.Tensor): + original_size = image.shape[-2:] if image.dim() == 4 else image.shape[-2:] + else: + original_size = image.size[::-1] + + print(f"Original size: {original_size}") + + # 预处理图像 + processed_image = self.preprocess_image(image) + if processed_image is None: + raise Exception("Failed to preprocess image") + + print(f"Processed image shape: {processed_image.shape}") + + # 执行推理 + with torch.no_grad(): + outputs = self.model(processed_image) + result = outputs[-1].sigmoid().cpu() + print(f"Model output shape: {result.shape}") + + # 确保结果有正的维度格式 [B, C, H, W] + if result.dim() == 3: + result = result.unsqueeze(1) # 添加通道维度 + elif result.dim() == 2: + result = result.unsqueeze(0).unsqueeze(0) # 添加batch和通道维度 + + print(f"Reshaped result shape: {result.shape}") + + # 调整大小 + result = F.interpolate( + result, + size=(original_size[0], original_size[1]), # 明确指定高度和宽度 + mode='bilinear', + align_corners=True + ) + print(f"Resized result shape: {result.shape}") + + # 归一化 + result = result.squeeze() # 移除多余的维度 + ma = torch.max(result) + mi = torch.min(result) + result = (result-mi)/(ma-mi) + + # 应用阈值 + if threshold > 0: + result = (result > threshold).float() + + # 创建mask和结果图像 + alpha_mask = result.unsqueeze(0).unsqueeze(0) # 确保mask是 [1, 1, H, W] + if isinstance(image, torch.Tensor): + if image.dim() == 3: + image = image.unsqueeze(0) + masked_image = image * alpha_mask + else: + image_tensor = transforms.ToTensor()(image).unsqueeze(0) + masked_image = image_tensor * alpha_mask + + # 发送完成状态 + PromptServer.instance.send_sync("matting_status", {"status": "completed"}) + + return (masked_image, alpha_mask) + + except Exception as e: + # 发送错误状态 + PromptServer.instance.send_sync("matting_status", {"status": "error"}) + raise e + + @classmethod + def IS_CHANGED(cls, image, model_path, threshold, refinement): + """检查输入是否改变""" + m = hashlib.md5() + m.update(str(image).encode()) + m.update(str(model_path).encode()) + m.update(str(threshold).encode()) + m.update(str(refinement).encode()) + return m.hexdigest() + +@PromptServer.instance.routes.post("/matting") +async def matting(request): + try: + print("Received matting request") + data = await request.json() + + # 获取BiRefNet实例 + matting = BiRefNetMatting() + + # 处理图像数据,现在返回图像tensor和alpha通道 + image_tensor, original_alpha = convert_base64_to_tensor(data["image"]) + print(f"Input image shape: {image_tensor.shape}") + + # 执行抠图 + matted_image, alpha_mask = matting.execute( + image_tensor, + "BiRefNet/model.safetensors", + threshold=data.get("threshold", 0.5), + refinement=data.get("refinement", 1) + ) + + # 转换结果为base64,包含原始alpha信息 + result_image = convert_tensor_to_base64(matted_image, alpha_mask, original_alpha) + result_mask = convert_tensor_to_base64(alpha_mask) + + return web.json_response({ + "matted_image": result_image, + "alpha_mask": result_mask + }) + + except Exception as e: + print(f"Error in matting endpoint: {str(e)}") + import traceback + traceback.print_exc() + return web.json_response({ + "error": str(e), + "details": traceback.format_exc() + }, status=500) + +def convert_base64_to_tensor(base64_str): + """将base64图像数据转换为tensor,保留alpha通道""" + import base64 + import io + + try: + # 解码base64数据 + img_data = base64.b64decode(base64_str.split(',')[1]) + img = Image.open(io.BytesIO(img_data)) + + # 保存原始alpha通道 + has_alpha = img.mode == 'RGBA' + alpha = None + if has_alpha: + alpha = img.split()[3] + # 创建白色背景 + background = Image.new('RGB', img.size, (255, 255, 255)) + background.paste(img, mask=alpha) + img = background + elif img.mode != 'RGB': + img = img.convert('RGB') + + # 转换为tensor + transform = transforms.ToTensor() + img_tensor = transform(img).unsqueeze(0) # [1, C, H, W] + + if has_alpha: + # 将alpha转换为tensor并保存 + alpha_tensor = transforms.ToTensor()(alpha).unsqueeze(0) # [1, 1, H, W] + return img_tensor, alpha_tensor + + return img_tensor, None + + except Exception as e: + print(f"Error in convert_base64_to_tensor: {str(e)}") + raise + +def convert_tensor_to_base64(tensor, alpha_mask=None, original_alpha=None): + """将tensor转换为base64图像数据,支持alpha通道""" + import base64 + import io + + try: + # 确保tensor在CPU上 + tensor = tensor.cpu() + + # 处理维度 + if tensor.dim() == 4: + tensor = tensor.squeeze(0) # 移除batch维度 + if tensor.dim() == 3 and tensor.shape[0] in [1, 3]: + tensor = tensor.permute(1, 2, 0) + + # 转换为numpy数组并调整值范围到0-255 + img_array = (tensor.numpy() * 255).astype(np.uint8) + + # 如果有alpha遮罩和原始alpha + if alpha_mask is not None and original_alpha is not None: + # 将alpha_mask转换为正确的格式 + alpha_mask = alpha_mask.cpu().squeeze().numpy() + alpha_mask = (alpha_mask * 255).astype(np.uint8) + + # 将原始alpha转换为正确的格式 + original_alpha = original_alpha.cpu().squeeze().numpy() + original_alpha = (original_alpha * 255).astype(np.uint8) + + # 组合alpha_mask和original_alpha + combined_alpha = np.minimum(alpha_mask, original_alpha) + + # 创建RGBA图像 + img = Image.fromarray(img_array, mode='RGB') + alpha_img = Image.fromarray(combined_alpha, mode='L') + img.putalpha(alpha_img) + else: + # 处理没有alpha通道的情况 + if img_array.shape[-1] == 1: + img_array = img_array.squeeze(-1) + img = Image.fromarray(img_array, mode='L') + else: + img = Image.fromarray(img_array, mode='RGB') + + # 转换为base64 + buffer = io.BytesIO() + img.save(buffer, format='PNG') + img_str = base64.b64encode(buffer.getvalue()).decode() + + return f"data:image/png;base64,{img_str}" + + except Exception as e: + print(f"Error in convert_tensor_to_base64: {str(e)}") + print(f"Tensor shape: {tensor.shape}, dtype: {tensor.dtype}") + raise