feat: Enhance image retrieval in MetadataRegistry and update recipe routes to process images from metadata

This commit is contained in:
Will Miao
2025-04-18 09:24:48 +08:00
parent df6d56ce66
commit 91b4827c1d
2 changed files with 68 additions and 19 deletions

View File

@@ -239,6 +239,14 @@ class MetadataRegistry:
metadata = self.prompt_metadata[key] metadata = self.prompt_metadata[key]
if IMAGES in metadata and "first_decode" in metadata[IMAGES]: if IMAGES in metadata and "first_decode" in metadata[IMAGES]:
return metadata[IMAGES]["first_decode"]["image"] image_data = metadata[IMAGES]["first_decode"]["image"]
# If it's an image batch or tuple, handle various formats
if isinstance(image_data, (list, tuple)) and len(image_data) > 0:
# Return first element of list/tuple
return image_data[0]
# If it's a tensor, return as is for processing in the route handler
return image_data
return None return None

View File

@@ -1,5 +1,9 @@
import os import os
import time import time
import numpy as np
from PIL import Image
import torch
import io
import logging import logging
from aiohttp import web from aiohttp import web
from typing import Dict from typing import Dict
@@ -15,6 +19,7 @@ from ..metadata_collector import get_metadata # Add MetadataCollector import
from ..metadata_collector.metadata_processor import MetadataProcessor # Add MetadataProcessor import from ..metadata_collector.metadata_processor import MetadataProcessor # Add MetadataProcessor import
from ..utils.utils import download_civitai_image from ..utils.utils import download_civitai_image
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import
from ..metadata_collector.metadata_registry import MetadataRegistry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -657,8 +662,8 @@ class RecipeRoutes:
logger.error(f"Error retrieving base models: {e}", exc_info=True) logger.error(f"Error retrieving base models: {e}", exc_info=True)
return web.json_response({ return web.json_response({
'success': False, 'success': False,
'error': str(e) 'error': str(e)}
}, status=500) , status=500)
async def share_recipe(self, request: web.Request) -> web.Response: async def share_recipe(self, request: web.Request) -> web.Response:
"""Process a recipe image for sharing by adding metadata to EXIF""" """Process a recipe image for sharing by adding metadata to EXIF"""
@@ -795,21 +800,61 @@ class RecipeRoutes:
if not metadata_dict: if not metadata_dict:
return web.json_response({"error": "No generation metadata found"}, status=400) return web.json_response({"error": "No generation metadata found"}, status=400)
# Find the latest image in the temp directory # Get the most recent image from metadata registry instead of temp directory
temp_dir = config.temp_directory metadata_registry = MetadataRegistry()
image_files = [] latest_image = metadata_registry.get_first_decoded_image()
for file in os.listdir(temp_dir): if not latest_image:
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')): return web.json_response({"error": "No recent images found to use for recipe. Try generating an image first."}, status=400)
file_path = os.path.join(temp_dir, file)
image_files.append((file_path, os.path.getmtime(file_path)))
if not image_files: # Convert the image data to bytes - handle tuple and tensor cases
return web.json_response({"error": "No recent images found to use for recipe"}, status=400) logger.debug(f"Image type: {type(latest_image)}")
# Sort by modification time (newest first) try:
image_files.sort(key=lambda x: x[1], reverse=True) # Handle the tuple case first
latest_image_path = image_files[0][0] if isinstance(latest_image, tuple):
# Extract the tensor from the tuple
if len(latest_image) > 0:
tensor_image = latest_image[0]
else:
return web.json_response({"error": "Empty image tuple received"}, status=400)
else:
tensor_image = latest_image
# Get the shape info for debugging
if hasattr(tensor_image, 'shape'):
shape_info = tensor_image.shape
logger.debug(f"Tensor shape: {shape_info}, dtype: {tensor_image.dtype}")
# Convert tensor to numpy array
if isinstance(tensor_image, torch.Tensor):
image_np = tensor_image.cpu().numpy()
else:
image_np = np.array(tensor_image)
# Handle different tensor shapes
# Case: (1, 1, H, W, 3) or (1, H, W, 3) - batch or multi-batch
if len(image_np.shape) > 3:
# Remove batch dimensions until we get to (H, W, 3)
while len(image_np.shape) > 3:
image_np = image_np[0]
# If values are in [0, 1] range, convert to [0, 255]
if image_np.dtype == np.float32 or image_np.dtype == np.float64:
if image_np.max() <= 1.0:
image_np = (image_np * 255).astype(np.uint8)
# Ensure image is in the right format (HWC with RGB channels)
if len(image_np.shape) == 3 and image_np.shape[2] == 3:
pil_image = Image.fromarray(image_np)
img_byte_arr = io.BytesIO()
pil_image.save(img_byte_arr, format='PNG')
image = img_byte_arr.getvalue()
else:
return web.json_response({"error": f"Cannot handle this data shape: {image_np.shape}, {image_np.dtype}"}, status=400)
except Exception as e:
logger.error(f"Error processing image data: {str(e)}", exc_info=True)
return web.json_response({"error": f"Error processing image: {str(e)}"}, status=400)
# Get the lora stack from the metadata # Get the lora stack from the metadata
lora_stack = metadata_dict.get("loras", "") lora_stack = metadata_dict.get("loras", "")
@@ -834,10 +879,6 @@ class RecipeRoutes:
recipe_name = " ".join(recipe_name_parts) recipe_name = " ".join(recipe_name_parts)
# Read the image
with open(latest_image_path, 'rb') as f:
image = f.read()
# Create recipes directory if it doesn't exist # Create recipes directory if it doesn't exist
recipes_dir = self.recipe_scanner.recipes_dir recipes_dir = self.recipe_scanner.recipes_dir
os.makedirs(recipes_dir, exist_ok=True) os.makedirs(recipes_dir, exist_ok=True)