diff --git a/py/routes/handlers/model_handlers.py b/py/routes/handlers/model_handlers.py index 18bb5e1b..efc16001 100644 --- a/py/routes/handlers/model_handlers.py +++ b/py/routes/handlers/model_handlers.py @@ -16,6 +16,10 @@ import jinja2 from ...config import config from ...services.download_coordinator import DownloadCoordinator +from ...services.connectivity_guard import ( + OFFLINE_FRIENDLY_MESSAGE, + is_expected_offline_error, +) from ...services.metadata_sync_service import MetadataSyncService from ...services.model_file_service import ModelMoveService from ...services.preview_asset_service import PreviewAssetService @@ -504,6 +508,11 @@ class ModelManagementHandler: formatted_metadata = await self._service.format_response(model_data) return web.json_response({"success": True, "metadata": formatted_metadata}) except Exception as exc: + if is_expected_offline_error(str(exc)): + return web.json_response( + {"success": False, "error": OFFLINE_FRIENDLY_MESSAGE}, + status=503, + ) self._logger.error("Error fetching from CivitAI: %s", exc, exc_info=True) return web.json_response({"success": False, "error": str(exc)}, status=500) @@ -550,6 +559,11 @@ class ModelManagementHandler: } ) except Exception as exc: + if is_expected_offline_error(str(exc)): + return web.json_response( + {"success": False, "error": OFFLINE_FRIENDLY_MESSAGE}, + status=503, + ) self._logger.error("Error re-linking to CivitAI: %s", exc, exc_info=True) return web.json_response({"success": False, "error": str(exc)}, status=500) @@ -1858,6 +1872,11 @@ class ModelUpdateHandler: status=429, ) except Exception as exc: # pragma: no cover - defensive log + if is_expected_offline_error(str(exc)): + return web.json_response( + {"success": False, "error": OFFLINE_FRIENDLY_MESSAGE}, + status=503, + ) self._logger.error("Failed to fetch license info: %s", exc, exc_info=True) return web.json_response({"success": False, "error": str(exc)}, status=500) @@ -1946,9 +1965,12 @@ class ModelUpdateHandler: {"success": False, "error": str(exc) or "Rate limited"}, status=429 ) except Exception as exc: # pragma: no cover - defensive logging - self._logger.error( - "Failed to refresh model updates: %s", exc, exc_info=True - ) + if is_expected_offline_error(str(exc)): + return web.json_response( + {"success": False, "error": OFFLINE_FRIENDLY_MESSAGE}, + status=503, + ) + self._logger.error("Failed to refresh model updates: %s", exc, exc_info=True) return web.json_response({"success": False, "error": str(exc)}, status=500) serialized_records = [] diff --git a/py/services/civitai_client.py b/py/services/civitai_client.py index 466eaa2b..6b4ef91d 100644 --- a/py/services/civitai_client.py +++ b/py/services/civitai_client.py @@ -3,6 +3,11 @@ import copy import logging import os from typing import Any, Optional, Dict, Tuple, List, Sequence +from .connectivity_guard import ( + OFFLINE_FRIENDLY_MESSAGE, + is_expected_offline_error, + is_offline_cooldown_error, +) from .model_metadata_provider import ( CivitaiModelMetadataProvider, ModelMetadataProviderManager, @@ -65,6 +70,8 @@ class CivitaiClient: if result.provider is None: result.provider = "civitai_api" raise result + if not success and is_offline_cooldown_error(result): + return False, OFFLINE_FRIENDLY_MESSAGE return success, result @staticmethod @@ -124,6 +131,8 @@ class CivitaiClient: ) if not success: message = str(version) + if is_expected_offline_error(message): + return None, OFFLINE_FRIENDLY_MESSAGE if "not found" in message.lower(): return None, "Model not found" @@ -164,6 +173,9 @@ class CivitaiClient: return True return False except Exception as e: + if is_expected_offline_error(str(e)): + logger.debug("Preview download skipped due to offline state.") + return False logger.error(f"Download Error: {str(e)}") return False @@ -207,6 +219,9 @@ class CivitaiClient: message = self._extract_error_message(result) if message and "not found" in message.lower(): raise ResourceNotFoundError(f"Resource not found for model {model_id}") + if is_expected_offline_error(message): + logger.info("Civitai request skipped: %s", OFFLINE_FRIENDLY_MESSAGE) + return None if message: raise RuntimeError(message) return None @@ -357,6 +372,8 @@ class CivitaiClient: ) if success: return data + if is_expected_offline_error(data): + return None logger.warning(f"Failed to fetch model data for model {model_id}") return None @@ -371,6 +388,8 @@ class CivitaiClient: ) if success: return version + if is_expected_offline_error(version): + return None logger.warning(f"Failed to fetch version by id {version_id}") return None @@ -386,6 +405,8 @@ class CivitaiClient: ) if success: return version + if is_expected_offline_error(version): + return None logger.warning(f"Failed to fetch version by hash {model_hash}") return None @@ -473,6 +494,8 @@ class CivitaiClient: return result, None # Handle specific error cases + if is_expected_offline_error(result): + return None, OFFLINE_FRIENDLY_MESSAGE if "not found" in str(result): error_msg = f"Model not found" logger.warning(f"Model version not found: {version_id} - {error_msg}") @@ -507,6 +530,8 @@ class CivitaiClient: success, result = await self._make_request("GET", url, use_auth=True) if not success: + if is_expected_offline_error(result): + return None logger.error( "Failed to fetch image info for ID %s from civitai.red: %s", image_id, @@ -566,6 +591,9 @@ class CivitaiClient: ) if not success: + if is_expected_offline_error(result): + logger.info("User model fetch skipped: %s", OFFLINE_FRIENDLY_MESSAGE) + return None logger.error("Failed to fetch models for %s: %s", username, result) return None diff --git a/py/services/connectivity_guard.py b/py/services/connectivity_guard.py new file mode 100644 index 00000000..1f60d5df --- /dev/null +++ b/py/services/connectivity_guard.py @@ -0,0 +1,204 @@ +"""In-memory connectivity guard to suppress repeated network retries when offline.""" + +from __future__ import annotations + +import asyncio +import errno +import logging +import socket +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import Any + +import aiohttp + +logger = logging.getLogger(__name__) + +OFFLINE_COOLDOWN_ERROR = "offline_cooldown" +OFFLINE_FRIENDLY_MESSAGE = "Network offline, will retry automatically later" + + +def is_offline_cooldown_error(value: Any) -> bool: + """Return True when a response payload represents guard short-circuit.""" + return isinstance(value, str) and value == OFFLINE_COOLDOWN_ERROR + + +def is_expected_offline_error(value: Any) -> bool: + """Return True when payload is an expected offline-related result.""" + if is_offline_cooldown_error(value): + return True + if not isinstance(value, str): + return False + normalized = value.lower() + return "network offline" in normalized or "offline" in normalized + + +class ConnectivityGuard: + """Tracks network failures and gates outbound requests during cooldown.""" + + _instance: "ConnectivityGuard | None" = None + _instance_lock = asyncio.Lock() + + @classmethod + async def get_instance(cls) -> "ConnectivityGuard": + async with cls._instance_lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def __init__(self) -> None: + if hasattr(self, "_initialized"): + return + self._initialized = True + self._default_destination = "__global__" + self._destination_states: dict[str, _DestinationState] = { + self._default_destination: _DestinationState() + } + self.base_backoff_seconds = 30 + self.max_backoff_seconds = 300 + self.failure_threshold = 3 + + @property + def online(self) -> bool: + return self._state_for_destination(None).online + + @online.setter + def online(self, value: bool) -> None: + self._state_for_destination(None).online = value + + @property + def failure_count(self) -> int: + return self._state_for_destination(None).failure_count + + @failure_count.setter + def failure_count(self, value: int) -> None: + self._state_for_destination(None).failure_count = value + + @property + def cooldown_until(self) -> datetime | None: + return self._state_for_destination(None).cooldown_until + + @cooldown_until.setter + def cooldown_until(self, value: datetime | None) -> None: + self._state_for_destination(None).cooldown_until = value + + def _now(self) -> datetime: + return datetime.now() + + def _normalize_destination(self, destination: str | None) -> str: + if destination is None or not destination.strip(): + return self._default_destination + return destination.lower().strip() + + def _state_for_destination(self, destination: str | None) -> "_DestinationState": + destination_key = self._normalize_destination(destination) + if destination_key not in self._destination_states: + self._destination_states[destination_key] = _DestinationState() + return self._destination_states[destination_key] + + def in_cooldown(self, destination: str | None = None) -> bool: + state = self._state_for_destination(destination) + if state.cooldown_until is None: + return False + return self._now() < state.cooldown_until + + def cooldown_remaining_seconds(self, destination: str | None = None) -> float: + state = self._state_for_destination(destination) + if state.cooldown_until is None: + return 0.0 + return max(0.0, (state.cooldown_until - self._now()).total_seconds()) + + def should_block_request(self, destination: str | None = None) -> bool: + return self.in_cooldown(destination) + + def register_success(self, destination: str | None = None) -> None: + destination_key = self._normalize_destination(destination) + state = self._state_for_destination(destination_key) + was_offline = (not state.online) or state.cooldown_until is not None + state.online = True + state.failure_count = 0 + state.cooldown_until = None + if was_offline: + logger.info( + "Connectivity restored for destination '%s'; requests resumed.", + destination_key, + ) + + def register_network_failure( + self, exc: Exception, destination: str | None = None + ) -> None: + destination_key = self._normalize_destination(destination) + state = self._state_for_destination(destination_key) + state.online = False + state.failure_count += 1 + + if state.failure_count < self.failure_threshold: + logger.debug( + "Network failure tracked for destination '%s' (%d/%d): %s", + destination_key, + state.failure_count, + self.failure_threshold, + exc, + ) + return + + retry_step = state.failure_count - self.failure_threshold + backoff = min( + self.max_backoff_seconds, + self.base_backoff_seconds * (2**retry_step), + ) + should_log_warning = not self.in_cooldown(destination_key) + state.cooldown_until = self._now() + timedelta(seconds=backoff) + + if should_log_warning: + logger.warning( + "Connectivity offline for destination '%s'; enter cooldown for %ss after %d network failures.", + destination_key, + int(backoff), + state.failure_count, + ) + else: + logger.debug( + "Cooldown still active for destination '%s'; failure_count=%d, backoff=%ss.", + destination_key, + state.failure_count, + int(backoff), + ) + + @staticmethod + def is_network_unreachable_error(exc: Exception) -> bool: + """Return whether the exception should count as connectivity failure.""" + if isinstance(exc, asyncio.CancelledError): + return False + + if isinstance( + exc, + ( + asyncio.TimeoutError, + TimeoutError, + ConnectionRefusedError, + socket.gaierror, + aiohttp.ServerTimeoutError, + aiohttp.ConnectionTimeoutError, + aiohttp.ClientConnectorError, + aiohttp.ClientConnectionError, + ), + ): + return True + + if isinstance(exc, OSError) and exc.errno in { + errno.ENETUNREACH, + errno.EHOSTUNREACH, + errno.ETIMEDOUT, + errno.ECONNREFUSED, + }: + return True + + return False + + +@dataclass +class _DestinationState: + online: bool = True + failure_count: int = 0 + cooldown_until: datetime | None = None diff --git a/py/services/downloader.py b/py/services/downloader.py index 55b6b40e..20fe5851 100644 --- a/py/services/downloader.py +++ b/py/services/downloader.py @@ -18,8 +18,13 @@ from collections import deque from dataclasses import dataclass from datetime import datetime, timedelta from email.utils import parsedate_to_datetime +from urllib.parse import urlparse from typing import Optional, Dict, Tuple, Callable, Union, Awaitable from ..services.settings_manager import get_settings_manager +from .connectivity_guard import ( + OFFLINE_COOLDOWN_ERROR, + ConnectivityGuard, +) from .errors import RateLimitError logger = logging.getLogger(__name__) @@ -797,6 +802,11 @@ class Downloader: Returns: Tuple[bool, Union[bytes, str], Optional[Dict]]: (success, content or error message, response headers if requested) """ + guard = await ConnectivityGuard.get_instance() + destination = self._guard_destination(url) + if guard.should_block_request(destination): + return False, OFFLINE_COOLDOWN_ERROR, None + try: session = await self.session # Debug log for proxy mode at request time @@ -819,6 +829,7 @@ class Downloader: ) as response: if response.status == 200: content = await response.read() + guard.register_success(destination) if return_headers: return True, content, dict(response.headers) else: @@ -837,6 +848,12 @@ class Downloader: return False, error_msg, None except Exception as e: + if guard.is_network_unreachable_error(e): + guard.register_network_failure(e, destination) + if guard.should_block_request(destination): + return False, OFFLINE_COOLDOWN_ERROR, None + logger.debug("Network unavailable during memory download: %s", e) + return False, str(e), None logger.error(f"Error downloading to memory from {url}: {e}") return False, str(e), None @@ -857,6 +874,11 @@ class Downloader: Returns: Tuple[bool, Union[Dict, str]]: (success, headers dict or error message) """ + guard = await ConnectivityGuard.get_instance() + destination = self._guard_destination(url) + if guard.should_block_request(destination): + return False, OFFLINE_COOLDOWN_ERROR + try: session = await self.session # Debug log for proxy mode at request time @@ -878,11 +900,18 @@ class Downloader: url, headers=headers, proxy=self.proxy_url ) as response: if response.status == 200: + guard.register_success(destination) return True, dict(response.headers) else: return False, f"Head request failed with status {response.status}" except Exception as e: + if guard.is_network_unreachable_error(e): + guard.register_network_failure(e, destination) + if guard.should_block_request(destination): + return False, OFFLINE_COOLDOWN_ERROR + logger.debug("Network unavailable during header probe: %s", e) + return False, str(e) logger.error(f"Error getting headers from {url}: {e}") return False, str(e) @@ -907,6 +936,11 @@ class Downloader: Returns: Tuple[bool, Union[Dict, str]]: (success, response data or error message) """ + guard = await ConnectivityGuard.get_instance() + destination = self._guard_destination(url) + if guard.should_block_request(destination): + return False, OFFLINE_COOLDOWN_ERROR + try: session = await self.session # Debug log for proxy mode at request time @@ -930,6 +964,7 @@ class Downloader: method, url, headers=headers, **kwargs ) as response: if response.status == 200: + guard.register_success(destination) # Try to parse as JSON, fall back to text try: data = await response.json() @@ -960,6 +995,12 @@ class Downloader: return False, f"Request failed with status {response.status}" except Exception as e: + if guard.is_network_unreachable_error(e): + guard.register_network_failure(e, destination) + if guard.should_block_request(destination): + return False, OFFLINE_COOLDOWN_ERROR + logger.debug("Network unavailable for %s %s: %s", method, url, e) + return False, str(e) logger.error(f"Error making {method} request to {url}: {e}") return False, str(e) @@ -1010,6 +1051,14 @@ class Downloader: delta = retry_datetime - datetime.now(tz=retry_datetime.tzinfo) return max(0.0, delta.total_seconds()) + @staticmethod + def _guard_destination(url: str) -> str: + """Build per-destination connectivity guard scope from request URL.""" + parsed_url = urlparse(url) + if parsed_url.hostname: + return parsed_url.hostname.lower() + return "unknown" + # Global instance accessor async def get_downloader() -> Downloader: diff --git a/py/services/metadata_sync_service.py b/py/services/metadata_sync_service.py index 1c14eaf0..37433934 100644 --- a/py/services/metadata_sync_service.py +++ b/py/services/metadata_sync_service.py @@ -11,6 +11,7 @@ from typing import Any, Awaitable, Callable, Dict, Iterable, Optional from ..services.settings_manager import SettingsManager from ..utils.civitai_utils import resolve_license_payload from ..utils.model_utils import determine_base_model +from .connectivity_guard import OFFLINE_FRIENDLY_MESSAGE, is_expected_offline_error from .errors import RateLimitError logger = logging.getLogger(__name__) @@ -274,11 +275,18 @@ class MetadataSyncService: else "No provider returned metadata" ) + resolved_error = last_error or default_error + if is_expected_offline_error(resolved_error): + resolved_error = OFFLINE_FRIENDLY_MESSAGE + error_msg = ( - f"Error fetching metadata: {last_error or default_error} " + f"Error fetching metadata: {resolved_error} " f"(model_name={model_data.get('model_name', '')})" ) - logger.error(error_msg) + if is_expected_offline_error(resolved_error): + logger.info(error_msg) + else: + logger.error(error_msg) return False, error_msg model_data["from_civitai"] = True @@ -347,6 +355,9 @@ class MetadataSyncService: return False, error_msg except Exception as exc: # pragma: no cover - error path error_msg = f"Error fetching metadata: {exc}" + if is_expected_offline_error(str(exc)): + logger.info(OFFLINE_FRIENDLY_MESSAGE) + return False, OFFLINE_FRIENDLY_MESSAGE logger.error(error_msg, exc_info=True) return False, error_msg diff --git a/tests/services/test_civitai_client.py b/tests/services/test_civitai_client.py index d98871ed..038ccba4 100644 --- a/tests/services/test_civitai_client.py +++ b/tests/services/test_civitai_client.py @@ -5,6 +5,7 @@ import pytest from py.services import civitai_client as civitai_client_module from py.services.civitai_client import CivitaiClient +from py.services.connectivity_guard import OFFLINE_COOLDOWN_ERROR, OFFLINE_FRIENDLY_MESSAGE from py.services.errors import RateLimitError, ResourceNotFoundError from py.services.model_metadata_provider import ModelMetadataProviderManager @@ -115,6 +116,20 @@ async def test_get_model_by_hash_handles_not_found(monkeypatch, downloader): assert error == "Model not found" +async def test_get_model_by_hash_handles_offline_cooldown(downloader): + async def fake_make_request(method, url, use_auth=True, **kwargs): + return False, OFFLINE_COOLDOWN_ERROR + + downloader.make_request = fake_make_request + + client = await CivitaiClient.get_instance() + + result, error = await client.get_model_by_hash("missing") + + assert result is None + assert error == OFFLINE_FRIENDLY_MESSAGE + + async def test_get_model_by_hash_propagates_rate_limit(monkeypatch, downloader): async def fake_make_request(method, url, use_auth=True, **kwargs): return False, RateLimitError("limited", retry_after=4) diff --git a/tests/services/test_connectivity_guard.py b/tests/services/test_connectivity_guard.py new file mode 100644 index 00000000..712a9f6e --- /dev/null +++ b/tests/services/test_connectivity_guard.py @@ -0,0 +1,124 @@ +import asyncio +import errno +from datetime import datetime, timedelta + +import pytest + +from py.services.connectivity_guard import ( + OFFLINE_COOLDOWN_ERROR, + ConnectivityGuard, +) +from py.services.downloader import Downloader + + +@pytest.fixture(autouse=True) +def reset_connectivity_guard_singleton(): + ConnectivityGuard._instance = None + yield + ConnectivityGuard._instance = None + + +async def test_connectivity_guard_enters_cooldown_after_threshold(): + guard = await ConnectivityGuard.get_instance() + + assert guard.online is True + assert guard.should_block_request() is False + + guard.register_network_failure(OSError(errno.ENETUNREACH, "unreachable")) + guard.register_network_failure(asyncio.TimeoutError("timeout")) + + assert guard.should_block_request() is False + assert guard.failure_count == 2 + + guard.register_network_failure(ConnectionRefusedError("refused")) + + assert guard.online is False + assert guard.failure_count == 3 + assert guard.should_block_request() is True + assert guard.cooldown_remaining_seconds() > 0 + + +async def test_connectivity_guard_scopes_cooldown_to_destination(): + guard = await ConnectivityGuard.get_instance() + + destination_a = "civitai.com" + destination_b = "api.github.com" + + guard.register_network_failure( + OSError(errno.ENETUNREACH, "unreachable"), + destination_a, + ) + guard.register_network_failure(asyncio.TimeoutError("timeout"), destination_a) + guard.register_network_failure(ConnectionRefusedError("refused"), destination_a) + + assert guard.should_block_request(destination_a) is True + assert guard.should_block_request(destination_b) is False + + guard.register_success(destination_a) + assert guard.should_block_request(destination_a) is False + + +async def test_connectivity_guard_recovers_after_success(): + guard = await ConnectivityGuard.get_instance() + guard.online = False + guard.failure_count = 5 + guard.cooldown_until = datetime.now() + timedelta(seconds=90) + + guard.register_success() + + assert guard.online is True + assert guard.failure_count == 0 + assert guard.cooldown_until is None + assert guard.should_block_request() is False + + +async def test_downloader_short_circuits_all_request_helpers_during_cooldown(): + guard = await ConnectivityGuard.get_instance() + destination = "example.invalid" + guard.register_network_failure( + OSError(errno.ENETUNREACH, "unreachable"), + destination, + ) + guard.register_network_failure(asyncio.TimeoutError("timeout"), destination) + guard.register_network_failure( + ConnectionRefusedError("refused"), + destination, + ) + + downloader = Downloader() + + ok, payload = await downloader.make_request("GET", f"https://{destination}") + assert ok is False + assert payload == OFFLINE_COOLDOWN_ERROR + + ok, payload, headers = await downloader.download_to_memory(f"https://{destination}") + assert ok is False + assert payload == OFFLINE_COOLDOWN_ERROR + assert headers is None + + ok, payload = await downloader.get_response_headers(f"https://{destination}") + assert ok is False + assert payload == OFFLINE_COOLDOWN_ERROR + + +async def test_downloader_only_short_circuits_requests_for_same_destination(): + guard = await ConnectivityGuard.get_instance() + guard.register_network_failure( + OSError(errno.ENETUNREACH, "unreachable"), + "example.invalid", + ) + guard.register_network_failure(asyncio.TimeoutError("timeout"), "example.invalid") + guard.register_network_failure( + ConnectionRefusedError("refused"), + "example.invalid", + ) + + downloader = Downloader() + ok, payload = await downloader.make_request("GET", "https://example.invalid") + assert ok is False + assert payload == OFFLINE_COOLDOWN_ERROR + + assert ( + guard.should_block_request(downloader._guard_destination("https://example.com")) + is False + ) diff --git a/tests/services/test_metadata_sync_service.py b/tests/services/test_metadata_sync_service.py index d9e4a90c..c89a6af5 100644 --- a/tests/services/test_metadata_sync_service.py +++ b/tests/services/test_metadata_sync_service.py @@ -4,6 +4,7 @@ from unittest.mock import AsyncMock import pytest +from py.services.connectivity_guard import OFFLINE_COOLDOWN_ERROR, OFFLINE_FRIENDLY_MESSAGE from py.services.errors import RateLimitError from py.services.metadata_sync_service import MetadataSyncService @@ -243,17 +244,32 @@ async def test_fetch_and_update_model_handles_missing_remote_metadata(tmp_path): assert not ok assert "Model not found" in error - assert model_data["from_civitai"] is False - assert model_data["civitai_deleted"] is True - helpers.metadata_manager.hydrate_model_data.assert_not_awaited() - assert model_data["hydrated"] is True - helpers.metadata_manager.save_metadata.assert_awaited_once() - call_args = helpers.metadata_manager.save_metadata.await_args - assert call_args.args[0].endswith("model.safetensors") - assert "folder" not in call_args.args[1] - assert call_args.args[1]["hydrated"] is True +@pytest.mark.asyncio +async def test_fetch_and_update_model_returns_friendly_offline_message(tmp_path): + helpers = build_service() + helpers.default_provider.get_model_by_hash.return_value = (None, OFFLINE_COOLDOWN_ERROR) + + model_path = tmp_path / "model.safetensors" + model_data = { + "model_name": "Local", + "folder": "root", + "file_path": str(model_path), + } + update_cache = AsyncMock(return_value=True) + + ok, error = await helpers.service.fetch_and_update_model( + sha256="abc", + file_path=str(model_path), + model_data=model_data, + update_cache_func=update_cache, + ) + + assert ok is False + assert error is not None + assert OFFLINE_FRIENDLY_MESSAGE in error + update_cache.assert_not_awaited() @pytest.mark.asyncio