mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 06:32:12 -03:00
feat(civitai): add rate limiting support and error handling
- Add RateLimitError import and _make_request wrapper method to handle rate limiting - Update API methods to use _make_request wrapper instead of direct downloader calls - Add explicit RateLimitError handling in API methods to properly propagate rate limit errors - Add _extract_retry_after method to parse Retry-After headers - Improve error handling by surfacing rate limit information to callers These changes ensure that rate limiting from the Civitai API is properly detected and handled, allowing callers to implement appropriate backoff strategies when rate limits are encountered.
This commit is contained in:
@@ -5,6 +5,7 @@ import pytest
|
||||
|
||||
from py.services import civitai_client as civitai_client_module
|
||||
from py.services.civitai_client import CivitaiClient
|
||||
from py.services.errors import RateLimitError
|
||||
from py.services.model_metadata_provider import ModelMetadataProviderManager
|
||||
|
||||
|
||||
@@ -106,6 +107,21 @@ async def test_get_model_by_hash_handles_not_found(monkeypatch, downloader):
|
||||
assert error == "Model not found"
|
||||
|
||||
|
||||
async def test_get_model_by_hash_propagates_rate_limit(monkeypatch, downloader):
|
||||
async def fake_make_request(method, url, use_auth=True):
|
||||
return False, RateLimitError("limited", retry_after=4)
|
||||
|
||||
downloader.make_request = fake_make_request
|
||||
|
||||
client = await CivitaiClient.get_instance()
|
||||
|
||||
with pytest.raises(RateLimitError) as exc_info:
|
||||
await client.get_model_by_hash("limited")
|
||||
|
||||
assert exc_info.value.retry_after == 4
|
||||
assert exc_info.value.provider == "civitai_api"
|
||||
|
||||
|
||||
async def test_download_preview_image_writes_file(tmp_path, downloader):
|
||||
client = await CivitaiClient.get_instance()
|
||||
target = tmp_path / "preview" / "image.jpg"
|
||||
|
||||
@@ -4,6 +4,7 @@ from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.errors import RateLimitError
|
||||
from py.services.metadata_sync_service import MetadataSyncService
|
||||
|
||||
|
||||
@@ -340,6 +341,37 @@ async def test_fetch_and_update_model_falls_back_to_sqlite_after_civarchive_fail
|
||||
helpers.metadata_manager.save_metadata.assert_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_and_update_model_returns_rate_limit_error(tmp_path):
|
||||
rate_error = RateLimitError("limited", retry_after=7)
|
||||
default_provider = SimpleNamespace(
|
||||
get_model_by_hash=AsyncMock(side_effect=rate_error),
|
||||
get_model_version=AsyncMock(),
|
||||
)
|
||||
helpers = build_service(default_provider=default_provider)
|
||||
|
||||
model_path = tmp_path / "model.safetensors"
|
||||
model_data = {
|
||||
"file_path": str(model_path),
|
||||
"model_name": "Local",
|
||||
}
|
||||
update_cache = AsyncMock()
|
||||
|
||||
ok, error = await helpers.service.fetch_and_update_model(
|
||||
sha256="deadbeef",
|
||||
file_path=str(model_path),
|
||||
model_data=model_data,
|
||||
update_cache_func=update_cache,
|
||||
)
|
||||
|
||||
assert ok is False
|
||||
assert error is not None and "Rate limited" in error
|
||||
assert "7" in error
|
||||
helpers.metadata_manager.save_metadata.assert_not_awaited()
|
||||
update_cache.assert_not_awaited()
|
||||
helpers.provider_selector.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_relink_metadata_fetches_version_and_updates_sha(tmp_path):
|
||||
provider = SimpleNamespace(
|
||||
|
||||
82
tests/services/test_model_metadata_provider.py
Normal file
82
tests/services/test_model_metadata_provider.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services import model_metadata_provider as provider_module
|
||||
from py.services.errors import RateLimitError
|
||||
from py.services.model_metadata_provider import FallbackMetadataProvider
|
||||
|
||||
|
||||
class RateLimitThenSuccessProvider:
|
||||
def __init__(self) -> None:
|
||||
self.calls = 0
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str):
|
||||
self.calls += 1
|
||||
if self.calls == 1:
|
||||
raise RateLimitError("limited", retry_after=1.0)
|
||||
return {"id": "ok"}, None
|
||||
|
||||
|
||||
class AlwaysRateLimitedProvider:
|
||||
def __init__(self) -> None:
|
||||
self.calls = 0
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str):
|
||||
self.calls += 1
|
||||
raise RateLimitError("limited")
|
||||
|
||||
|
||||
class TrackingProvider:
|
||||
def __init__(self) -> None:
|
||||
self.calls = 0
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str):
|
||||
self.calls += 1
|
||||
return {"id": "secondary"}, None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_retries_same_provider_on_rate_limit(monkeypatch):
|
||||
sleep_mock = AsyncMock()
|
||||
monkeypatch.setattr(provider_module.asyncio, "sleep", sleep_mock)
|
||||
monkeypatch.setattr(provider_module.random, "uniform", lambda *_: 0.0)
|
||||
|
||||
primary = RateLimitThenSuccessProvider()
|
||||
secondary = TrackingProvider()
|
||||
|
||||
fallback = FallbackMetadataProvider(
|
||||
[("primary", primary), ("secondary", secondary)],
|
||||
)
|
||||
|
||||
result, error = await fallback.get_model_by_hash("abc")
|
||||
|
||||
assert error is None
|
||||
assert result == {"id": "ok"}
|
||||
assert primary.calls == 2
|
||||
assert secondary.calls == 0
|
||||
sleep_mock.assert_awaited_once()
|
||||
assert sleep_mock.await_args_list[0].args[0] == pytest.approx(1.0, rel=0.0, abs=1e-6)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_respects_retry_limit(monkeypatch):
|
||||
sleep_mock = AsyncMock()
|
||||
monkeypatch.setattr(provider_module.asyncio, "sleep", sleep_mock)
|
||||
monkeypatch.setattr(provider_module.random, "uniform", lambda *_: 0.0)
|
||||
|
||||
primary = AlwaysRateLimitedProvider()
|
||||
secondary = TrackingProvider()
|
||||
|
||||
fallback = FallbackMetadataProvider(
|
||||
[("primary", primary), ("secondary", secondary)],
|
||||
rate_limit_retry_limit=2,
|
||||
)
|
||||
|
||||
with pytest.raises(RateLimitError) as exc_info:
|
||||
await fallback.get_model_by_hash("abc")
|
||||
|
||||
assert exc_info.value.provider == "primary"
|
||||
assert primary.calls == 2
|
||||
assert secondary.calls == 0
|
||||
sleep_mock.assert_awaited_once()
|
||||
Reference in New Issue
Block a user