mirror of
https://github.com/Azornes/Comfyui-LayerForge.git
synced 2026-03-21 12:52:10 -03:00
Fix matting model check and frontend flow
Added proper backend validation for both config.json and model.safetensors to confirm model availability. Updated frontend logic to use /matting/check-model response, preventing unnecessary download notifications.
This commit is contained in:
109
canvas_node.py
109
canvas_node.py
@@ -64,6 +64,8 @@ class BiRefNetConfig(PretrainedConfig):
|
||||
|
||||
def __init__(self, bb_pretrained=False, **kwargs):
|
||||
self.bb_pretrained = bb_pretrained
|
||||
# Add the missing is_encoder_decoder attribute for compatibility with newer transformers
|
||||
self.is_encoder_decoder = False
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
@@ -755,16 +757,32 @@ class BiRefNetMatting:
|
||||
full_model_path = os.path.join(self.base_path, "BiRefNet")
|
||||
log_info(f"Loading BiRefNet model from {full_model_path}...")
|
||||
try:
|
||||
# Try loading with additional configuration to handle compatibility issues
|
||||
self.model = AutoModelForImageSegmentation.from_pretrained(
|
||||
"ZhengPeng7/BiRefNet",
|
||||
trust_remote_code=True,
|
||||
cache_dir=full_model_path
|
||||
cache_dir=full_model_path,
|
||||
# Add force_download=False to use cached version if available
|
||||
force_download=False,
|
||||
# Add local_files_only=False to allow downloading if needed
|
||||
local_files_only=False
|
||||
)
|
||||
self.model.eval()
|
||||
if torch.cuda.is_available():
|
||||
self.model = self.model.cuda()
|
||||
self.model_cache[model_path] = self.model
|
||||
log_info("Model loaded successfully from Hugging Face")
|
||||
except AttributeError as e:
|
||||
if "'Config' object has no attribute 'is_encoder_decoder'" in str(e):
|
||||
log_error("Compatibility issue detected with transformers library. This has been fixed in the code.")
|
||||
log_error("If you're still seeing this error, please clear the model cache and try again.")
|
||||
raise RuntimeError(
|
||||
"Model configuration compatibility issue detected. "
|
||||
f"Please delete the model cache directory '{full_model_path}' and restart ComfyUI. "
|
||||
"This will download a fresh copy of the model with the updated configuration."
|
||||
) from e
|
||||
else:
|
||||
raise e
|
||||
except JSONDecodeError as e:
|
||||
log_error(f"JSONDecodeError: Failed to load model from {full_model_path}. The model's config.json may be corrupted.")
|
||||
raise RuntimeError(
|
||||
@@ -894,6 +912,95 @@ class BiRefNetMatting:
|
||||
|
||||
_matting_lock = None
|
||||
|
||||
@PromptServer.instance.routes.get("/matting/check-model")
|
||||
async def check_matting_model(request):
|
||||
"""Check if the matting model is available and ready to use"""
|
||||
try:
|
||||
if not TRANSFORMERS_AVAILABLE:
|
||||
return web.json_response({
|
||||
"available": False,
|
||||
"reason": "missing_dependency",
|
||||
"message": "The 'transformers' library is required for the matting feature. Please install it by running: pip install transformers"
|
||||
})
|
||||
|
||||
# Check if model exists in cache
|
||||
base_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "models")
|
||||
model_path = os.path.join(base_path, "BiRefNet")
|
||||
|
||||
# Look for the actual BiRefNet model structure
|
||||
model_files_exist = False
|
||||
if os.path.exists(model_path):
|
||||
# BiRefNet model from Hugging Face has a specific structure
|
||||
# Check for subdirectories that indicate the model is downloaded
|
||||
existing_items = os.listdir(model_path) if os.path.isdir(model_path) else []
|
||||
|
||||
# Look for the model subdirectory (usually named with the model ID)
|
||||
model_subdirs = [d for d in existing_items if os.path.isdir(os.path.join(model_path, d)) and
|
||||
(d.startswith("models--") or d == "ZhengPeng7--BiRefNet")]
|
||||
|
||||
if model_subdirs:
|
||||
# Found model subdirectory, check inside for actual model files
|
||||
for subdir in model_subdirs:
|
||||
subdir_path = os.path.join(model_path, subdir)
|
||||
# Navigate through the cache structure
|
||||
if os.path.exists(os.path.join(subdir_path, "snapshots")):
|
||||
snapshots_path = os.path.join(subdir_path, "snapshots")
|
||||
snapshot_dirs = os.listdir(snapshots_path) if os.path.isdir(snapshots_path) else []
|
||||
|
||||
for snapshot in snapshot_dirs:
|
||||
snapshot_path = os.path.join(snapshots_path, snapshot)
|
||||
snapshot_files = os.listdir(snapshot_path) if os.path.isdir(snapshot_path) else []
|
||||
|
||||
# Check for essential files - BiRefNet uses model.safetensors
|
||||
has_config = "config.json" in snapshot_files
|
||||
has_model = "model.safetensors" in snapshot_files or "pytorch_model.bin" in snapshot_files
|
||||
has_backbone = "backbone_swin.pth" in snapshot_files or "swin_base_patch4_window12_384_22kto1k.pth" in snapshot_files
|
||||
has_birefnet = "birefnet.pth" in snapshot_files or any(f.endswith(".pth") for f in snapshot_files)
|
||||
|
||||
# Model is valid if it has config and either model.safetensors or other model files
|
||||
if has_config and (has_model or has_backbone or has_birefnet):
|
||||
model_files_exist = True
|
||||
log_info(f"Found model files in: {snapshot_path} (config: {has_config}, model: {has_model})")
|
||||
break
|
||||
|
||||
if model_files_exist:
|
||||
break
|
||||
|
||||
# Also check if there are .pth files directly in the model_path
|
||||
if not model_files_exist:
|
||||
direct_files = existing_items
|
||||
has_config = "config.json" in direct_files
|
||||
has_model_files = any(f.endswith((".pth", ".bin", ".safetensors")) for f in direct_files)
|
||||
model_files_exist = has_config and has_model_files
|
||||
|
||||
if model_files_exist:
|
||||
log_info(f"Found model files directly in: {model_path}")
|
||||
|
||||
if model_files_exist:
|
||||
# Model files exist, assume it's ready
|
||||
log_info("BiRefNet model files detected")
|
||||
return web.json_response({
|
||||
"available": True,
|
||||
"reason": "ready",
|
||||
"message": "Model is ready to use"
|
||||
})
|
||||
else:
|
||||
log_info(f"BiRefNet model not found in {model_path}")
|
||||
return web.json_response({
|
||||
"available": False,
|
||||
"reason": "not_downloaded",
|
||||
"message": "The matting model needs to be downloaded. This will happen automatically when you first use the matting feature (requires internet connection).",
|
||||
"model_path": model_path
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
log_error(f"Error checking matting model: {str(e)}")
|
||||
return web.json_response({
|
||||
"available": False,
|
||||
"reason": "error",
|
||||
"message": f"Error checking model status: {str(e)}"
|
||||
}, status=500)
|
||||
|
||||
@PromptServer.instance.routes.post("/matting")
|
||||
async def matting(request):
|
||||
global _matting_lock
|
||||
|
||||
@@ -343,11 +343,38 @@ async function createCanvasWidget(node, widget, app) {
|
||||
const button = e.target.closest('.matting-button');
|
||||
if (button.classList.contains('loading'))
|
||||
return;
|
||||
const spinner = $el("div.matting-spinner");
|
||||
button.appendChild(spinner);
|
||||
button.classList.add('loading');
|
||||
showInfoNotification("Starting background removal process...", 2000);
|
||||
try {
|
||||
// First check if model is available
|
||||
const modelCheckResponse = await fetch("/matting/check-model");
|
||||
const modelStatus = await modelCheckResponse.json();
|
||||
if (!modelStatus.available) {
|
||||
switch (modelStatus.reason) {
|
||||
case 'missing_dependency':
|
||||
showErrorNotification(modelStatus.message, 8000);
|
||||
return;
|
||||
case 'not_downloaded':
|
||||
showWarningNotification("The matting model needs to be downloaded first. This will happen automatically when you proceed (requires internet connection).", 5000);
|
||||
// Ask user if they want to proceed with download
|
||||
if (!confirm("The matting model needs to be downloaded (about 1GB). This is a one-time download. Do you want to proceed?")) {
|
||||
return;
|
||||
}
|
||||
showInfoNotification("Downloading matting model... This may take a few minutes.", 10000);
|
||||
break;
|
||||
case 'corrupted':
|
||||
showErrorNotification(modelStatus.message, 8000);
|
||||
return;
|
||||
case 'error':
|
||||
showErrorNotification(`Error checking model: ${modelStatus.message}`, 5000);
|
||||
return;
|
||||
}
|
||||
}
|
||||
// Proceed with matting
|
||||
const spinner = $el("div.matting-spinner");
|
||||
button.appendChild(spinner);
|
||||
button.classList.add('loading');
|
||||
if (modelStatus.available) {
|
||||
showInfoNotification("Starting background removal process...", 2000);
|
||||
}
|
||||
if (canvas.canvasSelection.selectedLayers.length !== 1) {
|
||||
throw new Error("Please select exactly one image layer for matting.");
|
||||
}
|
||||
@@ -363,7 +390,20 @@ async function createCanvasWidget(node, widget, app) {
|
||||
if (!response.ok) {
|
||||
let errorMsg = `Server error: ${response.status} - ${response.statusText}`;
|
||||
if (result && result.error) {
|
||||
errorMsg = `Error: ${result.error}. Details: ${result.details || 'Check console'}`;
|
||||
// Handle specific error types
|
||||
if (result.error === "Network Connection Error") {
|
||||
showErrorNotification("Failed to download the matting model. Please check your internet connection and try again.", 8000);
|
||||
return;
|
||||
}
|
||||
else if (result.error === "Matting Model Error") {
|
||||
showErrorNotification(result.details || "Model loading error. Please check the console for details.", 8000);
|
||||
return;
|
||||
}
|
||||
else if (result.error === "Dependency Not Found") {
|
||||
showErrorNotification(result.details || "Missing required dependencies.", 8000);
|
||||
return;
|
||||
}
|
||||
errorMsg = `${result.error}: ${result.details || 'Check console'}`;
|
||||
}
|
||||
throw new Error(errorMsg);
|
||||
}
|
||||
@@ -383,11 +423,16 @@ async function createCanvasWidget(node, widget, app) {
|
||||
catch (error) {
|
||||
log.error("Matting error:", error);
|
||||
const errorMessage = error.message || "An unknown error occurred.";
|
||||
showErrorNotification(`Matting Failed: ${errorMessage}`);
|
||||
if (!errorMessage.includes("Network Connection Error") &&
|
||||
!errorMessage.includes("Matting Model Error") &&
|
||||
!errorMessage.includes("Dependency Not Found")) {
|
||||
showErrorNotification(`Matting Failed: ${errorMessage}`);
|
||||
}
|
||||
}
|
||||
finally {
|
||||
button.classList.remove('loading');
|
||||
if (button.contains(spinner)) {
|
||||
const spinner = button.querySelector('.matting-spinner');
|
||||
if (spinner && button.contains(spinner)) {
|
||||
button.removeChild(spinner);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -418,13 +418,46 @@ async function createCanvasWidget(node: ComfyNode, widget: any, app: ComfyApp):
|
||||
const button = (e.target as HTMLElement).closest('.matting-button') as HTMLButtonElement;
|
||||
if (button.classList.contains('loading')) return;
|
||||
|
||||
const spinner = $el("div.matting-spinner") as HTMLDivElement;
|
||||
button.appendChild(spinner);
|
||||
button.classList.add('loading');
|
||||
|
||||
showInfoNotification("Starting background removal process...", 2000);
|
||||
|
||||
try {
|
||||
// First check if model is available
|
||||
const modelCheckResponse = await fetch("/matting/check-model");
|
||||
const modelStatus = await modelCheckResponse.json();
|
||||
|
||||
if (!modelStatus.available) {
|
||||
switch (modelStatus.reason) {
|
||||
case 'missing_dependency':
|
||||
showErrorNotification(modelStatus.message, 8000);
|
||||
return;
|
||||
|
||||
case 'not_downloaded':
|
||||
showWarningNotification("The matting model needs to be downloaded first. This will happen automatically when you proceed (requires internet connection).", 5000);
|
||||
|
||||
// Ask user if they want to proceed with download
|
||||
if (!confirm("The matting model needs to be downloaded (about 1GB). This is a one-time download. Do you want to proceed?")) {
|
||||
return;
|
||||
}
|
||||
showInfoNotification("Downloading matting model... This may take a few minutes.", 10000);
|
||||
break;
|
||||
|
||||
case 'corrupted':
|
||||
showErrorNotification(modelStatus.message, 8000);
|
||||
return;
|
||||
|
||||
case 'error':
|
||||
showErrorNotification(`Error checking model: ${modelStatus.message}`, 5000);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Proceed with matting
|
||||
const spinner = $el("div.matting-spinner") as HTMLDivElement;
|
||||
button.appendChild(spinner);
|
||||
button.classList.add('loading');
|
||||
|
||||
if (modelStatus.available) {
|
||||
showInfoNotification("Starting background removal process...", 2000);
|
||||
}
|
||||
|
||||
if (canvas.canvasSelection.selectedLayers.length !== 1) {
|
||||
throw new Error("Please select exactly one image layer for matting.");
|
||||
}
|
||||
@@ -443,7 +476,18 @@ async function createCanvasWidget(node: ComfyNode, widget: any, app: ComfyApp):
|
||||
if (!response.ok) {
|
||||
let errorMsg = `Server error: ${response.status} - ${response.statusText}`;
|
||||
if (result && result.error) {
|
||||
errorMsg = `Error: ${result.error}. Details: ${result.details || 'Check console'}`;
|
||||
// Handle specific error types
|
||||
if (result.error === "Network Connection Error") {
|
||||
showErrorNotification("Failed to download the matting model. Please check your internet connection and try again.", 8000);
|
||||
return;
|
||||
} else if (result.error === "Matting Model Error") {
|
||||
showErrorNotification(result.details || "Model loading error. Please check the console for details.", 8000);
|
||||
return;
|
||||
} else if (result.error === "Dependency Not Found") {
|
||||
showErrorNotification(result.details || "Missing required dependencies.", 8000);
|
||||
return;
|
||||
}
|
||||
errorMsg = `${result.error}: ${result.details || 'Check console'}`;
|
||||
}
|
||||
throw new Error(errorMsg);
|
||||
}
|
||||
@@ -468,10 +512,15 @@ async function createCanvasWidget(node: ComfyNode, widget: any, app: ComfyApp):
|
||||
} catch (error: any) {
|
||||
log.error("Matting error:", error);
|
||||
const errorMessage = error.message || "An unknown error occurred.";
|
||||
showErrorNotification(`Matting Failed: ${errorMessage}`);
|
||||
if (!errorMessage.includes("Network Connection Error") &&
|
||||
!errorMessage.includes("Matting Model Error") &&
|
||||
!errorMessage.includes("Dependency Not Found")) {
|
||||
showErrorNotification(`Matting Failed: ${errorMessage}`);
|
||||
}
|
||||
} finally {
|
||||
button.classList.remove('loading');
|
||||
if (button.contains(spinner)) {
|
||||
const spinner = button.querySelector('.matting-spinner');
|
||||
if (spinner && button.contains(spinner)) {
|
||||
button.removeChild(spinner);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user