Update canvas_node.py

This commit is contained in:
tanglup
2024-11-20 18:21:25 +08:00
committed by GitHub
parent a06448b610
commit 872814b676

View File

@@ -4,6 +4,46 @@ import torch
import numpy as np import numpy as np
import folder_paths import folder_paths
from server import PromptServer from server import PromptServer
from aiohttp import web
import os
from tqdm import tqdm
from torchvision import transforms
from transformers import AutoModelForImageSegmentation, PretrainedConfig
import torch.nn.functional as F
import traceback
# 设置高精度计算
torch.set_float32_matmul_precision('high')
# 定义配置类
class BiRefNetConfig(PretrainedConfig):
model_type = "BiRefNet"
def __init__(self, bb_pretrained=False, **kwargs):
self.bb_pretrained = bb_pretrained
super().__init__(**kwargs)
# 定义模型类
class BiRefNet(torch.nn.Module):
def __init__(self, config):
super().__init__()
# 基本网络结构
self.encoder = torch.nn.Sequential(
torch.nn.Conv2d(3, 64, kernel_size=3, padding=1),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(64, 64, kernel_size=3, padding=1),
torch.nn.ReLU(inplace=True)
)
self.decoder = torch.nn.Sequential(
torch.nn.Conv2d(64, 32, kernel_size=3, padding=1),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(32, 1, kernel_size=1)
)
def forward(self, x):
features = self.encoder(x)
output = self.decoder(features)
return [output]
class CanvasView: class CanvasView:
@classmethod @classmethod
@@ -25,7 +65,7 @@ class CanvasView:
def process_canvas_image(self, canvas_image, trigger, unique_id): def process_canvas_image(self, canvas_image, trigger, unique_id):
try: try:
# 读取保存的画布图像和遮 # 读取保存的画布图像和遮
path_image = folder_paths.get_annotated_filepath(canvas_image) path_image = folder_paths.get_annotated_filepath(canvas_image)
path_mask = folder_paths.get_annotated_filepath(canvas_image.replace('.png', '_mask.png')) path_mask = folder_paths.get_annotated_filepath(canvas_image.replace('.png', '_mask.png'))
@@ -55,7 +95,302 @@ class CanvasView:
return (image, mask) return (image, mask)
except Exception as e: except Exception as e:
print(f"Error processing canvas image: {str(e)}") print(f"Error processing canvas image: {str(e)}")
# 回白色图像和空白遮罩 # 回白色图像和空白遮罩
blank = np.ones((512, 512, 3), dtype=np.float32) blank = np.ones((512, 512, 3), dtype=np.float32)
blank_mask = np.zeros((512, 512), dtype=np.float32) blank_mask = np.zeros((512, 512), dtype=np.float32)
return (torch.from_numpy(blank)[None,], torch.from_numpy(blank_mask)[None,]) return (torch.from_numpy(blank)[None,], torch.from_numpy(blank_mask)[None,])
class BiRefNetMatting:
def __init__(self):
self.model = None
self.model_path = None
self.model_cache = {}
# 使用 ComfyUI models 目录
self.base_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "models")
def load_model(self, model_path):
try:
if model_path not in self.model_cache:
# 使用 ComfyUI models 目录下的 BiRefNet 路径
full_model_path = os.path.join(self.base_path, "BiRefNet")
print(f"Loading BiRefNet model from {full_model_path}...")
try:
# 直接从Hugging Face加载
self.model = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet",
trust_remote_code=True,
cache_dir=full_model_path # 使用本地缓存目录
)
# 设置为评估模式并移动到GPU
self.model.eval()
if torch.cuda.is_available():
self.model = self.model.cuda()
self.model_cache[model_path] = self.model
print("Model loaded successfully from Hugging Face")
print(f"Model type: {type(self.model)}")
print(f"Model device: {next(self.model.parameters()).device}")
except Exception as e:
print(f"Failed to load model: {str(e)}")
raise
else:
self.model = self.model_cache[model_path]
print("Using cached model")
return True
except Exception as e:
print(f"Error loading model: {str(e)}")
traceback.print_exc()
return False
def preprocess_image(self, image):
"""预处理输入图像"""
try:
# 转换为PIL图像
if isinstance(image, torch.Tensor):
if image.dim() == 4:
image = image.squeeze(0)
if image.dim() == 3:
image = transforms.ToPILImage()(image)
# 参考nodes.py的预处理
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 转换为tensor并添加batch维度
image_tensor = transform_image(image).unsqueeze(0)
if torch.cuda.is_available():
image_tensor = image_tensor.cuda()
return image_tensor
except Exception as e:
print(f"Error preprocessing image: {str(e)}")
return None
def execute(self, image, model_path, threshold=0.5, refinement=1):
try:
# 发送开始状态
PromptServer.instance.send_sync("matting_status", {"status": "processing"})
# 加载模型
if not self.load_model(model_path):
raise RuntimeError("Failed to load model")
# 获取原始尺寸
if isinstance(image, torch.Tensor):
original_size = image.shape[-2:] if image.dim() == 4 else image.shape[-2:]
else:
original_size = image.size[::-1]
print(f"Original size: {original_size}")
# 预处理图像
processed_image = self.preprocess_image(image)
if processed_image is None:
raise Exception("Failed to preprocess image")
print(f"Processed image shape: {processed_image.shape}")
# 执行推理
with torch.no_grad():
outputs = self.model(processed_image)
result = outputs[-1].sigmoid().cpu()
print(f"Model output shape: {result.shape}")
# 确保结果有正的维度格式 [B, C, H, W]
if result.dim() == 3:
result = result.unsqueeze(1) # 添加通道维度
elif result.dim() == 2:
result = result.unsqueeze(0).unsqueeze(0) # 添加batch和通道维度
print(f"Reshaped result shape: {result.shape}")
# 调整大小
result = F.interpolate(
result,
size=(original_size[0], original_size[1]), # 明确指定高度和宽度
mode='bilinear',
align_corners=True
)
print(f"Resized result shape: {result.shape}")
# 归一化
result = result.squeeze() # 移除多余的维度
ma = torch.max(result)
mi = torch.min(result)
result = (result-mi)/(ma-mi)
# 应用阈值
if threshold > 0:
result = (result > threshold).float()
# 创建mask和结果图像
alpha_mask = result.unsqueeze(0).unsqueeze(0) # 确保mask是 [1, 1, H, W]
if isinstance(image, torch.Tensor):
if image.dim() == 3:
image = image.unsqueeze(0)
masked_image = image * alpha_mask
else:
image_tensor = transforms.ToTensor()(image).unsqueeze(0)
masked_image = image_tensor * alpha_mask
# 发送完成状态
PromptServer.instance.send_sync("matting_status", {"status": "completed"})
return (masked_image, alpha_mask)
except Exception as e:
# 发送错误状态
PromptServer.instance.send_sync("matting_status", {"status": "error"})
raise e
@classmethod
def IS_CHANGED(cls, image, model_path, threshold, refinement):
"""检查输入是否改变"""
m = hashlib.md5()
m.update(str(image).encode())
m.update(str(model_path).encode())
m.update(str(threshold).encode())
m.update(str(refinement).encode())
return m.hexdigest()
@PromptServer.instance.routes.post("/matting")
async def matting(request):
try:
print("Received matting request")
data = await request.json()
# 获取BiRefNet实例
matting = BiRefNetMatting()
# 处理图像数据,现在返回图像tensor和alpha通道
image_tensor, original_alpha = convert_base64_to_tensor(data["image"])
print(f"Input image shape: {image_tensor.shape}")
# 执行抠图
matted_image, alpha_mask = matting.execute(
image_tensor,
"BiRefNet/model.safetensors",
threshold=data.get("threshold", 0.5),
refinement=data.get("refinement", 1)
)
# 转换结果为base64,包含原始alpha信息
result_image = convert_tensor_to_base64(matted_image, alpha_mask, original_alpha)
result_mask = convert_tensor_to_base64(alpha_mask)
return web.json_response({
"matted_image": result_image,
"alpha_mask": result_mask
})
except Exception as e:
print(f"Error in matting endpoint: {str(e)}")
import traceback
traceback.print_exc()
return web.json_response({
"error": str(e),
"details": traceback.format_exc()
}, status=500)
def convert_base64_to_tensor(base64_str):
"""将base64图像数据转换为tensor,保留alpha通道"""
import base64
import io
try:
# 解码base64数据
img_data = base64.b64decode(base64_str.split(',')[1])
img = Image.open(io.BytesIO(img_data))
# 保存原始alpha通道
has_alpha = img.mode == 'RGBA'
alpha = None
if has_alpha:
alpha = img.split()[3]
# 创建白色背景
background = Image.new('RGB', img.size, (255, 255, 255))
background.paste(img, mask=alpha)
img = background
elif img.mode != 'RGB':
img = img.convert('RGB')
# 转换为tensor
transform = transforms.ToTensor()
img_tensor = transform(img).unsqueeze(0) # [1, C, H, W]
if has_alpha:
# 将alpha转换为tensor并保存
alpha_tensor = transforms.ToTensor()(alpha).unsqueeze(0) # [1, 1, H, W]
return img_tensor, alpha_tensor
return img_tensor, None
except Exception as e:
print(f"Error in convert_base64_to_tensor: {str(e)}")
raise
def convert_tensor_to_base64(tensor, alpha_mask=None, original_alpha=None):
"""将tensor转换为base64图像数据,支持alpha通道"""
import base64
import io
try:
# 确保tensor在CPU上
tensor = tensor.cpu()
# 处理维度
if tensor.dim() == 4:
tensor = tensor.squeeze(0) # 移除batch维度
if tensor.dim() == 3 and tensor.shape[0] in [1, 3]:
tensor = tensor.permute(1, 2, 0)
# 转换为numpy数组并调整值范围到0-255
img_array = (tensor.numpy() * 255).astype(np.uint8)
# 如果有alpha遮罩和原始alpha
if alpha_mask is not None and original_alpha is not None:
# 将alpha_mask转换为正确的格式
alpha_mask = alpha_mask.cpu().squeeze().numpy()
alpha_mask = (alpha_mask * 255).astype(np.uint8)
# 将原始alpha转换为正确的格式
original_alpha = original_alpha.cpu().squeeze().numpy()
original_alpha = (original_alpha * 255).astype(np.uint8)
# 组合alpha_mask和original_alpha
combined_alpha = np.minimum(alpha_mask, original_alpha)
# 创建RGBA图像
img = Image.fromarray(img_array, mode='RGB')
alpha_img = Image.fromarray(combined_alpha, mode='L')
img.putalpha(alpha_img)
else:
# 处理没有alpha通道的情况
if img_array.shape[-1] == 1:
img_array = img_array.squeeze(-1)
img = Image.fromarray(img_array, mode='L')
else:
img = Image.fromarray(img_array, mode='RGB')
# 转换为base64
buffer = io.BytesIO()
img.save(buffer, format='PNG')
img_str = base64.b64encode(buffer.getvalue()).decode()
return f"data:image/png;base64,{img_str}"
except Exception as e:
print(f"Error in convert_tensor_to_base64: {str(e)}")
print(f"Tensor shape: {tensor.shape}, dtype: {tensor.dtype}")
raise