mirror of
https://github.com/Azornes/Comfyui-LayerForge.git
synced 2026-03-25 22:35:43 -03:00
Added Outpainting Logic
This commit is contained in:
4
.github/workflows/publish.yml
vendored
4
.github/workflows/publish.yml
vendored
@@ -12,7 +12,7 @@ jobs:
|
|||||||
publish-node:
|
publish-node:
|
||||||
name: Publish Custom Node to registry
|
name: Publish Custom Node to registry
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
# if this is a forked repository. Skipping the workflow.
|
|
||||||
if: github.event.repository.fork == false
|
if: github.event.repository.fork == false
|
||||||
steps:
|
steps:
|
||||||
- name: Check out code
|
- name: Check out code
|
||||||
@@ -20,5 +20,5 @@ jobs:
|
|||||||
- name: Publish Custom Node
|
- name: Publish Custom Node
|
||||||
uses: Comfy-Org/publish-node-action@main
|
uses: Comfy-Org/publish-node-action@main
|
||||||
with:
|
with:
|
||||||
## Add your own personal access token to your Github Repository secrets and reference it here.
|
|
||||||
personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
|
personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
|
||||||
|
|||||||
12
README.md
12
README.md
@@ -1,4 +1,4 @@
|
|||||||
# Comfyui-Ycnode
|
|
||||||
**Canvas Node**
|
**Canvas Node**
|
||||||
|
|
||||||
**1**. Basic Operations:
|
**1**. Basic Operations:
|
||||||
@@ -27,21 +27,21 @@ Model Name: models--ZhengPeng7--BiRefNet
|
|||||||
|
|
||||||
The cloud disk link is as follows
|
The cloud disk link is as follows
|
||||||
|
|
||||||
baidu Link:https://pan.baidu.com/s/1PiZvuHcdlcZGoL7WDYnMkA?pwd=nt76
|
baidu Link:https:
|
||||||
google link: https://drive.google.com/drive/folders/1BCLInCLH89fmTpYoP8Sgs_Eqww28f_wq?usp=sharing
|
google link: https:
|
||||||
|
|
||||||
Place it in: models/BiRefNet
|
Place it in: models/BiRefNet
|
||||||
|
|
||||||
2024/11/24 Updated Features:
|
2024/11/24 Updated Features:
|
||||||
Add input images and masks; add blending mode options for images in the canvas (you can select them by selecting the image and then shift+clicking the image to pop up the menu)
|
Add input images and masks; add blending mode options for images in the canvas (you can select them by selecting the image and then shift+clicking the image to pop up the menu)
|
||||||
Note: The output blending mode does not change, and needs to be updated by slightly changing the canvas content
|
Note: The output blending mode does not change, and needs to be updated by slightly changing the canvas content
|
||||||

|

|

|

