feat(recipe-parser): add SuiImage metadata format support

- Add SuiImageParamsParser for sui_image_params JSON format
- Register new parser in RecipeParserFactory
- Fix metadata_provider auto-initialization when not ready
- Add 10 test cases for SuiImageParamsParser

Fixes batch import failure for images with sui_image_params metadata.
This commit is contained in:
Will Miao
2026-03-25 08:43:33 +08:00
parent 9112cd3b62
commit 8b85e083e2
6 changed files with 420 additions and 6 deletions

View File

@@ -7,6 +7,7 @@ from .parsers import (
MetaFormatParser, MetaFormatParser,
AutomaticMetadataParser, AutomaticMetadataParser,
CivitaiApiMetadataParser, CivitaiApiMetadataParser,
SuiImageParamsParser,
) )
from .base import RecipeMetadataParser from .base import RecipeMetadataParser
@@ -55,6 +56,13 @@ class RecipeParserFactory:
# If JSON parsing fails, move on to other parsers # If JSON parsing fails, move on to other parsers
pass pass
# Try SuiImageParamsParser for SuiImage metadata format
try:
if SuiImageParamsParser().is_metadata_matching(metadata_str):
return SuiImageParamsParser()
except Exception:
pass
# Check other parsers that expect string input # Check other parsers that expect string input
if RecipeFormatParser().is_metadata_matching(metadata_str): if RecipeFormatParser().is_metadata_matching(metadata_str):
return RecipeFormatParser() return RecipeFormatParser()

View File

