Merge branch 'main' into codex/github-mention-fixnetwork-add-connectivityguard-to-short

This commit is contained in:
pixelpaws
2026-04-20 15:54:30 +08:00
committed by GitHub
28 changed files with 4469 additions and 194 deletions

View File

@@ -0,0 +1,570 @@
from __future__ import annotations
import asyncio
import json
import logging
import os
import secrets
import shutil
import socket
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
import aiohttp
from .downloader import DownloadProgress, get_downloader
from .aria2_transfer_state import Aria2TransferStateStore
from .settings_manager import get_settings_manager
logger = logging.getLogger(__name__)
CIVITAI_DOWNLOAD_URL_PREFIXES = (
"https://civitai.com/api/download/",
"https://civitai.red/api/download/",
)
class Aria2Error(RuntimeError):
"""Raised when aria2 integration fails."""
@dataclass
class Aria2Transfer:
"""Track an aria2 download registered by the Python coordinator."""
gid: str
save_path: str
class Aria2Downloader:
"""Manage an aria2 RPC daemon for experimental model downloads."""
_instance = None
_lock = asyncio.Lock()
@classmethod
async def get_instance(cls) -> "Aria2Downloader":
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self) -> None:
if hasattr(self, "_initialized"):
return
self._initialized = True
self._process: Optional[asyncio.subprocess.Process] = None
self._rpc_port: Optional[int] = None
self._rpc_secret = ""
self._rpc_url = ""
self._rpc_session: Optional[aiohttp.ClientSession] = None
self._rpc_session_lock = asyncio.Lock()
self._process_lock = asyncio.Lock()
self._transfers: Dict[str, Aria2Transfer] = {}
self._poll_interval = 0.5
self._state_store = Aria2TransferStateStore()
@property
def is_running(self) -> bool:
return self._process is not None and self._process.returncode is None
async def download_file(
self,
url: str,
save_path: str,
*,
download_id: str,
progress_callback=None,
headers: Optional[Dict[str, str]] = None,
) -> Tuple[bool, str]:
"""Download a file using aria2 RPC and wait for completion."""
await self._ensure_process()
save_path = os.path.abspath(save_path)
transfer = self._transfers.get(download_id)
if transfer is None or os.path.abspath(transfer.save_path) != save_path:
gid = await self._schedule_download(
url,
save_path,
download_id=download_id,
headers=headers,
)
transfer = Aria2Transfer(gid=gid, save_path=save_path)
self._transfers[download_id] = transfer
try:
while True:
status = await self.get_status(download_id)
if status is None:
return False, "aria2 download not found"
snapshot = self._build_progress_snapshot(status)
if progress_callback is not None:
await self._dispatch_progress(progress_callback, snapshot)
state = status.get("status", "")
if state == "complete":
completed_path = self._resolve_completed_path(status, save_path)
return True, completed_path
if state == "error":
return False, status.get("errorMessage") or "aria2 download failed"
if state == "removed":
return False, "Download was cancelled"
await asyncio.sleep(self._poll_interval)
finally:
self._transfers.pop(download_id, None)
async def _schedule_download(
self,
url: str,
save_path: str,
*,
download_id: str,
headers: Optional[Dict[str, str]] = None,
) -> str:
save_dir = os.path.dirname(save_path)
out_name = os.path.basename(save_path)
Path(save_dir).mkdir(parents=True, exist_ok=True)
resolved_url = url
request_headers = headers
if headers and url.startswith(CIVITAI_DOWNLOAD_URL_PREFIXES):
resolved_url = await self._resolve_authenticated_redirect_url(url, headers)
if resolved_url != url:
request_headers = None
logger.debug(
"Resolved Civitai download %s to signed URL for aria2",
download_id,
)
options: Dict[str, str] = {
"dir": save_dir,
"out": out_name,
"continue": "true",
"max-connection-per-server": "4",
"split": "4",
"min-split-size": "1M",
"allow-overwrite": "true",
"auto-file-renaming": "false",
"file-allocation": "none",
}
if request_headers:
options["header"] = [
f"{key}: {value}" for key, value in request_headers.items()
]
logger.debug(
"Submitting aria2 download %s -> %s (auth=%s, civitai_signed=%s)",
download_id,
save_path,
bool(request_headers),
resolved_url != url,
)
try:
gid = await self._rpc_call("aria2.addUri", [[resolved_url], options])
except Exception as exc:
raise Aria2Error(f"Failed to schedule aria2 download: {exc}") from exc
logger.debug("aria2 accepted download %s with gid %s", download_id, gid)
await self._state_store.upsert(
download_id,
{
"gid": gid,
"save_path": save_path,
"status": "downloading",
"url": url,
},
)
return gid
async def get_status(self, download_id: str) -> Optional[Dict[str, Any]]:
"""Return the raw aria2 status payload for a known download."""
transfer = self._transfers.get(download_id)
if transfer is None:
return None
keys = [
"gid",
"status",
"totalLength",
"completedLength",
"downloadSpeed",
"errorMessage",
"files",
]
try:
status = await self._rpc_call("aria2.tellStatus", [transfer.gid, keys])
except Exception as exc:
raise Aria2Error(f"Failed to query aria2 download status: {exc}") from exc
if isinstance(status, dict):
return status
return None
async def get_status_by_gid(self, gid: str) -> Optional[Dict[str, Any]]:
keys = [
"gid",
"status",
"totalLength",
"completedLength",
"downloadSpeed",
"errorMessage",
"files",
]
try:
status = await self._rpc_call("aria2.tellStatus", [gid, keys])
except Exception as exc:
message = str(exc)
if "cannot be found" in message.lower() or "not found" in message.lower():
return None
raise Aria2Error(f"Failed to query aria2 download status: {exc}") from exc
if isinstance(status, dict):
return status
return None
async def restore_transfer(self, download_id: str, gid: str, save_path: str) -> None:
await self._ensure_process()
self._transfers[download_id] = Aria2Transfer(
gid=gid,
save_path=os.path.abspath(save_path),
)
async def reassign_transfer(
self, from_download_id: str, to_download_id: str
) -> Optional[Aria2Transfer]:
transfer = self._transfers.get(from_download_id)
if transfer is None:
return None
self._transfers[to_download_id] = transfer
if from_download_id != to_download_id:
self._transfers.pop(from_download_id, None)
return transfer
async def has_transfer(self, download_id: str) -> bool:
return download_id in self._transfers
async def pause_download(self, download_id: str) -> Dict[str, Any]:
transfer = self._transfers.get(download_id)
if transfer is None:
return {"success": False, "error": "Download task not found"}
try:
await self._rpc_call("aria2.forcePause", [transfer.gid])
except Exception as exc:
return {"success": False, "error": str(exc)}
await self._state_store.upsert(download_id, {"status": "paused"})
return {"success": True, "message": "Download paused successfully"}
async def resume_download(self, download_id: str) -> Dict[str, Any]:
transfer = self._transfers.get(download_id)
if transfer is None:
return {"success": False, "error": "Download task not found"}
try:
await self._rpc_call("aria2.unpause", [transfer.gid])
except Exception as exc:
return {"success": False, "error": str(exc)}
await self._state_store.upsert(download_id, {"status": "downloading"})
return {"success": True, "message": "Download resumed successfully"}
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
transfer = self._transfers.get(download_id)
if transfer is None:
return {"success": False, "error": "Download task not found"}
try:
await self._rpc_call("aria2.forceRemove", [transfer.gid])
except Exception as exc:
return {"success": False, "error": str(exc)}
await self._state_store.remove(download_id)
return {"success": True, "message": "Download cancelled successfully"}
async def close(self) -> None:
"""Shut down the RPC process and session."""
if self._rpc_session is not None:
await self._rpc_session.close()
self._rpc_session = None
process = self._process
self._process = None
self._transfers.clear()
if process is None:
return
if process.returncode is None:
process.terminate()
try:
await asyncio.wait_for(process.wait(), timeout=5.0)
except asyncio.TimeoutError:
process.kill()
await process.wait()
async def _dispatch_progress(self, callback, snapshot: DownloadProgress) -> None:
try:
result = callback(snapshot, snapshot)
except TypeError:
result = callback(snapshot.percent_complete)
if asyncio.iscoroutine(result):
await result
elif hasattr(result, "__await__"):
await result
def _build_progress_snapshot(self, status: Dict[str, Any]) -> DownloadProgress:
completed = self._parse_int(status.get("completedLength"))
total = self._parse_int(status.get("totalLength"))
speed = float(self._parse_int(status.get("downloadSpeed")))
percent = 0.0
if total > 0:
percent = (completed / total) * 100.0
return DownloadProgress(
percent_complete=max(0.0, min(percent, 100.0)),
bytes_downloaded=completed,
total_bytes=total or None,
bytes_per_second=speed,
timestamp=datetime.now().timestamp(),
)
def _resolve_completed_path(self, status: Dict[str, Any], default_path: str) -> str:
files = status.get("files")
if isinstance(files, list) and files:
first = files[0]
if isinstance(first, dict):
candidate = first.get("path")
if isinstance(candidate, str) and candidate:
return candidate
return default_path
@staticmethod
def _parse_int(value: Any) -> int:
try:
return int(value)
except (TypeError, ValueError):
return 0
async def _resolve_authenticated_redirect_url(
self,
url: str,
headers: Dict[str, str],
) -> str:
downloader = await get_downloader()
session = await downloader.session
request_headers = dict(downloader.default_headers)
request_headers.update(headers)
request_headers["Accept-Encoding"] = "identity"
try:
async with session.get(
url,
headers=request_headers,
allow_redirects=False,
proxy=downloader.proxy_url,
) as response:
if response.status in {301, 302, 303, 307, 308}:
location = response.headers.get("Location")
if location:
return location
raise Aria2Error(
"Authenticated Civitai redirect did not include a Location header"
)
if response.status == 200:
return url
body = await response.text()
raise Aria2Error(
f"Failed to resolve authenticated Civitai redirect: status={response.status} body={body[:300]}"
)
except aiohttp.ClientError as exc:
raise Aria2Error(
f"Failed to resolve authenticated Civitai redirect: {exc}"
) from exc
async def _ensure_process(self) -> None:
async with self._process_lock:
if self.is_running and await self._ping():
return
await self.close()
executable = self._resolve_executable()
self._rpc_port = self._find_free_port()
self._rpc_secret = secrets.token_hex(16)
self._rpc_url = f"http://127.0.0.1:{self._rpc_port}/jsonrpc"
command = [
executable,
"--enable-rpc=true",
"--rpc-listen-all=false",
f"--rpc-listen-port={self._rpc_port}",
f"--rpc-secret={self._rpc_secret}",
"--check-certificate=true",
"--allow-overwrite=true",
"--auto-file-renaming=false",
"--file-allocation=none",
"--max-concurrent-downloads=5",
"--continue=true",
"--daemon=false",
"--quiet=true",
f"--stop-with-process={os.getpid()}",
]
logger.info("Starting aria2 RPC daemon from %s", executable)
self._process = await asyncio.create_subprocess_exec(
*command,
stdout=asyncio.subprocess.DEVNULL,
stderr=asyncio.subprocess.PIPE,
)
await self._wait_until_ready()
def _resolve_executable(self) -> str:
settings = get_settings_manager()
configured_path = (settings.get("aria2c_path") or "").strip()
candidate = configured_path or "aria2c"
resolved = shutil.which(candidate)
if resolved:
return resolved
if configured_path and os.path.isfile(configured_path) and os.access(
configured_path, os.X_OK
):
return configured_path
raise Aria2Error(
"aria2c executable was not found. Install aria2 or configure aria2c_path."
)
async def _wait_until_ready(self) -> None:
assert self._process is not None
start_time = asyncio.get_running_loop().time()
last_error = ""
while asyncio.get_running_loop().time() - start_time < 10.0:
if self._process.returncode is not None:
stderr_output = ""
if self._process.stderr is not None:
try:
stderr_output = (
await asyncio.wait_for(self._process.stderr.read(), timeout=0.2)
).decode("utf-8", errors="replace")
except Exception:
stderr_output = ""
raise Aria2Error(
f"aria2 RPC process exited early with code {self._process.returncode}: {stderr_output.strip()}"
)
try:
if await self._ping():
return
except Exception as exc: # pragma: no cover - startup race
last_error = str(exc)
await asyncio.sleep(0.2)
raise Aria2Error(
f"Timed out waiting for aria2 RPC to become ready{': ' + last_error if last_error else ''}"
)
async def _ping(self) -> bool:
try:
result = await self._rpc_call("aria2.getVersion", [])
except Exception:
return False
return isinstance(result, dict)
async def _rpc_call(self, method: str, params: list[Any]) -> Any:
if not self._rpc_url:
raise Aria2Error("aria2 RPC endpoint is not initialized")
session = await self._get_rpc_session()
payload = {
"jsonrpc": "2.0",
"id": secrets.token_hex(8),
"method": method,
"params": [f"token:{self._rpc_secret}", *params],
}
async with session.post(self._rpc_url, json=payload) as response:
text = await response.text()
try:
body = json.loads(text)
except json.JSONDecodeError:
body = None
if body is None:
if response.status != 200:
raise Aria2Error(
f"aria2 RPC returned status {response.status} with non-JSON body: {text}"
)
raise Aria2Error(f"Invalid aria2 RPC response: {text}")
if "error" in body:
error = body["error"] or {}
code = error.get("code") if isinstance(error, dict) else None
message = error.get("message") if isinstance(error, dict) else str(error)
logger.error(
"aria2 RPC %s failed with HTTP %s, code=%s, message=%s",
method,
response.status,
code,
message,
)
status_message = (
f"aria2 RPC {method} failed with status {response.status}: {message}"
if response.status != 200
else message
)
raise Aria2Error(status_message or "Unknown aria2 RPC error")
if response.status != 200:
logger.error(
"aria2 RPC %s returned unexpected HTTP status %s without error payload: %s",
method,
response.status,
body,
)
raise Aria2Error(
f"aria2 RPC {method} returned unexpected status {response.status}"
)
return body.get("result")
async def _get_rpc_session(self) -> aiohttp.ClientSession:
if self._rpc_session is None or self._rpc_session.closed:
async with self._rpc_session_lock:
if self._rpc_session is None or self._rpc_session.closed:
timeout = aiohttp.ClientTimeout(total=30)
self._rpc_session = aiohttp.ClientSession(timeout=timeout)
return self._rpc_session
@staticmethod
def _find_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("127.0.0.1", 0))
sock.listen(1)
return int(sock.getsockname()[1])
async def get_aria2_downloader() -> Aria2Downloader:
"""Get the singleton aria2 downloader."""
return await Aria2Downloader.get_instance()

