fix(network): scope connectivity cooldown by destination

This commit is contained in:
pixelpaws
2026-04-20 15:20:57 +08:00
parent 5a7f4dc88b
commit 7ab271c752
3 changed files with 154 additions and 42 deletions

View File

@@ -6,6 +6,7 @@ import asyncio
import errno import errno
import logging import logging
import socket import socket
from dataclasses import dataclass
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any from typing import Any
@@ -49,68 +50,118 @@ class ConnectivityGuard:
if hasattr(self, "_initialized"): if hasattr(self, "_initialized"):
return return
self._initialized = True self._initialized = True
self.online = True self._default_destination = "__global__"
self.failure_count = 0 self._destination_states: dict[str, _DestinationState] = {
self.cooldown_until: datetime | None = None self._default_destination: _DestinationState()
}
self.base_backoff_seconds = 30 self.base_backoff_seconds = 30
self.max_backoff_seconds = 300 self.max_backoff_seconds = 300
self.failure_threshold = 3 self.failure_threshold = 3
@property
def online(self) -> bool:
return self._state_for_destination(None).online
@online.setter
def online(self, value: bool) -> None:
self._state_for_destination(None).online = value
@property
def failure_count(self) -> int:
return self._state_for_destination(None).failure_count
@failure_count.setter
def failure_count(self, value: int) -> None:
self._state_for_destination(None).failure_count = value
@property
def cooldown_until(self) -> datetime | None:
return self._state_for_destination(None).cooldown_until
@cooldown_until.setter
def cooldown_until(self, value: datetime | None) -> None:
self._state_for_destination(None).cooldown_until = value
def _now(self) -> datetime: def _now(self) -> datetime:
return datetime.now() return datetime.now()
def in_cooldown(self) -> bool: def _normalize_destination(self, destination: str | None) -> str:
if self.cooldown_until is None: if destination is None or not destination.strip():
return self._default_destination
return destination.lower().strip()
def _state_for_destination(self, destination: str | None) -> "_DestinationState":
destination_key = self._normalize_destination(destination)
if destination_key not in self._destination_states:
self._destination_states[destination_key] = _DestinationState()
return self._destination_states[destination_key]
def in_cooldown(self, destination: str | None = None) -> bool:
state = self._state_for_destination(destination)
if state.cooldown_until is None:
return False return False
return self._now() < self.cooldown_until return self._now() < state.cooldown_until
def cooldown_remaining_seconds(self) -> float: def cooldown_remaining_seconds(self, destination: str | None = None) -> float:
if self.cooldown_until is None: state = self._state_for_destination(destination)
if state.cooldown_until is None:
return 0.0 return 0.0
return max(0.0, (self.cooldown_until - self._now()).total_seconds()) return max(0.0, (state.cooldown_until - self._now()).total_seconds())
def should_block_request(self) -> bool: def should_block_request(self, destination: str | None = None) -> bool:
return self.in_cooldown() return self.in_cooldown(destination)
def register_success(self) -> None: def register_success(self, destination: str | None = None) -> None:
was_offline = (not self.online) or self.cooldown_until is not None destination_key = self._normalize_destination(destination)
self.online = True state = self._state_for_destination(destination_key)
self.failure_count = 0 was_offline = (not state.online) or state.cooldown_until is not None
self.cooldown_until = None state.online = True
state.failure_count = 0
state.cooldown_until = None
if was_offline: if was_offline:
logger.info("Connectivity restored; requests resumed.") logger.info(
"Connectivity restored for destination '%s'; requests resumed.",
destination_key,
)
def register_network_failure(self, exc: Exception) -> None: def register_network_failure(
self.online = False self, exc: Exception, destination: str | None = None
self.failure_count += 1 ) -> None:
destination_key = self._normalize_destination(destination)
state = self._state_for_destination(destination_key)
state.online = False
state.failure_count += 1
if self.failure_count < self.failure_threshold: if state.failure_count < self.failure_threshold:
logger.debug( logger.debug(
"Network failure tracked (%d/%d): %s", "Network failure tracked for destination '%s' (%d/%d): %s",
self.failure_count, destination_key,
state.failure_count,
self.failure_threshold, self.failure_threshold,
exc, exc,
) )
return return
retry_step = self.failure_count - self.failure_threshold retry_step = state.failure_count - self.failure_threshold
backoff = min( backoff = min(
self.max_backoff_seconds, self.max_backoff_seconds,
self.base_backoff_seconds * (2**retry_step), self.base_backoff_seconds * (2**retry_step),
) )
should_log_warning = not self.in_cooldown() should_log_warning = not self.in_cooldown(destination_key)
self.cooldown_until = self._now() + timedelta(seconds=backoff) state.cooldown_until = self._now() + timedelta(seconds=backoff)
if should_log_warning: if should_log_warning:
logger.warning( logger.warning(
"Connectivity offline; enter cooldown for %ss after %d network failures.", "Connectivity offline for destination '%s'; enter cooldown for %ss after %d network failures.",
destination_key,
int(backoff), int(backoff),
self.failure_count, state.failure_count,
) )
else: else:
logger.debug( logger.debug(
"Cooldown still active; failure_count=%d, backoff=%ss.", "Cooldown still active for destination '%s'; failure_count=%d, backoff=%ss.",
self.failure_count, destination_key,
state.failure_count,
int(backoff), int(backoff),
) )
@@ -145,3 +196,9 @@ class ConnectivityGuard:
return False return False
@dataclass
class _DestinationState:
online: bool = True
failure_count: int = 0
cooldown_until: datetime | None = None

