Merge pull request #648 from willmiao/fix-rate-limit-retry, see #647

feat(metadata): add rate limit retry support to metadata providers
This commit is contained in:
pixelpaws
2025-11-07 10:57:48 +08:00
committed by GitHub
4 changed files with 244 additions and 43 deletions

View File

@@ -2,11 +2,12 @@ import os
import logging
from .model_metadata_provider import (
ModelMetadataProvider,
ModelMetadataProviderManager,
ModelMetadataProviderManager,
SQLiteModelMetadataProvider,
CivitaiModelMetadataProvider,
CivArchiveModelMetadataProvider,
FallbackMetadataProvider
FallbackMetadataProvider,
RateLimitRetryingProvider,
)
from .settings_manager import get_settings_manager
from .metadata_archive_manager import MetadataArchiveManager
@@ -108,14 +109,24 @@ async def get_metadata_archive_manager():
base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
return MetadataArchiveManager(base_path)
def _wrap_provider_with_rate_limit(provider_name: str | None, provider: ModelMetadataProvider) -> ModelMetadataProvider:
if isinstance(provider, (FallbackMetadataProvider, RateLimitRetryingProvider)):
return provider
return RateLimitRetryingProvider(provider, label=provider_name)
async def get_metadata_provider(provider_name: str = None):
"""Get a specific metadata provider or default provider"""
"""Get a specific metadata provider or default provider with rate-limit handling."""
provider_manager = await ModelMetadataProviderManager.get_instance()
if provider_name:
return provider_manager._get_provider(provider_name)
return provider_manager._get_provider()
provider = (
provider_manager._get_provider(provider_name)
if provider_name
else provider_manager._get_provider()
)
return _wrap_provider_with_rate_limit(provider_name, provider)
async def get_default_metadata_provider():
"""Get the default metadata provider (fallback or single provider)"""

View File

@@ -41,6 +41,55 @@ def _require_aiosqlite() -> Any:
logger = logging.getLogger(__name__)
class _RateLimitRetryHelper:
"""Coordinate exponential backoff retries after rate limiting."""
def __init__(
self,
*,
retry_limit: int = 3,
base_delay: float = 1.5,
max_delay: float = 30.0,
jitter_ratio: float = 0.2,
) -> None:
self._retry_limit = max(1, retry_limit)
self._base_delay = base_delay
self._max_delay = max_delay
self._jitter_ratio = max(0.0, jitter_ratio)
async def run(self, label: str, func, *args, **kwargs):
attempt = 0
while True:
try:
return await func(*args, **kwargs)
except RateLimitError as exc:
attempt += 1
if attempt >= self._retry_limit:
exc.provider = exc.provider or label
raise
delay = self._calculate_delay(exc.retry_after, attempt)
logger.warning(
"Provider %s rate limited request; retrying in %.2fs (attempt %s/%s)",
label,
delay,
attempt,
self._retry_limit,
)
await asyncio.sleep(delay)
def _calculate_delay(self, retry_after: Optional[float], attempt: int) -> float:
if retry_after is not None:
return min(self._max_delay, max(0.0, retry_after))
base_delay = self._base_delay * (2 ** max(0, attempt - 1))
jitter_span = base_delay * self._jitter_ratio
if jitter_span > 0:
base_delay += random.uniform(-jitter_span, jitter_span)
return min(self._max_delay, max(0.0, base_delay))
class ModelMetadataProvider(ABC):
"""Base abstract class for all model metadata providers"""
@@ -390,6 +439,12 @@ class FallbackMetadataProvider(ModelMetadataProvider):
self._rate_limit_base_delay = rate_limit_base_delay
self._rate_limit_max_delay = rate_limit_max_delay
self._rate_limit_jitter_ratio = max(0.0, rate_limit_jitter_ratio)
self._rate_limit_helper = _RateLimitRetryHelper(
retry_limit=self._rate_limit_retry_limit,
base_delay=self._rate_limit_base_delay,
max_delay=self._rate_limit_max_delay,
jitter_ratio=self._rate_limit_jitter_ratio,
)
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
for provider, label in self._iter_providers():
@@ -485,44 +540,80 @@ class FallbackMetadataProvider(ModelMetadataProvider):
def _iter_providers(self):
return zip(self.providers, self._provider_labels)
async def _call_with_rate_limit(
async def _call_with_rate_limit(self, label: str, func, *args, **kwargs):
return await self._rate_limit_helper.run(label, func, *args, **kwargs)
class RateLimitRetryingProvider(ModelMetadataProvider):
"""Adapter that retries individual provider calls after rate limiting."""
def __init__(
self,
label: str,
func,
*args,
**kwargs,
):
attempt = 0
while True:
try:
return await func(*args, **kwargs)
except RateLimitError as exc:
attempt += 1
if attempt >= self._rate_limit_retry_limit:
exc.provider = exc.provider or label
raise exc
delay = self._calculate_rate_limit_delay(exc.retry_after, attempt)
logger.warning(
"Provider %s rate limited request; retrying in %.2fs (attempt %s/%s)",
label,
delay,
attempt,
self._rate_limit_retry_limit,
)
await asyncio.sleep(delay)
except Exception:
raise
provider: ModelMetadataProvider,
label: Optional[str] = None,
*,
rate_limit_retry_limit: int = 3,
rate_limit_base_delay: float = 1.5,
rate_limit_max_delay: float = 30.0,
rate_limit_jitter_ratio: float = 0.2,
) -> None:
self._provider = provider
self._label = label or provider.__class__.__name__
self._rate_limit_helper = _RateLimitRetryHelper(
retry_limit=rate_limit_retry_limit,
base_delay=rate_limit_base_delay,
max_delay=rate_limit_max_delay,
jitter_ratio=rate_limit_jitter_ratio,
)
def _calculate_rate_limit_delay(self, retry_after: Optional[float], attempt: int) -> float:
if retry_after is not None:
return min(self._rate_limit_max_delay, max(0.0, retry_after))
def __getattr__(self, item):
return getattr(self._provider, item)
base_delay = self._rate_limit_base_delay * (2 ** max(0, attempt - 1))
jitter_span = base_delay * self._rate_limit_jitter_ratio
if jitter_span > 0:
base_delay += random.uniform(-jitter_span, jitter_span)
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
return await self._rate_limit_helper.run(
self._label,
self._provider.get_model_by_hash,
model_hash,
)
return min(self._rate_limit_max_delay, max(0.0, base_delay))
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
return await self._rate_limit_helper.run(
self._label,
self._provider.get_model_versions,
model_id,
)
async def get_model_versions_bulk(
self,
model_ids: Sequence[int],
) -> Optional[Dict[int, Dict]]:
return await self._rate_limit_helper.run(
self._label,
self._provider.get_model_versions_bulk,
model_ids,
)
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
return await self._rate_limit_helper.run(
self._label,
self._provider.get_model_version,
model_id,
version_id,
)
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
return await self._rate_limit_helper.run(
self._label,
self._provider.get_model_version_info,
version_id,
)
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
return await self._rate_limit_helper.run(
self._label,
self._provider.get_user_models,
username,
)
class ModelMetadataProviderManager:
"""Manager for selecting and using model metadata providers"""