View File

@@ -0,0 +1,108 @@
from __future__ import annotations
import asyncio
import json
import os
from copy import deepcopy
from typing import Any, Dict, Optional
from ..utils.cache_paths import get_cache_base_dir
def get_aria2_state_path() -> str:
base_dir = get_cache_base_dir(create=True)
state_dir = os.path.join(base_dir, "aria2")
os.makedirs(state_dir, exist_ok=True)
return os.path.join(state_dir, "downloads.json")
class Aria2TransferStateStore:
"""Persist aria2 transfer metadata needed for restart recovery."""
_locks_by_path: Dict[str, asyncio.Lock] = {}
def __init__(self, state_path: Optional[str] = None) -> None:
self._state_path = os.path.abspath(state_path or get_aria2_state_path())
self._lock = self._locks_by_path.setdefault(self._state_path, asyncio.Lock())
def _read_all_unlocked(self) -> Dict[str, Dict[str, Any]]:
try:
with open(self._state_path, "r", encoding="utf-8") as handle:
data = json.load(handle)
except FileNotFoundError:
return {}
except json.JSONDecodeError:
return {}
if not isinstance(data, dict):
return {}
normalized: Dict[str, Dict[str, Any]] = {}
for download_id, entry in data.items():
if isinstance(download_id, str) and isinstance(entry, dict):
normalized[download_id] = entry
return normalized
def _write_all_unlocked(self, data: Dict[str, Dict[str, Any]]) -> None:
directory = os.path.dirname(self._state_path)
if directory:
os.makedirs(directory, exist_ok=True)
temp_path = f"{self._state_path}.tmp"
with open(temp_path, "w", encoding="utf-8") as handle:
json.dump(data, handle, ensure_ascii=True, indent=2, sort_keys=True)
os.replace(temp_path, self._state_path)
async def load_all(self) -> Dict[str, Dict[str, Any]]:
async with self._lock:
return deepcopy(self._read_all_unlocked())
async def get(self, download_id: str) -> Optional[Dict[str, Any]]:
async with self._lock:
return deepcopy(self._read_all_unlocked().get(download_id))
async def upsert(self, download_id: str, payload: Dict[str, Any]) -> Dict[str, Any]:
async with self._lock:
data = self._read_all_unlocked()
current = data.get(download_id, {})
current.update(payload)
data[download_id] = current
self._write_all_unlocked(data)
return deepcopy(current)
async def remove(self, download_id: str) -> None:
async with self._lock:
data = self._read_all_unlocked()
if download_id in data:
del data[download_id]
self._write_all_unlocked(data)
async def find_by_save_path(
self, save_path: str, *, exclude_download_id: Optional[str] = None
) -> Optional[Dict[str, Any]]:
normalized_target = os.path.abspath(save_path)
async with self._lock:
data = self._read_all_unlocked()
for download_id, entry in data.items():
if exclude_download_id and download_id == exclude_download_id:
continue
candidate = entry.get("save_path")
if isinstance(candidate, str) and os.path.abspath(candidate) == normalized_target:
result = dict(entry)
result["download_id"] = download_id
return result
return None
async def reassign(self, from_download_id: str, to_download_id: str) -> Optional[Dict[str, Any]]:
async with self._lock:
data = self._read_all_unlocked()
existing = data.get(from_download_id)
if existing is None:
return None
updated = dict(existing)
updated["download_id"] = to_download_id
data[to_download_id] = updated
if from_download_id != to_download_id:
data.pop(from_download_id, None)
self._write_all_unlocked(data)
return deepcopy(updated)