View File

@@ -18,6 +18,7 @@ from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta from datetime import datetime, timedelta
from email.utils import parsedate_to_datetime from email.utils import parsedate_to_datetime
from urllib.parse import urlparse
from typing import Optional, Dict, Tuple, Callable, Union, Awaitable from typing import Optional, Dict, Tuple, Callable, Union, Awaitable
from ..services.settings_manager import get_settings_manager from ..services.settings_manager import get_settings_manager
from .connectivity_guard import ( from .connectivity_guard import (
@@ -802,7 +803,8 @@ class Downloader:
Tuple[bool, Union[bytes, str], Optional[Dict]]: (success, content or error message, response headers if requested) Tuple[bool, Union[bytes, str], Optional[Dict]]: (success, content or error message, response headers if requested)
""" """
guard = await ConnectivityGuard.get_instance() guard = await ConnectivityGuard.get_instance()
if guard.should_block_request(): destination = self._guard_destination(url)
if guard.should_block_request(destination):
return False, OFFLINE_COOLDOWN_ERROR, None return False, OFFLINE_COOLDOWN_ERROR, None
try: try:
@@ -827,7 +829,7 @@ class Downloader:
) as response: ) as response:
if response.status == 200: if response.status == 200:
content = await response.read() content = await response.read()
guard.register_success() guard.register_success(destination)
if return_headers: if return_headers:
return True, content, dict(response.headers) return True, content, dict(response.headers)
else: else:
@@ -847,8 +849,8 @@ class Downloader:
except Exception as e: except Exception as e:
if guard.is_network_unreachable_error(e): if guard.is_network_unreachable_error(e):
guard.register_network_failure(e) guard.register_network_failure(e, destination)
if guard.should_block_request(): if guard.should_block_request(destination):
return False, OFFLINE_COOLDOWN_ERROR, None return False, OFFLINE_COOLDOWN_ERROR, None
logger.debug("Network unavailable during memory download: %s", e) logger.debug("Network unavailable during memory download: %s", e)
return False, str(e), None return False, str(e), None
@@ -873,7 +875,8 @@ class Downloader:
Tuple[bool, Union[Dict, str]]: (success, headers dict or error message) Tuple[bool, Union[Dict, str]]: (success, headers dict or error message)
""" """
guard = await ConnectivityGuard.get_instance() guard = await ConnectivityGuard.get_instance()
if guard.should_block_request(): destination = self._guard_destination(url)
if guard.should_block_request(destination):
return False, OFFLINE_COOLDOWN_ERROR return False, OFFLINE_COOLDOWN_ERROR
try: try:
@@ -897,15 +900,15 @@ class Downloader:
url, headers=headers, proxy=self.proxy_url url, headers=headers, proxy=self.proxy_url
) as response: ) as response:
if response.status == 200: if response.status == 200:
guard.register_success() guard.register_success(destination)
return True, dict(response.headers) return True, dict(response.headers)
else: else:
return False, f"Head request failed with status {response.status}" return False, f"Head request failed with status {response.status}"
except Exception as e: except Exception as e:
if guard.is_network_unreachable_error(e): if guard.is_network_unreachable_error(e):
guard.register_network_failure(e) guard.register_network_failure(e, destination)
if guard.should_block_request(): if guard.should_block_request(destination):
return False, OFFLINE_COOLDOWN_ERROR return False, OFFLINE_COOLDOWN_ERROR
logger.debug("Network unavailable during header probe: %s", e) logger.debug("Network unavailable during header probe: %s", e)
return False, str(e) return False, str(e)
@@ -934,7 +937,8 @@ class Downloader:
Tuple[bool, Union[Dict, str]]: (success, response data or error message) Tuple[bool, Union[Dict, str]]: (success, response data or error message)
""" """
guard = await ConnectivityGuard.get_instance() guard = await ConnectivityGuard.get_instance()
if guard.should_block_request(): destination = self._guard_destination(url)
if guard.should_block_request(destination):
return False, OFFLINE_COOLDOWN_ERROR return False, OFFLINE_COOLDOWN_ERROR
try: try:
@@ -960,7 +964,7 @@ class Downloader:
method, url, headers=headers, **kwargs method, url, headers=headers, **kwargs
) as response: ) as response:
if response.status == 200: if response.status == 200:
guard.register_success() guard.register_success(destination)
# Try to parse as JSON, fall back to text # Try to parse as JSON, fall back to text
try: try:
data = await response.json() data = await response.json()
@@ -992,8 +996,8 @@ class Downloader:
except Exception as e: except Exception as e:
if guard.is_network_unreachable_error(e): if guard.is_network_unreachable_error(e):
guard.register_network_failure(e) guard.register_network_failure(e, destination)
if guard.should_block_request(): if guard.should_block_request(destination):
return False, OFFLINE_COOLDOWN_ERROR return False, OFFLINE_COOLDOWN_ERROR
logger.debug("Network unavailable for %s %s: %s", method, url, e) logger.debug("Network unavailable for %s %s: %s", method, url, e)
return False, str(e) return False, str(e)
@@ -1047,6 +1051,14 @@ class Downloader:
delta = retry_datetime - datetime.now(tz=retry_datetime.tzinfo) delta = retry_datetime - datetime.now(tz=retry_datetime.tzinfo)
return max(0.0, delta.total_seconds()) return max(0.0, delta.total_seconds())
@staticmethod
def _guard_destination(url: str) -> str:
"""Build per-destination connectivity guard scope from request URL."""
parsed_url = urlparse(url)
if parsed_url.hostname:
return parsed_url.hostname.lower()
return "unknown"
# Global instance accessor # Global instance accessor
async def get_downloader() -> Downloader: async def get_downloader() -> Downloader:

