diff --git a/py/services/civitai_client.py b/py/services/civitai_client.py index 56a9adcb..b4c9561f 100644 --- a/py/services/civitai_client.py +++ b/py/services/civitai_client.py @@ -2,6 +2,7 @@ import asyncio import copy import logging import os +from collections import OrderedDict from typing import Any, Optional, Dict, Tuple, List, Sequence from .connectivity_guard import ( OFFLINE_FRIENDLY_MESSAGE, @@ -45,6 +46,14 @@ class CivitaiClient: self._initialized = True self.base_url = "https://civitai.red/api/v1" + # In-memory cache to avoid redundant get_model_version_info calls + # within the same import/scan flow. Only successful results are cached. + # Uses OrderedDict with LRU eviction at MAX_CACHE_ENTRIES to prevent + # unbounded growth in long-running server processes. + self._version_info_cache: OrderedDict[ + str, Tuple[Optional[Dict], Optional[str]] + ] = OrderedDict() + self._MAX_CACHE_ENTRIES = 500 def _build_image_info_url(self, image_id: str) -> str: return f"{self.base_url}/images?imageId={image_id}&nsfw=X" @@ -57,22 +66,57 @@ class CivitaiClient: use_auth: bool = False, **kwargs, ) -> Tuple[bool, Dict | str]: - """Wrapper around downloader.make_request that surfaces rate limits.""" + """Wrapper around downloader.make_request that surfaces rate limits, + with retry for transient server errors (5xx, Cloudflare 524, network flakiness).""" - downloader = await get_downloader() - success, result = await downloader.make_request( - method, - url, - use_auth=use_auth, - **kwargs, - ) - if not success and isinstance(result, RateLimitError): - if result.provider is None: - result.provider = "civitai_api" - raise result - if not success and is_offline_cooldown_error(result): - return False, OFFLINE_FRIENDLY_MESSAGE - return success, result + max_retries = 3 + for attempt in range(max_retries): + downloader = await get_downloader() + success, result = await downloader.make_request( + method, + url, + use_auth=use_auth, + **kwargs, + ) + if success: + return True, result + + if isinstance(result, RateLimitError): + if result.provider is None: + result.provider = "civitai_api" + raise result + + if is_offline_cooldown_error(result): + return False, OFFLINE_FRIENDLY_MESSAGE + + # Transient server error — retry with exponential backoff + if self._is_transient_server_error(str(result)): + if attempt < max_retries - 1: + wait = 2**attempt # 1s, 2s, 4s + logger.info( + "Transient error on %s %s, retrying in %ds " + "(attempt %d/%d): %s", + method, + url, + wait, + attempt + 1, + max_retries, + result, + ) + await asyncio.sleep(wait) + continue + logger.warning( + "All %d retries exhausted for %s %s: %s", + max_retries, + method, + url, + result, + ) + return False, result + + return False, result + + return False, "Unexpected error in _make_request" @staticmethod def _remove_comfy_metadata(model_version: Optional[Dict]) -> None: @@ -512,6 +556,14 @@ class CivitaiClient: - The model version data or None if not found - An error message if there was an error, or None on success """ + # In-memory cache avoids redundant API calls within the same + # import/scan flow (e.g. _resolve_base_model_from_checkpoint + # followed by _resolve_and_populate_checkpoint with the same id). + if version_id in self._version_info_cache: + logger.debug("Cache hit for model version info: %s", version_id) + self._version_info_cache.move_to_end(version_id) # LRU bump + return self._version_info_cache[version_id] + try: url = f"{self.base_url}/model-versions/{version_id}" @@ -521,6 +573,11 @@ class CivitaiClient: if success: logger.debug("Successfully fetched model version info for: %s", version_id) self._remove_comfy_metadata(result) + self._version_info_cache[version_id] = (result, None) + self._version_info_cache.move_to_end(version_id) + # Evict oldest entry when over capacity + if len(self._version_info_cache) > self._MAX_CACHE_ENTRIES: + self._version_info_cache.popitem(last=False) return result, None # Handle specific error cases