@@ -5,6 +5,7 @@ from .comfy import ComfyMetadataParser
from .meta_format import MetaFormatParser from .meta_format import MetaFormatParser
from .automatic import AutomaticMetadataParser from .automatic import AutomaticMetadataParser
from .civitai_image import CivitaiApiMetadataParser from .civitai_image import CivitaiApiMetadataParser
from .sui_image_params import SuiImageParamsParser
__all__ = [ __all__ = [
'RecipeFormatParser', 'RecipeFormatParser',
@@ -12,4 +13,5 @@ __all__ = [
'MetaFormatParser', 'MetaFormatParser',
'AutomaticMetadataParser', 'AutomaticMetadataParser',
'CivitaiApiMetadataParser', 'CivitaiApiMetadataParser',
'SuiImageParamsParser',
] ]

View File

@@ -0,0 +1,188 @@
"""Parser for SuiImage (Stable Diffusion WebUI) metadata format."""
import json
import logging
from typing import Dict, Any, Optional, List
from ..base import RecipeMetadataParser
from ...services.metadata_service import get_default_metadata_provider
logger = logging.getLogger(__name__)
class SuiImageParamsParser(RecipeMetadataParser):
"""Parser for SuiImage metadata JSON format.
This format is used by some Stable Diffusion WebUI variants.
Structure:
{
"sui_image_params": {
"prompt": "...",
"negativeprompt": "...",
"model": "...",
"seed": ...,
"steps": ...,
...
},
"sui_models": [
{"name": "...", "param": "model", "hash": "..."},
...
],
"sui_extra_data": {...}
}
"""
def is_metadata_matching(self, user_comment: str) -> bool:
"""Check if the user comment matches the SuiImage metadata format"""
try:
data = json.loads(user_comment)
return isinstance(data, dict) and 'sui_image_params' in data
except (json.JSONDecodeError, TypeError):
return False
async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]:
"""Parse metadata from SuiImage metadata format"""
try:
metadata_provider = await get_default_metadata_provider()
data = json.loads(user_comment)
params = data.get('sui_image_params', {})
models = data.get('sui_models', [])
# Extract prompt and negative prompt
prompt = params.get('prompt', '')
negative_prompt = params.get('negativeprompt', '') or params.get('negative_prompt', '')
# Extract generation parameters
gen_params = {}
if prompt:
gen_params['prompt'] = prompt
if negative_prompt:
gen_params['negative_prompt'] = negative_prompt
# Map standard parameters
param_mapping = {
'steps': 'steps',
'seed': 'seed',
'cfgscale': 'cfg_scale',
'cfg_scale': 'cfg_scale',
'width': 'width',
'height': 'height',
'sampler': 'sampler',
'scheduler': 'scheduler',
'model': 'model',
'vae': 'vae',
}
for src_key, dest_key in param_mapping.items():
if src_key in params and params[src_key] is not None:
gen_params[dest_key] = params[src_key]
# Add size info if available
if 'width' in gen_params and 'height' in gen_params:
gen_params['size'] = f"{gen_params['width']}x{gen_params['height']}"
# Process models - extract checkpoint and loras
loras: List[Dict[str, Any]] = []
checkpoint: Optional[Dict[str, Any]] = None
for model in models:
model_name = model.get('name', '')
param_type = model.get('param', '')
model_hash = model.get('hash', '')
# Remove .safetensors extension for cleaner name
clean_name = model_name.replace('.safetensors', '') if model_name else ''
# Check if this is a LoRA by looking at the name or param type
is_lora = 'lora' in model_name.lower() or param_type.lower().startswith('lora')
if is_lora:
lora_entry = {
'id': 0,
'modelId': 0,
'name': clean_name,
'version': '',
'type': 'lora',
'weight': 1.0,
'existsLocally': False,
'localPath': None,
'file_name': model_name,
'hash': model_hash.replace('0x', '') if model_hash.startswith('0x') else model_hash,
'thumbnailUrl': '/loras_static/images/no-preview.png',
'baseModel': '',
'size': 0,
'downloadUrl': '',
'isDeleted': False
}
# Try to get additional info from metadata provider
if metadata_provider and model_hash:
try:
civitai_info = await metadata_provider.get_model_by_hash(
model_hash.replace('0x', '') if model_hash.startswith('0x') else model_hash
)
if civitai_info:
lora_entry = await self.populate_lora_from_civitai(
lora_entry, civitai_info, recipe_scanner
)
except Exception as e:
logger.debug(f"Error fetching info for LoRA {clean_name}: {e}")
if lora_entry:
loras.append(lora_entry)
elif param_type == 'model' or 'lora' not in model_name.lower():
# This is likely a checkpoint
checkpoint_entry = {
'id': 0,
'modelId': 0,
'name': clean_name,
'version': '',
'type': 'checkpoint',
'hash': model_hash.replace('0x', '') if model_hash.startswith('0x') else model_hash,
'existsLocally': False,
'localPath': None,
'file_name': model_name,
'thumbnailUrl': '/loras_static/images/no-preview.png',
'baseModel': '',
'size': 0,
'downloadUrl': '',
'isDeleted': False
}
# Try to get additional info from metadata provider
if metadata_provider and model_hash:
try:
civitai_info = await metadata_provider.get_model_by_hash(
model_hash.replace('0x', '') if model_hash.startswith('0x') else model_hash
)
if civitai_info:
checkpoint_entry = await self.populate_checkpoint_from_civitai(
checkpoint_entry, civitai_info
)
except Exception as e:
logger.debug(f"Error fetching info for checkpoint {clean_name}: {e}")
checkpoint = checkpoint_entry
# Determine base model from loras or checkpoint
base_model = None
if loras:
base_models = [lora.get('baseModel') for lora in loras if lora.get('baseModel')]
if base_models:
from collections import Counter
base_model_counts = Counter(base_models)
base_model = base_model_counts.most_common(1)[0][0]
elif checkpoint and checkpoint.get('baseModel'):
base_model = checkpoint['baseModel']
return {
'base_model': base_model,
'loras': loras,
'checkpoint': checkpoint,
'gen_params': gen_params,
'from_sui_image_params': True
}
except Exception as e:
logger.error(f"Error parsing SuiImage metadata: {e}", exc_info=True)
return {"error": str(e), "loras": []}

