feat: Add checkpoint hash retrieval and enhance metadata formatting in SaveImage class

This commit is contained in:
Will Miao
2025-04-18 23:55:45 +08:00
parent ea34d753c1
commit fb6a652a57

View File

@@ -5,6 +5,7 @@ import re
import numpy as np
import folder_paths # type: ignore
from ..services.lora_scanner import LoraScanner
from ..services.checkpoint_scanner import CheckpointScanner
from ..metadata_collector.metadata_processor import MetadataProcessor
from ..metadata_collector import get_metadata
from PIL import Image, PngImagePlugin
@@ -60,6 +61,27 @@ class SaveImage:
return item.get('sha256')
return None
async def get_checkpoint_hash(self, checkpoint_path):
"""Get the checkpoint hash from cache"""
scanner = await CheckpointScanner.get_instance()
cache = await scanner.get_cached_data()
if not checkpoint_path:
return None
# Extract basename without extension
checkpoint_name = os.path.basename(checkpoint_path)
checkpoint_name = os.path.splitext(checkpoint_name)[0]
# Normalize path separators for comparison
normalized_path = checkpoint_path.replace('\\', '/')
for item in cache.raw_data:
if item.get('file_name') == checkpoint_name and item.get('file_path').endswith(normalized_path):
return item.get('sha256')
return None
async def format_metadata(self, metadata_dict):
"""Format metadata in the requested format similar to userComment example"""
if not metadata_dict:
@@ -102,6 +124,10 @@ class SaveImage:
if 'steps' in metadata_dict:
params.append(f"Steps: {metadata_dict.get('steps')}")
# Combine sampler and scheduler information
sampler_name = None
scheduler_name = None
if 'sampler' in metadata_dict:
sampler = metadata_dict.get('sampler')
# Convert ComfyUI sampler names to user-friendly names
@@ -123,7 +149,6 @@ class SaveImage:
'ddim': 'DDIM'
}
sampler_name = sampler_mapping.get(sampler, sampler)
params.append(f"Sampler: {sampler_name}")
if 'scheduler' in metadata_dict:
scheduler = metadata_dict.get('scheduler')
@@ -135,10 +160,18 @@ class SaveImage:
'sgm_quadratic': 'SGM Quadratic'
}
scheduler_name = scheduler_mapping.get(scheduler, scheduler)
params.append(f"Schedule type: {scheduler_name}")
# CFG scale (cfg_scale in metadata_dict)
if 'cfg_scale' in metadata_dict:
# Add combined sampler and scheduler information
if sampler_name:
if scheduler_name:
params.append(f"Sampler: {sampler_name} {scheduler_name}")
else:
params.append(f"Sampler: {sampler_name}")
# CFG scale (Use guidance if available, otherwise fall back to cfg_scale or cfg)
if 'guidance' in metadata_dict:
params.append(f"CFG scale: {metadata_dict.get('guidance')}")
elif 'cfg_scale' in metadata_dict:
params.append(f"CFG scale: {metadata_dict.get('cfg_scale')}")
elif 'cfg' in metadata_dict:
params.append(f"CFG scale: {metadata_dict.get('cfg')}")
@@ -156,17 +189,19 @@ class SaveImage:
# Ensure checkpoint is a string before processing
checkpoint = metadata_dict.get('checkpoint')
if checkpoint is not None:
# Handle both string and other types safely
if isinstance(checkpoint, str):
# Extract basename without path
checkpoint = os.path.basename(checkpoint)
# Remove extension if present
checkpoint = os.path.splitext(checkpoint)[0]
else:
# Convert non-string to string
checkpoint = str(checkpoint)
# Get model hash
model_hash = await self.get_checkpoint_hash(checkpoint)
params.append(f"Model: {checkpoint}")
# Extract basename without path
checkpoint_name = os.path.basename(checkpoint)
# Remove extension if present
checkpoint_name = os.path.splitext(checkpoint_name)[0]
# Add model hash if available
if model_hash:
params.append(f"Model hash: {model_hash[:10]}, Model: {checkpoint_name}")
else:
params.append(f"Model: {checkpoint_name}")
# Add LoRA hashes if available
if lora_hashes: