diff --git a/canvas_node.py b/canvas_node.py index d11f715..d2efe80 100644 --- a/canvas_node.py +++ b/canvas_node.py @@ -92,7 +92,7 @@ class BiRefNet(torch.nn.Module): return [output] -class LayerForgeNode: +class LayerForgeNode: _canvas_data_storage = {} _storage_lock = threading.Lock() @@ -731,42 +731,139 @@ class LayerForgeNode: else: self.cached_image = image_data - def get_cached_image(self): - - if self.cached_image: - buffered = io.BytesIO() - self.cached_image.save(buffered, format="PNG") - img_str = base64.b64encode(buffered.getvalue()).decode() - return f"data:image/png;base64,{img_str}" - return None - - -class BiRefNetMatting: - def __init__(self): - self.model = None - self.model_path = None - self.model_cache = {} - - self.base_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), - "models") - - def load_model(self, model_path): - from json.decoder import JSONDecodeError - try: - if model_path not in self.model_cache: - full_model_path = os.path.join(self.base_path, "BiRefNet") - log_info(f"Loading BiRefNet model from {full_model_path}...") - try: - # Try loading with additional configuration to handle compatibility issues - self.model = AutoModelForImageSegmentation.from_pretrained( - "ZhengPeng7/BiRefNet", - trust_remote_code=True, - cache_dir=full_model_path, - # Add force_download=False to use cached version if available - force_download=False, - # Add local_files_only=False to allow downloading if needed - local_files_only=False - ) + def get_cached_image(self): + + if self.cached_image: + buffered = io.BytesIO() + self.cached_image.save(buffered, format="PNG") + img_str = base64.b64encode(buffered.getvalue()).decode() + return f"data:image/png;base64,{img_str}" + return None + + +def _get_birefnet_base_paths(): + paths = [] + + comfy_models_dir = getattr(folder_paths, "models_dir", None) + if comfy_models_dir: + paths.append(os.path.join(comfy_models_dir, "BiRefNet")) + + legacy_models_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + "models", + "BiRefNet" + ) + paths.append(legacy_models_dir) + + unique_paths = [] + seen = set() + for path in paths: + normalized = os.path.normpath(path) + if normalized not in seen: + seen.add(normalized) + unique_paths.append(path) + + return unique_paths + + +def _is_valid_birefnet_model_dir(path): + if not os.path.isdir(path): + return False + + try: + files = os.listdir(path) + except OSError: + return False + + has_config = "config.json" in files + has_model = "model.safetensors" in files or "pytorch_model.bin" in files + has_backbone = "backbone_swin.pth" in files or "swin_base_patch4_window12_384_22kto1k.pth" in files + has_birefnet = "birefnet.pth" in files or any(f.endswith(".pth") for f in files) + + return has_config and (has_model or has_backbone or has_birefnet) + + +def _find_local_birefnet_model(): + for base_path in _get_birefnet_base_paths(): + if not os.path.isdir(base_path): + continue + + if _is_valid_birefnet_model_dir(base_path): + return base_path + + try: + existing_items = os.listdir(base_path) + except OSError: + continue + + model_subdirs = [ + d for d in existing_items + if os.path.isdir(os.path.join(base_path, d)) and + (d.startswith("models--") or d == "ZhengPeng7--BiRefNet") + ] + + for subdir in model_subdirs: + snapshots_path = os.path.join(base_path, subdir, "snapshots") + if not os.path.isdir(snapshots_path): + continue + + try: + snapshot_dirs = os.listdir(snapshots_path) + except OSError: + continue + + for snapshot in snapshot_dirs: + snapshot_path = os.path.join(snapshots_path, snapshot) + if _is_valid_birefnet_model_dir(snapshot_path): + return snapshot_path + + return None + + +class BiRefNetMatting: + def __init__(self): + self.model = None + self.model_path = None + self.model_cache = {} + self.base_paths = _get_birefnet_base_paths() + + def load_model(self, model_path): + from json.decoder import JSONDecodeError + try: + if model_path not in self.model_cache: + local_model_path = _find_local_birefnet_model() + cache_dir = self.base_paths[0] if self.base_paths else None + + if local_model_path: + log_info(f"Loading BiRefNet model from local path {local_model_path}...") + try: + self.model = AutoModelForImageSegmentation.from_pretrained( + local_model_path, + trust_remote_code=True + ) + self.model.eval() + if torch.cuda.is_available(): + self.model = self.model.cuda() + self.model_cache[model_path] = self.model + log_info("Model loaded successfully from local disk") + return + except Exception as local_error: + log_warn(f"Failed to load local BiRefNet model from {local_model_path}: {str(local_error)}") + log_info("Falling back to Hugging Face model loading") + + full_model_path = cache_dir or "BiRefNet" + log_info(f"Loading BiRefNet model from Hugging Face cache {full_model_path}...") + try: + # Try loading with additional configuration to handle compatibility issues + self.model = AutoModelForImageSegmentation.from_pretrained( + "ZhengPeng7/BiRefNet", + trust_remote_code=True, + cache_dir=cache_dir, + # Add force_download=False to use cached version if available + force_download=False, + # Add local_files_only=False to allow downloading if needed + local_files_only=False + ) self.model.eval() if torch.cuda.is_available(): self.model = self.model.cuda() @@ -923,75 +1020,27 @@ async def check_matting_model(request): "message": "The 'transformers' library is required for the matting feature. Please install it by running: pip install transformers" }) - # Check if model exists in cache - base_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "models") - model_path = os.path.join(base_path, "BiRefNet") - - # Look for the actual BiRefNet model structure - model_files_exist = False - if os.path.exists(model_path): - # BiRefNet model from Hugging Face has a specific structure - # Check for subdirectories that indicate the model is downloaded - existing_items = os.listdir(model_path) if os.path.isdir(model_path) else [] - - # Look for the model subdirectory (usually named with the model ID) - model_subdirs = [d for d in existing_items if os.path.isdir(os.path.join(model_path, d)) and - (d.startswith("models--") or d == "ZhengPeng7--BiRefNet")] - - if model_subdirs: - # Found model subdirectory, check inside for actual model files - for subdir in model_subdirs: - subdir_path = os.path.join(model_path, subdir) - # Navigate through the cache structure - if os.path.exists(os.path.join(subdir_path, "snapshots")): - snapshots_path = os.path.join(subdir_path, "snapshots") - snapshot_dirs = os.listdir(snapshots_path) if os.path.isdir(snapshots_path) else [] - - for snapshot in snapshot_dirs: - snapshot_path = os.path.join(snapshots_path, snapshot) - snapshot_files = os.listdir(snapshot_path) if os.path.isdir(snapshot_path) else [] - - # Check for essential files - BiRefNet uses model.safetensors - has_config = "config.json" in snapshot_files - has_model = "model.safetensors" in snapshot_files or "pytorch_model.bin" in snapshot_files - has_backbone = "backbone_swin.pth" in snapshot_files or "swin_base_patch4_window12_384_22kto1k.pth" in snapshot_files - has_birefnet = "birefnet.pth" in snapshot_files or any(f.endswith(".pth") for f in snapshot_files) - - # Model is valid if it has config and either model.safetensors or other model files - if has_config and (has_model or has_backbone or has_birefnet): - model_files_exist = True - log_info(f"Found model files in: {snapshot_path} (config: {has_config}, model: {has_model})") - break - - if model_files_exist: - break - - # Also check if there are .pth files directly in the model_path - if not model_files_exist: - direct_files = existing_items - has_config = "config.json" in direct_files - has_model_files = any(f.endswith((".pth", ".bin", ".safetensors")) for f in direct_files) - model_files_exist = has_config and has_model_files - - if model_files_exist: - log_info(f"Found model files directly in: {model_path}") - - if model_files_exist: - # Model files exist, assume it's ready - log_info("BiRefNet model files detected") - return web.json_response({ - "available": True, - "reason": "ready", - "message": "Model is ready to use" - }) - else: - log_info(f"BiRefNet model not found in {model_path}") - return web.json_response({ - "available": False, - "reason": "not_downloaded", - "message": "The matting model needs to be downloaded. This will happen automatically when you first use the matting feature (requires internet connection).", - "model_path": model_path - }) + # Check if model exists in cache + local_model_path = _find_local_birefnet_model() + + if local_model_path: + # Model files exist, assume it's ready + log_info(f"BiRefNet model files detected at {local_model_path}") + return web.json_response({ + "available": True, + "reason": "ready", + "message": "Model is ready to use", + "model_path": local_model_path + }) + else: + searched_paths = _get_birefnet_base_paths() + log_info(f"BiRefNet model not found in any of: {searched_paths}") + return web.json_response({ + "available": False, + "reason": "not_downloaded", + "message": "The matting model needs to be downloaded. This will happen automatically when you first use the matting feature (requires internet connection).", + "model_path": searched_paths[0] if searched_paths else None + }) except Exception as e: log_error(f"Error checking matting model: {str(e)}")