View File

@@ -0,0 +1,62 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from py.services import metadata_service
from py.services.model_metadata_provider import (
FallbackMetadataProvider,
ModelMetadataProvider,
RateLimitRetryingProvider,
)
class DummyProvider(ModelMetadataProvider):
async def get_model_by_hash(self, model_hash: str):
return None, None
async def get_model_versions(self, model_id: str):
return None
async def get_model_versions_bulk(self, model_ids):
return None
async def get_model_version(self, model_id: int = None, version_id: int = None):
return None
async def get_model_version_info(self, version_id: str):
return None, None
async def get_user_models(self, username: str):
return None
@pytest.mark.asyncio
async def test_get_metadata_provider_wraps_non_fallback(monkeypatch):
provider = DummyProvider()
dummy_manager = SimpleNamespace(_get_provider=lambda _name=None: provider)
monkeypatch.setattr(
metadata_service.ModelMetadataProviderManager,
"get_instance",
AsyncMock(return_value=dummy_manager),
)
wrapped = await metadata_service.get_metadata_provider("dummy")
assert isinstance(wrapped, RateLimitRetryingProvider)
assert wrapped is not provider
@pytest.mark.asyncio
async def test_get_metadata_provider_returns_fallback_as_is(monkeypatch):
fallback = FallbackMetadataProvider([("dummy", DummyProvider())])
dummy_manager = SimpleNamespace(_get_provider=lambda _name=None: fallback)
monkeypatch.setattr(
metadata_service.ModelMetadataProviderManager,
"get_instance",
AsyncMock(return_value=dummy_manager),
)
provider = await metadata_service.get_metadata_provider()
assert provider is fallback

View File

@@ -4,7 +4,10 @@ import pytest
from py.services import model_metadata_provider as provider_module
from py.services.errors import RateLimitError
from py.services.model_metadata_provider import FallbackMetadataProvider
from py.services.model_metadata_provider import (
FallbackMetadataProvider,
RateLimitRetryingProvider,
)
class RateLimitThenSuccessProvider:
@@ -80,3 +83,37 @@ async def test_fallback_respects_retry_limit(monkeypatch):
assert primary.calls == 2
assert secondary.calls == 0
sleep_mock.assert_awaited_once()
@pytest.mark.asyncio
async def test_rate_limit_retrying_provider_retries(monkeypatch):
sleep_mock = AsyncMock()
monkeypatch.setattr(provider_module.asyncio, "sleep", sleep_mock)
monkeypatch.setattr(provider_module.random, "uniform", lambda *_: 0.0)
inner = RateLimitThenSuccessProvider()
wrapper = RateLimitRetryingProvider(inner, label="inner", rate_limit_base_delay=0.1)
result, error = await wrapper.get_model_by_hash("abc")
assert error is None
assert result == {"id": "ok"}
assert inner.calls == 2
sleep_mock.assert_awaited_once()
@pytest.mark.asyncio
async def test_rate_limit_retrying_provider_respects_limit(monkeypatch):
sleep_mock = AsyncMock()
monkeypatch.setattr(provider_module.asyncio, "sleep", sleep_mock)
monkeypatch.setattr(provider_module.random, "uniform", lambda *_: 0.0)
inner = AlwaysRateLimitedProvider()
wrapper = RateLimitRetryingProvider(inner, label="inner", rate_limit_retry_limit=2)
with pytest.raises(RateLimitError) as exc_info:
await wrapper.get_model_by_hash("abc")
assert exc_info.value.provider == "inner"
assert inner.calls == 2
sleep_mock.assert_awaited_once()