diff --git a/py/nodes/save_image.py b/py/nodes/save_image.py index 7c6174c0..cb3a8e9e 100644 --- a/py/nodes/save_image.py +++ b/py/nodes/save_image.py @@ -418,11 +418,15 @@ class SaveImage: # Make sure the output directory exists os.makedirs(self.output_dir, exist_ok=True) - # Ensure images is always a list of images - if len(images.shape) == 3: # Single image (height, width, channels) - images = [images] - else: # Multiple images (batch, height, width, channels) - images = [img for img in images] + # If images is already a list or array of images, do nothing; otherwise, convert to list + if isinstance(images, (list, np.ndarray)): + pass + else: + # Ensure images is always a list of images + if len(images.shape) == 3: # Single image (height, width, channels) + images = [images] + else: # Multiple images (batch, height, width, channels) + images = [img for img in images] # Save all images results = self.save_images(