diff --git a/py/utils/exif_utils.py b/py/utils/exif_utils.py index a39f7880..23a9325c 100644 --- a/py/utils/exif_utils.py +++ b/py/utils/exif_utils.py @@ -15,17 +15,33 @@ class ExifUtils: def extract_user_comment(image_path: str) -> Optional[str]: """Extract UserComment field from image EXIF data""" try: - exif_dict = piexif.load(image_path) - - if piexif.ExifIFD.UserComment in exif_dict.get('Exif', {}): - user_comment = exif_dict['Exif'][piexif.ExifIFD.UserComment] - if isinstance(user_comment, bytes): - if user_comment.startswith(b'UNICODE\0'): - user_comment = user_comment[8:].decode('utf-16be') - else: - user_comment = user_comment.decode('utf-8', errors='ignore') - return user_comment - return None + # First try to open as image to check format + with Image.open(image_path) as img: + if img.format not in ['JPEG', 'TIFF']: + # For non-JPEG/TIFF images, try to get EXIF through PIL + exif = img._getexif() + if exif and piexif.ExifIFD.UserComment in exif: + user_comment = exif[piexif.ExifIFD.UserComment] + if isinstance(user_comment, bytes): + if user_comment.startswith(b'UNICODE\0'): + return user_comment[8:].decode('utf-16be') + return user_comment.decode('utf-8', errors='ignore') + return user_comment + return None + + # For JPEG/TIFF, use piexif + exif_dict = piexif.load(image_path) + + if piexif.ExifIFD.UserComment in exif_dict.get('Exif', {}): + user_comment = exif_dict['Exif'][piexif.ExifIFD.UserComment] + if isinstance(user_comment, bytes): + if user_comment.startswith(b'UNICODE\0'): + user_comment = user_comment[8:].decode('utf-16be') + else: + user_comment = user_comment.decode('utf-8', errors='ignore') + return user_comment + return None + except Exception as e: logger.error(f"Error extracting EXIF data from {image_path}: {e}") return None diff --git a/py/utils/recipe_parsers.py b/py/utils/recipe_parsers.py index 86c3cc57..ba68cf88 100644 --- a/py/utils/recipe_parsers.py +++ b/py/utils/recipe_parsers.py @@ -2,12 +2,24 @@ import json import logging import os import re -from typing import Dict, List, Any, Optional +from typing import Dict, List, Any, Optional, Tuple from abc import ABC, abstractmethod from ..config import config logger = logging.getLogger(__name__) +# Constants for generation parameters +GEN_PARAM_KEYS = [ + 'prompt', + 'negative_prompt', + 'steps', + 'sampler', + 'cfg_scale', + 'seed', + 'size', + 'clip_skip', +] + class RecipeMetadataParser(ABC): """Interface for parsing recipe metadata from image user comments""" @@ -128,10 +140,17 @@ class RecipeFormatParser(RecipeMetadataParser): logger.info(f"Found {len(loras)} loras in recipe metadata") + # Filter gen_params to only include recognized keys + filtered_gen_params = {} + if 'gen_params' in recipe_metadata: + for key, value in recipe_metadata['gen_params'].items(): + if key in GEN_PARAM_KEYS: + filtered_gen_params[key] = value + return { 'base_model': recipe_metadata.get('base_model', ''), 'loras': loras, - 'gen_params': recipe_metadata.get('gen_params', {}), + 'gen_params': filtered_gen_params, 'tags': recipe_metadata.get('tags', []), 'title': recipe_metadata.get('title', ''), 'from_recipe_metadata': True @@ -251,17 +270,10 @@ class StandardMetadataParser(RecipeMetadataParser): base_model = max(base_model_counts.items(), key=lambda x: x[1])[0] # Extract generation parameters for recipe metadata - gen_params = { - 'prompt': metadata.get('prompt', ''), - 'negative_prompt': metadata.get('negative_prompt', ''), - 'checkpoint': checkpoint, - 'steps': metadata.get('steps', ''), - 'sampler': metadata.get('sampler', ''), - 'cfg_scale': metadata.get('cfg_scale', ''), - 'seed': metadata.get('seed', ''), - 'size': metadata.get('size', ''), - 'clip_skip': metadata.get('clip_skip', '') - } + gen_params = {} + for key in GEN_PARAM_KEYS: + if key in metadata: + gen_params[key] = metadata.get(key, '') return { 'base_model': base_model, @@ -330,6 +342,168 @@ class StandardMetadataParser(RecipeMetadataParser): return {"prompt": user_comment, "loras": [], "checkpoint": None} +class A1111MetadataParser(RecipeMetadataParser): + """Parser for images with A1111 metadata format (Lora hashes)""" + + METADATA_MARKER = r'Lora hashes:' + LORA_PATTERN = r']+)>' + LORA_HASH_PATTERN = r'([^:]+): ([a-f0-9]+)' + + def is_metadata_matching(self, user_comment: str) -> bool: + """Check if the user comment matches the A1111 metadata format""" + return 'Lora hashes:' in user_comment + + async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]: + """Parse metadata from images with A1111 metadata format""" + try: + # Extract prompt and negative prompt + parts = user_comment.split('Negative prompt:', 1) + prompt = parts[0].strip() + + # Initialize metadata + metadata = {"prompt": prompt, "loras": []} + + # Extract negative prompt and parameters + if len(parts) > 1: + negative_and_params = parts[1] + + # Extract negative prompt + if "Steps:" in negative_and_params: + neg_prompt = negative_and_params.split("Steps:", 1)[0].strip() + metadata["negative_prompt"] = neg_prompt + + # Extract key-value parameters (Steps, Sampler, CFG scale, etc.) + param_pattern = r'([A-Za-z ]+): ([^,]+)' + params = re.findall(param_pattern, negative_and_params) + for key, value in params: + clean_key = key.strip().lower().replace(' ', '_') + metadata[clean_key] = value.strip() + + # Extract LoRA information from prompt + lora_weights = {} + lora_matches = re.findall(self.LORA_PATTERN, prompt) + for lora_name, weight in lora_matches: + lora_weights[lora_name.strip()] = float(weight.strip()) + + # Remove LoRA patterns from prompt + metadata["prompt"] = re.sub(self.LORA_PATTERN, '', prompt).strip() + + # Extract LoRA hashes + lora_hashes = {} + if 'Lora hashes:' in user_comment: + lora_hash_section = user_comment.split('Lora hashes:', 1)[1].strip() + if lora_hash_section.startswith('"'): + lora_hash_section = lora_hash_section[1:].split('"', 1)[0] + hash_matches = re.findall(self.LORA_HASH_PATTERN, lora_hash_section) + for lora_name, hash_value in hash_matches: + # Remove any leading comma and space from lora name + clean_name = lora_name.strip().lstrip(',').strip() + lora_hashes[clean_name] = hash_value.strip() + + # Process LoRAs and collect base models + base_model_counts = {} + loras = [] + + # Process each LoRA with hash and weight + for lora_name, hash_value in lora_hashes.items(): + weight = lora_weights.get(lora_name, 1.0) + + # Initialize lora entry with default values + lora_entry = { + 'name': lora_name, + 'type': 'lora', + 'weight': weight, + 'existsLocally': False, + 'localPath': None, + 'file_name': lora_name, + 'hash': hash_value, + 'thumbnailUrl': '/loras_static/images/no-preview.png', + 'baseModel': '', + 'size': 0, + 'downloadUrl': '', + 'isDeleted': False + } + + # Get info from Civitai by hash + if civitai_client: + try: + civitai_info = await civitai_client.get_model_by_hash(hash_value) + if civitai_info and civitai_info.get("error") != "Model not found": + # Get model version ID + lora_entry['id'] = civitai_info.get('modelVersionId', '') + + # Get model name and version + lora_entry['name'] = civitai_info.get('modelName', lora_name) + lora_entry['version'] = civitai_info.get('modelVersionName', '') + + # Get thumbnail URL + if 'images' in civitai_info and civitai_info['images']: + lora_entry['thumbnailUrl'] = civitai_info['images'][0].get('url', '') + + # Get base model and update counts + current_base_model = civitai_info.get('baseModel', '') + lora_entry['baseModel'] = current_base_model + if current_base_model: + base_model_counts[current_base_model] = base_model_counts.get(current_base_model, 0) + 1 + + # Get download URL + lora_entry['downloadUrl'] = civitai_info.get('downloadUrl', '') + + # Get file name and size from Civitai + if 'files' in civitai_info: + model_file = next((file for file in civitai_info.get('files', []) + if file.get('type') == 'Model'), None) + if model_file: + file_name = model_file.get('name', '') + lora_entry['file_name'] = os.path.splitext(file_name)[0] if file_name else lora_name + lora_entry['size'] = model_file.get('sizeKB', 0) * 1024 + # Update hash to sha256 + lora_entry['hash'] = model_file.get('hashes', {}).get('SHA256', hash_value).lower() + + # Check if exists locally with sha256 hash + if recipe_scanner and lora_entry['hash']: + lora_scanner = recipe_scanner._lora_scanner + exists_locally = lora_scanner.has_lora_hash(lora_entry['hash']) + if exists_locally: + lora_cache = await lora_scanner.get_cached_data() + lora_item = next((item for item in lora_cache.raw_data if item['sha256'] == lora_entry['hash']), None) + if lora_item: + lora_entry['existsLocally'] = True + lora_entry['localPath'] = lora_item['file_path'] + lora_entry['thumbnailUrl'] = config.get_preview_static_url(lora_item['preview_url']) + + except Exception as e: + logger.error(f"Error fetching Civitai info for LoRA hash {hash_value}: {e}") + + loras.append(lora_entry) + + # Set base_model to the most common one from civitai_info + base_model = None + if base_model_counts: + base_model = max(base_model_counts.items(), key=lambda x: x[1])[0] + + # Extract generation parameters for recipe metadata + gen_params = {} + for key in GEN_PARAM_KEYS: + if key in metadata: + gen_params[key] = metadata.get(key, '') + + # Add model information if available + if 'model' in metadata: + gen_params['checkpoint'] = metadata['model'] + + return { + 'base_model': base_model, + 'loras': loras, + 'gen_params': gen_params, + 'raw_metadata': metadata + } + + except Exception as e: + logger.error(f"Error parsing A1111 metadata: {e}", exc_info=True) + return {"error": str(e), "loras": []} + + class RecipeParserFactory: """Factory for creating recipe metadata parsers""" @@ -345,11 +519,14 @@ class RecipeParserFactory: Appropriate RecipeMetadataParser implementation """ if RecipeFormatParser().is_metadata_matching(user_comment): - print("RecipeFormatParser") + logger.info("RecipeFormatParser") return RecipeFormatParser() elif StandardMetadataParser().is_metadata_matching(user_comment): - print("StandardMetadataParser") - return StandardMetadataParser() + logger.info("StandardMetadataParser") + return StandardMetadataParser() + elif A1111MetadataParser().is_metadata_matching(user_comment): + logger.info("A1111MetadataParser") + return A1111MetadataParser() else: - print("None") + logger.info("No parser found for this image") return None \ No newline at end of file