diff --git a/py/services/metadata_service.py b/py/services/metadata_service.py index 302e6551..1a48ceaf 100644 --- a/py/services/metadata_service.py +++ b/py/services/metadata_service.py @@ -2,11 +2,12 @@ import os import logging from .model_metadata_provider import ( ModelMetadataProvider, - ModelMetadataProviderManager, + ModelMetadataProviderManager, SQLiteModelMetadataProvider, CivitaiModelMetadataProvider, CivArchiveModelMetadataProvider, - FallbackMetadataProvider + FallbackMetadataProvider, + RateLimitRetryingProvider, ) from .settings_manager import get_settings_manager from .metadata_archive_manager import MetadataArchiveManager @@ -108,14 +109,24 @@ async def get_metadata_archive_manager(): base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) return MetadataArchiveManager(base_path) +def _wrap_provider_with_rate_limit(provider_name: str | None, provider: ModelMetadataProvider) -> ModelMetadataProvider: + if isinstance(provider, (FallbackMetadataProvider, RateLimitRetryingProvider)): + return provider + return RateLimitRetryingProvider(provider, label=provider_name) + + async def get_metadata_provider(provider_name: str = None): - """Get a specific metadata provider or default provider""" + """Get a specific metadata provider or default provider with rate-limit handling.""" + provider_manager = await ModelMetadataProviderManager.get_instance() - - if provider_name: - return provider_manager._get_provider(provider_name) - - return provider_manager._get_provider() + + provider = ( + provider_manager._get_provider(provider_name) + if provider_name + else provider_manager._get_provider() + ) + + return _wrap_provider_with_rate_limit(provider_name, provider) async def get_default_metadata_provider(): """Get the default metadata provider (fallback or single provider)""" diff --git a/py/services/model_metadata_provider.py b/py/services/model_metadata_provider.py index 9cc06300..afd8459b 100644 --- a/py/services/model_metadata_provider.py +++ b/py/services/model_metadata_provider.py @@ -41,6 +41,55 @@ def _require_aiosqlite() -> Any: logger = logging.getLogger(__name__) + +class _RateLimitRetryHelper: + """Coordinate exponential backoff retries after rate limiting.""" + + def __init__( + self, + *, + retry_limit: int = 3, + base_delay: float = 1.5, + max_delay: float = 30.0, + jitter_ratio: float = 0.2, + ) -> None: + self._retry_limit = max(1, retry_limit) + self._base_delay = base_delay + self._max_delay = max_delay + self._jitter_ratio = max(0.0, jitter_ratio) + + async def run(self, label: str, func, *args, **kwargs): + attempt = 0 + while True: + try: + return await func(*args, **kwargs) + except RateLimitError as exc: + attempt += 1 + if attempt >= self._retry_limit: + exc.provider = exc.provider or label + raise + + delay = self._calculate_delay(exc.retry_after, attempt) + logger.warning( + "Provider %s rate limited request; retrying in %.2fs (attempt %s/%s)", + label, + delay, + attempt, + self._retry_limit, + ) + await asyncio.sleep(delay) + + def _calculate_delay(self, retry_after: Optional[float], attempt: int) -> float: + if retry_after is not None: + return min(self._max_delay, max(0.0, retry_after)) + + base_delay = self._base_delay * (2 ** max(0, attempt - 1)) + jitter_span = base_delay * self._jitter_ratio + if jitter_span > 0: + base_delay += random.uniform(-jitter_span, jitter_span) + + return min(self._max_delay, max(0.0, base_delay)) + class ModelMetadataProvider(ABC): """Base abstract class for all model metadata providers""" @@ -390,6 +439,12 @@ class FallbackMetadataProvider(ModelMetadataProvider): self._rate_limit_base_delay = rate_limit_base_delay self._rate_limit_max_delay = rate_limit_max_delay self._rate_limit_jitter_ratio = max(0.0, rate_limit_jitter_ratio) + self._rate_limit_helper = _RateLimitRetryHelper( + retry_limit=self._rate_limit_retry_limit, + base_delay=self._rate_limit_base_delay, + max_delay=self._rate_limit_max_delay, + jitter_ratio=self._rate_limit_jitter_ratio, + ) async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]: for provider, label in self._iter_providers(): @@ -485,44 +540,80 @@ class FallbackMetadataProvider(ModelMetadataProvider): def _iter_providers(self): return zip(self.providers, self._provider_labels) - async def _call_with_rate_limit( + async def _call_with_rate_limit(self, label: str, func, *args, **kwargs): + return await self._rate_limit_helper.run(label, func, *args, **kwargs) + + +class RateLimitRetryingProvider(ModelMetadataProvider): + """Adapter that retries individual provider calls after rate limiting.""" + + def __init__( self, - label: str, - func, - *args, - **kwargs, - ): - attempt = 0 - while True: - try: - return await func(*args, **kwargs) - except RateLimitError as exc: - attempt += 1 - if attempt >= self._rate_limit_retry_limit: - exc.provider = exc.provider or label - raise exc - delay = self._calculate_rate_limit_delay(exc.retry_after, attempt) - logger.warning( - "Provider %s rate limited request; retrying in %.2fs (attempt %s/%s)", - label, - delay, - attempt, - self._rate_limit_retry_limit, - ) - await asyncio.sleep(delay) - except Exception: - raise + provider: ModelMetadataProvider, + label: Optional[str] = None, + *, + rate_limit_retry_limit: int = 3, + rate_limit_base_delay: float = 1.5, + rate_limit_max_delay: float = 30.0, + rate_limit_jitter_ratio: float = 0.2, + ) -> None: + self._provider = provider + self._label = label or provider.__class__.__name__ + self._rate_limit_helper = _RateLimitRetryHelper( + retry_limit=rate_limit_retry_limit, + base_delay=rate_limit_base_delay, + max_delay=rate_limit_max_delay, + jitter_ratio=rate_limit_jitter_ratio, + ) - def _calculate_rate_limit_delay(self, retry_after: Optional[float], attempt: int) -> float: - if retry_after is not None: - return min(self._rate_limit_max_delay, max(0.0, retry_after)) + def __getattr__(self, item): + return getattr(self._provider, item) - base_delay = self._rate_limit_base_delay * (2 ** max(0, attempt - 1)) - jitter_span = base_delay * self._rate_limit_jitter_ratio - if jitter_span > 0: - base_delay += random.uniform(-jitter_span, jitter_span) + async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]: + return await self._rate_limit_helper.run( + self._label, + self._provider.get_model_by_hash, + model_hash, + ) - return min(self._rate_limit_max_delay, max(0.0, base_delay)) + async def get_model_versions(self, model_id: str) -> Optional[Dict]: + return await self._rate_limit_helper.run( + self._label, + self._provider.get_model_versions, + model_id, + ) + + async def get_model_versions_bulk( + self, + model_ids: Sequence[int], + ) -> Optional[Dict[int, Dict]]: + return await self._rate_limit_helper.run( + self._label, + self._provider.get_model_versions_bulk, + model_ids, + ) + + async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: + return await self._rate_limit_helper.run( + self._label, + self._provider.get_model_version, + model_id, + version_id, + ) + + async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]: + return await self._rate_limit_helper.run( + self._label, + self._provider.get_model_version_info, + version_id, + ) + + async def get_user_models(self, username: str) -> Optional[List[Dict]]: + return await self._rate_limit_helper.run( + self._label, + self._provider.get_user_models, + username, + ) class ModelMetadataProviderManager: """Manager for selecting and using model metadata providers""" diff --git a/tests/services/test_metadata_service.py b/tests/services/test_metadata_service.py new file mode 100644 index 00000000..b4cc8112 --- /dev/null +++ b/tests/services/test_metadata_service.py @@ -0,0 +1,62 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from py.services import metadata_service +from py.services.model_metadata_provider import ( + FallbackMetadataProvider, + ModelMetadataProvider, + RateLimitRetryingProvider, +) + + +class DummyProvider(ModelMetadataProvider): + async def get_model_by_hash(self, model_hash: str): + return None, None + + async def get_model_versions(self, model_id: str): + return None + + async def get_model_versions_bulk(self, model_ids): + return None + + async def get_model_version(self, model_id: int = None, version_id: int = None): + return None + + async def get_model_version_info(self, version_id: str): + return None, None + + async def get_user_models(self, username: str): + return None + + +@pytest.mark.asyncio +async def test_get_metadata_provider_wraps_non_fallback(monkeypatch): + provider = DummyProvider() + dummy_manager = SimpleNamespace(_get_provider=lambda _name=None: provider) + monkeypatch.setattr( + metadata_service.ModelMetadataProviderManager, + "get_instance", + AsyncMock(return_value=dummy_manager), + ) + + wrapped = await metadata_service.get_metadata_provider("dummy") + + assert isinstance(wrapped, RateLimitRetryingProvider) + assert wrapped is not provider + + +@pytest.mark.asyncio +async def test_get_metadata_provider_returns_fallback_as_is(monkeypatch): + fallback = FallbackMetadataProvider([("dummy", DummyProvider())]) + dummy_manager = SimpleNamespace(_get_provider=lambda _name=None: fallback) + monkeypatch.setattr( + metadata_service.ModelMetadataProviderManager, + "get_instance", + AsyncMock(return_value=dummy_manager), + ) + + provider = await metadata_service.get_metadata_provider() + + assert provider is fallback diff --git a/tests/services/test_model_metadata_provider.py b/tests/services/test_model_metadata_provider.py index cb1761de..6c9dd093 100644 --- a/tests/services/test_model_metadata_provider.py +++ b/tests/services/test_model_metadata_provider.py @@ -4,7 +4,10 @@ import pytest from py.services import model_metadata_provider as provider_module from py.services.errors import RateLimitError -from py.services.model_metadata_provider import FallbackMetadataProvider +from py.services.model_metadata_provider import ( + FallbackMetadataProvider, + RateLimitRetryingProvider, +) class RateLimitThenSuccessProvider: @@ -80,3 +83,37 @@ async def test_fallback_respects_retry_limit(monkeypatch): assert primary.calls == 2 assert secondary.calls == 0 sleep_mock.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_rate_limit_retrying_provider_retries(monkeypatch): + sleep_mock = AsyncMock() + monkeypatch.setattr(provider_module.asyncio, "sleep", sleep_mock) + monkeypatch.setattr(provider_module.random, "uniform", lambda *_: 0.0) + + inner = RateLimitThenSuccessProvider() + wrapper = RateLimitRetryingProvider(inner, label="inner", rate_limit_base_delay=0.1) + + result, error = await wrapper.get_model_by_hash("abc") + + assert error is None + assert result == {"id": "ok"} + assert inner.calls == 2 + sleep_mock.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_rate_limit_retrying_provider_respects_limit(monkeypatch): + sleep_mock = AsyncMock() + monkeypatch.setattr(provider_module.asyncio, "sleep", sleep_mock) + monkeypatch.setattr(provider_module.random, "uniform", lambda *_: 0.0) + + inner = AlwaysRateLimitedProvider() + wrapper = RateLimitRetryingProvider(inner, label="inner", rate_limit_retry_limit=2) + + with pytest.raises(RateLimitError) as exc_info: + await wrapper.get_model_by_hash("abc") + + assert exc_info.value.provider == "inner" + assert inner.calls == 2 + sleep_mock.assert_awaited_once()