feat(metadata): add rate limit retry support to metadata providers

Add RateLimitRetryingProvider and _RateLimitRetryHelper classes to handle rate limiting with exponential backoff retries. Update get_metadata_provider function to automatically wrap providers with rate limit handling. This improves reliability when external APIs return rate limit errors by implementing automatic retries with configurable delays and jitter.
This commit is contained in:
Will Miao
2025-11-07 09:18:59 +08:00
parent c3932538e1
commit 1bb5d0b072
4 changed files with 244 additions and 43 deletions

View File

@@ -2,11 +2,12 @@ import os
import logging import logging
from .model_metadata_provider import ( from .model_metadata_provider import (
ModelMetadataProvider, ModelMetadataProvider,
ModelMetadataProviderManager, ModelMetadataProviderManager,
SQLiteModelMetadataProvider, SQLiteModelMetadataProvider,
CivitaiModelMetadataProvider, CivitaiModelMetadataProvider,
CivArchiveModelMetadataProvider, CivArchiveModelMetadataProvider,
FallbackMetadataProvider FallbackMetadataProvider,
RateLimitRetryingProvider,
) )
from .settings_manager import get_settings_manager from .settings_manager import get_settings_manager
from .metadata_archive_manager import MetadataArchiveManager from .metadata_archive_manager import MetadataArchiveManager
@@ -108,14 +109,24 @@ async def get_metadata_archive_manager():
base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
return MetadataArchiveManager(base_path) return MetadataArchiveManager(base_path)
def _wrap_provider_with_rate_limit(provider_name: str | None, provider: ModelMetadataProvider) -> ModelMetadataProvider:
if isinstance(provider, (FallbackMetadataProvider, RateLimitRetryingProvider)):
return provider
return RateLimitRetryingProvider(provider, label=provider_name)
async def get_metadata_provider(provider_name: str = None): async def get_metadata_provider(provider_name: str = None):
"""Get a specific metadata provider or default provider""" """Get a specific metadata provider or default provider with rate-limit handling."""
provider_manager = await ModelMetadataProviderManager.get_instance() provider_manager = await ModelMetadataProviderManager.get_instance()
if provider_name: provider = (
return provider_manager._get_provider(provider_name) provider_manager._get_provider(provider_name)
if provider_name
return provider_manager._get_provider() else provider_manager._get_provider()
)
return _wrap_provider_with_rate_limit(provider_name, provider)
async def get_default_metadata_provider(): async def get_default_metadata_provider():
"""Get the default metadata provider (fallback or single provider)""" """Get the default metadata provider (fallback or single provider)"""

View File

