feat(metadata): add rate limit retry support to metadata providers

Add RateLimitRetryingProvider and _RateLimitRetryHelper classes to handle rate limiting with exponential backoff retries. Update get_metadata_provider function to automatically wrap providers with rate limit handling. This improves reliability when external APIs return rate limit errors by implementing automatic retries with configurable delays and jitter.
This commit is contained in:
Will Miao
2025-11-07 09:18:59 +08:00
parent c3932538e1
commit 1bb5d0b072
4 changed files with 244 additions and 43 deletions

View File

@@ -2,11 +2,12 @@ import os
import logging
from .model_metadata_provider import (
ModelMetadataProvider,
ModelMetadataProviderManager,
ModelMetadataProviderManager,
SQLiteModelMetadataProvider,
CivitaiModelMetadataProvider,
CivArchiveModelMetadataProvider,
FallbackMetadataProvider
FallbackMetadataProvider,
RateLimitRetryingProvider,
)
from .settings_manager import get_settings_manager
from .metadata_archive_manager import MetadataArchiveManager
@@ -108,14 +109,24 @@ async def get_metadata_archive_manager():
base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
return MetadataArchiveManager(base_path)
def _wrap_provider_with_rate_limit(provider_name: str | None, provider: ModelMetadataProvider) -> ModelMetadataProvider:
if isinstance(provider, (FallbackMetadataProvider, RateLimitRetryingProvider)):
return provider
return RateLimitRetryingProvider(provider, label=provider_name)
async def get_metadata_provider(provider_name: str = None):
"""Get a specific metadata provider or default provider"""
"""Get a specific metadata provider or default provider with rate-limit handling."""
provider_manager = await ModelMetadataProviderManager.get_instance()
if provider_name:
return provider_manager._get_provider(provider_name)
return provider_manager._get_provider()
provider = (
provider_manager._get_provider(provider_name)
if provider_name
else provider_manager._get_provider()
)
return _wrap_provider_with_rate_limit(provider_name, provider)
async def get_default_metadata_provider():
"""Get the default metadata provider (fallback or single provider)"""

View File

@@ -41,6 +41,55 @@ def _require_aiosqlite() -> Any:
logger = logging.getLogger(__name__)
class _RateLimitRetryHelper:
"""Coordinate exponential backoff retries after rate limiting."""
def __init__(
self,
*,
retry_limit: int = 3,
base_delay: float = 1.5,
max_delay: float = 30.0,
jitter_ratio: float = 0.2,
) -> None:
self._retry_limit = max(1, retry_limit)
self._base_delay = base_delay
self._max_delay = max_delay
self._jitter_ratio = max(0.0, jitter_ratio)
async def run(self, label: str, func, *args, **kwargs):
attempt = 0
while True:
try:
return await func(*args, **kwargs)
except RateLimitError as exc:
attempt += 1
if attempt >= self._retry_limit:
exc.provider = exc.provider or label
raise
delay = self._calculate_delay(exc.retry_after, attempt)
logger.warning(
"Provider %s rate limited request; retrying in %.2fs (attempt %s/%s)",
label,
delay,
attempt,
self._retry_limit,
)
await asyncio.sleep(delay)
def _calculate_delay(self, retry_after: Optional[float], attempt: int) -> float:
if retry_after is not None:
return min(self._max_delay, max(0.0, retry_after))
base_delay = self._base_delay * (2 ** max(0, attempt - 1))
jitter_span = base_delay * self._jitter_ratio
if jitter_span > 0:
base_delay += random.uniform(-jitter_span, jitter_span)
return min(self._max_delay, max(0.0, base_delay))
class ModelMetadataProvider(ABC):
"""Base abstract class for all model metadata providers"""
@@ -390,6 +439,12 @@ class FallbackMetadataProvider(ModelMetadataProvider):
self._rate_limit_base_delay = rate_limit_base_delay
self._rate_limit_max_delay = rate_limit_max_delay
self._rate_limit_jitter_ratio = max(0.0, rate_limit_jitter_ratio)
self._rate_limit_helper = _RateLimitRetryHelper(
retry_limit=self._rate_limit_retry_limit,
base_delay=self._rate_limit_base_delay,
max_delay=self._rate_limit_max_delay,
jitter_ratio=self._rate_limit_jitter_ratio,
)
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
for provider, label in self._iter_providers():
@@ -485,44 +540,80 @@ class FallbackMetadataProvider(ModelMetadataProvider):
def _iter_providers(self):
return zip(self.providers, self._provider_labels)
async def _call_with_rate_limit(
async def _call_with_rate_limit(self, label: str, func, *args, **kwargs):
return await self._rate_limit_helper.run(label, func, *args, **kwargs)
class RateLimitRetryingProvider(ModelMetadataProvider):
"""Adapter that retries individual provider calls after rate limiting."""
def __init__(
self,
label: str,
func,
*args,
**kwargs,
):
attempt = 0
while True:
try:
return await func(*args, **kwargs)
except RateLimitError as exc:
attempt += 1
if attempt >= self._rate_limit_retry_limit:
exc.provider = exc.provider or label
raise exc
delay = self._calculate_rate_limit_delay(exc.retry_after, attempt)
logger.warning(
"Provider %s rate limited request; retrying in %.2fs (attempt %s/%s)",
label,
delay,
attempt,
self._rate_limit_retry_limit,
)
await asyncio.sleep(delay)
except Exception:
raise
provider: ModelMetadataProvider,
label: Optional[str] = None,
*,
rate_limit_retry_limit: int = 3,
rate_limit_base_delay: float = 1.5,
rate_limit_max_delay: float = 30.0,
rate_limit_jitter_ratio: float = 0.2,
) -> None:
self._provider = provider
self._label = label or provider.__class__.__name__
self._rate_limit_helper = _RateLimitRetryHelper(
retry_limit=rate_limit_retry_limit,
base_delay=rate_limit_base_delay,
max_delay=rate_limit_max_delay,
jitter_ratio=rate_limit_jitter_ratio,
)
def _calculate_rate_limit_delay(self, retry_after: Optional[float], attempt: int) -> float:
if retry_after is not None:
return min(self._rate_limit_max_delay, max(0.0, retry_after))
def __getattr__(self, item):
return getattr(self._provider, item)
base_delay = self._rate_limit_base_delay * (2 ** max(0, attempt - 1))
jitter_span = base_delay * self._rate_limit_jitter_ratio
if jitter_span > 0:
base_delay += random.uniform(-jitter_span, jitter_span)
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
return await self._rate_limit_helper.run(
self._label,
self._provider.get_model_by_hash,
model_hash,
)
return min(self._rate_limit_max_delay, max(0.0, base_delay))
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
return await self._rate_limit_helper.run(
self._label,
self._provider.get_model_versions,
model_id,
)
async def get_model_versions_bulk(
self,
model_ids: Sequence[int],
) -> Optional[Dict[int, Dict]]:
return await self._rate_limit_helper.run(
self._label,
self._provider.get_model_versions_bulk,
model_ids,
)
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
return await self._rate_limit_helper.run(
self._label,
self._provider.get_model_version,
model_id,
version_id,
)
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
return await self._rate_limit_helper.run(
self._label,
self._provider.get_model_version_info,
version_id,
)
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
return await self._rate_limit_helper.run(
self._label,
self._provider.get_user_models,
username,
)
class ModelMetadataProviderManager:
"""Manager for selecting and using model metadata providers"""