refactor: enhance model version retrieval logic in CivitaiClient, fixes #460

This commit is contained in:
Will Miao
2025-09-24 09:16:02 +08:00
parent 6f4453aaf3
commit 12c88835f2
3 changed files with 176 additions and 20 deletions

View File

@@ -1,4 +1,5 @@
import os
import copy
import logging
import asyncio
from typing import Optional, Dict, Tuple, List
@@ -189,31 +190,76 @@ class CivitaiClient:
)
if not success:
return None
model_versions = data.get('modelVersions', [])
# Step 2: Determine the version_id to use
target_version_id = version_id
if target_version_id is None:
target_version_id = model_versions[0].get('id')
# Step 3: Get detailed version info using the version_id
success, version = await downloader.make_request(
'GET',
f"{self.base_url}/model-versions/{target_version_id}",
use_auth=True
)
if not success:
if not model_versions:
logger.warning(f"No model versions found for model {model_id}")
return None
# Step 2: Determine the target version entry to use
target_version = None
if version_id is not None:
target_version = next(
(item for item in model_versions if item.get('id') == version_id),
None
)
if target_version is None:
logger.warning(
f"Version {version_id} not found for model {model_id}, defaulting to first version"
)
if target_version is None:
target_version = model_versions[0]
target_version_id = target_version.get('id')
# Step 3: Get detailed version info using the SHA256 hash
model_hash = None
for file_info in target_version.get('files', []):
if file_info.get('type') == 'Model' and file_info.get('primary'):
model_hash = file_info.get('hashes', {}).get('SHA256')
if model_hash:
break
version = None
if model_hash:
success, version = await downloader.make_request(
'GET',
f"{self.base_url}/model-versions/by-hash/{model_hash}",
use_auth=True
)
if not success:
logger.warning(
f"Failed to fetch version by hash for model {model_id} version {target_version_id}: {version}"
)
version = None
else:
logger.warning(
f"No primary model hash found for model {model_id} version {target_version_id}"
)
if version is None:
version = copy.deepcopy(target_version)
version.pop('index', None)
version['modelId'] = model_id
version['model'] = {
'name': data.get('name'),
'type': data.get('type'),
'nsfw': data.get('nsfw'),
'poi': data.get('poi')
}
# Step 4: Enrich version_info with model data
# Add description and tags from model data
version['model']['description'] = data.get("description")
version['model']['tags'] = data.get("tags", [])
model_info = version.get('model')
if not isinstance(model_info, dict):
model_info = {}
version['model'] = model_info
model_info['description'] = data.get("description")
model_info['tags'] = data.get("tags", [])
# Add creator from model data
version['creator'] = data.get("creator")
return version
# Case 3: Neither model_id nor version_id provided