mirror of
https://github.com/Azornes/Comfyui-LayerForge.git
synced 2026-03-21 12:52:10 -03:00
Added proper backend validation for both config.json and model.safetensors to confirm model availability. Updated frontend logic to use /matting/check-model response, preventing unnecessary download notifications.
1157 lines
49 KiB
Python
1157 lines
49 KiB
Python
from PIL import Image, ImageOps
|
|
import hashlib
|
|
import torch
|
|
import numpy as np
|
|
import folder_paths
|
|
from server import PromptServer
|
|
from aiohttp import web
|
|
import asyncio
|
|
import threading
|
|
import os
|
|
from tqdm import tqdm
|
|
from torchvision import transforms
|
|
try:
|
|
from transformers import AutoModelForImageSegmentation, PretrainedConfig
|
|
from requests.exceptions import ConnectionError as RequestsConnectionError
|
|
TRANSFORMERS_AVAILABLE = True
|
|
except ImportError:
|
|
TRANSFORMERS_AVAILABLE = False
|
|
import torch.nn.functional as F
|
|
import traceback
|
|
import uuid
|
|
import time
|
|
import base64
|
|
from PIL import Image
|
|
import io
|
|
import sys
|
|
import os
|
|
|
|
try:
|
|
from python.logger import logger, LogLevel, debug, info, warn, error, exception
|
|
from python.config import LOG_LEVEL
|
|
|
|
logger.set_module_level('canvas_node', LogLevel[LOG_LEVEL])
|
|
|
|
logger.configure({
|
|
'log_to_file': True,
|
|
'log_dir': os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')
|
|
})
|
|
|
|
log_debug = lambda *args, **kwargs: debug('canvas_node', *args, **kwargs)
|
|
log_info = lambda *args, **kwargs: info('canvas_node', *args, **kwargs)
|
|
log_warn = lambda *args, **kwargs: warn('canvas_node', *args, **kwargs)
|
|
log_error = lambda *args, **kwargs: error('canvas_node', *args, **kwargs)
|
|
log_exception = lambda *args: exception('canvas_node', *args)
|
|
|
|
log_info("Logger initialized for canvas_node")
|
|
except ImportError as e:
|
|
|
|
print(f"Warning: Logger module not available: {e}")
|
|
|
|
def log_debug(*args): print("[DEBUG]", *args)
|
|
def log_info(*args): print("[INFO]", *args)
|
|
def log_warn(*args): print("[WARN]", *args)
|
|
def log_error(*args): print("[ERROR]", *args)
|
|
def log_exception(*args):
|
|
print("[ERROR]", *args)
|
|
traceback.print_exc()
|
|
|
|
torch.set_float32_matmul_precision('high')
|
|
|
|
|
|
class BiRefNetConfig(PretrainedConfig):
|
|
model_type = "BiRefNet"
|
|
|
|
def __init__(self, bb_pretrained=False, **kwargs):
|
|
self.bb_pretrained = bb_pretrained
|
|
# Add the missing is_encoder_decoder attribute for compatibility with newer transformers
|
|
self.is_encoder_decoder = False
|
|
super().__init__(**kwargs)
|
|
|
|
|
|
class BiRefNet(torch.nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
|
|
self.encoder = torch.nn.Sequential(
|
|
torch.nn.Conv2d(3, 64, kernel_size=3, padding=1),
|
|
torch.nn.ReLU(inplace=True),
|
|
torch.nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
|
torch.nn.ReLU(inplace=True)
|
|
)
|
|
|
|
self.decoder = torch.nn.Sequential(
|
|
torch.nn.Conv2d(64, 32, kernel_size=3, padding=1),
|
|
torch.nn.ReLU(inplace=True),
|
|
torch.nn.Conv2d(32, 1, kernel_size=1)
|
|
)
|
|
|
|
def forward(self, x):
|
|
features = self.encoder(x)
|
|
output = self.decoder(features)
|
|
return [output]
|
|
|
|
|
|
class LayerForgeNode:
|
|
_canvas_data_storage = {}
|
|
_storage_lock = threading.Lock()
|
|
|
|
_canvas_cache = {
|
|
'image': None,
|
|
'mask': None,
|
|
'data_flow_status': {},
|
|
'persistent_cache': {},
|
|
'last_execution_id': None
|
|
}
|
|
|
|
|
|
_websocket_data = {}
|
|
_websocket_listeners = {}
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.flow_id = str(uuid.uuid4())
|
|
self.node_id = None # Will be set when node is created
|
|
|
|
if self.__class__._canvas_cache['persistent_cache']:
|
|
self.restore_cache()
|
|
|
|
def restore_cache(self):
|
|
try:
|
|
persistent = self.__class__._canvas_cache['persistent_cache']
|
|
current_execution = self.get_execution_id()
|
|
|
|
if current_execution != self.__class__._canvas_cache['last_execution_id']:
|
|
log_info(f"New execution detected: {current_execution}")
|
|
self.__class__._canvas_cache['image'] = None
|
|
self.__class__._canvas_cache['mask'] = None
|
|
self.__class__._canvas_cache['last_execution_id'] = current_execution
|
|
else:
|
|
|
|
if persistent.get('image') is not None:
|
|
self.__class__._canvas_cache['image'] = persistent['image']
|
|
log_info("Restored image from persistent cache")
|
|
if persistent.get('mask') is not None:
|
|
self.__class__._canvas_cache['mask'] = persistent['mask']
|
|
log_info("Restored mask from persistent cache")
|
|
except Exception as e:
|
|
log_error(f"Error restoring cache: {str(e)}")
|
|
|
|
def get_execution_id(self):
|
|
|
|
try:
|
|
|
|
return str(int(time.time() * 1000))
|
|
except Exception as e:
|
|
log_error(f"Error getting execution ID: {str(e)}")
|
|
return None
|
|
|
|
def update_persistent_cache(self):
|
|
|
|
try:
|
|
self.__class__._canvas_cache['persistent_cache'] = {
|
|
'image': self.__class__._canvas_cache['image'],
|
|
'mask': self.__class__._canvas_cache['mask']
|
|
}
|
|
log_debug("Updated persistent cache")
|
|
except Exception as e:
|
|
log_error(f"Error updating persistent cache: {str(e)}")
|
|
|
|
def track_data_flow(self, stage, status, data_info=None):
|
|
|
|
flow_status = {
|
|
'timestamp': time.time(),
|
|
'stage': stage,
|
|
'status': status,
|
|
'data_info': data_info
|
|
}
|
|
log_debug(f"Data Flow [{self.flow_id}] - Stage: {stage}, Status: {status}")
|
|
if data_info:
|
|
log_debug(f"Data Info: {data_info}")
|
|
|
|
self.__class__._canvas_cache['data_flow_status'][self.flow_id] = flow_status
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"fit_on_add": ("BOOLEAN", {"default": False, "label_on": "Fit on Add/Paste", "label_off": "Default Behavior"}),
|
|
"show_preview": ("BOOLEAN", {"default": False, "label_on": "Show Preview", "label_off": "Hide Preview"}),
|
|
"auto_refresh_after_generation": ("BOOLEAN", {"default": False, "label_on": "True", "label_off": "False"}),
|
|
"trigger": ("INT", {"default": 0, "min": 0, "max": 99999999, "step": 1}),
|
|
"node_id": ("STRING", {"default": "0"}),
|
|
},
|
|
"optional": {
|
|
"input_image": ("IMAGE",),
|
|
"input_mask": ("MASK",),
|
|
},
|
|
"hidden": {
|
|
"prompt": ("PROMPT",),
|
|
"unique_id": ("UNIQUE_ID",),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE", "MASK")
|
|
RETURN_NAMES = ("image", "mask")
|
|
FUNCTION = "process_canvas_image"
|
|
CATEGORY = "azNodes > LayerForge"
|
|
|
|
def add_image_to_canvas(self, input_image):
|
|
|
|
try:
|
|
|
|
if not isinstance(input_image, torch.Tensor):
|
|
raise ValueError("Input image must be a torch.Tensor")
|
|
|
|
if input_image.dim() == 4:
|
|
input_image = input_image.squeeze(0)
|
|
|
|
if input_image.dim() == 3 and input_image.shape[0] in [1, 3]:
|
|
input_image = input_image.permute(1, 2, 0)
|
|
|
|
return input_image
|
|
|
|
except Exception as e:
|
|
log_error(f"Error in add_image_to_canvas: {str(e)}")
|
|
return None
|
|
|
|
def add_mask_to_canvas(self, input_mask, input_image):
|
|
|
|
try:
|
|
|
|
if not isinstance(input_mask, torch.Tensor):
|
|
raise ValueError("Input mask must be a torch.Tensor")
|
|
|
|
if input_mask.dim() == 4:
|
|
input_mask = input_mask.squeeze(0)
|
|
if input_mask.dim() == 3 and input_mask.shape[0] == 1:
|
|
input_mask = input_mask.squeeze(0)
|
|
|
|
if input_image is not None:
|
|
expected_shape = input_image.shape[:2]
|
|
if input_mask.shape != expected_shape:
|
|
input_mask = F.interpolate(
|
|
input_mask.unsqueeze(0).unsqueeze(0),
|
|
size=expected_shape,
|
|
mode='bilinear',
|
|
align_corners=False
|
|
).squeeze()
|
|
|
|
return input_mask
|
|
|
|
except Exception as e:
|
|
log_error(f"Error in add_mask_to_canvas: {str(e)}")
|
|
return None
|
|
|
|
_processing_lock = threading.Lock()
|
|
|
|
def process_canvas_image(self, fit_on_add, show_preview, auto_refresh_after_generation, trigger, node_id, input_image=None, input_mask=None, prompt=None, unique_id=None):
|
|
|
|
try:
|
|
|
|
if not self.__class__._processing_lock.acquire(blocking=False):
|
|
log_warn(f"Process already in progress for node {node_id}, skipping...")
|
|
|
|
return self.get_cached_data()
|
|
|
|
log_info(f"Lock acquired. Starting process_canvas_image for node_id: {node_id} (fallback unique_id: {unique_id})")
|
|
|
|
# Always store fresh input data, even if None, to clear stale data
|
|
log_info(f"Storing input data for node {node_id} - Image: {input_image is not None}, Mask: {input_mask is not None}")
|
|
|
|
with self.__class__._storage_lock:
|
|
input_data = {}
|
|
|
|
if input_image is not None:
|
|
# Convert image tensor(s) to base64 - handle batch
|
|
if isinstance(input_image, torch.Tensor):
|
|
# Ensure correct shape [B, H, W, C]
|
|
if input_image.dim() == 3:
|
|
input_image = input_image.unsqueeze(0)
|
|
|
|
batch_size = input_image.shape[0]
|
|
log_info(f"Processing batch of {batch_size} image(s)")
|
|
|
|
if batch_size == 1:
|
|
# Single image - keep backward compatibility
|
|
img_np = (input_image.squeeze(0).cpu().numpy() * 255).astype(np.uint8)
|
|
pil_img = Image.fromarray(img_np, 'RGB')
|
|
|
|
# Convert to base64
|
|
buffered = io.BytesIO()
|
|
pil_img.save(buffered, format="PNG")
|
|
img_str = base64.b64encode(buffered.getvalue()).decode()
|
|
input_data['input_image'] = f"data:image/png;base64,{img_str}"
|
|
input_data['input_image_width'] = pil_img.width
|
|
input_data['input_image_height'] = pil_img.height
|
|
log_debug(f"Stored single input image: {pil_img.width}x{pil_img.height}")
|
|
else:
|
|
# Multiple images - store as array
|
|
images_array = []
|
|
for i in range(batch_size):
|
|
img_np = (input_image[i].cpu().numpy() * 255).astype(np.uint8)
|
|
pil_img = Image.fromarray(img_np, 'RGB')
|
|
|
|
# Convert to base64
|
|
buffered = io.BytesIO()
|
|
pil_img.save(buffered, format="PNG")
|
|
img_str = base64.b64encode(buffered.getvalue()).decode()
|
|
images_array.append({
|
|
'data': f"data:image/png;base64,{img_str}",
|
|
'width': pil_img.width,
|
|
'height': pil_img.height
|
|
})
|
|
log_debug(f"Stored batch image {i+1}/{batch_size}: {pil_img.width}x{pil_img.height}")
|
|
|
|
input_data['input_images_batch'] = images_array
|
|
log_info(f"Stored batch of {batch_size} images")
|
|
|
|
if input_mask is not None:
|
|
# Convert mask tensor to base64
|
|
if isinstance(input_mask, torch.Tensor):
|
|
# Ensure correct shape
|
|
if input_mask.dim() == 2:
|
|
input_mask = input_mask.unsqueeze(0)
|
|
if input_mask.dim() == 3 and input_mask.shape[0] == 1:
|
|
input_mask = input_mask.squeeze(0)
|
|
|
|
# Convert to numpy and then to PIL
|
|
mask_np = (input_mask.cpu().numpy() * 255).astype(np.uint8)
|
|
pil_mask = Image.fromarray(mask_np, 'L')
|
|
|
|
# Convert to base64
|
|
mask_buffered = io.BytesIO()
|
|
pil_mask.save(mask_buffered, format="PNG")
|
|
mask_str = base64.b64encode(mask_buffered.getvalue()).decode()
|
|
input_data['input_mask'] = f"data:image/png;base64,{mask_str}"
|
|
log_debug(f"Stored input mask: {pil_mask.width}x{pil_mask.height}")
|
|
|
|
input_data['fit_on_add'] = fit_on_add
|
|
|
|
# Store in a special key for input data (overwrites any previous data)
|
|
self.__class__._canvas_data_storage[f"{node_id}_input"] = input_data
|
|
|
|
storage_key = node_id
|
|
|
|
processed_image = None
|
|
processed_mask = None
|
|
|
|
with self.__class__._storage_lock:
|
|
canvas_data = self.__class__._canvas_data_storage.pop(storage_key, None)
|
|
|
|
if canvas_data:
|
|
log_info(f"Canvas data found for node {storage_key} from WebSocket")
|
|
if canvas_data.get('image'):
|
|
image_data = canvas_data['image'].split(',')[1]
|
|
image_bytes = base64.b64decode(image_data)
|
|
pil_image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
|
|
image_array = np.array(pil_image).astype(np.float32) / 255.0
|
|
processed_image = torch.from_numpy(image_array)[None,]
|
|
log_debug(f"Image loaded from WebSocket, shape: {processed_image.shape}")
|
|
|
|
if canvas_data.get('mask'):
|
|
mask_data = canvas_data['mask'].split(',')[1]
|
|
mask_bytes = base64.b64decode(mask_data)
|
|
pil_mask = Image.open(io.BytesIO(mask_bytes)).convert('L')
|
|
mask_array = np.array(pil_mask).astype(np.float32) / 255.0
|
|
processed_mask = torch.from_numpy(mask_array)[None,]
|
|
log_debug(f"Mask loaded from WebSocket, shape: {processed_mask.shape}")
|
|
else:
|
|
log_warn(f"No canvas data found for node {storage_key} in WebSocket cache.")
|
|
|
|
if processed_image is None:
|
|
log_warn(f"Processed image is still None, creating default blank image.")
|
|
processed_image = torch.zeros((1, 512, 512, 3), dtype=torch.float32)
|
|
if processed_mask is None:
|
|
log_warn(f"Processed mask is still None, creating default blank mask.")
|
|
processed_mask = torch.zeros((1, 512, 512), dtype=torch.float32)
|
|
|
|
log_debug(f"About to return output - Image shape: {processed_image.shape}, Mask shape: {processed_mask.shape}")
|
|
|
|
self.update_persistent_cache()
|
|
|
|
log_info(f"Successfully returning processed image and mask")
|
|
return (processed_image, processed_mask)
|
|
|
|
except Exception as e:
|
|
log_exception(f"Error in process_canvas_image: {str(e)}")
|
|
return (None, None)
|
|
|
|
finally:
|
|
|
|
if self.__class__._processing_lock.locked():
|
|
self.__class__._processing_lock.release()
|
|
log_debug(f"Process completed for node {node_id}, lock released")
|
|
|
|
def get_cached_data(self):
|
|
return {
|
|
'image': self.__class__._canvas_cache['image'],
|
|
'mask': self.__class__._canvas_cache['mask']
|
|
}
|
|
|
|
@classmethod
|
|
def api_get_data(cls, node_id):
|
|
try:
|
|
return {
|
|
'success': True,
|
|
'data': cls._canvas_cache
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
'success': False,
|
|
'error': str(e)
|
|
}
|
|
|
|
@classmethod
|
|
def get_latest_image(cls):
|
|
output_dir = folder_paths.get_output_directory()
|
|
files = [os.path.join(output_dir, f) for f in os.listdir(output_dir) if
|
|
os.path.isfile(os.path.join(output_dir, f))]
|
|
|
|
image_files = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif'))]
|
|
|
|
if not image_files:
|
|
return None
|
|
|
|
latest_image_path = max(image_files, key=os.path.getctime)
|
|
return latest_image_path
|
|
|
|
@classmethod
|
|
def get_latest_images(cls, since_timestamp=0):
|
|
output_dir = folder_paths.get_output_directory()
|
|
files = []
|
|
for f_name in os.listdir(output_dir):
|
|
file_path = os.path.join(output_dir, f_name)
|
|
if os.path.isfile(file_path) and file_path.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
|
|
try:
|
|
mtime = os.path.getmtime(file_path)
|
|
if mtime > since_timestamp:
|
|
files.append((mtime, file_path))
|
|
except OSError:
|
|
continue
|
|
|
|
files.sort(key=lambda x: x[0])
|
|
|
|
return [f[1] for f in files]
|
|
|
|
@classmethod
|
|
def get_flow_status(cls, flow_id=None):
|
|
|
|
if flow_id:
|
|
return cls._canvas_cache['data_flow_status'].get(flow_id)
|
|
return cls._canvas_cache['data_flow_status']
|
|
|
|
@classmethod
|
|
def _cleanup_old_websocket_data(cls):
|
|
"""Clean up old WebSocket data from invalid nodes or data older than 5 minutes"""
|
|
try:
|
|
current_time = time.time()
|
|
cleanup_threshold = 300 # 5 minutes
|
|
|
|
nodes_to_remove = []
|
|
for node_id, data in cls._websocket_data.items():
|
|
|
|
if node_id < 0:
|
|
nodes_to_remove.append(node_id)
|
|
continue
|
|
|
|
if current_time - data.get('timestamp', 0) > cleanup_threshold:
|
|
nodes_to_remove.append(node_id)
|
|
continue
|
|
|
|
for node_id in nodes_to_remove:
|
|
del cls._websocket_data[node_id]
|
|
log_debug(f"Cleaned up old WebSocket data for node {node_id}")
|
|
|
|
if nodes_to_remove:
|
|
log_info(f"Cleaned up {len(nodes_to_remove)} old WebSocket entries")
|
|
|
|
except Exception as e:
|
|
log_error(f"Error during WebSocket cleanup: {str(e)}")
|
|
|
|
@classmethod
|
|
def setup_routes(cls):
|
|
@PromptServer.instance.routes.get("/layerforge/canvas_ws")
|
|
async def handle_canvas_websocket(request):
|
|
ws = web.WebSocketResponse(max_msg_size=33554432)
|
|
await ws.prepare(request)
|
|
|
|
async for msg in ws:
|
|
if msg.type == web.WSMsgType.TEXT:
|
|
try:
|
|
data = msg.json()
|
|
node_id = data.get('nodeId')
|
|
if not node_id:
|
|
await ws.send_json({'status': 'error', 'message': 'nodeId is required'})
|
|
continue
|
|
|
|
image_data = data.get('image')
|
|
mask_data = data.get('mask')
|
|
|
|
with cls._storage_lock:
|
|
cls._canvas_data_storage[node_id] = {
|
|
'image': image_data,
|
|
'mask': mask_data,
|
|
'timestamp': time.time()
|
|
}
|
|
|
|
log_info(f"Received canvas data for node {node_id} via WebSocket")
|
|
|
|
ack_payload = {
|
|
'type': 'ack',
|
|
'nodeId': node_id,
|
|
'status': 'success'
|
|
}
|
|
await ws.send_json(ack_payload)
|
|
log_debug(f"Sent ACK for node {node_id}")
|
|
|
|
except Exception as e:
|
|
log_error(f"Error processing WebSocket message: {e}")
|
|
await ws.send_json({'status': 'error', 'message': str(e)})
|
|
elif msg.type == web.WSMsgType.ERROR:
|
|
log_error(f"WebSocket connection closed with exception {ws.exception()}")
|
|
|
|
log_info("WebSocket connection closed")
|
|
return ws
|
|
|
|
@PromptServer.instance.routes.get("/layerforge/get_input_data/{node_id}")
|
|
async def get_input_data(request):
|
|
try:
|
|
node_id = request.match_info["node_id"]
|
|
log_debug(f"Checking for input data for node: {node_id}")
|
|
|
|
with cls._storage_lock:
|
|
input_key = f"{node_id}_input"
|
|
input_data = cls._canvas_data_storage.get(input_key, None)
|
|
|
|
if input_data:
|
|
log_info(f"Input data found for node {node_id}, sending to frontend")
|
|
return web.json_response({
|
|
'success': True,
|
|
'has_input': True,
|
|
'data': input_data
|
|
})
|
|
else:
|
|
log_debug(f"No input data found for node {node_id}")
|
|
return web.json_response({
|
|
'success': True,
|
|
'has_input': False
|
|
})
|
|
|
|
except Exception as e:
|
|
log_error(f"Error in get_input_data: {str(e)}")
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': str(e)
|
|
}, status=500)
|
|
|
|
@PromptServer.instance.routes.post("/layerforge/clear_input_data/{node_id}")
|
|
async def clear_input_data(request):
|
|
try:
|
|
node_id = request.match_info["node_id"]
|
|
log_info(f"Clearing input data for node: {node_id}")
|
|
|
|
with cls._storage_lock:
|
|
input_key = f"{node_id}_input"
|
|
if input_key in cls._canvas_data_storage:
|
|
del cls._canvas_data_storage[input_key]
|
|
log_info(f"Input data cleared for node {node_id}")
|
|
else:
|
|
log_debug(f"No input data to clear for node {node_id}")
|
|
|
|
return web.json_response({
|
|
'success': True,
|
|
'message': f'Input data cleared for node {node_id}'
|
|
})
|
|
|
|
except Exception as e:
|
|
log_error(f"Error in clear_input_data: {str(e)}")
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': str(e)
|
|
}, status=500)
|
|
|
|
@PromptServer.instance.routes.get("/ycnode/get_canvas_data/{node_id}")
|
|
async def get_canvas_data(request):
|
|
try:
|
|
node_id = request.match_info["node_id"]
|
|
log_debug(f"Received request for node: {node_id}")
|
|
|
|
cache_data = cls._canvas_cache
|
|
log_debug(f"Cache content: {cache_data}")
|
|
log_debug(f"Image in cache: {cache_data['image'] is not None}")
|
|
|
|
response_data = {
|
|
'success': True,
|
|
'data': {
|
|
'image': None,
|
|
'mask': None
|
|
}
|
|
}
|
|
|
|
if cache_data['image'] is not None:
|
|
pil_image = cache_data['image']
|
|
buffered = io.BytesIO()
|
|
pil_image.save(buffered, format="PNG")
|
|
img_str = base64.b64encode(buffered.getvalue()).decode()
|
|
response_data['data']['image'] = f"data:image/png;base64,{img_str}"
|
|
|
|
if cache_data['mask'] is not None:
|
|
pil_mask = cache_data['mask']
|
|
mask_buffer = io.BytesIO()
|
|
pil_mask.save(mask_buffer, format="PNG")
|
|
mask_str = base64.b64encode(mask_buffer.getvalue()).decode()
|
|
response_data['data']['mask'] = f"data:image/png;base64,{mask_str}"
|
|
|
|
return web.json_response(response_data)
|
|
|
|
except Exception as e:
|
|
log_error(f"Error in get_canvas_data: {str(e)}")
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': str(e)
|
|
})
|
|
|
|
@PromptServer.instance.routes.get("/layerforge/get-latest-images/{since}")
|
|
async def get_latest_images_route(request):
|
|
try:
|
|
since_timestamp = float(request.match_info.get('since', 0))
|
|
# JS Timestamps are in milliseconds, Python's are in seconds
|
|
latest_image_paths = cls.get_latest_images(since_timestamp / 1000.0)
|
|
|
|
images_data = []
|
|
for image_path in latest_image_paths:
|
|
with open(image_path, "rb") as f:
|
|
encoded_string = base64.b64encode(f.read()).decode('utf-8')
|
|
images_data.append(f"data:image/png;base64,{encoded_string}")
|
|
|
|
return web.json_response({
|
|
'success': True,
|
|
'images': images_data
|
|
})
|
|
except Exception as e:
|
|
log_error(f"Error in get_latest_images_route: {str(e)}")
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': str(e)
|
|
}, status=500)
|
|
|
|
@PromptServer.instance.routes.get("/ycnode/get_latest_image")
|
|
async def get_latest_image_route(request):
|
|
try:
|
|
latest_image_path = cls.get_latest_image()
|
|
if latest_image_path:
|
|
with open(latest_image_path, "rb") as f:
|
|
encoded_string = base64.b64encode(f.read()).decode('utf-8')
|
|
return web.json_response({
|
|
'success': True,
|
|
'image_data': f"data:image/png;base64,{encoded_string}"
|
|
})
|
|
else:
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': 'No images found in output directory.'
|
|
}, status=404)
|
|
except Exception as e:
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': str(e)
|
|
}, status=500)
|
|
|
|
@PromptServer.instance.routes.post("/ycnode/load_image_from_path")
|
|
async def load_image_from_path_route(request):
|
|
try:
|
|
data = await request.json()
|
|
file_path = data.get('file_path')
|
|
|
|
if not file_path:
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': 'file_path is required'
|
|
}, status=400)
|
|
|
|
log_info(f"Attempting to load image from path: {file_path}")
|
|
|
|
# Check if file exists and is accessible
|
|
if not os.path.exists(file_path):
|
|
log_warn(f"File not found: {file_path}")
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': f'File not found: {file_path}'
|
|
}, status=404)
|
|
|
|
# Check if it's an image file
|
|
valid_extensions = ('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp', '.tiff', '.tif', '.ico', '.avif')
|
|
if not file_path.lower().endswith(valid_extensions):
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': f'Invalid image file extension. Supported: {valid_extensions}'
|
|
}, status=400)
|
|
|
|
# Try to load and convert the image
|
|
try:
|
|
with Image.open(file_path) as img:
|
|
# Convert to RGB if necessary
|
|
if img.mode != 'RGB':
|
|
img = img.convert('RGB')
|
|
|
|
# Convert to base64
|
|
buffered = io.BytesIO()
|
|
img.save(buffered, format="PNG")
|
|
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
|
|
|
log_info(f"Successfully loaded image from path: {file_path}")
|
|
return web.json_response({
|
|
'success': True,
|
|
'image_data': f"data:image/png;base64,{img_str}",
|
|
'width': img.width,
|
|
'height': img.height
|
|
})
|
|
|
|
except Exception as img_error:
|
|
log_error(f"Error processing image file {file_path}: {str(img_error)}")
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': f'Error processing image file: {str(img_error)}'
|
|
}, status=500)
|
|
|
|
except Exception as e:
|
|
log_error(f"Error in load_image_from_path_route: {str(e)}")
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': str(e)
|
|
}, status=500)
|
|
|
|
def store_image(self, image_data):
|
|
|
|
if isinstance(image_data, str) and image_data.startswith('data:image'):
|
|
image_data = image_data.split(',')[1]
|
|
image_bytes = base64.b64decode(image_data)
|
|
self.cached_image = Image.open(io.BytesIO(image_bytes))
|
|
else:
|
|
self.cached_image = image_data
|
|
|
|
def get_cached_image(self):
|
|
|
|
if self.cached_image:
|
|
buffered = io.BytesIO()
|
|
self.cached_image.save(buffered, format="PNG")
|
|
img_str = base64.b64encode(buffered.getvalue()).decode()
|
|
return f"data:image/png;base64,{img_str}"
|
|
return None
|
|
|
|
|
|
class BiRefNetMatting:
|
|
def __init__(self):
|
|
self.model = None
|
|
self.model_path = None
|
|
self.model_cache = {}
|
|
|
|
self.base_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
|
|
"models")
|
|
|
|
def load_model(self, model_path):
|
|
from json.decoder import JSONDecodeError
|
|
try:
|
|
if model_path not in self.model_cache:
|
|
full_model_path = os.path.join(self.base_path, "BiRefNet")
|
|
log_info(f"Loading BiRefNet model from {full_model_path}...")
|
|
try:
|
|
# Try loading with additional configuration to handle compatibility issues
|
|
self.model = AutoModelForImageSegmentation.from_pretrained(
|
|
"ZhengPeng7/BiRefNet",
|
|
trust_remote_code=True,
|
|
cache_dir=full_model_path,
|
|
# Add force_download=False to use cached version if available
|
|
force_download=False,
|
|
# Add local_files_only=False to allow downloading if needed
|
|
local_files_only=False
|
|
)
|
|
self.model.eval()
|
|
if torch.cuda.is_available():
|
|
self.model = self.model.cuda()
|
|
self.model_cache[model_path] = self.model
|
|
log_info("Model loaded successfully from Hugging Face")
|
|
except AttributeError as e:
|
|
if "'Config' object has no attribute 'is_encoder_decoder'" in str(e):
|
|
log_error("Compatibility issue detected with transformers library. This has been fixed in the code.")
|
|
log_error("If you're still seeing this error, please clear the model cache and try again.")
|
|
raise RuntimeError(
|
|
"Model configuration compatibility issue detected. "
|
|
f"Please delete the model cache directory '{full_model_path}' and restart ComfyUI. "
|
|
"This will download a fresh copy of the model with the updated configuration."
|
|
) from e
|
|
else:
|
|
raise e
|
|
except JSONDecodeError as e:
|
|
log_error(f"JSONDecodeError: Failed to load model from {full_model_path}. The model's config.json may be corrupted.")
|
|
raise RuntimeError(
|
|
"The matting model's configuration file (config.json) appears to be corrupted. "
|
|
f"Please manually delete the directory '{full_model_path}' and try again. "
|
|
"This will force a fresh download of the model."
|
|
) from e
|
|
except Exception as e:
|
|
log_error(f"Failed to load model from Hugging Face: {str(e)}")
|
|
# Re-raise with a more informative message
|
|
raise RuntimeError(
|
|
"Failed to download or load the matting model. "
|
|
"This could be due to a network issue, file permissions, or a corrupted model cache. "
|
|
f"Please check your internet connection and the model cache path: {full_model_path}. "
|
|
f"Original error: {str(e)}"
|
|
) from e
|
|
else:
|
|
self.model = self.model_cache[model_path]
|
|
log_debug("Using cached model")
|
|
|
|
except Exception as e:
|
|
# Catch the re-raised exception or any other error
|
|
log_error(f"Error loading model: {str(e)}")
|
|
log_exception("Model loading failed")
|
|
raise # Re-raise the exception to be caught by the execute method
|
|
|
|
def preprocess_image(self, image):
|
|
|
|
try:
|
|
|
|
if isinstance(image, torch.Tensor):
|
|
if image.dim() == 4:
|
|
image = image.squeeze(0)
|
|
if image.dim() == 3:
|
|
image = transforms.ToPILImage()(image)
|
|
|
|
transform_image = transforms.Compose([
|
|
transforms.Resize((1024, 1024)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
])
|
|
|
|
image_tensor = transform_image(image).unsqueeze(0)
|
|
|
|
if torch.cuda.is_available():
|
|
image_tensor = image_tensor.cuda()
|
|
|
|
return image_tensor
|
|
except Exception as e:
|
|
log_error(f"Error preprocessing image: {str(e)}")
|
|
return None
|
|
|
|
def execute(self, image, model_path, threshold=0.5, refinement=1):
|
|
try:
|
|
PromptServer.instance.send_sync("matting_status", {"status": "processing"})
|
|
|
|
self.load_model(model_path)
|
|
|
|
if isinstance(image, torch.Tensor):
|
|
original_size = image.shape[-2:] if image.dim() == 4 else image.shape[-2:]
|
|
else:
|
|
original_size = image.size[::-1]
|
|
|
|
log_debug(f"Original size: {original_size}")
|
|
|
|
processed_image = self.preprocess_image(image)
|
|
if processed_image is None:
|
|
raise Exception("Failed to preprocess image")
|
|
|
|
log_debug(f"Processed image shape: {processed_image.shape}")
|
|
|
|
with torch.no_grad():
|
|
outputs = self.model(processed_image)
|
|
result = outputs[-1].sigmoid().cpu()
|
|
log_debug(f"Model output shape: {result.shape}")
|
|
|
|
if result.dim() == 3:
|
|
result = result.unsqueeze(1) # 添加通道维度
|
|
elif result.dim() == 2:
|
|
result = result.unsqueeze(0).unsqueeze(0) # 添加batch和通道维度
|
|
|
|
log_debug(f"Reshaped result shape: {result.shape}")
|
|
|
|
result = F.interpolate(
|
|
result,
|
|
size=(original_size[0], original_size[1]), # 明确指定高度和宽度
|
|
mode='bilinear',
|
|
align_corners=True
|
|
)
|
|
log_debug(f"Resized result shape: {result.shape}")
|
|
|
|
result = result.squeeze() # 移除多余的维度
|
|
ma = torch.max(result)
|
|
mi = torch.min(result)
|
|
result = (result - mi) / (ma - mi)
|
|
|
|
if threshold > 0:
|
|
result = (result > threshold).float()
|
|
|
|
alpha_mask = result.unsqueeze(0).unsqueeze(0) # 确保mask是 [1, 1, H, W]
|
|
if isinstance(image, torch.Tensor):
|
|
if image.dim() == 3:
|
|
image = image.unsqueeze(0)
|
|
masked_image = image * alpha_mask
|
|
else:
|
|
image_tensor = transforms.ToTensor()(image).unsqueeze(0)
|
|
masked_image = image_tensor * alpha_mask
|
|
|
|
PromptServer.instance.send_sync("matting_status", {"status": "completed"})
|
|
|
|
return (masked_image, alpha_mask)
|
|
|
|
except Exception as e:
|
|
|
|
PromptServer.instance.send_sync("matting_status", {"status": "error"})
|
|
raise e
|
|
|
|
@classmethod
|
|
def IS_CHANGED(cls, image, model_path, threshold, refinement):
|
|
|
|
m = hashlib.md5()
|
|
m.update(str(image).encode())
|
|
m.update(str(model_path).encode())
|
|
m.update(str(threshold).encode())
|
|
m.update(str(refinement).encode())
|
|
return m.hexdigest()
|
|
|
|
_matting_lock = None
|
|
|
|
@PromptServer.instance.routes.get("/matting/check-model")
|
|
async def check_matting_model(request):
|
|
"""Check if the matting model is available and ready to use"""
|
|
try:
|
|
if not TRANSFORMERS_AVAILABLE:
|
|
return web.json_response({
|
|
"available": False,
|
|
"reason": "missing_dependency",
|
|
"message": "The 'transformers' library is required for the matting feature. Please install it by running: pip install transformers"
|
|
})
|
|
|
|
# Check if model exists in cache
|
|
base_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "models")
|
|
model_path = os.path.join(base_path, "BiRefNet")
|
|
|
|
# Look for the actual BiRefNet model structure
|
|
model_files_exist = False
|
|
if os.path.exists(model_path):
|
|
# BiRefNet model from Hugging Face has a specific structure
|
|
# Check for subdirectories that indicate the model is downloaded
|
|
existing_items = os.listdir(model_path) if os.path.isdir(model_path) else []
|
|
|
|
# Look for the model subdirectory (usually named with the model ID)
|
|
model_subdirs = [d for d in existing_items if os.path.isdir(os.path.join(model_path, d)) and
|
|
(d.startswith("models--") or d == "ZhengPeng7--BiRefNet")]
|
|
|
|
if model_subdirs:
|
|
# Found model subdirectory, check inside for actual model files
|
|
for subdir in model_subdirs:
|
|
subdir_path = os.path.join(model_path, subdir)
|
|
# Navigate through the cache structure
|
|
if os.path.exists(os.path.join(subdir_path, "snapshots")):
|
|
snapshots_path = os.path.join(subdir_path, "snapshots")
|
|
snapshot_dirs = os.listdir(snapshots_path) if os.path.isdir(snapshots_path) else []
|
|
|
|
for snapshot in snapshot_dirs:
|
|
snapshot_path = os.path.join(snapshots_path, snapshot)
|
|
snapshot_files = os.listdir(snapshot_path) if os.path.isdir(snapshot_path) else []
|
|
|
|
# Check for essential files - BiRefNet uses model.safetensors
|
|
has_config = "config.json" in snapshot_files
|
|
has_model = "model.safetensors" in snapshot_files or "pytorch_model.bin" in snapshot_files
|
|
has_backbone = "backbone_swin.pth" in snapshot_files or "swin_base_patch4_window12_384_22kto1k.pth" in snapshot_files
|
|
has_birefnet = "birefnet.pth" in snapshot_files or any(f.endswith(".pth") for f in snapshot_files)
|
|
|
|
# Model is valid if it has config and either model.safetensors or other model files
|
|
if has_config and (has_model or has_backbone or has_birefnet):
|
|
model_files_exist = True
|
|
log_info(f"Found model files in: {snapshot_path} (config: {has_config}, model: {has_model})")
|
|
break
|
|
|
|
if model_files_exist:
|
|
break
|
|
|
|
# Also check if there are .pth files directly in the model_path
|
|
if not model_files_exist:
|
|
direct_files = existing_items
|
|
has_config = "config.json" in direct_files
|
|
has_model_files = any(f.endswith((".pth", ".bin", ".safetensors")) for f in direct_files)
|
|
model_files_exist = has_config and has_model_files
|
|
|
|
if model_files_exist:
|
|
log_info(f"Found model files directly in: {model_path}")
|
|
|
|
if model_files_exist:
|
|
# Model files exist, assume it's ready
|
|
log_info("BiRefNet model files detected")
|
|
return web.json_response({
|
|
"available": True,
|
|
"reason": "ready",
|
|
"message": "Model is ready to use"
|
|
})
|
|
else:
|
|
log_info(f"BiRefNet model not found in {model_path}")
|
|
return web.json_response({
|
|
"available": False,
|
|
"reason": "not_downloaded",
|
|
"message": "The matting model needs to be downloaded. This will happen automatically when you first use the matting feature (requires internet connection).",
|
|
"model_path": model_path
|
|
})
|
|
|
|
except Exception as e:
|
|
log_error(f"Error checking matting model: {str(e)}")
|
|
return web.json_response({
|
|
"available": False,
|
|
"reason": "error",
|
|
"message": f"Error checking model status: {str(e)}"
|
|
}, status=500)
|
|
|
|
@PromptServer.instance.routes.post("/matting")
|
|
async def matting(request):
|
|
global _matting_lock
|
|
|
|
if not TRANSFORMERS_AVAILABLE:
|
|
log_error("Matting request failed: 'transformers' library is not installed.")
|
|
return web.json_response({
|
|
"error": "Dependency Not Found",
|
|
"details": "The 'transformers' library is required for the matting feature. Please install it by running: pip install transformers"
|
|
}, status=400)
|
|
|
|
if _matting_lock is not None:
|
|
log_warn("Matting already in progress, rejecting request")
|
|
return web.json_response({
|
|
"error": "Another matting operation is in progress",
|
|
"details": "Please wait for the current operation to complete"
|
|
}, status=429)
|
|
|
|
_matting_lock = True
|
|
try:
|
|
log_info("Received matting request")
|
|
data = await request.json()
|
|
|
|
matting_instance = BiRefNetMatting()
|
|
|
|
image_tensor, original_alpha = convert_base64_to_tensor(data["image"])
|
|
log_debug(f"Input image shape: {image_tensor.shape}")
|
|
|
|
matted_image, alpha_mask = matting_instance.execute(
|
|
image_tensor,
|
|
"BiRefNet/model.safetensors",
|
|
threshold=data.get("threshold", 0.5),
|
|
refinement=data.get("refinement", 1)
|
|
)
|
|
|
|
result_image = convert_tensor_to_base64(matted_image, alpha_mask, original_alpha)
|
|
result_mask = convert_tensor_to_base64(alpha_mask)
|
|
|
|
return web.json_response({
|
|
"matted_image": result_image,
|
|
"alpha_mask": result_mask
|
|
})
|
|
|
|
except RequestsConnectionError as e:
|
|
log_error(f"Connection error during matting model download: {e}")
|
|
return web.json_response({
|
|
"error": "Network Connection Error",
|
|
"details": "Failed to download the matting model from Hugging Face. Please check your internet connection."
|
|
}, status=400)
|
|
except RuntimeError as e:
|
|
log_error(f"Runtime error during matting: {e}")
|
|
return web.json_response({
|
|
"error": "Matting Model Error",
|
|
"details": str(e)
|
|
}, status=500)
|
|
except Exception as e:
|
|
log_exception(f"Error in matting endpoint: {e}")
|
|
# Check for offline error message from Hugging Face
|
|
if "Offline mode is enabled" in str(e) or "Can't load 'ZhengPeng7/BiRefNet' offline" in str(e):
|
|
return web.json_response({
|
|
"error": "Network Connection Error",
|
|
"details": "Failed to download the matting model from Hugging Face. Please check your internet connection and ensure you are not in offline mode."
|
|
}, status=400)
|
|
|
|
return web.json_response({
|
|
"error": "An unexpected error occurred",
|
|
"details": traceback.format_exc()
|
|
}, status=500)
|
|
finally:
|
|
_matting_lock = None
|
|
log_debug("Matting lock released")
|
|
|
|
|
|
def convert_base64_to_tensor(base64_str):
|
|
import base64
|
|
import io
|
|
|
|
try:
|
|
|
|
img_data = base64.b64decode(base64_str.split(',')[1])
|
|
img = Image.open(io.BytesIO(img_data))
|
|
|
|
has_alpha = img.mode == 'RGBA'
|
|
alpha = None
|
|
if has_alpha:
|
|
alpha = img.split()[3]
|
|
|
|
background = Image.new('RGB', img.size, (255, 255, 255))
|
|
background.paste(img, mask=alpha)
|
|
img = background
|
|
elif img.mode != 'RGB':
|
|
img = img.convert('RGB')
|
|
|
|
transform = transforms.ToTensor()
|
|
img_tensor = transform(img).unsqueeze(0) # [1, C, H, W]
|
|
|
|
if has_alpha:
|
|
alpha_tensor = transforms.ToTensor()(alpha).unsqueeze(0) # [1, 1, H, W]
|
|
return img_tensor, alpha_tensor
|
|
|
|
return img_tensor, None
|
|
|
|
except Exception as e:
|
|
log_error(f"Error in convert_base64_to_tensor: {str(e)}")
|
|
raise
|
|
|
|
|
|
def convert_tensor_to_base64(tensor, alpha_mask=None, original_alpha=None):
|
|
import base64
|
|
import io
|
|
|
|
try:
|
|
|
|
tensor = tensor.cpu()
|
|
|
|
if tensor.dim() == 4:
|
|
tensor = tensor.squeeze(0) # 移除batch维度
|
|
if tensor.dim() == 3 and tensor.shape[0] in [1, 3]:
|
|
tensor = tensor.permute(1, 2, 0)
|
|
|
|
img_array = (tensor.numpy() * 255).astype(np.uint8)
|
|
|
|
if alpha_mask is not None and original_alpha is not None:
|
|
|
|
alpha_mask = alpha_mask.cpu().squeeze().numpy()
|
|
alpha_mask = (alpha_mask * 255).astype(np.uint8)
|
|
|
|
original_alpha = original_alpha.cpu().squeeze().numpy()
|
|
original_alpha = (original_alpha * 255).astype(np.uint8)
|
|
|
|
combined_alpha = np.minimum(alpha_mask, original_alpha)
|
|
|
|
img = Image.fromarray(img_array, mode='RGB')
|
|
alpha_img = Image.fromarray(combined_alpha, mode='L')
|
|
img.putalpha(alpha_img)
|
|
else:
|
|
|
|
if img_array.shape[-1] == 1:
|
|
img_array = img_array.squeeze(-1)
|
|
img = Image.fromarray(img_array, mode='L')
|
|
else:
|
|
img = Image.fromarray(img_array, mode='RGB')
|
|
|
|
buffer = io.BytesIO()
|
|
img.save(buffer, format='PNG')
|
|
img_str = base64.b64encode(buffer.getvalue()).decode()
|
|
|
|
return f"data:image/png;base64,{img_str}"
|
|
|
|
except Exception as e:
|
|
log_error(f"Error in convert_tensor_to_base64: {str(e)}")
|
|
log_debug(f"Tensor shape: {tensor.shape}, dtype: {tensor.dtype}")
|
|
raise
|