@@ -41,6 +41,55 @@ def _require_aiosqlite() -> Any:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class _RateLimitRetryHelper:
"""Coordinate exponential backoff retries after rate limiting."""
def __init__(
self,
*,
retry_limit: int = 3,
base_delay: float = 1.5,
max_delay: float = 30.0,
jitter_ratio: float = 0.2,
) -> None:
self._retry_limit = max(1, retry_limit)
self._base_delay = base_delay
self._max_delay = max_delay
self._jitter_ratio = max(0.0, jitter_ratio)
async def run(self, label: str, func, *args, **kwargs):
attempt = 0
while True:
try:
return await func(*args, **kwargs)
except RateLimitError as exc:
attempt += 1
if attempt >= self._retry_limit:
exc.provider = exc.provider or label
raise
delay = self._calculate_delay(exc.retry_after, attempt)
logger.warning(
"Provider %s rate limited request; retrying in %.2fs (attempt %s/%s)",
label,
delay,
attempt,
self._retry_limit,
)
await asyncio.sleep(delay)
def _calculate_delay(self, retry_after: Optional[float], attempt: int) -> float:
if retry_after is not None:
return min(self._max_delay, max(0.0, retry_after))
base_delay = self._base_delay * (2 ** max(0, attempt - 1))
jitter_span = base_delay * self._jitter_ratio
if jitter_span > 0:
base_delay += random.uniform(-jitter_span, jitter_span)
return min(self._max_delay, max(0.0, base_delay))
class ModelMetadataProvider(ABC): class ModelMetadataProvider(ABC):
"""Base abstract class for all model metadata providers""" """Base abstract class for all model metadata providers"""
@@ -390,6 +439,12 @@ class FallbackMetadataProvider(ModelMetadataProvider):
self._rate_limit_base_delay = rate_limit_base_delay self._rate_limit_base_delay = rate_limit_base_delay
self._rate_limit_max_delay = rate_limit_max_delay self._rate_limit_max_delay = rate_limit_max_delay
self._rate_limit_jitter_ratio = max(0.0, rate_limit_jitter_ratio) self._rate_limit_jitter_ratio = max(0.0, rate_limit_jitter_ratio)
self._rate_limit_helper = _RateLimitRetryHelper(
retry_limit=self._rate_limit_retry_limit,
base_delay=self._rate_limit_base_delay,
max_delay=self._rate_limit_max_delay,
jitter_ratio=self._rate_limit_jitter_ratio,
)
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]: async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
for provider, label in self._iter_providers(): for provider, label in self._iter_providers():
@@ -485,44 +540,80 @@ class FallbackMetadataProvider(ModelMetadataProvider):
def _iter_providers(self): def _iter_providers(self):
return zip(self.providers, self._provider_labels) return zip(self.providers, self._provider_labels)
async def _call_with_rate_limit( async def _call_with_rate_limit(self, label: str, func, *args, **kwargs):
return await self._rate_limit_helper.run(label, func, *args, **kwargs)
class RateLimitRetryingProvider(ModelMetadataProvider):
"""Adapter that retries individual provider calls after rate limiting."""
def __init__(
self, self,
label: str, provider: ModelMetadataProvider,
func, label: Optional[str] = None,
*args, *,
**kwargs, rate_limit_retry_limit: int = 3,
): rate_limit_base_delay: float = 1.5,
attempt = 0 rate_limit_max_delay: float = 30.0,
while True: rate_limit_jitter_ratio: float = 0.2,
try: ) -> None:
return await func(*args, **kwargs) self._provider = provider
except RateLimitError as exc: self._label = label or provider.__class__.__name__
attempt += 1 self._rate_limit_helper = _RateLimitRetryHelper(
if attempt >= self._rate_limit_retry_limit: retry_limit=rate_limit_retry_limit,
exc.provider = exc.provider or label base_delay=rate_limit_base_delay,
raise exc max_delay=rate_limit_max_delay,
delay = self._calculate_rate_limit_delay(exc.retry_after, attempt) jitter_ratio=rate_limit_jitter_ratio,
logger.warning( )
"Provider %s rate limited request; retrying in %.2fs (attempt %s/%s)",
label,
delay,
attempt,
self._rate_limit_retry_limit,
)
await asyncio.sleep(delay)
except Exception:
raise
def _calculate_rate_limit_delay(self, retry_after: Optional[float], attempt: int) -> float: def __getattr__(self, item):
if retry_after is not None: return getattr(self._provider, item)
return min(self._rate_limit_max_delay, max(0.0, retry_after))
base_delay = self._rate_limit_base_delay * (2 ** max(0, attempt - 1)) async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
jitter_span = base_delay * self._rate_limit_jitter_ratio return await self._rate_limit_helper.run(
if jitter_span > 0: self._label,
base_delay += random.uniform(-jitter_span, jitter_span) self._provider.get_model_by_hash,
model_hash,
)
return min(self._rate_limit_max_delay, max(0.0, base_delay)) async def get_model_versions(self, model_id: str) -> Optional[Dict]:
return await self._rate_limit_helper.run(
self._label,
self._provider.get_model_versions,
model_id,
)
async def get_model_versions_bulk(
self,
model_ids: Sequence[int],
) -> Optional[Dict[int, Dict]]:
return await self._rate_limit_helper.run(
self._label,
self._provider.get_model_versions_bulk,
model_ids,
)
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
return await self._rate_limit_helper.run(
self._label,
self._provider.get_model_version,
model_id,
version_id,
)
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
return await self._rate_limit_helper.run(
self._label,
self._provider.get_model_version_info,
version_id,
)
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
return await self._rate_limit_helper.run(
self._label,
self._provider.get_user_models,
username,
)
class ModelMetadataProviderManager: class ModelMetadataProviderManager:
"""Manager for selecting and using model metadata providers""" """Manager for selecting and using model metadata providers"""

