mirror of
https://github.com/Azornes/Comfyui-LayerForge.git
synced 2026-03-21 20:52:12 -03:00
Improve error handling in BiRefNetMatting model loading
Refines exception handling in the load_model method to provide more informative error messages and re-raise exceptions for upstream handling. Removes boolean return values in favor of exception-based flow, and updates execute to rely on exceptions for error detection.
This commit is contained in:
@@ -613,42 +613,38 @@ class BiRefNetMatting:
|
|||||||
def load_model(self, model_path):
|
def load_model(self, model_path):
|
||||||
try:
|
try:
|
||||||
if model_path not in self.model_cache:
|
if model_path not in self.model_cache:
|
||||||
|
|
||||||
full_model_path = os.path.join(self.base_path, "BiRefNet")
|
full_model_path = os.path.join(self.base_path, "BiRefNet")
|
||||||
|
|
||||||
log_info(f"Loading BiRefNet model from {full_model_path}...")
|
log_info(f"Loading BiRefNet model from {full_model_path}...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
self.model = AutoModelForImageSegmentation.from_pretrained(
|
self.model = AutoModelForImageSegmentation.from_pretrained(
|
||||||
"ZhengPeng7/BiRefNet",
|
"ZhengPeng7/BiRefNet",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
cache_dir=full_model_path
|
cache_dir=full_model_path
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
self.model = self.model.cuda()
|
self.model = self.model.cuda()
|
||||||
|
|
||||||
self.model_cache[model_path] = self.model
|
self.model_cache[model_path] = self.model
|
||||||
log_info("Model loaded successfully from Hugging Face")
|
log_info("Model loaded successfully from Hugging Face")
|
||||||
log_debug(f"Model type: {type(self.model)}")
|
|
||||||
log_debug(f"Model device: {next(self.model.parameters()).device}")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log_error(f"Failed to load model: {str(e)}")
|
log_error(f"Failed to load model from Hugging Face: {str(e)}")
|
||||||
raise
|
# Re-raise with a more informative message
|
||||||
|
raise RuntimeError(
|
||||||
|
"Failed to download or load the matting model. "
|
||||||
|
"This could be due to a network issue, file permissions, or a corrupted model cache. "
|
||||||
|
f"Please check your internet connection and the model cache path: {full_model_path}. "
|
||||||
|
f"Original error: {str(e)}"
|
||||||
|
) from e
|
||||||
else:
|
else:
|
||||||
self.model = self.model_cache[model_path]
|
self.model = self.model_cache[model_path]
|
||||||
log_debug("Using cached model")
|
log_debug("Using cached model")
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# Catch the re-raised exception or any other error
|
||||||
log_error(f"Error loading model: {str(e)}")
|
log_error(f"Error loading model: {str(e)}")
|
||||||
log_exception("Model loading failed")
|
log_exception("Model loading failed")
|
||||||
return False
|
raise # Re-raise the exception to be caught by the execute method
|
||||||
|
|
||||||
def preprocess_image(self, image):
|
def preprocess_image(self, image):
|
||||||
|
|
||||||
@@ -678,11 +674,9 @@ class BiRefNetMatting:
|
|||||||
|
|
||||||
def execute(self, image, model_path, threshold=0.5, refinement=1):
|
def execute(self, image, model_path, threshold=0.5, refinement=1):
|
||||||
try:
|
try:
|
||||||
|
|
||||||
PromptServer.instance.send_sync("matting_status", {"status": "processing"})
|
PromptServer.instance.send_sync("matting_status", {"status": "processing"})
|
||||||
|
|
||||||
if not self.load_model(model_path):
|
self.load_model(model_path)
|
||||||
raise RuntimeError("Failed to load model")
|
|
||||||
|
|
||||||
if isinstance(image, torch.Tensor):
|
if isinstance(image, torch.Tensor):
|
||||||
original_size = image.shape[-2:] if image.dim() == 4 else image.shape[-2:]
|
original_size = image.shape[-2:] if image.dim() == 4 else image.shape[-2:]
|
||||||
|
|||||||
Reference in New Issue
Block a user