mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 22:52:12 -03:00
feat: Introduce generation parameter merging from request, Civitai, and embedded image metadata, and enhance ComfyUI metadata parsing.
This commit is contained in:
50
py/recipes/merger.py
Normal file
50
py/recipes/merger.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
from typing import Any, Dict, Optional
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class GenParamsMerger:
|
||||||
|
"""Utility to merge generation parameters from multiple sources with priority."""
|
||||||
|
|
||||||
|
BLACKLISTED_KEYS = {"id", "url", "userId", "username", "createdAt", "updatedAt", "hash"}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def merge(
|
||||||
|
request_params: Optional[Dict[str, Any]] = None,
|
||||||
|
civitai_meta: Optional[Dict[str, Any]] = None,
|
||||||
|
embedded_metadata: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Merge generation parameters from three sources.
|
||||||
|
|
||||||
|
Priority: request_params > civitai_meta > embedded_metadata
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_params: Params provided directly in the import request
|
||||||
|
civitai_meta: Params from Civitai Image API 'meta' field
|
||||||
|
embedded_metadata: Params extracted from image EXIF/embedded metadata
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Merged parameters dictionary
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
# 1. Start with embedded metadata (lowest priority)
|
||||||
|
if embedded_metadata:
|
||||||
|
# If it's a full recipe metadata, we use its gen_params
|
||||||
|
if "gen_params" in embedded_metadata and isinstance(embedded_metadata["gen_params"], dict):
|
||||||
|
result.update(embedded_metadata["gen_params"])
|
||||||
|
else:
|
||||||
|
# Otherwise assume the dict itself contains gen_params
|
||||||
|
result.update(embedded_metadata)
|
||||||
|
|
||||||
|
# 2. Layer Civitai meta (medium priority)
|
||||||
|
if civitai_meta:
|
||||||
|
result.update(civitai_meta)
|
||||||
|
|
||||||
|
# 3. Layer request params (highest priority)
|
||||||
|
if request_params:
|
||||||
|
result.update(request_params)
|
||||||
|
|
||||||
|
# Filter out blacklisted keys
|
||||||
|
return {k: v for k, v in result.items() if k not in GenParamsMerger.BLACKLISTED_KEYS}
|
||||||
@@ -36,9 +36,6 @@ class ComfyMetadataParser(RecipeMetadataParser):
|
|||||||
# Find all LoraLoader nodes
|
# Find all LoraLoader nodes
|
||||||
lora_nodes = {k: v for k, v in data.items() if isinstance(v, dict) and v.get('class_type') == 'LoraLoader'}
|
lora_nodes = {k: v for k, v in data.items() if isinstance(v, dict) and v.get('class_type') == 'LoraLoader'}
|
||||||
|
|
||||||
if not lora_nodes:
|
|
||||||
return {"error": "No LoRA information found in this ComfyUI workflow", "loras": []}
|
|
||||||
|
|
||||||
# Process each LoraLoader node
|
# Process each LoraLoader node
|
||||||
for node_id, node in lora_nodes.items():
|
for node_id, node in lora_nodes.items():
|
||||||
if 'inputs' not in node or 'lora_name' not in node['inputs']:
|
if 'inputs' not in node or 'lora_name' not in node['inputs']:
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ from ...services.recipes import (
|
|||||||
)
|
)
|
||||||
from ...services.metadata_service import get_default_metadata_provider
|
from ...services.metadata_service import get_default_metadata_provider
|
||||||
from ...utils.civitai_utils import rewrite_preview_url
|
from ...utils.civitai_utils import rewrite_preview_url
|
||||||
|
from ...utils.exif_utils import ExifUtils
|
||||||
|
from ...recipes.merger import GenParamsMerger
|
||||||
|
|
||||||
Logger = logging.Logger
|
Logger = logging.Logger
|
||||||
EnsureDependenciesCallable = Callable[[], Awaitable[None]]
|
EnsureDependenciesCallable = Callable[[], Awaitable[None]]
|
||||||
@@ -552,7 +554,41 @@ class RecipeManagementHandler:
|
|||||||
metadata["base_model"] = base_model_from_metadata
|
metadata["base_model"] = base_model_from_metadata
|
||||||
|
|
||||||
tags = self._parse_tags(params.get("tags"))
|
tags = self._parse_tags(params.get("tags"))
|
||||||
image_bytes, extension = await self._download_remote_media(image_url)
|
image_bytes, extension, civitai_meta = await self._download_remote_media(image_url)
|
||||||
|
|
||||||
|
# Extract embedded metadata from the downloaded image
|
||||||
|
embedded_metadata = None
|
||||||
|
try:
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=extension, delete=False) as temp_img:
|
||||||
|
temp_img.write(image_bytes)
|
||||||
|
temp_img_path = temp_img.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
raw_embedded = ExifUtils.extract_image_metadata(temp_img_path)
|
||||||
|
if raw_embedded:
|
||||||
|
# Try to parse it using standard parsers if it looks like a recipe
|
||||||
|
parser = self._analysis_service._recipe_parser_factory.create_parser(raw_embedded)
|
||||||
|
if parser:
|
||||||
|
parsed_embedded = await parser.parse_metadata(raw_embedded, recipe_scanner=recipe_scanner)
|
||||||
|
embedded_metadata = parsed_embedded
|
||||||
|
else:
|
||||||
|
# Fallback to raw string if no parser matches (might be simple params)
|
||||||
|
embedded_metadata = {"gen_params": {"raw_metadata": raw_embedded}}
|
||||||
|
finally:
|
||||||
|
if os.path.exists(temp_img_path):
|
||||||
|
os.unlink(temp_img_path)
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.warning("Failed to extract embedded metadata during import: %s", exc)
|
||||||
|
|
||||||
|
# Merge gen_params from all sources
|
||||||
|
merged_gen_params = GenParamsMerger.merge(
|
||||||
|
request_params=gen_params,
|
||||||
|
civitai_meta=civitai_meta,
|
||||||
|
embedded_metadata=embedded_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
if merged_gen_params:
|
||||||
|
metadata["gen_params"] = merged_gen_params
|
||||||
|
|
||||||
result = await self._persistence_service.save_recipe(
|
result = await self._persistence_service.save_recipe(
|
||||||
recipe_scanner=recipe_scanner,
|
recipe_scanner=recipe_scanner,
|
||||||
@@ -900,7 +936,7 @@ class RecipeManagementHandler:
|
|||||||
extension = ".webp" # Default to webp if unknown
|
extension = ".webp" # Default to webp if unknown
|
||||||
|
|
||||||
with open(temp_path, "rb") as file_obj:
|
with open(temp_path, "rb") as file_obj:
|
||||||
return file_obj.read(), extension
|
return file_obj.read(), extension, image_info.get("meta") if civitai_match and image_info else None
|
||||||
except RecipeDownloadError:
|
except RecipeDownloadError:
|
||||||
raise
|
raise
|
||||||
except RecipeValidationError:
|
except RecipeValidationError:
|
||||||
|
|||||||
@@ -91,6 +91,7 @@ class StubAnalysisService:
|
|||||||
self.remote_calls: List[Optional[str]] = []
|
self.remote_calls: List[Optional[str]] = []
|
||||||
self.local_calls: List[Optional[str]] = []
|
self.local_calls: List[Optional[str]] = []
|
||||||
self.result = SimpleNamespace(payload={"loras": []}, status=200)
|
self.result = SimpleNamespace(payload={"loras": []}, status=200)
|
||||||
|
self._recipe_parser_factory = None
|
||||||
StubAnalysisService.instances.append(self)
|
StubAnalysisService.instances.append(self)
|
||||||
|
|
||||||
async def analyze_uploaded_image(self, *, image_bytes: bytes | None, recipe_scanner) -> SimpleNamespace: # noqa: D401 - mirrors real signature
|
async def analyze_uploaded_image(self, *, image_bytes: bytes | None, recipe_scanner) -> SimpleNamespace: # noqa: D401 - mirrors real signature
|
||||||
@@ -527,3 +528,69 @@ async def test_share_and_download_recipe(monkeypatch, tmp_path: Path) -> None:
|
|||||||
assert body == b"stub"
|
assert body == b"stub"
|
||||||
|
|
||||||
download_path.unlink(missing_ok=True)
|
download_path.unlink(missing_ok=True)
|
||||||
|
async def test_import_remote_recipe_merges_metadata(monkeypatch, tmp_path: Path) -> None:
|
||||||
|
# 1. Mock Metadata Provider
|
||||||
|
class Provider:
|
||||||
|
async def get_model_version_info(self, model_version_id):
|
||||||
|
return {"baseModel": "Flux Provider"}, None
|
||||||
|
|
||||||
|
async def fake_get_default_metadata_provider():
|
||||||
|
return Provider()
|
||||||
|
|
||||||
|
monkeypatch.setattr(recipe_handlers, "get_default_metadata_provider", fake_get_default_metadata_provider)
|
||||||
|
|
||||||
|
# 2. Mock ExifUtils to return some embedded metadata
|
||||||
|
class MockExifUtils:
|
||||||
|
@staticmethod
|
||||||
|
def extract_image_metadata(path):
|
||||||
|
return "Recipe metadata: " + json.dumps({
|
||||||
|
"gen_params": {"prompt": "from embedded", "seed": 123}
|
||||||
|
})
|
||||||
|
|
||||||
|
monkeypatch.setattr(recipe_handlers, "ExifUtils", MockExifUtils)
|
||||||
|
|
||||||
|
# 3. Mock Parser Factory for StubAnalysisService
|
||||||
|
class MockParser:
|
||||||
|
async def parse_metadata(self, raw, recipe_scanner=None):
|
||||||
|
return json.loads(raw[len("Recipe metadata: "):])
|
||||||
|
|
||||||
|
class MockFactory:
|
||||||
|
def create_parser(self, raw):
|
||||||
|
if raw.startswith("Recipe metadata: "):
|
||||||
|
return MockParser()
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 4. Setup Harness and run test
|
||||||
|
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||||
|
harness.analysis._recipe_parser_factory = MockFactory()
|
||||||
|
|
||||||
|
# Civitai meta via image_info
|
||||||
|
harness.civitai.image_info["1"] = {
|
||||||
|
"id": 1,
|
||||||
|
"url": "https://example.com/images/1.jpg",
|
||||||
|
"meta": {"prompt": "from civitai", "cfg": 7.0}
|
||||||
|
}
|
||||||
|
|
||||||
|
resources = []
|
||||||
|
response = await harness.client.get(
|
||||||
|
"/api/lm/recipes/import-remote",
|
||||||
|
params={
|
||||||
|
"image_url": "https://civitai.com/images/1",
|
||||||
|
"name": "Merged Recipe",
|
||||||
|
"resources": json.dumps(resources),
|
||||||
|
"gen_params": json.dumps({"prompt": "from request", "steps": 25}),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = await response.json()
|
||||||
|
assert response.status == 200
|
||||||
|
|
||||||
|
call = harness.persistence.save_calls[-1]
|
||||||
|
metadata = call["metadata"]
|
||||||
|
gen_params = metadata["gen_params"]
|
||||||
|
|
||||||
|
# Priority: request (prompt=request, steps=25) > civitai (prompt=civitai, cfg=7.0) > embedded (prompt=embedded, seed=123)
|
||||||
|
assert gen_params["prompt"] == "from request"
|
||||||
|
assert gen_params["steps"] == 25
|
||||||
|
assert gen_params["cfg"] == 7.0
|
||||||
|
assert gen_params["seed"] == 123
|
||||||
|
|||||||
113
tests/services/test_comfy_metadata_parser.py
Normal file
113
tests/services/test_comfy_metadata_parser.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
import pytest
|
||||||
|
import json
|
||||||
|
from py.recipes.parsers.comfy import ComfyMetadataParser
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_metadata_without_loras(monkeypatch):
|
||||||
|
checkpoint_info = {
|
||||||
|
"id": 2224012,
|
||||||
|
"modelId": 1908679,
|
||||||
|
"model": {"name": "SDXL Checkpoint", "type": "checkpoint"},
|
||||||
|
"name": "v1.0",
|
||||||
|
"images": [{"url": "https://image.civitai.com/checkpoints/original=true"}],
|
||||||
|
"baseModel": "sdxl",
|
||||||
|
"downloadUrl": "https://civitai.com/api/download/checkpoint",
|
||||||
|
}
|
||||||
|
|
||||||
|
async def fake_metadata_provider():
|
||||||
|
class Provider:
|
||||||
|
async def get_model_version_info(self, version_id):
|
||||||
|
assert version_id == "2224012"
|
||||||
|
return checkpoint_info, None
|
||||||
|
return Provider()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"py.recipes.parsers.comfy.get_default_metadata_provider",
|
||||||
|
fake_metadata_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser = ComfyMetadataParser()
|
||||||
|
|
||||||
|
# User provided metadata
|
||||||
|
metadata_json = {
|
||||||
|
"resource-stack": {
|
||||||
|
"class_type": "CheckpointLoaderSimple",
|
||||||
|
"inputs": {"ckpt_name": "urn:air:sdxl:checkpoint:civitai:1908679@2224012"}
|
||||||
|
},
|
||||||
|
"6": {
|
||||||
|
"class_type": "smZ CLIPTextEncode",
|
||||||
|
"inputs": {"text": "Positive prompt content"},
|
||||||
|
"_meta": {"title": "Positive"}
|
||||||
|
},
|
||||||
|
"7": {
|
||||||
|
"class_type": "smZ CLIPTextEncode",
|
||||||
|
"inputs": {"text": "Negative prompt content"},
|
||||||
|
"_meta": {"title": "Negative"}
|
||||||
|
},
|
||||||
|
"11": {
|
||||||
|
"class_type": "KSampler",
|
||||||
|
"inputs": {
|
||||||
|
"sampler_name": "euler_ancestral",
|
||||||
|
"scheduler": "normal",
|
||||||
|
"seed": 904124997,
|
||||||
|
"steps": 35,
|
||||||
|
"cfg": 6,
|
||||||
|
"denoise": 0.1,
|
||||||
|
"model": ["resource-stack", 0],
|
||||||
|
"positive": ["6", 0],
|
||||||
|
"negative": ["7", 0],
|
||||||
|
"latent_image": ["21", 0]
|
||||||
|
},
|
||||||
|
"_meta": {"title": "KSampler"}
|
||||||
|
},
|
||||||
|
"extraMetadata": json.dumps({
|
||||||
|
"prompt": "One woman, (solo:1.3), ...",
|
||||||
|
"negativePrompt": "lowres, worst quality, ...",
|
||||||
|
"steps": 35,
|
||||||
|
"cfgScale": 6,
|
||||||
|
"sampler": "euler_ancestral",
|
||||||
|
"seed": 904124997,
|
||||||
|
"width": 1024,
|
||||||
|
"height": 1024
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await parser.parse_metadata(json.dumps(metadata_json))
|
||||||
|
|
||||||
|
assert "error" not in result
|
||||||
|
assert result["loras"] == []
|
||||||
|
assert result["checkpoint"] is not None
|
||||||
|
assert int(result["checkpoint"]["modelId"]) == 1908679
|
||||||
|
assert int(result["checkpoint"]["id"]) == 2224012
|
||||||
|
assert result["gen_params"]["prompt"] == "One woman, (solo:1.3), ..."
|
||||||
|
assert result["gen_params"]["steps"] == 35
|
||||||
|
assert result["gen_params"]["size"] == "1024x1024"
|
||||||
|
assert result["from_comfy_metadata"] is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_metadata_without_extra_metadata(monkeypatch):
|
||||||
|
async def fake_metadata_provider():
|
||||||
|
class Provider:
|
||||||
|
async def get_model_version_info(self, version_id):
|
||||||
|
return {"model": {"name": "Test"}, "id": version_id}, None
|
||||||
|
return Provider()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"py.recipes.parsers.comfy.get_default_metadata_provider",
|
||||||
|
fake_metadata_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser = ComfyMetadataParser()
|
||||||
|
|
||||||
|
metadata_json = {
|
||||||
|
"node_1": {
|
||||||
|
"class_type": "CheckpointLoaderSimple",
|
||||||
|
"inputs": {"ckpt_name": "urn:air:sdxl:checkpoint:civitai:123@456"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await parser.parse_metadata(json.dumps(metadata_json))
|
||||||
|
|
||||||
|
assert "error" not in result
|
||||||
|
assert result["loras"] == []
|
||||||
|
assert result["checkpoint"]["id"] == "456"
|
||||||
59
tests/services/test_gen_params_merger.py
Normal file
59
tests/services/test_gen_params_merger.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
import pytest
|
||||||
|
from py.recipes.merger import GenParamsMerger
|
||||||
|
|
||||||
|
def test_merge_priority():
|
||||||
|
request_params = {"prompt": "from request", "steps": 20}
|
||||||
|
civitai_meta = {"prompt": "from civitai", "cfg": 7.0}
|
||||||
|
embedded_metadata = {"gen_params": {"prompt": "from embedded", "seed": 123}}
|
||||||
|
|
||||||
|
merged = GenParamsMerger.merge(request_params, civitai_meta, embedded_metadata)
|
||||||
|
|
||||||
|
assert merged["prompt"] == "from request"
|
||||||
|
assert merged["steps"] == 20
|
||||||
|
assert merged["cfg"] == 7.0
|
||||||
|
assert merged["seed"] == 123
|
||||||
|
|
||||||
|
def test_merge_no_request_params():
|
||||||
|
civitai_meta = {"prompt": "from civitai", "cfg": 7.0}
|
||||||
|
embedded_metadata = {"gen_params": {"prompt": "from embedded", "seed": 123}}
|
||||||
|
|
||||||
|
merged = GenParamsMerger.merge(None, civitai_meta, embedded_metadata)
|
||||||
|
|
||||||
|
assert merged["prompt"] == "from civitai"
|
||||||
|
assert merged["cfg"] == 7.0
|
||||||
|
assert merged["seed"] == 123
|
||||||
|
|
||||||
|
def test_merge_only_embedded():
|
||||||
|
embedded_metadata = {"gen_params": {"prompt": "from embedded", "seed": 123}}
|
||||||
|
|
||||||
|
merged = GenParamsMerger.merge(None, None, embedded_metadata)
|
||||||
|
|
||||||
|
assert merged["prompt"] == "from embedded"
|
||||||
|
assert merged["seed"] == 123
|
||||||
|
|
||||||
|
def test_merge_raw_embedded():
|
||||||
|
# Test when embedded metadata is just the gen_params themselves
|
||||||
|
embedded_metadata = {"prompt": "from raw embedded", "seed": 456}
|
||||||
|
|
||||||
|
merged = GenParamsMerger.merge(None, None, embedded_metadata)
|
||||||
|
|
||||||
|
assert merged["prompt"] == "from raw embedded"
|
||||||
|
assert merged["seed"] == 456
|
||||||
|
|
||||||
|
def test_merge_none_values():
|
||||||
|
merged = GenParamsMerger.merge(None, None, None)
|
||||||
|
assert merged == {}
|
||||||
|
|
||||||
|
def test_merge_filters_blacklisted_keys():
|
||||||
|
request_params = {"prompt": "test", "id": "should-be-removed"}
|
||||||
|
civitai_meta = {"cfg": 7, "url": "remove-me"}
|
||||||
|
embedded_metadata = {"seed": 123, "hash": "remove-also"}
|
||||||
|
|
||||||
|
merged = GenParamsMerger.merge(request_params, civitai_meta, embedded_metadata)
|
||||||
|
|
||||||
|
assert "prompt" in merged
|
||||||
|
assert "cfg" in merged
|
||||||
|
assert "seed" in merged
|
||||||
|
assert "id" not in merged
|
||||||
|
assert "url" not in merged
|
||||||
|
assert "hash" not in merged
|
||||||
Reference in New Issue
Block a user