diff --git a/canvas_node.py b/canvas_node.py index eb83859..d11f715 100644 --- a/canvas_node.py +++ b/canvas_node.py @@ -64,6 +64,8 @@ class BiRefNetConfig(PretrainedConfig): def __init__(self, bb_pretrained=False, **kwargs): self.bb_pretrained = bb_pretrained + # Add the missing is_encoder_decoder attribute for compatibility with newer transformers + self.is_encoder_decoder = False super().__init__(**kwargs) @@ -755,16 +757,32 @@ class BiRefNetMatting: full_model_path = os.path.join(self.base_path, "BiRefNet") log_info(f"Loading BiRefNet model from {full_model_path}...") try: + # Try loading with additional configuration to handle compatibility issues self.model = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet", trust_remote_code=True, - cache_dir=full_model_path + cache_dir=full_model_path, + # Add force_download=False to use cached version if available + force_download=False, + # Add local_files_only=False to allow downloading if needed + local_files_only=False ) self.model.eval() if torch.cuda.is_available(): self.model = self.model.cuda() self.model_cache[model_path] = self.model log_info("Model loaded successfully from Hugging Face") + except AttributeError as e: + if "'Config' object has no attribute 'is_encoder_decoder'" in str(e): + log_error("Compatibility issue detected with transformers library. This has been fixed in the code.") + log_error("If you're still seeing this error, please clear the model cache and try again.") + raise RuntimeError( + "Model configuration compatibility issue detected. " + f"Please delete the model cache directory '{full_model_path}' and restart ComfyUI. " + "This will download a fresh copy of the model with the updated configuration." + ) from e + else: + raise e except JSONDecodeError as e: log_error(f"JSONDecodeError: Failed to load model from {full_model_path}. The model's config.json may be corrupted.") raise RuntimeError( @@ -894,6 +912,95 @@ class BiRefNetMatting: _matting_lock = None +@PromptServer.instance.routes.get("/matting/check-model") +async def check_matting_model(request): + """Check if the matting model is available and ready to use""" + try: + if not TRANSFORMERS_AVAILABLE: + return web.json_response({ + "available": False, + "reason": "missing_dependency", + "message": "The 'transformers' library is required for the matting feature. Please install it by running: pip install transformers" + }) + + # Check if model exists in cache + base_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "models") + model_path = os.path.join(base_path, "BiRefNet") + + # Look for the actual BiRefNet model structure + model_files_exist = False + if os.path.exists(model_path): + # BiRefNet model from Hugging Face has a specific structure + # Check for subdirectories that indicate the model is downloaded + existing_items = os.listdir(model_path) if os.path.isdir(model_path) else [] + + # Look for the model subdirectory (usually named with the model ID) + model_subdirs = [d for d in existing_items if os.path.isdir(os.path.join(model_path, d)) and + (d.startswith("models--") or d == "ZhengPeng7--BiRefNet")] + + if model_subdirs: + # Found model subdirectory, check inside for actual model files + for subdir in model_subdirs: + subdir_path = os.path.join(model_path, subdir) + # Navigate through the cache structure + if os.path.exists(os.path.join(subdir_path, "snapshots")): + snapshots_path = os.path.join(subdir_path, "snapshots") + snapshot_dirs = os.listdir(snapshots_path) if os.path.isdir(snapshots_path) else [] + + for snapshot in snapshot_dirs: + snapshot_path = os.path.join(snapshots_path, snapshot) + snapshot_files = os.listdir(snapshot_path) if os.path.isdir(snapshot_path) else [] + + # Check for essential files - BiRefNet uses model.safetensors + has_config = "config.json" in snapshot_files + has_model = "model.safetensors" in snapshot_files or "pytorch_model.bin" in snapshot_files + has_backbone = "backbone_swin.pth" in snapshot_files or "swin_base_patch4_window12_384_22kto1k.pth" in snapshot_files + has_birefnet = "birefnet.pth" in snapshot_files or any(f.endswith(".pth") for f in snapshot_files) + + # Model is valid if it has config and either model.safetensors or other model files + if has_config and (has_model or has_backbone or has_birefnet): + model_files_exist = True + log_info(f"Found model files in: {snapshot_path} (config: {has_config}, model: {has_model})") + break + + if model_files_exist: + break + + # Also check if there are .pth files directly in the model_path + if not model_files_exist: + direct_files = existing_items + has_config = "config.json" in direct_files + has_model_files = any(f.endswith((".pth", ".bin", ".safetensors")) for f in direct_files) + model_files_exist = has_config and has_model_files + + if model_files_exist: + log_info(f"Found model files directly in: {model_path}") + + if model_files_exist: + # Model files exist, assume it's ready + log_info("BiRefNet model files detected") + return web.json_response({ + "available": True, + "reason": "ready", + "message": "Model is ready to use" + }) + else: + log_info(f"BiRefNet model not found in {model_path}") + return web.json_response({ + "available": False, + "reason": "not_downloaded", + "message": "The matting model needs to be downloaded. This will happen automatically when you first use the matting feature (requires internet connection).", + "model_path": model_path + }) + + except Exception as e: + log_error(f"Error checking matting model: {str(e)}") + return web.json_response({ + "available": False, + "reason": "error", + "message": f"Error checking model status: {str(e)}" + }, status=500) + @PromptServer.instance.routes.post("/matting") async def matting(request): global _matting_lock diff --git a/js/CanvasView.js b/js/CanvasView.js index e2d94bc..4b12db2 100644 --- a/js/CanvasView.js +++ b/js/CanvasView.js @@ -343,11 +343,38 @@ async function createCanvasWidget(node, widget, app) { const button = e.target.closest('.matting-button'); if (button.classList.contains('loading')) return; - const spinner = $el("div.matting-spinner"); - button.appendChild(spinner); - button.classList.add('loading'); - showInfoNotification("Starting background removal process...", 2000); try { + // First check if model is available + const modelCheckResponse = await fetch("/matting/check-model"); + const modelStatus = await modelCheckResponse.json(); + if (!modelStatus.available) { + switch (modelStatus.reason) { + case 'missing_dependency': + showErrorNotification(modelStatus.message, 8000); + return; + case 'not_downloaded': + showWarningNotification("The matting model needs to be downloaded first. This will happen automatically when you proceed (requires internet connection).", 5000); + // Ask user if they want to proceed with download + if (!confirm("The matting model needs to be downloaded (about 1GB). This is a one-time download. Do you want to proceed?")) { + return; + } + showInfoNotification("Downloading matting model... This may take a few minutes.", 10000); + break; + case 'corrupted': + showErrorNotification(modelStatus.message, 8000); + return; + case 'error': + showErrorNotification(`Error checking model: ${modelStatus.message}`, 5000); + return; + } + } + // Proceed with matting + const spinner = $el("div.matting-spinner"); + button.appendChild(spinner); + button.classList.add('loading'); + if (modelStatus.available) { + showInfoNotification("Starting background removal process...", 2000); + } if (canvas.canvasSelection.selectedLayers.length !== 1) { throw new Error("Please select exactly one image layer for matting."); } @@ -363,7 +390,20 @@ async function createCanvasWidget(node, widget, app) { if (!response.ok) { let errorMsg = `Server error: ${response.status} - ${response.statusText}`; if (result && result.error) { - errorMsg = `Error: ${result.error}. Details: ${result.details || 'Check console'}`; + // Handle specific error types + if (result.error === "Network Connection Error") { + showErrorNotification("Failed to download the matting model. Please check your internet connection and try again.", 8000); + return; + } + else if (result.error === "Matting Model Error") { + showErrorNotification(result.details || "Model loading error. Please check the console for details.", 8000); + return; + } + else if (result.error === "Dependency Not Found") { + showErrorNotification(result.details || "Missing required dependencies.", 8000); + return; + } + errorMsg = `${result.error}: ${result.details || 'Check console'}`; } throw new Error(errorMsg); } @@ -383,11 +423,16 @@ async function createCanvasWidget(node, widget, app) { catch (error) { log.error("Matting error:", error); const errorMessage = error.message || "An unknown error occurred."; - showErrorNotification(`Matting Failed: ${errorMessage}`); + if (!errorMessage.includes("Network Connection Error") && + !errorMessage.includes("Matting Model Error") && + !errorMessage.includes("Dependency Not Found")) { + showErrorNotification(`Matting Failed: ${errorMessage}`); + } } finally { button.classList.remove('loading'); - if (button.contains(spinner)) { + const spinner = button.querySelector('.matting-spinner'); + if (spinner && button.contains(spinner)) { button.removeChild(spinner); } } diff --git a/src/CanvasView.ts b/src/CanvasView.ts index 8584880..a1940c7 100644 --- a/src/CanvasView.ts +++ b/src/CanvasView.ts @@ -418,13 +418,46 @@ async function createCanvasWidget(node: ComfyNode, widget: any, app: ComfyApp): const button = (e.target as HTMLElement).closest('.matting-button') as HTMLButtonElement; if (button.classList.contains('loading')) return; - const spinner = $el("div.matting-spinner") as HTMLDivElement; - button.appendChild(spinner); - button.classList.add('loading'); - - showInfoNotification("Starting background removal process...", 2000); - try { + // First check if model is available + const modelCheckResponse = await fetch("/matting/check-model"); + const modelStatus = await modelCheckResponse.json(); + + if (!modelStatus.available) { + switch (modelStatus.reason) { + case 'missing_dependency': + showErrorNotification(modelStatus.message, 8000); + return; + + case 'not_downloaded': + showWarningNotification("The matting model needs to be downloaded first. This will happen automatically when you proceed (requires internet connection).", 5000); + + // Ask user if they want to proceed with download + if (!confirm("The matting model needs to be downloaded (about 1GB). This is a one-time download. Do you want to proceed?")) { + return; + } + showInfoNotification("Downloading matting model... This may take a few minutes.", 10000); + break; + + case 'corrupted': + showErrorNotification(modelStatus.message, 8000); + return; + + case 'error': + showErrorNotification(`Error checking model: ${modelStatus.message}`, 5000); + return; + } + } + + // Proceed with matting + const spinner = $el("div.matting-spinner") as HTMLDivElement; + button.appendChild(spinner); + button.classList.add('loading'); + + if (modelStatus.available) { + showInfoNotification("Starting background removal process...", 2000); + } + if (canvas.canvasSelection.selectedLayers.length !== 1) { throw new Error("Please select exactly one image layer for matting."); } @@ -443,7 +476,18 @@ async function createCanvasWidget(node: ComfyNode, widget: any, app: ComfyApp): if (!response.ok) { let errorMsg = `Server error: ${response.status} - ${response.statusText}`; if (result && result.error) { - errorMsg = `Error: ${result.error}. Details: ${result.details || 'Check console'}`; + // Handle specific error types + if (result.error === "Network Connection Error") { + showErrorNotification("Failed to download the matting model. Please check your internet connection and try again.", 8000); + return; + } else if (result.error === "Matting Model Error") { + showErrorNotification(result.details || "Model loading error. Please check the console for details.", 8000); + return; + } else if (result.error === "Dependency Not Found") { + showErrorNotification(result.details || "Missing required dependencies.", 8000); + return; + } + errorMsg = `${result.error}: ${result.details || 'Check console'}`; } throw new Error(errorMsg); } @@ -468,10 +512,15 @@ async function createCanvasWidget(node: ComfyNode, widget: any, app: ComfyApp): } catch (error: any) { log.error("Matting error:", error); const errorMessage = error.message || "An unknown error occurred."; - showErrorNotification(`Matting Failed: ${errorMessage}`); + if (!errorMessage.includes("Network Connection Error") && + !errorMessage.includes("Matting Model Error") && + !errorMessage.includes("Dependency Not Found")) { + showErrorNotification(`Matting Failed: ${errorMessage}`); + } } finally { button.classList.remove('loading'); - if (button.contains(spinner)) { + const spinner = button.querySelector('.matting-spinner'); + if (spinner && button.contains(spinner)) { button.removeChild(spinner); } }