diff --git a/py/services/civitai_client.py b/py/services/civitai_client.py index e860994b..5912dc07 100644 --- a/py/services/civitai_client.py +++ b/py/services/civitai_client.py @@ -5,6 +5,7 @@ import os from typing import Optional, Dict, Tuple, List from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager from .downloader import get_downloader +from .errors import RateLimitError logger = logging.getLogger(__name__) @@ -33,6 +34,29 @@ class CivitaiClient: self.base_url = "https://civitai.com/api/v1" + async def _make_request( + self, + method: str, + url: str, + *, + use_auth: bool = False, + **kwargs, + ) -> Tuple[bool, Dict | str]: + """Wrapper around downloader.make_request that surfaces rate limits.""" + + downloader = await get_downloader() + success, result = await downloader.make_request( + method, + url, + use_auth=use_auth, + **kwargs, + ) + if not success and isinstance(result, RateLimitError): + if result.provider is None: + result.provider = "civitai_api" + raise result + return success, result + @staticmethod def _remove_comfy_metadata(model_version: Optional[Dict]) -> None: """Remove Comfy-specific metadata from model version images.""" @@ -79,8 +103,7 @@ class CivitaiClient: async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]: try: - downloader = await get_downloader() - success, result = await downloader.make_request( + success, result = await self._make_request( 'GET', f"{self.base_url}/model-versions/by-hash/{model_hash}", use_auth=True @@ -90,7 +113,7 @@ class CivitaiClient: model_id = result.get('modelId') if model_id: # Fetch additional model metadata - success_model, data = await downloader.make_request( + success_model, data = await self._make_request( 'GET', f"{self.base_url}/models/{model_id}", use_auth=True @@ -113,6 +136,8 @@ class CivitaiClient: # Other error cases logger.error(f"Failed to fetch model info for {model_hash[:10]}: {result}") return None, str(result) + except RateLimitError: + raise except Exception as e: logger.error(f"API Error: {str(e)}") return None, str(e) @@ -138,8 +163,7 @@ class CivitaiClient: async def get_model_versions(self, model_id: str) -> List[Dict]: """Get all versions of a model with local availability info""" try: - downloader = await get_downloader() - success, result = await downloader.make_request( + success, result = await self._make_request( 'GET', f"{self.base_url}/models/{model_id}", use_auth=True @@ -152,6 +176,8 @@ class CivitaiClient: 'name': result.get('name', '') } return None + except RateLimitError: + raise except Exception as e: logger.error(f"Error fetching model versions: {e}") return None @@ -159,23 +185,23 @@ class CivitaiClient: async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: """Get specific model version with additional metadata.""" try: - downloader = await get_downloader() - if model_id is None and version_id is not None: - return await self._get_version_by_id_only(downloader, version_id) + return await self._get_version_by_id_only(version_id) if model_id is not None: - return await self._get_version_with_model_id(downloader, model_id, version_id) + return await self._get_version_with_model_id(model_id, version_id) logger.error("Either model_id or version_id must be provided") return None + except RateLimitError: + raise except Exception as e: logger.error(f"Error fetching model version: {e}") return None - async def _get_version_by_id_only(self, downloader, version_id: int) -> Optional[Dict]: - version = await self._fetch_version_by_id(downloader, version_id) + async def _get_version_by_id_only(self, version_id: int) -> Optional[Dict]: + version = await self._fetch_version_by_id(version_id) if version is None: return None @@ -184,15 +210,15 @@ class CivitaiClient: logger.error(f"No modelId found in version {version_id}") return None - model_data = await self._fetch_model_data(downloader, model_id) + model_data = await self._fetch_model_data(model_id) if model_data: self._enrich_version_with_model_data(version, model_data) self._remove_comfy_metadata(version) return version - async def _get_version_with_model_id(self, downloader, model_id: int, version_id: Optional[int]) -> Optional[Dict]: - model_data = await self._fetch_model_data(downloader, model_id) + async def _get_version_with_model_id(self, model_id: int, version_id: Optional[int]) -> Optional[Dict]: + model_data = await self._fetch_model_data(model_id) if not model_data: return None @@ -201,12 +227,12 @@ class CivitaiClient: return None target_version_id = target_version.get('id') - version = await self._fetch_version_by_id(downloader, target_version_id) if target_version_id else None + version = await self._fetch_version_by_id(target_version_id) if target_version_id else None if version is None: model_hash = self._extract_primary_model_hash(target_version) if model_hash: - version = await self._fetch_version_by_hash(downloader, model_hash) + version = await self._fetch_version_by_hash(model_hash) else: logger.warning( f"No primary model hash found for model {model_id} version {target_version_id}" @@ -219,8 +245,8 @@ class CivitaiClient: self._remove_comfy_metadata(version) return version - async def _fetch_model_data(self, downloader, model_id: int) -> Optional[Dict]: - success, data = await downloader.make_request( + async def _fetch_model_data(self, model_id: int) -> Optional[Dict]: + success, data = await self._make_request( 'GET', f"{self.base_url}/models/{model_id}", use_auth=True @@ -230,11 +256,11 @@ class CivitaiClient: logger.warning(f"Failed to fetch model data for model {model_id}") return None - async def _fetch_version_by_id(self, downloader, version_id: Optional[int]) -> Optional[Dict]: + async def _fetch_version_by_id(self, version_id: Optional[int]) -> Optional[Dict]: if version_id is None: return None - success, version = await downloader.make_request( + success, version = await self._make_request( 'GET', f"{self.base_url}/model-versions/{version_id}", use_auth=True @@ -245,11 +271,11 @@ class CivitaiClient: logger.warning(f"Failed to fetch version by id {version_id}") return None - async def _fetch_version_by_hash(self, downloader, model_hash: Optional[str]) -> Optional[Dict]: + async def _fetch_version_by_hash(self, model_hash: Optional[str]) -> Optional[Dict]: if not model_hash: return None - success, version = await downloader.make_request( + success, version = await self._make_request( 'GET', f"{self.base_url}/model-versions/by-hash/{model_hash}", use_auth=True @@ -323,11 +349,10 @@ class CivitaiClient: - An error message if there was an error, or None on success """ try: - downloader = await get_downloader() url = f"{self.base_url}/model-versions/{version_id}" logger.debug(f"Resolving DNS for model version info: {url}") - success, result = await downloader.make_request( + success, result = await self._make_request( 'GET', url, use_auth=True @@ -347,6 +372,8 @@ class CivitaiClient: # Other error cases logger.error(f"Failed to fetch model info for {version_id}: {result}") return None, str(result) + except RateLimitError: + raise except Exception as e: error_msg = f"Error fetching model version info: {e}" logger.error(error_msg) @@ -362,11 +389,10 @@ class CivitaiClient: Optional[Dict]: The image data or None if not found """ try: - downloader = await get_downloader() url = f"{self.base_url}/images?imageId={image_id}&nsfw=X" logger.debug(f"Fetching image info for ID: {image_id}") - success, result = await downloader.make_request( + success, result = await self._make_request( 'GET', url, use_auth=True @@ -381,6 +407,8 @@ class CivitaiClient: logger.error(f"Failed to fetch image info for ID: {image_id}: {result}") return None + except RateLimitError: + raise except Exception as e: error_msg = f"Error fetching image info: {e}" logger.error(error_msg) @@ -392,9 +420,8 @@ class CivitaiClient: return None try: - downloader = await get_downloader() url = f"{self.base_url}/models?username={username}" - success, result = await downloader.make_request( + success, result = await self._make_request( 'GET', url, use_auth=True @@ -416,6 +443,8 @@ class CivitaiClient: self._remove_comfy_metadata(version) return items + except RateLimitError: + raise except Exception as exc: # pragma: no cover - defensive logging logger.error("Error fetching models for %s: %s", username, exc) return None diff --git a/py/services/downloader.py b/py/services/downloader.py index dafef78e..775ba5d8 100644 --- a/py/services/downloader.py +++ b/py/services/downloader.py @@ -17,8 +17,10 @@ import aiohttp from collections import deque from dataclasses import dataclass from datetime import datetime, timedelta +from email.utils import parsedate_to_datetime from typing import Optional, Dict, Tuple, Callable, Union, Awaitable from ..services.settings_manager import get_settings_manager +from .errors import RateLimitError logger = logging.getLogger(__name__) @@ -587,6 +589,19 @@ class Downloader: return False, "Access forbidden" elif response.status == 404: return False, "Resource not found" + elif response.status == 429: + retry_after = self._extract_retry_after(response.headers) + error_msg = "Request rate limited" + logger.warning( + "Rate limit encountered for %s %s; retry_after=%s", + method, + url, + retry_after, + ) + return False, RateLimitError( + error_msg, + retry_after=retry_after, + ) else: return False, f"Request failed with status {response.status}" @@ -608,6 +623,38 @@ class Downloader: await self._create_session() logger.info("HTTP session refreshed due to settings change") + @staticmethod + def _extract_retry_after(headers) -> Optional[float]: + """Parse the Retry-After header into seconds.""" + if not headers: + return None + + header_value = headers.get("Retry-After") + if not header_value: + return None + + header_value = header_value.strip() + if not header_value: + return None + + if header_value.isdigit(): + try: + seconds = float(header_value) + except ValueError: + return None + return max(0.0, seconds) + + try: + retry_datetime = parsedate_to_datetime(header_value) + except (TypeError, ValueError): + return None + + if retry_datetime.tzinfo is None: + return None + + delta = retry_datetime - datetime.now(tz=retry_datetime.tzinfo) + return max(0.0, delta.total_seconds()) + # Global instance accessor async def get_downloader() -> Downloader: diff --git a/py/services/errors.py b/py/services/errors.py new file mode 100644 index 00000000..54478381 --- /dev/null +++ b/py/services/errors.py @@ -0,0 +1,21 @@ +"""Common service-level exception types.""" + +from __future__ import annotations + +from typing import Optional + + +class RateLimitError(RuntimeError): + """Raised when a remote provider rejects a request due to rate limiting.""" + + def __init__( + self, + message: str, + *, + retry_after: Optional[float] = None, + provider: Optional[str] = None, + ) -> None: + super().__init__(message) + self.retry_after = retry_after + self.provider = provider + diff --git a/py/services/metadata_service.py b/py/services/metadata_service.py index 730e45b0..302e6551 100644 --- a/py/services/metadata_service.py +++ b/py/services/metadata_service.py @@ -1,6 +1,7 @@ import os import logging from .model_metadata_provider import ( + ModelMetadataProvider, ModelMetadataProviderManager, SQLiteModelMetadataProvider, CivitaiModelMetadataProvider, @@ -68,10 +69,10 @@ async def initialize_metadata_providers(): # Set up fallback provider based on available providers if len(providers) > 1: # Always use Civitai API (it has better metadata), then CivArchive API, then Archive DB - ordered_providers = [] - ordered_providers.extend([p[1] for p in providers if p[0] == 'civitai_api']) - ordered_providers.extend([p[1] for p in providers if p[0] == 'civarchive_api']) - ordered_providers.extend([p[1] for p in providers if p[0] == 'sqlite']) + ordered_providers: list[tuple[str, ModelMetadataProvider]] = [] + ordered_providers.extend([p for p in providers if p[0] == 'civitai_api']) + ordered_providers.extend([p for p in providers if p[0] == 'civarchive_api']) + ordered_providers.extend([p for p in providers if p[0] == 'sqlite']) if ordered_providers: fallback_provider = FallbackMetadataProvider(ordered_providers) diff --git a/py/services/metadata_sync_service.py b/py/services/metadata_sync_service.py index 485ddc30..b9b192ff 100644 --- a/py/services/metadata_sync_service.py +++ b/py/services/metadata_sync_service.py @@ -10,6 +10,7 @@ from typing import Any, Awaitable, Callable, Dict, Iterable, Optional from ..services.settings_manager import SettingsManager from ..utils.model_utils import determine_base_model +from .errors import RateLimitError logger = logging.getLogger(__name__) @@ -205,6 +206,9 @@ class MetadataSyncService: for provider_name, provider in provider_attempts: try: civitai_metadata_candidate, error = await provider.get_model_by_hash(sha256) + except RateLimitError as exc: + exc.provider = exc.provider or (provider_name or provider.__class__.__name__) + raise except Exception as exc: # pragma: no cover - defensive logging logger.error("Provider %s failed for hash %s: %s", provider_name, sha256, exc) civitai_metadata_candidate, error = None, str(exc) @@ -299,6 +303,16 @@ class MetadataSyncService: error_msg = f"Error fetching metadata - Missing key: {exc} in model_data={model_data}" logger.error(error_msg) return False, error_msg + except RateLimitError as exc: + provider_label = exc.provider or "metadata provider" + wait_hint = ( + f"; retry after approximately {int(exc.retry_after)}s" + if exc.retry_after and exc.retry_after > 0 + else "" + ) + error_msg = f"Rate limited by {provider_label}{wait_hint}" + logger.warning(error_msg) + return False, error_msg except Exception as exc: # pragma: no cover - error path error_msg = f"Error fetching metadata: {exc}" logger.error(error_msg, exc_info=True) diff --git a/py/services/model_metadata_provider.py b/py/services/model_metadata_provider.py index 73d9b7d8..2c580c59 100644 --- a/py/services/model_metadata_provider.py +++ b/py/services/model_metadata_provider.py @@ -1,8 +1,11 @@ from abc import ABC, abstractmethod +import asyncio import json import logging -from typing import Optional, Dict, Tuple, Any, List +import random +from typing import Optional, Dict, Tuple, Any, List, Sequence from .downloader import get_downloader +from .errors import RateLimitError try: from bs4 import BeautifulSoup @@ -350,64 +353,166 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider): class FallbackMetadataProvider(ModelMetadataProvider): """Try providers in order, return first successful result.""" - def __init__(self, providers: list): - self.providers = providers + + def __init__( + self, + providers: Sequence[ModelMetadataProvider | Tuple[str, ModelMetadataProvider]], + *, + rate_limit_retry_limit: int = 3, + rate_limit_base_delay: float = 1.5, + rate_limit_max_delay: float = 30.0, + rate_limit_jitter_ratio: float = 0.2, + ) -> None: + self.providers: List[ModelMetadataProvider] = [] + self._provider_labels: List[str] = [] + + for entry in providers: + if isinstance(entry, tuple) and len(entry) == 2: + name, provider = entry + else: + provider = entry + name = provider.__class__.__name__ + self.providers.append(provider) + self._provider_labels.append(str(name)) + + self._rate_limit_retry_limit = max(1, rate_limit_retry_limit) + self._rate_limit_base_delay = rate_limit_base_delay + self._rate_limit_max_delay = rate_limit_max_delay + self._rate_limit_jitter_ratio = max(0.0, rate_limit_jitter_ratio) async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]: - for provider in self.providers: + for provider, label in self._iter_providers(): try: - result, error = await provider.get_model_by_hash(model_hash) + result, error = await self._call_with_rate_limit( + label, + provider.get_model_by_hash, + model_hash, + ) if result: return result, error + except RateLimitError as exc: + exc.provider = exc.provider or label + raise exc except Exception as e: - logger.debug(f"Provider failed for get_model_by_hash: {e}") + logger.debug("Provider %s failed for get_model_by_hash: %s", label, e) continue return None, "Model not found" async def get_model_versions(self, model_id: str) -> Optional[Dict]: - for provider in self.providers: + for provider, label in self._iter_providers(): try: - result = await provider.get_model_versions(model_id) + result = await self._call_with_rate_limit( + label, + provider.get_model_versions, + model_id, + ) if result: return result + except RateLimitError as exc: + exc.provider = exc.provider or label + raise exc except Exception as e: - logger.debug(f"Provider failed for get_model_versions: {e}") + logger.debug("Provider %s failed for get_model_versions: %s", label, e) continue return None async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: - for provider in self.providers: + for provider, label in self._iter_providers(): try: - result = await provider.get_model_version(model_id, version_id) + result = await self._call_with_rate_limit( + label, + provider.get_model_version, + model_id, + version_id, + ) if result: return result + except RateLimitError as exc: + exc.provider = exc.provider or label + raise exc except Exception as e: - logger.debug(f"Provider failed for get_model_version: {e}") + logger.debug("Provider %s failed for get_model_version: %s", label, e) continue return None async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]: - for provider in self.providers: + for provider, label in self._iter_providers(): try: - result, error = await provider.get_model_version_info(version_id) + result, error = await self._call_with_rate_limit( + label, + provider.get_model_version_info, + version_id, + ) if result: return result, error + except RateLimitError as exc: + exc.provider = exc.provider or label + raise exc except Exception as e: - logger.debug(f"Provider failed for get_model_version_info: {e}") + logger.debug("Provider %s failed for get_model_version_info: %s", label, e) continue return None, "No provider could retrieve the data" async def get_user_models(self, username: str) -> Optional[List[Dict]]: - for provider in self.providers: + for provider, label in self._iter_providers(): try: - result = await provider.get_user_models(username) + result = await self._call_with_rate_limit( + label, + provider.get_user_models, + username, + ) if result is not None: return result + except RateLimitError as exc: + exc.provider = exc.provider or label + raise exc except Exception as e: - logger.debug(f"Provider failed for get_user_models: {e}") + logger.debug("Provider %s failed for get_user_models: %s", label, e) continue return None + def _iter_providers(self): + return zip(self.providers, self._provider_labels) + + async def _call_with_rate_limit( + self, + label: str, + func, + *args, + **kwargs, + ): + attempt = 0 + while True: + try: + return await func(*args, **kwargs) + except RateLimitError as exc: + attempt += 1 + if attempt >= self._rate_limit_retry_limit: + exc.provider = exc.provider or label + raise exc + delay = self._calculate_rate_limit_delay(exc.retry_after, attempt) + logger.warning( + "Provider %s rate limited request; retrying in %.2fs (attempt %s/%s)", + label, + delay, + attempt, + self._rate_limit_retry_limit, + ) + await asyncio.sleep(delay) + except Exception: + raise + + def _calculate_rate_limit_delay(self, retry_after: Optional[float], attempt: int) -> float: + if retry_after is not None: + return min(self._rate_limit_max_delay, max(0.0, retry_after)) + + base_delay = self._rate_limit_base_delay * (2 ** max(0, attempt - 1)) + jitter_span = base_delay * self._rate_limit_jitter_ratio + if jitter_span > 0: + base_delay += random.uniform(-jitter_span, jitter_span) + + return min(self._rate_limit_max_delay, max(0.0, base_delay)) + class ModelMetadataProviderManager: """Manager for selecting and using model metadata providers""" diff --git a/tests/services/test_civitai_client.py b/tests/services/test_civitai_client.py index aaf0a0a9..2421a07d 100644 --- a/tests/services/test_civitai_client.py +++ b/tests/services/test_civitai_client.py @@ -5,6 +5,7 @@ import pytest from py.services import civitai_client as civitai_client_module from py.services.civitai_client import CivitaiClient +from py.services.errors import RateLimitError from py.services.model_metadata_provider import ModelMetadataProviderManager @@ -106,6 +107,21 @@ async def test_get_model_by_hash_handles_not_found(monkeypatch, downloader): assert error == "Model not found" +async def test_get_model_by_hash_propagates_rate_limit(monkeypatch, downloader): + async def fake_make_request(method, url, use_auth=True): + return False, RateLimitError("limited", retry_after=4) + + downloader.make_request = fake_make_request + + client = await CivitaiClient.get_instance() + + with pytest.raises(RateLimitError) as exc_info: + await client.get_model_by_hash("limited") + + assert exc_info.value.retry_after == 4 + assert exc_info.value.provider == "civitai_api" + + async def test_download_preview_image_writes_file(tmp_path, downloader): client = await CivitaiClient.get_instance() target = tmp_path / "preview" / "image.jpg" diff --git a/tests/services/test_metadata_sync_service.py b/tests/services/test_metadata_sync_service.py index cd3ade58..09d56f5a 100644 --- a/tests/services/test_metadata_sync_service.py +++ b/tests/services/test_metadata_sync_service.py @@ -4,6 +4,7 @@ from unittest.mock import AsyncMock import pytest +from py.services.errors import RateLimitError from py.services.metadata_sync_service import MetadataSyncService @@ -340,6 +341,37 @@ async def test_fetch_and_update_model_falls_back_to_sqlite_after_civarchive_fail helpers.metadata_manager.save_metadata.assert_awaited() +@pytest.mark.asyncio +async def test_fetch_and_update_model_returns_rate_limit_error(tmp_path): + rate_error = RateLimitError("limited", retry_after=7) + default_provider = SimpleNamespace( + get_model_by_hash=AsyncMock(side_effect=rate_error), + get_model_version=AsyncMock(), + ) + helpers = build_service(default_provider=default_provider) + + model_path = tmp_path / "model.safetensors" + model_data = { + "file_path": str(model_path), + "model_name": "Local", + } + update_cache = AsyncMock() + + ok, error = await helpers.service.fetch_and_update_model( + sha256="deadbeef", + file_path=str(model_path), + model_data=model_data, + update_cache_func=update_cache, + ) + + assert ok is False + assert error is not None and "Rate limited" in error + assert "7" in error + helpers.metadata_manager.save_metadata.assert_not_awaited() + update_cache.assert_not_awaited() + helpers.provider_selector.assert_not_awaited() + + @pytest.mark.asyncio async def test_relink_metadata_fetches_version_and_updates_sha(tmp_path): provider = SimpleNamespace( diff --git a/tests/services/test_model_metadata_provider.py b/tests/services/test_model_metadata_provider.py new file mode 100644 index 00000000..cb1761de --- /dev/null +++ b/tests/services/test_model_metadata_provider.py @@ -0,0 +1,82 @@ +from unittest.mock import AsyncMock + +import pytest + +from py.services import model_metadata_provider as provider_module +from py.services.errors import RateLimitError +from py.services.model_metadata_provider import FallbackMetadataProvider + + +class RateLimitThenSuccessProvider: + def __init__(self) -> None: + self.calls = 0 + + async def get_model_by_hash(self, model_hash: str): + self.calls += 1 + if self.calls == 1: + raise RateLimitError("limited", retry_after=1.0) + return {"id": "ok"}, None + + +class AlwaysRateLimitedProvider: + def __init__(self) -> None: + self.calls = 0 + + async def get_model_by_hash(self, model_hash: str): + self.calls += 1 + raise RateLimitError("limited") + + +class TrackingProvider: + def __init__(self) -> None: + self.calls = 0 + + async def get_model_by_hash(self, model_hash: str): + self.calls += 1 + return {"id": "secondary"}, None + + +@pytest.mark.asyncio +async def test_fallback_retries_same_provider_on_rate_limit(monkeypatch): + sleep_mock = AsyncMock() + monkeypatch.setattr(provider_module.asyncio, "sleep", sleep_mock) + monkeypatch.setattr(provider_module.random, "uniform", lambda *_: 0.0) + + primary = RateLimitThenSuccessProvider() + secondary = TrackingProvider() + + fallback = FallbackMetadataProvider( + [("primary", primary), ("secondary", secondary)], + ) + + result, error = await fallback.get_model_by_hash("abc") + + assert error is None + assert result == {"id": "ok"} + assert primary.calls == 2 + assert secondary.calls == 0 + sleep_mock.assert_awaited_once() + assert sleep_mock.await_args_list[0].args[0] == pytest.approx(1.0, rel=0.0, abs=1e-6) + + +@pytest.mark.asyncio +async def test_fallback_respects_retry_limit(monkeypatch): + sleep_mock = AsyncMock() + monkeypatch.setattr(provider_module.asyncio, "sleep", sleep_mock) + monkeypatch.setattr(provider_module.random, "uniform", lambda *_: 0.0) + + primary = AlwaysRateLimitedProvider() + secondary = TrackingProvider() + + fallback = FallbackMetadataProvider( + [("primary", primary), ("secondary", secondary)], + rate_limit_retry_limit=2, + ) + + with pytest.raises(RateLimitError) as exc_info: + await fallback.get_model_by_hash("abc") + + assert exc_info.value.provider == "primary" + assert primary.calls == 2 + assert secondary.calls == 0 + sleep_mock.assert_awaited_once()