mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
feat: Enhance image retrieval in MetadataRegistry and update recipe routes to process images from metadata
This commit is contained in:
@@ -239,6 +239,14 @@ class MetadataRegistry:
|
|||||||
|
|
||||||
metadata = self.prompt_metadata[key]
|
metadata = self.prompt_metadata[key]
|
||||||
if IMAGES in metadata and "first_decode" in metadata[IMAGES]:
|
if IMAGES in metadata and "first_decode" in metadata[IMAGES]:
|
||||||
return metadata[IMAGES]["first_decode"]["image"]
|
image_data = metadata[IMAGES]["first_decode"]["image"]
|
||||||
|
|
||||||
|
# If it's an image batch or tuple, handle various formats
|
||||||
|
if isinstance(image_data, (list, tuple)) and len(image_data) > 0:
|
||||||
|
# Return first element of list/tuple
|
||||||
|
return image_data[0]
|
||||||
|
|
||||||
|
# If it's a tensor, return as is for processing in the route handler
|
||||||
|
return image_data
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -1,5 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
import io
|
||||||
import logging
|
import logging
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
@@ -15,6 +19,7 @@ from ..metadata_collector import get_metadata # Add MetadataCollector import
|
|||||||
from ..metadata_collector.metadata_processor import MetadataProcessor # Add MetadataProcessor import
|
from ..metadata_collector.metadata_processor import MetadataProcessor # Add MetadataProcessor import
|
||||||
from ..utils.utils import download_civitai_image
|
from ..utils.utils import download_civitai_image
|
||||||
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import
|
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import
|
||||||
|
from ..metadata_collector.metadata_registry import MetadataRegistry
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -657,8 +662,8 @@ class RecipeRoutes:
|
|||||||
logger.error(f"Error retrieving base models: {e}", exc_info=True)
|
logger.error(f"Error retrieving base models: {e}", exc_info=True)
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
'success': False,
|
'success': False,
|
||||||
'error': str(e)
|
'error': str(e)}
|
||||||
}, status=500)
|
, status=500)
|
||||||
|
|
||||||
async def share_recipe(self, request: web.Request) -> web.Response:
|
async def share_recipe(self, request: web.Request) -> web.Response:
|
||||||
"""Process a recipe image for sharing by adding metadata to EXIF"""
|
"""Process a recipe image for sharing by adding metadata to EXIF"""
|
||||||
@@ -795,21 +800,61 @@ class RecipeRoutes:
|
|||||||
if not metadata_dict:
|
if not metadata_dict:
|
||||||
return web.json_response({"error": "No generation metadata found"}, status=400)
|
return web.json_response({"error": "No generation metadata found"}, status=400)
|
||||||
|
|
||||||
# Find the latest image in the temp directory
|
# Get the most recent image from metadata registry instead of temp directory
|
||||||
temp_dir = config.temp_directory
|
metadata_registry = MetadataRegistry()
|
||||||
image_files = []
|
latest_image = metadata_registry.get_first_decoded_image()
|
||||||
|
|
||||||
for file in os.listdir(temp_dir):
|
if not latest_image:
|
||||||
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
|
return web.json_response({"error": "No recent images found to use for recipe. Try generating an image first."}, status=400)
|
||||||
file_path = os.path.join(temp_dir, file)
|
|
||||||
image_files.append((file_path, os.path.getmtime(file_path)))
|
|
||||||
|
|
||||||
if not image_files:
|
# Convert the image data to bytes - handle tuple and tensor cases
|
||||||
return web.json_response({"error": "No recent images found to use for recipe"}, status=400)
|
logger.debug(f"Image type: {type(latest_image)}")
|
||||||
|
|
||||||
# Sort by modification time (newest first)
|
try:
|
||||||
image_files.sort(key=lambda x: x[1], reverse=True)
|
# Handle the tuple case first
|
||||||
latest_image_path = image_files[0][0]
|
if isinstance(latest_image, tuple):
|
||||||
|
# Extract the tensor from the tuple
|
||||||
|
if len(latest_image) > 0:
|
||||||
|
tensor_image = latest_image[0]
|
||||||
|
else:
|
||||||
|
return web.json_response({"error": "Empty image tuple received"}, status=400)
|
||||||
|
else:
|
||||||
|
tensor_image = latest_image
|
||||||
|
|
||||||
|
# Get the shape info for debugging
|
||||||
|
if hasattr(tensor_image, 'shape'):
|
||||||
|
shape_info = tensor_image.shape
|
||||||
|
logger.debug(f"Tensor shape: {shape_info}, dtype: {tensor_image.dtype}")
|
||||||
|
|
||||||
|
# Convert tensor to numpy array
|
||||||
|
if isinstance(tensor_image, torch.Tensor):
|
||||||
|
image_np = tensor_image.cpu().numpy()
|
||||||
|
else:
|
||||||
|
image_np = np.array(tensor_image)
|
||||||
|
|
||||||
|
# Handle different tensor shapes
|
||||||
|
# Case: (1, 1, H, W, 3) or (1, H, W, 3) - batch or multi-batch
|
||||||
|
if len(image_np.shape) > 3:
|
||||||
|
# Remove batch dimensions until we get to (H, W, 3)
|
||||||
|
while len(image_np.shape) > 3:
|
||||||
|
image_np = image_np[0]
|
||||||
|
|
||||||
|
# If values are in [0, 1] range, convert to [0, 255]
|
||||||
|
if image_np.dtype == np.float32 or image_np.dtype == np.float64:
|
||||||
|
if image_np.max() <= 1.0:
|
||||||
|
image_np = (image_np * 255).astype(np.uint8)
|
||||||
|
|
||||||
|
# Ensure image is in the right format (HWC with RGB channels)
|
||||||
|
if len(image_np.shape) == 3 and image_np.shape[2] == 3:
|
||||||
|
pil_image = Image.fromarray(image_np)
|
||||||
|
img_byte_arr = io.BytesIO()
|
||||||
|
pil_image.save(img_byte_arr, format='PNG')
|
||||||
|
image = img_byte_arr.getvalue()
|
||||||
|
else:
|
||||||
|
return web.json_response({"error": f"Cannot handle this data shape: {image_np.shape}, {image_np.dtype}"}, status=400)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing image data: {str(e)}", exc_info=True)
|
||||||
|
return web.json_response({"error": f"Error processing image: {str(e)}"}, status=400)
|
||||||
|
|
||||||
# Get the lora stack from the metadata
|
# Get the lora stack from the metadata
|
||||||
lora_stack = metadata_dict.get("loras", "")
|
lora_stack = metadata_dict.get("loras", "")
|
||||||
@@ -834,10 +879,6 @@ class RecipeRoutes:
|
|||||||
|
|
||||||
recipe_name = " ".join(recipe_name_parts)
|
recipe_name = " ".join(recipe_name_parts)
|
||||||
|
|
||||||
# Read the image
|
|
||||||
with open(latest_image_path, 'rb') as f:
|
|
||||||
image = f.read()
|
|
||||||
|
|
||||||
# Create recipes directory if it doesn't exist
|
# Create recipes directory if it doesn't exist
|
||||||
recipes_dir = self.recipe_scanner.recipes_dir
|
recipes_dir = self.recipe_scanner.recipes_dir
|
||||||
os.makedirs(recipes_dir, exist_ok=True)
|
os.makedirs(recipes_dir, exist_ok=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user