mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-06-09 12:39:23 -03:00
When certifi is available, pass its CA bundle path as --ca-certificate to the aria2c subprocess so that aria2 downloads use the same certificate store as Python aiohttp downloads. Graceful fallback when certifi is not installed.
603 lines
20 KiB
Python
603 lines
20 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import secrets
|
|
import shutil
|
|
import socket
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Optional, Tuple
|
|
|
|
import aiohttp
|
|
|
|
from .downloader import DownloadProgress, get_downloader, is_ssl_cert_verify_error
|
|
from .aria2_transfer_state import Aria2TransferStateStore
|
|
from .settings_manager import get_settings_manager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def _try_certifi_ca_path() -> str | None:
|
|
"""Return the certifi CA bundle path if available, else None."""
|
|
try:
|
|
import certifi # type: ignore[import-untyped]
|
|
|
|
path = certifi.where()
|
|
if os.path.isfile(path):
|
|
logger.debug(
|
|
"aria2 --ca-certificate: using certifi CA bundle at %s", path
|
|
)
|
|
return path
|
|
except ImportError:
|
|
pass
|
|
|
|
logger.debug("aria2 --ca-certificate: certifi not available")
|
|
return None
|
|
|
|
|
|
CIVITAI_DOWNLOAD_URL_PREFIXES = (
|
|
"https://civitai.com/api/download/",
|
|
"https://civitai.red/api/download/",
|
|
)
|
|
|
|
|
|
class Aria2Error(RuntimeError):
|
|
"""Raised when aria2 integration fails."""
|
|
|
|
|
|
@dataclass
|
|
class Aria2Transfer:
|
|
"""Track an aria2 download registered by the Python coordinator."""
|
|
|
|
gid: str
|
|
save_path: str
|
|
|
|
|
|
class Aria2Downloader:
|
|
"""Manage an aria2 RPC daemon for recommended model downloads."""
|
|
|
|
_instance = None
|
|
_lock = asyncio.Lock()
|
|
|
|
@classmethod
|
|
async def get_instance(cls) -> "Aria2Downloader":
|
|
async with cls._lock:
|
|
if cls._instance is None:
|
|
cls._instance = cls()
|
|
return cls._instance
|
|
|
|
def __init__(self) -> None:
|
|
if hasattr(self, "_initialized"):
|
|
return
|
|
|
|
self._initialized = True
|
|
self._process: Optional[asyncio.subprocess.Process] = None
|
|
self._rpc_port: Optional[int] = None
|
|
self._rpc_secret = ""
|
|
self._rpc_url = ""
|
|
self._rpc_session: Optional[aiohttp.ClientSession] = None
|
|
self._rpc_session_lock = asyncio.Lock()
|
|
self._process_lock = asyncio.Lock()
|
|
self._transfers: Dict[str, Aria2Transfer] = {}
|
|
self._poll_interval = 0.5
|
|
self._state_store = Aria2TransferStateStore()
|
|
|
|
@property
|
|
def is_running(self) -> bool:
|
|
return self._process is not None and self._process.returncode is None
|
|
|
|
async def download_file(
|
|
self,
|
|
url: str,
|
|
save_path: str,
|
|
*,
|
|
download_id: str,
|
|
progress_callback=None,
|
|
headers: Optional[Dict[str, str]] = None,
|
|
) -> Tuple[bool, str]:
|
|
"""Download a file using aria2 RPC and wait for completion."""
|
|
|
|
await self._ensure_process()
|
|
save_path = os.path.abspath(save_path)
|
|
transfer = self._transfers.get(download_id)
|
|
if transfer is None or os.path.abspath(transfer.save_path) != save_path:
|
|
gid = await self._schedule_download(
|
|
url,
|
|
save_path,
|
|
download_id=download_id,
|
|
headers=headers,
|
|
)
|
|
transfer = Aria2Transfer(gid=gid, save_path=save_path)
|
|
self._transfers[download_id] = transfer
|
|
|
|
try:
|
|
while True:
|
|
status = await self.get_status(download_id)
|
|
if status is None:
|
|
return False, "aria2 download not found"
|
|
|
|
snapshot = self._build_progress_snapshot(status)
|
|
if progress_callback is not None:
|
|
await self._dispatch_progress(progress_callback, snapshot)
|
|
|
|
state = status.get("status", "")
|
|
if state == "complete":
|
|
completed_path = self._resolve_completed_path(status, save_path)
|
|
return True, completed_path
|
|
if state == "error":
|
|
return False, status.get("errorMessage") or "aria2 download failed"
|
|
if state == "removed":
|
|
return False, "Download was cancelled"
|
|
|
|
await asyncio.sleep(self._poll_interval)
|
|
finally:
|
|
self._transfers.pop(download_id, None)
|
|
|
|
async def _schedule_download(
|
|
self,
|
|
url: str,
|
|
save_path: str,
|
|
*,
|
|
download_id: str,
|
|
headers: Optional[Dict[str, str]] = None,
|
|
) -> str:
|
|
save_dir = os.path.dirname(save_path)
|
|
out_name = os.path.basename(save_path)
|
|
|
|
Path(save_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
resolved_url = url
|
|
request_headers = headers
|
|
if headers and url.startswith(CIVITAI_DOWNLOAD_URL_PREFIXES):
|
|
resolved_url = await self._resolve_authenticated_redirect_url(url, headers)
|
|
if resolved_url != url:
|
|
request_headers = None
|
|
logger.debug(
|
|
"Resolved Civitai download %s to signed URL for aria2",
|
|
download_id,
|
|
)
|
|
|
|
options: Dict[str, str] = {
|
|
"dir": save_dir,
|
|
"out": out_name,
|
|
"continue": "true",
|
|
"max-connection-per-server": "4",
|
|
"split": "4",
|
|
"min-split-size": "1M",
|
|
"allow-overwrite": "true",
|
|
"auto-file-renaming": "false",
|
|
"file-allocation": "none",
|
|
}
|
|
if request_headers:
|
|
options["header"] = [
|
|
f"{key}: {value}" for key, value in request_headers.items()
|
|
]
|
|
|
|
logger.debug(
|
|
"Submitting aria2 download %s -> %s (auth=%s, civitai_signed=%s)",
|
|
download_id,
|
|
save_path,
|
|
bool(request_headers),
|
|
resolved_url != url,
|
|
)
|
|
|
|
try:
|
|
gid = await self._rpc_call("aria2.addUri", [[resolved_url], options])
|
|
except Exception as exc:
|
|
raise Aria2Error(f"Failed to schedule aria2 download: {exc}") from exc
|
|
|
|
logger.debug("aria2 accepted download %s with gid %s", download_id, gid)
|
|
await self._state_store.upsert(
|
|
download_id,
|
|
{
|
|
"gid": gid,
|
|
"save_path": save_path,
|
|
"status": "downloading",
|
|
"url": url,
|
|
},
|
|
)
|
|
return gid
|
|
|
|
async def get_status(self, download_id: str) -> Optional[Dict[str, Any]]:
|
|
"""Return the raw aria2 status payload for a known download."""
|
|
|
|
transfer = self._transfers.get(download_id)
|
|
if transfer is None:
|
|
return None
|
|
|
|
keys = [
|
|
"gid",
|
|
"status",
|
|
"totalLength",
|
|
"completedLength",
|
|
"downloadSpeed",
|
|
"errorMessage",
|
|
"files",
|
|
]
|
|
try:
|
|
status = await self._rpc_call("aria2.tellStatus", [transfer.gid, keys])
|
|
except Exception as exc:
|
|
raise Aria2Error(f"Failed to query aria2 download status: {exc}") from exc
|
|
|
|
if isinstance(status, dict):
|
|
return status
|
|
return None
|
|
|
|
async def get_status_by_gid(self, gid: str) -> Optional[Dict[str, Any]]:
|
|
keys = [
|
|
"gid",
|
|
"status",
|
|
"totalLength",
|
|
"completedLength",
|
|
"downloadSpeed",
|
|
"errorMessage",
|
|
"files",
|
|
]
|
|
try:
|
|
status = await self._rpc_call("aria2.tellStatus", [gid, keys])
|
|
except Exception as exc:
|
|
message = str(exc)
|
|
if "cannot be found" in message.lower() or "not found" in message.lower():
|
|
return None
|
|
raise Aria2Error(f"Failed to query aria2 download status: {exc}") from exc
|
|
|
|
if isinstance(status, dict):
|
|
return status
|
|
return None
|
|
|
|
async def restore_transfer(self, download_id: str, gid: str, save_path: str) -> None:
|
|
await self._ensure_process()
|
|
self._transfers[download_id] = Aria2Transfer(
|
|
gid=gid,
|
|
save_path=os.path.abspath(save_path),
|
|
)
|
|
|
|
async def reassign_transfer(
|
|
self, from_download_id: str, to_download_id: str
|
|
) -> Optional[Aria2Transfer]:
|
|
transfer = self._transfers.get(from_download_id)
|
|
if transfer is None:
|
|
return None
|
|
|
|
self._transfers[to_download_id] = transfer
|
|
if from_download_id != to_download_id:
|
|
self._transfers.pop(from_download_id, None)
|
|
return transfer
|
|
|
|
async def has_transfer(self, download_id: str) -> bool:
|
|
return download_id in self._transfers
|
|
|
|
async def pause_download(self, download_id: str) -> Dict[str, Any]:
|
|
transfer = self._transfers.get(download_id)
|
|
if transfer is None:
|
|
return {"success": False, "error": "Download task not found"}
|
|
|
|
try:
|
|
await self._rpc_call("aria2.forcePause", [transfer.gid])
|
|
except Exception as exc:
|
|
return {"success": False, "error": str(exc)}
|
|
|
|
await self._state_store.upsert(download_id, {"status": "paused"})
|
|
return {"success": True, "message": "Download paused successfully"}
|
|
|
|
async def resume_download(self, download_id: str) -> Dict[str, Any]:
|
|
transfer = self._transfers.get(download_id)
|
|
if transfer is None:
|
|
return {"success": False, "error": "Download task not found"}
|
|
|
|
try:
|
|
await self._rpc_call("aria2.unpause", [transfer.gid])
|
|
except Exception as exc:
|
|
return {"success": False, "error": str(exc)}
|
|
|
|
await self._state_store.upsert(download_id, {"status": "downloading"})
|
|
return {"success": True, "message": "Download resumed successfully"}
|
|
|
|
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
|
|
transfer = self._transfers.get(download_id)
|
|
if transfer is None:
|
|
return {"success": False, "error": "Download task not found"}
|
|
|
|
try:
|
|
await self._rpc_call("aria2.forceRemove", [transfer.gid])
|
|
except Exception as exc:
|
|
return {"success": False, "error": str(exc)}
|
|
|
|
await self._state_store.remove(download_id)
|
|
return {"success": True, "message": "Download cancelled successfully"}
|
|
|
|
async def close(self) -> None:
|
|
"""Shut down the RPC process and session."""
|
|
|
|
if self._rpc_session is not None:
|
|
await self._rpc_session.close()
|
|
self._rpc_session = None
|
|
|
|
process = self._process
|
|
self._process = None
|
|
self._transfers.clear()
|
|
|
|
if process is None:
|
|
return
|
|
|
|
if process.returncode is None:
|
|
process.terminate()
|
|
try:
|
|
await asyncio.wait_for(process.wait(), timeout=5.0)
|
|
except asyncio.TimeoutError:
|
|
process.kill()
|
|
await process.wait()
|
|
|
|
async def _dispatch_progress(self, callback, snapshot: DownloadProgress) -> None:
|
|
try:
|
|
result = callback(snapshot, snapshot)
|
|
except TypeError:
|
|
result = callback(snapshot.percent_complete)
|
|
|
|
if asyncio.iscoroutine(result):
|
|
await result
|
|
elif hasattr(result, "__await__"):
|
|
await result
|
|
|
|
def _build_progress_snapshot(self, status: Dict[str, Any]) -> DownloadProgress:
|
|
completed = self._parse_int(status.get("completedLength"))
|
|
total = self._parse_int(status.get("totalLength"))
|
|
speed = float(self._parse_int(status.get("downloadSpeed")))
|
|
percent = 0.0
|
|
if total > 0:
|
|
percent = (completed / total) * 100.0
|
|
|
|
return DownloadProgress(
|
|
percent_complete=max(0.0, min(percent, 100.0)),
|
|
bytes_downloaded=completed,
|
|
total_bytes=total or None,
|
|
bytes_per_second=speed,
|
|
timestamp=datetime.now().timestamp(),
|
|
)
|
|
|
|
def _resolve_completed_path(self, status: Dict[str, Any], default_path: str) -> str:
|
|
files = status.get("files")
|
|
if isinstance(files, list) and files:
|
|
first = files[0]
|
|
if isinstance(first, dict):
|
|
candidate = first.get("path")
|
|
if isinstance(candidate, str) and candidate:
|
|
return candidate
|
|
return default_path
|
|
|
|
@staticmethod
|
|
def _parse_int(value: Any) -> int:
|
|
try:
|
|
return int(value)
|
|
except (TypeError, ValueError):
|
|
return 0
|
|
|
|
async def _resolve_authenticated_redirect_url(
|
|
self,
|
|
url: str,
|
|
headers: Dict[str, str],
|
|
) -> str:
|
|
downloader = await get_downloader()
|
|
session = await downloader.session
|
|
request_headers = dict(downloader.default_headers)
|
|
request_headers.update(headers)
|
|
request_headers["Accept-Encoding"] = "identity"
|
|
|
|
try:
|
|
async with session.get(
|
|
url,
|
|
headers=request_headers,
|
|
allow_redirects=False,
|
|
proxy=downloader.proxy_url,
|
|
) as response:
|
|
if response.status in {301, 302, 303, 307, 308}:
|
|
location = response.headers.get("Location")
|
|
if location:
|
|
return location
|
|
raise Aria2Error(
|
|
"Authenticated Civitai redirect did not include a Location header"
|
|
)
|
|
|
|
if response.status == 200:
|
|
return url
|
|
|
|
body = await response.text()
|
|
raise Aria2Error(
|
|
f"Failed to resolve authenticated Civitai redirect: status={response.status} body={body[:300]}"
|
|
)
|
|
except aiohttp.ClientError as exc:
|
|
if is_ssl_cert_verify_error(exc):
|
|
logger.error(
|
|
"SSL certificate verification failed during Civitai redirect "
|
|
"resolution for %s. This is usually caused by an outdated CA "
|
|
"certificate bundle. Recommended fixes:\n"
|
|
" 1. pip install --upgrade certifi\n"
|
|
" 2. pip install pip-system-certs",
|
|
url,
|
|
)
|
|
raise Aria2Error(
|
|
f"Failed to resolve authenticated Civitai redirect: {exc}"
|
|
) from exc
|
|
|
|
async def _ensure_process(self) -> None:
|
|
async with self._process_lock:
|
|
if self.is_running and await self._ping():
|
|
return
|
|
|
|
await self.close()
|
|
|
|
executable = self._resolve_executable()
|
|
self._rpc_port = self._find_free_port()
|
|
self._rpc_secret = secrets.token_hex(16)
|
|
self._rpc_url = f"http://127.0.0.1:{self._rpc_port}/jsonrpc"
|
|
|
|
command = [
|
|
executable,
|
|
"--enable-rpc=true",
|
|
"--rpc-listen-all=false",
|
|
f"--rpc-listen-port={self._rpc_port}",
|
|
f"--rpc-secret={self._rpc_secret}",
|
|
"--check-certificate=true",
|
|
# Point aria2 at certifi's CA bundle when available so it uses
|
|
# the same certificate store as Python downloads.
|
|
*((
|
|
f"--ca-certificate={ca_cert}",
|
|
) if (ca_cert := _try_certifi_ca_path()) else ()),
|
|
"--allow-overwrite=true",
|
|
"--auto-file-renaming=false",
|
|
"--file-allocation=none",
|
|
"--max-concurrent-downloads=5",
|
|
"--continue=true",
|
|
"--daemon=false",
|
|
"--quiet=true",
|
|
f"--stop-with-process={os.getpid()}",
|
|
]
|
|
|
|
logger.info("Starting aria2 RPC daemon from %s", executable)
|
|
self._process = await asyncio.create_subprocess_exec(
|
|
*command,
|
|
stdout=asyncio.subprocess.DEVNULL,
|
|
stderr=asyncio.subprocess.PIPE,
|
|
)
|
|
|
|
await self._wait_until_ready()
|
|
|
|
def _resolve_executable(self) -> str:
|
|
settings = get_settings_manager()
|
|
configured_path = (settings.get("aria2c_path") or "").strip()
|
|
candidate = configured_path or "aria2c"
|
|
|
|
resolved = shutil.which(candidate)
|
|
if resolved:
|
|
return resolved
|
|
|
|
if configured_path and os.path.isfile(configured_path) and os.access(
|
|
configured_path, os.X_OK
|
|
):
|
|
return configured_path
|
|
|
|
raise Aria2Error(
|
|
"aria2c executable was not found. Install aria2 or configure aria2c_path."
|
|
)
|
|
|
|
async def _wait_until_ready(self) -> None:
|
|
assert self._process is not None
|
|
|
|
start_time = asyncio.get_running_loop().time()
|
|
last_error = ""
|
|
while asyncio.get_running_loop().time() - start_time < 10.0:
|
|
if self._process.returncode is not None:
|
|
stderr_output = ""
|
|
if self._process.stderr is not None:
|
|
try:
|
|
stderr_output = (
|
|
await asyncio.wait_for(self._process.stderr.read(), timeout=0.2)
|
|
).decode("utf-8", errors="replace")
|
|
except Exception:
|
|
stderr_output = ""
|
|
raise Aria2Error(
|
|
f"aria2 RPC process exited early with code {self._process.returncode}: {stderr_output.strip()}"
|
|
)
|
|
|
|
try:
|
|
if await self._ping():
|
|
return
|
|
except Exception as exc: # pragma: no cover - startup race
|
|
last_error = str(exc)
|
|
|
|
await asyncio.sleep(0.2)
|
|
|
|
raise Aria2Error(
|
|
f"Timed out waiting for aria2 RPC to become ready{': ' + last_error if last_error else ''}"
|
|
)
|
|
|
|
async def _ping(self) -> bool:
|
|
try:
|
|
result = await self._rpc_call("aria2.getVersion", [])
|
|
except Exception:
|
|
return False
|
|
|
|
return isinstance(result, dict)
|
|
|
|
async def _rpc_call(self, method: str, params: list[Any]) -> Any:
|
|
if not self._rpc_url:
|
|
raise Aria2Error("aria2 RPC endpoint is not initialized")
|
|
|
|
session = await self._get_rpc_session()
|
|
payload = {
|
|
"jsonrpc": "2.0",
|
|
"id": secrets.token_hex(8),
|
|
"method": method,
|
|
"params": [f"token:{self._rpc_secret}", *params],
|
|
}
|
|
|
|
async with session.post(self._rpc_url, json=payload) as response:
|
|
text = await response.text()
|
|
|
|
try:
|
|
body = json.loads(text)
|
|
except json.JSONDecodeError:
|
|
body = None
|
|
|
|
if body is None:
|
|
if response.status != 200:
|
|
raise Aria2Error(
|
|
f"aria2 RPC returned status {response.status} with non-JSON body: {text}"
|
|
)
|
|
raise Aria2Error(f"Invalid aria2 RPC response: {text}")
|
|
|
|
if "error" in body:
|
|
error = body["error"] or {}
|
|
code = error.get("code") if isinstance(error, dict) else None
|
|
message = error.get("message") if isinstance(error, dict) else str(error)
|
|
logger.error(
|
|
"aria2 RPC %s failed with HTTP %s, code=%s, message=%s",
|
|
method,
|
|
response.status,
|
|
code,
|
|
message,
|
|
)
|
|
status_message = (
|
|
f"aria2 RPC {method} failed with status {response.status}: {message}"
|
|
if response.status != 200
|
|
else message
|
|
)
|
|
raise Aria2Error(status_message or "Unknown aria2 RPC error")
|
|
|
|
if response.status != 200:
|
|
logger.error(
|
|
"aria2 RPC %s returned unexpected HTTP status %s without error payload: %s",
|
|
method,
|
|
response.status,
|
|
body,
|
|
)
|
|
raise Aria2Error(
|
|
f"aria2 RPC {method} returned unexpected status {response.status}"
|
|
)
|
|
|
|
return body.get("result")
|
|
|
|
async def _get_rpc_session(self) -> aiohttp.ClientSession:
|
|
if self._rpc_session is None or self._rpc_session.closed:
|
|
async with self._rpc_session_lock:
|
|
if self._rpc_session is None or self._rpc_session.closed:
|
|
timeout = aiohttp.ClientTimeout(total=30)
|
|
self._rpc_session = aiohttp.ClientSession(timeout=timeout)
|
|
return self._rpc_session
|
|
|
|
@staticmethod
|
|
def _find_free_port() -> int:
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
|
sock.bind(("127.0.0.1", 0))
|
|
sock.listen(1)
|
|
return int(sock.getsockname()[1])
|
|
|
|
|
|
async def get_aria2_downloader() -> Aria2Downloader:
|
|
"""Get the singleton aria2 downloader."""
|
|
|
|
return await Aria2Downloader.get_instance()
|