diff --git a/py/metadata_collector/metadata_registry.py b/py/metadata_collector/metadata_registry.py index bcf2284a..e287c5b1 100644 --- a/py/metadata_collector/metadata_registry.py +++ b/py/metadata_collector/metadata_registry.py @@ -239,6 +239,14 @@ class MetadataRegistry: metadata = self.prompt_metadata[key] if IMAGES in metadata and "first_decode" in metadata[IMAGES]: - return metadata[IMAGES]["first_decode"]["image"] + image_data = metadata[IMAGES]["first_decode"]["image"] + + # If it's an image batch or tuple, handle various formats + if isinstance(image_data, (list, tuple)) and len(image_data) > 0: + # Return first element of list/tuple + return image_data[0] + + # If it's a tensor, return as is for processing in the route handler + return image_data return None diff --git a/py/routes/recipe_routes.py b/py/routes/recipe_routes.py index 2e13fdf8..2b97832b 100644 --- a/py/routes/recipe_routes.py +++ b/py/routes/recipe_routes.py @@ -1,5 +1,9 @@ import os import time +import numpy as np +from PIL import Image +import torch +import io import logging from aiohttp import web from typing import Dict @@ -15,6 +19,7 @@ from ..metadata_collector import get_metadata # Add MetadataCollector import from ..metadata_collector.metadata_processor import MetadataProcessor # Add MetadataProcessor import from ..utils.utils import download_civitai_image from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import +from ..metadata_collector.metadata_registry import MetadataRegistry logger = logging.getLogger(__name__) @@ -657,8 +662,8 @@ class RecipeRoutes: logger.error(f"Error retrieving base models: {e}", exc_info=True) return web.json_response({ 'success': False, - 'error': str(e) - }, status=500) + 'error': str(e)} + , status=500) async def share_recipe(self, request: web.Request) -> web.Response: """Process a recipe image for sharing by adding metadata to EXIF""" @@ -795,21 +800,61 @@ class RecipeRoutes: if not metadata_dict: return web.json_response({"error": "No generation metadata found"}, status=400) - # Find the latest image in the temp directory - temp_dir = config.temp_directory - image_files = [] + # Get the most recent image from metadata registry instead of temp directory + metadata_registry = MetadataRegistry() + latest_image = metadata_registry.get_first_decoded_image() - for file in os.listdir(temp_dir): - if file.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')): - file_path = os.path.join(temp_dir, file) - image_files.append((file_path, os.path.getmtime(file_path))) + if not latest_image: + return web.json_response({"error": "No recent images found to use for recipe. Try generating an image first."}, status=400) - if not image_files: - return web.json_response({"error": "No recent images found to use for recipe"}, status=400) + # Convert the image data to bytes - handle tuple and tensor cases + logger.debug(f"Image type: {type(latest_image)}") - # Sort by modification time (newest first) - image_files.sort(key=lambda x: x[1], reverse=True) - latest_image_path = image_files[0][0] + try: + # Handle the tuple case first + if isinstance(latest_image, tuple): + # Extract the tensor from the tuple + if len(latest_image) > 0: + tensor_image = latest_image[0] + else: + return web.json_response({"error": "Empty image tuple received"}, status=400) + else: + tensor_image = latest_image + + # Get the shape info for debugging + if hasattr(tensor_image, 'shape'): + shape_info = tensor_image.shape + logger.debug(f"Tensor shape: {shape_info}, dtype: {tensor_image.dtype}") + + # Convert tensor to numpy array + if isinstance(tensor_image, torch.Tensor): + image_np = tensor_image.cpu().numpy() + else: + image_np = np.array(tensor_image) + + # Handle different tensor shapes + # Case: (1, 1, H, W, 3) or (1, H, W, 3) - batch or multi-batch + if len(image_np.shape) > 3: + # Remove batch dimensions until we get to (H, W, 3) + while len(image_np.shape) > 3: + image_np = image_np[0] + + # If values are in [0, 1] range, convert to [0, 255] + if image_np.dtype == np.float32 or image_np.dtype == np.float64: + if image_np.max() <= 1.0: + image_np = (image_np * 255).astype(np.uint8) + + # Ensure image is in the right format (HWC with RGB channels) + if len(image_np.shape) == 3 and image_np.shape[2] == 3: + pil_image = Image.fromarray(image_np) + img_byte_arr = io.BytesIO() + pil_image.save(img_byte_arr, format='PNG') + image = img_byte_arr.getvalue() + else: + return web.json_response({"error": f"Cannot handle this data shape: {image_np.shape}, {image_np.dtype}"}, status=400) + except Exception as e: + logger.error(f"Error processing image data: {str(e)}", exc_info=True) + return web.json_response({"error": f"Error processing image: {str(e)}"}, status=400) # Get the lora stack from the metadata lora_stack = metadata_dict.get("loras", "") @@ -834,10 +879,6 @@ class RecipeRoutes: recipe_name = " ".join(recipe_name_parts) - # Read the image - with open(latest_image_path, 'rb') as f: - image = f.read() - # Create recipes directory if it doesn't exist recipes_dir = self.recipe_scanner.recipes_dir os.makedirs(recipes_dir, exist_ok=True)