From c533a8e7bf9cfca0d532e9bb3266d624285f0a46 Mon Sep 17 00:00:00 2001 From: Will Miao Date: Thu, 20 Nov 2025 16:31:48 +0800 Subject: [PATCH] feat: enhance Civitai metadata handling and image URL processing - Import rewrite_preview_url utility for optimized image URL handling - Update thumbnail URL processing for both LoRA and checkpoint entries to use rewritten URLs - Expand checkpoint metadata with modelId, file size, SHA256 hash, and file name - Improve error handling and data validation for Civitai API responses - Maintain backward compatibility with existing data structures --- py/recipes/base.py | 76 ++++++++++++----- py/recipes/parsers/civitai_image.py | 54 ++++++++++-- tests/services/test_civitai_image_parser.py | 95 +++++++++++++++++++++ 3 files changed, 193 insertions(+), 32 deletions(-) diff --git a/py/recipes/base.py b/py/recipes/base.py index 3897c683..43534348 100644 --- a/py/recipes/base.py +++ b/py/recipes/base.py @@ -8,6 +8,7 @@ from typing import Dict, List, Any, Optional, Tuple from abc import ABC, abstractmethod from ..config import config from ..utils.constants import VALID_LORA_TYPES +from ..utils.civitai_utils import rewrite_preview_url logger = logging.getLogger(__name__) @@ -78,7 +79,7 @@ class RecipeMetadataParser(ABC): # Update model name if available if 'model' in civitai_info and 'name' in civitai_info['model']: lora_entry['name'] = civitai_info['model']['name'] - + lora_entry['id'] = civitai_info.get('id') lora_entry['modelId'] = civitai_info.get('modelId') @@ -88,7 +89,10 @@ class RecipeMetadataParser(ABC): # Get thumbnail URL from first image if 'images' in civitai_info and civitai_info['images']: - lora_entry['thumbnailUrl'] = civitai_info['images'][0].get('url', '') + image_url = civitai_info['images'][0].get('url') + if image_url: + rewritten_image_url, _ = rewrite_preview_url(image_url, media_type='image') + lora_entry['thumbnailUrl'] = rewritten_image_url or image_url # Get base model current_base_model = civitai_info.get('baseModel', '') @@ -151,33 +155,59 @@ class RecipeMetadataParser(ABC): Args: checkpoint: The checkpoint entry to populate - civitai_info: The response from Civitai API + civitai_info: The response from Civitai API or a (data, error_msg) tuple Returns: The populated checkpoint dict """ try: - if civitai_info and civitai_info.get("error") != "Model not found": - # Update model name if available - if 'model' in civitai_info and 'name' in civitai_info['model']: - checkpoint['name'] = civitai_info['model']['name'] - - # Update version if available - if 'name' in civitai_info: - checkpoint['version'] = civitai_info.get('name', '') - - # Get thumbnail URL from first image - if 'images' in civitai_info and civitai_info['images']: - checkpoint['thumbnailUrl'] = civitai_info['images'][0].get('url', '') - - # Get base model - checkpoint['baseModel'] = civitai_info.get('baseModel', '') - - # Get download URL - checkpoint['downloadUrl'] = civitai_info.get('downloadUrl', '') - else: - # Model not found or deleted + civitai_data, error_msg = ( + (civitai_info, None) + if not isinstance(civitai_info, tuple) + else civitai_info + ) + + if not civitai_data or error_msg == "Model not found": checkpoint['isDeleted'] = True + return checkpoint + + if 'model' in civitai_data and 'name' in civitai_data['model']: + checkpoint['name'] = civitai_data['model']['name'] + + if 'name' in civitai_data: + checkpoint['version'] = civitai_data.get('name', '') + + if 'images' in civitai_data and civitai_data['images']: + image_url = civitai_data['images'][0].get('url') + if image_url: + rewritten_image_url, _ = rewrite_preview_url(image_url, media_type='image') + checkpoint['thumbnailUrl'] = rewritten_image_url or image_url + + checkpoint['baseModel'] = civitai_data.get('baseModel', '') + checkpoint['downloadUrl'] = civitai_data.get('downloadUrl', '') + + checkpoint['modelId'] = civitai_data.get('modelId', checkpoint.get('modelId', 0)) + + if 'files' in civitai_data: + model_file = next( + ( + file + for file in civitai_data.get('files', []) + if file.get('type') == 'Model' + ), + None, + ) + + if model_file: + checkpoint['size'] = model_file.get('sizeKB', 0) * 1024 + + sha256 = model_file.get('hashes', {}).get('SHA256') + if sha256: + checkpoint['hash'] = sha256.lower() + + file_name = model_file.get('name', '') + if file_name: + checkpoint['file_name'] = os.path.splitext(file_name)[0] except Exception as e: logger.error(f"Error populating checkpoint from Civitai info: {e}") diff --git a/py/recipes/parsers/civitai_image.py b/py/recipes/parsers/civitai_image.py index 409c5fa3..8c1e1f8c 100644 --- a/py/recipes/parsers/civitai_image.py +++ b/py/recipes/parsers/civitai_image.py @@ -50,6 +50,7 @@ class CivitaiApiMetadataParser(RecipeMetadataParser): result = { 'base_model': None, 'loras': [], + 'model': None, 'gen_params': {}, 'from_civitai_image': True } @@ -174,13 +175,48 @@ class CivitaiApiMetadataParser(RecipeMetadataParser): # Process civitaiResources array if "civitaiResources" in metadata and isinstance(metadata["civitaiResources"], list): for resource in metadata["civitaiResources"]: - # Get unique identifier for deduplication + # Get resource type and identifier + resource_type = str(resource.get("type") or "").lower() version_id = str(resource.get("modelVersionId", "")) - + + if resource_type == "checkpoint": + checkpoint_entry = { + 'id': resource.get("modelVersionId", 0), + 'modelId': resource.get("modelId", 0), + 'name': resource.get("modelName", "Unknown Checkpoint"), + 'version': resource.get("modelVersionName", ""), + 'type': resource.get("type", "checkpoint"), + 'existsLocally': False, + 'localPath': None, + 'file_name': resource.get("modelName", ""), + 'hash': resource.get("hash", "") or "", + 'thumbnailUrl': '/loras_static/images/no-preview.png', + 'baseModel': '', + 'size': 0, + 'downloadUrl': '', + 'isDeleted': False + } + + if version_id and metadata_provider: + try: + civitai_info = await metadata_provider.get_model_version_info(version_id) + + checkpoint_entry = await self.populate_checkpoint_from_civitai( + checkpoint_entry, + civitai_info + ) + except Exception as e: + logger.error(f"Error fetching Civitai info for checkpoint version {version_id}: {e}") + + if result["model"] is None: + result["model"] = checkpoint_entry + + continue + # Skip if we've already added this LoRA if version_id and version_id in added_loras: continue - + # Initialize lora entry lora_entry = { 'id': resource.get("modelVersionId", 0), @@ -196,31 +232,31 @@ class CivitaiApiMetadataParser(RecipeMetadataParser): 'downloadUrl': '', 'isDeleted': False } - + # Try to get info from Civitai if modelVersionId is available if version_id and metadata_provider: try: # Use get_model_version_info instead of get_model_version civitai_info = await metadata_provider.get_model_version_info(version_id) - + populated_entry = await self.populate_lora_from_civitai( lora_entry, civitai_info, recipe_scanner, base_model_counts ) - + if populated_entry is None: continue # Skip invalid LoRA types - + lora_entry = populated_entry except Exception as e: logger.error(f"Error fetching Civitai info for model version {version_id}: {e}") - + # Track this LoRA in our deduplication dict if version_id: added_loras[version_id] = len(result["loras"]) - + result["loras"].append(lora_entry) # Process additionalResources array diff --git a/tests/services/test_civitai_image_parser.py b/tests/services/test_civitai_image_parser.py index 54353336..e222765b 100644 --- a/tests/services/test_civitai_image_parser.py +++ b/tests/services/test_civitai_image_parser.py @@ -59,3 +59,98 @@ async def test_parse_metadata_creates_loras_from_hashes(monkeypatch): "绪儿 厚涂构图光影质感增强V3", } + +@pytest.mark.asyncio +async def test_parse_metadata_populates_checkpoint_and_rewrites_thumbnails(monkeypatch): + checkpoint_info = { + "id": 222, + "modelId": 111, + "model": {"name": "Checkpoint Example", "type": "checkpoint"}, + "name": "Checkpoint Version", + "images": [{"url": "https://image.civitai.com/checkpoints/original=true"}], + "baseModel": "Illustrious", + "downloadUrl": "https://civitai.com/checkpoint/download", + "files": [ + { + "type": "Model", + "primary": True, + "sizeKB": 1024, + "name": "Checkpoint Example.safetensors", + "hashes": {"SHA256": "FFAA0011"}, + } + ], + } + + lora_info = { + "id": 444, + "modelId": 333, + "model": {"name": "Example Lora Model", "type": "lora"}, + "name": "Example Lora Version", + "images": [{"url": "https://image.civitai.com/loras/original=true"}], + "baseModel": "Illustrious", + "downloadUrl": "https://civitai.com/lora/download", + "files": [ + { + "type": "Model", + "primary": True, + "sizeKB": 512, + "hashes": {"SHA256": "abc123"}, + } + ], + } + + async def fake_metadata_provider(): + class Provider: + async def get_model_version_info(self, version_id): + if version_id == "222": + return checkpoint_info, None + if version_id == "444": + return lora_info, None + return None, "Model not found" + + return Provider() + + monkeypatch.setattr( + "py.recipes.parsers.civitai_image.get_default_metadata_provider", + fake_metadata_provider, + ) + + parser = CivitaiApiMetadataParser() + + metadata = { + "prompt": "test prompt", + "negativePrompt": "test negative prompt", + "civitaiResources": [ + { + "type": "checkpoint", + "modelId": 111, + "modelVersionId": 222, + "modelName": "Checkpoint Example", + "modelVersionName": "Checkpoint Version", + }, + { + "type": "lora", + "modelId": 333, + "modelVersionId": 444, + "modelName": "Example Lora", + "modelVersionName": "Lora Version", + "weight": 0.7, + }, + ], + } + + result = await parser.parse_metadata(metadata) + + assert result["model"] is not None + assert result["model"]["name"] == "Checkpoint Example" + assert result["model"]["type"] == "checkpoint" + assert result["model"]["thumbnailUrl"] == "https://image.civitai.com/checkpoints/width=450,optimized=true" + assert result["model"]["modelId"] == 111 + assert result["model"]["size"] == 1024 * 1024 + assert result["model"]["hash"] == "ffaa0011" + assert result["model"]["file_name"] == "Checkpoint Example" + + assert result["loras"] + assert result["loras"][0]["name"] == "Example Lora Model" + assert result["loras"][0]["thumbnailUrl"] == "https://image.civitai.com/loras/width=450,optimized=true" + assert result["loras"][0]["hash"] == "abc123"