From 12c88835f2285c14e3901db073a0ba9a65c0fc8d Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 24 Sep 2025 09:16:02 +0800 Subject: [PATCH] refactor: enhance model version retrieval logic in CivitaiClient, fixes #460 --- py/services/civitai_client.py | 84 ++++++++++---- refs/civitai_api_model_by_modelId.json | 110 +++++++++++++++++++ tests/routes/test_base_model_routes_smoke.py | 2 +- 3 files changed, 176 insertions(+), 20 deletions(-) create mode 100644 refs/civitai_api_model_by_modelId.json diff --git a/py/services/civitai_client.py b/py/services/civitai_client.py index ccadbcb5..bb6004ed 100644 --- a/py/services/civitai_client.py +++ b/py/services/civitai_client.py @@ -1,4 +1,5 @@ import os +import copy import logging import asyncio from typing import Optional, Dict, Tuple, List @@ -189,31 +190,76 @@ class CivitaiClient: ) if not success: return None - + model_versions = data.get('modelVersions', []) - - # Step 2: Determine the version_id to use - target_version_id = version_id - if target_version_id is None: - target_version_id = model_versions[0].get('id') - - # Step 3: Get detailed version info using the version_id - success, version = await downloader.make_request( - 'GET', - f"{self.base_url}/model-versions/{target_version_id}", - use_auth=True - ) - if not success: + if not model_versions: + logger.warning(f"No model versions found for model {model_id}") return None - + + # Step 2: Determine the target version entry to use + target_version = None + if version_id is not None: + target_version = next( + (item for item in model_versions if item.get('id') == version_id), + None + ) + if target_version is None: + logger.warning( + f"Version {version_id} not found for model {model_id}, defaulting to first version" + ) + if target_version is None: + target_version = model_versions[0] + + target_version_id = target_version.get('id') + + # Step 3: Get detailed version info using the SHA256 hash + model_hash = None + for file_info in target_version.get('files', []): + if file_info.get('type') == 'Model' and file_info.get('primary'): + model_hash = file_info.get('hashes', {}).get('SHA256') + if model_hash: + break + + version = None + if model_hash: + success, version = await downloader.make_request( + 'GET', + f"{self.base_url}/model-versions/by-hash/{model_hash}", + use_auth=True + ) + if not success: + logger.warning( + f"Failed to fetch version by hash for model {model_id} version {target_version_id}: {version}" + ) + version = None + else: + logger.warning( + f"No primary model hash found for model {model_id} version {target_version_id}" + ) + + if version is None: + version = copy.deepcopy(target_version) + version.pop('index', None) + version['modelId'] = model_id + version['model'] = { + 'name': data.get('name'), + 'type': data.get('type'), + 'nsfw': data.get('nsfw'), + 'poi': data.get('poi') + } + # Step 4: Enrich version_info with model data # Add description and tags from model data - version['model']['description'] = data.get("description") - version['model']['tags'] = data.get("tags", []) - + model_info = version.get('model') + if not isinstance(model_info, dict): + model_info = {} + version['model'] = model_info + model_info['description'] = data.get("description") + model_info['tags'] = data.get("tags", []) + # Add creator from model data version['creator'] = data.get("creator") - + return version # Case 3: Neither model_id nor version_id provided diff --git a/refs/civitai_api_model_by_modelId.json b/refs/civitai_api_model_by_modelId.json new file mode 100644 index 00000000..2cd20f20 --- /dev/null +++ b/refs/civitai_api_model_by_modelId.json @@ -0,0 +1,110 @@ +{ + "id": 1231067, + "name": "Vivid Impressions Storybook Style", + "description": "