mirror of
https://github.com/Azornes/Comfyui-LayerForge.git
synced 2026-03-21 20:52:12 -03:00
Added Outpainting Logic
This commit is contained in:
357
canvas_node.py
357
canvas_node.py
@@ -17,39 +17,40 @@ import base64
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
# 设置高精度计算
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
# 定义配置类
|
||||
|
||||
class BiRefNetConfig(PretrainedConfig):
|
||||
model_type = "BiRefNet"
|
||||
|
||||
def __init__(self, bb_pretrained=False, **kwargs):
|
||||
self.bb_pretrained = bb_pretrained
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# 定义模型类
|
||||
|
||||
class BiRefNet(torch.nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
# 基本网络结构
|
||||
|
||||
self.encoder = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(3, 64, kernel_size=3, padding=1),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
torch.nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
||||
torch.nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
|
||||
self.decoder = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(64, 32, kernel_size=3, padding=1),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
torch.nn.Conv2d(32, 1, kernel_size=1)
|
||||
)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
features = self.encoder(x)
|
||||
output = self.decoder(features)
|
||||
return [output]
|
||||
|
||||
|
||||
class CanvasNode:
|
||||
_canvas_cache = {
|
||||
'image': None,
|
||||
@@ -59,28 +60,26 @@ class CanvasNode:
|
||||
'persistent_cache': {},
|
||||
'last_execution_id': None
|
||||
}
|
||||
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.flow_id = str(uuid.uuid4())
|
||||
# 从持久化缓存恢复数据
|
||||
|
||||
if self.__class__._canvas_cache['persistent_cache']:
|
||||
self.restore_cache()
|
||||
|
||||
def restore_cache(self):
|
||||
"""从持久化缓存恢复数据,除非是新的执行"""
|
||||
try:
|
||||
persistent = self.__class__._canvas_cache['persistent_cache']
|
||||
current_execution = self.get_execution_id()
|
||||
|
||||
# 只有在新的执行ID时才清除缓存
|
||||
|
||||
if current_execution != self.__class__._canvas_cache['last_execution_id']:
|
||||
print(f"New execution detected: {current_execution}")
|
||||
self.__class__._canvas_cache['image'] = None
|
||||
self.__class__._canvas_cache['mask'] = None
|
||||
self.__class__._canvas_cache['last_execution_id'] = current_execution
|
||||
else:
|
||||
# 否则保留现有缓存
|
||||
|
||||
if persistent.get('image') is not None:
|
||||
self.__class__._canvas_cache['image'] = persistent['image']
|
||||
print("Restored image from persistent cache")
|
||||
@@ -91,16 +90,16 @@ class CanvasNode:
|
||||
print(f"Error restoring cache: {str(e)}")
|
||||
|
||||
def get_execution_id(self):
|
||||
"""获取当前工作流执行ID"""
|
||||
|
||||
try:
|
||||
# 可以使用时间戳或其他唯一标识
|
||||
|
||||
return str(int(time.time() * 1000))
|
||||
except Exception as e:
|
||||
print(f"Error getting execution ID: {str(e)}")
|
||||
return None
|
||||
|
||||
def update_persistent_cache(self):
|
||||
"""更新持久化缓存"""
|
||||
|
||||
try:
|
||||
self.__class__._canvas_cache['persistent_cache'] = {
|
||||
'image': self.__class__._canvas_cache['image'],
|
||||
@@ -111,7 +110,7 @@ class CanvasNode:
|
||||
print(f"Error updating persistent cache: {str(e)}")
|
||||
|
||||
def track_data_flow(self, stage, status, data_info=None):
|
||||
"""追踪数据流状态"""
|
||||
|
||||
flow_status = {
|
||||
'timestamp': time.time(),
|
||||
'stage': stage,
|
||||
@@ -121,7 +120,7 @@ class CanvasNode:
|
||||
print(f"Data Flow [{self.flow_id}] - Stage: {stage}, Status: {status}")
|
||||
if data_info:
|
||||
print(f"Data Info: {data_info}")
|
||||
|
||||
|
||||
self.__class__._canvas_cache['data_flow_status'][self.flow_id] = flow_status
|
||||
|
||||
@classmethod
|
||||
@@ -138,47 +137,43 @@ class CanvasNode:
|
||||
"input_mask": ("MASK",)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK")
|
||||
RETURN_NAMES = ("image", "mask")
|
||||
FUNCTION = "process_canvas_image"
|
||||
CATEGORY = "Ycanvas"
|
||||
|
||||
def add_image_to_canvas(self, input_image):
|
||||
"""处理输入图像"""
|
||||
|
||||
try:
|
||||
# 确保输入图像是正确的格式
|
||||
|
||||
if not isinstance(input_image, torch.Tensor):
|
||||
raise ValueError("Input image must be a torch.Tensor")
|
||||
|
||||
# 处理图像维度
|
||||
|
||||
if input_image.dim() == 4:
|
||||
input_image = input_image.squeeze(0)
|
||||
|
||||
# 转换为标准格式
|
||||
|
||||
if input_image.dim() == 3 and input_image.shape[0] in [1, 3]:
|
||||
input_image = input_image.permute(1, 2, 0)
|
||||
|
||||
|
||||
return input_image
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in add_image_to_canvas: {str(e)}")
|
||||
return None
|
||||
|
||||
def add_mask_to_canvas(self, input_mask, input_image):
|
||||
"""处理输入遮罩"""
|
||||
|
||||
try:
|
||||
# 确保输入遮罩是正确的格式
|
||||
|
||||
if not isinstance(input_mask, torch.Tensor):
|
||||
raise ValueError("Input mask must be a torch.Tensor")
|
||||
|
||||
# 处理遮罩维度
|
||||
|
||||
if input_mask.dim() == 4:
|
||||
input_mask = input_mask.squeeze(0)
|
||||
if input_mask.dim() == 3 and input_mask.shape[0] == 1:
|
||||
input_mask = input_mask.squeeze(0)
|
||||
|
||||
# 确保遮罩尺寸与图像匹配
|
||||
|
||||
if input_image is not None:
|
||||
expected_shape = input_image.shape[:2]
|
||||
if input_mask.shape != expected_shape:
|
||||
@@ -188,60 +183,55 @@ class CanvasNode:
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
).squeeze()
|
||||
|
||||
|
||||
return input_mask
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in add_mask_to_canvas: {str(e)}")
|
||||
return None
|
||||
|
||||
def process_canvas_image(self, canvas_image, trigger, output_switch, cache_enabled, input_image=None, input_mask=None):
|
||||
def process_canvas_image(self, canvas_image, trigger, output_switch, cache_enabled, input_image=None,
|
||||
input_mask=None):
|
||||
try:
|
||||
current_execution = self.get_execution_id()
|
||||
print(f"Processing canvas image, execution ID: {current_execution}")
|
||||
|
||||
# 检查是否是新的执行
|
||||
|
||||
if current_execution != self.__class__._canvas_cache['last_execution_id']:
|
||||
print(f"New execution detected: {current_execution}")
|
||||
# 清除旧的缓存
|
||||
|
||||
self.__class__._canvas_cache['image'] = None
|
||||
self.__class__._canvas_cache['mask'] = None
|
||||
self.__class__._canvas_cache['last_execution_id'] = current_execution
|
||||
|
||||
# 处理输入图像
|
||||
|
||||
if input_image is not None:
|
||||
print("Input image received, converting to PIL Image...")
|
||||
# 将tensor转换为PIL Image并存储到缓存
|
||||
|
||||
if isinstance(input_image, torch.Tensor):
|
||||
if input_image.dim() == 4:
|
||||
input_image = input_image.squeeze(0) # 移除batch维度
|
||||
|
||||
# 确保图像格式为[H, W, C]
|
||||
|
||||
if input_image.shape[0] == 3: # 如果是[C, H, W]格式
|
||||
input_image = input_image.permute(1, 2, 0)
|
||||
|
||||
# 转换为numpy数组并确保值范围在0-255
|
||||
|
||||
image_array = (input_image.cpu().numpy() * 255).astype(np.uint8)
|
||||
|
||||
# 确保数组形状正确
|
||||
|
||||
if len(image_array.shape) == 2: # 如果是灰度图
|
||||
image_array = np.stack([image_array] * 3, axis=-1)
|
||||
elif len(image_array.shape) == 3 and image_array.shape[-1] != 3:
|
||||
image_array = np.transpose(image_array, (1, 2, 0))
|
||||
|
||||
|
||||
try:
|
||||
# 转换为PIL Image
|
||||
|
||||
pil_image = Image.fromarray(image_array, 'RGB')
|
||||
print("Successfully converted to PIL Image")
|
||||
# 存储PIL Image到缓存
|
||||
|
||||
self.__class__._canvas_cache['image'] = pil_image
|
||||
print(f"Image stored in cache with size: {pil_image.size}")
|
||||
except Exception as e:
|
||||
print(f"Error converting to PIL Image: {str(e)}")
|
||||
print(f"Array shape: {image_array.shape}, dtype: {image_array.dtype}")
|
||||
raise
|
||||
|
||||
# 处理输入遮罩
|
||||
|
||||
if input_mask is not None:
|
||||
print("Input mask received, converting to PIL Image...")
|
||||
if isinstance(input_mask, torch.Tensor):
|
||||
@@ -249,20 +239,18 @@ class CanvasNode:
|
||||
input_mask = input_mask.squeeze(0)
|
||||
if input_mask.dim() == 3 and input_mask.shape[0] == 1:
|
||||
input_mask = input_mask.squeeze(0)
|
||||
|
||||
# 转换为PIL Image
|
||||
|
||||
mask_array = (input_mask.cpu().numpy() * 255).astype(np.uint8)
|
||||
pil_mask = Image.fromarray(mask_array, 'L')
|
||||
print("Successfully converted mask to PIL Image")
|
||||
# 存储遮罩到缓存
|
||||
|
||||
self.__class__._canvas_cache['mask'] = pil_mask
|
||||
print(f"Mask stored in cache with size: {pil_mask.size}")
|
||||
|
||||
# 更新缓存开关状态
|
||||
|
||||
self.__class__._canvas_cache['cache_enabled'] = cache_enabled
|
||||
|
||||
|
||||
try:
|
||||
# 尝试读取画布图像
|
||||
|
||||
path_image = folder_paths.get_annotated_filepath(canvas_image)
|
||||
i = Image.open(path_image)
|
||||
i = ImageOps.exif_transpose(i)
|
||||
@@ -275,47 +263,44 @@ class CanvasNode:
|
||||
image = rgb * alpha + (1 - alpha) * 0.5
|
||||
processed_image = torch.from_numpy(image)[None,]
|
||||
except Exception as e:
|
||||
# 如果读取失败,创建白色画布
|
||||
|
||||
processed_image = torch.ones((1, 512, 512, 3), dtype=torch.float32)
|
||||
|
||||
|
||||
try:
|
||||
# 尝试读取遮罩图像
|
||||
|
||||
path_mask = path_image.replace('.png', '_mask.png')
|
||||
if os.path.exists(path_mask):
|
||||
mask = Image.open(path_mask).convert('L')
|
||||
mask = np.array(mask).astype(np.float32) / 255.0
|
||||
processed_mask = torch.from_numpy(mask)[None,]
|
||||
else:
|
||||
# 如果没有遮罩文件,创建全白遮罩
|
||||
processed_mask = torch.ones((1, processed_image.shape[1], processed_image.shape[2]), dtype=torch.float32)
|
||||
|
||||
processed_mask = torch.ones((1, processed_image.shape[1], processed_image.shape[2]),
|
||||
dtype=torch.float32)
|
||||
except Exception as e:
|
||||
print(f"Error loading mask: {str(e)}")
|
||||
# 创建默认遮罩
|
||||
processed_mask = torch.ones((1, processed_image.shape[1], processed_image.shape[2]), dtype=torch.float32)
|
||||
|
||||
# 输出处理
|
||||
|
||||
processed_mask = torch.ones((1, processed_image.shape[1], processed_image.shape[2]),
|
||||
dtype=torch.float32)
|
||||
|
||||
if not output_switch:
|
||||
return ()
|
||||
|
||||
# 更新持久化缓存
|
||||
|
||||
self.update_persistent_cache()
|
||||
|
||||
# 返回处理后的图像和遮罩
|
||||
|
||||
return (processed_image, processed_mask)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in process_canvas_image: {str(e)}")
|
||||
traceback.print_exc()
|
||||
return ()
|
||||
|
||||
# 添加获取缓存数据的方法
|
||||
def get_cached_data(self):
|
||||
return {
|
||||
'image': self.__class__._canvas_cache['image'],
|
||||
'mask': self.__class__._canvas_cache['mask']
|
||||
}
|
||||
|
||||
# 添加API路由处理器
|
||||
@classmethod
|
||||
def api_get_data(cls, node_id):
|
||||
try:
|
||||
@@ -329,9 +314,23 @@ class CanvasNode:
|
||||
'error': str(e)
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_latest_image(cls):
|
||||
output_dir = folder_paths.get_output_directory()
|
||||
files = [os.path.join(output_dir, f) for f in os.listdir(output_dir) if
|
||||
os.path.isfile(os.path.join(output_dir, f))]
|
||||
|
||||
image_files = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif'))]
|
||||
|
||||
if not image_files:
|
||||
return None
|
||||
|
||||
latest_image_path = max(image_files, key=os.path.getctime)
|
||||
return latest_image_path
|
||||
|
||||
@classmethod
|
||||
def get_flow_status(cls, flow_id=None):
|
||||
"""获取数据流状态"""
|
||||
|
||||
if flow_id:
|
||||
return cls._canvas_cache['data_flow_status'].get(flow_id)
|
||||
return cls._canvas_cache['data_flow_status']
|
||||
@@ -343,11 +342,11 @@ class CanvasNode:
|
||||
try:
|
||||
node_id = request.match_info["node_id"]
|
||||
print(f"Received request for node: {node_id}")
|
||||
|
||||
|
||||
cache_data = cls._canvas_cache
|
||||
print(f"Cache content: {cache_data}")
|
||||
print(f"Image in cache: {cache_data['image'] is not None}")
|
||||
|
||||
|
||||
response_data = {
|
||||
'success': True,
|
||||
'data': {
|
||||
@@ -355,23 +354,23 @@ class CanvasNode:
|
||||
'mask': None
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if cache_data['image'] is not None:
|
||||
pil_image = cache_data['image']
|
||||
buffered = io.BytesIO()
|
||||
pil_image.save(buffered, format="PNG")
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode()
|
||||
response_data['data']['image'] = f"data:image/png;base64,{img_str}"
|
||||
|
||||
|
||||
if cache_data['mask'] is not None:
|
||||
pil_mask = cache_data['mask']
|
||||
mask_buffer = io.BytesIO()
|
||||
pil_mask.save(mask_buffer, format="PNG")
|
||||
mask_str = base64.b64encode(mask_buffer.getvalue()).decode()
|
||||
response_data['data']['mask'] = f"data:image/png;base64,{mask_str}"
|
||||
|
||||
|
||||
return web.json_response(response_data)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in get_canvas_data: {str(e)}")
|
||||
return web.json_response({
|
||||
@@ -379,17 +378,39 @@ class CanvasNode:
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
@PromptServer.instance.routes.get("/ycnode/get_latest_image")
|
||||
async def get_latest_image_route(request):
|
||||
try:
|
||||
latest_image_path = cls.get_latest_image()
|
||||
if latest_image_path:
|
||||
with open(latest_image_path, "rb") as f:
|
||||
encoded_string = base64.b64encode(f.read()).decode('utf-8')
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'image_data': f"data:image/png;base64,{encoded_string}"
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No images found in output directory.'
|
||||
}, status=404)
|
||||
except Exception as e:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
def store_image(self, image_data):
|
||||
# 将base64数据转换为PIL Image并存储
|
||||
|
||||
if isinstance(image_data, str) and image_data.startswith('data:image'):
|
||||
image_data = image_data.split(',')[1]
|
||||
image_bytes = base64.b64decode(image_data)
|
||||
self.cached_image = Image.open(io.BytesIO(image_bytes))
|
||||
else:
|
||||
self.cached_image = image_data
|
||||
|
||||
|
||||
def get_cached_image(self):
|
||||
# 将PIL Image转换为base64
|
||||
|
||||
if self.cached_image:
|
||||
buffered = io.BytesIO()
|
||||
self.cached_image.save(buffered, format="PNG")
|
||||
@@ -397,78 +418,77 @@ class CanvasNode:
|
||||
return f"data:image/png;base64,{img_str}"
|
||||
return None
|
||||
|
||||
|
||||
class BiRefNetMatting:
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.model_path = None
|
||||
self.model_cache = {}
|
||||
# 使用 ComfyUI models 目录
|
||||
self.base_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "models")
|
||||
|
||||
self.base_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
|
||||
"models")
|
||||
|
||||
def load_model(self, model_path):
|
||||
try:
|
||||
if model_path not in self.model_cache:
|
||||
# 使用 ComfyUI models 目录下的 BiRefNet 路径
|
||||
|
||||
full_model_path = os.path.join(self.base_path, "BiRefNet")
|
||||
|
||||
|
||||
print(f"Loading BiRefNet model from {full_model_path}...")
|
||||
|
||||
|
||||
try:
|
||||
# 直接从Hugging Face加载
|
||||
|
||||
self.model = AutoModelForImageSegmentation.from_pretrained(
|
||||
"ZhengPeng7/BiRefNet",
|
||||
trust_remote_code=True,
|
||||
cache_dir=full_model_path # 使用本地缓存目录
|
||||
cache_dir=full_model_path
|
||||
)
|
||||
|
||||
# 设置为评估模式并移动到GPU
|
||||
|
||||
self.model.eval()
|
||||
if torch.cuda.is_available():
|
||||
self.model = self.model.cuda()
|
||||
|
||||
|
||||
self.model_cache[model_path] = self.model
|
||||
print("Model loaded successfully from Hugging Face")
|
||||
print(f"Model type: {type(self.model)}")
|
||||
print(f"Model device: {next(self.model.parameters()).device}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to load model: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
else:
|
||||
self.model = self.model_cache[model_path]
|
||||
print("Using cached model")
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading model: {str(e)}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def preprocess_image(self, image):
|
||||
"""预处理输入图像"""
|
||||
|
||||
try:
|
||||
# 转换为PIL图像
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
if image.dim() == 4:
|
||||
image = image.squeeze(0)
|
||||
if image.dim() == 3:
|
||||
image = transforms.ToPILImage()(image)
|
||||
|
||||
# 参考nodes.py的预处理
|
||||
|
||||
transform_image = transforms.Compose([
|
||||
transforms.Resize((1024, 1024)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
# 转换为tensor并添加batch维度
|
||||
|
||||
image_tensor = transform_image(image).unsqueeze(0)
|
||||
|
||||
|
||||
if torch.cuda.is_available():
|
||||
image_tensor = image_tensor.cuda()
|
||||
|
||||
|
||||
return image_tensor
|
||||
except Exception as e:
|
||||
print(f"Error preprocessing image: {str(e)}")
|
||||
@@ -476,43 +496,37 @@ class BiRefNetMatting:
|
||||
|
||||
def execute(self, image, model_path, threshold=0.5, refinement=1):
|
||||
try:
|
||||
# 发送开始状态
|
||||
|
||||
PromptServer.instance.send_sync("matting_status", {"status": "processing"})
|
||||
|
||||
# 加载模型
|
||||
|
||||
if not self.load_model(model_path):
|
||||
raise RuntimeError("Failed to load model")
|
||||
|
||||
# 获取原始尺寸
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
original_size = image.shape[-2:] if image.dim() == 4 else image.shape[-2:]
|
||||
else:
|
||||
original_size = image.size[::-1]
|
||||
|
||||
|
||||
print(f"Original size: {original_size}")
|
||||
|
||||
# 预处理图像
|
||||
|
||||
processed_image = self.preprocess_image(image)
|
||||
if processed_image is None:
|
||||
raise Exception("Failed to preprocess image")
|
||||
|
||||
|
||||
print(f"Processed image shape: {processed_image.shape}")
|
||||
|
||||
# 执行推理
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self.model(processed_image)
|
||||
result = outputs[-1].sigmoid().cpu()
|
||||
print(f"Model output shape: {result.shape}")
|
||||
|
||||
# 确保结果有正的维度格式 [B, C, H, W]
|
||||
|
||||
if result.dim() == 3:
|
||||
result = result.unsqueeze(1) # 添加通道维度
|
||||
elif result.dim() == 2:
|
||||
result = result.unsqueeze(0).unsqueeze(0) # 添加batch和通道维度
|
||||
|
||||
|
||||
print(f"Reshaped result shape: {result.shape}")
|
||||
|
||||
# 调整大小
|
||||
|
||||
result = F.interpolate(
|
||||
result,
|
||||
size=(original_size[0], original_size[1]), # 明确指定高度和宽度
|
||||
@@ -520,18 +534,15 @@ class BiRefNetMatting:
|
||||
align_corners=True
|
||||
)
|
||||
print(f"Resized result shape: {result.shape}")
|
||||
|
||||
# 归一化
|
||||
|
||||
result = result.squeeze() # 移除多余的维度
|
||||
ma = torch.max(result)
|
||||
mi = torch.min(result)
|
||||
result = (result-mi)/(ma-mi)
|
||||
|
||||
# 应用阈值
|
||||
result = (result - mi) / (ma - mi)
|
||||
|
||||
if threshold > 0:
|
||||
result = (result > threshold).float()
|
||||
|
||||
# 创建mask和结果图像
|
||||
|
||||
alpha_mask = result.unsqueeze(0).unsqueeze(0) # 确保mask是 [1, 1, H, W]
|
||||
if isinstance(image, torch.Tensor):
|
||||
if image.dim() == 3:
|
||||
@@ -540,20 +551,19 @@ class BiRefNetMatting:
|
||||
else:
|
||||
image_tensor = transforms.ToTensor()(image).unsqueeze(0)
|
||||
masked_image = image_tensor * alpha_mask
|
||||
|
||||
# 发送完成状态
|
||||
|
||||
PromptServer.instance.send_sync("matting_status", {"status": "completed"})
|
||||
|
||||
|
||||
return (masked_image, alpha_mask)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# 发送错误状态
|
||||
|
||||
PromptServer.instance.send_sync("matting_status", {"status": "error"})
|
||||
raise e
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(cls, image, model_path, threshold, refinement):
|
||||
"""检查输入是否改变"""
|
||||
|
||||
m = hashlib.md5()
|
||||
m.update(str(image).encode())
|
||||
m.update(str(model_path).encode())
|
||||
@@ -561,36 +571,33 @@ class BiRefNetMatting:
|
||||
m.update(str(refinement).encode())
|
||||
return m.hexdigest()
|
||||
|
||||
|
||||
@PromptServer.instance.routes.post("/matting")
|
||||
async def matting(request):
|
||||
try:
|
||||
print("Received matting request")
|
||||
data = await request.json()
|
||||
|
||||
# 取BiRefNet实例
|
||||
|
||||
matting = BiRefNetMatting()
|
||||
|
||||
# 处理图像数据,现在返回图像tensor和alpha通道
|
||||
|
||||
image_tensor, original_alpha = convert_base64_to_tensor(data["image"])
|
||||
print(f"Input image shape: {image_tensor.shape}")
|
||||
|
||||
# 执行抠图
|
||||
|
||||
matted_image, alpha_mask = matting.execute(
|
||||
image_tensor,
|
||||
image_tensor,
|
||||
"BiRefNet/model.safetensors",
|
||||
threshold=data.get("threshold", 0.5),
|
||||
refinement=data.get("refinement", 1)
|
||||
)
|
||||
|
||||
# 转换结果为base64,包含原始alpha信息
|
||||
|
||||
result_image = convert_tensor_to_base64(matted_image, alpha_mask, original_alpha)
|
||||
result_mask = convert_tensor_to_base64(alpha_mask)
|
||||
|
||||
|
||||
return web.json_response({
|
||||
"matted_image": result_image,
|
||||
"alpha_mask": result_mask
|
||||
})
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in matting endpoint: {str(e)}")
|
||||
import traceback
|
||||
@@ -600,93 +607,83 @@ async def matting(request):
|
||||
"details": traceback.format_exc()
|
||||
}, status=500)
|
||||
|
||||
|
||||
def convert_base64_to_tensor(base64_str):
|
||||
"""将base64图像数据转换为tensor,保留alpha通道"""
|
||||
import base64
|
||||
import io
|
||||
|
||||
|
||||
try:
|
||||
# 解码base64数据
|
||||
|
||||
img_data = base64.b64decode(base64_str.split(',')[1])
|
||||
img = Image.open(io.BytesIO(img_data))
|
||||
|
||||
# 保存原始alpha通道
|
||||
|
||||
has_alpha = img.mode == 'RGBA'
|
||||
alpha = None
|
||||
if has_alpha:
|
||||
alpha = img.split()[3]
|
||||
# 创建白色背景
|
||||
|
||||
background = Image.new('RGB', img.size, (255, 255, 255))
|
||||
background.paste(img, mask=alpha)
|
||||
img = background
|
||||
elif img.mode != 'RGB':
|
||||
img = img.convert('RGB')
|
||||
|
||||
# 转换为tensor
|
||||
|
||||
transform = transforms.ToTensor()
|
||||
img_tensor = transform(img).unsqueeze(0) # [1, C, H, W]
|
||||
|
||||
|
||||
if has_alpha:
|
||||
# 将alpha转换为tensor并保存
|
||||
alpha_tensor = transforms.ToTensor()(alpha).unsqueeze(0) # [1, 1, H, W]
|
||||
return img_tensor, alpha_tensor
|
||||
|
||||
|
||||
return img_tensor, None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in convert_base64_to_tensor: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def convert_tensor_to_base64(tensor, alpha_mask=None, original_alpha=None):
|
||||
"""将tensor转换为base64图像数据,支持alpha通道"""
|
||||
import base64
|
||||
import io
|
||||
|
||||
|
||||
try:
|
||||
# 确保tensor在CPU上
|
||||
|
||||
tensor = tensor.cpu()
|
||||
|
||||
# 处理维度
|
||||
|
||||
if tensor.dim() == 4:
|
||||
tensor = tensor.squeeze(0) # 移除batch维度
|
||||
if tensor.dim() == 3 and tensor.shape[0] in [1, 3]:
|
||||
tensor = tensor.permute(1, 2, 0)
|
||||
|
||||
# 转换为numpy数组并调整值范围到0-255
|
||||
|
||||
img_array = (tensor.numpy() * 255).astype(np.uint8)
|
||||
|
||||
# 如果有alpha遮罩和原始alpha
|
||||
|
||||
if alpha_mask is not None and original_alpha is not None:
|
||||
# 将alpha_mask转换为正确的格式
|
||||
|
||||
alpha_mask = alpha_mask.cpu().squeeze().numpy()
|
||||
alpha_mask = (alpha_mask * 255).astype(np.uint8)
|
||||
|
||||
# 将原始alpha转换为正确的格式
|
||||
|
||||
original_alpha = original_alpha.cpu().squeeze().numpy()
|
||||
original_alpha = (original_alpha * 255).astype(np.uint8)
|
||||
|
||||
# 组合alpha_mask和original_alpha
|
||||
|
||||
combined_alpha = np.minimum(alpha_mask, original_alpha)
|
||||
|
||||
# 创建RGBA图像
|
||||
|
||||
img = Image.fromarray(img_array, mode='RGB')
|
||||
alpha_img = Image.fromarray(combined_alpha, mode='L')
|
||||
img.putalpha(alpha_img)
|
||||
else:
|
||||
# 处理没有alpha通道的情况
|
||||
|
||||
if img_array.shape[-1] == 1:
|
||||
img_array = img_array.squeeze(-1)
|
||||
img = Image.fromarray(img_array, mode='L')
|
||||
else:
|
||||
img = Image.fromarray(img_array, mode='RGB')
|
||||
|
||||
# 转换为base64
|
||||
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format='PNG')
|
||||
img_str = base64.b64encode(buffer.getvalue()).decode()
|
||||
|
||||
|
||||
return f"data:image/png;base64,{img_str}"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in convert_tensor_to_base64: {str(e)}")
|
||||
print(f"Tensor shape: {tensor.shape}, dtype: {tensor.dtype}")
|
||||
|
||||
Reference in New Issue
Block a user