Update canvas_node.py

This commit is contained in:
tanglup
2024-11-24 10:34:20 +08:00
committed by GitHub
parent 6e020ca3b8
commit af0a32a4f9

View File

@@ -11,6 +11,11 @@ from torchvision import transforms
from transformers import AutoModelForImageSegmentation, PretrainedConfig
import torch.nn.functional as F
import traceback
import uuid
import time
import base64
from PIL import Image
import io
# 设置高精度计算
torch.set_float32_matmul_precision('high')
@@ -45,16 +50,92 @@ class BiRefNet(torch.nn.Module):
output = self.decoder(features)
return [output]
class CanvasView:
class CanvasNode:
_canvas_cache = {
'image': None,
'mask': None,
'cache_enabled': True,
'data_flow_status': {},
'persistent_cache': {},
'last_execution_id': None
}
def __init__(self):
super().__init__()
self.flow_id = str(uuid.uuid4())
# 从持久化缓存恢复数据
if self.__class__._canvas_cache['persistent_cache']:
self.restore_cache()
def restore_cache(self):
"""从持久化缓存恢复数据,除非是新的执行"""
try:
persistent = self.__class__._canvas_cache['persistent_cache']
current_execution = self.get_execution_id()
# 只有在新的执行ID时才清除缓存
if current_execution != self.__class__._canvas_cache['last_execution_id']:
print(f"New execution detected: {current_execution}")
self.__class__._canvas_cache['image'] = None
self.__class__._canvas_cache['mask'] = None
self.__class__._canvas_cache['last_execution_id'] = current_execution
else:
# 否则保留现有缓存
if persistent.get('image') is not None:
self.__class__._canvas_cache['image'] = persistent['image']
print("Restored image from persistent cache")
if persistent.get('mask') is not None:
self.__class__._canvas_cache['mask'] = persistent['mask']
print("Restored mask from persistent cache")
except Exception as e:
print(f"Error restoring cache: {str(e)}")
def get_execution_id(self):
"""获取当前工作流执行ID"""
try:
# 可以使用时间戳或其他唯一标识
return str(int(time.time() * 1000))
except Exception as e:
print(f"Error getting execution ID: {str(e)}")
return None
def update_persistent_cache(self):
"""更新持久化缓存"""
try:
self.__class__._canvas_cache['persistent_cache'] = {
'image': self.__class__._canvas_cache['image'],
'mask': self.__class__._canvas_cache['mask']
}
print("Updated persistent cache")
except Exception as e:
print(f"Error updating persistent cache: {str(e)}")
def track_data_flow(self, stage, status, data_info=None):
"""追踪数据流状态"""
flow_status = {
'timestamp': time.time(),
'stage': stage,
'status': status,
'data_info': data_info
}
print(f"Data Flow [{self.flow_id}] - Stage: {stage}, Status: {status}")
if data_info:
print(f"Data Info: {data_info}")
self.__class__._canvas_cache['data_flow_status'][self.flow_id] = flow_status
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"canvas_image": ("STRING", {"default": "canvas_image.png"}),
"trigger": ("INT", {"default": 0, "min": 0, "max": 99999999, "step": 1})
"trigger": ("INT", {"default": 0, "min": 0, "max": 99999999, "step": 1, "hidden": True}),
"output_switch": ("BOOLEAN", {"default": True}),
"cache_enabled": ("BOOLEAN", {"default": True, "label": "Enable Cache"})
},
"hidden": {
"unique_id": "UNIQUE_ID"
"optional": {
"input_image": ("IMAGE",),
"input_mask": ("MASK",)
}
}
@@ -63,42 +144,258 @@ class CanvasView:
FUNCTION = "process_canvas_image"
CATEGORY = "ycnode"
def process_canvas_image(self, canvas_image, trigger, unique_id):
def add_image_to_canvas(self, input_image):
"""处理输入图像"""
try:
# 读取保存的画布图像和遮
path_image = folder_paths.get_annotated_filepath(canvas_image)
path_mask = folder_paths.get_annotated_filepath(canvas_image.replace('.png', '_mask.png'))
# 确保输入图像是正确的格式
if not isinstance(input_image, torch.Tensor):
raise ValueError("Input image must be a torch.Tensor")
# 处理图像
i = Image.open(path_image)
i = ImageOps.exif_transpose(i)
if i.mode not in ['RGB', 'RGBA']:
i = i.convert('RGB')
image = np.array(i).astype(np.float32) / 255.0
if i.mode == 'RGBA':
rgb = image[..., :3]
alpha = image[..., 3:]
image = rgb * alpha + (1 - alpha) * 0.5
# 处理图像维度
if input_image.dim() == 4:
input_image = input_image.squeeze(0)
# 处理遮罩图像
try:
mask = Image.open(path_mask).convert('L')
mask = np.array(mask).astype(np.float32) / 255.0
mask = torch.from_numpy(mask)[None,]
except:
# 如果没有遮罩文件,创建空白遮罩
mask = torch.zeros((1, image.shape[0], image.shape[1]), dtype=torch.float32)
# 转换为标准格式
if input_image.dim() == 3 and input_image.shape[0] in [1, 3]:
input_image = input_image.permute(1, 2, 0)
# 转换为tensor
image = torch.from_numpy(image)[None,]
return input_image
return (image, mask)
except Exception as e:
print(f"Error processing canvas image: {str(e)}")
# 回白色图像和空白遮罩
blank = np.ones((512, 512, 3), dtype=np.float32)
blank_mask = np.zeros((512, 512), dtype=np.float32)
return (torch.from_numpy(blank)[None,], torch.from_numpy(blank_mask)[None,])
print(f"Error in add_image_to_canvas: {str(e)}")
return None
def add_mask_to_canvas(self, input_mask, input_image):
"""处理输入遮罩"""
try:
# 确保输入遮罩是正确的格式
if not isinstance(input_mask, torch.Tensor):
raise ValueError("Input mask must be a torch.Tensor")
# 处理遮罩维度
if input_mask.dim() == 4:
input_mask = input_mask.squeeze(0)
if input_mask.dim() == 3 and input_mask.shape[0] == 1:
input_mask = input_mask.squeeze(0)
# 确保遮罩尺寸与图像匹配
if input_image is not None:
expected_shape = input_image.shape[:2]
if input_mask.shape != expected_shape:
input_mask = F.interpolate(
input_mask.unsqueeze(0).unsqueeze(0),
size=expected_shape,
mode='bilinear',
align_corners=False
).squeeze()
return input_mask
except Exception as e:
print(f"Error in add_mask_to_canvas: {str(e)}")
return None
def process_canvas_image(self, canvas_image, trigger, output_switch, cache_enabled, input_image=None, input_mask=None):
try:
current_execution = self.get_execution_id()
print(f"Processing canvas image, execution ID: {current_execution}")
# 检查是否是新的执行
if current_execution != self.__class__._canvas_cache['last_execution_id']:
print(f"New execution detected: {current_execution}")
# 清除旧的缓存
self.__class__._canvas_cache['image'] = None
self.__class__._canvas_cache['mask'] = None
self.__class__._canvas_cache['last_execution_id'] = current_execution
# 处理输入图像
if input_image is not None:
print("Input image received, converting to PIL Image...")
# 将tensor转换为PIL Image并存储到缓存
if isinstance(input_image, torch.Tensor):
if input_image.dim() == 4:
input_image = input_image.squeeze(0) # 移除batch维度
# 确保图像格式为[H, W, C]
if input_image.shape[0] == 3: # 如果是[C, H, W]格式
input_image = input_image.permute(1, 2, 0)
# 转换为numpy数组并确保值范围在0-255
image_array = (input_image.cpu().numpy() * 255).astype(np.uint8)
# 确保数组形状正确
if len(image_array.shape) == 2: # 如果是灰度图
image_array = np.stack([image_array] * 3, axis=-1)
elif len(image_array.shape) == 3 and image_array.shape[-1] != 3:
image_array = np.transpose(image_array, (1, 2, 0))
try:
# 转换为PIL Image
pil_image = Image.fromarray(image_array, 'RGB')
print("Successfully converted to PIL Image")
# 存储PIL Image到缓存
self.__class__._canvas_cache['image'] = pil_image
print(f"Image stored in cache with size: {pil_image.size}")
except Exception as e:
print(f"Error converting to PIL Image: {str(e)}")
print(f"Array shape: {image_array.shape}, dtype: {image_array.dtype}")
raise
# 处理输入遮罩
if input_mask is not None:
print("Input mask received, converting to PIL Image...")
if isinstance(input_mask, torch.Tensor):
if input_mask.dim() == 4:
input_mask = input_mask.squeeze(0)
if input_mask.dim() == 3 and input_mask.shape[0] == 1:
input_mask = input_mask.squeeze(0)
# 转换为PIL Image
mask_array = (input_mask.cpu().numpy() * 255).astype(np.uint8)
pil_mask = Image.fromarray(mask_array, 'L')
print("Successfully converted mask to PIL Image")
# 存储遮罩到缓存
self.__class__._canvas_cache['mask'] = pil_mask
print(f"Mask stored in cache with size: {pil_mask.size}")
# 更新缓存开关状态
self.__class__._canvas_cache['cache_enabled'] = cache_enabled
try:
# 尝试读取画布图像
path_image = folder_paths.get_annotated_filepath(canvas_image)
i = Image.open(path_image)
i = ImageOps.exif_transpose(i)
if i.mode not in ['RGB', 'RGBA']:
i = i.convert('RGB')
image = np.array(i).astype(np.float32) / 255.0
if i.mode == 'RGBA':
rgb = image[..., :3]
alpha = image[..., 3:]
image = rgb * alpha + (1 - alpha) * 0.5
processed_image = torch.from_numpy(image)[None,]
except Exception as e:
# 如果读取失败,创建白色画布
processed_image = torch.ones((1, 512, 512, 3), dtype=torch.float32)
try:
# 尝试读取遮罩图像
path_mask = path_image.replace('.png', '_mask.png')
if os.path.exists(path_mask):
mask = Image.open(path_mask).convert('L')
mask = np.array(mask).astype(np.float32) / 255.0
processed_mask = torch.from_numpy(mask)[None,]
else:
# 如果没有遮罩文件,创建全白遮罩
processed_mask = torch.ones((1, processed_image.shape[1], processed_image.shape[2]), dtype=torch.float32)
except Exception as e:
print(f"Error loading mask: {str(e)}")
# 创建默认遮罩
processed_mask = torch.ones((1, processed_image.shape[1], processed_image.shape[2]), dtype=torch.float32)
# 输出处理
if not output_switch:
return ()
# 更新持久化缓存
self.update_persistent_cache()
# 返回处理后的图像和遮罩
return (processed_image, processed_mask)
except Exception as e:
print(f"Error in process_canvas_image: {str(e)}")
traceback.print_exc()
return ()
# 添加获取缓存数据的方法
def get_cached_data(self):
return {
'image': self.__class__._canvas_cache['image'],
'mask': self.__class__._canvas_cache['mask']
}
# 添加API路由处理器
@classmethod
def api_get_data(cls, node_id):
try:
return {
'success': True,
'data': cls._canvas_cache
}
except Exception as e:
return {
'success': False,
'error': str(e)
}
@classmethod
def get_flow_status(cls, flow_id=None):
"""获取数据流状态"""
if flow_id:
return cls._canvas_cache['data_flow_status'].get(flow_id)
return cls._canvas_cache['data_flow_status']
@classmethod
def setup_routes(cls):
@PromptServer.instance.routes.get("/ycnode/get_canvas_data/{node_id}")
async def get_canvas_data(request):
try:
node_id = request.match_info["node_id"]
print(f"Received request for node: {node_id}")
cache_data = cls._canvas_cache
print(f"Cache content: {cache_data}")
print(f"Image in cache: {cache_data['image'] is not None}")
response_data = {
'success': True,
'data': {
'image': None,
'mask': None
}
}
if cache_data['image'] is not None:
pil_image = cache_data['image']
buffered = io.BytesIO()
pil_image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
response_data['data']['image'] = f"data:image/png;base64,{img_str}"
if cache_data['mask'] is not None:
pil_mask = cache_data['mask']
mask_buffer = io.BytesIO()
pil_mask.save(mask_buffer, format="PNG")
mask_str = base64.b64encode(mask_buffer.getvalue()).decode()
response_data['data']['mask'] = f"data:image/png;base64,{mask_str}"
return web.json_response(response_data)
except Exception as e:
print(f"Error in get_canvas_data: {str(e)}")
return web.json_response({
'success': False,
'error': str(e)
})
def store_image(self, image_data):
# 将base64数据转换为PIL Image并存储
if isinstance(image_data, str) and image_data.startswith('data:image'):
image_data = image_data.split(',')[1]
image_bytes = base64.b64decode(image_data)
self.cached_image = Image.open(io.BytesIO(image_bytes))
else:
self.cached_image = image_data
def get_cached_image(self):
# 将PIL Image转换为base64
if self.cached_image:
buffered = io.BytesIO()
self.cached_image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/png;base64,{img_str}"
return None
class BiRefNetMatting:
def __init__(self):
@@ -270,7 +567,7 @@ async def matting(request):
print("Received matting request")
data = await request.json()
# 取BiRefNet实例
# 取BiRefNet实例
matting = BiRefNetMatting()
# 处理图像数据,现在返回图像tensor和alpha通道