Improve error handling in BiRefNetMatting model loading

Refines exception handling in the load_model method to provide more informative error messages and re-raise exceptions for upstream handling. Removes boolean return values in favor of exception-based flow, and updates execute to rely on exceptions for error detection.
This commit is contained in:
Dariusz L
2025-07-03 12:04:28 +02:00
parent d40f68b8c6
commit 2ab406ebfd

View File

@@ -613,42 +613,38 @@ class BiRefNetMatting:
def load_model(self, model_path): def load_model(self, model_path):
try: try:
if model_path not in self.model_cache: if model_path not in self.model_cache:
full_model_path = os.path.join(self.base_path, "BiRefNet") full_model_path = os.path.join(self.base_path, "BiRefNet")
log_info(f"Loading BiRefNet model from {full_model_path}...") log_info(f"Loading BiRefNet model from {full_model_path}...")
try: try:
self.model = AutoModelForImageSegmentation.from_pretrained( self.model = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", "ZhengPeng7/BiRefNet",
trust_remote_code=True, trust_remote_code=True,
cache_dir=full_model_path cache_dir=full_model_path
) )
self.model.eval() self.model.eval()
if torch.cuda.is_available(): if torch.cuda.is_available():
self.model = self.model.cuda() self.model = self.model.cuda()
self.model_cache[model_path] = self.model self.model_cache[model_path] = self.model
log_info("Model loaded successfully from Hugging Face") log_info("Model loaded successfully from Hugging Face")
log_debug(f"Model type: {type(self.model)}")
log_debug(f"Model device: {next(self.model.parameters()).device}")
except Exception as e: except Exception as e:
log_error(f"Failed to load model: {str(e)}") log_error(f"Failed to load model from Hugging Face: {str(e)}")
raise # Re-raise with a more informative message
raise RuntimeError(
"Failed to download or load the matting model. "
"This could be due to a network issue, file permissions, or a corrupted model cache. "
f"Please check your internet connection and the model cache path: {full_model_path}. "
f"Original error: {str(e)}"
) from e
else: else:
self.model = self.model_cache[model_path] self.model = self.model_cache[model_path]
log_debug("Using cached model") log_debug("Using cached model")
return True
except Exception as e: except Exception as e:
# Catch the re-raised exception or any other error
log_error(f"Error loading model: {str(e)}") log_error(f"Error loading model: {str(e)}")
log_exception("Model loading failed") log_exception("Model loading failed")
return False raise # Re-raise the exception to be caught by the execute method
def preprocess_image(self, image): def preprocess_image(self, image):
@@ -678,11 +674,9 @@ class BiRefNetMatting:
def execute(self, image, model_path, threshold=0.5, refinement=1): def execute(self, image, model_path, threshold=0.5, refinement=1):
try: try:
PromptServer.instance.send_sync("matting_status", {"status": "processing"}) PromptServer.instance.send_sync("matting_status", {"status": "processing"})
if not self.load_model(model_path): self.load_model(model_path)
raise RuntimeError("Failed to load model")
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
original_size = image.shape[-2:] if image.dim() == 4 else image.shape[-2:] original_size = image.shape[-2:] if image.dim() == 4 else image.shape[-2:]