mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-23 22:22:11 -03:00
fix(civitai): improve model version retrieval
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
@@ -169,6 +170,158 @@ async def test_get_model_version_by_version_id(monkeypatch, downloader):
|
||||
assert result["images"][0]["meta"]["other"] == "keep"
|
||||
|
||||
|
||||
async def test_get_model_version_with_model_id_prefers_version_endpoint(monkeypatch, downloader):
|
||||
requests = []
|
||||
|
||||
model_payload = {
|
||||
"modelVersions": [
|
||||
{
|
||||
"id": 7,
|
||||
"files": [
|
||||
{
|
||||
"type": "Model",
|
||||
"primary": True,
|
||||
"hashes": {"SHA256": "hash7"},
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
"description": "desc",
|
||||
"tags": ["tag"],
|
||||
"creator": {"username": "user"},
|
||||
"name": "Model",
|
||||
"type": "LORA",
|
||||
"nsfw": False,
|
||||
"poi": False,
|
||||
}
|
||||
|
||||
version_payload = {
|
||||
"id": 7,
|
||||
"modelId": 99,
|
||||
"model": {},
|
||||
"files": [],
|
||||
"images": [],
|
||||
}
|
||||
|
||||
async def fake_make_request(method, url, use_auth=True):
|
||||
requests.append(url)
|
||||
if url.endswith("/models/99"):
|
||||
return True, copy.deepcopy(model_payload)
|
||||
if url.endswith("/model-versions/7"):
|
||||
return True, copy.deepcopy(version_payload)
|
||||
return False, "unexpected"
|
||||
|
||||
downloader.make_request = fake_make_request
|
||||
|
||||
client = await CivitaiClient.get_instance()
|
||||
|
||||
result = await client.get_model_version(model_id=99, version_id=7)
|
||||
|
||||
assert result["id"] == 7
|
||||
assert result["model"]["description"] == "desc"
|
||||
assert result["model"]["tags"] == ["tag"]
|
||||
assert result["creator"] == {"username": "user"}
|
||||
assert requests[0].endswith("/models/99")
|
||||
assert requests[1].endswith("/model-versions/7")
|
||||
|
||||
|
||||
async def test_get_model_version_with_model_id_fallbacks_to_hash(monkeypatch, downloader):
|
||||
requests = []
|
||||
|
||||
model_payload = {
|
||||
"modelVersions": [
|
||||
{
|
||||
"id": 7,
|
||||
"files": [
|
||||
{
|
||||
"type": "Model",
|
||||
"primary": True,
|
||||
"hashes": {"SHA256": "hash7"},
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
"description": "desc",
|
||||
"tags": ["tag"],
|
||||
"creator": {"username": "user"},
|
||||
"name": "Model",
|
||||
"type": "LORA",
|
||||
"nsfw": False,
|
||||
"poi": False,
|
||||
}
|
||||
|
||||
version_payload = {
|
||||
"id": 7,
|
||||
"modelId": 99,
|
||||
"files": [],
|
||||
"images": [],
|
||||
}
|
||||
|
||||
async def fake_make_request(method, url, use_auth=True):
|
||||
requests.append(url)
|
||||
if url.endswith("/models/99"):
|
||||
return True, copy.deepcopy(model_payload)
|
||||
if url.endswith("/model-versions/7"):
|
||||
return False, "boom"
|
||||
if url.endswith("/model-versions/by-hash/hash7"):
|
||||
return True, copy.deepcopy(version_payload)
|
||||
return False, "unexpected"
|
||||
|
||||
downloader.make_request = fake_make_request
|
||||
|
||||
client = await CivitaiClient.get_instance()
|
||||
|
||||
result = await client.get_model_version(model_id=99, version_id=7)
|
||||
|
||||
assert result["id"] == 7
|
||||
assert result["model"]["description"] == "desc"
|
||||
assert result["model"]["tags"] == ["tag"]
|
||||
assert result["creator"] == {"username": "user"}
|
||||
assert requests[1].endswith("/model-versions/7")
|
||||
assert requests[2].endswith("/model-versions/by-hash/hash7")
|
||||
|
||||
|
||||
async def test_get_model_version_with_model_id_builds_from_model_data(monkeypatch, downloader):
|
||||
model_payload = {
|
||||
"modelVersions": [
|
||||
{
|
||||
"id": 7,
|
||||
"files": [],
|
||||
"name": "v1",
|
||||
}
|
||||
],
|
||||
"description": "desc",
|
||||
"tags": ["tag"],
|
||||
"creator": {"username": "user"},
|
||||
"name": "Model",
|
||||
"type": "LORA",
|
||||
"nsfw": False,
|
||||
"poi": False,
|
||||
}
|
||||
|
||||
async def fake_make_request(method, url, use_auth=True):
|
||||
if url.endswith("/models/99"):
|
||||
return True, copy.deepcopy(model_payload)
|
||||
if url.endswith("/model-versions/7"):
|
||||
return False, "boom"
|
||||
if "/model-versions/by-hash/" in url:
|
||||
return False, "boom"
|
||||
return False, "unexpected"
|
||||
|
||||
downloader.make_request = fake_make_request
|
||||
|
||||
client = await CivitaiClient.get_instance()
|
||||
|
||||
result = await client.get_model_version(model_id=99, version_id=7)
|
||||
|
||||
assert result["modelId"] == 99
|
||||
assert result["model"]["name"] == "Model"
|
||||
assert result["model"]["type"] == "LORA"
|
||||
assert result["model"]["description"] == "desc"
|
||||
assert result["model"]["tags"] == ["tag"]
|
||||
assert result["creator"] == {"username": "user"}
|
||||
|
||||
|
||||
async def test_get_model_version_requires_identifier(monkeypatch, downloader):
|
||||
client = await CivitaiClient.get_instance()
|
||||
result = await client.get_model_version()
|
||||
|
||||
Reference in New Issue
Block a user