View File

@@ -6,6 +6,7 @@ import asyncio
import errno
import logging
import socket
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any
@@ -49,68 +50,118 @@ class ConnectivityGuard:
if hasattr(self, "_initialized"):
return
self._initialized = True
self.online = True
self.failure_count = 0
self.cooldown_until: datetime | None = None
self._default_destination = "__global__"
self._destination_states: dict[str, _DestinationState] = {
self._default_destination: _DestinationState()
}
self.base_backoff_seconds = 30
self.max_backoff_seconds = 300
self.failure_threshold = 3
@property
def online(self) -> bool:
return self._state_for_destination(None).online
@online.setter
def online(self, value: bool) -> None:
self._state_for_destination(None).online = value
@property
def failure_count(self) -> int:
return self._state_for_destination(None).failure_count
@failure_count.setter
def failure_count(self, value: int) -> None:
self._state_for_destination(None).failure_count = value
@property
def cooldown_until(self) -> datetime | None:
return self._state_for_destination(None).cooldown_until
@cooldown_until.setter
def cooldown_until(self, value: datetime | None) -> None:
self._state_for_destination(None).cooldown_until = value
def _now(self) -> datetime:
return datetime.now()
def in_cooldown(self) -> bool:
if self.cooldown_until is None:
def _normalize_destination(self, destination: str | None) -> str:
if destination is None or not destination.strip():
return self._default_destination
return destination.lower().strip()
def _state_for_destination(self, destination: str | None) -> "_DestinationState":
destination_key = self._normalize_destination(destination)
if destination_key not in self._destination_states:
self._destination_states[destination_key] = _DestinationState()
return self._destination_states[destination_key]
def in_cooldown(self, destination: str | None = None) -> bool:
state = self._state_for_destination(destination)
if state.cooldown_until is None:
return False
return self._now() < self.cooldown_until
return self._now() < state.cooldown_until
def cooldown_remaining_seconds(self) -> float:
if self.cooldown_until is None:
def cooldown_remaining_seconds(self, destination: str | None = None) -> float:
state = self._state_for_destination(destination)
if state.cooldown_until is None:
return 0.0
return max(0.0, (self.cooldown_until - self._now()).total_seconds())
return max(0.0, (state.cooldown_until - self._now()).total_seconds())
def should_block_request(self) -> bool:
return self.in_cooldown()
def should_block_request(self, destination: str | None = None) -> bool:
return self.in_cooldown(destination)
def register_success(self) -> None:
was_offline = (not self.online) or self.cooldown_until is not None
self.online = True
self.failure_count = 0
self.cooldown_until = None
def register_success(self, destination: str | None = None) -> None:
destination_key = self._normalize_destination(destination)
state = self._state_for_destination(destination_key)
was_offline = (not state.online) or state.cooldown_until is not None
state.online = True
state.failure_count = 0
state.cooldown_until = None
if was_offline:
logger.info("Connectivity restored; requests resumed.")
logger.info(
"Connectivity restored for destination '%s'; requests resumed.",
destination_key,
)
def register_network_failure(self, exc: Exception) -> None:
self.online = False
self.failure_count += 1
def register_network_failure(
self, exc: Exception, destination: str | None = None
) -> None:
destination_key = self._normalize_destination(destination)
state = self._state_for_destination(destination_key)
state.online = False
state.failure_count += 1
if self.failure_count < self.failure_threshold:
if state.failure_count < self.failure_threshold:
logger.debug(
"Network failure tracked (%d/%d): %s",
self.failure_count,
"Network failure tracked for destination '%s' (%d/%d): %s",
destination_key,
state.failure_count,
self.failure_threshold,
exc,
)
return
retry_step = self.failure_count - self.failure_threshold
retry_step = state.failure_count - self.failure_threshold
backoff = min(
self.max_backoff_seconds,
self.base_backoff_seconds * (2**retry_step),
)
should_log_warning = not self.in_cooldown()
self.cooldown_until = self._now() + timedelta(seconds=backoff)
should_log_warning = not self.in_cooldown(destination_key)
state.cooldown_until = self._now() + timedelta(seconds=backoff)
if should_log_warning:
logger.warning(
"Connectivity offline; enter cooldown for %ss after %d network failures.",
"Connectivity offline for destination '%s'; enter cooldown for %ss after %d network failures.",
destination_key,
int(backoff),
self.failure_count,
state.failure_count,
)
else:
logger.debug(
"Cooldown still active; failure_count=%d, backoff=%ss.",
self.failure_count,
"Cooldown still active for destination '%s'; failure_count=%d, backoff=%ss.",
destination_key,
state.failure_count,
int(backoff),
)
@@ -145,3 +196,9 @@ class ConnectivityGuard:
return False
@dataclass
class _DestinationState:
online: bool = True
failure_count: int = 0
cooldown_until: datetime | None = None