|
CanvasNode.setup_routes()
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
|||||||
183
canvas_node.py
183
canvas_node.py
@@ -17,21 +17,21 @@ import base64
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import io
|
import io
|
||||||
|
|
||||||
# 设置高精度计算
|
|
||||||
torch.set_float32_matmul_precision('high')
|
torch.set_float32_matmul_precision('high')
|
||||||
|
|
||||||
# 定义配置类
|
|
||||||
class BiRefNetConfig(PretrainedConfig):
|
class BiRefNetConfig(PretrainedConfig):
|
||||||
model_type = "BiRefNet"
|
model_type = "BiRefNet"
|
||||||
|
|
||||||
def __init__(self, bb_pretrained=False, **kwargs):
|
def __init__(self, bb_pretrained=False, **kwargs):
|
||||||
self.bb_pretrained = bb_pretrained
|
self.bb_pretrained = bb_pretrained
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
# 定义模型类
|
|
||||||
class BiRefNet(torch.nn.Module):
|
class BiRefNet(torch.nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# 基本网络结构
|
|
||||||
self.encoder = torch.nn.Sequential(
|
self.encoder = torch.nn.Sequential(
|
||||||
torch.nn.Conv2d(3, 64, kernel_size=3, padding=1),
|
torch.nn.Conv2d(3, 64, kernel_size=3, padding=1),
|
||||||
torch.nn.ReLU(inplace=True),
|
torch.nn.ReLU(inplace=True),
|
||||||
@@ -50,6 +50,7 @@ class BiRefNet(torch.nn.Module):
|
|||||||
output = self.decoder(features)
|
output = self.decoder(features)
|
||||||
return [output]
|
return [output]
|
||||||
|
|
||||||
|
|
||||||
class CanvasNode:
|
class CanvasNode:
|
||||||
_canvas_cache = {
|
_canvas_cache = {
|
||||||
'image': None,
|
'image': None,
|
||||||
@@ -63,24 +64,22 @@ class CanvasNode:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.flow_id = str(uuid.uuid4())
|
self.flow_id = str(uuid.uuid4())
|
||||||
# 从持久化缓存恢复数据
|
|
||||||
if self.__class__._canvas_cache['persistent_cache']:
|
if self.__class__._canvas_cache['persistent_cache']:
|
||||||
self.restore_cache()
|
self.restore_cache()
|
||||||
|
|
||||||
def restore_cache(self):
|
def restore_cache(self):
|
||||||
"""从持久化缓存恢复数据,除非是新的执行"""
|
|
||||||
try:
|
try:
|
||||||
persistent = self.__class__._canvas_cache['persistent_cache']
|
persistent = self.__class__._canvas_cache['persistent_cache']
|
||||||
current_execution = self.get_execution_id()
|
current_execution = self.get_execution_id()
|
||||||
|
|
||||||
# 只有在新的执行ID时才清除缓存
|
|
||||||
if current_execution != self.__class__._canvas_cache['last_execution_id']:
|
if current_execution != self.__class__._canvas_cache['last_execution_id']:
|
||||||
print(f"New execution detected: {current_execution}")
|
print(f"New execution detected: {current_execution}")
|
||||||
self.__class__._canvas_cache['image'] = None
|
self.__class__._canvas_cache['image'] = None
|
||||||
self.__class__._canvas_cache['mask'] = None
|
self.__class__._canvas_cache['mask'] = None
|
||||||
self.__class__._canvas_cache['last_execution_id'] = current_execution
|
self.__class__._canvas_cache['last_execution_id'] = current_execution
|
||||||
else:
|
else:
|
||||||
# 否则保留现有缓存
|
|
||||||
if persistent.get('image') is not None:
|
if persistent.get('image') is not None:
|
||||||
self.__class__._canvas_cache['image'] = persistent['image']
|
self.__class__._canvas_cache['image'] = persistent['image']
|
||||||
print("Restored image from persistent cache")
|
print("Restored image from persistent cache")
|
||||||
@@ -91,16 +90,16 @@ class CanvasNode:
|
|||||||
print(f"Error restoring cache: {str(e)}")
|
print(f"Error restoring cache: {str(e)}")
|
||||||
|
|
||||||
def get_execution_id(self):
|
def get_execution_id(self):
|
||||||
"""获取当前工作流执行ID"""
|
|
||||||
try:
|
try:
|
||||||
# 可以使用时间戳或其他唯一标识
|
|
||||||
return str(int(time.time() * 1000))
|
return str(int(time.time() * 1000))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error getting execution ID: {str(e)}")
|
print(f"Error getting execution ID: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_persistent_cache(self):
|
def update_persistent_cache(self):
|
||||||
"""更新持久化缓存"""
|
|
||||||
try:
|
try:
|
||||||
self.__class__._canvas_cache['persistent_cache'] = {
|
self.__class__._canvas_cache['persistent_cache'] = {
|
||||||
'image': self.__class__._canvas_cache['image'],
|
'image': self.__class__._canvas_cache['image'],
|
||||||
@@ -111,7 +110,7 @@ class CanvasNode:
|
|||||||
print(f"Error updating persistent cache: {str(e)}")
|
print(f"Error updating persistent cache: {str(e)}")
|
||||||
|
|
||||||
def track_data_flow(self, stage, status, data_info=None):
|
def track_data_flow(self, stage, status, data_info=None):
|
||||||
"""追踪数据流状态"""
|
|
||||||
flow_status = {
|
flow_status = {
|
||||||
'timestamp': time.time(),
|
'timestamp': time.time(),
|
||||||
'stage': stage,
|
'stage': stage,
|
||||||
@@ -145,17 +144,15 @@ class CanvasNode:
|
|||||||
CATEGORY = "Ycanvas"
|
CATEGORY = "Ycanvas"
|
||||||
|
|
||||||
def add_image_to_canvas(self, input_image):
|
def add_image_to_canvas(self, input_image):
|
||||||
"""处理输入图像"""
|
|
||||||
try:
|
try:
|
||||||
# 确保输入图像是正确的格式
|
|
||||||
if not isinstance(input_image, torch.Tensor):
|
if not isinstance(input_image, torch.Tensor):
|
||||||
raise ValueError("Input image must be a torch.Tensor")
|
raise ValueError("Input image must be a torch.Tensor")
|
||||||
|
|
||||||
# 处理图像维度
|
|
||||||
if input_image.dim() == 4:
|
if input_image.dim() == 4:
|
||||||
input_image = input_image.squeeze(0)
|
input_image = input_image.squeeze(0)
|
||||||
|
|
||||||
# 转换为标准格式
|
|
||||||
if input_image.dim() == 3 and input_image.shape[0] in [1, 3]:
|
if input_image.dim() == 3 and input_image.shape[0] in [1, 3]:
|
||||||
input_image = input_image.permute(1, 2, 0)
|
input_image = input_image.permute(1, 2, 0)
|
||||||
|
|
||||||
@@ -166,19 +163,17 @@ class CanvasNode:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def add_mask_to_canvas(self, input_mask, input_image):
|
def add_mask_to_canvas(self, input_mask, input_image):
|
||||||
"""处理输入遮罩"""
|
|
||||||
try:
|
try:
|
||||||
# 确保输入遮罩是正确的格式
|
|
||||||
if not isinstance(input_mask, torch.Tensor):
|
if not isinstance(input_mask, torch.Tensor):
|
||||||
raise ValueError("Input mask must be a torch.Tensor")
|
raise ValueError("Input mask must be a torch.Tensor")
|
||||||
|
|
||||||
# 处理遮罩维度
|
|
||||||
if input_mask.dim() == 4:
|
if input_mask.dim() == 4:
|
||||||
input_mask = input_mask.squeeze(0)
|
input_mask = input_mask.squeeze(0)
|
||||||
if input_mask.dim() == 3 and input_mask.shape[0] == 1:
|
if input_mask.dim() == 3 and input_mask.shape[0] == 1:
|
||||||
input_mask = input_mask.squeeze(0)
|
input_mask = input_mask.squeeze(0)
|
||||||
|
|
||||||
# 确保遮罩尺寸与图像匹配
|
|
||||||
if input_image is not None:
|
if input_image is not None:
|
||||||
expected_shape = input_image.shape[:2]
|
expected_shape = input_image.shape[:2]
|
||||||
if input_mask.shape != expected_shape:
|
if input_mask.shape != expected_shape:
|
||||||
@@ -195,45 +190,41 @@ class CanvasNode:
|
|||||||
print(f"Error in add_mask_to_canvas: {str(e)}")
|
print(f"Error in add_mask_to_canvas: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def process_canvas_image(self, canvas_image, trigger, output_switch, cache_enabled, input_image=None, input_mask=None):
|
def process_canvas_image(self, canvas_image, trigger, output_switch, cache_enabled, input_image=None,
|
||||||
|
input_mask=None):
|
||||||
try:
|
try:
|
||||||
current_execution = self.get_execution_id()
|
current_execution = self.get_execution_id()
|
||||||
print(f"Processing canvas image, execution ID: {current_execution}")
|
print(f"Processing canvas image, execution ID: {current_execution}")
|
||||||
|
|
||||||
# 检查是否是新的执行
|
|
||||||
if current_execution != self.__class__._canvas_cache['last_execution_id']:
|
if current_execution != self.__class__._canvas_cache['last_execution_id']:
|
||||||
print(f"New execution detected: {current_execution}")
|
print(f"New execution detected: {current_execution}")
|
||||||
# 清除旧的缓存
|
|
||||||
self.__class__._canvas_cache['image'] = None
|
self.__class__._canvas_cache['image'] = None
|
||||||
self.__class__._canvas_cache['mask'] = None
|
self.__class__._canvas_cache['mask'] = None
|
||||||
self.__class__._canvas_cache['last_execution_id'] = current_execution
|
self.__class__._canvas_cache['last_execution_id'] = current_execution
|
||||||
|
|
||||||
# 处理输入图像
|
|
||||||
if input_image is not None:
|
if input_image is not None:
|
||||||
print("Input image received, converting to PIL Image...")
|
print("Input image received, converting to PIL Image...")
|
||||||
# 将tensor转换为PIL Image并存储到缓存
|
|
||||||
if isinstance(input_image, torch.Tensor):
|
if isinstance(input_image, torch.Tensor):
|
||||||
if input_image.dim() == 4:
|
if input_image.dim() == 4:
|
||||||
input_image = input_image.squeeze(0) # 移除batch维度
|
input_image = input_image.squeeze(0) # 移除batch维度
|
||||||
|
|
||||||
# 确保图像格式为[H, W, C]
|
|
||||||
if input_image.shape[0] == 3: # 如果是[C, H, W]格式
|
if input_image.shape[0] == 3: # 如果是[C, H, W]格式
|
||||||
input_image = input_image.permute(1, 2, 0)
|
input_image = input_image.permute(1, 2, 0)
|
||||||
|
|
||||||
# 转换为numpy数组并确保值范围在0-255
|
|
||||||
image_array = (input_image.cpu().numpy() * 255).astype(np.uint8)
|
image_array = (input_image.cpu().numpy() * 255).astype(np.uint8)
|
||||||
|
|
||||||
# 确保数组形状正确
|
|
||||||
if len(image_array.shape) == 2: # 如果是灰度图
|
if len(image_array.shape) == 2: # 如果是灰度图
|
||||||
image_array = np.stack([image_array] * 3, axis=-1)
|
image_array = np.stack([image_array] * 3, axis=-1)
|
||||||
elif len(image_array.shape) == 3 and image_array.shape[-1] != 3:
|
elif len(image_array.shape) == 3 and image_array.shape[-1] != 3:
|
||||||
image_array = np.transpose(image_array, (1, 2, 0))
|
image_array = np.transpose(image_array, (1, 2, 0))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 转换为PIL Image
|
|
||||||
pil_image = Image.fromarray(image_array, 'RGB')
|
pil_image = Image.fromarray(image_array, 'RGB')
|
||||||
print("Successfully converted to PIL Image")
|
print("Successfully converted to PIL Image")
|
||||||
# 存储PIL Image到缓存
|
|
||||||
self.__class__._canvas_cache['image'] = pil_image
|
self.__class__._canvas_cache['image'] = pil_image
|
||||||
print(f"Image stored in cache with size: {pil_image.size}")
|
print(f"Image stored in cache with size: {pil_image.size}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -241,7 +232,6 @@ class CanvasNode:
|
|||||||
print(f"Array shape: {image_array.shape}, dtype: {image_array.dtype}")
|
print(f"Array shape: {image_array.shape}, dtype: {image_array.dtype}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# 处理输入遮罩
|
|
||||||
if input_mask is not None:
|
if input_mask is not None:
|
||||||
print("Input mask received, converting to PIL Image...")
|
print("Input mask received, converting to PIL Image...")
|
||||||
if isinstance(input_mask, torch.Tensor):
|
if isinstance(input_mask, torch.Tensor):
|
||||||
@@ -250,19 +240,17 @@ class CanvasNode:
|
|||||||
if input_mask.dim() == 3 and input_mask.shape[0] == 1:
|
if input_mask.dim() == 3 and input_mask.shape[0] == 1:
|
||||||
input_mask = input_mask.squeeze(0)
|
input_mask = input_mask.squeeze(0)
|
||||||
|
|
||||||
# 转换为PIL Image
|
|
||||||
mask_array = (input_mask.cpu().numpy() * 255).astype(np.uint8)
|
mask_array = (input_mask.cpu().numpy() * 255).astype(np.uint8)
|
||||||
pil_mask = Image.fromarray(mask_array, 'L')
|
pil_mask = Image.fromarray(mask_array, 'L')
|
||||||
print("Successfully converted mask to PIL Image")
|
print("Successfully converted mask to PIL Image")
|
||||||
# 存储遮罩到缓存
|
|
||||||
self.__class__._canvas_cache['mask'] = pil_mask
|
self.__class__._canvas_cache['mask'] = pil_mask
|
||||||
print(f"Mask stored in cache with size: {pil_mask.size}")
|
print(f"Mask stored in cache with size: {pil_mask.size}")
|
||||||
|
|
||||||
# 更新缓存开关状态
|
|
||||||
self.__class__._canvas_cache['cache_enabled'] = cache_enabled
|
self.__class__._canvas_cache['cache_enabled'] = cache_enabled
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 尝试读取画布图像
|
|
||||||
path_image = folder_paths.get_annotated_filepath(canvas_image)
|
path_image = folder_paths.get_annotated_filepath(canvas_image)
|
||||||
i = Image.open(path_image)
|
i = Image.open(path_image)
|
||||||
i = ImageOps.exif_transpose(i)
|
i = ImageOps.exif_transpose(i)
|
||||||
@@ -275,32 +263,31 @@ class CanvasNode:
|
|||||||
image = rgb * alpha + (1 - alpha) * 0.5
|
image = rgb * alpha + (1 - alpha) * 0.5
|
||||||
processed_image = torch.from_numpy(image)[None,]
|
processed_image = torch.from_numpy(image)[None,]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 如果读取失败,创建白色画布
|
|
||||||
processed_image = torch.ones((1, 512, 512, 3), dtype=torch.float32)
|
processed_image = torch.ones((1, 512, 512, 3), dtype=torch.float32)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 尝试读取遮罩图像
|
|
||||||
path_mask = path_image.replace('.png', '_mask.png')
|
path_mask = path_image.replace('.png', '_mask.png')
|
||||||
if os.path.exists(path_mask):
|
if os.path.exists(path_mask):
|
||||||
mask = Image.open(path_mask).convert('L')
|
mask = Image.open(path_mask).convert('L')
|
||||||
mask = np.array(mask).astype(np.float32) / 255.0
|
mask = np.array(mask).astype(np.float32) / 255.0
|
||||||
processed_mask = torch.from_numpy(mask)[None,]
|
processed_mask = torch.from_numpy(mask)[None,]
|
||||||
else:
|
else:
|
||||||
# 如果没有遮罩文件,创建全白遮罩
|
|
||||||
processed_mask = torch.ones((1, processed_image.shape[1], processed_image.shape[2]), dtype=torch.float32)
|
processed_mask = torch.ones((1, processed_image.shape[1], processed_image.shape[2]),
|
||||||
|
dtype=torch.float32)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error loading mask: {str(e)}")
|
print(f"Error loading mask: {str(e)}")
|
||||||
# 创建默认遮罩
|
|
||||||
processed_mask = torch.ones((1, processed_image.shape[1], processed_image.shape[2]), dtype=torch.float32)
|
|
||||||
|
|
||||||
# 输出处理
|
processed_mask = torch.ones((1, processed_image.shape[1], processed_image.shape[2]),
|
||||||
|
dtype=torch.float32)
|
||||||
|
|
||||||
if not output_switch:
|
if not output_switch:
|
||||||
return ()
|
return ()
|
||||||
|
|
||||||
# 更新持久化缓存
|
|
||||||
self.update_persistent_cache()
|
self.update_persistent_cache()
|
||||||
|
|
||||||
# 返回处理后的图像和遮罩
|
|
||||||
return (processed_image, processed_mask)
|
return (processed_image, processed_mask)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -308,14 +295,12 @@ class CanvasNode:
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return ()
|
return ()
|
||||||
|
|
||||||
# 添加获取缓存数据的方法
|
|
||||||
def get_cached_data(self):
|
def get_cached_data(self):
|
||||||
return {
|
return {
|
||||||
'image': self.__class__._canvas_cache['image'],
|
'image': self.__class__._canvas_cache['image'],
|
||||||
'mask': self.__class__._canvas_cache['mask']
|
'mask': self.__class__._canvas_cache['mask']
|
||||||
}
|
}
|
||||||
|
|
||||||
# 添加API路由处理器
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def api_get_data(cls, node_id):
|
def api_get_data(cls, node_id):
|
||||||
try:
|
try:
|
||||||
@@ -329,9 +314,23 @@ class CanvasNode:
|
|||||||
'error': str(e)
|
'error': str(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_latest_image(cls):
|
||||||
|
output_dir = folder_paths.get_output_directory()
|
||||||
|
files = [os.path.join(output_dir, f) for f in os.listdir(output_dir) if
|
||||||
|
os.path.isfile(os.path.join(output_dir, f))]
|
||||||
|
|
||||||
|
image_files = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif'))]
|
||||||
|
|
||||||
|
if not image_files:
|
||||||
|
return None
|
||||||
|
|
||||||
|
latest_image_path = max(image_files, key=os.path.getctime)
|
||||||
|
return latest_image_path
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_flow_status(cls, flow_id=None):
|
def get_flow_status(cls, flow_id=None):
|
||||||
"""获取数据流状态"""
|
|
||||||
if flow_id:
|
if flow_id:
|
||||||
return cls._canvas_cache['data_flow_status'].get(flow_id)
|
return cls._canvas_cache['data_flow_status'].get(flow_id)
|
||||||
return cls._canvas_cache['data_flow_status']
|
return cls._canvas_cache['data_flow_status']
|
||||||
@@ -379,8 +378,30 @@ class CanvasNode:
|
|||||||
'error': str(e)
|
'error': str(e)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@PromptServer.instance.routes.get("/ycnode/get_latest_image")
|
||||||
|
async def get_latest_image_route(request):
|
||||||
|
try:
|
||||||
|
latest_image_path = cls.get_latest_image()
|
||||||
|
if latest_image_path:
|
||||||
|
with open(latest_image_path, "rb") as f:
|
||||||
|
encoded_string = base64.b64encode(f.read()).decode('utf-8')
|
||||||
|
return web.json_response({
|
||||||
|
'success': True,
|
||||||
|
'image_data': f"data:image/png;base64,{encoded_string}"
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': 'No images found in output directory.'
|
||||||
|
}, status=404)
|
||||||
|
except Exception as e:
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
def store_image(self, image_data):
|
def store_image(self, image_data):
|
||||||
# 将base64数据转换为PIL Image并存储
|
|
||||||
if isinstance(image_data, str) and image_data.startswith('data:image'):
|
if isinstance(image_data, str) and image_data.startswith('data:image'):
|
||||||
image_data = image_data.split(',')[1]
|
image_data = image_data.split(',')[1]
|
||||||
image_bytes = base64.b64decode(image_data)
|
image_bytes = base64.b64decode(image_data)
|
||||||
@@ -389,7 +410,7 @@ class CanvasNode:
|
|||||||
self.cached_image = image_data
|
self.cached_image = image_data
|
||||||
|
|
||||||
def get_cached_image(self):
|
def get_cached_image(self):
|
||||||
# 将PIL Image转换为base64
|
|
||||||
if self.cached_image:
|
if self.cached_image:
|
||||||
buffered = io.BytesIO()
|
buffered = io.BytesIO()
|
||||||
self.cached_image.save(buffered, format="PNG")
|
self.cached_image.save(buffered, format="PNG")
|
||||||
@@ -397,31 +418,32 @@ class CanvasNode:
|
|||||||
return f"data:image/png;base64,{img_str}"
|
return f"data:image/png;base64,{img_str}"
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class BiRefNetMatting:
|
class BiRefNetMatting:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.model = None
|
self.model = None
|
||||||
self.model_path = None
|
self.model_path = None
|
||||||
self.model_cache = {}
|
self.model_cache = {}
|
||||||
# 使用 ComfyUI models 目录
|
|
||||||
self.base_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "models")
|
self.base_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
|
||||||
|
"models")
|
||||||
|
|
||||||
def load_model(self, model_path):
|
def load_model(self, model_path):
|
||||||
try:
|
try:
|
||||||
if model_path not in self.model_cache:
|
if model_path not in self.model_cache:
|
||||||
# 使用 ComfyUI models 目录下的 BiRefNet 路径
|
|
||||||
full_model_path = os.path.join(self.base_path, "BiRefNet")
|
full_model_path = os.path.join(self.base_path, "BiRefNet")
|
||||||
|
|
||||||
print(f"Loading BiRefNet model from {full_model_path}...")
|
print(f"Loading BiRefNet model from {full_model_path}...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 直接从Hugging Face加载
|
|
||||||
self.model = AutoModelForImageSegmentation.from_pretrained(
|
self.model = AutoModelForImageSegmentation.from_pretrained(
|
||||||
"ZhengPeng7/BiRefNet",
|
"ZhengPeng7/BiRefNet",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
cache_dir=full_model_path # 使用本地缓存目录
|
cache_dir=full_model_path
|
||||||
)
|
)
|
||||||
|
|
||||||
# 设置为评估模式并移动到GPU
|
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
self.model = self.model.cuda()
|
self.model = self.model.cuda()
|
||||||
@@ -447,23 +469,21 @@ class BiRefNetMatting:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def preprocess_image(self, image):
|
def preprocess_image(self, image):
|
||||||
"""预处理输入图像"""
|
|
||||||
try:
|
try:
|
||||||
# 转换为PIL图像
|
|
||||||
if isinstance(image, torch.Tensor):
|
if isinstance(image, torch.Tensor):
|
||||||
if image.dim() == 4:
|
if image.dim() == 4:
|
||||||
image = image.squeeze(0)
|
image = image.squeeze(0)
|
||||||
if image.dim() == 3:
|
if image.dim() == 3:
|
||||||
image = transforms.ToPILImage()(image)
|
image = transforms.ToPILImage()(image)
|
||||||
|
|
||||||
# 参考nodes.py的预处理
|
|
||||||
transform_image = transforms.Compose([
|
transform_image = transforms.Compose([
|
||||||
transforms.Resize((1024, 1024)),
|
transforms.Resize((1024, 1024)),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||||
])
|
])
|
||||||
|
|
||||||
# 转换为tensor并添加batch维度
|
|
||||||
image_tensor = transform_image(image).unsqueeze(0)
|
image_tensor = transform_image(image).unsqueeze(0)
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@@ -476,14 +496,12 @@ class BiRefNetMatting:
|
|||||||
|
|
||||||
def execute(self, image, model_path, threshold=0.5, refinement=1):
|
def execute(self, image, model_path, threshold=0.5, refinement=1):
|
||||||
try:
|
try:
|
||||||
# 发送开始状态
|
|
||||||
PromptServer.instance.send_sync("matting_status", {"status": "processing"})
|
PromptServer.instance.send_sync("matting_status", {"status": "processing"})
|
||||||
|
|
||||||
# 加载模型
|
|
||||||
if not self.load_model(model_path):
|
if not self.load_model(model_path):
|
||||||
raise RuntimeError("Failed to load model")
|
raise RuntimeError("Failed to load model")
|
||||||
|
|
||||||
# 获取原始尺寸
|
|
||||||
if isinstance(image, torch.Tensor):
|
if isinstance(image, torch.Tensor):
|
||||||
original_size = image.shape[-2:] if image.dim() == 4 else image.shape[-2:]
|
original_size = image.shape[-2:] if image.dim() == 4 else image.shape[-2:]
|
||||||
else:
|
else:
|
||||||
@@ -491,20 +509,17 @@ class BiRefNetMatting:
|
|||||||
|
|
||||||
print(f"Original size: {original_size}")
|
print(f"Original size: {original_size}")
|
||||||
|
|
||||||
# 预处理图像
|
|
||||||
processed_image = self.preprocess_image(image)
|
processed_image = self.preprocess_image(image)
|
||||||
if processed_image is None:
|
if processed_image is None:
|
||||||
raise Exception("Failed to preprocess image")
|
raise Exception("Failed to preprocess image")
|
||||||
|
|
||||||
print(f"Processed image shape: {processed_image.shape}")
|
print(f"Processed image shape: {processed_image.shape}")
|
||||||
|
|
||||||
# 执行推理
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = self.model(processed_image)
|
outputs = self.model(processed_image)
|
||||||
result = outputs[-1].sigmoid().cpu()
|
result = outputs[-1].sigmoid().cpu()
|
||||||
print(f"Model output shape: {result.shape}")
|
print(f"Model output shape: {result.shape}")
|
||||||
|
|
||||||
# 确保结果有正的维度格式 [B, C, H, W]
|
|
||||||
if result.dim() == 3:
|
if result.dim() == 3:
|
||||||
result = result.unsqueeze(1) # 添加通道维度
|
result = result.unsqueeze(1) # 添加通道维度
|
||||||
elif result.dim() == 2:
|
elif result.dim() == 2:
|
||||||
@@ -512,7 +527,6 @@ class BiRefNetMatting:
|
|||||||
|
|
||||||
print(f"Reshaped result shape: {result.shape}")
|
print(f"Reshaped result shape: {result.shape}")
|
||||||
|
|
||||||
# 调整大小
|
|
||||||
result = F.interpolate(
|
result = F.interpolate(
|
||||||
result,
|
result,
|
||||||
size=(original_size[0], original_size[1]), # 明确指定高度和宽度
|
size=(original_size[0], original_size[1]), # 明确指定高度和宽度
|
||||||
@@ -521,17 +535,14 @@ class BiRefNetMatting:
|
|||||||
)
|
)
|
||||||
print(f"Resized result shape: {result.shape}")
|
print(f"Resized result shape: {result.shape}")
|
||||||
|
|
||||||
# 归一化
|
|
||||||
result = result.squeeze() # 移除多余的维度
|
result = result.squeeze() # 移除多余的维度
|
||||||
ma = torch.max(result)
|
ma = torch.max(result)
|
||||||
mi = torch.min(result)
|
mi = torch.min(result)
|
||||||
result = (result - mi) / (ma - mi)
|
result = (result - mi) / (ma - mi)
|
||||||
|
|
||||||
# 应用阈值
|
|
||||||
if threshold > 0:
|
if threshold > 0:
|
||||||
result = (result > threshold).float()
|
result = (result > threshold).float()
|
||||||
|
|
||||||
# 创建mask和结果图像
|
|
||||||
alpha_mask = result.unsqueeze(0).unsqueeze(0) # 确保mask是 [1, 1, H, W]
|
alpha_mask = result.unsqueeze(0).unsqueeze(0) # 确保mask是 [1, 1, H, W]
|
||||||
if isinstance(image, torch.Tensor):
|
if isinstance(image, torch.Tensor):
|
||||||
if image.dim() == 3:
|
if image.dim() == 3:
|
||||||
@@ -541,19 +552,18 @@ class BiRefNetMatting:
|
|||||||
image_tensor = transforms.ToTensor()(image).unsqueeze(0)
|
image_tensor = transforms.ToTensor()(image).unsqueeze(0)
|
||||||
masked_image = image_tensor * alpha_mask
|
masked_image = image_tensor * alpha_mask
|
||||||
|
|
||||||
# 发送完成状态
|
|
||||||
PromptServer.instance.send_sync("matting_status", {"status": "completed"})
|
PromptServer.instance.send_sync("matting_status", {"status": "completed"})
|
||||||
|
|
||||||
return (masked_image, alpha_mask)
|
return (masked_image, alpha_mask)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 发送错误状态
|
|
||||||
PromptServer.instance.send_sync("matting_status", {"status": "error"})
|
PromptServer.instance.send_sync("matting_status", {"status": "error"})
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def IS_CHANGED(cls, image, model_path, threshold, refinement):
|
def IS_CHANGED(cls, image, model_path, threshold, refinement):
|
||||||
"""检查输入是否改变"""
|
|
||||||
m = hashlib.md5()
|
m = hashlib.md5()
|
||||||
m.update(str(image).encode())
|
m.update(str(image).encode())
|
||||||
m.update(str(model_path).encode())
|
m.update(str(model_path).encode())
|
||||||
@@ -561,20 +571,18 @@ class BiRefNetMatting:
|
|||||||
m.update(str(refinement).encode())
|
m.update(str(refinement).encode())
|
||||||
return m.hexdigest()
|
return m.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
@PromptServer.instance.routes.post("/matting")
|
@PromptServer.instance.routes.post("/matting")
|
||||||
async def matting(request):
|
async def matting(request):
|
||||||
try:
|
try:
|
||||||
print("Received matting request")
|
print("Received matting request")
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
|
|
||||||
# 取BiRefNet实例
|
|
||||||
matting = BiRefNetMatting()
|
matting = BiRefNetMatting()
|
||||||
|
|
||||||
# 处理图像数据,现在返回图像tensor和alpha通道
|
|
||||||
image_tensor, original_alpha = convert_base64_to_tensor(data["image"])
|
image_tensor, original_alpha = convert_base64_to_tensor(data["image"])
|
||||||
print(f"Input image shape: {image_tensor.shape}")
|
print(f"Input image shape: {image_tensor.shape}")
|
||||||
|
|
||||||
# 执行抠图
|
|
||||||
matted_image, alpha_mask = matting.execute(
|
matted_image, alpha_mask = matting.execute(
|
||||||
image_tensor,
|
image_tensor,
|
||||||
"BiRefNet/model.safetensors",
|
"BiRefNet/model.safetensors",
|
||||||
@@ -582,7 +590,6 @@ async def matting(request):
|
|||||||
refinement=data.get("refinement", 1)
|
refinement=data.get("refinement", 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 转换结果为base64,包含原始alpha信息
|
|
||||||
result_image = convert_tensor_to_base64(matted_image, alpha_mask, original_alpha)
|
result_image = convert_tensor_to_base64(matted_image, alpha_mask, original_alpha)
|
||||||
result_mask = convert_tensor_to_base64(alpha_mask)
|
result_mask = convert_tensor_to_base64(alpha_mask)
|
||||||
|
|
||||||
@@ -600,34 +607,31 @@ async def matting(request):
|
|||||||
"details": traceback.format_exc()
|
"details": traceback.format_exc()
|
||||||
}, status=500)
|
}, status=500)
|
||||||
|
|
||||||
|
|
||||||
def convert_base64_to_tensor(base64_str):
|
def convert_base64_to_tensor(base64_str):
|
||||||
"""将base64图像数据转换为tensor,保留alpha通道"""
|
|
||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 解码base64数据
|
|
||||||
img_data = base64.b64decode(base64_str.split(',')[1])
|
img_data = base64.b64decode(base64_str.split(',')[1])
|
||||||
img = Image.open(io.BytesIO(img_data))
|
img = Image.open(io.BytesIO(img_data))
|
||||||
|
|
||||||
# 保存原始alpha通道
|
|
||||||
has_alpha = img.mode == 'RGBA'
|
has_alpha = img.mode == 'RGBA'
|
||||||
alpha = None
|
alpha = None
|
||||||
if has_alpha:
|
if has_alpha:
|
||||||
alpha = img.split()[3]
|
alpha = img.split()[3]
|
||||||
# 创建白色背景
|
|
||||||
background = Image.new('RGB', img.size, (255, 255, 255))
|
background = Image.new('RGB', img.size, (255, 255, 255))
|
||||||
background.paste(img, mask=alpha)
|
background.paste(img, mask=alpha)
|
||||||
img = background
|
img = background
|
||||||
elif img.mode != 'RGB':
|
elif img.mode != 'RGB':
|
||||||
img = img.convert('RGB')
|
img = img.convert('RGB')
|
||||||
|
|
||||||
# 转换为tensor
|
|
||||||
transform = transforms.ToTensor()
|
transform = transforms.ToTensor()
|
||||||
img_tensor = transform(img).unsqueeze(0) # [1, C, H, W]
|
img_tensor = transform(img).unsqueeze(0) # [1, C, H, W]
|
||||||
|
|
||||||
if has_alpha:
|
if has_alpha:
|
||||||
# 将alpha转换为tensor并保存
|
|
||||||
alpha_tensor = transforms.ToTensor()(alpha).unsqueeze(0) # [1, 1, H, W]
|
alpha_tensor = transforms.ToTensor()(alpha).unsqueeze(0) # [1, 1, H, W]
|
||||||
return img_tensor, alpha_tensor
|
return img_tensor, alpha_tensor
|
||||||
|
|
||||||
@@ -637,50 +641,43 @@ def convert_base64_to_tensor(base64_str):
|
|||||||
print(f"Error in convert_base64_to_tensor: {str(e)}")
|
print(f"Error in convert_base64_to_tensor: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def convert_tensor_to_base64(tensor, alpha_mask=None, original_alpha=None):
|
def convert_tensor_to_base64(tensor, alpha_mask=None, original_alpha=None):
|
||||||
"""将tensor转换为base64图像数据,支持alpha通道"""
|
|
||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 确保tensor在CPU上
|
|
||||||
tensor = tensor.cpu()
|
tensor = tensor.cpu()
|
||||||
|
|
||||||
# 处理维度
|
|
||||||
if tensor.dim() == 4:
|
if tensor.dim() == 4:
|
||||||
tensor = tensor.squeeze(0) # 移除batch维度
|
tensor = tensor.squeeze(0) # 移除batch维度
|
||||||
if tensor.dim() == 3 and tensor.shape[0] in [1, 3]:
|
if tensor.dim() == 3 and tensor.shape[0] in [1, 3]:
|
||||||
tensor = tensor.permute(1, 2, 0)
|
tensor = tensor.permute(1, 2, 0)
|
||||||
|
|
||||||
# 转换为numpy数组并调整值范围到0-255
|
|
||||||
img_array = (tensor.numpy() * 255).astype(np.uint8)
|
img_array = (tensor.numpy() * 255).astype(np.uint8)
|
||||||
|
|
||||||
# 如果有alpha遮罩和原始alpha
|
|
||||||
if alpha_mask is not None and original_alpha is not None:
|
if alpha_mask is not None and original_alpha is not None:
|
||||||
# 将alpha_mask转换为正确的格式
|
|
||||||
alpha_mask = alpha_mask.cpu().squeeze().numpy()
|
alpha_mask = alpha_mask.cpu().squeeze().numpy()
|
||||||
alpha_mask = (alpha_mask * 255).astype(np.uint8)
|
alpha_mask = (alpha_mask * 255).astype(np.uint8)
|
||||||
|
|
||||||
# 将原始alpha转换为正确的格式
|
|
||||||
original_alpha = original_alpha.cpu().squeeze().numpy()
|
original_alpha = original_alpha.cpu().squeeze().numpy()
|
||||||
original_alpha = (original_alpha * 255).astype(np.uint8)
|
original_alpha = (original_alpha * 255).astype(np.uint8)
|
||||||
|
|
||||||
# 组合alpha_mask和original_alpha
|
|
||||||
combined_alpha = np.minimum(alpha_mask, original_alpha)
|
combined_alpha = np.minimum(alpha_mask, original_alpha)
|
||||||
|
|
||||||
# 创建RGBA图像
|
|
||||||
img = Image.fromarray(img_array, mode='RGB')
|
img = Image.fromarray(img_array, mode='RGB')
|
||||||
alpha_img = Image.fromarray(combined_alpha, mode='L')
|
alpha_img = Image.fromarray(combined_alpha, mode='L')
|
||||||
img.putalpha(alpha_img)
|
img.putalpha(alpha_img)
|
||||||
else:
|
else:
|
||||||
# 处理没有alpha通道的情况
|
|
||||||
if img_array.shape[-1] == 1:
|
if img_array.shape[-1] == 1:
|
||||||
img_array = img_array.squeeze(-1)
|
img_array = img_array.squeeze(-1)
|
||||||
img = Image.fromarray(img_array, mode='L')
|
img = Image.fromarray(img_array, mode='L')
|
||||||
else:
|
else:
|
||||||
img = Image.fromarray(img_array, mode='RGB')
|
img = Image.fromarray(img_array, mode='RGB')
|
||||||
|
|
||||||
# 转换为base64
|
|
||||||
buffer = io.BytesIO()
|
buffer = io.BytesIO()
|
||||||
img.save(buffer, format='PNG')
|
img.save(buffer, format='PNG')
|
||||||
img_str = base64.b64encode(buffer.getvalue()).decode()
|
img_str = base64.b64encode(buffer.getvalue()).decode()
|
||||||
|
|||||||
1366
js/Canvas.js
1366
js/Canvas.js
File diff suppressed because it is too large
Load Diff
@@ -6,7 +6,6 @@ import { Canvas } from "./Canvas.js";
|
|||||||
async function createCanvasWidget(node, widget, app) {
|
async function createCanvasWidget(node, widget, app) {
|
||||||
const canvas = new Canvas(node, widget);
|
const canvas = new Canvas(node, widget);
|
||||||
|
|
||||||
// 添加全局样式
|
|
||||||
const style = document.createElement('style');
|
const style = document.createElement('style');
|
||||||
style.textContent = `
|
style.textContent = `
|
||||||
.painter-button {
|
.painter-button {
|
||||||
@@ -59,6 +58,12 @@ async function createCanvasWidget(node, widget, app) {
|
|||||||
border: 1px solid #4a5a6a;
|
border: 1px solid #4a5a6a;
|
||||||
border-radius: 6px;
|
border-radius: 6px;
|
||||||
box-shadow: inset 0 0 10px rgba(0,0,0,0.1);
|
box-shadow: inset 0 0 10px rgba(0,0,0,0.1);
|
||||||
|
transition: border-color 0.3s ease; /* Dodano dla płynnej zmiany ramki */
|
||||||
|
}
|
||||||
|
|
||||||
|
.painter-container.drag-over {
|
||||||
|
border-color: #00ff00; /* Zielona ramka podczas przeciągania */
|
||||||
|
border-style: dashed;
|
||||||
}
|
}
|
||||||
|
|
||||||
.painter-dialog {
|
.painter-dialog {
|
||||||
@@ -115,7 +120,6 @@ async function createCanvasWidget(node, widget, app) {
|
|||||||
`;
|
`;
|
||||||
document.head.appendChild(style);
|
document.head.appendChild(style);
|
||||||
|
|
||||||
// 修改控制面板,使其高度自适应
|
|
||||||
const controlPanel = $el("div.painterControlPanel", {}, [
|
const controlPanel = $el("div.painterControlPanel", {}, [
|
||||||
$el("div.controls.painter-controls", {
|
$el("div.controls.painter-controls", {
|
||||||
style: {
|
style: {
|
||||||
@@ -123,7 +127,7 @@ async function createCanvasWidget(node, widget, app) {
|
|||||||
top: "0",
|
top: "0",
|
||||||
left: "0",
|
left: "0",
|
||||||
right: "0",
|
right: "0",
|
||||||
minHeight: "50px", // 改为最小高度
|
minHeight: "50px",
|
||||||
zIndex: "10",
|
zIndex: "10",
|
||||||
background: "linear-gradient(to bottom, #404040, #383838)",
|
background: "linear-gradient(to bottom, #404040, #383838)",
|
||||||
borderBottom: "1px solid #2a2a2a",
|
borderBottom: "1px solid #2a2a2a",
|
||||||
@@ -134,7 +138,7 @@ async function createCanvasWidget(node, widget, app) {
|
|||||||
flexWrap: "wrap",
|
flexWrap: "wrap",
|
||||||
alignItems: "center"
|
alignItems: "center"
|
||||||
},
|
},
|
||||||
// 添加监听器来动态整画布容器的位置
|
|
||||||
onresize: (entries) => {
|
onresize: (entries) => {
|
||||||
const controlsHeight = entries[0].target.offsetHeight;
|
const controlsHeight = entries[0].target.offsetHeight;
|
||||||
canvasContainer.style.top = (controlsHeight + 10) + "px";
|
canvasContainer.style.top = (controlsHeight + 10) + "px";
|
||||||
@@ -149,16 +153,15 @@ async function createCanvasWidget(node, widget, app) {
|
|||||||
input.multiple = true;
|
input.multiple = true;
|
||||||
input.onchange = async (e) => {
|
input.onchange = async (e) => {
|
||||||
for (const file of e.target.files) {
|
for (const file of e.target.files) {
|
||||||
// 创建图片对象
|
|
||||||
const img = new Image();
|
const img = new Image();
|
||||||
img.onload = async () => {
|
img.onload = async () => {
|
||||||
// 计算适当的缩放比例
|
|
||||||
const scale = Math.min(
|
const scale = Math.min(
|
||||||
canvas.width / img.width * 0.8,
|
canvas.width / img.width * 0.8,
|
||||||
canvas.height / img.height * 0.8
|
canvas.height / img.height * 0.8
|
||||||
);
|
);
|
||||||
|
|
||||||
// 创建新图层
|
|
||||||
const layer = {
|
const layer = {
|
||||||
image: img,
|
image: img,
|
||||||
x: (canvas.width - img.width * scale) / 2,
|
x: (canvas.width - img.width * scale) / 2,
|
||||||
@@ -169,17 +172,13 @@ async function createCanvasWidget(node, widget, app) {
|
|||||||
zIndex: canvas.layers.length
|
zIndex: canvas.layers.length
|
||||||
};
|
};
|
||||||
|
|
||||||
// 添加图层并选中
|
|
||||||
canvas.layers.push(layer);
|
canvas.layers.push(layer);
|
||||||
canvas.selectedLayer = layer;
|
canvas.selectedLayer = layer;
|
||||||
|
|
||||||
// 渲染画布
|
|
||||||
canvas.render();
|
canvas.render();
|
||||||
|
|
||||||
// 立即保存并触发输出更新
|
|
||||||
await canvas.saveToServer(widget.value);
|
await canvas.saveToServer(widget.value);
|
||||||
|
|
||||||
// 触发节点更新
|
|
||||||
app.graph.runStep();
|
app.graph.runStep();
|
||||||
};
|
};
|
||||||
img.src = URL.createObjectURL(file);
|
img.src = URL.createObjectURL(file);
|
||||||
@@ -193,32 +192,13 @@ async function createCanvasWidget(node, widget, app) {
|
|||||||
onclick: async () => {
|
onclick: async () => {
|
||||||
try {
|
try {
|
||||||
console.log("Import Input clicked");
|
console.log("Import Input clicked");
|
||||||
console.log("Node ID:", node.id);
|
const success = await canvas.importLatestImage();
|
||||||
|
if (success) {
|
||||||
const response = await fetch(`/ycnode/get_canvas_data/${node.id}`);
|
|
||||||
console.log("Response status:", response.status);
|
|
||||||
|
|
||||||
const result = await response.json();
|
|
||||||
console.log("Full response data:", result);
|
|
||||||
|
|
||||||
if (result.success && result.data) {
|
|
||||||
if (result.data.image) {
|
|
||||||
console.log("Found image data, importing...");
|
|
||||||
await canvas.importImage({
|
|
||||||
image: result.data.image,
|
|
||||||
mask: result.data.mask
|
|
||||||
});
|
|
||||||
await canvas.saveToServer(widget.value);
|
await canvas.saveToServer(widget.value);
|
||||||
app.graph.runStep();
|
app.graph.runStep();
|
||||||
} else {
|
|
||||||
throw new Error("No image data found in cache");
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
throw new Error("Invalid response format");
|
|
||||||
}
|
|
||||||
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Error importing input:", error);
|
console.error("Error during import input process:", error);
|
||||||
alert(`Failed to import input: ${error.message}`);
|
alert(`Failed to import input: ${error.message}`);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -341,21 +321,21 @@ async function createCanvasWidget(node, widget, app) {
|
|||||||
app.graph.runStep();
|
app.graph.runStep();
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
// 添加水平镜像按钮
|
|
||||||
$el("button.painter-button", {
|
$el("button.painter-button", {
|
||||||
textContent: "Mirror H",
|
textContent: "Mirror H",
|
||||||
onclick: () => {
|
onclick: () => {
|
||||||
canvas.mirrorHorizontal();
|
canvas.mirrorHorizontal();
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
// 添加垂直镜像按钮
|
|
||||||
$el("button.painter-button", {
|
$el("button.painter-button", {
|
||||||
textContent: "Mirror V",
|
textContent: "Mirror V",
|
||||||
onclick: () => {
|
onclick: () => {
|
||||||
canvas.mirrorVertical();
|
canvas.mirrorVertical();
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
// 在控制面板中添加抠图按钮
|
|
||||||
$el("button.painter-button", {
|
$el("button.painter-button", {
|
||||||
textContent: "Matting",
|
textContent: "Matting",
|
||||||
onclick: async () => {
|
onclick: async () => {
|
||||||
@@ -364,10 +344,8 @@ async function createCanvasWidget(node, widget, app) {
|
|||||||
throw new Error("Please select an image first");
|
throw new Error("Please select an image first");
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取或创建状态指示器
|
|
||||||
const statusIndicator = MattingStatusIndicator.getInstance(controlPanel.querySelector('.controls'));
|
const statusIndicator = MattingStatusIndicator.getInstance(controlPanel.querySelector('.controls'));
|
||||||
|
|
||||||
// 添加状态监听
|
|
||||||
const updateStatus = (event) => {
|
const updateStatus = (event) => {
|
||||||
const {status} = event.detail;
|
const {status} = event.detail;
|
||||||
statusIndicator.setStatus(status);
|
statusIndicator.setStatus(status);
|
||||||
@@ -376,11 +354,10 @@ async function createCanvasWidget(node, widget, app) {
|
|||||||
api.addEventListener("matting_status", updateStatus);
|
api.addEventListener("matting_status", updateStatus);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// 获取图像据
|
|
||||||
const imageData = await canvas.getLayerImageData(canvas.selectedLayer);
|
const imageData = await canvas.getLayerImageData(canvas.selectedLayer);
|
||||||
console.log("Sending image to server...");
|
console.log("Sending image to server...");
|
||||||
|
|
||||||
// 发送请求
|
|
||||||
const response = await fetch("/matting", {
|
const response = await fetch("/matting", {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: {
|
headers: {
|
||||||
@@ -400,23 +377,20 @@ async function createCanvasWidget(node, widget, app) {
|
|||||||
const result = await response.json();
|
const result = await response.json();
|
||||||
console.log("Creating new layer with matting result...");
|
console.log("Creating new layer with matting result...");
|
||||||
|
|
||||||
// 创建新图层
|
|
||||||
const mattedImage = new Image();
|
const mattedImage = new Image();
|
||||||
mattedImage.onload = async () => {
|
mattedImage.onload = async () => {
|
||||||
// 创建临时画布来处理透明度
|
|
||||||
const tempCanvas = document.createElement('canvas');
|
const tempCanvas = document.createElement('canvas');
|
||||||
const tempCtx = tempCanvas.getContext('2d');
|
const tempCtx = tempCanvas.getContext('2d');
|
||||||
tempCanvas.width = canvas.selectedLayer.width;
|
tempCanvas.width = canvas.selectedLayer.width;
|
||||||
tempCanvas.height = canvas.selectedLayer.height;
|
tempCanvas.height = canvas.selectedLayer.height;
|
||||||
|
|
||||||
// 绘制原始图像
|
|
||||||
tempCtx.drawImage(
|
tempCtx.drawImage(
|
||||||
mattedImage,
|
mattedImage,
|
||||||
0, 0,
|
0, 0,
|
||||||
tempCanvas.width, tempCanvas.height
|
tempCanvas.width, tempCanvas.height
|
||||||
);
|
);
|
||||||
|
|
||||||
// 创建新图层
|
|
||||||
const newImage = new Image();
|
const newImage = new Image();
|
||||||
newImage.onload = async () => {
|
newImage.onload = async () => {
|
||||||
const newLayer = {
|
const newLayer = {
|
||||||
@@ -433,12 +407,10 @@ async function createCanvasWidget(node, widget, app) {
|
|||||||
canvas.selectedLayer = newLayer;
|
canvas.selectedLayer = newLayer;
|
||||||
canvas.render();
|
canvas.render();
|
||||||
|
|
||||||
// 保存并更新
|
|
||||||
await canvas.saveToServer(widget.value);
|
await canvas.saveToServer(widget.value);
|
||||||
app.graph.runStep();
|
app.graph.runStep();
|
||||||
};
|
};
|
||||||
|
|
||||||
// 转换为PNG并保持透明度
|
|
||||||
newImage.src = tempCanvas.toDataURL('image/png');
|
newImage.src = tempCanvas.toDataURL('image/png');
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -458,29 +430,24 @@ async function createCanvasWidget(node, widget, app) {
|
|||||||
])
|
])
|
||||||
]);
|
]);
|
||||||
|
|
||||||
// 创建ResizeObserver来监控控制面板的高度变化
|
|
||||||
const resizeObserver = new ResizeObserver((entries) => {
|
const resizeObserver = new ResizeObserver((entries) => {
|
||||||
const controlsHeight = entries[0].target.offsetHeight;
|
const controlsHeight = entries[0].target.offsetHeight;
|
||||||
canvasContainer.style.top = (controlsHeight + 10) + "px";
|
canvasContainer.style.top = (controlsHeight + 10) + "px";
|
||||||
});
|
});
|
||||||
|
|
||||||
// 监控控制面板的大小变化
|
|
||||||
resizeObserver.observe(controlPanel.querySelector('.controls'));
|
resizeObserver.observe(controlPanel.querySelector('.controls'));
|
||||||
|
|
||||||
// 获取触发器widget
|
|
||||||
const triggerWidget = node.widgets.find(w => w.name === "trigger");
|
const triggerWidget = node.widgets.find(w => w.name === "trigger");
|
||||||
|
|
||||||
// 创建更新函数
|
|
||||||
const updateOutput = async () => {
|
const updateOutput = async () => {
|
||||||
// 保存画布
|
|
||||||
await canvas.saveToServer(widget.value);
|
await canvas.saveToServer(widget.value);
|
||||||
// 更新触发器值
|
|
||||||
triggerWidget.value = (triggerWidget.value + 1) % 99999999;
|
triggerWidget.value = (triggerWidget.value + 1) % 99999999;
|
||||||
// 触发节点更新
|
|
||||||
app.graph.runStep();
|
app.graph.runStep();
|
||||||
};
|
};
|
||||||
|
|
||||||
// 修改所有可能触发更新的操作
|
|
||||||
const addUpdateToButton = (button) => {
|
const addUpdateToButton = (button) => {
|
||||||
const origClick = button.onclick;
|
const origClick = button.onclick;
|
||||||
button.onclick = async (...args) => {
|
button.onclick = async (...args) => {
|
||||||
@@ -489,63 +456,27 @@ async function createCanvasWidget(node, widget, app) {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
// 为所有按钮添加更新逻辑
|
|
||||||
controlPanel.querySelectorAll('button').forEach(addUpdateToButton);
|
controlPanel.querySelectorAll('button').forEach(addUpdateToButton);
|
||||||
|
|
||||||
// 修改画布容器样式,使用动态top值
|
|
||||||
const canvasContainer = $el("div.painterCanvasContainer.painter-container", {
|
const canvasContainer = $el("div.painterCanvasContainer.painter-container", {
|
||||||
style: {
|
style: {
|
||||||
position: "absolute",
|
position: "absolute",
|
||||||
top: "60px", // 初始值
|
top: "60px",
|
||||||
left: "10px",
|
left: "10px",
|
||||||
right: "10px",
|
right: "10px",
|
||||||
bottom: "10px",
|
bottom: "10px",
|
||||||
display: "flex",
|
|
||||||
justifyContent: "center",
|
|
||||||
alignItems: "center",
|
|
||||||
overflow: "hidden"
|
overflow: "hidden"
|
||||||
}
|
}
|
||||||
}, [canvas.canvas]);
|
}, [canvas.canvas]);
|
||||||
|
|
||||||
// 修改节点大小调整逻辑
|
|
||||||
node.onResize = function () {
|
node.onResize = function () {
|
||||||
const minSize = 300;
|
|
||||||
const controlsElement = controlPanel.querySelector('.controls');
|
|
||||||
const controlPanelHeight = controlsElement.offsetHeight; // 取实际高
|
|
||||||
const padding = 20;
|
|
||||||
|
|
||||||
// 保持节点宽度,高度根据画布比例调整
|
|
||||||
const width = Math.max(this.size[0], minSize);
|
|
||||||
const height = Math.max(
|
|
||||||
width * (canvas.height / canvas.width) + controlPanelHeight + padding * 2,
|
|
||||||
minSize + controlPanelHeight
|
|
||||||
);
|
|
||||||
|
|
||||||
this.size[0] = width;
|
|
||||||
this.size[1] = height;
|
|
||||||
|
|
||||||
// 计算画布的实际可用空间
|
|
||||||
const availableWidth = width - padding * 2;
|
|
||||||
const availableHeight = height - controlPanelHeight - padding * 2;
|
|
||||||
|
|
||||||
// 更新画布尺寸,保持比例
|
|
||||||
const scale = Math.min(
|
|
||||||
availableWidth / canvas.width,
|
|
||||||
availableHeight / canvas.height
|
|
||||||
);
|
|
||||||
|
|
||||||
canvas.canvas.style.width = (canvas.width * scale) + "px";
|
|
||||||
canvas.canvas.style.height = (canvas.height * scale) + "px";
|
|
||||||
|
|
||||||
// 强制重新渲染
|
|
||||||
canvas.render();
|
canvas.render();
|
||||||
};
|
};
|
||||||
|
|
||||||
// 添加拖拽事件监听
|
|
||||||
canvas.canvas.addEventListener('mouseup', updateOutput);
|
canvas.canvas.addEventListener('mouseup', updateOutput);
|
||||||
canvas.canvas.addEventListener('mouseleave', updateOutput);
|
canvas.canvas.addEventListener('mouseleave', updateOutput);
|
||||||
|
|
||||||
// 创建一个包含控制面板和画布的容器
|
|
||||||
const mainContainer = $el("div.painterMainContainer", {
|
const mainContainer = $el("div.painterMainContainer", {
|
||||||
style: {
|
style: {
|
||||||
position: "relative",
|
position: "relative",
|
||||||
@@ -553,19 +484,80 @@ async function createCanvasWidget(node, widget, app) {
|
|||||||
height: "100%"
|
height: "100%"
|
||||||
}
|
}
|
||||||
}, [controlPanel, canvasContainer]);
|
}, [controlPanel, canvasContainer]);
|
||||||
|
const handleFileLoad = async (file) => {
|
||||||
|
// Sprawdzamy, czy plik jest obrazem
|
||||||
|
if (!file.type.startsWith('image/')) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const img = new Image();
|
||||||
|
img.onload = async () => {
|
||||||
|
// Logika dodawania obrazu jest taka sama jak w przycisku "Add Image"
|
||||||
|
const scale = Math.min(
|
||||||
|
canvas.width / img.width * 0.8,
|
||||||
|
canvas.height / img.height * 0.8
|
||||||
|
);
|
||||||
|
|
||||||
|
const layer = {
|
||||||
|
image: img,
|
||||||
|
x: (canvas.width - img.width * scale) / 2,
|
||||||
|
y: (canvas.height - img.height * scale) / 2,
|
||||||
|
width: img.width * scale,
|
||||||
|
height: img.height * scale,
|
||||||
|
rotation: 0,
|
||||||
|
zIndex: canvas.layers.length,
|
||||||
|
blendMode: 'normal',
|
||||||
|
opacity: 1
|
||||||
|
};
|
||||||
|
|
||||||
|
canvas.layers.push(layer);
|
||||||
|
canvas.selectedLayer = layer;
|
||||||
|
canvas.render();
|
||||||
|
|
||||||
|
// Używamy funkcji updateOutput, aby zapisać stan i uruchomić graf
|
||||||
|
await updateOutput();
|
||||||
|
|
||||||
|
// Zwolnienie zasobu URL
|
||||||
|
URL.revokeObjectURL(img.src);
|
||||||
|
};
|
||||||
|
img.src = URL.createObjectURL(file);
|
||||||
|
};
|
||||||
|
|
||||||
|
mainContainer.addEventListener('dragover', (e) => {
|
||||||
|
e.preventDefault(); // Niezbędne, aby zdarzenie 'drop' zadziałało
|
||||||
|
e.stopPropagation();
|
||||||
|
// Dodajemy klasę, aby pokazać wizualną informację zwrotną
|
||||||
|
canvasContainer.classList.add('drag-over');
|
||||||
|
});
|
||||||
|
|
||||||
|
mainContainer.addEventListener('dragleave', (e) => {
|
||||||
|
e.preventDefault();
|
||||||
|
e.stopPropagation();
|
||||||
|
// Usuwamy klasę po opuszczeniu obszaru
|
||||||
|
canvasContainer.classList.remove('drag-over');
|
||||||
|
});
|
||||||
|
|
||||||
|
mainContainer.addEventListener('drop', async (e) => {
|
||||||
|
e.preventDefault();
|
||||||
|
e.stopPropagation();
|
||||||
|
// Usuwamy klasę po upuszczeniu pliku
|
||||||
|
canvasContainer.classList.remove('drag-over');
|
||||||
|
|
||||||
|
if (e.dataTransfer.files) {
|
||||||
|
// Przetwarzamy wszystkie upuszczone pliki
|
||||||
|
for (const file of e.dataTransfer.files) {
|
||||||
|
await handleFileLoad(file);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
// 将主容器添加到节点
|
|
||||||
const mainWidget = node.addDOMWidget("mainContainer", "widget", mainContainer);
|
const mainWidget = node.addDOMWidget("mainContainer", "widget", mainContainer);
|
||||||
|
|
||||||
// 设置节点的默认大小
|
node.size = [500, 500];
|
||||||
node.size = [500, 500]; // 设置初始大小为正方形
|
|
||||||
|
|
||||||
// 在执行开始时保存数据
|
|
||||||
api.addEventListener("execution_start", async () => {
|
api.addEventListener("execution_start", async () => {
|
||||||
// 保存画布
|
|
||||||
await canvas.saveToServer(widget.value);
|
await canvas.saveToServer(widget.value);
|
||||||
|
|
||||||
// 保存当前节点的输入数据
|
|
||||||
if (node.inputs[0].link) {
|
if (node.inputs[0].link) {
|
||||||
const linkId = node.inputs[0].link;
|
const linkId = node.inputs[0].link;
|
||||||
const inputData = app.nodeOutputs[linkId];
|
const inputData = app.nodeOutputs[linkId];
|
||||||
@@ -575,22 +567,21 @@ async function createCanvasWidget(node, widget, app) {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// 移除原来在 saveToServer 中的缓存清理
|
|
||||||
const originalSaveToServer = canvas.saveToServer;
|
const originalSaveToServer = canvas.saveToServer;
|
||||||
canvas.saveToServer = async function (fileName) {
|
canvas.saveToServer = async function (fileName) {
|
||||||
const result = await originalSaveToServer.call(this, fileName);
|
const result = await originalSaveToServer.call(this, fileName);
|
||||||
// 移除这里的缓存清理
|
|
||||||
// ImageCache.clear();
|
|
||||||
return result;
|
return result;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
node.canvasWidget = canvas;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
canvas: canvas,
|
canvas: canvas,
|
||||||
panel: controlPanel
|
panel: controlPanel
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// 修改状态指示器类,确保单例模式
|
|
||||||
class MattingStatusIndicator {
|
class MattingStatusIndicator {
|
||||||
static instance = null;
|
static instance = null;
|
||||||
|
|
||||||
@@ -637,7 +628,7 @@ class MattingStatusIndicator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
setStatus(status) {
|
setStatus(status) {
|
||||||
this.indicator.className = ''; // 清除所有状态
|
this.indicator.className = '';
|
||||||
if (status) {
|
if (status) {
|
||||||
this.indicator.classList.add(status);
|
this.indicator.classList.add(status);
|
||||||
}
|
}
|
||||||
@@ -649,9 +640,8 @@ class MattingStatusIndicator {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 验证 ComfyUI 的图像数据格式
|
|
||||||
function validateImageData(data) {
|
function validateImageData(data) {
|
||||||
// 打印完整的输入数据结构
|
|
||||||
console.log("Validating data structure:", {
|
console.log("Validating data structure:", {
|
||||||
hasData: !!data,
|
hasData: !!data,
|
||||||
type: typeof data,
|
type: typeof data,
|
||||||
@@ -659,36 +649,31 @@ function validateImageData(data) {
|
|||||||
keys: data ? Object.keys(data) : null,
|
keys: data ? Object.keys(data) : null,
|
||||||
shape: data?.shape,
|
shape: data?.shape,
|
||||||
dataType: data?.data ? data.data.constructor.name : null,
|
dataType: data?.data ? data.data.constructor.name : null,
|
||||||
fullData: data // 打印完整数据
|
fullData: data
|
||||||
});
|
});
|
||||||
|
|
||||||
// 检查是否为空
|
|
||||||
if (!data) {
|
if (!data) {
|
||||||
console.log("Data is null or undefined");
|
console.log("Data is null or undefined");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果是数组,获取第一个元素
|
|
||||||
if (Array.isArray(data)) {
|
if (Array.isArray(data)) {
|
||||||
console.log("Data is array, getting first element");
|
console.log("Data is array, getting first element");
|
||||||
data = data[0];
|
data = data[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查数据结构
|
|
||||||
if (!data || typeof data !== 'object') {
|
if (!data || typeof data !== 'object') {
|
||||||
console.log("Invalid data type");
|
console.log("Invalid data type");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查是否有数据属性
|
|
||||||
if (!data.data) {
|
if (!data.data) {
|
||||||
console.log("Missing data property");
|
console.log("Missing data property");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查数据类型
|
|
||||||
if (!(data.data instanceof Float32Array)) {
|
if (!(data.data instanceof Float32Array)) {
|
||||||
// 如果不是 Float32Array,尝试转换
|
|
||||||
try {
|
try {
|
||||||
data.data = new Float32Array(data.data);
|
data.data = new Float32Array(data.data);
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
@@ -700,53 +685,44 @@ function validateImageData(data) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 转换 ComfyUI 图像数据为画布可用格式
|
|
||||||
function convertImageData(data) {
|
function convertImageData(data) {
|
||||||
console.log("Converting image data:", data);
|
console.log("Converting image data:", data);
|
||||||
|
|
||||||
// 如果是数组,获取第一个元素
|
|
||||||
if (Array.isArray(data)) {
|
if (Array.isArray(data)) {
|
||||||
data = data[0];
|
data = data[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取维度信息 [batch, height, width, channels]
|
|
||||||
const shape = data.shape;
|
const shape = data.shape;
|
||||||
const height = shape[1]; // 1393
|
const height = shape[1];
|
||||||
const width = shape[2]; // 1393
|
const width = shape[2];
|
||||||
const channels = shape[3]; // 3
|
const channels = shape[3];
|
||||||
const floatData = new Float32Array(data.data);
|
const floatData = new Float32Array(data.data);
|
||||||
|
|
||||||
console.log("Processing dimensions:", {height, width, channels});
|
console.log("Processing dimensions:", {height, width, channels});
|
||||||
|
|
||||||
// 创建画布格式的数据 (RGBA)
|
|
||||||
const rgbaData = new Uint8ClampedArray(width * height * 4);
|
const rgbaData = new Uint8ClampedArray(width * height * 4);
|
||||||
|
|
||||||
// 转换数据格式 [batch, height, width, channels] -> RGBA
|
|
||||||
for (let h = 0; h < height; h++) {
|
for (let h = 0; h < height; h++) {
|
||||||
for (let w = 0; w < width; w++) {
|
for (let w = 0; w < width; w++) {
|
||||||
const pixelIndex = (h * width + w) * 4;
|
const pixelIndex = (h * width + w) * 4;
|
||||||
const tensorIndex = (h * width + w) * channels;
|
const tensorIndex = (h * width + w) * channels;
|
||||||
|
|
||||||
// 复制 RGB 通道并转换值范围 (0-1 -> 0-255)
|
|
||||||
for (let c = 0; c < channels; c++) {
|
for (let c = 0; c < channels; c++) {
|
||||||
const value = floatData[tensorIndex + c];
|
const value = floatData[tensorIndex + c];
|
||||||
rgbaData[pixelIndex + c] = Math.max(0, Math.min(255, Math.round(value * 255)));
|
rgbaData[pixelIndex + c] = Math.max(0, Math.min(255, Math.round(value * 255)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置 alpha 通道为完全不透明
|
|
||||||
rgbaData[pixelIndex + 3] = 255;
|
rgbaData[pixelIndex + 3] = 255;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 返回画布可用的格式
|
|
||||||
return {
|
return {
|
||||||
data: rgbaData, // Uint8ClampedArray 格式的 RGBA 数据
|
data: rgbaData,
|
||||||
width: width, // 图像宽度
|
width: width,
|
||||||
height: height // 图像高度
|
height: height
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// 处理遮罩数据
|
|
||||||
function applyMaskToImageData(imageData, maskData) {
|
function applyMaskToImageData(imageData, maskData) {
|
||||||
console.log("Applying mask to image data");
|
console.log("Applying mask to image data");
|
||||||
|
|
||||||
@@ -754,18 +730,16 @@ function applyMaskToImageData(imageData, maskData) {
|
|||||||
const width = imageData.width;
|
const width = imageData.width;
|
||||||
const height = imageData.height;
|
const height = imageData.height;
|
||||||
|
|
||||||
// 获取遮罩数据 [batch, height, width]
|
|
||||||
const maskShape = maskData.shape;
|
const maskShape = maskData.shape;
|
||||||
const maskFloatData = new Float32Array(maskData.data);
|
const maskFloatData = new Float32Array(maskData.data);
|
||||||
|
|
||||||
console.log(`Applying mask of shape: ${maskShape}`);
|
console.log(`Applying mask of shape: ${maskShape}`);
|
||||||
|
|
||||||
// 将遮罩数据应用到 alpha 通道
|
|
||||||
for (let h = 0; h < height; h++) {
|
for (let h = 0; h < height; h++) {
|
||||||
for (let w = 0; w < width; w++) {
|
for (let w = 0; w < width; w++) {
|
||||||
const pixelIndex = (h * width + w) * 4;
|
const pixelIndex = (h * width + w) * 4;
|
||||||
const maskIndex = h * width + w;
|
const maskIndex = h * width + w;
|
||||||
// 使遮罩值作为 alpha 值,转换值范围从 0-1 到 0-255
|
|
||||||
const alpha = maskFloatData[maskIndex];
|
const alpha = maskFloatData[maskIndex];
|
||||||
rgbaData[pixelIndex + 3] = Math.max(0, Math.min(255, Math.round(alpha * 255)));
|
rgbaData[pixelIndex + 3] = Math.max(0, Math.min(255, Math.round(alpha * 255)));
|
||||||
}
|
}
|
||||||
@@ -780,41 +754,35 @@ function applyMaskToImageData(imageData, maskData) {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// 修改缓存管理
|
|
||||||
const ImageCache = {
|
const ImageCache = {
|
||||||
cache: new Map(),
|
cache: new Map(),
|
||||||
|
|
||||||
// 存储图像数据
|
|
||||||
set(key, imageData) {
|
set(key, imageData) {
|
||||||
console.log("Caching image data for key:", key);
|
console.log("Caching image data for key:", key);
|
||||||
this.cache.set(key, imageData);
|
this.cache.set(key, imageData);
|
||||||
},
|
},
|
||||||
|
|
||||||
// 获取图像数据
|
|
||||||
get(key) {
|
get(key) {
|
||||||
const data = this.cache.get(key);
|
const data = this.cache.get(key);
|
||||||
console.log("Retrieved cached data for key:", key, !!data);
|
console.log("Retrieved cached data for key:", key, !!data);
|
||||||
return data;
|
return data;
|
||||||
},
|
},
|
||||||
|
|
||||||
// 检查是否存在
|
|
||||||
has(key) {
|
has(key) {
|
||||||
return this.cache.has(key);
|
return this.cache.has(key);
|
||||||
},
|
},
|
||||||
|
|
||||||
// 清除缓存
|
|
||||||
clear() {
|
clear() {
|
||||||
console.log("Clearing image cache");
|
console.log("Clearing image cache");
|
||||||
this.cache.clear();
|
this.cache.clear();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// 改进数据准备函数
|
|
||||||
function prepareImageForCanvas(inputImage) {
|
function prepareImageForCanvas(inputImage) {
|
||||||
console.log("Preparing image for canvas:", inputImage);
|
console.log("Preparing image for canvas:", inputImage);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// 如果是数组,获取第一个元素
|
|
||||||
if (Array.isArray(inputImage)) {
|
if (Array.isArray(inputImage)) {
|
||||||
inputImage = inputImage[0];
|
inputImage = inputImage[0];
|
||||||
}
|
}
|
||||||
@@ -823,7 +791,6 @@ function prepareImageForCanvas(inputImage) {
|
|||||||
throw new Error("Invalid input image format");
|
throw new Error("Invalid input image format");
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取维度信息 [batch, height, width, channels]
|
|
||||||
const shape = inputImage.shape;
|
const shape = inputImage.shape;
|
||||||
const height = shape[1];
|
const height = shape[1];
|
||||||
const width = shape[2];
|
const width = shape[2];
|
||||||
@@ -832,27 +799,22 @@ function prepareImageForCanvas(inputImage) {
|
|||||||
|
|
||||||
console.log("Image dimensions:", {height, width, channels});
|
console.log("Image dimensions:", {height, width, channels});
|
||||||
|
|
||||||
// 创建 RGBA 格式数据
|
|
||||||
const rgbaData = new Uint8ClampedArray(width * height * 4);
|
const rgbaData = new Uint8ClampedArray(width * height * 4);
|
||||||
|
|
||||||
// 转换数据格式 [batch, height, width, channels] -> RGBA
|
|
||||||
for (let h = 0; h < height; h++) {
|
for (let h = 0; h < height; h++) {
|
||||||
for (let w = 0; w < width; w++) {
|
for (let w = 0; w < width; w++) {
|
||||||
const pixelIndex = (h * width + w) * 4;
|
const pixelIndex = (h * width + w) * 4;
|
||||||
const tensorIndex = (h * width + w) * channels;
|
const tensorIndex = (h * width + w) * channels;
|
||||||
|
|
||||||
// 转换 RGB 通道 (0-1 -> 0-255)
|
|
||||||
for (let c = 0; c < channels; c++) {
|
for (let c = 0; c < channels; c++) {
|
||||||
const value = floatData[tensorIndex + c];
|
const value = floatData[tensorIndex + c];
|
||||||
rgbaData[pixelIndex + c] = Math.max(0, Math.min(255, Math.round(value * 255)));
|
rgbaData[pixelIndex + c] = Math.max(0, Math.min(255, Math.round(value * 255)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置 alpha 通道
|
|
||||||
rgbaData[pixelIndex + 3] = 255;
|
rgbaData[pixelIndex + 3] = 255;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 返回画布需要的格式
|
|
||||||
return {
|
return {
|
||||||
data: rgbaData,
|
data: rgbaData,
|
||||||
width: width,
|
width: width,
|
||||||
@@ -877,6 +839,63 @@ app.registerExtension({
|
|||||||
|
|
||||||
return r;
|
return r;
|
||||||
};
|
};
|
||||||
|
const originalGetExtraMenuOptions = nodeType.prototype.getExtraMenuOptions;
|
||||||
|
nodeType.prototype.getExtraMenuOptions = function (_, options) {
|
||||||
|
originalGetExtraMenuOptions?.apply(this, arguments);
|
||||||
|
|
||||||
|
const self = this;
|
||||||
|
const newOptions = [
|
||||||
|
{
|
||||||
|
content: "Open Image",
|
||||||
|
callback: async () => {
|
||||||
|
try {
|
||||||
|
const blob = await self.canvasWidget.getFlattenedCanvasAsBlob();
|
||||||
|
const url = URL.createObjectURL(blob);
|
||||||
|
window.open(url, '_blank');
|
||||||
|
setTimeout(() => URL.revokeObjectURL(url), 1000);
|
||||||
|
} catch (e) {
|
||||||
|
console.error("Error opening image:", e);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
content: "Copy Image",
|
||||||
|
callback: async () => {
|
||||||
|
try {
|
||||||
|
const blob = await self.canvasWidget.getFlattenedCanvasAsBlob();
|
||||||
|
const item = new ClipboardItem({'image/png': blob});
|
||||||
|
await navigator.clipboard.write([item]);
|
||||||
|
console.log("Image copied to clipboard.");
|
||||||
|
} catch (e) {
|
||||||
|
console.error("Error copying image:", e);
|
||||||
|
alert("Failed to copy image to clipboard.");
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
content: "Save Image",
|
||||||
|
callback: async () => {
|
||||||
|
try {
|
||||||
|
const blob = await self.canvasWidget.getFlattenedCanvasAsBlob();
|
||||||
|
const url = URL.createObjectURL(blob);
|
||||||
|
const a = document.createElement('a');
|
||||||
|
a.href = url;
|
||||||
|
a.download = 'canvas_output.png';
|
||||||
|
document.body.appendChild(a);
|
||||||
|
a.click();
|
||||||
|
document.body.removeChild(a);
|
||||||
|
setTimeout(() => URL.revokeObjectURL(url), 1000);
|
||||||
|
} catch (e) {
|
||||||
|
console.error("Error saving image:", e);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
];
|
||||||
|
if (options.length > 0) {
|
||||||
|
options.unshift({content: "___", disabled: true});
|
||||||
|
}
|
||||||
|
options.unshift(...newOptions);
|
||||||
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ license = {file = "LICENSE"}
|
|||||||
dependencies = ["torch", "torchvision", "transformers", "aiohttp", "numpy", "tqdm", "Pillow"]
|
dependencies = ["torch", "torchvision", "transformers", "aiohttp", "numpy", "tqdm", "Pillow"]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
Repository = "https://github.com/yichengup/Comfyui-Ycanvas"
|
Repository = "https:
|
||||||
# Used by Comfy Registry https://comfyregistry.org
|
|
||||||
|
|
||||||
[tool.comfy]
|
[tool.comfy]
|
||||||
PublisherId = ""
|
PublisherId = ""
|
||||||
|
|||||||
Reference in New Issue
Block a user