fix(civitai): improve model version retrieval

This commit is contained in:
pixelpaws
2025-10-09 10:56:25 +08:00
parent f0e246b4ac
commit 8e51f0f19f
2 changed files with 299 additions and 127 deletions

View File

@@ -1,7 +1,7 @@
import os import asyncio
import copy import copy
import logging import logging
import asyncio import os
from typing import Optional, Dict, Tuple, List from typing import Optional, Dict, Tuple, List
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager
from .downloader import get_downloader from .downloader import get_downloader
@@ -157,27 +157,26 @@ class CivitaiClient:
return None return None
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
"""Get specific model version with additional metadata """Get specific model version with additional metadata."""
Args:
model_id: The Civitai model ID (optional if version_id is provided)
version_id: Optional specific version ID to retrieve
Returns:
Optional[Dict]: The model version data with additional fields or None if not found
"""
try: try:
downloader = await get_downloader() downloader = await get_downloader()
# Case 1: Only version_id is provided
if model_id is None and version_id is not None: if model_id is None and version_id is not None:
# First get the version info to extract model_id return await self._get_version_by_id_only(downloader, version_id)
success, version = await downloader.make_request(
'GET', if model_id is not None:
f"{self.base_url}/model-versions/{version_id}", return await self._get_version_with_model_id(downloader, model_id, version_id)
use_auth=True
) logger.error("Either model_id or version_id must be provided")
if not success: return None
except Exception as e:
logger.error(f"Error fetching model version: {e}")
return None
async def _get_version_by_id_only(self, downloader, version_id: int) -> Optional[Dict]:
version = await self._fetch_version_by_id(downloader, version_id)
if version is None:
return None return None
model_id = version.get('modelId') model_id = version.get('modelId')
@@ -185,39 +184,88 @@ class CivitaiClient:
logger.error(f"No modelId found in version {version_id}") logger.error(f"No modelId found in version {version_id}")
return None return None
# Now get the model data for additional metadata model_data = await self._fetch_model_data(downloader, model_id)
success, model_data = await downloader.make_request( if model_data:
'GET', self._enrich_version_with_model_data(version, model_data)
f"{self.base_url}/models/{model_id}",
use_auth=True
)
if success:
# Enrich version with model data
version['model']['description'] = model_data.get("description")
version['model']['tags'] = model_data.get("tags", [])
version['creator'] = model_data.get("creator")
self._remove_comfy_metadata(version) self._remove_comfy_metadata(version)
return version return version
# Case 2: model_id is provided (with or without version_id) async def _get_version_with_model_id(self, downloader, model_id: int, version_id: Optional[int]) -> Optional[Dict]:
elif model_id is not None: model_data = await self._fetch_model_data(downloader, model_id)
# Step 1: Get model data to find version_id if not provided and get additional metadata if not model_data:
return None
target_version = self._select_target_version(model_data, model_id, version_id)
if target_version is None:
return None
target_version_id = target_version.get('id')
version = await self._fetch_version_by_id(downloader, target_version_id) if target_version_id else None
if version is None:
model_hash = self._extract_primary_model_hash(target_version)
if model_hash:
version = await self._fetch_version_by_hash(downloader, model_hash)
else:
logger.warning(
f"No primary model hash found for model {model_id} version {target_version_id}"
)
if version is None:
version = self._build_version_from_model_data(target_version, model_id, model_data)
self._enrich_version_with_model_data(version, model_data)
self._remove_comfy_metadata(version)
return version
async def _fetch_model_data(self, downloader, model_id: int) -> Optional[Dict]:
success, data = await downloader.make_request( success, data = await downloader.make_request(
'GET', 'GET',
f"{self.base_url}/models/{model_id}", f"{self.base_url}/models/{model_id}",
use_auth=True use_auth=True
) )
if not success: if success:
return data
logger.warning(f"Failed to fetch model data for model {model_id}")
return None return None
model_versions = data.get('modelVersions', []) async def _fetch_version_by_id(self, downloader, version_id: Optional[int]) -> Optional[Dict]:
if version_id is None:
return None
success, version = await downloader.make_request(
'GET',
f"{self.base_url}/model-versions/{version_id}",
use_auth=True
)
if success:
return version
logger.warning(f"Failed to fetch version by id {version_id}")
return None
async def _fetch_version_by_hash(self, downloader, model_hash: Optional[str]) -> Optional[Dict]:
if not model_hash:
return None
success, version = await downloader.make_request(
'GET',
f"{self.base_url}/model-versions/by-hash/{model_hash}",
use_auth=True
)
if success:
return version
logger.warning(f"Failed to fetch version by hash {model_hash}")
return None
def _select_target_version(self, model_data: Dict, model_id: int, version_id: Optional[int]) -> Optional[Dict]:
model_versions = model_data.get('modelVersions', [])
if not model_versions: if not model_versions:
logger.warning(f"No model versions found for model {model_id}") logger.warning(f"No model versions found for model {model_id}")
return None return None
# Step 2: Determine the target version entry to use
target_version = None
if version_id is not None: if version_id is not None:
target_version = next( target_version = next(
(item for item in model_versions if item.get('id') == version_id), (item for item in model_versions if item.get('id') == version_id),
@@ -227,70 +275,41 @@ class CivitaiClient:
logger.warning( logger.warning(
f"Version {version_id} not found for model {model_id}, defaulting to first version" f"Version {version_id} not found for model {model_id}, defaulting to first version"
) )
if target_version is None: return model_versions[0]
target_version = model_versions[0] return target_version
target_version_id = target_version.get('id') return model_versions[0]
# Step 3: Get detailed version info using the SHA256 hash def _extract_primary_model_hash(self, version_entry: Dict) -> Optional[str]:
model_hash = None for file_info in version_entry.get('files', []):
for file_info in target_version.get('files', []):
if file_info.get('type') == 'Model' and file_info.get('primary'): if file_info.get('type') == 'Model' and file_info.get('primary'):
model_hash = file_info.get('hashes', {}).get('SHA256') hashes = file_info.get('hashes', {})
model_hash = hashes.get('SHA256')
if model_hash: if model_hash:
break return model_hash
return None
version = None def _build_version_from_model_data(self, version_entry: Dict, model_id: int, model_data: Dict) -> Dict:
if model_hash: version = copy.deepcopy(version_entry)
success, version = await downloader.make_request(
'GET',
f"{self.base_url}/model-versions/by-hash/{model_hash}",
use_auth=True
)
if not success:
logger.warning(
f"Failed to fetch version by hash for model {model_id} version {target_version_id}: {version}"
)
version = None
else:
logger.warning(
f"No primary model hash found for model {model_id} version {target_version_id}"
)
if version is None:
version = copy.deepcopy(target_version)
version.pop('index', None) version.pop('index', None)
version['modelId'] = model_id version['modelId'] = model_id
version['model'] = { version['model'] = {
'name': data.get('name'), 'name': model_data.get('name'),
'type': data.get('type'), 'type': model_data.get('type'),
'nsfw': data.get('nsfw'), 'nsfw': model_data.get('nsfw'),
'poi': data.get('poi') 'poi': model_data.get('poi')
} }
return version
# Step 4: Enrich version_info with model data def _enrich_version_with_model_data(self, version: Dict, model_data: Dict) -> None:
# Add description and tags from model data
model_info = version.get('model') model_info = version.get('model')
if not isinstance(model_info, dict): if not isinstance(model_info, dict):
model_info = {} model_info = {}
version['model'] = model_info version['model'] = model_info
model_info['description'] = data.get("description")
model_info['tags'] = data.get("tags", [])
# Add creator from model data model_info['description'] = model_data.get("description")
version['creator'] = data.get("creator") model_info['tags'] = model_data.get("tags", [])
version['creator'] = model_data.get("creator")
self._remove_comfy_metadata(version)
return version
# Case 3: Neither model_id nor version_id provided
else:
logger.error("Either model_id or version_id must be provided")
return None
except Exception as e:
logger.error(f"Error fetching model version: {e}")
return None
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]: async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Fetch model version metadata from Civitai """Fetch model version metadata from Civitai

View File

@@ -1,3 +1,4 @@
import copy
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
import pytest import pytest
@@ -169,6 +170,158 @@ async def test_get_model_version_by_version_id(monkeypatch, downloader):
assert result["images"][0]["meta"]["other"] == "keep" assert result["images"][0]["meta"]["other"] == "keep"
async def test_get_model_version_with_model_id_prefers_version_endpoint(monkeypatch, downloader):
requests = []
model_payload = {
"modelVersions": [
{
"id": 7,
"files": [
{
"type": "Model",
"primary": True,
"hashes": {"SHA256": "hash7"},
}
],
}
],
"description": "desc",
"tags": ["tag"],
"creator": {"username": "user"},
"name": "Model",
"type": "LORA",
"nsfw": False,
"poi": False,
}
version_payload = {
"id": 7,
"modelId": 99,
"model": {},
"files": [],
"images": [],
}
async def fake_make_request(method, url, use_auth=True):
requests.append(url)
if url.endswith("/models/99"):
return True, copy.deepcopy(model_payload)
if url.endswith("/model-versions/7"):
return True, copy.deepcopy(version_payload)
return False, "unexpected"
downloader.make_request = fake_make_request
client = await CivitaiClient.get_instance()
result = await client.get_model_version(model_id=99, version_id=7)
assert result["id"] == 7
assert result["model"]["description"] == "desc"
assert result["model"]["tags"] == ["tag"]
assert result["creator"] == {"username": "user"}
assert requests[0].endswith("/models/99")
assert requests[1].endswith("/model-versions/7")
async def test_get_model_version_with_model_id_fallbacks_to_hash(monkeypatch, downloader):
requests = []
model_payload = {
"modelVersions": [
{
"id": 7,
"files": [
{
"type": "Model",
"primary": True,
"hashes": {"SHA256": "hash7"},
}
],
}
],
"description": "desc",
"tags": ["tag"],
"creator": {"username": "user"},
"name": "Model",
"type": "LORA",
"nsfw": False,
"poi": False,
}
version_payload = {
"id": 7,
"modelId": 99,
"files": [],
"images": [],
}
async def fake_make_request(method, url, use_auth=True):
requests.append(url)
if url.endswith("/models/99"):
return True, copy.deepcopy(model_payload)
if url.endswith("/model-versions/7"):
return False, "boom"
if url.endswith("/model-versions/by-hash/hash7"):
return True, copy.deepcopy(version_payload)
return False, "unexpected"
downloader.make_request = fake_make_request
client = await CivitaiClient.get_instance()
result = await client.get_model_version(model_id=99, version_id=7)
assert result["id"] == 7
assert result["model"]["description"] == "desc"
assert result["model"]["tags"] == ["tag"]
assert result["creator"] == {"username": "user"}
assert requests[1].endswith("/model-versions/7")
assert requests[2].endswith("/model-versions/by-hash/hash7")
async def test_get_model_version_with_model_id_builds_from_model_data(monkeypatch, downloader):
model_payload = {
"modelVersions": [
{
"id": 7,
"files": [],
"name": "v1",
}
],
"description": "desc",
"tags": ["tag"],
"creator": {"username": "user"},
"name": "Model",
"type": "LORA",
"nsfw": False,
"poi": False,
}
async def fake_make_request(method, url, use_auth=True):
if url.endswith("/models/99"):
return True, copy.deepcopy(model_payload)
if url.endswith("/model-versions/7"):
return False, "boom"
if "/model-versions/by-hash/" in url:
return False, "boom"
return False, "unexpected"
downloader.make_request = fake_make_request
client = await CivitaiClient.get_instance()
result = await client.get_model_version(model_id=99, version_id=7)
assert result["modelId"] == 99
assert result["model"]["name"] == "Model"
assert result["model"]["type"] == "LORA"
assert result["model"]["description"] == "desc"
assert result["model"]["tags"] == ["tag"]
assert result["creator"] == {"username": "user"}
async def test_get_model_version_requires_identifier(monkeypatch, downloader): async def test_get_model_version_requires_identifier(monkeypatch, downloader):
client = await CivitaiClient.get_instance() client = await CivitaiClient.get_instance()
result = await client.get_model_version() result = await client.get_model_version()