mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat(metadata): add rate limit retry support to metadata providers
This commit is contained in:
@@ -2,11 +2,12 @@ import os
|
||||
import logging
|
||||
from .model_metadata_provider import (
|
||||
ModelMetadataProvider,
|
||||
ModelMetadataProviderManager,
|
||||
ModelMetadataProviderManager,
|
||||
SQLiteModelMetadataProvider,
|
||||
CivitaiModelMetadataProvider,
|
||||
CivArchiveModelMetadataProvider,
|
||||
FallbackMetadataProvider
|
||||
FallbackMetadataProvider,
|
||||
RateLimitRetryingProvider,
|
||||
)
|
||||
from .settings_manager import get_settings_manager
|
||||
from .metadata_archive_manager import MetadataArchiveManager
|
||||
@@ -108,14 +109,24 @@ async def get_metadata_archive_manager():
|
||||
base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
return MetadataArchiveManager(base_path)
|
||||
|
||||
def _wrap_provider_with_rate_limit(provider_name: str | None, provider: ModelMetadataProvider) -> ModelMetadataProvider:
|
||||
if isinstance(provider, (FallbackMetadataProvider, RateLimitRetryingProvider)):
|
||||
return provider
|
||||
return RateLimitRetryingProvider(provider, label=provider_name)
|
||||
|
||||
|
||||
async def get_metadata_provider(provider_name: str = None):
|
||||
"""Get a specific metadata provider or default provider"""
|
||||
"""Get a specific metadata provider or default provider with rate-limit handling."""
|
||||
|
||||
provider_manager = await ModelMetadataProviderManager.get_instance()
|
||||
|
||||
if provider_name:
|
||||
return provider_manager._get_provider(provider_name)
|
||||
|
||||
return provider_manager._get_provider()
|
||||
|
||||
provider = (
|
||||
provider_manager._get_provider(provider_name)
|
||||
if provider_name
|
||||
else provider_manager._get_provider()
|
||||
)
|
||||
|
||||
return _wrap_provider_with_rate_limit(provider_name, provider)
|
||||
|
||||
async def get_default_metadata_provider():
|
||||
"""Get the default metadata provider (fallback or single provider)"""
|
||||
|
||||
@@ -41,6 +41,55 @@ def _require_aiosqlite() -> Any:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _RateLimitRetryHelper:
|
||||
"""Coordinate exponential backoff retries after rate limiting."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
retry_limit: int = 3,
|
||||
base_delay: float = 1.5,
|
||||
max_delay: float = 30.0,
|
||||
jitter_ratio: float = 0.2,
|
||||
) -> None:
|
||||
self._retry_limit = max(1, retry_limit)
|
||||
self._base_delay = base_delay
|
||||
self._max_delay = max_delay
|
||||
self._jitter_ratio = max(0.0, jitter_ratio)
|
||||
|
||||
async def run(self, label: str, func, *args, **kwargs):
|
||||
attempt = 0
|
||||
while True:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except RateLimitError as exc:
|
||||
attempt += 1
|
||||
if attempt >= self._retry_limit:
|
||||
exc.provider = exc.provider or label
|
||||
raise
|
||||
|
||||
delay = self._calculate_delay(exc.retry_after, attempt)
|
||||
logger.warning(
|
||||
"Provider %s rate limited request; retrying in %.2fs (attempt %s/%s)",
|
||||
label,
|
||||
delay,
|
||||
attempt,
|
||||
self._retry_limit,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
def _calculate_delay(self, retry_after: Optional[float], attempt: int) -> float:
|
||||
if retry_after is not None:
|
||||
return min(self._max_delay, max(0.0, retry_after))
|
||||
|
||||
base_delay = self._base_delay * (2 ** max(0, attempt - 1))
|
||||
jitter_span = base_delay * self._jitter_ratio
|
||||
if jitter_span > 0:
|
||||
base_delay += random.uniform(-jitter_span, jitter_span)
|
||||
|
||||
return min(self._max_delay, max(0.0, base_delay))
|
||||
|
||||
class ModelMetadataProvider(ABC):
|
||||
"""Base abstract class for all model metadata providers"""
|
||||
|
||||
@@ -390,6 +439,12 @@ class FallbackMetadataProvider(ModelMetadataProvider):
|
||||
self._rate_limit_base_delay = rate_limit_base_delay
|
||||
self._rate_limit_max_delay = rate_limit_max_delay
|
||||
self._rate_limit_jitter_ratio = max(0.0, rate_limit_jitter_ratio)
|
||||
self._rate_limit_helper = _RateLimitRetryHelper(
|
||||
retry_limit=self._rate_limit_retry_limit,
|
||||
base_delay=self._rate_limit_base_delay,
|
||||
max_delay=self._rate_limit_max_delay,
|
||||
jitter_ratio=self._rate_limit_jitter_ratio,
|
||||
)
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
for provider, label in self._iter_providers():
|
||||
@@ -485,44 +540,80 @@ class FallbackMetadataProvider(ModelMetadataProvider):
|
||||
def _iter_providers(self):
|
||||
return zip(self.providers, self._provider_labels)
|
||||
|
||||
async def _call_with_rate_limit(
|
||||
async def _call_with_rate_limit(self, label: str, func, *args, **kwargs):
|
||||
return await self._rate_limit_helper.run(label, func, *args, **kwargs)
|
||||
|
||||
|
||||
class RateLimitRetryingProvider(ModelMetadataProvider):
|
||||
"""Adapter that retries individual provider calls after rate limiting."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
label: str,
|
||||
func,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
attempt = 0
|
||||
while True:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except RateLimitError as exc:
|
||||
attempt += 1
|
||||
if attempt >= self._rate_limit_retry_limit:
|
||||
exc.provider = exc.provider or label
|
||||
raise exc
|
||||
delay = self._calculate_rate_limit_delay(exc.retry_after, attempt)
|
||||
logger.warning(
|
||||
"Provider %s rate limited request; retrying in %.2fs (attempt %s/%s)",
|
||||
label,
|
||||
delay,
|
||||
attempt,
|
||||
self._rate_limit_retry_limit,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
except Exception:
|
||||
raise
|
||||
provider: ModelMetadataProvider,
|
||||
label: Optional[str] = None,
|
||||
*,
|
||||
rate_limit_retry_limit: int = 3,
|
||||
rate_limit_base_delay: float = 1.5,
|
||||
rate_limit_max_delay: float = 30.0,
|
||||
rate_limit_jitter_ratio: float = 0.2,
|
||||
) -> None:
|
||||
self._provider = provider
|
||||
self._label = label or provider.__class__.__name__
|
||||
self._rate_limit_helper = _RateLimitRetryHelper(
|
||||
retry_limit=rate_limit_retry_limit,
|
||||
base_delay=rate_limit_base_delay,
|
||||
max_delay=rate_limit_max_delay,
|
||||
jitter_ratio=rate_limit_jitter_ratio,
|
||||
)
|
||||
|
||||
def _calculate_rate_limit_delay(self, retry_after: Optional[float], attempt: int) -> float:
|
||||
if retry_after is not None:
|
||||
return min(self._rate_limit_max_delay, max(0.0, retry_after))
|
||||
def __getattr__(self, item):
|
||||
return getattr(self._provider, item)
|
||||
|
||||
base_delay = self._rate_limit_base_delay * (2 ** max(0, attempt - 1))
|
||||
jitter_span = base_delay * self._rate_limit_jitter_ratio
|
||||
if jitter_span > 0:
|
||||
base_delay += random.uniform(-jitter_span, jitter_span)
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
return await self._rate_limit_helper.run(
|
||||
self._label,
|
||||
self._provider.get_model_by_hash,
|
||||
model_hash,
|
||||
)
|
||||
|
||||
return min(self._rate_limit_max_delay, max(0.0, base_delay))
|
||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||
return await self._rate_limit_helper.run(
|
||||
self._label,
|
||||
self._provider.get_model_versions,
|
||||
model_id,
|
||||
)
|
||||
|
||||
async def get_model_versions_bulk(
|
||||
self,
|
||||
model_ids: Sequence[int],
|
||||
) -> Optional[Dict[int, Dict]]:
|
||||
return await self._rate_limit_helper.run(
|
||||
self._label,
|
||||
self._provider.get_model_versions_bulk,
|
||||
model_ids,
|
||||
)
|
||||
|
||||
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||
return await self._rate_limit_helper.run(
|
||||
self._label,
|
||||
self._provider.get_model_version,
|
||||
model_id,
|
||||
version_id,
|
||||
)
|
||||
|
||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
return await self._rate_limit_helper.run(
|
||||
self._label,
|
||||
self._provider.get_model_version_info,
|
||||
version_id,
|
||||
)
|
||||
|
||||
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
||||
return await self._rate_limit_helper.run(
|
||||
self._label,
|
||||
self._provider.get_user_models,
|
||||
username,
|
||||
)
|
||||
|
||||
class ModelMetadataProviderManager:
|
||||
"""Manager for selecting and using model metadata providers"""
|
||||
|
||||
62
tests/services/test_metadata_service.py
Normal file
62
tests/services/test_metadata_service.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services import metadata_service
|
||||
from py.services.model_metadata_provider import (
|
||||
FallbackMetadataProvider,
|
||||
ModelMetadataProvider,
|
||||
RateLimitRetryingProvider,
|
||||
)
|
||||
|
||||
|
||||
class DummyProvider(ModelMetadataProvider):
|
||||
async def get_model_by_hash(self, model_hash: str):
|
||||
return None, None
|
||||
|
||||
async def get_model_versions(self, model_id: str):
|
||||
return None
|
||||
|
||||
async def get_model_versions_bulk(self, model_ids):
|
||||
return None
|
||||
|
||||
async def get_model_version(self, model_id: int = None, version_id: int = None):
|
||||
return None
|
||||
|
||||
async def get_model_version_info(self, version_id: str):
|
||||
return None, None
|
||||
|
||||
async def get_user_models(self, username: str):
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_metadata_provider_wraps_non_fallback(monkeypatch):
|
||||
provider = DummyProvider()
|
||||
dummy_manager = SimpleNamespace(_get_provider=lambda _name=None: provider)
|
||||
monkeypatch.setattr(
|
||||
metadata_service.ModelMetadataProviderManager,
|
||||
"get_instance",
|
||||
AsyncMock(return_value=dummy_manager),
|
||||
)
|
||||
|
||||
wrapped = await metadata_service.get_metadata_provider("dummy")
|
||||
|
||||
assert isinstance(wrapped, RateLimitRetryingProvider)
|
||||
assert wrapped is not provider
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_metadata_provider_returns_fallback_as_is(monkeypatch):
|
||||
fallback = FallbackMetadataProvider([("dummy", DummyProvider())])
|
||||
dummy_manager = SimpleNamespace(_get_provider=lambda _name=None: fallback)
|
||||
monkeypatch.setattr(
|
||||
metadata_service.ModelMetadataProviderManager,
|
||||
"get_instance",
|
||||
AsyncMock(return_value=dummy_manager),
|
||||
)
|
||||
|
||||
provider = await metadata_service.get_metadata_provider()
|
||||
|
||||
assert provider is fallback
|
||||
@@ -4,7 +4,10 @@ import pytest
|
||||
|
||||
from py.services import model_metadata_provider as provider_module
|
||||
from py.services.errors import RateLimitError
|
||||
from py.services.model_metadata_provider import FallbackMetadataProvider
|
||||
from py.services.model_metadata_provider import (
|
||||
FallbackMetadataProvider,
|
||||
RateLimitRetryingProvider,
|
||||
)
|
||||
|
||||
|
||||
class RateLimitThenSuccessProvider:
|
||||
@@ -80,3 +83,37 @@ async def test_fallback_respects_retry_limit(monkeypatch):
|
||||
assert primary.calls == 2
|
||||
assert secondary.calls == 0
|
||||
sleep_mock.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_retrying_provider_retries(monkeypatch):
|
||||
sleep_mock = AsyncMock()
|
||||
monkeypatch.setattr(provider_module.asyncio, "sleep", sleep_mock)
|
||||
monkeypatch.setattr(provider_module.random, "uniform", lambda *_: 0.0)
|
||||
|
||||
inner = RateLimitThenSuccessProvider()
|
||||
wrapper = RateLimitRetryingProvider(inner, label="inner", rate_limit_base_delay=0.1)
|
||||
|
||||
result, error = await wrapper.get_model_by_hash("abc")
|
||||
|
||||
assert error is None
|
||||
assert result == {"id": "ok"}
|
||||
assert inner.calls == 2
|
||||
sleep_mock.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_retrying_provider_respects_limit(monkeypatch):
|
||||
sleep_mock = AsyncMock()
|
||||
monkeypatch.setattr(provider_module.asyncio, "sleep", sleep_mock)
|
||||
monkeypatch.setattr(provider_module.random, "uniform", lambda *_: 0.0)
|
||||
|
||||
inner = AlwaysRateLimitedProvider()
|
||||
wrapper = RateLimitRetryingProvider(inner, label="inner", rate_limit_retry_limit=2)
|
||||
|
||||
with pytest.raises(RateLimitError) as exc_info:
|
||||
await wrapper.get_model_by_hash("abc")
|
||||
|
||||
assert exc_info.value.provider == "inner"
|
||||
assert inner.calls == 2
|
||||
sleep_mock.assert_awaited_once()
|
||||
|
||||
Reference in New Issue
Block a user