From df9367059877b7a2159410244ff80f84c1753991 Mon Sep 17 00:00:00 2001 From: Will Miao Date: Sat, 29 Nov 2025 08:46:38 +0800 Subject: [PATCH] feat: add checkpoint metadata to EXIF recipe data Add support for storing checkpoint information in image EXIF metadata. The checkpoint data is simplified and includes fields like model ID, version, name, hash, and base model. This allows for better tracking of AI model checkpoints used in image generation workflows. --- py/utils/exif_utils.py | 27 +++++++++++++-- tests/utils/test_exif_utils.py | 61 ++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 2 deletions(-) create mode 100644 tests/utils/test_exif_utils.py diff --git a/py/utils/exif_utils.py b/py/utils/exif_utils.py index ff96703c..5c853a7d 100644 --- a/py/utils/exif_utils.py +++ b/py/utils/exif_utils.py @@ -140,6 +140,28 @@ class ExifUtils: if metadata: # Remove any existing recipe metadata metadata = ExifUtils.remove_recipe_metadata(metadata) + + # Prepare checkpoint data + checkpoint_data = recipe_data.get("checkpoint") or {} + simplified_checkpoint = None + if isinstance(checkpoint_data, dict) and checkpoint_data: + simplified_checkpoint = { + "type": checkpoint_data.get("type", "checkpoint"), + "modelId": checkpoint_data.get("modelId", 0), + "modelVersionId": checkpoint_data.get("modelVersionId") + or checkpoint_data.get("id", 0), + "modelName": checkpoint_data.get( + "modelName", checkpoint_data.get("name", "") + ), + "modelVersionName": checkpoint_data.get( + "modelVersionName", checkpoint_data.get("version", "") + ), + "hash": checkpoint_data.get("hash", "").lower() + if checkpoint_data.get("hash") + else "", + "file_name": checkpoint_data.get("file_name", ""), + "baseModel": checkpoint_data.get("baseModel", ""), + } # Prepare simplified loras data simplified_loras = [] @@ -160,7 +182,8 @@ class ExifUtils: 'base_model': recipe_data.get('base_model', ''), 'loras': simplified_loras, 'gen_params': recipe_data.get('gen_params', {}), - 'tags': recipe_data.get('tags', []) + 'tags': recipe_data.get('tags', []), + **({'checkpoint': simplified_checkpoint} if simplified_checkpoint else {}) } # Convert to JSON string @@ -359,4 +382,4 @@ class ExifUtils: return f.read(), os.path.splitext(image_data)[1] except Exception: return image_data, '.jpg' # Last resort fallback - return image_data, '.jpg' \ No newline at end of file + return image_data, '.jpg' diff --git a/tests/utils/test_exif_utils.py b/tests/utils/test_exif_utils.py new file mode 100644 index 00000000..9e84d7a5 --- /dev/null +++ b/tests/utils/test_exif_utils.py @@ -0,0 +1,61 @@ +import json + +from py.utils.exif_utils import ExifUtils + + +def test_append_recipe_metadata_includes_checkpoint(monkeypatch, tmp_path): + captured = {} + + monkeypatch.setattr( + ExifUtils, "extract_image_metadata", staticmethod(lambda _path: None) + ) + + def fake_update_image_metadata(image_path, metadata): + captured["path"] = image_path + captured["metadata"] = metadata + return image_path + + monkeypatch.setattr( + ExifUtils, "update_image_metadata", staticmethod(fake_update_image_metadata) + ) + + checkpoint = { + "type": "checkpoint", + "modelId": 827184, + "modelVersionId": 2167369, + "modelName": "WAI-illustrious-SDXL", + "modelVersionName": "v15.0", + "hash": "ABC123", + "file_name": "WAI-illustrious-SDXL", + "baseModel": "Illustrious", + } + + recipe_data = { + "title": "Semi-realism", + "base_model": "Illustrious", + "loras": [], + "tags": [], + "checkpoint": checkpoint, + } + + image_path = tmp_path / "image.webp" + image_path.write_bytes(b"") + + ExifUtils.append_recipe_metadata(str(image_path), recipe_data) + + assert captured["path"] == str(image_path) + assert captured["metadata"].startswith("Recipe metadata: ") + + payload = json.loads(captured["metadata"].split("Recipe metadata: ", 1)[1]) + + assert payload["checkpoint"] == { + "type": "checkpoint", + "modelId": 827184, + "modelVersionId": 2167369, + "modelName": "WAI-illustrious-SDXL", + "modelVersionName": "v15.0", + "hash": "abc123", + "file_name": "WAI-illustrious-SDXL", + "baseModel": "Illustrious", + } + assert payload["base_model"] == "Illustrious"