diff --git a/py/services/civitai_client.py b/py/services/civitai_client.py index bb6004ed..4ec3ff0b 100644 --- a/py/services/civitai_client.py +++ b/py/services/civitai_client.py @@ -32,6 +32,24 @@ class CivitaiClient: self._initialized = True self.base_url = "https://civitai.com/api/v1" + + @staticmethod + def _remove_comfy_metadata(model_version: Optional[Dict]) -> None: + """Remove Comfy-specific metadata from model version images.""" + if not isinstance(model_version, dict): + return + + images = model_version.get("images") + if not isinstance(images, list): + return + + for image in images: + if not isinstance(image, dict): + continue + + meta = image.get("meta") + if isinstance(meta, dict) and "comfy" in meta: + meta.pop("comfy", None) async def download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]: """Download file with resumable downloads and retry mechanism @@ -81,10 +99,11 @@ class CivitaiClient: # Enrich version_info with model data result['model']['description'] = data.get("description") result['model']['tags'] = data.get("tags", []) - + # Add creator from model data result['creator'] = data.get("creator") - + + self._remove_comfy_metadata(result) return result, None # Handle specific error cases @@ -177,7 +196,8 @@ class CivitaiClient: version['model']['description'] = model_data.get("description") version['model']['tags'] = model_data.get("tags", []) version['creator'] = model_data.get("creator") - + + self._remove_comfy_metadata(version) return version # Case 2: model_id is provided (with or without version_id) @@ -260,6 +280,7 @@ class CivitaiClient: # Add creator from model data version['creator'] = data.get("creator") + self._remove_comfy_metadata(version) return version # Case 3: Neither model_id nor version_id provided @@ -295,6 +316,7 @@ class CivitaiClient: if success: logger.debug(f"Successfully fetched model version info for: {version_id}") + self._remove_comfy_metadata(result) return result, None # Handle specific error cases diff --git a/tests/services/test_civitai_client.py b/tests/services/test_civitai_client.py index f5283443..c6241478 100644 --- a/tests/services/test_civitai_client.py +++ b/tests/services/test_civitai_client.py @@ -63,6 +63,10 @@ async def test_get_model_by_hash_enriches_metadata(monkeypatch, downloader): "modelId": 123, "model": {"description": "", "tags": []}, "creator": {}, + "images": [ + {"meta": {"comfy": {"foo": "bar"}, "other": "keep"}}, + {"meta": "not-a-dict"}, + ], } model_payload = {"description": "desc", "tags": ["tag"], "creator": {"username": "user"}} @@ -83,6 +87,8 @@ async def test_get_model_by_hash_enriches_metadata(monkeypatch, downloader): assert result["model"]["description"] == "desc" assert result["model"]["tags"] == ["tag"] assert result["creator"] == {"username": "user"} + assert "comfy" not in result["images"][0]["meta"] + assert result["images"][0]["meta"]["other"] == "keep" async def test_get_model_by_hash_handles_not_found(monkeypatch, downloader): @@ -144,6 +150,7 @@ async def test_get_model_version_by_version_id(monkeypatch, downloader): "modelId": 321, "model": {"description": ""}, "files": [], + "images": [{"meta": {"comfy": {"foo": "bar"}, "other": "keep"}}], } if url.endswith("/models/321"): return True, {"description": "desc", "tags": ["tag"], "creator": {"username": "user"}} @@ -158,6 +165,8 @@ async def test_get_model_version_by_version_id(monkeypatch, downloader): assert result["model"]["description"] == "desc" assert result["model"]["tags"] == ["tag"] assert result["creator"] == {"username": "user"} + assert "comfy" not in result["images"][0]["meta"] + assert result["images"][0]["meta"]["other"] == "keep" async def test_get_model_version_requires_identifier(monkeypatch, downloader): @@ -181,7 +190,7 @@ async def test_get_model_version_info_handles_not_found(monkeypatch, downloader) async def test_get_model_version_info_success(monkeypatch, downloader): - expected = {"id": 55} + expected = {"id": 55, "images": [{"meta": {"comfy": {"foo": "bar"}, "other": "keep"}}]} async def fake_make_request(method, url, use_auth=True): return True, expected @@ -194,6 +203,8 @@ async def test_get_model_version_info_success(monkeypatch, downloader): assert result == expected assert error is None + assert "comfy" not in result["images"][0]["meta"] + assert result["images"][0]["meta"]["other"] == "keep" async def test_get_image_info_returns_first_item(monkeypatch, downloader):