mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
feat(civitai): add rate limiting support and error handling
- Add RateLimitError import and _make_request wrapper method to handle rate limiting - Update API methods to use _make_request wrapper instead of direct downloader calls - Add explicit RateLimitError handling in API methods to properly propagate rate limit errors - Add _extract_retry_after method to parse Retry-After headers - Improve error handling by surfacing rate limit information to callers These changes ensure that rate limiting from the Civitai API is properly detected and handled, allowing callers to implement appropriate backoff strategies when rate limits are encountered.
This commit is contained in:
@@ -5,6 +5,7 @@ import os
|
|||||||
from typing import Optional, Dict, Tuple, List
|
from typing import Optional, Dict, Tuple, List
|
||||||
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager
|
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager
|
||||||
from .downloader import get_downloader
|
from .downloader import get_downloader
|
||||||
|
from .errors import RateLimitError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -33,6 +34,29 @@ class CivitaiClient:
|
|||||||
|
|
||||||
self.base_url = "https://civitai.com/api/v1"
|
self.base_url = "https://civitai.com/api/v1"
|
||||||
|
|
||||||
|
async def _make_request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
url: str,
|
||||||
|
*,
|
||||||
|
use_auth: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[bool, Dict | str]:
|
||||||
|
"""Wrapper around downloader.make_request that surfaces rate limits."""
|
||||||
|
|
||||||
|
downloader = await get_downloader()
|
||||||
|
success, result = await downloader.make_request(
|
||||||
|
method,
|
||||||
|
url,
|
||||||
|
use_auth=use_auth,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
if not success and isinstance(result, RateLimitError):
|
||||||
|
if result.provider is None:
|
||||||
|
result.provider = "civitai_api"
|
||||||
|
raise result
|
||||||
|
return success, result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _remove_comfy_metadata(model_version: Optional[Dict]) -> None:
|
def _remove_comfy_metadata(model_version: Optional[Dict]) -> None:
|
||||||
"""Remove Comfy-specific metadata from model version images."""
|
"""Remove Comfy-specific metadata from model version images."""
|
||||||
@@ -79,8 +103,7 @@ class CivitaiClient:
|
|||||||
|
|
||||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||||
try:
|
try:
|
||||||
downloader = await get_downloader()
|
success, result = await self._make_request(
|
||||||
success, result = await downloader.make_request(
|
|
||||||
'GET',
|
'GET',
|
||||||
f"{self.base_url}/model-versions/by-hash/{model_hash}",
|
f"{self.base_url}/model-versions/by-hash/{model_hash}",
|
||||||
use_auth=True
|
use_auth=True
|
||||||
@@ -90,7 +113,7 @@ class CivitaiClient:
|
|||||||
model_id = result.get('modelId')
|
model_id = result.get('modelId')
|
||||||
if model_id:
|
if model_id:
|
||||||
# Fetch additional model metadata
|
# Fetch additional model metadata
|
||||||
success_model, data = await downloader.make_request(
|
success_model, data = await self._make_request(
|
||||||
'GET',
|
'GET',
|
||||||
f"{self.base_url}/models/{model_id}",
|
f"{self.base_url}/models/{model_id}",
|
||||||
use_auth=True
|
use_auth=True
|
||||||
@@ -113,6 +136,8 @@ class CivitaiClient:
|
|||||||
# Other error cases
|
# Other error cases
|
||||||
logger.error(f"Failed to fetch model info for {model_hash[:10]}: {result}")
|
logger.error(f"Failed to fetch model info for {model_hash[:10]}: {result}")
|
||||||
return None, str(result)
|
return None, str(result)
|
||||||
|
except RateLimitError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"API Error: {str(e)}")
|
logger.error(f"API Error: {str(e)}")
|
||||||
return None, str(e)
|
return None, str(e)
|
||||||
@@ -138,8 +163,7 @@ class CivitaiClient:
|
|||||||
async def get_model_versions(self, model_id: str) -> List[Dict]:
|
async def get_model_versions(self, model_id: str) -> List[Dict]:
|
||||||
"""Get all versions of a model with local availability info"""
|
"""Get all versions of a model with local availability info"""
|
||||||
try:
|
try:
|
||||||
downloader = await get_downloader()
|
success, result = await self._make_request(
|
||||||
success, result = await downloader.make_request(
|
|
||||||
'GET',
|
'GET',
|
||||||
f"{self.base_url}/models/{model_id}",
|
f"{self.base_url}/models/{model_id}",
|
||||||
use_auth=True
|
use_auth=True
|
||||||
@@ -152,6 +176,8 @@ class CivitaiClient:
|
|||||||
'name': result.get('name', '')
|
'name': result.get('name', '')
|
||||||
}
|
}
|
||||||
return None
|
return None
|
||||||
|
except RateLimitError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching model versions: {e}")
|
logger.error(f"Error fetching model versions: {e}")
|
||||||
return None
|
return None
|
||||||
@@ -159,23 +185,23 @@ class CivitaiClient:
|
|||||||
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||||
"""Get specific model version with additional metadata."""
|
"""Get specific model version with additional metadata."""
|
||||||
try:
|
try:
|
||||||
downloader = await get_downloader()
|
|
||||||
|
|
||||||
if model_id is None and version_id is not None:
|
if model_id is None and version_id is not None:
|
||||||
return await self._get_version_by_id_only(downloader, version_id)
|
return await self._get_version_by_id_only(version_id)
|
||||||
|
|
||||||
if model_id is not None:
|
if model_id is not None:
|
||||||
return await self._get_version_with_model_id(downloader, model_id, version_id)
|
return await self._get_version_with_model_id(model_id, version_id)
|
||||||
|
|
||||||
logger.error("Either model_id or version_id must be provided")
|
logger.error("Either model_id or version_id must be provided")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
except RateLimitError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching model version: {e}")
|
logger.error(f"Error fetching model version: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _get_version_by_id_only(self, downloader, version_id: int) -> Optional[Dict]:
|
async def _get_version_by_id_only(self, version_id: int) -> Optional[Dict]:
|
||||||
version = await self._fetch_version_by_id(downloader, version_id)
|
version = await self._fetch_version_by_id(version_id)
|
||||||
if version is None:
|
if version is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -184,15 +210,15 @@ class CivitaiClient:
|
|||||||
logger.error(f"No modelId found in version {version_id}")
|
logger.error(f"No modelId found in version {version_id}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
model_data = await self._fetch_model_data(downloader, model_id)
|
model_data = await self._fetch_model_data(model_id)
|
||||||
if model_data:
|
if model_data:
|
||||||
self._enrich_version_with_model_data(version, model_data)
|
self._enrich_version_with_model_data(version, model_data)
|
||||||
|
|
||||||
self._remove_comfy_metadata(version)
|
self._remove_comfy_metadata(version)
|
||||||
return version
|
return version
|
||||||
|
|
||||||
async def _get_version_with_model_id(self, downloader, model_id: int, version_id: Optional[int]) -> Optional[Dict]:
|
async def _get_version_with_model_id(self, model_id: int, version_id: Optional[int]) -> Optional[Dict]:
|
||||||
model_data = await self._fetch_model_data(downloader, model_id)
|
model_data = await self._fetch_model_data(model_id)
|
||||||
if not model_data:
|
if not model_data:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -201,12 +227,12 @@ class CivitaiClient:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
target_version_id = target_version.get('id')
|
target_version_id = target_version.get('id')
|
||||||
version = await self._fetch_version_by_id(downloader, target_version_id) if target_version_id else None
|
version = await self._fetch_version_by_id(target_version_id) if target_version_id else None
|
||||||
|
|
||||||
if version is None:
|
if version is None:
|
||||||
model_hash = self._extract_primary_model_hash(target_version)
|
model_hash = self._extract_primary_model_hash(target_version)
|
||||||
if model_hash:
|
if model_hash:
|
||||||
version = await self._fetch_version_by_hash(downloader, model_hash)
|
version = await self._fetch_version_by_hash(model_hash)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"No primary model hash found for model {model_id} version {target_version_id}"
|
f"No primary model hash found for model {model_id} version {target_version_id}"
|
||||||
@@ -219,8 +245,8 @@ class CivitaiClient:
|
|||||||
self._remove_comfy_metadata(version)
|
self._remove_comfy_metadata(version)
|
||||||
return version
|
return version
|
||||||
|
|
||||||
async def _fetch_model_data(self, downloader, model_id: int) -> Optional[Dict]:
|
async def _fetch_model_data(self, model_id: int) -> Optional[Dict]:
|
||||||
success, data = await downloader.make_request(
|
success, data = await self._make_request(
|
||||||
'GET',
|
'GET',
|
||||||
f"{self.base_url}/models/{model_id}",
|
f"{self.base_url}/models/{model_id}",
|
||||||
use_auth=True
|
use_auth=True
|
||||||
@@ -230,11 +256,11 @@ class CivitaiClient:
|
|||||||
logger.warning(f"Failed to fetch model data for model {model_id}")
|
logger.warning(f"Failed to fetch model data for model {model_id}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _fetch_version_by_id(self, downloader, version_id: Optional[int]) -> Optional[Dict]:
|
async def _fetch_version_by_id(self, version_id: Optional[int]) -> Optional[Dict]:
|
||||||
if version_id is None:
|
if version_id is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
success, version = await downloader.make_request(
|
success, version = await self._make_request(
|
||||||
'GET',
|
'GET',
|
||||||
f"{self.base_url}/model-versions/{version_id}",
|
f"{self.base_url}/model-versions/{version_id}",
|
||||||
use_auth=True
|
use_auth=True
|
||||||
@@ -245,11 +271,11 @@ class CivitaiClient:
|
|||||||
logger.warning(f"Failed to fetch version by id {version_id}")
|
logger.warning(f"Failed to fetch version by id {version_id}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _fetch_version_by_hash(self, downloader, model_hash: Optional[str]) -> Optional[Dict]:
|
async def _fetch_version_by_hash(self, model_hash: Optional[str]) -> Optional[Dict]:
|
||||||
if not model_hash:
|
if not model_hash:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
success, version = await downloader.make_request(
|
success, version = await self._make_request(
|
||||||
'GET',
|
'GET',
|
||||||
f"{self.base_url}/model-versions/by-hash/{model_hash}",
|
f"{self.base_url}/model-versions/by-hash/{model_hash}",
|
||||||
use_auth=True
|
use_auth=True
|
||||||
@@ -323,11 +349,10 @@ class CivitaiClient:
|
|||||||
- An error message if there was an error, or None on success
|
- An error message if there was an error, or None on success
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
downloader = await get_downloader()
|
|
||||||
url = f"{self.base_url}/model-versions/{version_id}"
|
url = f"{self.base_url}/model-versions/{version_id}"
|
||||||
|
|
||||||
logger.debug(f"Resolving DNS for model version info: {url}")
|
logger.debug(f"Resolving DNS for model version info: {url}")
|
||||||
success, result = await downloader.make_request(
|
success, result = await self._make_request(
|
||||||
'GET',
|
'GET',
|
||||||
url,
|
url,
|
||||||
use_auth=True
|
use_auth=True
|
||||||
@@ -347,6 +372,8 @@ class CivitaiClient:
|
|||||||
# Other error cases
|
# Other error cases
|
||||||
logger.error(f"Failed to fetch model info for {version_id}: {result}")
|
logger.error(f"Failed to fetch model info for {version_id}: {result}")
|
||||||
return None, str(result)
|
return None, str(result)
|
||||||
|
except RateLimitError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error fetching model version info: {e}"
|
error_msg = f"Error fetching model version info: {e}"
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
@@ -362,11 +389,10 @@ class CivitaiClient:
|
|||||||
Optional[Dict]: The image data or None if not found
|
Optional[Dict]: The image data or None if not found
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
downloader = await get_downloader()
|
|
||||||
url = f"{self.base_url}/images?imageId={image_id}&nsfw=X"
|
url = f"{self.base_url}/images?imageId={image_id}&nsfw=X"
|
||||||
|
|
||||||
logger.debug(f"Fetching image info for ID: {image_id}")
|
logger.debug(f"Fetching image info for ID: {image_id}")
|
||||||
success, result = await downloader.make_request(
|
success, result = await self._make_request(
|
||||||
'GET',
|
'GET',
|
||||||
url,
|
url,
|
||||||
use_auth=True
|
use_auth=True
|
||||||
@@ -381,6 +407,8 @@ class CivitaiClient:
|
|||||||
|
|
||||||
logger.error(f"Failed to fetch image info for ID: {image_id}: {result}")
|
logger.error(f"Failed to fetch image info for ID: {image_id}: {result}")
|
||||||
return None
|
return None
|
||||||
|
except RateLimitError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error fetching image info: {e}"
|
error_msg = f"Error fetching image info: {e}"
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
@@ -392,9 +420,8 @@ class CivitaiClient:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
downloader = await get_downloader()
|
|
||||||
url = f"{self.base_url}/models?username={username}"
|
url = f"{self.base_url}/models?username={username}"
|
||||||
success, result = await downloader.make_request(
|
success, result = await self._make_request(
|
||||||
'GET',
|
'GET',
|
||||||
url,
|
url,
|
||||||
use_auth=True
|
use_auth=True
|
||||||
@@ -416,6 +443,8 @@ class CivitaiClient:
|
|||||||
self._remove_comfy_metadata(version)
|
self._remove_comfy_metadata(version)
|
||||||
|
|
||||||
return items
|
return items
|
||||||
|
except RateLimitError:
|
||||||
|
raise
|
||||||
except Exception as exc: # pragma: no cover - defensive logging
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
logger.error("Error fetching models for %s: %s", username, exc)
|
logger.error("Error fetching models for %s: %s", username, exc)
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -17,8 +17,10 @@ import aiohttp
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
from email.utils import parsedate_to_datetime
|
||||||
from typing import Optional, Dict, Tuple, Callable, Union, Awaitable
|
from typing import Optional, Dict, Tuple, Callable, Union, Awaitable
|
||||||
from ..services.settings_manager import get_settings_manager
|
from ..services.settings_manager import get_settings_manager
|
||||||
|
from .errors import RateLimitError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -587,6 +589,19 @@ class Downloader:
|
|||||||
return False, "Access forbidden"
|
return False, "Access forbidden"
|
||||||
elif response.status == 404:
|
elif response.status == 404:
|
||||||
return False, "Resource not found"
|
return False, "Resource not found"
|
||||||
|
elif response.status == 429:
|
||||||
|
retry_after = self._extract_retry_after(response.headers)
|
||||||
|
error_msg = "Request rate limited"
|
||||||
|
logger.warning(
|
||||||
|
"Rate limit encountered for %s %s; retry_after=%s",
|
||||||
|
method,
|
||||||
|
url,
|
||||||
|
retry_after,
|
||||||
|
)
|
||||||
|
return False, RateLimitError(
|
||||||
|
error_msg,
|
||||||
|
retry_after=retry_after,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return False, f"Request failed with status {response.status}"
|
return False, f"Request failed with status {response.status}"
|
||||||
|
|
||||||
@@ -608,6 +623,38 @@ class Downloader:
|
|||||||
await self._create_session()
|
await self._create_session()
|
||||||
logger.info("HTTP session refreshed due to settings change")
|
logger.info("HTTP session refreshed due to settings change")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_retry_after(headers) -> Optional[float]:
|
||||||
|
"""Parse the Retry-After header into seconds."""
|
||||||
|
if not headers:
|
||||||
|
return None
|
||||||
|
|
||||||
|
header_value = headers.get("Retry-After")
|
||||||
|
if not header_value:
|
||||||
|
return None
|
||||||
|
|
||||||
|
header_value = header_value.strip()
|
||||||
|
if not header_value:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if header_value.isdigit():
|
||||||
|
try:
|
||||||
|
seconds = float(header_value)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
return max(0.0, seconds)
|
||||||
|
|
||||||
|
try:
|
||||||
|
retry_datetime = parsedate_to_datetime(header_value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
if retry_datetime.tzinfo is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
delta = retry_datetime - datetime.now(tz=retry_datetime.tzinfo)
|
||||||
|
return max(0.0, delta.total_seconds())
|
||||||
|
|
||||||
|
|
||||||
# Global instance accessor
|
# Global instance accessor
|
||||||
async def get_downloader() -> Downloader:
|
async def get_downloader() -> Downloader:
|
||||||
|
|||||||
21
py/services/errors.py
Normal file
21
py/services/errors.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
"""Common service-level exception types."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitError(RuntimeError):
|
||||||
|
"""Raised when a remote provider rejects a request due to rate limiting."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
*,
|
||||||
|
retry_after: Optional[float] = None,
|
||||||
|
provider: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message)
|
||||||
|
self.retry_after = retry_after
|
||||||
|
self.provider = provider
|
||||||
|
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from .model_metadata_provider import (
|
from .model_metadata_provider import (
|
||||||
|
ModelMetadataProvider,
|
||||||
ModelMetadataProviderManager,
|
ModelMetadataProviderManager,
|
||||||
SQLiteModelMetadataProvider,
|
SQLiteModelMetadataProvider,
|
||||||
CivitaiModelMetadataProvider,
|
CivitaiModelMetadataProvider,
|
||||||
@@ -68,10 +69,10 @@ async def initialize_metadata_providers():
|
|||||||
# Set up fallback provider based on available providers
|
# Set up fallback provider based on available providers
|
||||||
if len(providers) > 1:
|
if len(providers) > 1:
|
||||||
# Always use Civitai API (it has better metadata), then CivArchive API, then Archive DB
|
# Always use Civitai API (it has better metadata), then CivArchive API, then Archive DB
|
||||||
ordered_providers = []
|
ordered_providers: list[tuple[str, ModelMetadataProvider]] = []
|
||||||
ordered_providers.extend([p[1] for p in providers if p[0] == 'civitai_api'])
|
ordered_providers.extend([p for p in providers if p[0] == 'civitai_api'])
|
||||||
ordered_providers.extend([p[1] for p in providers if p[0] == 'civarchive_api'])
|
ordered_providers.extend([p for p in providers if p[0] == 'civarchive_api'])
|
||||||
ordered_providers.extend([p[1] for p in providers if p[0] == 'sqlite'])
|
ordered_providers.extend([p for p in providers if p[0] == 'sqlite'])
|
||||||
|
|
||||||
if ordered_providers:
|
if ordered_providers:
|
||||||
fallback_provider = FallbackMetadataProvider(ordered_providers)
|
fallback_provider = FallbackMetadataProvider(ordered_providers)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from typing import Any, Awaitable, Callable, Dict, Iterable, Optional
|
|||||||
|
|
||||||
from ..services.settings_manager import SettingsManager
|
from ..services.settings_manager import SettingsManager
|
||||||
from ..utils.model_utils import determine_base_model
|
from ..utils.model_utils import determine_base_model
|
||||||
|
from .errors import RateLimitError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -205,6 +206,9 @@ class MetadataSyncService:
|
|||||||
for provider_name, provider in provider_attempts:
|
for provider_name, provider in provider_attempts:
|
||||||
try:
|
try:
|
||||||
civitai_metadata_candidate, error = await provider.get_model_by_hash(sha256)
|
civitai_metadata_candidate, error = await provider.get_model_by_hash(sha256)
|
||||||
|
except RateLimitError as exc:
|
||||||
|
exc.provider = exc.provider or (provider_name or provider.__class__.__name__)
|
||||||
|
raise
|
||||||
except Exception as exc: # pragma: no cover - defensive logging
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
logger.error("Provider %s failed for hash %s: %s", provider_name, sha256, exc)
|
logger.error("Provider %s failed for hash %s: %s", provider_name, sha256, exc)
|
||||||
civitai_metadata_candidate, error = None, str(exc)
|
civitai_metadata_candidate, error = None, str(exc)
|
||||||
@@ -299,6 +303,16 @@ class MetadataSyncService:
|
|||||||
error_msg = f"Error fetching metadata - Missing key: {exc} in model_data={model_data}"
|
error_msg = f"Error fetching metadata - Missing key: {exc} in model_data={model_data}"
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
return False, error_msg
|
return False, error_msg
|
||||||
|
except RateLimitError as exc:
|
||||||
|
provider_label = exc.provider or "metadata provider"
|
||||||
|
wait_hint = (
|
||||||
|
f"; retry after approximately {int(exc.retry_after)}s"
|
||||||
|
if exc.retry_after and exc.retry_after > 0
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
error_msg = f"Rate limited by {provider_label}{wait_hint}"
|
||||||
|
logger.warning(error_msg)
|
||||||
|
return False, error_msg
|
||||||
except Exception as exc: # pragma: no cover - error path
|
except Exception as exc: # pragma: no cover - error path
|
||||||
error_msg = f"Error fetching metadata: {exc}"
|
error_msg = f"Error fetching metadata: {exc}"
|
||||||
logger.error(error_msg, exc_info=True)
|
logger.error(error_msg, exc_info=True)
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, Dict, Tuple, Any, List
|
import random
|
||||||
|
from typing import Optional, Dict, Tuple, Any, List, Sequence
|
||||||
from .downloader import get_downloader
|
from .downloader import get_downloader
|
||||||
|
from .errors import RateLimitError
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
@@ -350,64 +353,166 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
|||||||
|
|
||||||
class FallbackMetadataProvider(ModelMetadataProvider):
|
class FallbackMetadataProvider(ModelMetadataProvider):
|
||||||
"""Try providers in order, return first successful result."""
|
"""Try providers in order, return first successful result."""
|
||||||
def __init__(self, providers: list):
|
|
||||||
self.providers = providers
|
def __init__(
|
||||||
|
self,
|
||||||
|
providers: Sequence[ModelMetadataProvider | Tuple[str, ModelMetadataProvider]],
|
||||||
|
*,
|
||||||
|
rate_limit_retry_limit: int = 3,
|
||||||
|
rate_limit_base_delay: float = 1.5,
|
||||||
|
rate_limit_max_delay: float = 30.0,
|
||||||
|
rate_limit_jitter_ratio: float = 0.2,
|
||||||
|
) -> None:
|
||||||
|
self.providers: List[ModelMetadataProvider] = []
|
||||||
|
self._provider_labels: List[str] = []
|
||||||
|
|
||||||
|
for entry in providers:
|
||||||
|
if isinstance(entry, tuple) and len(entry) == 2:
|
||||||
|
name, provider = entry
|
||||||
|
else:
|
||||||
|
provider = entry
|
||||||
|
name = provider.__class__.__name__
|
||||||
|
self.providers.append(provider)
|
||||||
|
self._provider_labels.append(str(name))
|
||||||
|
|
||||||
|
self._rate_limit_retry_limit = max(1, rate_limit_retry_limit)
|
||||||
|
self._rate_limit_base_delay = rate_limit_base_delay
|
||||||
|
self._rate_limit_max_delay = rate_limit_max_delay
|
||||||
|
self._rate_limit_jitter_ratio = max(0.0, rate_limit_jitter_ratio)
|
||||||
|
|
||||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||||
for provider in self.providers:
|
for provider, label in self._iter_providers():
|
||||||
try:
|
try:
|
||||||
result, error = await provider.get_model_by_hash(model_hash)
|
result, error = await self._call_with_rate_limit(
|
||||||
|
label,
|
||||||
|
provider.get_model_by_hash,
|
||||||
|
model_hash,
|
||||||
|
)
|
||||||
if result:
|
if result:
|
||||||
return result, error
|
return result, error
|
||||||
|
except RateLimitError as exc:
|
||||||
|
exc.provider = exc.provider or label
|
||||||
|
raise exc
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Provider failed for get_model_by_hash: {e}")
|
logger.debug("Provider %s failed for get_model_by_hash: %s", label, e)
|
||||||
continue
|
continue
|
||||||
return None, "Model not found"
|
return None, "Model not found"
|
||||||
|
|
||||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||||
for provider in self.providers:
|
for provider, label in self._iter_providers():
|
||||||
try:
|
try:
|
||||||
result = await provider.get_model_versions(model_id)
|
result = await self._call_with_rate_limit(
|
||||||
|
label,
|
||||||
|
provider.get_model_versions,
|
||||||
|
model_id,
|
||||||
|
)
|
||||||
if result:
|
if result:
|
||||||
return result
|
return result
|
||||||
|
except RateLimitError as exc:
|
||||||
|
exc.provider = exc.provider or label
|
||||||
|
raise exc
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Provider failed for get_model_versions: {e}")
|
logger.debug("Provider %s failed for get_model_versions: %s", label, e)
|
||||||
continue
|
continue
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||||
for provider in self.providers:
|
for provider, label in self._iter_providers():
|
||||||
try:
|
try:
|
||||||
result = await provider.get_model_version(model_id, version_id)
|
result = await self._call_with_rate_limit(
|
||||||
|
label,
|
||||||
|
provider.get_model_version,
|
||||||
|
model_id,
|
||||||
|
version_id,
|
||||||
|
)
|
||||||
if result:
|
if result:
|
||||||
return result
|
return result
|
||||||
|
except RateLimitError as exc:
|
||||||
|
exc.provider = exc.provider or label
|
||||||
|
raise exc
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Provider failed for get_model_version: {e}")
|
logger.debug("Provider %s failed for get_model_version: %s", label, e)
|
||||||
continue
|
continue
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||||
for provider in self.providers:
|
for provider, label in self._iter_providers():
|
||||||
try:
|
try:
|
||||||
result, error = await provider.get_model_version_info(version_id)
|
result, error = await self._call_with_rate_limit(
|
||||||
|
label,
|
||||||
|
provider.get_model_version_info,
|
||||||
|
version_id,
|
||||||
|
)
|
||||||
if result:
|
if result:
|
||||||
return result, error
|
return result, error
|
||||||
|
except RateLimitError as exc:
|
||||||
|
exc.provider = exc.provider or label
|
||||||
|
raise exc
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Provider failed for get_model_version_info: {e}")
|
logger.debug("Provider %s failed for get_model_version_info: %s", label, e)
|
||||||
continue
|
continue
|
||||||
return None, "No provider could retrieve the data"
|
return None, "No provider could retrieve the data"
|
||||||
|
|
||||||
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
||||||
for provider in self.providers:
|
for provider, label in self._iter_providers():
|
||||||
try:
|
try:
|
||||||
result = await provider.get_user_models(username)
|
result = await self._call_with_rate_limit(
|
||||||
|
label,
|
||||||
|
provider.get_user_models,
|
||||||
|
username,
|
||||||
|
)
|
||||||
if result is not None:
|
if result is not None:
|
||||||
return result
|
return result
|
||||||
|
except RateLimitError as exc:
|
||||||
|
exc.provider = exc.provider or label
|
||||||
|
raise exc
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Provider failed for get_user_models: {e}")
|
logger.debug("Provider %s failed for get_user_models: %s", label, e)
|
||||||
continue
|
continue
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _iter_providers(self):
|
||||||
|
return zip(self.providers, self._provider_labels)
|
||||||
|
|
||||||
|
async def _call_with_rate_limit(
|
||||||
|
self,
|
||||||
|
label: str,
|
||||||
|
func,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
attempt = 0
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
except RateLimitError as exc:
|
||||||
|
attempt += 1
|
||||||
|
if attempt >= self._rate_limit_retry_limit:
|
||||||
|
exc.provider = exc.provider or label
|
||||||
|
raise exc
|
||||||
|
delay = self._calculate_rate_limit_delay(exc.retry_after, attempt)
|
||||||
|
logger.warning(
|
||||||
|
"Provider %s rate limited request; retrying in %.2fs (attempt %s/%s)",
|
||||||
|
label,
|
||||||
|
delay,
|
||||||
|
attempt,
|
||||||
|
self._rate_limit_retry_limit,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
except Exception:
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _calculate_rate_limit_delay(self, retry_after: Optional[float], attempt: int) -> float:
|
||||||
|
if retry_after is not None:
|
||||||
|
return min(self._rate_limit_max_delay, max(0.0, retry_after))
|
||||||
|
|
||||||
|
base_delay = self._rate_limit_base_delay * (2 ** max(0, attempt - 1))
|
||||||
|
jitter_span = base_delay * self._rate_limit_jitter_ratio
|
||||||
|
if jitter_span > 0:
|
||||||
|
base_delay += random.uniform(-jitter_span, jitter_span)
|
||||||
|
|
||||||
|
return min(self._rate_limit_max_delay, max(0.0, base_delay))
|
||||||
|
|
||||||
class ModelMetadataProviderManager:
|
class ModelMetadataProviderManager:
|
||||||
"""Manager for selecting and using model metadata providers"""
|
"""Manager for selecting and using model metadata providers"""
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import pytest
|
|||||||
|
|
||||||
from py.services import civitai_client as civitai_client_module
|
from py.services import civitai_client as civitai_client_module
|
||||||
from py.services.civitai_client import CivitaiClient
|
from py.services.civitai_client import CivitaiClient
|
||||||
|
from py.services.errors import RateLimitError
|
||||||
from py.services.model_metadata_provider import ModelMetadataProviderManager
|
from py.services.model_metadata_provider import ModelMetadataProviderManager
|
||||||
|
|
||||||
|
|
||||||
@@ -106,6 +107,21 @@ async def test_get_model_by_hash_handles_not_found(monkeypatch, downloader):
|
|||||||
assert error == "Model not found"
|
assert error == "Model not found"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_model_by_hash_propagates_rate_limit(monkeypatch, downloader):
|
||||||
|
async def fake_make_request(method, url, use_auth=True):
|
||||||
|
return False, RateLimitError("limited", retry_after=4)
|
||||||
|
|
||||||
|
downloader.make_request = fake_make_request
|
||||||
|
|
||||||
|
client = await CivitaiClient.get_instance()
|
||||||
|
|
||||||
|
with pytest.raises(RateLimitError) as exc_info:
|
||||||
|
await client.get_model_by_hash("limited")
|
||||||
|
|
||||||
|
assert exc_info.value.retry_after == 4
|
||||||
|
assert exc_info.value.provider == "civitai_api"
|
||||||
|
|
||||||
|
|
||||||
async def test_download_preview_image_writes_file(tmp_path, downloader):
|
async def test_download_preview_image_writes_file(tmp_path, downloader):
|
||||||
client = await CivitaiClient.get_instance()
|
client = await CivitaiClient.get_instance()
|
||||||
target = tmp_path / "preview" / "image.jpg"
|
target = tmp_path / "preview" / "image.jpg"
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from unittest.mock import AsyncMock
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from py.services.errors import RateLimitError
|
||||||
from py.services.metadata_sync_service import MetadataSyncService
|
from py.services.metadata_sync_service import MetadataSyncService
|
||||||
|
|
||||||
|
|
||||||
@@ -340,6 +341,37 @@ async def test_fetch_and_update_model_falls_back_to_sqlite_after_civarchive_fail
|
|||||||
helpers.metadata_manager.save_metadata.assert_awaited()
|
helpers.metadata_manager.save_metadata.assert_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetch_and_update_model_returns_rate_limit_error(tmp_path):
|
||||||
|
rate_error = RateLimitError("limited", retry_after=7)
|
||||||
|
default_provider = SimpleNamespace(
|
||||||
|
get_model_by_hash=AsyncMock(side_effect=rate_error),
|
||||||
|
get_model_version=AsyncMock(),
|
||||||
|
)
|
||||||
|
helpers = build_service(default_provider=default_provider)
|
||||||
|
|
||||||
|
model_path = tmp_path / "model.safetensors"
|
||||||
|
model_data = {
|
||||||
|
"file_path": str(model_path),
|
||||||
|
"model_name": "Local",
|
||||||
|
}
|
||||||
|
update_cache = AsyncMock()
|
||||||
|
|
||||||
|
ok, error = await helpers.service.fetch_and_update_model(
|
||||||
|
sha256="deadbeef",
|
||||||
|
file_path=str(model_path),
|
||||||
|
model_data=model_data,
|
||||||
|
update_cache_func=update_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ok is False
|
||||||
|
assert error is not None and "Rate limited" in error
|
||||||
|
assert "7" in error
|
||||||
|
helpers.metadata_manager.save_metadata.assert_not_awaited()
|
||||||
|
update_cache.assert_not_awaited()
|
||||||
|
helpers.provider_selector.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_relink_metadata_fetches_version_and_updates_sha(tmp_path):
|
async def test_relink_metadata_fetches_version_and_updates_sha(tmp_path):
|
||||||
provider = SimpleNamespace(
|
provider = SimpleNamespace(
|
||||||
|
|||||||
82
tests/services/test_model_metadata_provider.py
Normal file
82
tests/services/test_model_metadata_provider.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from py.services import model_metadata_provider as provider_module
|
||||||
|
from py.services.errors import RateLimitError
|
||||||
|
from py.services.model_metadata_provider import FallbackMetadataProvider
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitThenSuccessProvider:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.calls = 0
|
||||||
|
|
||||||
|
async def get_model_by_hash(self, model_hash: str):
|
||||||
|
self.calls += 1
|
||||||
|
if self.calls == 1:
|
||||||
|
raise RateLimitError("limited", retry_after=1.0)
|
||||||
|
return {"id": "ok"}, None
|
||||||
|
|
||||||
|
|
||||||
|
class AlwaysRateLimitedProvider:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.calls = 0
|
||||||
|
|
||||||
|
async def get_model_by_hash(self, model_hash: str):
|
||||||
|
self.calls += 1
|
||||||
|
raise RateLimitError("limited")
|
||||||
|
|
||||||
|
|
||||||
|
class TrackingProvider:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.calls = 0
|
||||||
|
|
||||||
|
async def get_model_by_hash(self, model_hash: str):
|
||||||
|
self.calls += 1
|
||||||
|
return {"id": "secondary"}, None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fallback_retries_same_provider_on_rate_limit(monkeypatch):
|
||||||
|
sleep_mock = AsyncMock()
|
||||||
|
monkeypatch.setattr(provider_module.asyncio, "sleep", sleep_mock)
|
||||||
|
monkeypatch.setattr(provider_module.random, "uniform", lambda *_: 0.0)
|
||||||
|
|
||||||
|
primary = RateLimitThenSuccessProvider()
|
||||||
|
secondary = TrackingProvider()
|
||||||
|
|
||||||
|
fallback = FallbackMetadataProvider(
|
||||||
|
[("primary", primary), ("secondary", secondary)],
|
||||||
|
)
|
||||||
|
|
||||||
|
result, error = await fallback.get_model_by_hash("abc")
|
||||||
|
|
||||||
|
assert error is None
|
||||||
|
assert result == {"id": "ok"}
|
||||||
|
assert primary.calls == 2
|
||||||
|
assert secondary.calls == 0
|
||||||
|
sleep_mock.assert_awaited_once()
|
||||||
|
assert sleep_mock.await_args_list[0].args[0] == pytest.approx(1.0, rel=0.0, abs=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fallback_respects_retry_limit(monkeypatch):
|
||||||
|
sleep_mock = AsyncMock()
|
||||||
|
monkeypatch.setattr(provider_module.asyncio, "sleep", sleep_mock)
|
||||||
|
monkeypatch.setattr(provider_module.random, "uniform", lambda *_: 0.0)
|
||||||
|
|
||||||
|
primary = AlwaysRateLimitedProvider()
|
||||||
|
secondary = TrackingProvider()
|
||||||
|
|
||||||
|
fallback = FallbackMetadataProvider(
|
||||||
|
[("primary", primary), ("secondary", secondary)],
|
||||||
|
rate_limit_retry_limit=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(RateLimitError) as exc_info:
|
||||||
|
await fallback.get_model_by_hash("abc")
|
||||||
|
|
||||||
|
assert exc_info.value.provider == "primary"
|
||||||
|
assert primary.calls == 2
|
||||||
|
assert secondary.calls == 0
|
||||||
|
sleep_mock.assert_awaited_once()
|
||||||
Reference in New Issue
Block a user