fix(civitai): improve model version retrieval

This commit is contained in:
pixelpaws
2025-10-09 10:56:25 +08:00
parent f0e246b4ac
commit 8e51f0f19f
2 changed files with 299 additions and 127 deletions

View File

@@ -1,3 +1,4 @@
import copy
from unittest.mock import AsyncMock
import pytest
@@ -169,6 +170,158 @@ async def test_get_model_version_by_version_id(monkeypatch, downloader):
assert result["images"][0]["meta"]["other"] == "keep"
async def test_get_model_version_with_model_id_prefers_version_endpoint(monkeypatch, downloader):
requests = []
model_payload = {
"modelVersions": [
{
"id": 7,
"files": [
{
"type": "Model",
"primary": True,
"hashes": {"SHA256": "hash7"},
}
],
}
],
"description": "desc",
"tags": ["tag"],
"creator": {"username": "user"},
"name": "Model",
"type": "LORA",
"nsfw": False,
"poi": False,
}
version_payload = {
"id": 7,
"modelId": 99,
"model": {},
"files": [],
"images": [],
}
async def fake_make_request(method, url, use_auth=True):
requests.append(url)
if url.endswith("/models/99"):
return True, copy.deepcopy(model_payload)
if url.endswith("/model-versions/7"):
return True, copy.deepcopy(version_payload)
return False, "unexpected"
downloader.make_request = fake_make_request
client = await CivitaiClient.get_instance()
result = await client.get_model_version(model_id=99, version_id=7)
assert result["id"] == 7
assert result["model"]["description"] == "desc"
assert result["model"]["tags"] == ["tag"]
assert result["creator"] == {"username": "user"}
assert requests[0].endswith("/models/99")
assert requests[1].endswith("/model-versions/7")
async def test_get_model_version_with_model_id_fallbacks_to_hash(monkeypatch, downloader):
requests = []
model_payload = {
"modelVersions": [
{
"id": 7,
"files": [
{
"type": "Model",
"primary": True,
"hashes": {"SHA256": "hash7"},
}
],
}
],
"description": "desc",
"tags": ["tag"],
"creator": {"username": "user"},
"name": "Model",
"type": "LORA",
"nsfw": False,
"poi": False,
}
version_payload = {
"id": 7,
"modelId": 99,
"files": [],
"images": [],
}
async def fake_make_request(method, url, use_auth=True):
requests.append(url)
if url.endswith("/models/99"):
return True, copy.deepcopy(model_payload)
if url.endswith("/model-versions/7"):
return False, "boom"
if url.endswith("/model-versions/by-hash/hash7"):
return True, copy.deepcopy(version_payload)
return False, "unexpected"
downloader.make_request = fake_make_request
client = await CivitaiClient.get_instance()
result = await client.get_model_version(model_id=99, version_id=7)
assert result["id"] == 7
assert result["model"]["description"] == "desc"
assert result["model"]["tags"] == ["tag"]
assert result["creator"] == {"username": "user"}
assert requests[1].endswith("/model-versions/7")
assert requests[2].endswith("/model-versions/by-hash/hash7")
async def test_get_model_version_with_model_id_builds_from_model_data(monkeypatch, downloader):
model_payload = {
"modelVersions": [
{
"id": 7,
"files": [],
"name": "v1",
}
],
"description": "desc",
"tags": ["tag"],
"creator": {"username": "user"},
"name": "Model",
"type": "LORA",
"nsfw": False,
"poi": False,
}
async def fake_make_request(method, url, use_auth=True):
if url.endswith("/models/99"):
return True, copy.deepcopy(model_payload)
if url.endswith("/model-versions/7"):
return False, "boom"
if "/model-versions/by-hash/" in url:
return False, "boom"
return False, "unexpected"
downloader.make_request = fake_make_request
client = await CivitaiClient.get_instance()
result = await client.get_model_version(model_id=99, version_id=7)
assert result["modelId"] == 99
assert result["model"]["name"] == "Model"
assert result["model"]["type"] == "LORA"
assert result["model"]["description"] == "desc"
assert result["model"]["tags"] == ["tag"]
assert result["creator"] == {"username": "user"}
async def test_get_model_version_requires_identifier(monkeypatch, downloader):
client = await CivitaiClient.get_instance()
result = await client.get_model_version()