diff --git a/canvas_node.py b/canvas_node.py index e131bc6..5747c5f 100644 --- a/canvas_node.py +++ b/canvas_node.py @@ -613,42 +613,38 @@ class BiRefNetMatting: def load_model(self, model_path): try: if model_path not in self.model_cache: - full_model_path = os.path.join(self.base_path, "BiRefNet") - log_info(f"Loading BiRefNet model from {full_model_path}...") try: - self.model = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet", trust_remote_code=True, cache_dir=full_model_path ) - self.model.eval() if torch.cuda.is_available(): self.model = self.model.cuda() - self.model_cache[model_path] = self.model log_info("Model loaded successfully from Hugging Face") - log_debug(f"Model type: {type(self.model)}") - log_debug(f"Model device: {next(self.model.parameters()).device}") - except Exception as e: - log_error(f"Failed to load model: {str(e)}") - raise - + log_error(f"Failed to load model from Hugging Face: {str(e)}") + # Re-raise with a more informative message + raise RuntimeError( + "Failed to download or load the matting model. " + "This could be due to a network issue, file permissions, or a corrupted model cache. " + f"Please check your internet connection and the model cache path: {full_model_path}. " + f"Original error: {str(e)}" + ) from e else: self.model = self.model_cache[model_path] log_debug("Using cached model") - return True - except Exception as e: + # Catch the re-raised exception or any other error log_error(f"Error loading model: {str(e)}") log_exception("Model loading failed") - return False + raise # Re-raise the exception to be caught by the execute method def preprocess_image(self, image): @@ -678,11 +674,9 @@ class BiRefNetMatting: def execute(self, image, model_path, threshold=0.5, refinement=1): try: - PromptServer.instance.send_sync("matting_status", {"status": "processing"}) - if not self.load_model(model_path): - raise RuntimeError("Failed to load model") + self.load_model(model_path) if isinstance(image, torch.Tensor): original_size = image.shape[-2:] if image.dim() == 4 else image.shape[-2:]