View File

@@ -38,6 +38,26 @@ async def test_connectivity_guard_enters_cooldown_after_threshold():
assert guard.cooldown_remaining_seconds() > 0 assert guard.cooldown_remaining_seconds() > 0
async def test_connectivity_guard_scopes_cooldown_to_destination():
guard = await ConnectivityGuard.get_instance()
destination_a = "civitai.com"
destination_b = "api.github.com"
guard.register_network_failure(
OSError(errno.ENETUNREACH, "unreachable"),
destination_a,
)
guard.register_network_failure(asyncio.TimeoutError("timeout"), destination_a)
guard.register_network_failure(ConnectionRefusedError("refused"), destination_a)
assert guard.should_block_request(destination_a) is True
assert guard.should_block_request(destination_b) is False
guard.register_success(destination_a)
assert guard.should_block_request(destination_a) is False
async def test_connectivity_guard_recovers_after_success(): async def test_connectivity_guard_recovers_after_success():
guard = await ConnectivityGuard.get_instance() guard = await ConnectivityGuard.get_instance()
guard.online = False guard.online = False
@@ -72,3 +92,26 @@ async def test_downloader_short_circuits_all_request_helpers_during_cooldown():
ok, payload = await downloader.get_response_headers("https://example.invalid") ok, payload = await downloader.get_response_headers("https://example.invalid")
assert ok is False assert ok is False
assert payload == OFFLINE_COOLDOWN_ERROR assert payload == OFFLINE_COOLDOWN_ERROR
async def test_downloader_only_short_circuits_requests_for_same_destination():
guard = await ConnectivityGuard.get_instance()
guard.register_network_failure(
OSError(errno.ENETUNREACH, "unreachable"),
"example.invalid",
)
guard.register_network_failure(asyncio.TimeoutError("timeout"), "example.invalid")
guard.register_network_failure(
ConnectionRefusedError("refused"),
"example.invalid",
)
downloader = Downloader()
ok, payload = await downloader.make_request("GET", "https://example.invalid")
assert ok is False
assert payload == OFFLINE_COOLDOWN_ERROR
assert (
guard.should_block_request(downloader._guard_destination("https://example.com"))
is False
)