From 8b85e083e2cd5b3e6d70a23b7c19d3a728e059a5 Mon Sep 17 00:00:00 2001 From: Will Miao Date: Wed, 25 Mar 2026 08:43:33 +0800 Subject: [PATCH] feat(recipe-parser): add SuiImage metadata format support - Add SuiImageParamsParser for sui_image_params JSON format - Register new parser in RecipeParserFactory - Fix metadata_provider auto-initialization when not ready - Add 10 test cases for SuiImageParamsParser Fixes batch import failure for images with sui_image_params metadata. --- py/recipes/factory.py | 8 + py/recipes/parsers/__init__.py | 2 + py/recipes/parsers/sui_image_params.py | 188 ++++++++++++++++ py/services/metadata_service.py | 24 ++- tests/services/test_recipe_services.py | 2 +- .../services/test_sui_image_params_parser.py | 202 ++++++++++++++++++ 6 files changed, 420 insertions(+), 6 deletions(-) create mode 100644 py/recipes/parsers/sui_image_params.py create mode 100644 tests/services/test_sui_image_params_parser.py diff --git a/py/recipes/factory.py b/py/recipes/factory.py index 6dbcee2b..963ea710 100644 --- a/py/recipes/factory.py +++ b/py/recipes/factory.py @@ -7,6 +7,7 @@ from .parsers import ( MetaFormatParser, AutomaticMetadataParser, CivitaiApiMetadataParser, + SuiImageParamsParser, ) from .base import RecipeMetadataParser @@ -55,6 +56,13 @@ class RecipeParserFactory: # If JSON parsing fails, move on to other parsers pass + # Try SuiImageParamsParser for SuiImage metadata format + try: + if SuiImageParamsParser().is_metadata_matching(metadata_str): + return SuiImageParamsParser() + except Exception: + pass + # Check other parsers that expect string input if RecipeFormatParser().is_metadata_matching(metadata_str): return RecipeFormatParser() diff --git a/py/recipes/parsers/__init__.py b/py/recipes/parsers/__init__.py index 436737a4..81d6419f 100644 --- a/py/recipes/parsers/__init__.py +++ b/py/recipes/parsers/__init__.py @@ -5,6 +5,7 @@ from .comfy import ComfyMetadataParser from .meta_format import MetaFormatParser from .automatic import AutomaticMetadataParser from .civitai_image import CivitaiApiMetadataParser +from .sui_image_params import SuiImageParamsParser __all__ = [ 'RecipeFormatParser', @@ -12,4 +13,5 @@ __all__ = [ 'MetaFormatParser', 'AutomaticMetadataParser', 'CivitaiApiMetadataParser', + 'SuiImageParamsParser', ] diff --git a/py/recipes/parsers/sui_image_params.py b/py/recipes/parsers/sui_image_params.py new file mode 100644 index 00000000..9a6a2b13 --- /dev/null +++ b/py/recipes/parsers/sui_image_params.py @@ -0,0 +1,188 @@ +"""Parser for SuiImage (Stable Diffusion WebUI) metadata format.""" + +import json +import logging +from typing import Dict, Any, Optional, List +from ..base import RecipeMetadataParser +from ...services.metadata_service import get_default_metadata_provider + +logger = logging.getLogger(__name__) + + +class SuiImageParamsParser(RecipeMetadataParser): + """Parser for SuiImage metadata JSON format. + + This format is used by some Stable Diffusion WebUI variants. + Structure: + { + "sui_image_params": { + "prompt": "...", + "negativeprompt": "...", + "model": "...", + "seed": ..., + "steps": ..., + ... + }, + "sui_models": [ + {"name": "...", "param": "model", "hash": "..."}, + ... + ], + "sui_extra_data": {...} + } + """ + + def is_metadata_matching(self, user_comment: str) -> bool: + """Check if the user comment matches the SuiImage metadata format""" + try: + data = json.loads(user_comment) + return isinstance(data, dict) and 'sui_image_params' in data + except (json.JSONDecodeError, TypeError): + return False + + async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]: + """Parse metadata from SuiImage metadata format""" + try: + metadata_provider = await get_default_metadata_provider() + + data = json.loads(user_comment) + params = data.get('sui_image_params', {}) + models = data.get('sui_models', []) + + # Extract prompt and negative prompt + prompt = params.get('prompt', '') + negative_prompt = params.get('negativeprompt', '') or params.get('negative_prompt', '') + + # Extract generation parameters + gen_params = {} + if prompt: + gen_params['prompt'] = prompt + if negative_prompt: + gen_params['negative_prompt'] = negative_prompt + + # Map standard parameters + param_mapping = { + 'steps': 'steps', + 'seed': 'seed', + 'cfgscale': 'cfg_scale', + 'cfg_scale': 'cfg_scale', + 'width': 'width', + 'height': 'height', + 'sampler': 'sampler', + 'scheduler': 'scheduler', + 'model': 'model', + 'vae': 'vae', + } + + for src_key, dest_key in param_mapping.items(): + if src_key in params and params[src_key] is not None: + gen_params[dest_key] = params[src_key] + + # Add size info if available + if 'width' in gen_params and 'height' in gen_params: + gen_params['size'] = f"{gen_params['width']}x{gen_params['height']}" + + # Process models - extract checkpoint and loras + loras: List[Dict[str, Any]] = [] + checkpoint: Optional[Dict[str, Any]] = None + + for model in models: + model_name = model.get('name', '') + param_type = model.get('param', '') + model_hash = model.get('hash', '') + + # Remove .safetensors extension for cleaner name + clean_name = model_name.replace('.safetensors', '') if model_name else '' + + # Check if this is a LoRA by looking at the name or param type + is_lora = 'lora' in model_name.lower() or param_type.lower().startswith('lora') + + if is_lora: + lora_entry = { + 'id': 0, + 'modelId': 0, + 'name': clean_name, + 'version': '', + 'type': 'lora', + 'weight': 1.0, + 'existsLocally': False, + 'localPath': None, + 'file_name': model_name, + 'hash': model_hash.replace('0x', '') if model_hash.startswith('0x') else model_hash, + 'thumbnailUrl': '/loras_static/images/no-preview.png', + 'baseModel': '', + 'size': 0, + 'downloadUrl': '', + 'isDeleted': False + } + + # Try to get additional info from metadata provider + if metadata_provider and model_hash: + try: + civitai_info = await metadata_provider.get_model_by_hash( + model_hash.replace('0x', '') if model_hash.startswith('0x') else model_hash + ) + if civitai_info: + lora_entry = await self.populate_lora_from_civitai( + lora_entry, civitai_info, recipe_scanner + ) + except Exception as e: + logger.debug(f"Error fetching info for LoRA {clean_name}: {e}") + + if lora_entry: + loras.append(lora_entry) + elif param_type == 'model' or 'lora' not in model_name.lower(): + # This is likely a checkpoint + checkpoint_entry = { + 'id': 0, + 'modelId': 0, + 'name': clean_name, + 'version': '', + 'type': 'checkpoint', + 'hash': model_hash.replace('0x', '') if model_hash.startswith('0x') else model_hash, + 'existsLocally': False, + 'localPath': None, + 'file_name': model_name, + 'thumbnailUrl': '/loras_static/images/no-preview.png', + 'baseModel': '', + 'size': 0, + 'downloadUrl': '', + 'isDeleted': False + } + + # Try to get additional info from metadata provider + if metadata_provider and model_hash: + try: + civitai_info = await metadata_provider.get_model_by_hash( + model_hash.replace('0x', '') if model_hash.startswith('0x') else model_hash + ) + if civitai_info: + checkpoint_entry = await self.populate_checkpoint_from_civitai( + checkpoint_entry, civitai_info + ) + except Exception as e: + logger.debug(f"Error fetching info for checkpoint {clean_name}: {e}") + + checkpoint = checkpoint_entry + + # Determine base model from loras or checkpoint + base_model = None + if loras: + base_models = [lora.get('baseModel') for lora in loras if lora.get('baseModel')] + if base_models: + from collections import Counter + base_model_counts = Counter(base_models) + base_model = base_model_counts.most_common(1)[0][0] + elif checkpoint and checkpoint.get('baseModel'): + base_model = checkpoint['baseModel'] + + return { + 'base_model': base_model, + 'loras': loras, + 'checkpoint': checkpoint, + 'gen_params': gen_params, + 'from_sui_image_params': True + } + + except Exception as e: + logger.error(f"Error parsing SuiImage metadata: {e}", exc_info=True) + return {"error": str(e), "loras": []} diff --git a/py/services/metadata_service.py b/py/services/metadata_service.py index 8909653a..95cb4d0d 100644 --- a/py/services/metadata_service.py +++ b/py/services/metadata_service.py @@ -122,11 +122,25 @@ async def get_metadata_provider(provider_name: str = None): provider_manager = await ModelMetadataProviderManager.get_instance() - provider = ( - provider_manager._get_provider(provider_name) - if provider_name - else provider_manager._get_provider() - ) + try: + provider = ( + provider_manager._get_provider(provider_name) + if provider_name + else provider_manager._get_provider() + ) + except ValueError as e: + # Provider not initialized, attempt to initialize + if "No default provider set" in str(e) or "not registered" in str(e): + logger.warning(f"Metadata provider not initialized ({e}), initializing now...") + await initialize_metadata_providers() + provider_manager = await ModelMetadataProviderManager.get_instance() + provider = ( + provider_manager._get_provider(provider_name) + if provider_name + else provider_manager._get_provider() + ) + else: + raise return _wrap_provider_with_rate_limit(provider_name, provider) diff --git a/tests/services/test_recipe_services.py b/tests/services/test_recipe_services.py index e4551de2..1db28271 100644 --- a/tests/services/test_recipe_services.py +++ b/tests/services/test_recipe_services.py @@ -492,7 +492,7 @@ async def test_analyze_remote_video(tmp_path): class DummyFactory: def create_parser(self, metadata): - async def parse_metadata(m, recipe_scanner): + async def parse_metadata(m, recipe_scanner=None, civitai_client=None): return {"loras": []} return SimpleNamespace(parse_metadata=parse_metadata) diff --git a/tests/services/test_sui_image_params_parser.py b/tests/services/test_sui_image_params_parser.py new file mode 100644 index 00000000..9bcaf04d --- /dev/null +++ b/tests/services/test_sui_image_params_parser.py @@ -0,0 +1,202 @@ +"""Tests for SuiImageParamsParser.""" + +import pytest +import json +from py.recipes.parsers import SuiImageParamsParser + + +class TestSuiImageParamsParser: + """Test cases for SuiImageParamsParser.""" + + def setup_method(self): + """Set up test fixtures.""" + self.parser = SuiImageParamsParser() + + def test_is_metadata_matching_positive(self): + """Test that parser correctly identifies SuiImage metadata format.""" + metadata = { + "sui_image_params": { + "prompt": "test prompt", + "model": "test_model" + } + } + metadata_str = json.dumps(metadata) + assert self.parser.is_metadata_matching(metadata_str) is True + + def test_is_metadata_matching_negative(self): + """Test that parser rejects non-SuiImage metadata.""" + # Missing sui_image_params key + metadata = { + "other_params": { + "prompt": "test prompt" + } + } + metadata_str = json.dumps(metadata) + assert self.parser.is_metadata_matching(metadata_str) is False + + def test_is_metadata_matching_invalid_json(self): + """Test that parser handles invalid JSON gracefully.""" + metadata_str = "not valid json" + assert self.parser.is_metadata_matching(metadata_str) is False + + @pytest.mark.asyncio + async def test_parse_metadata_extracts_basic_fields(self): + """Test parsing basic fields from SuiImage metadata.""" + metadata = { + "sui_image_params": { + "prompt": "beautiful landscape", + "negativeprompt": "ugly, blurry", + "steps": 30, + "seed": 12345, + "cfgscale": 7.5, + "width": 512, + "height": 768, + "sampler": "Euler a", + "scheduler": "normal" + }, + "sui_models": [], + "sui_extra_data": {} + } + metadata_str = json.dumps(metadata) + result = await self.parser.parse_metadata(metadata_str) + + assert result.get('gen_params', {}).get('prompt') == "beautiful landscape" + assert result.get('gen_params', {}).get('negative_prompt') == "ugly, blurry" + assert result.get('gen_params', {}).get('steps') == 30 + assert result.get('gen_params', {}).get('seed') == 12345 + assert result.get('gen_params', {}).get('cfg_scale') == 7.5 + assert result.get('gen_params', {}).get('width') == 512 + assert result.get('gen_params', {}).get('height') == 768 + assert result.get('gen_params', {}).get('size') == "512x768" + assert result.get('loras') == [] + + @pytest.mark.asyncio + async def test_parse_metadata_extracts_checkpoint(self): + """Test parsing checkpoint from sui_models.""" + metadata = { + "sui_image_params": { + "prompt": "test prompt", + "model": "checkpoint_model" + }, + "sui_models": [ + { + "name": "test_checkpoint.safetensors", + "param": "model", + "hash": "0x1234567890abcdef" + } + ], + "sui_extra_data": {} + } + metadata_str = json.dumps(metadata) + result = await self.parser.parse_metadata(metadata_str) + + checkpoint = result.get('checkpoint') + assert checkpoint is not None + assert checkpoint['type'] == 'checkpoint' + assert checkpoint['name'] == 'test_checkpoint' + assert checkpoint['hash'] == '1234567890abcdef' + + @pytest.mark.asyncio + async def test_parse_metadata_extracts_lora(self): + """Test parsing LoRA from sui_models.""" + metadata = { + "sui_image_params": { + "prompt": "test prompt" + }, + "sui_models": [ + { + "name": "test_lora.safetensors", + "param": "lora", + "hash": "0xabcdef1234567890" + } + ], + "sui_extra_data": {} + } + metadata_str = json.dumps(metadata) + result = await self.parser.parse_metadata(metadata_str) + + loras = result.get('loras') + assert len(loras) == 1 + assert loras[0]['type'] == 'lora' + assert loras[0]['name'] == 'test_lora' + assert loras[0]['file_name'] == 'test_lora.safetensors' + assert loras[0]['hash'] == 'abcdef1234567890' + + @pytest.mark.asyncio + async def test_parse_metadata_handles_lora_in_name(self): + """Test that LoRA is detected by 'lora' in name.""" + metadata = { + "sui_image_params": { + "prompt": "test prompt" + }, + "sui_models": [ + { + "name": "style_lora_v2.safetensors", + "param": "some_other_param", + "hash": "0x1111111111111111" + } + ], + "sui_extra_data": {} + } + metadata_str = json.dumps(metadata) + result = await self.parser.parse_metadata(metadata_str) + + loras = result.get('loras') + assert len(loras) == 1 + assert loras[0]['type'] == 'lora' + + @pytest.mark.asyncio + async def test_parse_metadata_empty_models(self): + """Test parsing with empty sui_models array.""" + metadata = { + "sui_image_params": { + "prompt": "test prompt", + "steps": 20 + }, + "sui_models": [], + "sui_extra_data": { + "date": "2024-01-01" + } + } + metadata_str = json.dumps(metadata) + result = await self.parser.parse_metadata(metadata_str) + + assert result.get('loras') == [] + assert result.get('checkpoint') is None + assert result.get('gen_params', {}).get('prompt') == "test prompt" + assert result.get('gen_params', {}).get('steps') == 20 + + @pytest.mark.asyncio + async def test_parse_metadata_alternative_field_names(self): + """Test parsing with alternative field names.""" + metadata = { + "sui_image_params": { + "prompt": "test prompt", + "negative_prompt": "bad quality", # Using underscore variant + "cfg_scale": 6.0 # Using underscore variant + }, + "sui_models": [], + "sui_extra_data": {} + } + metadata_str = json.dumps(metadata) + result = await self.parser.parse_metadata(metadata_str) + + assert result.get('gen_params', {}).get('negative_prompt') == "bad quality" + assert result.get('gen_params', {}).get('cfg_scale') == 6.0 + + @pytest.mark.asyncio + async def test_parse_metadata_error_handling(self): + """Test that parser handles malformed data gracefully.""" + # Missing required fields + metadata = { + "sui_image_params": {}, + "sui_models": [], + "sui_extra_data": {} + } + metadata_str = json.dumps(metadata) + result = await self.parser.parse_metadata(metadata_str) + + assert 'error' not in result + assert result.get('loras') == [] + # Empty params result in empty gen_params dict + assert result.get('gen_params') == {}