mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
feat(recipe-parser): add SuiImage metadata format support
- Add SuiImageParamsParser for sui_image_params JSON format - Register new parser in RecipeParserFactory - Fix metadata_provider auto-initialization when not ready - Add 10 test cases for SuiImageParamsParser Fixes batch import failure for images with sui_image_params metadata.
This commit is contained in:
@@ -7,6 +7,7 @@ from .parsers import (
|
|||||||
MetaFormatParser,
|
MetaFormatParser,
|
||||||
AutomaticMetadataParser,
|
AutomaticMetadataParser,
|
||||||
CivitaiApiMetadataParser,
|
CivitaiApiMetadataParser,
|
||||||
|
SuiImageParamsParser,
|
||||||
)
|
)
|
||||||
from .base import RecipeMetadataParser
|
from .base import RecipeMetadataParser
|
||||||
|
|
||||||
@@ -55,6 +56,13 @@ class RecipeParserFactory:
|
|||||||
# If JSON parsing fails, move on to other parsers
|
# If JSON parsing fails, move on to other parsers
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Try SuiImageParamsParser for SuiImage metadata format
|
||||||
|
try:
|
||||||
|
if SuiImageParamsParser().is_metadata_matching(metadata_str):
|
||||||
|
return SuiImageParamsParser()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# Check other parsers that expect string input
|
# Check other parsers that expect string input
|
||||||
if RecipeFormatParser().is_metadata_matching(metadata_str):
|
if RecipeFormatParser().is_metadata_matching(metadata_str):
|
||||||
return RecipeFormatParser()
|
return RecipeFormatParser()
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from .comfy import ComfyMetadataParser
|
|||||||
from .meta_format import MetaFormatParser
|
from .meta_format import MetaFormatParser
|
||||||
from .automatic import AutomaticMetadataParser
|
from .automatic import AutomaticMetadataParser
|
||||||
from .civitai_image import CivitaiApiMetadataParser
|
from .civitai_image import CivitaiApiMetadataParser
|
||||||
|
from .sui_image_params import SuiImageParamsParser
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'RecipeFormatParser',
|
'RecipeFormatParser',
|
||||||
@@ -12,4 +13,5 @@ __all__ = [
|
|||||||
'MetaFormatParser',
|
'MetaFormatParser',
|
||||||
'AutomaticMetadataParser',
|
'AutomaticMetadataParser',
|
||||||
'CivitaiApiMetadataParser',
|
'CivitaiApiMetadataParser',
|
||||||
|
'SuiImageParamsParser',
|
||||||
]
|
]
|
||||||
|
|||||||
188
py/recipes/parsers/sui_image_params.py
Normal file
188
py/recipes/parsers/sui_image_params.py
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
"""Parser for SuiImage (Stable Diffusion WebUI) metadata format."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Any, Optional, List
|
||||||
|
from ..base import RecipeMetadataParser
|
||||||
|
from ...services.metadata_service import get_default_metadata_provider
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SuiImageParamsParser(RecipeMetadataParser):
|
||||||
|
"""Parser for SuiImage metadata JSON format.
|
||||||
|
|
||||||
|
This format is used by some Stable Diffusion WebUI variants.
|
||||||
|
Structure:
|
||||||
|
{
|
||||||
|
"sui_image_params": {
|
||||||
|
"prompt": "...",
|
||||||
|
"negativeprompt": "...",
|
||||||
|
"model": "...",
|
||||||
|
"seed": ...,
|
||||||
|
"steps": ...,
|
||||||
|
...
|
||||||
|
},
|
||||||
|
"sui_models": [
|
||||||
|
{"name": "...", "param": "model", "hash": "..."},
|
||||||
|
...
|
||||||
|
],
|
||||||
|
"sui_extra_data": {...}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def is_metadata_matching(self, user_comment: str) -> bool:
|
||||||
|
"""Check if the user comment matches the SuiImage metadata format"""
|
||||||
|
try:
|
||||||
|
data = json.loads(user_comment)
|
||||||
|
return isinstance(data, dict) and 'sui_image_params' in data
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]:
|
||||||
|
"""Parse metadata from SuiImage metadata format"""
|
||||||
|
try:
|
||||||
|
metadata_provider = await get_default_metadata_provider()
|
||||||
|
|
||||||
|
data = json.loads(user_comment)
|
||||||
|
params = data.get('sui_image_params', {})
|
||||||
|
models = data.get('sui_models', [])
|
||||||
|
|
||||||
|
# Extract prompt and negative prompt
|
||||||
|
prompt = params.get('prompt', '')
|
||||||
|
negative_prompt = params.get('negativeprompt', '') or params.get('negative_prompt', '')
|
||||||
|
|
||||||
|
# Extract generation parameters
|
||||||
|
gen_params = {}
|
||||||
|
if prompt:
|
||||||
|
gen_params['prompt'] = prompt
|
||||||
|
if negative_prompt:
|
||||||
|
gen_params['negative_prompt'] = negative_prompt
|
||||||
|
|
||||||
|
# Map standard parameters
|
||||||
|
param_mapping = {
|
||||||
|
'steps': 'steps',
|
||||||
|
'seed': 'seed',
|
||||||
|
'cfgscale': 'cfg_scale',
|
||||||
|
'cfg_scale': 'cfg_scale',
|
||||||
|
'width': 'width',
|
||||||
|
'height': 'height',
|
||||||
|
'sampler': 'sampler',
|
||||||
|
'scheduler': 'scheduler',
|
||||||
|
'model': 'model',
|
||||||
|
'vae': 'vae',
|
||||||
|
}
|
||||||
|
|
||||||
|
for src_key, dest_key in param_mapping.items():
|
||||||
|
if src_key in params and params[src_key] is not None:
|
||||||
|
gen_params[dest_key] = params[src_key]
|
||||||
|
|
||||||
|
# Add size info if available
|
||||||
|
if 'width' in gen_params and 'height' in gen_params:
|
||||||
|
gen_params['size'] = f"{gen_params['width']}x{gen_params['height']}"
|
||||||
|
|
||||||
|
# Process models - extract checkpoint and loras
|
||||||
|
loras: List[Dict[str, Any]] = []
|
||||||
|
checkpoint: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
model_name = model.get('name', '')
|
||||||
|
param_type = model.get('param', '')
|
||||||
|
model_hash = model.get('hash', '')
|
||||||
|
|
||||||
|
# Remove .safetensors extension for cleaner name
|
||||||
|
clean_name = model_name.replace('.safetensors', '') if model_name else ''
|
||||||
|
|
||||||
|
# Check if this is a LoRA by looking at the name or param type
|
||||||
|
is_lora = 'lora' in model_name.lower() or param_type.lower().startswith('lora')
|
||||||
|
|
||||||
|
if is_lora:
|
||||||
|
lora_entry = {
|
||||||
|
'id': 0,
|
||||||
|
'modelId': 0,
|
||||||
|
'name': clean_name,
|
||||||
|
'version': '',
|
||||||
|
'type': 'lora',
|
||||||
|
'weight': 1.0,
|
||||||
|
'existsLocally': False,
|
||||||
|
'localPath': None,
|
||||||
|
'file_name': model_name,
|
||||||
|
'hash': model_hash.replace('0x', '') if model_hash.startswith('0x') else model_hash,
|
||||||
|
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||||
|
'baseModel': '',
|
||||||
|
'size': 0,
|
||||||
|
'downloadUrl': '',
|
||||||
|
'isDeleted': False
|
||||||
|
}
|
||||||
|
|
||||||
|
# Try to get additional info from metadata provider
|
||||||
|
if metadata_provider and model_hash:
|
||||||
|
try:
|
||||||
|
civitai_info = await metadata_provider.get_model_by_hash(
|
||||||
|
model_hash.replace('0x', '') if model_hash.startswith('0x') else model_hash
|
||||||
|
)
|
||||||
|
if civitai_info:
|
||||||
|
lora_entry = await self.populate_lora_from_civitai(
|
||||||
|
lora_entry, civitai_info, recipe_scanner
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error fetching info for LoRA {clean_name}: {e}")
|
||||||
|
|
||||||
|
if lora_entry:
|
||||||
|
loras.append(lora_entry)
|
||||||
|
elif param_type == 'model' or 'lora' not in model_name.lower():
|
||||||
|
# This is likely a checkpoint
|
||||||
|
checkpoint_entry = {
|
||||||
|
'id': 0,
|
||||||
|
'modelId': 0,
|
||||||
|
'name': clean_name,
|
||||||
|
'version': '',
|
||||||
|
'type': 'checkpoint',
|
||||||
|
'hash': model_hash.replace('0x', '') if model_hash.startswith('0x') else model_hash,
|
||||||
|
'existsLocally': False,
|
||||||
|
'localPath': None,
|
||||||
|
'file_name': model_name,
|
||||||
|
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||||
|
'baseModel': '',
|
||||||
|
'size': 0,
|
||||||
|
'downloadUrl': '',
|
||||||
|
'isDeleted': False
|
||||||
|
}
|
||||||
|
|
||||||
|
# Try to get additional info from metadata provider
|
||||||
|
if metadata_provider and model_hash:
|
||||||
|
try:
|
||||||
|
civitai_info = await metadata_provider.get_model_by_hash(
|
||||||
|
model_hash.replace('0x', '') if model_hash.startswith('0x') else model_hash
|
||||||
|
)
|
||||||
|
if civitai_info:
|
||||||
|
checkpoint_entry = await self.populate_checkpoint_from_civitai(
|
||||||
|
checkpoint_entry, civitai_info
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error fetching info for checkpoint {clean_name}: {e}")
|
||||||
|
|
||||||
|
checkpoint = checkpoint_entry
|
||||||
|
|
||||||
|
# Determine base model from loras or checkpoint
|
||||||
|
base_model = None
|
||||||
|
if loras:
|
||||||
|
base_models = [lora.get('baseModel') for lora in loras if lora.get('baseModel')]
|
||||||
|
if base_models:
|
||||||
|
from collections import Counter
|
||||||
|
base_model_counts = Counter(base_models)
|
||||||
|
base_model = base_model_counts.most_common(1)[0][0]
|
||||||
|
elif checkpoint and checkpoint.get('baseModel'):
|
||||||
|
base_model = checkpoint['baseModel']
|
||||||
|
|
||||||
|
return {
|
||||||
|
'base_model': base_model,
|
||||||
|
'loras': loras,
|
||||||
|
'checkpoint': checkpoint,
|
||||||
|
'gen_params': gen_params,
|
||||||
|
'from_sui_image_params': True
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error parsing SuiImage metadata: {e}", exc_info=True)
|
||||||
|
return {"error": str(e), "loras": []}
|
||||||
@@ -122,11 +122,25 @@ async def get_metadata_provider(provider_name: str = None):
|
|||||||
|
|
||||||
provider_manager = await ModelMetadataProviderManager.get_instance()
|
provider_manager = await ModelMetadataProviderManager.get_instance()
|
||||||
|
|
||||||
|
try:
|
||||||
provider = (
|
provider = (
|
||||||
provider_manager._get_provider(provider_name)
|
provider_manager._get_provider(provider_name)
|
||||||
if provider_name
|
if provider_name
|
||||||
else provider_manager._get_provider()
|
else provider_manager._get_provider()
|
||||||
)
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
# Provider not initialized, attempt to initialize
|
||||||
|
if "No default provider set" in str(e) or "not registered" in str(e):
|
||||||
|
logger.warning(f"Metadata provider not initialized ({e}), initializing now...")
|
||||||
|
await initialize_metadata_providers()
|
||||||
|
provider_manager = await ModelMetadataProviderManager.get_instance()
|
||||||
|
provider = (
|
||||||
|
provider_manager._get_provider(provider_name)
|
||||||
|
if provider_name
|
||||||
|
else provider_manager._get_provider()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
return _wrap_provider_with_rate_limit(provider_name, provider)
|
return _wrap_provider_with_rate_limit(provider_name, provider)
|
||||||
|
|
||||||
|
|||||||
@@ -492,7 +492,7 @@ async def test_analyze_remote_video(tmp_path):
|
|||||||
|
|
||||||
class DummyFactory:
|
class DummyFactory:
|
||||||
def create_parser(self, metadata):
|
def create_parser(self, metadata):
|
||||||
async def parse_metadata(m, recipe_scanner):
|
async def parse_metadata(m, recipe_scanner=None, civitai_client=None):
|
||||||
return {"loras": []}
|
return {"loras": []}
|
||||||
return SimpleNamespace(parse_metadata=parse_metadata)
|
return SimpleNamespace(parse_metadata=parse_metadata)
|
||||||
|
|
||||||
|
|||||||
202
tests/services/test_sui_image_params_parser.py
Normal file
202
tests/services/test_sui_image_params_parser.py
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
"""Tests for SuiImageParamsParser."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import json
|
||||||
|
from py.recipes.parsers import SuiImageParamsParser
|
||||||
|
|
||||||
|
|
||||||
|
class TestSuiImageParamsParser:
|
||||||
|
"""Test cases for SuiImageParamsParser."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Set up test fixtures."""
|
||||||
|
self.parser = SuiImageParamsParser()
|
||||||
|
|
||||||
|
def test_is_metadata_matching_positive(self):
|
||||||
|
"""Test that parser correctly identifies SuiImage metadata format."""
|
||||||
|
metadata = {
|
||||||
|
"sui_image_params": {
|
||||||
|
"prompt": "test prompt",
|
||||||
|
"model": "test_model"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
metadata_str = json.dumps(metadata)
|
||||||
|
assert self.parser.is_metadata_matching(metadata_str) is True
|
||||||
|
|
||||||
|
def test_is_metadata_matching_negative(self):
|
||||||
|
"""Test that parser rejects non-SuiImage metadata."""
|
||||||
|
# Missing sui_image_params key
|
||||||
|
metadata = {
|
||||||
|
"other_params": {
|
||||||
|
"prompt": "test prompt"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
metadata_str = json.dumps(metadata)
|
||||||
|
assert self.parser.is_metadata_matching(metadata_str) is False
|
||||||
|
|
||||||
|
def test_is_metadata_matching_invalid_json(self):
|
||||||
|
"""Test that parser handles invalid JSON gracefully."""
|
||||||
|
metadata_str = "not valid json"
|
||||||
|
assert self.parser.is_metadata_matching(metadata_str) is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_metadata_extracts_basic_fields(self):
|
||||||
|
"""Test parsing basic fields from SuiImage metadata."""
|
||||||
|
metadata = {
|
||||||
|
"sui_image_params": {
|
||||||
|
"prompt": "beautiful landscape",
|
||||||
|
"negativeprompt": "ugly, blurry",
|
||||||
|
"steps": 30,
|
||||||
|
"seed": 12345,
|
||||||
|
"cfgscale": 7.5,
|
||||||
|
"width": 512,
|
||||||
|
"height": 768,
|
||||||
|
"sampler": "Euler a",
|
||||||
|
"scheduler": "normal"
|
||||||
|
},
|
||||||
|
"sui_models": [],
|
||||||
|
"sui_extra_data": {}
|
||||||
|
}
|
||||||
|
metadata_str = json.dumps(metadata)
|
||||||
|
result = await self.parser.parse_metadata(metadata_str)
|
||||||
|
|
||||||
|
assert result.get('gen_params', {}).get('prompt') == "beautiful landscape"
|
||||||
|
assert result.get('gen_params', {}).get('negative_prompt') == "ugly, blurry"
|
||||||
|
assert result.get('gen_params', {}).get('steps') == 30
|
||||||
|
assert result.get('gen_params', {}).get('seed') == 12345
|
||||||
|
assert result.get('gen_params', {}).get('cfg_scale') == 7.5
|
||||||
|
assert result.get('gen_params', {}).get('width') == 512
|
||||||
|
assert result.get('gen_params', {}).get('height') == 768
|
||||||
|
assert result.get('gen_params', {}).get('size') == "512x768"
|
||||||
|
assert result.get('loras') == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_metadata_extracts_checkpoint(self):
|
||||||
|
"""Test parsing checkpoint from sui_models."""
|
||||||
|
metadata = {
|
||||||
|
"sui_image_params": {
|
||||||
|
"prompt": "test prompt",
|
||||||
|
"model": "checkpoint_model"
|
||||||
|
},
|
||||||
|
"sui_models": [
|
||||||
|
{
|
||||||
|
"name": "test_checkpoint.safetensors",
|
||||||
|
"param": "model",
|
||||||
|
"hash": "0x1234567890abcdef"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"sui_extra_data": {}
|
||||||
|
}
|
||||||
|
metadata_str = json.dumps(metadata)
|
||||||
|
result = await self.parser.parse_metadata(metadata_str)
|
||||||
|
|
||||||
|
checkpoint = result.get('checkpoint')
|
||||||
|
assert checkpoint is not None
|
||||||
|
assert checkpoint['type'] == 'checkpoint'
|
||||||
|
assert checkpoint['name'] == 'test_checkpoint'
|
||||||
|
assert checkpoint['hash'] == '1234567890abcdef'
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_metadata_extracts_lora(self):
|
||||||
|
"""Test parsing LoRA from sui_models."""
|
||||||
|
metadata = {
|
||||||
|
"sui_image_params": {
|
||||||
|
"prompt": "test prompt"
|
||||||
|
},
|
||||||
|
"sui_models": [
|
||||||
|
{
|
||||||
|
"name": "test_lora.safetensors",
|
||||||
|
"param": "lora",
|
||||||
|
"hash": "0xabcdef1234567890"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"sui_extra_data": {}
|
||||||
|
}
|
||||||
|
metadata_str = json.dumps(metadata)
|
||||||
|
result = await self.parser.parse_metadata(metadata_str)
|
||||||
|
|
||||||
|
loras = result.get('loras')
|
||||||
|
assert len(loras) == 1
|
||||||
|
assert loras[0]['type'] == 'lora'
|
||||||
|
assert loras[0]['name'] == 'test_lora'
|
||||||
|
assert loras[0]['file_name'] == 'test_lora.safetensors'
|
||||||
|
assert loras[0]['hash'] == 'abcdef1234567890'
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_metadata_handles_lora_in_name(self):
|
||||||
|
"""Test that LoRA is detected by 'lora' in name."""
|
||||||
|
metadata = {
|
||||||
|
"sui_image_params": {
|
||||||
|
"prompt": "test prompt"
|
||||||
|
},
|
||||||
|
"sui_models": [
|
||||||
|
{
|
||||||
|
"name": "style_lora_v2.safetensors",
|
||||||
|
"param": "some_other_param",
|
||||||
|
"hash": "0x1111111111111111"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"sui_extra_data": {}
|
||||||
|
}
|
||||||
|
metadata_str = json.dumps(metadata)
|
||||||
|
result = await self.parser.parse_metadata(metadata_str)
|
||||||
|
|
||||||
|
loras = result.get('loras')
|
||||||
|
assert len(loras) == 1
|
||||||
|
assert loras[0]['type'] == 'lora'
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_metadata_empty_models(self):
|
||||||
|
"""Test parsing with empty sui_models array."""
|
||||||
|
metadata = {
|
||||||
|
"sui_image_params": {
|
||||||
|
"prompt": "test prompt",
|
||||||
|
"steps": 20
|
||||||
|
},
|
||||||
|
"sui_models": [],
|
||||||
|
"sui_extra_data": {
|
||||||
|
"date": "2024-01-01"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
metadata_str = json.dumps(metadata)
|
||||||
|
result = await self.parser.parse_metadata(metadata_str)
|
||||||
|
|
||||||
|
assert result.get('loras') == []
|
||||||
|
assert result.get('checkpoint') is None
|
||||||
|
assert result.get('gen_params', {}).get('prompt') == "test prompt"
|
||||||
|
assert result.get('gen_params', {}).get('steps') == 20
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_metadata_alternative_field_names(self):
|
||||||
|
"""Test parsing with alternative field names."""
|
||||||
|
metadata = {
|
||||||
|
"sui_image_params": {
|
||||||
|
"prompt": "test prompt",
|
||||||
|
"negative_prompt": "bad quality", # Using underscore variant
|
||||||
|
"cfg_scale": 6.0 # Using underscore variant
|
||||||
|
},
|
||||||
|
"sui_models": [],
|
||||||
|
"sui_extra_data": {}
|
||||||
|
}
|
||||||
|
metadata_str = json.dumps(metadata)
|
||||||
|
result = await self.parser.parse_metadata(metadata_str)
|
||||||
|
|
||||||
|
assert result.get('gen_params', {}).get('negative_prompt') == "bad quality"
|
||||||
|
assert result.get('gen_params', {}).get('cfg_scale') == 6.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_metadata_error_handling(self):
|
||||||
|
"""Test that parser handles malformed data gracefully."""
|
||||||
|
# Missing required fields
|
||||||
|
metadata = {
|
||||||
|
"sui_image_params": {},
|
||||||
|
"sui_models": [],
|
||||||
|
"sui_extra_data": {}
|
||||||
|
}
|
||||||
|
metadata_str = json.dumps(metadata)
|
||||||
|
result = await self.parser.parse_metadata(metadata_str)
|
||||||
|
|
||||||
|
assert 'error' not in result
|
||||||
|
assert result.get('loras') == []
|
||||||
|
# Empty params result in empty gen_params dict
|
||||||
|
assert result.get('gen_params') == {}
|
||||||
Reference in New Issue
Block a user