diff --git a/py/nodes/save_image.py b/py/nodes/save_image.py index b9e75a5a..39fea253 100644 --- a/py/nodes/save_image.py +++ b/py/nodes/save_image.py @@ -34,6 +34,7 @@ class SaveImage: "file_format": (["png", "jpeg", "webp"],), }, "optional": { + "custom_prompt": ("STRING", {"default": "", "forceInput": True}), "lossless_webp": ("BOOLEAN", {"default": True}), "quality": ("INT", {"default": 100, "min": 1, "max": 100}), "embed_workflow": ("BOOLEAN", {"default": False}), @@ -60,7 +61,7 @@ class SaveImage: return item.get('sha256') return None - async def format_metadata(self, parsed_workflow): + async def format_metadata(self, parsed_workflow, custom_prompt=None): """Format metadata in the requested format similar to userComment example""" if not parsed_workflow: return "" @@ -69,6 +70,10 @@ class SaveImage: prompt = parsed_workflow.get('prompt', '') negative_prompt = parsed_workflow.get('negative_prompt', '') + # Override prompt with custom_prompt if provided + if custom_prompt: + prompt = custom_prompt + # Extract loras from the prompt if present loras_text = parsed_workflow.get('loras', '') lora_hashes = {} @@ -240,7 +245,8 @@ class SaveImage: return filename def save_images(self, images, filename_prefix, file_format, prompt=None, extra_pnginfo=None, - lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True): + lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True, + custom_prompt=None): """Save images with metadata""" results = [] @@ -248,11 +254,12 @@ class SaveImage: parser = WorkflowParser() if prompt: parsed_workflow = parser.parse_workflow(prompt) + print("parsed_workflow", parsed_workflow) else: parsed_workflow = {} # Get or create metadata asynchronously - metadata = asyncio.run(self.format_metadata(parsed_workflow)) + metadata = asyncio.run(self.format_metadata(parsed_workflow, custom_prompt)) # Process filename_prefix with pattern substitution filename_prefix = self.format_filename(filename_prefix, parsed_workflow) @@ -338,7 +345,8 @@ class SaveImage: return results def process_image(self, images, filename_prefix="ComfyUI", file_format="png", prompt=None, extra_pnginfo=None, - lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True): + lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True, + custom_prompt=""): """Process and save image with metadata""" # Make sure the output directory exists os.makedirs(self.output_dir, exist_ok=True) @@ -356,7 +364,8 @@ class SaveImage: lossless_webp, quality, embed_workflow, - add_counter_to_filename + add_counter_to_filename, + custom_prompt if custom_prompt.strip() else None ) return (images,) \ No newline at end of file diff --git a/py/utils/models.py b/py/utils/models.py index e543cc8c..13fcbad7 100644 --- a/py/utils/models.py +++ b/py/utils/models.py @@ -75,3 +75,31 @@ class LoraMetadata: self.modified = os.path.getmtime(file_path) self.file_path = file_path.replace(os.sep, '/') +@dataclass +class CheckpointMetadata: + """Represents the metadata structure for a Checkpoint model""" + file_name: str # The filename without extension + model_name: str # The checkpoint's name defined by the creator + file_path: str # Full path to the model file + size: int # File size in bytes + modified: float # Last modified timestamp + sha256: str # SHA256 hash of the file + base_model: str # Base model type (SD1.5/SD2.1/SDXL/etc.) + preview_url: str # Preview image URL + preview_nsfw_level: int = 0 # NSFW level of the preview image + model_type: str = "checkpoint" # Model type (checkpoint, inpainting, etc.) + notes: str = "" # Additional notes + from_civitai: bool = True # Whether from Civitai + civitai: Optional[Dict] = None # Civitai API data if available + tags: List[str] = None # Model tags + modelDescription: str = "" # Full model description + + # Additional checkpoint-specific fields + resolution: Optional[str] = None # Native resolution (e.g., 512x512, 1024x1024) + vae_included: bool = False # Whether VAE is included in the checkpoint + architecture: str = "" # Model architecture (if known) + + def __post_init__(self): + if self.tags is None: + self.tags = [] + diff --git a/py/workflow/mappers.py b/py/workflow/mappers.py index 33bc9e9b..528aeefb 100644 --- a/py/workflow/mappers.py +++ b/py/workflow/mappers.py @@ -134,7 +134,7 @@ def transform_lora_loader(inputs: Dict) -> Dict: "loras": " ".join(lora_texts) } - if "clip" in inputs: + if "clip" in inputs and isinstance(inputs["clip"], dict): result["clip_skip"] = inputs["clip"].get("clip_skip", "-1") return result