mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Merge pull request #543 from willmiao/codex/refactor-get_model_version-logic-and-add-tests, fixes #540
fix: improve Civitai model version retrieval
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Optional, Dict, Tuple, List
|
||||
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager
|
||||
from .downloader import get_downloader
|
||||
@@ -157,141 +157,160 @@ class CivitaiClient:
|
||||
return None
|
||||
|
||||
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||
"""Get specific model version with additional metadata
|
||||
|
||||
Args:
|
||||
model_id: The Civitai model ID (optional if version_id is provided)
|
||||
version_id: Optional specific version ID to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: The model version data with additional fields or None if not found
|
||||
"""
|
||||
"""Get specific model version with additional metadata."""
|
||||
try:
|
||||
downloader = await get_downloader()
|
||||
|
||||
# Case 1: Only version_id is provided
|
||||
|
||||
if model_id is None and version_id is not None:
|
||||
# First get the version info to extract model_id
|
||||
success, version = await downloader.make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/model-versions/{version_id}",
|
||||
use_auth=True
|
||||
)
|
||||
if not success:
|
||||
return None
|
||||
|
||||
model_id = version.get('modelId')
|
||||
if not model_id:
|
||||
logger.error(f"No modelId found in version {version_id}")
|
||||
return None
|
||||
|
||||
# Now get the model data for additional metadata
|
||||
success, model_data = await downloader.make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/models/{model_id}",
|
||||
use_auth=True
|
||||
)
|
||||
if success:
|
||||
# Enrich version with model data
|
||||
version['model']['description'] = model_data.get("description")
|
||||
version['model']['tags'] = model_data.get("tags", [])
|
||||
version['creator'] = model_data.get("creator")
|
||||
return await self._get_version_by_id_only(downloader, version_id)
|
||||
|
||||
self._remove_comfy_metadata(version)
|
||||
return version
|
||||
|
||||
# Case 2: model_id is provided (with or without version_id)
|
||||
elif model_id is not None:
|
||||
# Step 1: Get model data to find version_id if not provided and get additional metadata
|
||||
success, data = await downloader.make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/models/{model_id}",
|
||||
use_auth=True
|
||||
)
|
||||
if not success:
|
||||
return None
|
||||
if model_id is not None:
|
||||
return await self._get_version_with_model_id(downloader, model_id, version_id)
|
||||
|
||||
model_versions = data.get('modelVersions', [])
|
||||
if not model_versions:
|
||||
logger.warning(f"No model versions found for model {model_id}")
|
||||
return None
|
||||
logger.error("Either model_id or version_id must be provided")
|
||||
return None
|
||||
|
||||
# Step 2: Determine the target version entry to use
|
||||
target_version = None
|
||||
if version_id is not None:
|
||||
target_version = next(
|
||||
(item for item in model_versions if item.get('id') == version_id),
|
||||
None
|
||||
)
|
||||
if target_version is None:
|
||||
logger.warning(
|
||||
f"Version {version_id} not found for model {model_id}, defaulting to first version"
|
||||
)
|
||||
if target_version is None:
|
||||
target_version = model_versions[0]
|
||||
|
||||
target_version_id = target_version.get('id')
|
||||
|
||||
# Step 3: Get detailed version info using the SHA256 hash
|
||||
model_hash = None
|
||||
for file_info in target_version.get('files', []):
|
||||
if file_info.get('type') == 'Model' and file_info.get('primary'):
|
||||
model_hash = file_info.get('hashes', {}).get('SHA256')
|
||||
if model_hash:
|
||||
break
|
||||
|
||||
version = None
|
||||
if model_hash:
|
||||
success, version = await downloader.make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/model-versions/by-hash/{model_hash}",
|
||||
use_auth=True
|
||||
)
|
||||
if not success:
|
||||
logger.warning(
|
||||
f"Failed to fetch version by hash for model {model_id} version {target_version_id}: {version}"
|
||||
)
|
||||
version = None
|
||||
else:
|
||||
logger.warning(
|
||||
f"No primary model hash found for model {model_id} version {target_version_id}"
|
||||
)
|
||||
|
||||
if version is None:
|
||||
version = copy.deepcopy(target_version)
|
||||
version.pop('index', None)
|
||||
version['modelId'] = model_id
|
||||
version['model'] = {
|
||||
'name': data.get('name'),
|
||||
'type': data.get('type'),
|
||||
'nsfw': data.get('nsfw'),
|
||||
'poi': data.get('poi')
|
||||
}
|
||||
|
||||
# Step 4: Enrich version_info with model data
|
||||
# Add description and tags from model data
|
||||
model_info = version.get('model')
|
||||
if not isinstance(model_info, dict):
|
||||
model_info = {}
|
||||
version['model'] = model_info
|
||||
model_info['description'] = data.get("description")
|
||||
model_info['tags'] = data.get("tags", [])
|
||||
|
||||
# Add creator from model data
|
||||
version['creator'] = data.get("creator")
|
||||
|
||||
self._remove_comfy_metadata(version)
|
||||
return version
|
||||
|
||||
# Case 3: Neither model_id nor version_id provided
|
||||
else:
|
||||
logger.error("Either model_id or version_id must be provided")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model version: {e}")
|
||||
return None
|
||||
|
||||
async def _get_version_by_id_only(self, downloader, version_id: int) -> Optional[Dict]:
|
||||
version = await self._fetch_version_by_id(downloader, version_id)
|
||||
if version is None:
|
||||
return None
|
||||
|
||||
model_id = version.get('modelId')
|
||||
if not model_id:
|
||||
logger.error(f"No modelId found in version {version_id}")
|
||||
return None
|
||||
|
||||
model_data = await self._fetch_model_data(downloader, model_id)
|
||||
if model_data:
|
||||
self._enrich_version_with_model_data(version, model_data)
|
||||
|
||||
self._remove_comfy_metadata(version)
|
||||
return version
|
||||
|
||||
async def _get_version_with_model_id(self, downloader, model_id: int, version_id: Optional[int]) -> Optional[Dict]:
|
||||
model_data = await self._fetch_model_data(downloader, model_id)
|
||||
if not model_data:
|
||||
return None
|
||||
|
||||
target_version = self._select_target_version(model_data, model_id, version_id)
|
||||
if target_version is None:
|
||||
return None
|
||||
|
||||
target_version_id = target_version.get('id')
|
||||
version = await self._fetch_version_by_id(downloader, target_version_id) if target_version_id else None
|
||||
|
||||
if version is None:
|
||||
model_hash = self._extract_primary_model_hash(target_version)
|
||||
if model_hash:
|
||||
version = await self._fetch_version_by_hash(downloader, model_hash)
|
||||
else:
|
||||
logger.warning(
|
||||
f"No primary model hash found for model {model_id} version {target_version_id}"
|
||||
)
|
||||
|
||||
if version is None:
|
||||
version = self._build_version_from_model_data(target_version, model_id, model_data)
|
||||
|
||||
self._enrich_version_with_model_data(version, model_data)
|
||||
self._remove_comfy_metadata(version)
|
||||
return version
|
||||
|
||||
async def _fetch_model_data(self, downloader, model_id: int) -> Optional[Dict]:
|
||||
success, data = await downloader.make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/models/{model_id}",
|
||||
use_auth=True
|
||||
)
|
||||
if success:
|
||||
return data
|
||||
logger.warning(f"Failed to fetch model data for model {model_id}")
|
||||
return None
|
||||
|
||||
async def _fetch_version_by_id(self, downloader, version_id: Optional[int]) -> Optional[Dict]:
|
||||
if version_id is None:
|
||||
return None
|
||||
|
||||
success, version = await downloader.make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/model-versions/{version_id}",
|
||||
use_auth=True
|
||||
)
|
||||
if success:
|
||||
return version
|
||||
|
||||
logger.warning(f"Failed to fetch version by id {version_id}")
|
||||
return None
|
||||
|
||||
async def _fetch_version_by_hash(self, downloader, model_hash: Optional[str]) -> Optional[Dict]:
|
||||
if not model_hash:
|
||||
return None
|
||||
|
||||
success, version = await downloader.make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/model-versions/by-hash/{model_hash}",
|
||||
use_auth=True
|
||||
)
|
||||
if success:
|
||||
return version
|
||||
|
||||
logger.warning(f"Failed to fetch version by hash {model_hash}")
|
||||
return None
|
||||
|
||||
def _select_target_version(self, model_data: Dict, model_id: int, version_id: Optional[int]) -> Optional[Dict]:
|
||||
model_versions = model_data.get('modelVersions', [])
|
||||
if not model_versions:
|
||||
logger.warning(f"No model versions found for model {model_id}")
|
||||
return None
|
||||
|
||||
if version_id is not None:
|
||||
target_version = next(
|
||||
(item for item in model_versions if item.get('id') == version_id),
|
||||
None
|
||||
)
|
||||
if target_version is None:
|
||||
logger.warning(
|
||||
f"Version {version_id} not found for model {model_id}, defaulting to first version"
|
||||
)
|
||||
return model_versions[0]
|
||||
return target_version
|
||||
|
||||
return model_versions[0]
|
||||
|
||||
def _extract_primary_model_hash(self, version_entry: Dict) -> Optional[str]:
|
||||
for file_info in version_entry.get('files', []):
|
||||
if file_info.get('type') == 'Model' and file_info.get('primary'):
|
||||
hashes = file_info.get('hashes', {})
|
||||
model_hash = hashes.get('SHA256')
|
||||
if model_hash:
|
||||
return model_hash
|
||||
return None
|
||||
|
||||
def _build_version_from_model_data(self, version_entry: Dict, model_id: int, model_data: Dict) -> Dict:
|
||||
version = copy.deepcopy(version_entry)
|
||||
version.pop('index', None)
|
||||
version['modelId'] = model_id
|
||||
version['model'] = {
|
||||
'name': model_data.get('name'),
|
||||
'type': model_data.get('type'),
|
||||
'nsfw': model_data.get('nsfw'),
|
||||
'poi': model_data.get('poi')
|
||||
}
|
||||
return version
|
||||
|
||||
def _enrich_version_with_model_data(self, version: Dict, model_data: Dict) -> None:
|
||||
model_info = version.get('model')
|
||||
if not isinstance(model_info, dict):
|
||||
model_info = {}
|
||||
version['model'] = model_info
|
||||
|
||||
model_info['description'] = model_data.get("description")
|
||||
model_info['tags'] = model_data.get("tags", [])
|
||||
version['creator'] = model_data.get("creator")
|
||||
|
||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Fetch model version metadata from Civitai
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
@@ -169,6 +170,158 @@ async def test_get_model_version_by_version_id(monkeypatch, downloader):
|
||||
assert result["images"][0]["meta"]["other"] == "keep"
|
||||
|
||||
|
||||
async def test_get_model_version_with_model_id_prefers_version_endpoint(monkeypatch, downloader):
|
||||
requests = []
|
||||
|
||||
model_payload = {
|
||||
"modelVersions": [
|
||||
{
|
||||
"id": 7,
|
||||
"files": [
|
||||
{
|
||||
"type": "Model",
|
||||
"primary": True,
|
||||
"hashes": {"SHA256": "hash7"},
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
"description": "desc",
|
||||
"tags": ["tag"],
|
||||
"creator": {"username": "user"},
|
||||
"name": "Model",
|
||||
"type": "LORA",
|
||||
"nsfw": False,
|
||||
"poi": False,
|
||||
}
|
||||
|
||||
version_payload = {
|
||||
"id": 7,
|
||||
"modelId": 99,
|
||||
"model": {},
|
||||
"files": [],
|
||||
"images": [],
|
||||
}
|
||||
|
||||
async def fake_make_request(method, url, use_auth=True):
|
||||
requests.append(url)
|
||||
if url.endswith("/models/99"):
|
||||
return True, copy.deepcopy(model_payload)
|
||||
if url.endswith("/model-versions/7"):
|
||||
return True, copy.deepcopy(version_payload)
|
||||
return False, "unexpected"
|
||||
|
||||
downloader.make_request = fake_make_request
|
||||
|
||||
client = await CivitaiClient.get_instance()
|
||||
|
||||
result = await client.get_model_version(model_id=99, version_id=7)
|
||||
|
||||
assert result["id"] == 7
|
||||
assert result["model"]["description"] == "desc"
|
||||
assert result["model"]["tags"] == ["tag"]
|
||||
assert result["creator"] == {"username": "user"}
|
||||
assert requests[0].endswith("/models/99")
|
||||
assert requests[1].endswith("/model-versions/7")
|
||||
|
||||
|
||||
async def test_get_model_version_with_model_id_fallbacks_to_hash(monkeypatch, downloader):
|
||||
requests = []
|
||||
|
||||
model_payload = {
|
||||
"modelVersions": [
|
||||
{
|
||||
"id": 7,
|
||||
"files": [
|
||||
{
|
||||
"type": "Model",
|
||||
"primary": True,
|
||||
"hashes": {"SHA256": "hash7"},
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
"description": "desc",
|
||||
"tags": ["tag"],
|
||||
"creator": {"username": "user"},
|
||||
"name": "Model",
|
||||
"type": "LORA",
|
||||
"nsfw": False,
|
||||
"poi": False,
|
||||
}
|
||||
|
||||
version_payload = {
|
||||
"id": 7,
|
||||
"modelId": 99,
|
||||
"files": [],
|
||||
"images": [],
|
||||
}
|
||||
|
||||
async def fake_make_request(method, url, use_auth=True):
|
||||
requests.append(url)
|
||||
if url.endswith("/models/99"):
|
||||
return True, copy.deepcopy(model_payload)
|
||||
if url.endswith("/model-versions/7"):
|
||||
return False, "boom"
|
||||
if url.endswith("/model-versions/by-hash/hash7"):
|
||||
return True, copy.deepcopy(version_payload)
|
||||
return False, "unexpected"
|
||||
|
||||
downloader.make_request = fake_make_request
|
||||
|
||||
client = await CivitaiClient.get_instance()
|
||||
|
||||
result = await client.get_model_version(model_id=99, version_id=7)
|
||||
|
||||
assert result["id"] == 7
|
||||
assert result["model"]["description"] == "desc"
|
||||
assert result["model"]["tags"] == ["tag"]
|
||||
assert result["creator"] == {"username": "user"}
|
||||
assert requests[1].endswith("/model-versions/7")
|
||||
assert requests[2].endswith("/model-versions/by-hash/hash7")
|
||||
|
||||
|
||||
async def test_get_model_version_with_model_id_builds_from_model_data(monkeypatch, downloader):
|
||||
model_payload = {
|
||||
"modelVersions": [
|
||||
{
|
||||
"id": 7,
|
||||
"files": [],
|
||||
"name": "v1",
|
||||
}
|
||||
],
|
||||
"description": "desc",
|
||||
"tags": ["tag"],
|
||||
"creator": {"username": "user"},
|
||||
"name": "Model",
|
||||
"type": "LORA",
|
||||
"nsfw": False,
|
||||
"poi": False,
|
||||
}
|
||||
|
||||
async def fake_make_request(method, url, use_auth=True):
|
||||
if url.endswith("/models/99"):
|
||||
return True, copy.deepcopy(model_payload)
|
||||
if url.endswith("/model-versions/7"):
|
||||
return False, "boom"
|
||||
if "/model-versions/by-hash/" in url:
|
||||
return False, "boom"
|
||||
return False, "unexpected"
|
||||
|
||||
downloader.make_request = fake_make_request
|
||||
|
||||
client = await CivitaiClient.get_instance()
|
||||
|
||||
result = await client.get_model_version(model_id=99, version_id=7)
|
||||
|
||||
assert result["modelId"] == 99
|
||||
assert result["model"]["name"] == "Model"
|
||||
assert result["model"]["type"] == "LORA"
|
||||
assert result["model"]["description"] == "desc"
|
||||
assert result["model"]["tags"] == ["tag"]
|
||||
assert result["creator"] == {"username": "user"}
|
||||
|
||||
|
||||
async def test_get_model_version_requires_identifier(monkeypatch, downloader):
|
||||
client = await CivitaiClient.get_instance()
|
||||
result = await client.get_model_version()
|
||||
|
||||
Reference in New Issue
Block a user