View File

@@ -0,0 +1,62 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from py.services import metadata_service
from py.services.model_metadata_provider import (
FallbackMetadataProvider,
ModelMetadataProvider,
RateLimitRetryingProvider,
)
class DummyProvider(ModelMetadataProvider):
async def get_model_by_hash(self, model_hash: str):
return None, None
async def get_model_versions(self, model_id: str):
return None
async def get_model_versions_bulk(self, model_ids):
return None
async def get_model_version(self, model_id: int = None, version_id: int = None):
return None
async def get_model_version_info(self, version_id: str):
return None, None
async def get_user_models(self, username: str):
return None
@pytest.mark.asyncio
async def test_get_metadata_provider_wraps_non_fallback(monkeypatch):
provider = DummyProvider()
dummy_manager = SimpleNamespace(_get_provider=lambda _name=None: provider)
monkeypatch.setattr(
metadata_service.ModelMetadataProviderManager,
"get_instance",
AsyncMock(return_value=dummy_manager),
)
wrapped = await metadata_service.get_metadata_provider("dummy")
assert isinstance(wrapped, RateLimitRetryingProvider)
assert wrapped is not provider
@pytest.mark.asyncio
async def test_get_metadata_provider_returns_fallback_as_is(monkeypatch):
fallback = FallbackMetadataProvider([("dummy", DummyProvider())])
dummy_manager = SimpleNamespace(_get_provider=lambda _name=None: fallback)
monkeypatch.setattr(
metadata_service.ModelMetadataProviderManager,
"get_instance",
AsyncMock(return_value=dummy_manager),
)
provider = await metadata_service.get_metadata_provider()
assert provider is fallback

View File

@@ -4,7 +4,10 @@ import pytest
from py.services import model_metadata_provider as provider_module from py.services import model_metadata_provider as provider_module
from py.services.errors import RateLimitError from py.services.errors import RateLimitError
from py.services.model_metadata_provider import FallbackMetadataProvider from py.services.model_metadata_provider import (
FallbackMetadataProvider,
RateLimitRetryingProvider,
)
class RateLimitThenSuccessProvider: class RateLimitThenSuccessProvider:
@@ -80,3 +83,37 @@ async def test_fallback_respects_retry_limit(monkeypatch):
assert primary.calls == 2 assert primary.calls == 2
assert secondary.calls == 0 assert secondary.calls == 0
sleep_mock.assert_awaited_once() sleep_mock.assert_awaited_once()
@pytest.mark.asyncio
async def test_rate_limit_retrying_provider_retries(monkeypatch):
sleep_mock = AsyncMock()
monkeypatch.setattr(provider_module.asyncio, "sleep", sleep_mock)
monkeypatch.setattr(provider_module.random, "uniform", lambda *_: 0.0)
inner = RateLimitThenSuccessProvider()
wrapper = RateLimitRetryingProvider(inner, label="inner", rate_limit_base_delay=0.1)
result, error = await wrapper.get_model_by_hash("abc")
assert error is None
assert result == {"id": "ok"}
assert inner.calls == 2
sleep_mock.assert_awaited_once()
@pytest.mark.asyncio
async def test_rate_limit_retrying_provider_respects_limit(monkeypatch):
sleep_mock = AsyncMock()
monkeypatch.setattr(provider_module.asyncio, "sleep", sleep_mock)
monkeypatch.setattr(provider_module.random, "uniform", lambda *_: 0.0)
inner = AlwaysRateLimitedProvider()
wrapper = RateLimitRetryingProvider(inner, label="inner", rate_limit_retry_limit=2)
with pytest.raises(RateLimitError) as exc_info:
await wrapper.get_model_by_hash("abc")
assert exc_info.value.provider == "inner"
assert inner.calls == 2
sleep_mock.assert_awaited_once()