View File

@@ -122,11 +122,25 @@ async def get_metadata_provider(provider_name: str = None):
provider_manager = await ModelMetadataProviderManager.get_instance() provider_manager = await ModelMetadataProviderManager.get_instance()
try:
provider = ( provider = (
provider_manager._get_provider(provider_name) provider_manager._get_provider(provider_name)
if provider_name if provider_name
else provider_manager._get_provider() else provider_manager._get_provider()
) )
except ValueError as e:
# Provider not initialized, attempt to initialize
if "No default provider set" in str(e) or "not registered" in str(e):
logger.warning(f"Metadata provider not initialized ({e}), initializing now...")
await initialize_metadata_providers()
provider_manager = await ModelMetadataProviderManager.get_instance()
provider = (
provider_manager._get_provider(provider_name)
if provider_name
else provider_manager._get_provider()
)
else:
raise
return _wrap_provider_with_rate_limit(provider_name, provider) return _wrap_provider_with_rate_limit(provider_name, provider)

View File

@@ -492,7 +492,7 @@ async def test_analyze_remote_video(tmp_path):
class DummyFactory: class DummyFactory:
def create_parser(self, metadata): def create_parser(self, metadata):
async def parse_metadata(m, recipe_scanner): async def parse_metadata(m, recipe_scanner=None, civitai_client=None):
return {"loras": []} return {"loras": []}
return SimpleNamespace(parse_metadata=parse_metadata) return SimpleNamespace(parse_metadata=parse_metadata)

View File

@@ -0,0 +1,202 @@
"""Tests for SuiImageParamsParser."""
import pytest
import json
from py.recipes.parsers import SuiImageParamsParser
class TestSuiImageParamsParser:
"""Test cases for SuiImageParamsParser."""
def setup_method(self):
"""Set up test fixtures."""
self.parser = SuiImageParamsParser()
def test_is_metadata_matching_positive(self):
"""Test that parser correctly identifies SuiImage metadata format."""
metadata = {
"sui_image_params": {
"prompt": "test prompt",
"model": "test_model"
}
}
metadata_str = json.dumps(metadata)
assert self.parser.is_metadata_matching(metadata_str) is True
def test_is_metadata_matching_negative(self):
"""Test that parser rejects non-SuiImage metadata."""
# Missing sui_image_params key
metadata = {
"other_params": {
"prompt": "test prompt"
}
}
metadata_str = json.dumps(metadata)
assert self.parser.is_metadata_matching(metadata_str) is False
def test_is_metadata_matching_invalid_json(self):
"""Test that parser handles invalid JSON gracefully."""
metadata_str = "not valid json"
assert self.parser.is_metadata_matching(metadata_str) is False
@pytest.mark.asyncio
async def test_parse_metadata_extracts_basic_fields(self):
"""Test parsing basic fields from SuiImage metadata."""
metadata = {
"sui_image_params": {
"prompt": "beautiful landscape",
"negativeprompt": "ugly, blurry",
"steps": 30,
"seed": 12345,
"cfgscale": 7.5,
"width": 512,
"height": 768,
"sampler": "Euler a",
"scheduler": "normal"
},
"sui_models": [],
"sui_extra_data": {}
}
metadata_str = json.dumps(metadata)
result = await self.parser.parse_metadata(metadata_str)
assert result.get('gen_params', {}).get('prompt') == "beautiful landscape"
assert result.get('gen_params', {}).get('negative_prompt') == "ugly, blurry"
assert result.get('gen_params', {}).get('steps') == 30
assert result.get('gen_params', {}).get('seed') == 12345
assert result.get('gen_params', {}).get('cfg_scale') == 7.5
assert result.get('gen_params', {}).get('width') == 512
assert result.get('gen_params', {}).get('height') == 768
assert result.get('gen_params', {}).get('size') == "512x768"
assert result.get('loras') == []
@pytest.mark.asyncio
async def test_parse_metadata_extracts_checkpoint(self):
"""Test parsing checkpoint from sui_models."""
metadata = {
"sui_image_params": {
"prompt": "test prompt",
"model": "checkpoint_model"
},
"sui_models": [
{
"name": "test_checkpoint.safetensors",
"param": "model",
"hash": "0x1234567890abcdef"
}
],
"sui_extra_data": {}
}
metadata_str = json.dumps(metadata)
result = await self.parser.parse_metadata(metadata_str)
checkpoint = result.get('checkpoint')
assert checkpoint is not None
assert checkpoint['type'] == 'checkpoint'
assert checkpoint['name'] == 'test_checkpoint'
assert checkpoint['hash'] == '1234567890abcdef'
@pytest.mark.asyncio
async def test_parse_metadata_extracts_lora(self):
"""Test parsing LoRA from sui_models."""
metadata = {
"sui_image_params": {
"prompt": "test prompt"
},
"sui_models": [
{
"name": "test_lora.safetensors",
"param": "lora",
"hash": "0xabcdef1234567890"
}
],
"sui_extra_data": {}
}
metadata_str = json.dumps(metadata)
result = await self.parser.parse_metadata(metadata_str)
loras = result.get('loras')
assert len(loras) == 1
assert loras[0]['type'] == 'lora'
assert loras[0]['name'] == 'test_lora'
assert loras[0]['file_name'] == 'test_lora.safetensors'
assert loras[0]['hash'] == 'abcdef1234567890'
@pytest.mark.asyncio
async def test_parse_metadata_handles_lora_in_name(self):
"""Test that LoRA is detected by 'lora' in name."""
metadata = {
"sui_image_params": {
"prompt": "test prompt"
},
"sui_models": [
{
"name": "style_lora_v2.safetensors",
"param": "some_other_param",
"hash": "0x1111111111111111"
}
],
"sui_extra_data": {}
}
metadata_str = json.dumps(metadata)
result = await self.parser.parse_metadata(metadata_str)
loras = result.get('loras')
assert len(loras) == 1
assert loras[0]['type'] == 'lora'
@pytest.mark.asyncio
async def test_parse_metadata_empty_models(self):
"""Test parsing with empty sui_models array."""
metadata = {
"sui_image_params": {
"prompt": "test prompt",
"steps": 20
},
"sui_models": [],
"sui_extra_data": {
"date": "2024-01-01"
}
}
metadata_str = json.dumps(metadata)
result = await self.parser.parse_metadata(metadata_str)
assert result.get('loras') == []
assert result.get('checkpoint') is None
assert result.get('gen_params', {}).get('prompt') == "test prompt"
assert result.get('gen_params', {}).get('steps') == 20
@pytest.mark.asyncio
async def test_parse_metadata_alternative_field_names(self):
"""Test parsing with alternative field names."""
metadata = {
"sui_image_params": {
"prompt": "test prompt",
"negative_prompt": "bad quality", # Using underscore variant
"cfg_scale": 6.0 # Using underscore variant
},
"sui_models": [],
"sui_extra_data": {}
}
metadata_str = json.dumps(metadata)
result = await self.parser.parse_metadata(metadata_str)
assert result.get('gen_params', {}).get('negative_prompt') == "bad quality"
assert result.get('gen_params', {}).get('cfg_scale') == 6.0
@pytest.mark.asyncio
async def test_parse_metadata_error_handling(self):
"""Test that parser handles malformed data gracefully."""
# Missing required fields
metadata = {
"sui_image_params": {},
"sui_models": [],
"sui_extra_data": {}
}
metadata_str = json.dumps(metadata)
result = await self.parser.parse_metadata(metadata_str)
assert 'error' not in result
assert result.get('loras') == []
# Empty params result in empty gen_params dict
assert result.get('gen_params') == {}