File diff suppressed because it is too large Load Diff

View File

@@ -18,6 +18,7 @@ from collections import deque
from dataclasses import dataclass
from datetime import datetime, timedelta
from email.utils import parsedate_to_datetime
from urllib.parse import urlparse
from typing import Optional, Dict, Tuple, Callable, Union, Awaitable
from ..services.settings_manager import get_settings_manager
from .connectivity_guard import (
@@ -828,7 +829,7 @@ class Downloader:
) as response:
if response.status == 200:
content = await response.read()
guard.register_success()
guard.register_success(destination)
if return_headers:
return True, content, dict(response.headers)
else:
@@ -874,7 +875,8 @@ class Downloader:
Tuple[bool, Union[Dict, str]]: (success, headers dict or error message)
"""
guard = await ConnectivityGuard.get_instance()
if guard.should_block_request():
destination = self._guard_destination(url)
if guard.should_block_request(destination):
return False, OFFLINE_COOLDOWN_ERROR
try:
@@ -898,15 +900,15 @@ class Downloader:
url, headers=headers, proxy=self.proxy_url
) as response:
if response.status == 200:
guard.register_success()
guard.register_success(destination)
return True, dict(response.headers)
else:
return False, f"Head request failed with status {response.status}"
except Exception as e:
if guard.is_network_unreachable_error(e):
guard.register_network_failure(e)
if guard.should_block_request():
guard.register_network_failure(e, destination)
if guard.should_block_request(destination):
return False, OFFLINE_COOLDOWN_ERROR
logger.debug("Network unavailable during header probe: %s", e)
return False, str(e)
@@ -935,7 +937,8 @@ class Downloader:
Tuple[bool, Union[Dict, str]]: (success, response data or error message)
"""
guard = await ConnectivityGuard.get_instance()
if guard.should_block_request():
destination = self._guard_destination(url)
if guard.should_block_request(destination):
return False, OFFLINE_COOLDOWN_ERROR
try:
@@ -961,7 +964,7 @@ class Downloader:
method, url, headers=headers, **kwargs
) as response:
if response.status == 200:
guard.register_success()
guard.register_success(destination)
# Try to parse as JSON, fall back to text
try:
data = await response.json()
@@ -993,8 +996,8 @@ class Downloader:
except Exception as e:
if guard.is_network_unreachable_error(e):
guard.register_network_failure(e)
if guard.should_block_request():
guard.register_network_failure(e, destination)
if guard.should_block_request(destination):
return False, OFFLINE_COOLDOWN_ERROR
logger.debug("Network unavailable for %s %s: %s", method, url, e)
return False, str(e)
@@ -1048,6 +1051,14 @@ class Downloader:
delta = retry_datetime - datetime.now(tz=retry_datetime.tzinfo)
return max(0.0, delta.total_seconds())
@staticmethod
def _guard_destination(url: str) -> str:
"""Build per-destination connectivity guard scope from request URL."""
parsed_url = urlparse(url)
if parsed_url.hostname:
return parsed_url.hostname.lower()
return "unknown"
# Global instance accessor
async def get_downloader() -> Downloader:

