From 8e51f0f19fcd41bf254b4268d90e61920b60d767 Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Thu, 9 Oct 2025 10:56:25 +0800 Subject: [PATCH] fix(civitai): improve model version retrieval --- py/services/civitai_client.py | 273 ++++++++++++++------------ tests/services/test_civitai_client.py | 153 +++++++++++++++ 2 files changed, 299 insertions(+), 127 deletions(-) diff --git a/py/services/civitai_client.py b/py/services/civitai_client.py index 4ec3ff0b..5598d7d7 100644 --- a/py/services/civitai_client.py +++ b/py/services/civitai_client.py @@ -1,7 +1,7 @@ -import os +import asyncio import copy import logging -import asyncio +import os from typing import Optional, Dict, Tuple, List from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager from .downloader import get_downloader @@ -157,141 +157,160 @@ class CivitaiClient: return None async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: - """Get specific model version with additional metadata - - Args: - model_id: The Civitai model ID (optional if version_id is provided) - version_id: Optional specific version ID to retrieve - - Returns: - Optional[Dict]: The model version data with additional fields or None if not found - """ + """Get specific model version with additional metadata.""" try: downloader = await get_downloader() - - # Case 1: Only version_id is provided + if model_id is None and version_id is not None: - # First get the version info to extract model_id - success, version = await downloader.make_request( - 'GET', - f"{self.base_url}/model-versions/{version_id}", - use_auth=True - ) - if not success: - return None - - model_id = version.get('modelId') - if not model_id: - logger.error(f"No modelId found in version {version_id}") - return None - - # Now get the model data for additional metadata - success, model_data = await downloader.make_request( - 'GET', - f"{self.base_url}/models/{model_id}", - use_auth=True - ) - if success: - # Enrich version with model data - version['model']['description'] = model_data.get("description") - version['model']['tags'] = model_data.get("tags", []) - version['creator'] = model_data.get("creator") + return await self._get_version_by_id_only(downloader, version_id) - self._remove_comfy_metadata(version) - return version - - # Case 2: model_id is provided (with or without version_id) - elif model_id is not None: - # Step 1: Get model data to find version_id if not provided and get additional metadata - success, data = await downloader.make_request( - 'GET', - f"{self.base_url}/models/{model_id}", - use_auth=True - ) - if not success: - return None + if model_id is not None: + return await self._get_version_with_model_id(downloader, model_id, version_id) - model_versions = data.get('modelVersions', []) - if not model_versions: - logger.warning(f"No model versions found for model {model_id}") - return None + logger.error("Either model_id or version_id must be provided") + return None - # Step 2: Determine the target version entry to use - target_version = None - if version_id is not None: - target_version = next( - (item for item in model_versions if item.get('id') == version_id), - None - ) - if target_version is None: - logger.warning( - f"Version {version_id} not found for model {model_id}, defaulting to first version" - ) - if target_version is None: - target_version = model_versions[0] - - target_version_id = target_version.get('id') - - # Step 3: Get detailed version info using the SHA256 hash - model_hash = None - for file_info in target_version.get('files', []): - if file_info.get('type') == 'Model' and file_info.get('primary'): - model_hash = file_info.get('hashes', {}).get('SHA256') - if model_hash: - break - - version = None - if model_hash: - success, version = await downloader.make_request( - 'GET', - f"{self.base_url}/model-versions/by-hash/{model_hash}", - use_auth=True - ) - if not success: - logger.warning( - f"Failed to fetch version by hash for model {model_id} version {target_version_id}: {version}" - ) - version = None - else: - logger.warning( - f"No primary model hash found for model {model_id} version {target_version_id}" - ) - - if version is None: - version = copy.deepcopy(target_version) - version.pop('index', None) - version['modelId'] = model_id - version['model'] = { - 'name': data.get('name'), - 'type': data.get('type'), - 'nsfw': data.get('nsfw'), - 'poi': data.get('poi') - } - - # Step 4: Enrich version_info with model data - # Add description and tags from model data - model_info = version.get('model') - if not isinstance(model_info, dict): - model_info = {} - version['model'] = model_info - model_info['description'] = data.get("description") - model_info['tags'] = data.get("tags", []) - - # Add creator from model data - version['creator'] = data.get("creator") - - self._remove_comfy_metadata(version) - return version - - # Case 3: Neither model_id nor version_id provided - else: - logger.error("Either model_id or version_id must be provided") - return None - except Exception as e: logger.error(f"Error fetching model version: {e}") return None + async def _get_version_by_id_only(self, downloader, version_id: int) -> Optional[Dict]: + version = await self._fetch_version_by_id(downloader, version_id) + if version is None: + return None + + model_id = version.get('modelId') + if not model_id: + logger.error(f"No modelId found in version {version_id}") + return None + + model_data = await self._fetch_model_data(downloader, model_id) + if model_data: + self._enrich_version_with_model_data(version, model_data) + + self._remove_comfy_metadata(version) + return version + + async def _get_version_with_model_id(self, downloader, model_id: int, version_id: Optional[int]) -> Optional[Dict]: + model_data = await self._fetch_model_data(downloader, model_id) + if not model_data: + return None + + target_version = self._select_target_version(model_data, model_id, version_id) + if target_version is None: + return None + + target_version_id = target_version.get('id') + version = await self._fetch_version_by_id(downloader, target_version_id) if target_version_id else None + + if version is None: + model_hash = self._extract_primary_model_hash(target_version) + if model_hash: + version = await self._fetch_version_by_hash(downloader, model_hash) + else: + logger.warning( + f"No primary model hash found for model {model_id} version {target_version_id}" + ) + + if version is None: + version = self._build_version_from_model_data(target_version, model_id, model_data) + + self._enrich_version_with_model_data(version, model_data) + self._remove_comfy_metadata(version) + return version + + async def _fetch_model_data(self, downloader, model_id: int) -> Optional[Dict]: + success, data = await downloader.make_request( + 'GET', + f"{self.base_url}/models/{model_id}", + use_auth=True + ) + if success: + return data + logger.warning(f"Failed to fetch model data for model {model_id}") + return None + + async def _fetch_version_by_id(self, downloader, version_id: Optional[int]) -> Optional[Dict]: + if version_id is None: + return None + + success, version = await downloader.make_request( + 'GET', + f"{self.base_url}/model-versions/{version_id}", + use_auth=True + ) + if success: + return version + + logger.warning(f"Failed to fetch version by id {version_id}") + return None + + async def _fetch_version_by_hash(self, downloader, model_hash: Optional[str]) -> Optional[Dict]: + if not model_hash: + return None + + success, version = await downloader.make_request( + 'GET', + f"{self.base_url}/model-versions/by-hash/{model_hash}", + use_auth=True + ) + if success: + return version + + logger.warning(f"Failed to fetch version by hash {model_hash}") + return None + + def _select_target_version(self, model_data: Dict, model_id: int, version_id: Optional[int]) -> Optional[Dict]: + model_versions = model_data.get('modelVersions', []) + if not model_versions: + logger.warning(f"No model versions found for model {model_id}") + return None + + if version_id is not None: + target_version = next( + (item for item in model_versions if item.get('id') == version_id), + None + ) + if target_version is None: + logger.warning( + f"Version {version_id} not found for model {model_id}, defaulting to first version" + ) + return model_versions[0] + return target_version + + return model_versions[0] + + def _extract_primary_model_hash(self, version_entry: Dict) -> Optional[str]: + for file_info in version_entry.get('files', []): + if file_info.get('type') == 'Model' and file_info.get('primary'): + hashes = file_info.get('hashes', {}) + model_hash = hashes.get('SHA256') + if model_hash: + return model_hash + return None + + def _build_version_from_model_data(self, version_entry: Dict, model_id: int, model_data: Dict) -> Dict: + version = copy.deepcopy(version_entry) + version.pop('index', None) + version['modelId'] = model_id + version['model'] = { + 'name': model_data.get('name'), + 'type': model_data.get('type'), + 'nsfw': model_data.get('nsfw'), + 'poi': model_data.get('poi') + } + return version + + def _enrich_version_with_model_data(self, version: Dict, model_data: Dict) -> None: + model_info = version.get('model') + if not isinstance(model_info, dict): + model_info = {} + version['model'] = model_info + + model_info['description'] = model_data.get("description") + model_info['tags'] = model_data.get("tags", []) + version['creator'] = model_data.get("creator") + async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]: """Fetch model version metadata from Civitai diff --git a/tests/services/test_civitai_client.py b/tests/services/test_civitai_client.py index c6241478..aaf0a0a9 100644 --- a/tests/services/test_civitai_client.py +++ b/tests/services/test_civitai_client.py @@ -1,3 +1,4 @@ +import copy from unittest.mock import AsyncMock import pytest @@ -169,6 +170,158 @@ async def test_get_model_version_by_version_id(monkeypatch, downloader): assert result["images"][0]["meta"]["other"] == "keep" +async def test_get_model_version_with_model_id_prefers_version_endpoint(monkeypatch, downloader): + requests = [] + + model_payload = { + "modelVersions": [ + { + "id": 7, + "files": [ + { + "type": "Model", + "primary": True, + "hashes": {"SHA256": "hash7"}, + } + ], + } + ], + "description": "desc", + "tags": ["tag"], + "creator": {"username": "user"}, + "name": "Model", + "type": "LORA", + "nsfw": False, + "poi": False, + } + + version_payload = { + "id": 7, + "modelId": 99, + "model": {}, + "files": [], + "images": [], + } + + async def fake_make_request(method, url, use_auth=True): + requests.append(url) + if url.endswith("/models/99"): + return True, copy.deepcopy(model_payload) + if url.endswith("/model-versions/7"): + return True, copy.deepcopy(version_payload) + return False, "unexpected" + + downloader.make_request = fake_make_request + + client = await CivitaiClient.get_instance() + + result = await client.get_model_version(model_id=99, version_id=7) + + assert result["id"] == 7 + assert result["model"]["description"] == "desc" + assert result["model"]["tags"] == ["tag"] + assert result["creator"] == {"username": "user"} + assert requests[0].endswith("/models/99") + assert requests[1].endswith("/model-versions/7") + + +async def test_get_model_version_with_model_id_fallbacks_to_hash(monkeypatch, downloader): + requests = [] + + model_payload = { + "modelVersions": [ + { + "id": 7, + "files": [ + { + "type": "Model", + "primary": True, + "hashes": {"SHA256": "hash7"}, + } + ], + } + ], + "description": "desc", + "tags": ["tag"], + "creator": {"username": "user"}, + "name": "Model", + "type": "LORA", + "nsfw": False, + "poi": False, + } + + version_payload = { + "id": 7, + "modelId": 99, + "files": [], + "images": [], + } + + async def fake_make_request(method, url, use_auth=True): + requests.append(url) + if url.endswith("/models/99"): + return True, copy.deepcopy(model_payload) + if url.endswith("/model-versions/7"): + return False, "boom" + if url.endswith("/model-versions/by-hash/hash7"): + return True, copy.deepcopy(version_payload) + return False, "unexpected" + + downloader.make_request = fake_make_request + + client = await CivitaiClient.get_instance() + + result = await client.get_model_version(model_id=99, version_id=7) + + assert result["id"] == 7 + assert result["model"]["description"] == "desc" + assert result["model"]["tags"] == ["tag"] + assert result["creator"] == {"username": "user"} + assert requests[1].endswith("/model-versions/7") + assert requests[2].endswith("/model-versions/by-hash/hash7") + + +async def test_get_model_version_with_model_id_builds_from_model_data(monkeypatch, downloader): + model_payload = { + "modelVersions": [ + { + "id": 7, + "files": [], + "name": "v1", + } + ], + "description": "desc", + "tags": ["tag"], + "creator": {"username": "user"}, + "name": "Model", + "type": "LORA", + "nsfw": False, + "poi": False, + } + + async def fake_make_request(method, url, use_auth=True): + if url.endswith("/models/99"): + return True, copy.deepcopy(model_payload) + if url.endswith("/model-versions/7"): + return False, "boom" + if "/model-versions/by-hash/" in url: + return False, "boom" + return False, "unexpected" + + downloader.make_request = fake_make_request + + client = await CivitaiClient.get_instance() + + result = await client.get_model_version(model_id=99, version_id=7) + + assert result["modelId"] == 99 + assert result["model"]["name"] == "Model" + assert result["model"]["type"] == "LORA" + assert result["model"]["description"] == "desc" + assert result["model"]["tags"] == ["tag"] + assert result["creator"] == {"username": "user"} + + async def test_get_model_version_requires_identifier(monkeypatch, downloader): client = await CivitaiClient.get_instance() result = await client.get_model_version()