feat: Enhance image retrieval in MetadataRegistry and update recipe routes to process images from metadata

This commit is contained in:
Will Miao
2025-04-18 09:24:48 +08:00
parent df6d56ce66
commit 91b4827c1d
2 changed files with 68 additions and 19 deletions

View File

@@ -239,6 +239,14 @@ class MetadataRegistry:
metadata = self.prompt_metadata[key]
if IMAGES in metadata and "first_decode" in metadata[IMAGES]:
return metadata[IMAGES]["first_decode"]["image"]
image_data = metadata[IMAGES]["first_decode"]["image"]
# If it's an image batch or tuple, handle various formats
if isinstance(image_data, (list, tuple)) and len(image_data) > 0:
# Return first element of list/tuple
return image_data[0]
# If it's a tensor, return as is for processing in the route handler
return image_data
return None

View File

@@ -1,5 +1,9 @@
import os
import time
import numpy as np
from PIL import Image
import torch
import io
import logging
from aiohttp import web
from typing import Dict
@@ -15,6 +19,7 @@ from ..metadata_collector import get_metadata # Add MetadataCollector import
from ..metadata_collector.metadata_processor import MetadataProcessor # Add MetadataProcessor import
from ..utils.utils import download_civitai_image
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import
from ..metadata_collector.metadata_registry import MetadataRegistry
logger = logging.getLogger(__name__)
@@ -657,8 +662,8 @@ class RecipeRoutes:
logger.error(f"Error retrieving base models: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
'error': str(e)}
, status=500)
async def share_recipe(self, request: web.Request) -> web.Response:
"""Process a recipe image for sharing by adding metadata to EXIF"""
@@ -795,21 +800,61 @@ class RecipeRoutes:
if not metadata_dict:
return web.json_response({"error": "No generation metadata found"}, status=400)
# Find the latest image in the temp directory
temp_dir = config.temp_directory
image_files = []
# Get the most recent image from metadata registry instead of temp directory
metadata_registry = MetadataRegistry()
latest_image = metadata_registry.get_first_decoded_image()
for file in os.listdir(temp_dir):
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
file_path = os.path.join(temp_dir, file)
image_files.append((file_path, os.path.getmtime(file_path)))
if not latest_image:
return web.json_response({"error": "No recent images found to use for recipe. Try generating an image first."}, status=400)
if not image_files:
return web.json_response({"error": "No recent images found to use for recipe"}, status=400)
# Convert the image data to bytes - handle tuple and tensor cases
logger.debug(f"Image type: {type(latest_image)}")
# Sort by modification time (newest first)
image_files.sort(key=lambda x: x[1], reverse=True)
latest_image_path = image_files[0][0]
try:
# Handle the tuple case first
if isinstance(latest_image, tuple):
# Extract the tensor from the tuple
if len(latest_image) > 0:
tensor_image = latest_image[0]
else:
return web.json_response({"error": "Empty image tuple received"}, status=400)
else:
tensor_image = latest_image
# Get the shape info for debugging
if hasattr(tensor_image, 'shape'):
shape_info = tensor_image.shape
logger.debug(f"Tensor shape: {shape_info}, dtype: {tensor_image.dtype}")
# Convert tensor to numpy array
if isinstance(tensor_image, torch.Tensor):
image_np = tensor_image.cpu().numpy()
else:
image_np = np.array(tensor_image)
# Handle different tensor shapes
# Case: (1, 1, H, W, 3) or (1, H, W, 3) - batch or multi-batch
if len(image_np.shape) > 3:
# Remove batch dimensions until we get to (H, W, 3)
while len(image_np.shape) > 3:
image_np = image_np[0]
# If values are in [0, 1] range, convert to [0, 255]
if image_np.dtype == np.float32 or image_np.dtype == np.float64:
if image_np.max() <= 1.0:
image_np = (image_np * 255).astype(np.uint8)
# Ensure image is in the right format (HWC with RGB channels)
if len(image_np.shape) == 3 and image_np.shape[2] == 3:
pil_image = Image.fromarray(image_np)
img_byte_arr = io.BytesIO()
pil_image.save(img_byte_arr, format='PNG')
image = img_byte_arr.getvalue()
else:
return web.json_response({"error": f"Cannot handle this data shape: {image_np.shape}, {image_np.dtype}"}, status=400)
except Exception as e:
logger.error(f"Error processing image data: {str(e)}", exc_info=True)
return web.json_response({"error": f"Error processing image: {str(e)}"}, status=400)
# Get the lora stack from the metadata
lora_stack = metadata_dict.get("loras", "")
@@ -834,10 +879,6 @@ class RecipeRoutes:
recipe_name = " ".join(recipe_name_parts)
# Read the image
with open(latest_image_path, 'rb') as f:
image = f.read()
# Create recipes directory if it doesn't exist
recipes_dir = self.recipe_scanner.recipes_dir
os.makedirs(recipes_dir, exist_ok=True)