feat(civitai): add rate limiting support and error handling

- Add RateLimitError import and _make_request wrapper method to handle rate limiting
- Update API methods to use _make_request wrapper instead of direct downloader calls
- Add explicit RateLimitError handling in API methods to properly propagate rate limit errors
- Add _extract_retry_after method to parse Retry-After headers
- Improve error handling by surfacing rate limit information to callers

These changes ensure that rate limiting from the Civitai API is properly detected and handled, allowing callers to implement appropriate backoff strategies when rate limits are encountered.
This commit is contained in:
Will Miao
2025-10-14 20:52:01 +08:00
committed by pixelpaws
parent 1454991d6d
commit 3fde474583
9 changed files with 397 additions and 50 deletions

View File

@@ -5,6 +5,7 @@ import os
from typing import Optional, Dict, Tuple, List
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager
from .downloader import get_downloader
from .errors import RateLimitError
logger = logging.getLogger(__name__)
@@ -33,6 +34,29 @@ class CivitaiClient:
self.base_url = "https://civitai.com/api/v1"
async def _make_request(
self,
method: str,
url: str,
*,
use_auth: bool = False,
**kwargs,
) -> Tuple[bool, Dict | str]:
"""Wrapper around downloader.make_request that surfaces rate limits."""
downloader = await get_downloader()
success, result = await downloader.make_request(
method,
url,
use_auth=use_auth,
**kwargs,
)
if not success and isinstance(result, RateLimitError):
if result.provider is None:
result.provider = "civitai_api"
raise result
return success, result
@staticmethod
def _remove_comfy_metadata(model_version: Optional[Dict]) -> None:
"""Remove Comfy-specific metadata from model version images."""
@@ -79,8 +103,7 @@ class CivitaiClient:
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
try:
downloader = await get_downloader()
success, result = await downloader.make_request(
success, result = await self._make_request(
'GET',
f"{self.base_url}/model-versions/by-hash/{model_hash}",
use_auth=True
@@ -90,7 +113,7 @@ class CivitaiClient:
model_id = result.get('modelId')
if model_id:
# Fetch additional model metadata
success_model, data = await downloader.make_request(
success_model, data = await self._make_request(
'GET',
f"{self.base_url}/models/{model_id}",
use_auth=True
@@ -113,6 +136,8 @@ class CivitaiClient:
# Other error cases
logger.error(f"Failed to fetch model info for {model_hash[:10]}: {result}")
return None, str(result)
except RateLimitError:
raise
except Exception as e:
logger.error(f"API Error: {str(e)}")
return None, str(e)
@@ -138,8 +163,7 @@ class CivitaiClient:
async def get_model_versions(self, model_id: str) -> List[Dict]:
"""Get all versions of a model with local availability info"""
try:
downloader = await get_downloader()
success, result = await downloader.make_request(
success, result = await self._make_request(
'GET',
f"{self.base_url}/models/{model_id}",
use_auth=True
@@ -152,6 +176,8 @@ class CivitaiClient:
'name': result.get('name', '')
}
return None
except RateLimitError:
raise
except Exception as e:
logger.error(f"Error fetching model versions: {e}")
return None
@@ -159,23 +185,23 @@ class CivitaiClient:
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
"""Get specific model version with additional metadata."""
try:
downloader = await get_downloader()
if model_id is None and version_id is not None:
return await self._get_version_by_id_only(downloader, version_id)
return await self._get_version_by_id_only(version_id)
if model_id is not None:
return await self._get_version_with_model_id(downloader, model_id, version_id)
return await self._get_version_with_model_id(model_id, version_id)
logger.error("Either model_id or version_id must be provided")
return None
except RateLimitError:
raise
except Exception as e:
logger.error(f"Error fetching model version: {e}")
return None
async def _get_version_by_id_only(self, downloader, version_id: int) -> Optional[Dict]:
version = await self._fetch_version_by_id(downloader, version_id)
async def _get_version_by_id_only(self, version_id: int) -> Optional[Dict]:
version = await self._fetch_version_by_id(version_id)
if version is None:
return None
@@ -184,15 +210,15 @@ class CivitaiClient:
logger.error(f"No modelId found in version {version_id}")
return None
model_data = await self._fetch_model_data(downloader, model_id)
model_data = await self._fetch_model_data(model_id)
if model_data:
self._enrich_version_with_model_data(version, model_data)
self._remove_comfy_metadata(version)
return version
async def _get_version_with_model_id(self, downloader, model_id: int, version_id: Optional[int]) -> Optional[Dict]:
model_data = await self._fetch_model_data(downloader, model_id)
async def _get_version_with_model_id(self, model_id: int, version_id: Optional[int]) -> Optional[Dict]:
model_data = await self._fetch_model_data(model_id)
if not model_data:
return None
@@ -201,12 +227,12 @@ class CivitaiClient:
return None
target_version_id = target_version.get('id')
version = await self._fetch_version_by_id(downloader, target_version_id) if target_version_id else None
version = await self._fetch_version_by_id(target_version_id) if target_version_id else None
if version is None:
model_hash = self._extract_primary_model_hash(target_version)
if model_hash:
version = await self._fetch_version_by_hash(downloader, model_hash)
version = await self._fetch_version_by_hash(model_hash)
else:
logger.warning(
f"No primary model hash found for model {model_id} version {target_version_id}"
@@ -219,8 +245,8 @@ class CivitaiClient:
self._remove_comfy_metadata(version)
return version
async def _fetch_model_data(self, downloader, model_id: int) -> Optional[Dict]:
success, data = await downloader.make_request(
async def _fetch_model_data(self, model_id: int) -> Optional[Dict]:
success, data = await self._make_request(
'GET',
f"{self.base_url}/models/{model_id}",
use_auth=True
@@ -230,11 +256,11 @@ class CivitaiClient:
logger.warning(f"Failed to fetch model data for model {model_id}")
return None
async def _fetch_version_by_id(self, downloader, version_id: Optional[int]) -> Optional[Dict]:
async def _fetch_version_by_id(self, version_id: Optional[int]) -> Optional[Dict]:
if version_id is None:
return None
success, version = await downloader.make_request(
success, version = await self._make_request(
'GET',
f"{self.base_url}/model-versions/{version_id}",
use_auth=True
@@ -245,11 +271,11 @@ class CivitaiClient:
logger.warning(f"Failed to fetch version by id {version_id}")
return None
async def _fetch_version_by_hash(self, downloader, model_hash: Optional[str]) -> Optional[Dict]:
async def _fetch_version_by_hash(self, model_hash: Optional[str]) -> Optional[Dict]:
if not model_hash:
return None
success, version = await downloader.make_request(
success, version = await self._make_request(
'GET',
f"{self.base_url}/model-versions/by-hash/{model_hash}",
use_auth=True
@@ -323,11 +349,10 @@ class CivitaiClient:
- An error message if there was an error, or None on success
"""
try:
downloader = await get_downloader()
url = f"{self.base_url}/model-versions/{version_id}"
logger.debug(f"Resolving DNS for model version info: {url}")
success, result = await downloader.make_request(
success, result = await self._make_request(
'GET',
url,
use_auth=True
@@ -347,6 +372,8 @@ class CivitaiClient:
# Other error cases
logger.error(f"Failed to fetch model info for {version_id}: {result}")
return None, str(result)
except RateLimitError:
raise
except Exception as e:
error_msg = f"Error fetching model version info: {e}"
logger.error(error_msg)
@@ -362,11 +389,10 @@ class CivitaiClient:
Optional[Dict]: The image data or None if not found
"""
try:
downloader = await get_downloader()
url = f"{self.base_url}/images?imageId={image_id}&nsfw=X"
logger.debug(f"Fetching image info for ID: {image_id}")
success, result = await downloader.make_request(
success, result = await self._make_request(
'GET',
url,
use_auth=True
@@ -381,6 +407,8 @@ class CivitaiClient:
logger.error(f"Failed to fetch image info for ID: {image_id}: {result}")
return None
except RateLimitError:
raise
except Exception as e:
error_msg = f"Error fetching image info: {e}"
logger.error(error_msg)
@@ -392,9 +420,8 @@ class CivitaiClient:
return None
try:
downloader = await get_downloader()
url = f"{self.base_url}/models?username={username}"
success, result = await downloader.make_request(
success, result = await self._make_request(
'GET',
url,
use_auth=True
@@ -416,6 +443,8 @@ class CivitaiClient:
self._remove_comfy_metadata(version)
return items
except RateLimitError:
raise
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Error fetching models for %s: %s", username, exc)
return None

View File

@@ -17,8 +17,10 @@ import aiohttp
from collections import deque
from dataclasses import dataclass
from datetime import datetime, timedelta
from email.utils import parsedate_to_datetime
from typing import Optional, Dict, Tuple, Callable, Union, Awaitable
from ..services.settings_manager import get_settings_manager
from .errors import RateLimitError
logger = logging.getLogger(__name__)
@@ -587,6 +589,19 @@ class Downloader:
return False, "Access forbidden"
elif response.status == 404:
return False, "Resource not found"
elif response.status == 429:
retry_after = self._extract_retry_after(response.headers)
error_msg = "Request rate limited"
logger.warning(
"Rate limit encountered for %s %s; retry_after=%s",
method,
url,
retry_after,
)
return False, RateLimitError(
error_msg,
retry_after=retry_after,
)
else:
return False, f"Request failed with status {response.status}"
@@ -608,6 +623,38 @@ class Downloader:
await self._create_session()
logger.info("HTTP session refreshed due to settings change")
@staticmethod
def _extract_retry_after(headers) -> Optional[float]:
"""Parse the Retry-After header into seconds."""
if not headers:
return None
header_value = headers.get("Retry-After")
if not header_value:
return None
header_value = header_value.strip()
if not header_value:
return None
if header_value.isdigit():
try:
seconds = float(header_value)
except ValueError:
return None
return max(0.0, seconds)
try:
retry_datetime = parsedate_to_datetime(header_value)
except (TypeError, ValueError):
return None
if retry_datetime.tzinfo is None:
return None
delta = retry_datetime - datetime.now(tz=retry_datetime.tzinfo)
return max(0.0, delta.total_seconds())
# Global instance accessor
async def get_downloader() -> Downloader:

21
py/services/errors.py Normal file
View File

@@ -0,0 +1,21 @@
"""Common service-level exception types."""
from __future__ import annotations
from typing import Optional
class RateLimitError(RuntimeError):
"""Raised when a remote provider rejects a request due to rate limiting."""
def __init__(
self,
message: str,
*,
retry_after: Optional[float] = None,
provider: Optional[str] = None,
) -> None:
super().__init__(message)
self.retry_after = retry_after
self.provider = provider

View File

@@ -1,6 +1,7 @@
import os
import logging
from .model_metadata_provider import (
ModelMetadataProvider,
ModelMetadataProviderManager,
SQLiteModelMetadataProvider,
CivitaiModelMetadataProvider,
@@ -68,10 +69,10 @@ async def initialize_metadata_providers():
# Set up fallback provider based on available providers
if len(providers) > 1:
# Always use Civitai API (it has better metadata), then CivArchive API, then Archive DB
ordered_providers = []
ordered_providers.extend([p[1] for p in providers if p[0] == 'civitai_api'])
ordered_providers.extend([p[1] for p in providers if p[0] == 'civarchive_api'])
ordered_providers.extend([p[1] for p in providers if p[0] == 'sqlite'])
ordered_providers: list[tuple[str, ModelMetadataProvider]] = []
ordered_providers.extend([p for p in providers if p[0] == 'civitai_api'])
ordered_providers.extend([p for p in providers if p[0] == 'civarchive_api'])
ordered_providers.extend([p for p in providers if p[0] == 'sqlite'])
if ordered_providers:
fallback_provider = FallbackMetadataProvider(ordered_providers)

View File

@@ -10,6 +10,7 @@ from typing import Any, Awaitable, Callable, Dict, Iterable, Optional
from ..services.settings_manager import SettingsManager
from ..utils.model_utils import determine_base_model
from .errors import RateLimitError
logger = logging.getLogger(__name__)
@@ -205,6 +206,9 @@ class MetadataSyncService:
for provider_name, provider in provider_attempts:
try:
civitai_metadata_candidate, error = await provider.get_model_by_hash(sha256)
except RateLimitError as exc:
exc.provider = exc.provider or (provider_name or provider.__class__.__name__)
raise
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Provider %s failed for hash %s: %s", provider_name, sha256, exc)
civitai_metadata_candidate, error = None, str(exc)
@@ -299,6 +303,16 @@ class MetadataSyncService:
error_msg = f"Error fetching metadata - Missing key: {exc} in model_data={model_data}"
logger.error(error_msg)
return False, error_msg
except RateLimitError as exc:
provider_label = exc.provider or "metadata provider"
wait_hint = (
f"; retry after approximately {int(exc.retry_after)}s"
if exc.retry_after and exc.retry_after > 0
else ""
)
error_msg = f"Rate limited by {provider_label}{wait_hint}"
logger.warning(error_msg)
return False, error_msg
except Exception as exc: # pragma: no cover - error path
error_msg = f"Error fetching metadata: {exc}"
logger.error(error_msg, exc_info=True)

View File

@@ -1,8 +1,11 @@
from abc import ABC, abstractmethod
import asyncio
import json
import logging
from typing import Optional, Dict, Tuple, Any, List
import random
from typing import Optional, Dict, Tuple, Any, List, Sequence
from .downloader import get_downloader
from .errors import RateLimitError
try:
from bs4 import BeautifulSoup
@@ -350,64 +353,166 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
class FallbackMetadataProvider(ModelMetadataProvider):
"""Try providers in order, return first successful result."""
def __init__(self, providers: list):
self.providers = providers
def __init__(
self,
providers: Sequence[ModelMetadataProvider | Tuple[str, ModelMetadataProvider]],
*,
rate_limit_retry_limit: int = 3,
rate_limit_base_delay: float = 1.5,
rate_limit_max_delay: float = 30.0,
rate_limit_jitter_ratio: float = 0.2,
) -> None:
self.providers: List[ModelMetadataProvider] = []
self._provider_labels: List[str] = []
for entry in providers:
if isinstance(entry, tuple) and len(entry) == 2:
name, provider = entry
else:
provider = entry
name = provider.__class__.__name__
self.providers.append(provider)
self._provider_labels.append(str(name))
self._rate_limit_retry_limit = max(1, rate_limit_retry_limit)
self._rate_limit_base_delay = rate_limit_base_delay
self._rate_limit_max_delay = rate_limit_max_delay
self._rate_limit_jitter_ratio = max(0.0, rate_limit_jitter_ratio)
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
for provider in self.providers:
for provider, label in self._iter_providers():
try:
result, error = await provider.get_model_by_hash(model_hash)
result, error = await self._call_with_rate_limit(
label,
provider.get_model_by_hash,
model_hash,
)
if result:
return result, error
except RateLimitError as exc:
exc.provider = exc.provider or label
raise exc
except Exception as e:
logger.debug(f"Provider failed for get_model_by_hash: {e}")
logger.debug("Provider %s failed for get_model_by_hash: %s", label, e)
continue
return None, "Model not found"
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
for provider in self.providers:
for provider, label in self._iter_providers():
try:
result = await provider.get_model_versions(model_id)
result = await self._call_with_rate_limit(
label,
provider.get_model_versions,
model_id,
)
if result:
return result
except RateLimitError as exc:
exc.provider = exc.provider or label
raise exc
except Exception as e:
logger.debug(f"Provider failed for get_model_versions: {e}")
logger.debug("Provider %s failed for get_model_versions: %s", label, e)
continue
return None
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
for provider in self.providers:
for provider, label in self._iter_providers():
try:
result = await provider.get_model_version(model_id, version_id)
result = await self._call_with_rate_limit(
label,
provider.get_model_version,
model_id,
version_id,
)
if result:
return result
except RateLimitError as exc:
exc.provider = exc.provider or label
raise exc
except Exception as e:
logger.debug(f"Provider failed for get_model_version: {e}")
logger.debug("Provider %s failed for get_model_version: %s", label, e)
continue
return None
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
for provider in self.providers:
for provider, label in self._iter_providers():
try:
result, error = await provider.get_model_version_info(version_id)
result, error = await self._call_with_rate_limit(
label,
provider.get_model_version_info,
version_id,
)
if result:
return result, error
except RateLimitError as exc:
exc.provider = exc.provider or label
raise exc
except Exception as e:
logger.debug(f"Provider failed for get_model_version_info: {e}")
logger.debug("Provider %s failed for get_model_version_info: %s", label, e)
continue
return None, "No provider could retrieve the data"
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
for provider in self.providers:
for provider, label in self._iter_providers():
try:
result = await provider.get_user_models(username)
result = await self._call_with_rate_limit(
label,
provider.get_user_models,
username,
)
if result is not None:
return result
except RateLimitError as exc:
exc.provider = exc.provider or label
raise exc
except Exception as e:
logger.debug(f"Provider failed for get_user_models: {e}")
logger.debug("Provider %s failed for get_user_models: %s", label, e)
continue
return None
def _iter_providers(self):
return zip(self.providers, self._provider_labels)
async def _call_with_rate_limit(
self,
label: str,
func,
*args,
**kwargs,
):
attempt = 0
while True:
try:
return await func(*args, **kwargs)
except RateLimitError as exc:
attempt += 1
if attempt >= self._rate_limit_retry_limit:
exc.provider = exc.provider or label
raise exc
delay = self._calculate_rate_limit_delay(exc.retry_after, attempt)
logger.warning(
"Provider %s rate limited request; retrying in %.2fs (attempt %s/%s)",
label,
delay,
attempt,
self._rate_limit_retry_limit,
)
await asyncio.sleep(delay)
except Exception:
raise
def _calculate_rate_limit_delay(self, retry_after: Optional[float], attempt: int) -> float:
if retry_after is not None:
return min(self._rate_limit_max_delay, max(0.0, retry_after))
base_delay = self._rate_limit_base_delay * (2 ** max(0, attempt - 1))
jitter_span = base_delay * self._rate_limit_jitter_ratio
if jitter_span > 0:
base_delay += random.uniform(-jitter_span, jitter_span)
return min(self._rate_limit_max_delay, max(0.0, base_delay))
class ModelMetadataProviderManager:
"""Manager for selecting and using model metadata providers"""