From fb6a652a571440cee632ab9c48af9bcff77d4cf7 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Fri, 18 Apr 2025 23:55:45 +0800 Subject: [PATCH] feat: Add checkpoint hash retrieval and enhance metadata formatting in SaveImage class --- py/nodes/save_image.py | 63 ++++++++++++++++++++++++++++++++---------- 1 file changed, 49 insertions(+), 14 deletions(-) diff --git a/py/nodes/save_image.py b/py/nodes/save_image.py index 003092db..2d9dc084 100644 --- a/py/nodes/save_image.py +++ b/py/nodes/save_image.py @@ -5,6 +5,7 @@ import re import numpy as np import folder_paths # type: ignore from ..services.lora_scanner import LoraScanner +from ..services.checkpoint_scanner import CheckpointScanner from ..metadata_collector.metadata_processor import MetadataProcessor from ..metadata_collector import get_metadata from PIL import Image, PngImagePlugin @@ -60,6 +61,27 @@ class SaveImage: return item.get('sha256') return None + async def get_checkpoint_hash(self, checkpoint_path): + """Get the checkpoint hash from cache""" + scanner = await CheckpointScanner.get_instance() + cache = await scanner.get_cached_data() + + if not checkpoint_path: + return None + + # Extract basename without extension + checkpoint_name = os.path.basename(checkpoint_path) + checkpoint_name = os.path.splitext(checkpoint_name)[0] + + # Normalize path separators for comparison + normalized_path = checkpoint_path.replace('\\', '/') + + for item in cache.raw_data: + if item.get('file_name') == checkpoint_name and item.get('file_path').endswith(normalized_path): + return item.get('sha256') + + return None + async def format_metadata(self, metadata_dict): """Format metadata in the requested format similar to userComment example""" if not metadata_dict: @@ -102,6 +124,10 @@ class SaveImage: if 'steps' in metadata_dict: params.append(f"Steps: {metadata_dict.get('steps')}") + # Combine sampler and scheduler information + sampler_name = None + scheduler_name = None + if 'sampler' in metadata_dict: sampler = metadata_dict.get('sampler') # Convert ComfyUI sampler names to user-friendly names @@ -123,7 +149,6 @@ class SaveImage: 'ddim': 'DDIM' } sampler_name = sampler_mapping.get(sampler, sampler) - params.append(f"Sampler: {sampler_name}") if 'scheduler' in metadata_dict: scheduler = metadata_dict.get('scheduler') @@ -135,10 +160,18 @@ class SaveImage: 'sgm_quadratic': 'SGM Quadratic' } scheduler_name = scheduler_mapping.get(scheduler, scheduler) - params.append(f"Schedule type: {scheduler_name}") - # CFG scale (cfg_scale in metadata_dict) - if 'cfg_scale' in metadata_dict: + # Add combined sampler and scheduler information + if sampler_name: + if scheduler_name: + params.append(f"Sampler: {sampler_name} {scheduler_name}") + else: + params.append(f"Sampler: {sampler_name}") + + # CFG scale (Use guidance if available, otherwise fall back to cfg_scale or cfg) + if 'guidance' in metadata_dict: + params.append(f"CFG scale: {metadata_dict.get('guidance')}") + elif 'cfg_scale' in metadata_dict: params.append(f"CFG scale: {metadata_dict.get('cfg_scale')}") elif 'cfg' in metadata_dict: params.append(f"CFG scale: {metadata_dict.get('cfg')}") @@ -156,17 +189,19 @@ class SaveImage: # Ensure checkpoint is a string before processing checkpoint = metadata_dict.get('checkpoint') if checkpoint is not None: - # Handle both string and other types safely - if isinstance(checkpoint, str): - # Extract basename without path - checkpoint = os.path.basename(checkpoint) - # Remove extension if present - checkpoint = os.path.splitext(checkpoint)[0] - else: - # Convert non-string to string - checkpoint = str(checkpoint) + # Get model hash + model_hash = await self.get_checkpoint_hash(checkpoint) - params.append(f"Model: {checkpoint}") + # Extract basename without path + checkpoint_name = os.path.basename(checkpoint) + # Remove extension if present + checkpoint_name = os.path.splitext(checkpoint_name)[0] + + # Add model hash if available + if model_hash: + params.append(f"Model hash: {model_hash[:10]}, Model: {checkpoint_name}") + else: + params.append(f"Model: {checkpoint_name}") # Add LoRA hashes if available if lora_hashes: