Fix matting model check and frontend flow

Added proper backend validation for both config.json and model.safetensors to confirm model availability. Updated frontend logic to use /matting/check-model response, preventing unnecessary download notifications.
This commit is contained in:
Dariusz L
2025-09-04 23:10:22 +02:00
parent 20ab861315
commit 7a5ecb3919
3 changed files with 218 additions and 17 deletions

View File

@@ -64,6 +64,8 @@ class BiRefNetConfig(PretrainedConfig):
def __init__(self, bb_pretrained=False, **kwargs):
self.bb_pretrained = bb_pretrained
# Add the missing is_encoder_decoder attribute for compatibility with newer transformers
self.is_encoder_decoder = False
super().__init__(**kwargs)
@@ -755,16 +757,32 @@ class BiRefNetMatting:
full_model_path = os.path.join(self.base_path, "BiRefNet")
log_info(f"Loading BiRefNet model from {full_model_path}...")
try:
# Try loading with additional configuration to handle compatibility issues
self.model = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet",
trust_remote_code=True,
cache_dir=full_model_path
cache_dir=full_model_path,
# Add force_download=False to use cached version if available
force_download=False,
# Add local_files_only=False to allow downloading if needed
local_files_only=False
)
self.model.eval()
if torch.cuda.is_available():
self.model = self.model.cuda()
self.model_cache[model_path] = self.model
log_info("Model loaded successfully from Hugging Face")
except AttributeError as e:
if "'Config' object has no attribute 'is_encoder_decoder'" in str(e):
log_error("Compatibility issue detected with transformers library. This has been fixed in the code.")
log_error("If you're still seeing this error, please clear the model cache and try again.")
raise RuntimeError(
"Model configuration compatibility issue detected. "
f"Please delete the model cache directory '{full_model_path}' and restart ComfyUI. "
"This will download a fresh copy of the model with the updated configuration."
) from e
else:
raise e
except JSONDecodeError as e:
log_error(f"JSONDecodeError: Failed to load model from {full_model_path}. The model's config.json may be corrupted.")
raise RuntimeError(
@@ -894,6 +912,95 @@ class BiRefNetMatting:
_matting_lock = None
@PromptServer.instance.routes.get("/matting/check-model")
async def check_matting_model(request):
"""Check if the matting model is available and ready to use"""
try:
if not TRANSFORMERS_AVAILABLE:
return web.json_response({
"available": False,
"reason": "missing_dependency",
"message": "The 'transformers' library is required for the matting feature. Please install it by running: pip install transformers"
})
# Check if model exists in cache
base_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "models")
model_path = os.path.join(base_path, "BiRefNet")
# Look for the actual BiRefNet model structure
model_files_exist = False
if os.path.exists(model_path):
# BiRefNet model from Hugging Face has a specific structure
# Check for subdirectories that indicate the model is downloaded
existing_items = os.listdir(model_path) if os.path.isdir(model_path) else []
# Look for the model subdirectory (usually named with the model ID)
model_subdirs = [d for d in existing_items if os.path.isdir(os.path.join(model_path, d)) and
(d.startswith("models--") or d == "ZhengPeng7--BiRefNet")]
if model_subdirs:
# Found model subdirectory, check inside for actual model files
for subdir in model_subdirs:
subdir_path = os.path.join(model_path, subdir)
# Navigate through the cache structure
if os.path.exists(os.path.join(subdir_path, "snapshots")):
snapshots_path = os.path.join(subdir_path, "snapshots")
snapshot_dirs = os.listdir(snapshots_path) if os.path.isdir(snapshots_path) else []
for snapshot in snapshot_dirs:
snapshot_path = os.path.join(snapshots_path, snapshot)
snapshot_files = os.listdir(snapshot_path) if os.path.isdir(snapshot_path) else []
# Check for essential files - BiRefNet uses model.safetensors
has_config = "config.json" in snapshot_files
has_model = "model.safetensors" in snapshot_files or "pytorch_model.bin" in snapshot_files
has_backbone = "backbone_swin.pth" in snapshot_files or "swin_base_patch4_window12_384_22kto1k.pth" in snapshot_files
has_birefnet = "birefnet.pth" in snapshot_files or any(f.endswith(".pth") for f in snapshot_files)
# Model is valid if it has config and either model.safetensors or other model files
if has_config and (has_model or has_backbone or has_birefnet):
model_files_exist = True
log_info(f"Found model files in: {snapshot_path} (config: {has_config}, model: {has_model})")
break
if model_files_exist:
break
# Also check if there are .pth files directly in the model_path
if not model_files_exist:
direct_files = existing_items
has_config = "config.json" in direct_files
has_model_files = any(f.endswith((".pth", ".bin", ".safetensors")) for f in direct_files)
model_files_exist = has_config and has_model_files
if model_files_exist:
log_info(f"Found model files directly in: {model_path}")
if model_files_exist:
# Model files exist, assume it's ready
log_info("BiRefNet model files detected")
return web.json_response({
"available": True,
"reason": "ready",
"message": "Model is ready to use"
})
else:
log_info(f"BiRefNet model not found in {model_path}")
return web.json_response({
"available": False,
"reason": "not_downloaded",
"message": "The matting model needs to be downloaded. This will happen automatically when you first use the matting feature (requires internet connection).",
"model_path": model_path
})
except Exception as e:
log_error(f"Error checking matting model: {str(e)}")
return web.json_response({
"available": False,
"reason": "error",
"message": f"Error checking model status: {str(e)}"
}, status=500)
@PromptServer.instance.routes.post("/matting")
async def matting(request):
global _matting_lock