feat(civitai): add rate limiting support and error handling

- Add RateLimitError import and _make_request wrapper method to handle rate limiting
- Update API methods to use _make_request wrapper instead of direct downloader calls
- Add explicit RateLimitError handling in API methods to properly propagate rate limit errors
- Add _extract_retry_after method to parse Retry-After headers
- Improve error handling by surfacing rate limit information to callers

These changes ensure that rate limiting from the Civitai API is properly detected and handled, allowing callers to implement appropriate backoff strategies when rate limits are encountered.
This commit is contained in:
Will Miao
2025-10-14 20:52:01 +08:00
committed by pixelpaws
parent 1454991d6d
commit 3fde474583
9 changed files with 397 additions and 50 deletions

View File

@@ -1,8 +1,11 @@
from abc import ABC, abstractmethod
import asyncio
import json
import logging
from typing import Optional, Dict, Tuple, Any, List
import random
from typing import Optional, Dict, Tuple, Any, List, Sequence
from .downloader import get_downloader
from .errors import RateLimitError
try:
from bs4 import BeautifulSoup
@@ -350,64 +353,166 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
class FallbackMetadataProvider(ModelMetadataProvider):
"""Try providers in order, return first successful result."""
def __init__(self, providers: list):
self.providers = providers
def __init__(
self,
providers: Sequence[ModelMetadataProvider | Tuple[str, ModelMetadataProvider]],
*,
rate_limit_retry_limit: int = 3,
rate_limit_base_delay: float = 1.5,
rate_limit_max_delay: float = 30.0,
rate_limit_jitter_ratio: float = 0.2,
) -> None:
self.providers: List[ModelMetadataProvider] = []
self._provider_labels: List[str] = []
for entry in providers:
if isinstance(entry, tuple) and len(entry) == 2:
name, provider = entry
else:
provider = entry
name = provider.__class__.__name__
self.providers.append(provider)
self._provider_labels.append(str(name))
self._rate_limit_retry_limit = max(1, rate_limit_retry_limit)
self._rate_limit_base_delay = rate_limit_base_delay
self._rate_limit_max_delay = rate_limit_max_delay
self._rate_limit_jitter_ratio = max(0.0, rate_limit_jitter_ratio)
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
for provider in self.providers:
for provider, label in self._iter_providers():
try:
result, error = await provider.get_model_by_hash(model_hash)
result, error = await self._call_with_rate_limit(
label,
provider.get_model_by_hash,
model_hash,
)
if result:
return result, error
except RateLimitError as exc:
exc.provider = exc.provider or label
raise exc
except Exception as e:
logger.debug(f"Provider failed for get_model_by_hash: {e}")
logger.debug("Provider %s failed for get_model_by_hash: %s", label, e)
continue
return None, "Model not found"
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
for provider in self.providers:
for provider, label in self._iter_providers():
try:
result = await provider.get_model_versions(model_id)
result = await self._call_with_rate_limit(
label,
provider.get_model_versions,
model_id,
)
if result:
return result
except RateLimitError as exc:
exc.provider = exc.provider or label
raise exc
except Exception as e:
logger.debug(f"Provider failed for get_model_versions: {e}")
logger.debug("Provider %s failed for get_model_versions: %s", label, e)
continue
return None
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
for provider in self.providers:
for provider, label in self._iter_providers():
try:
result = await provider.get_model_version(model_id, version_id)
result = await self._call_with_rate_limit(
label,
provider.get_model_version,
model_id,
version_id,
)
if result:
return result
except RateLimitError as exc:
exc.provider = exc.provider or label
raise exc
except Exception as e:
logger.debug(f"Provider failed for get_model_version: {e}")
logger.debug("Provider %s failed for get_model_version: %s", label, e)
continue
return None
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
for provider in self.providers:
for provider, label in self._iter_providers():
try:
result, error = await provider.get_model_version_info(version_id)
result, error = await self._call_with_rate_limit(
label,
provider.get_model_version_info,
version_id,
)
if result:
return result, error
except RateLimitError as exc:
exc.provider = exc.provider or label
raise exc
except Exception as e:
logger.debug(f"Provider failed for get_model_version_info: {e}")
logger.debug("Provider %s failed for get_model_version_info: %s", label, e)
continue
return None, "No provider could retrieve the data"
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
for provider in self.providers:
for provider, label in self._iter_providers():
try:
result = await provider.get_user_models(username)
result = await self._call_with_rate_limit(
label,
provider.get_user_models,
username,
)
if result is not None:
return result
except RateLimitError as exc:
exc.provider = exc.provider or label
raise exc
except Exception as e:
logger.debug(f"Provider failed for get_user_models: {e}")
logger.debug("Provider %s failed for get_user_models: %s", label, e)
continue
return None
def _iter_providers(self):
return zip(self.providers, self._provider_labels)
async def _call_with_rate_limit(
self,
label: str,
func,
*args,
**kwargs,
):
attempt = 0
while True:
try:
return await func(*args, **kwargs)
except RateLimitError as exc:
attempt += 1
if attempt >= self._rate_limit_retry_limit:
exc.provider = exc.provider or label
raise exc
delay = self._calculate_rate_limit_delay(exc.retry_after, attempt)
logger.warning(
"Provider %s rate limited request; retrying in %.2fs (attempt %s/%s)",
label,
delay,
attempt,
self._rate_limit_retry_limit,
)
await asyncio.sleep(delay)
except Exception:
raise
def _calculate_rate_limit_delay(self, retry_after: Optional[float], attempt: int) -> float:
if retry_after is not None:
return min(self._rate_limit_max_delay, max(0.0, retry_after))
base_delay = self._rate_limit_base_delay * (2 ** max(0, attempt - 1))
jitter_span = base_delay * self._rate_limit_jitter_ratio
if jitter_span > 0:
base_delay += random.uniform(-jitter_span, jitter_span)
return min(self._rate_limit_max_delay, max(0.0, base_delay))
class ModelMetadataProviderManager:
"""Manager for selecting and using model metadata providers"""