View File

@@ -55,6 +55,8 @@ DEFAULT_KEYS_CLEANUP_THRESHOLD = 10
DEFAULT_SETTINGS: Dict[str, Any] = {
"civitai_api_key": "",
"civitai_host": "civitai.com",
"download_backend": "python",
"aria2c_path": "",
"use_portable_settings": False,
"hash_chunk_size_mb": DEFAULT_HASH_CHUNK_SIZE_MB,
"language": "en",
@@ -761,34 +763,29 @@ class SettingsManager:
if self._preserve_disk_template:
return
folder_paths = self.settings.get("folder_paths", {})
updated = False
def _check_and_auto_set(key: str, setting_key: str) -> bool:
"""Repair default roots when empty or no longer present."""
current = self.settings.get(setting_key, "")
candidates = folder_paths.get(key, [])
if not isinstance(candidates, list) or not candidates:
primary_candidates = self._get_valid_root_candidates(key)
if not primary_candidates:
return False
# Filter valid path strings
valid_paths = [p for p in candidates if isinstance(p, str) and p.strip()]
if not valid_paths:
allowed_roots = self._get_allowed_roots(key)
if current and current in allowed_roots:
return False
if current in valid_paths:
return False
self.settings[setting_key] = valid_paths[0]
self.settings[setting_key] = primary_candidates[0]
if current:
logger.info(
"Repaired stale %s from '%s' to '%s'",
"Repaired stale %s from '%s' to '%s' because it is not present in primary or extra roots",
setting_key,
current,
valid_paths[0],
primary_candidates[0],
)
else:
logger.info("Auto-set %s to '%s'", setting_key, valid_paths[0])
logger.info("Auto-set %s to '%s'", setting_key, primary_candidates[0])
return True
# Process all model types
@@ -811,6 +808,33 @@ class SettingsManager:
else:
self._save_settings()
def _get_valid_root_candidates(self, key: str) -> List[str]:
"""Return stable root candidates, preferring primary roots over extra roots."""
candidates: List[str] = []
seen: set[str] = set()
for mapping_key in ("folder_paths", "extra_folder_paths"):
raw_paths = self.settings.get(mapping_key, {})
if not isinstance(raw_paths, Mapping):
continue
values = raw_paths.get(key, [])
if not isinstance(values, list):
continue
for value in values:
if not isinstance(value, str):
continue
normalized = value.strip()
if not normalized or normalized in seen:
continue
seen.add(normalized)
candidates.append(normalized)
return candidates
def _get_allowed_roots(self, key: str) -> set[str]:
"""Return all valid roots for a model type, including extra roots."""
return set(self._get_valid_root_candidates(key))
def _check_environment_variables(self) -> None:
"""Check for environment variables and update settings if needed"""
env_api_key = os.environ.get("CIVITAI_API_KEY")