From 7ab271c752f037cbdc161ad100be4ebf9d582cc9 Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Mon, 20 Apr 2026 15:20:57 +0800 Subject: [PATCH] fix(network): scope connectivity cooldown by destination --- py/services/connectivity_guard.py | 117 ++++++++++++++++------ py/services/downloader.py | 36 ++++--- tests/services/test_connectivity_guard.py | 43 ++++++++ 3 files changed, 154 insertions(+), 42 deletions(-) diff --git a/py/services/connectivity_guard.py b/py/services/connectivity_guard.py index 05de8004..1f60d5df 100644 --- a/py/services/connectivity_guard.py +++ b/py/services/connectivity_guard.py @@ -6,6 +6,7 @@ import asyncio import errno import logging import socket +from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any @@ -49,68 +50,118 @@ class ConnectivityGuard: if hasattr(self, "_initialized"): return self._initialized = True - self.online = True - self.failure_count = 0 - self.cooldown_until: datetime | None = None + self._default_destination = "__global__" + self._destination_states: dict[str, _DestinationState] = { + self._default_destination: _DestinationState() + } self.base_backoff_seconds = 30 self.max_backoff_seconds = 300 self.failure_threshold = 3 + @property + def online(self) -> bool: + return self._state_for_destination(None).online + + @online.setter + def online(self, value: bool) -> None: + self._state_for_destination(None).online = value + + @property + def failure_count(self) -> int: + return self._state_for_destination(None).failure_count + + @failure_count.setter + def failure_count(self, value: int) -> None: + self._state_for_destination(None).failure_count = value + + @property + def cooldown_until(self) -> datetime | None: + return self._state_for_destination(None).cooldown_until + + @cooldown_until.setter + def cooldown_until(self, value: datetime | None) -> None: + self._state_for_destination(None).cooldown_until = value + def _now(self) -> datetime: return datetime.now() - def in_cooldown(self) -> bool: - if self.cooldown_until is None: + def _normalize_destination(self, destination: str | None) -> str: + if destination is None or not destination.strip(): + return self._default_destination + return destination.lower().strip() + + def _state_for_destination(self, destination: str | None) -> "_DestinationState": + destination_key = self._normalize_destination(destination) + if destination_key not in self._destination_states: + self._destination_states[destination_key] = _DestinationState() + return self._destination_states[destination_key] + + def in_cooldown(self, destination: str | None = None) -> bool: + state = self._state_for_destination(destination) + if state.cooldown_until is None: return False - return self._now() < self.cooldown_until + return self._now() < state.cooldown_until - def cooldown_remaining_seconds(self) -> float: - if self.cooldown_until is None: + def cooldown_remaining_seconds(self, destination: str | None = None) -> float: + state = self._state_for_destination(destination) + if state.cooldown_until is None: return 0.0 - return max(0.0, (self.cooldown_until - self._now()).total_seconds()) + return max(0.0, (state.cooldown_until - self._now()).total_seconds()) - def should_block_request(self) -> bool: - return self.in_cooldown() + def should_block_request(self, destination: str | None = None) -> bool: + return self.in_cooldown(destination) - def register_success(self) -> None: - was_offline = (not self.online) or self.cooldown_until is not None - self.online = True - self.failure_count = 0 - self.cooldown_until = None + def register_success(self, destination: str | None = None) -> None: + destination_key = self._normalize_destination(destination) + state = self._state_for_destination(destination_key) + was_offline = (not state.online) or state.cooldown_until is not None + state.online = True + state.failure_count = 0 + state.cooldown_until = None if was_offline: - logger.info("Connectivity restored; requests resumed.") + logger.info( + "Connectivity restored for destination '%s'; requests resumed.", + destination_key, + ) - def register_network_failure(self, exc: Exception) -> None: - self.online = False - self.failure_count += 1 + def register_network_failure( + self, exc: Exception, destination: str | None = None + ) -> None: + destination_key = self._normalize_destination(destination) + state = self._state_for_destination(destination_key) + state.online = False + state.failure_count += 1 - if self.failure_count < self.failure_threshold: + if state.failure_count < self.failure_threshold: logger.debug( - "Network failure tracked (%d/%d): %s", - self.failure_count, + "Network failure tracked for destination '%s' (%d/%d): %s", + destination_key, + state.failure_count, self.failure_threshold, exc, ) return - retry_step = self.failure_count - self.failure_threshold + retry_step = state.failure_count - self.failure_threshold backoff = min( self.max_backoff_seconds, self.base_backoff_seconds * (2**retry_step), ) - should_log_warning = not self.in_cooldown() - self.cooldown_until = self._now() + timedelta(seconds=backoff) + should_log_warning = not self.in_cooldown(destination_key) + state.cooldown_until = self._now() + timedelta(seconds=backoff) if should_log_warning: logger.warning( - "Connectivity offline; enter cooldown for %ss after %d network failures.", + "Connectivity offline for destination '%s'; enter cooldown for %ss after %d network failures.", + destination_key, int(backoff), - self.failure_count, + state.failure_count, ) else: logger.debug( - "Cooldown still active; failure_count=%d, backoff=%ss.", - self.failure_count, + "Cooldown still active for destination '%s'; failure_count=%d, backoff=%ss.", + destination_key, + state.failure_count, int(backoff), ) @@ -145,3 +196,9 @@ class ConnectivityGuard: return False + +@dataclass +class _DestinationState: + online: bool = True + failure_count: int = 0 + cooldown_until: datetime | None = None diff --git a/py/services/downloader.py b/py/services/downloader.py index 1be64542..20fe5851 100644 --- a/py/services/downloader.py +++ b/py/services/downloader.py @@ -18,6 +18,7 @@ from collections import deque from dataclasses import dataclass from datetime import datetime, timedelta from email.utils import parsedate_to_datetime +from urllib.parse import urlparse from typing import Optional, Dict, Tuple, Callable, Union, Awaitable from ..services.settings_manager import get_settings_manager from .connectivity_guard import ( @@ -802,7 +803,8 @@ class Downloader: Tuple[bool, Union[bytes, str], Optional[Dict]]: (success, content or error message, response headers if requested) """ guard = await ConnectivityGuard.get_instance() - if guard.should_block_request(): + destination = self._guard_destination(url) + if guard.should_block_request(destination): return False, OFFLINE_COOLDOWN_ERROR, None try: @@ -827,7 +829,7 @@ class Downloader: ) as response: if response.status == 200: content = await response.read() - guard.register_success() + guard.register_success(destination) if return_headers: return True, content, dict(response.headers) else: @@ -847,8 +849,8 @@ class Downloader: except Exception as e: if guard.is_network_unreachable_error(e): - guard.register_network_failure(e) - if guard.should_block_request(): + guard.register_network_failure(e, destination) + if guard.should_block_request(destination): return False, OFFLINE_COOLDOWN_ERROR, None logger.debug("Network unavailable during memory download: %s", e) return False, str(e), None @@ -873,7 +875,8 @@ class Downloader: Tuple[bool, Union[Dict, str]]: (success, headers dict or error message) """ guard = await ConnectivityGuard.get_instance() - if guard.should_block_request(): + destination = self._guard_destination(url) + if guard.should_block_request(destination): return False, OFFLINE_COOLDOWN_ERROR try: @@ -897,15 +900,15 @@ class Downloader: url, headers=headers, proxy=self.proxy_url ) as response: if response.status == 200: - guard.register_success() + guard.register_success(destination) return True, dict(response.headers) else: return False, f"Head request failed with status {response.status}" except Exception as e: if guard.is_network_unreachable_error(e): - guard.register_network_failure(e) - if guard.should_block_request(): + guard.register_network_failure(e, destination) + if guard.should_block_request(destination): return False, OFFLINE_COOLDOWN_ERROR logger.debug("Network unavailable during header probe: %s", e) return False, str(e) @@ -934,7 +937,8 @@ class Downloader: Tuple[bool, Union[Dict, str]]: (success, response data or error message) """ guard = await ConnectivityGuard.get_instance() - if guard.should_block_request(): + destination = self._guard_destination(url) + if guard.should_block_request(destination): return False, OFFLINE_COOLDOWN_ERROR try: @@ -960,7 +964,7 @@ class Downloader: method, url, headers=headers, **kwargs ) as response: if response.status == 200: - guard.register_success() + guard.register_success(destination) # Try to parse as JSON, fall back to text try: data = await response.json() @@ -992,8 +996,8 @@ class Downloader: except Exception as e: if guard.is_network_unreachable_error(e): - guard.register_network_failure(e) - if guard.should_block_request(): + guard.register_network_failure(e, destination) + if guard.should_block_request(destination): return False, OFFLINE_COOLDOWN_ERROR logger.debug("Network unavailable for %s %s: %s", method, url, e) return False, str(e) @@ -1047,6 +1051,14 @@ class Downloader: delta = retry_datetime - datetime.now(tz=retry_datetime.tzinfo) return max(0.0, delta.total_seconds()) + @staticmethod + def _guard_destination(url: str) -> str: + """Build per-destination connectivity guard scope from request URL.""" + parsed_url = urlparse(url) + if parsed_url.hostname: + return parsed_url.hostname.lower() + return "unknown" + # Global instance accessor async def get_downloader() -> Downloader: diff --git a/tests/services/test_connectivity_guard.py b/tests/services/test_connectivity_guard.py index 66a9bb6a..b7196d2e 100644 --- a/tests/services/test_connectivity_guard.py +++ b/tests/services/test_connectivity_guard.py @@ -38,6 +38,26 @@ async def test_connectivity_guard_enters_cooldown_after_threshold(): assert guard.cooldown_remaining_seconds() > 0 +async def test_connectivity_guard_scopes_cooldown_to_destination(): + guard = await ConnectivityGuard.get_instance() + + destination_a = "civitai.com" + destination_b = "api.github.com" + + guard.register_network_failure( + OSError(errno.ENETUNREACH, "unreachable"), + destination_a, + ) + guard.register_network_failure(asyncio.TimeoutError("timeout"), destination_a) + guard.register_network_failure(ConnectionRefusedError("refused"), destination_a) + + assert guard.should_block_request(destination_a) is True + assert guard.should_block_request(destination_b) is False + + guard.register_success(destination_a) + assert guard.should_block_request(destination_a) is False + + async def test_connectivity_guard_recovers_after_success(): guard = await ConnectivityGuard.get_instance() guard.online = False @@ -72,3 +92,26 @@ async def test_downloader_short_circuits_all_request_helpers_during_cooldown(): ok, payload = await downloader.get_response_headers("https://example.invalid") assert ok is False assert payload == OFFLINE_COOLDOWN_ERROR + + +async def test_downloader_only_short_circuits_requests_for_same_destination(): + guard = await ConnectivityGuard.get_instance() + guard.register_network_failure( + OSError(errno.ENETUNREACH, "unreachable"), + "example.invalid", + ) + guard.register_network_failure(asyncio.TimeoutError("timeout"), "example.invalid") + guard.register_network_failure( + ConnectionRefusedError("refused"), + "example.invalid", + ) + + downloader = Downloader() + ok, payload = await downloader.make_request("GET", "https://example.invalid") + assert ok is False + assert payload == OFFLINE_COOLDOWN_ERROR + + assert ( + guard.should_block_request(downloader._guard_destination("https://example.com")) + is False + )