mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-22 21:52:11 -03:00
- Add RateLimitError import and _make_request wrapper method to handle rate limiting - Update API methods to use _make_request wrapper instead of direct downloader calls - Add explicit RateLimitError handling in API methods to properly propagate rate limit errors - Add _extract_retry_after method to parse Retry-After headers - Improve error handling by surfacing rate limit information to callers These changes ensure that rate limiting from the Civitai API is properly detected and handled, allowing callers to implement appropriate backoff strategies when rate limits are encountered.
403 lines
12 KiB
Python
403 lines
12 KiB
Python
import copy
|
|
from unittest.mock import AsyncMock
|
|
|
|
import pytest
|
|
|
|
from py.services import civitai_client as civitai_client_module
|
|
from py.services.civitai_client import CivitaiClient
|
|
from py.services.errors import RateLimitError
|
|
from py.services.model_metadata_provider import ModelMetadataProviderManager
|
|
|
|
|
|
class DummyDownloader:
|
|
def __init__(self):
|
|
self.download_calls = []
|
|
self.memory_calls = []
|
|
self.request_calls = []
|
|
|
|
async def download_file(self, **kwargs):
|
|
self.download_calls.append(kwargs)
|
|
return True, kwargs["save_path"]
|
|
|
|
async def download_to_memory(self, url, use_auth=False):
|
|
self.memory_calls.append({"url": url, "use_auth": use_auth})
|
|
return True, b"bytes", {"content-type": "image/jpeg"}
|
|
|
|
async def make_request(self, method, url, use_auth=True):
|
|
self.request_calls.append({"method": method, "url": url, "use_auth": use_auth})
|
|
return True, {}
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def reset_singletons():
|
|
CivitaiClient._instance = None
|
|
ModelMetadataProviderManager._instance = None
|
|
yield
|
|
CivitaiClient._instance = None
|
|
ModelMetadataProviderManager._instance = None
|
|
|
|
|
|
@pytest.fixture
|
|
def downloader(monkeypatch):
|
|
instance = DummyDownloader()
|
|
monkeypatch.setattr(civitai_client_module, "get_downloader", AsyncMock(return_value=instance))
|
|
return instance
|
|
|
|
|
|
async def test_download_file_uses_downloader(tmp_path, downloader):
|
|
client = await CivitaiClient.get_instance()
|
|
save_dir = tmp_path / "files"
|
|
save_dir.mkdir()
|
|
|
|
success, path = await client.download_file(
|
|
url="https://example.invalid/model",
|
|
save_dir=str(save_dir),
|
|
default_filename="model.safetensors",
|
|
)
|
|
|
|
assert success is True
|
|
assert path == str(save_dir / "model.safetensors")
|
|
assert downloader.download_calls[0]["use_auth"] is True
|
|
|
|
|
|
async def test_get_model_by_hash_enriches_metadata(monkeypatch, downloader):
|
|
version_payload = {
|
|
"modelId": 123,
|
|
"model": {"description": "", "tags": []},
|
|
"creator": {},
|
|
"images": [
|
|
{"meta": {"comfy": {"foo": "bar"}, "other": "keep"}},
|
|
{"meta": "not-a-dict"},
|
|
],
|
|
}
|
|
model_payload = {"description": "desc", "tags": ["tag"], "creator": {"username": "user"}}
|
|
|
|
async def fake_make_request(method, url, use_auth=True):
|
|
if url.endswith("by-hash/hash"):
|
|
return True, version_payload.copy()
|
|
if url.endswith("/models/123"):
|
|
return True, model_payload
|
|
return False, "unexpected"
|
|
|
|
downloader.make_request = fake_make_request
|
|
|
|
client = await CivitaiClient.get_instance()
|
|
|
|
result, error = await client.get_model_by_hash("hash")
|
|
|
|
assert error is None
|
|
assert result["model"]["description"] == "desc"
|
|
assert result["model"]["tags"] == ["tag"]
|
|
assert result["creator"] == {"username": "user"}
|
|
assert "comfy" not in result["images"][0]["meta"]
|
|
assert result["images"][0]["meta"]["other"] == "keep"
|
|
|
|
|
|
async def test_get_model_by_hash_handles_not_found(monkeypatch, downloader):
|
|
async def fake_make_request(method, url, use_auth=True):
|
|
return False, "not found"
|
|
|
|
downloader.make_request = fake_make_request
|
|
|
|
client = await CivitaiClient.get_instance()
|
|
|
|
result, error = await client.get_model_by_hash("missing")
|
|
|
|
assert result is None
|
|
assert error == "Model not found"
|
|
|
|
|
|
async def test_get_model_by_hash_propagates_rate_limit(monkeypatch, downloader):
|
|
async def fake_make_request(method, url, use_auth=True):
|
|
return False, RateLimitError("limited", retry_after=4)
|
|
|
|
downloader.make_request = fake_make_request
|
|
|
|
client = await CivitaiClient.get_instance()
|
|
|
|
with pytest.raises(RateLimitError) as exc_info:
|
|
await client.get_model_by_hash("limited")
|
|
|
|
assert exc_info.value.retry_after == 4
|
|
assert exc_info.value.provider == "civitai_api"
|
|
|
|
|
|
async def test_download_preview_image_writes_file(tmp_path, downloader):
|
|
client = await CivitaiClient.get_instance()
|
|
target = tmp_path / "preview" / "image.jpg"
|
|
|
|
success = await client.download_preview_image("https://example.invalid/preview", str(target))
|
|
|
|
assert success is True
|
|
assert target.exists()
|
|
assert target.read_bytes() == b"bytes"
|
|
|
|
|
|
async def test_download_preview_image_failure(monkeypatch, downloader):
|
|
async def failing_download(url, use_auth=False):
|
|
return False, b"", {}
|
|
|
|
downloader.download_to_memory = failing_download
|
|
|
|
client = await CivitaiClient.get_instance()
|
|
target = "/tmp/ignored.jpg"
|
|
|
|
success = await client.download_preview_image("https://example.invalid/preview", target)
|
|
|
|
assert success is False
|
|
|
|
|
|
async def test_get_model_versions_success(monkeypatch, downloader):
|
|
async def fake_make_request(method, url, use_auth=True):
|
|
return True, {"modelVersions": [{"id": 1}], "type": "LORA", "name": "Model"}
|
|
|
|
downloader.make_request = fake_make_request
|
|
|
|
client = await CivitaiClient.get_instance()
|
|
|
|
result = await client.get_model_versions("123")
|
|
|
|
assert result == {"modelVersions": [{"id": 1}], "type": "LORA", "name": "Model"}
|
|
|
|
|
|
async def test_get_model_version_by_version_id(monkeypatch, downloader):
|
|
async def fake_make_request(method, url, use_auth=True):
|
|
if url.endswith("/model-versions/7"):
|
|
return True, {
|
|
"modelId": 321,
|
|
"model": {"description": ""},
|
|
"files": [],
|
|
"images": [{"meta": {"comfy": {"foo": "bar"}, "other": "keep"}}],
|
|
}
|
|
if url.endswith("/models/321"):
|
|
return True, {"description": "desc", "tags": ["tag"], "creator": {"username": "user"}}
|
|
return False, "unexpected"
|
|
|
|
downloader.make_request = fake_make_request
|
|
|
|
client = await CivitaiClient.get_instance()
|
|
|
|
result = await client.get_model_version(version_id=7)
|
|
|
|
assert result["model"]["description"] == "desc"
|
|
assert result["model"]["tags"] == ["tag"]
|
|
assert result["creator"] == {"username": "user"}
|
|
assert "comfy" not in result["images"][0]["meta"]
|
|
assert result["images"][0]["meta"]["other"] == "keep"
|
|
|
|
|
|
async def test_get_model_version_with_model_id_prefers_version_endpoint(monkeypatch, downloader):
|
|
requests = []
|
|
|
|
model_payload = {
|
|
"modelVersions": [
|
|
{
|
|
"id": 7,
|
|
"files": [
|
|
{
|
|
"type": "Model",
|
|
"primary": True,
|
|
"hashes": {"SHA256": "hash7"},
|
|
}
|
|
],
|
|
}
|
|
],
|
|
"description": "desc",
|
|
"tags": ["tag"],
|
|
"creator": {"username": "user"},
|
|
"name": "Model",
|
|
"type": "LORA",
|
|
"nsfw": False,
|
|
"poi": False,
|
|
}
|
|
|
|
version_payload = {
|
|
"id": 7,
|
|
"modelId": 99,
|
|
"model": {},
|
|
"files": [],
|
|
"images": [],
|
|
}
|
|
|
|
async def fake_make_request(method, url, use_auth=True):
|
|
requests.append(url)
|
|
if url.endswith("/models/99"):
|
|
return True, copy.deepcopy(model_payload)
|
|
if url.endswith("/model-versions/7"):
|
|
return True, copy.deepcopy(version_payload)
|
|
return False, "unexpected"
|
|
|
|
downloader.make_request = fake_make_request
|
|
|
|
client = await CivitaiClient.get_instance()
|
|
|
|
result = await client.get_model_version(model_id=99, version_id=7)
|
|
|
|
assert result["id"] == 7
|
|
assert result["model"]["description"] == "desc"
|
|
assert result["model"]["tags"] == ["tag"]
|
|
assert result["creator"] == {"username": "user"}
|
|
assert requests[0].endswith("/models/99")
|
|
assert requests[1].endswith("/model-versions/7")
|
|
|
|
|
|
async def test_get_model_version_with_model_id_fallbacks_to_hash(monkeypatch, downloader):
|
|
requests = []
|
|
|
|
model_payload = {
|
|
"modelVersions": [
|
|
{
|
|
"id": 7,
|
|
"files": [
|
|
{
|
|
"type": "Model",
|
|
"primary": True,
|
|
"hashes": {"SHA256": "hash7"},
|
|
}
|
|
],
|
|
}
|
|
],
|
|
"description": "desc",
|
|
"tags": ["tag"],
|
|
"creator": {"username": "user"},
|
|
"name": "Model",
|
|
"type": "LORA",
|
|
"nsfw": False,
|
|
"poi": False,
|
|
}
|
|
|
|
version_payload = {
|
|
"id": 7,
|
|
"modelId": 99,
|
|
"files": [],
|
|
"images": [],
|
|
}
|
|
|
|
async def fake_make_request(method, url, use_auth=True):
|
|
requests.append(url)
|
|
if url.endswith("/models/99"):
|
|
return True, copy.deepcopy(model_payload)
|
|
if url.endswith("/model-versions/7"):
|
|
return False, "boom"
|
|
if url.endswith("/model-versions/by-hash/hash7"):
|
|
return True, copy.deepcopy(version_payload)
|
|
return False, "unexpected"
|
|
|
|
downloader.make_request = fake_make_request
|
|
|
|
client = await CivitaiClient.get_instance()
|
|
|
|
result = await client.get_model_version(model_id=99, version_id=7)
|
|
|
|
assert result["id"] == 7
|
|
assert result["model"]["description"] == "desc"
|
|
assert result["model"]["tags"] == ["tag"]
|
|
assert result["creator"] == {"username": "user"}
|
|
assert requests[1].endswith("/model-versions/7")
|
|
assert requests[2].endswith("/model-versions/by-hash/hash7")
|
|
|
|
|
|
async def test_get_model_version_with_model_id_builds_from_model_data(monkeypatch, downloader):
|
|
model_payload = {
|
|
"modelVersions": [
|
|
{
|
|
"id": 7,
|
|
"files": [],
|
|
"name": "v1",
|
|
}
|
|
],
|
|
"description": "desc",
|
|
"tags": ["tag"],
|
|
"creator": {"username": "user"},
|
|
"name": "Model",
|
|
"type": "LORA",
|
|
"nsfw": False,
|
|
"poi": False,
|
|
}
|
|
|
|
async def fake_make_request(method, url, use_auth=True):
|
|
if url.endswith("/models/99"):
|
|
return True, copy.deepcopy(model_payload)
|
|
if url.endswith("/model-versions/7"):
|
|
return False, "boom"
|
|
if "/model-versions/by-hash/" in url:
|
|
return False, "boom"
|
|
return False, "unexpected"
|
|
|
|
downloader.make_request = fake_make_request
|
|
|
|
client = await CivitaiClient.get_instance()
|
|
|
|
result = await client.get_model_version(model_id=99, version_id=7)
|
|
|
|
assert result["modelId"] == 99
|
|
assert result["model"]["name"] == "Model"
|
|
assert result["model"]["type"] == "LORA"
|
|
assert result["model"]["description"] == "desc"
|
|
assert result["model"]["tags"] == ["tag"]
|
|
assert result["creator"] == {"username": "user"}
|
|
|
|
|
|
async def test_get_model_version_requires_identifier(monkeypatch, downloader):
|
|
client = await CivitaiClient.get_instance()
|
|
result = await client.get_model_version()
|
|
assert result is None
|
|
|
|
|
|
async def test_get_model_version_info_handles_not_found(monkeypatch, downloader):
|
|
async def fake_make_request(method, url, use_auth=True):
|
|
return False, "not found"
|
|
|
|
downloader.make_request = fake_make_request
|
|
|
|
client = await CivitaiClient.get_instance()
|
|
|
|
result, error = await client.get_model_version_info("55")
|
|
|
|
assert result is None
|
|
assert error == "Model not found"
|
|
|
|
|
|
async def test_get_model_version_info_success(monkeypatch, downloader):
|
|
expected = {"id": 55, "images": [{"meta": {"comfy": {"foo": "bar"}, "other": "keep"}}]}
|
|
|
|
async def fake_make_request(method, url, use_auth=True):
|
|
return True, expected
|
|
|
|
downloader.make_request = fake_make_request
|
|
|
|
client = await CivitaiClient.get_instance()
|
|
|
|
result, error = await client.get_model_version_info("55")
|
|
|
|
assert result == expected
|
|
assert error is None
|
|
assert "comfy" not in result["images"][0]["meta"]
|
|
assert result["images"][0]["meta"]["other"] == "keep"
|
|
|
|
|
|
async def test_get_image_info_returns_first_item(monkeypatch, downloader):
|
|
async def fake_make_request(method, url, use_auth=True):
|
|
return True, {"items": [{"id": 1}, {"id": 2}]}
|
|
|
|
downloader.make_request = fake_make_request
|
|
|
|
client = await CivitaiClient.get_instance()
|
|
|
|
result = await client.get_image_info("42")
|
|
|
|
assert result == {"id": 1}
|
|
|
|
|
|
async def test_get_image_info_handles_missing(monkeypatch, downloader):
|
|
async def fake_make_request(method, url, use_auth=True):
|
|
return True, {"items": []}
|
|
|
|
downloader.make_request = fake_make_request
|
|
|
|
client = await CivitaiClient.get_instance()
|
|
|
|
result = await client.get_image_info("42")
|
|
|
|
assert result is None
|