mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat(metadata): add rate limit retry support to metadata providers
Add RateLimitRetryingProvider and _RateLimitRetryHelper classes to handle rate limiting with exponential backoff retries. Update get_metadata_provider function to automatically wrap providers with rate limit handling. This improves reliability when external APIs return rate limit errors by implementing automatic retries with configurable delays and jitter.
This commit is contained in:
@@ -2,11 +2,12 @@ import os
|
||||
import logging
|
||||
from .model_metadata_provider import (
|
||||
ModelMetadataProvider,
|
||||
ModelMetadataProviderManager,
|
||||
ModelMetadataProviderManager,
|
||||
SQLiteModelMetadataProvider,
|
||||
CivitaiModelMetadataProvider,
|
||||
CivArchiveModelMetadataProvider,
|
||||
FallbackMetadataProvider
|
||||
FallbackMetadataProvider,
|
||||
RateLimitRetryingProvider,
|
||||
)
|
||||
from .settings_manager import get_settings_manager
|
||||
from .metadata_archive_manager import MetadataArchiveManager
|
||||
@@ -108,14 +109,24 @@ async def get_metadata_archive_manager():
|
||||
base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
return MetadataArchiveManager(base_path)
|
||||
|
||||
def _wrap_provider_with_rate_limit(provider_name: str | None, provider: ModelMetadataProvider) -> ModelMetadataProvider:
|
||||
if isinstance(provider, (FallbackMetadataProvider, RateLimitRetryingProvider)):
|
||||
return provider
|
||||
return RateLimitRetryingProvider(provider, label=provider_name)
|
||||
|
||||
|
||||
async def get_metadata_provider(provider_name: str = None):
|
||||
"""Get a specific metadata provider or default provider"""
|
||||
"""Get a specific metadata provider or default provider with rate-limit handling."""
|
||||
|
||||
provider_manager = await ModelMetadataProviderManager.get_instance()
|
||||
|
||||
if provider_name:
|
||||
return provider_manager._get_provider(provider_name)
|
||||
|
||||
return provider_manager._get_provider()
|
||||
|
||||
provider = (
|
||||
provider_manager._get_provider(provider_name)
|
||||
if provider_name
|
||||
else provider_manager._get_provider()
|
||||
)
|
||||
|
||||
return _wrap_provider_with_rate_limit(provider_name, provider)
|
||||
|
||||
async def get_default_metadata_provider():
|
||||
"""Get the default metadata provider (fallback or single provider)"""
|
||||
|
||||
@@ -41,6 +41,55 @@ def _require_aiosqlite() -> Any:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _RateLimitRetryHelper:
|
||||
"""Coordinate exponential backoff retries after rate limiting."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
retry_limit: int = 3,
|
||||
base_delay: float = 1.5,
|
||||
max_delay: float = 30.0,
|
||||
jitter_ratio: float = 0.2,
|
||||
) -> None:
|
||||
self._retry_limit = max(1, retry_limit)
|
||||
self._base_delay = base_delay
|
||||
self._max_delay = max_delay
|
||||
self._jitter_ratio = max(0.0, jitter_ratio)
|
||||
|
||||
async def run(self, label: str, func, *args, **kwargs):
|
||||
attempt = 0
|
||||
while True:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except RateLimitError as exc:
|
||||
attempt += 1
|
||||
if attempt >= self._retry_limit:
|
||||
exc.provider = exc.provider or label
|
||||
raise
|
||||
|
||||
delay = self._calculate_delay(exc.retry_after, attempt)
|
||||
logger.warning(
|
||||
"Provider %s rate limited request; retrying in %.2fs (attempt %s/%s)",
|
||||
label,
|
||||
delay,
|
||||
attempt,
|
||||
self._retry_limit,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
def _calculate_delay(self, retry_after: Optional[float], attempt: int) -> float:
|
||||
if retry_after is not None:
|
||||
return min(self._max_delay, max(0.0, retry_after))
|
||||
|
||||
base_delay = self._base_delay * (2 ** max(0, attempt - 1))
|
||||
jitter_span = base_delay * self._jitter_ratio
|
||||
if jitter_span > 0:
|
||||
base_delay += random.uniform(-jitter_span, jitter_span)
|
||||
|
||||
return min(self._max_delay, max(0.0, base_delay))
|
||||
|
||||
class ModelMetadataProvider(ABC):
|
||||
"""Base abstract class for all model metadata providers"""
|
||||
|
||||
@@ -390,6 +439,12 @@ class FallbackMetadataProvider(ModelMetadataProvider):
|
||||
self._rate_limit_base_delay = rate_limit_base_delay
|
||||
self._rate_limit_max_delay = rate_limit_max_delay
|
||||
self._rate_limit_jitter_ratio = max(0.0, rate_limit_jitter_ratio)
|
||||
self._rate_limit_helper = _RateLimitRetryHelper(
|
||||
retry_limit=self._rate_limit_retry_limit,
|
||||
base_delay=self._rate_limit_base_delay,
|
||||
max_delay=self._rate_limit_max_delay,
|
||||
jitter_ratio=self._rate_limit_jitter_ratio,
|
||||
)
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
for provider, label in self._iter_providers():
|
||||
@@ -485,44 +540,80 @@ class FallbackMetadataProvider(ModelMetadataProvider):
|
||||
def _iter_providers(self):
|
||||
return zip(self.providers, self._provider_labels)
|
||||
|
||||
async def _call_with_rate_limit(
|
||||
async def _call_with_rate_limit(self, label: str, func, *args, **kwargs):
|
||||
return await self._rate_limit_helper.run(label, func, *args, **kwargs)
|
||||
|
||||
|
||||
class RateLimitRetryingProvider(ModelMetadataProvider):
|
||||
"""Adapter that retries individual provider calls after rate limiting."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
label: str,
|
||||
func,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
attempt = 0
|
||||
while True:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except RateLimitError as exc:
|
||||
attempt += 1
|
||||
if attempt >= self._rate_limit_retry_limit:
|
||||
exc.provider = exc.provider or label
|
||||
raise exc
|
||||
delay = self._calculate_rate_limit_delay(exc.retry_after, attempt)
|
||||
logger.warning(
|
||||
"Provider %s rate limited request; retrying in %.2fs (attempt %s/%s)",
|
||||
label,
|
||||
delay,
|
||||
attempt,
|
||||
self._rate_limit_retry_limit,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
except Exception:
|
||||
raise
|
||||
provider: ModelMetadataProvider,
|
||||
label: Optional[str] = None,
|
||||
*,
|
||||
rate_limit_retry_limit: int = 3,
|
||||
rate_limit_base_delay: float = 1.5,
|
||||
rate_limit_max_delay: float = 30.0,
|
||||
rate_limit_jitter_ratio: float = 0.2,
|
||||
) -> None:
|
||||
self._provider = provider
|
||||
self._label = label or provider.__class__.__name__
|
||||
self._rate_limit_helper = _RateLimitRetryHelper(
|
||||
retry_limit=rate_limit_retry_limit,
|
||||
base_delay=rate_limit_base_delay,
|
||||
max_delay=rate_limit_max_delay,
|
||||
jitter_ratio=rate_limit_jitter_ratio,
|
||||
)
|
||||
|
||||
def _calculate_rate_limit_delay(self, retry_after: Optional[float], attempt: int) -> float:
|
||||
if retry_after is not None:
|
||||
return min(self._rate_limit_max_delay, max(0.0, retry_after))
|
||||
def __getattr__(self, item):
|
||||
return getattr(self._provider, item)
|
||||
|
||||
base_delay = self._rate_limit_base_delay * (2 ** max(0, attempt - 1))
|
||||
jitter_span = base_delay * self._rate_limit_jitter_ratio
|
||||
if jitter_span > 0:
|
||||
base_delay += random.uniform(-jitter_span, jitter_span)
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
return await self._rate_limit_helper.run(
|
||||
self._label,
|
||||
self._provider.get_model_by_hash,
|
||||
model_hash,
|
||||
)
|
||||
|
||||
return min(self._rate_limit_max_delay, max(0.0, base_delay))
|
||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||
return await self._rate_limit_helper.run(
|
||||
self._label,
|
||||
self._provider.get_model_versions,
|
||||
model_id,
|
||||
)
|
||||
|
||||
async def get_model_versions_bulk(
|
||||
self,
|
||||
model_ids: Sequence[int],
|
||||
) -> Optional[Dict[int, Dict]]:
|
||||
return await self._rate_limit_helper.run(
|
||||
self._label,
|
||||
self._provider.get_model_versions_bulk,
|
||||
model_ids,
|
||||
)
|
||||
|
||||
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||
return await self._rate_limit_helper.run(
|
||||
self._label,
|
||||
self._provider.get_model_version,
|
||||
model_id,
|
||||
version_id,
|
||||
)
|
||||
|
||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
return await self._rate_limit_helper.run(
|
||||
self._label,
|
||||
self._provider.get_model_version_info,
|
||||
version_id,
|
||||
)
|
||||
|
||||
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
||||
return await self._rate_limit_helper.run(
|
||||
self._label,
|
||||
self._provider.get_user_models,
|
||||
username,
|
||||
)
|
||||
|
||||
class ModelMetadataProviderManager:
|
||||
"""Manager for selecting and using model metadata providers"""
|
||||
|
||||
Reference in New Issue
Block a user