Files
ComfyUI-Lora-Manager/py/services/aria2_downloader.py
2026-04-20 09:52:48 +08:00

571 lines
19 KiB
Python

from __future__ import annotations
import asyncio
import json
import logging
import os
import secrets
import shutil
import socket
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
import aiohttp
from .downloader import DownloadProgress, get_downloader
from .aria2_transfer_state import Aria2TransferStateStore
from .settings_manager import get_settings_manager
logger = logging.getLogger(__name__)
CIVITAI_DOWNLOAD_URL_PREFIXES = (
"https://civitai.com/api/download/",
"https://civitai.red/api/download/",
)
class Aria2Error(RuntimeError):
"""Raised when aria2 integration fails."""
@dataclass
class Aria2Transfer:
"""Track an aria2 download registered by the Python coordinator."""
gid: str
save_path: str
class Aria2Downloader:
"""Manage an aria2 RPC daemon for experimental model downloads."""
_instance = None
_lock = asyncio.Lock()
@classmethod
async def get_instance(cls) -> "Aria2Downloader":
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self) -> None:
if hasattr(self, "_initialized"):
return
self._initialized = True
self._process: Optional[asyncio.subprocess.Process] = None
self._rpc_port: Optional[int] = None
self._rpc_secret = ""
self._rpc_url = ""
self._rpc_session: Optional[aiohttp.ClientSession] = None
self._rpc_session_lock = asyncio.Lock()
self._process_lock = asyncio.Lock()
self._transfers: Dict[str, Aria2Transfer] = {}
self._poll_interval = 0.5
self._state_store = Aria2TransferStateStore()
@property
def is_running(self) -> bool:
return self._process is not None and self._process.returncode is None
async def download_file(
self,
url: str,
save_path: str,
*,
download_id: str,
progress_callback=None,
headers: Optional[Dict[str, str]] = None,
) -> Tuple[bool, str]:
"""Download a file using aria2 RPC and wait for completion."""
await self._ensure_process()
save_path = os.path.abspath(save_path)
transfer = self._transfers.get(download_id)
if transfer is None or os.path.abspath(transfer.save_path) != save_path:
gid = await self._schedule_download(
url,
save_path,
download_id=download_id,
headers=headers,
)
transfer = Aria2Transfer(gid=gid, save_path=save_path)
self._transfers[download_id] = transfer
try:
while True:
status = await self.get_status(download_id)
if status is None:
return False, "aria2 download not found"
snapshot = self._build_progress_snapshot(status)
if progress_callback is not None:
await self._dispatch_progress(progress_callback, snapshot)
state = status.get("status", "")
if state == "complete":
completed_path = self._resolve_completed_path(status, save_path)
return True, completed_path
if state == "error":
return False, status.get("errorMessage") or "aria2 download failed"
if state == "removed":
return False, "Download was cancelled"
await asyncio.sleep(self._poll_interval)
finally:
self._transfers.pop(download_id, None)
async def _schedule_download(
self,
url: str,
save_path: str,
*,
download_id: str,
headers: Optional[Dict[str, str]] = None,
) -> str:
save_dir = os.path.dirname(save_path)
out_name = os.path.basename(save_path)
Path(save_dir).mkdir(parents=True, exist_ok=True)
resolved_url = url
request_headers = headers
if headers and url.startswith(CIVITAI_DOWNLOAD_URL_PREFIXES):
resolved_url = await self._resolve_authenticated_redirect_url(url, headers)
if resolved_url != url:
request_headers = None
logger.debug(
"Resolved Civitai download %s to signed URL for aria2",
download_id,
)
options: Dict[str, str] = {
"dir": save_dir,
"out": out_name,
"continue": "true",
"max-connection-per-server": "4",
"split": "4",
"min-split-size": "1M",
"allow-overwrite": "true",
"auto-file-renaming": "false",
"file-allocation": "none",
}
if request_headers:
options["header"] = [
f"{key}: {value}" for key, value in request_headers.items()
]
logger.debug(
"Submitting aria2 download %s -> %s (auth=%s, civitai_signed=%s)",
download_id,
save_path,
bool(request_headers),
resolved_url != url,
)
try:
gid = await self._rpc_call("aria2.addUri", [[resolved_url], options])
except Exception as exc:
raise Aria2Error(f"Failed to schedule aria2 download: {exc}") from exc
logger.debug("aria2 accepted download %s with gid %s", download_id, gid)
await self._state_store.upsert(
download_id,
{
"gid": gid,
"save_path": save_path,
"status": "downloading",
"url": url,
},
)
return gid
async def get_status(self, download_id: str) -> Optional[Dict[str, Any]]:
"""Return the raw aria2 status payload for a known download."""
transfer = self._transfers.get(download_id)
if transfer is None:
return None
keys = [
"gid",
"status",
"totalLength",
"completedLength",
"downloadSpeed",
"errorMessage",
"files",
]
try:
status = await self._rpc_call("aria2.tellStatus", [transfer.gid, keys])
except Exception as exc:
raise Aria2Error(f"Failed to query aria2 download status: {exc}") from exc
if isinstance(status, dict):
return status
return None
async def get_status_by_gid(self, gid: str) -> Optional[Dict[str, Any]]:
keys = [
"gid",
"status",
"totalLength",
"completedLength",
"downloadSpeed",
"errorMessage",
"files",
]
try:
status = await self._rpc_call("aria2.tellStatus", [gid, keys])
except Exception as exc:
message = str(exc)
if "cannot be found" in message.lower() or "not found" in message.lower():
return None
raise Aria2Error(f"Failed to query aria2 download status: {exc}") from exc
if isinstance(status, dict):
return status
return None
async def restore_transfer(self, download_id: str, gid: str, save_path: str) -> None:
await self._ensure_process()
self._transfers[download_id] = Aria2Transfer(
gid=gid,
save_path=os.path.abspath(save_path),
)
async def reassign_transfer(
self, from_download_id: str, to_download_id: str
) -> Optional[Aria2Transfer]:
transfer = self._transfers.get(from_download_id)
if transfer is None:
return None
self._transfers[to_download_id] = transfer
if from_download_id != to_download_id:
self._transfers.pop(from_download_id, None)
return transfer
async def has_transfer(self, download_id: str) -> bool:
return download_id in self._transfers
async def pause_download(self, download_id: str) -> Dict[str, Any]:
transfer = self._transfers.get(download_id)
if transfer is None:
return {"success": False, "error": "Download task not found"}
try:
await self._rpc_call("aria2.forcePause", [transfer.gid])
except Exception as exc:
return {"success": False, "error": str(exc)}
await self._state_store.upsert(download_id, {"status": "paused"})
return {"success": True, "message": "Download paused successfully"}
async def resume_download(self, download_id: str) -> Dict[str, Any]:
transfer = self._transfers.get(download_id)
if transfer is None:
return {"success": False, "error": "Download task not found"}
try:
await self._rpc_call("aria2.unpause", [transfer.gid])
except Exception as exc:
return {"success": False, "error": str(exc)}
await self._state_store.upsert(download_id, {"status": "downloading"})
return {"success": True, "message": "Download resumed successfully"}
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
transfer = self._transfers.get(download_id)
if transfer is None:
return {"success": False, "error": "Download task not found"}
try:
await self._rpc_call("aria2.forceRemove", [transfer.gid])
except Exception as exc:
return {"success": False, "error": str(exc)}
await self._state_store.remove(download_id)
return {"success": True, "message": "Download cancelled successfully"}
async def close(self) -> None:
"""Shut down the RPC process and session."""
if self._rpc_session is not None:
await self._rpc_session.close()
self._rpc_session = None
process = self._process
self._process = None
self._transfers.clear()
if process is None:
return
if process.returncode is None:
process.terminate()
try:
await asyncio.wait_for(process.wait(), timeout=5.0)
except asyncio.TimeoutError:
process.kill()
await process.wait()
async def _dispatch_progress(self, callback, snapshot: DownloadProgress) -> None:
try:
result = callback(snapshot, snapshot)
except TypeError:
result = callback(snapshot.percent_complete)
if asyncio.iscoroutine(result):
await result
elif hasattr(result, "__await__"):
await result
def _build_progress_snapshot(self, status: Dict[str, Any]) -> DownloadProgress:
completed = self._parse_int(status.get("completedLength"))
total = self._parse_int(status.get("totalLength"))
speed = float(self._parse_int(status.get("downloadSpeed")))
percent = 0.0
if total > 0:
percent = (completed / total) * 100.0
return DownloadProgress(
percent_complete=max(0.0, min(percent, 100.0)),
bytes_downloaded=completed,
total_bytes=total or None,
bytes_per_second=speed,
timestamp=datetime.now().timestamp(),
)
def _resolve_completed_path(self, status: Dict[str, Any], default_path: str) -> str:
files = status.get("files")
if isinstance(files, list) and files:
first = files[0]
if isinstance(first, dict):
candidate = first.get("path")
if isinstance(candidate, str) and candidate:
return candidate
return default_path
@staticmethod
def _parse_int(value: Any) -> int:
try:
return int(value)
except (TypeError, ValueError):
return 0
async def _resolve_authenticated_redirect_url(
self,
url: str,
headers: Dict[str, str],
) -> str:
downloader = await get_downloader()
session = await downloader.session
request_headers = dict(downloader.default_headers)
request_headers.update(headers)
request_headers["Accept-Encoding"] = "identity"
try:
async with session.get(
url,
headers=request_headers,
allow_redirects=False,
proxy=downloader.proxy_url,
) as response:
if response.status in {301, 302, 303, 307, 308}:
location = response.headers.get("Location")
if location:
return location
raise Aria2Error(
"Authenticated Civitai redirect did not include a Location header"
)
if response.status == 200:
return url
body = await response.text()
raise Aria2Error(
f"Failed to resolve authenticated Civitai redirect: status={response.status} body={body[:300]}"
)
except aiohttp.ClientError as exc:
raise Aria2Error(
f"Failed to resolve authenticated Civitai redirect: {exc}"
) from exc
async def _ensure_process(self) -> None:
async with self._process_lock:
if self.is_running and await self._ping():
return
await self.close()
executable = self._resolve_executable()
self._rpc_port = self._find_free_port()
self._rpc_secret = secrets.token_hex(16)
self._rpc_url = f"http://127.0.0.1:{self._rpc_port}/jsonrpc"
command = [
executable,
"--enable-rpc=true",
"--rpc-listen-all=false",
f"--rpc-listen-port={self._rpc_port}",
f"--rpc-secret={self._rpc_secret}",
"--check-certificate=true",
"--allow-overwrite=true",
"--auto-file-renaming=false",
"--file-allocation=none",
"--max-concurrent-downloads=5",
"--continue=true",
"--daemon=false",
"--quiet=true",
f"--stop-with-process={os.getpid()}",
]
logger.info("Starting aria2 RPC daemon from %s", executable)
self._process = await asyncio.create_subprocess_exec(
*command,
stdout=asyncio.subprocess.DEVNULL,
stderr=asyncio.subprocess.PIPE,
)
await self._wait_until_ready()
def _resolve_executable(self) -> str:
settings = get_settings_manager()
configured_path = (settings.get("aria2c_path") or "").strip()
candidate = configured_path or "aria2c"
resolved = shutil.which(candidate)
if resolved:
return resolved
if configured_path and os.path.isfile(configured_path) and os.access(
configured_path, os.X_OK
):
return configured_path
raise Aria2Error(
"aria2c executable was not found. Install aria2 or configure aria2c_path."
)
async def _wait_until_ready(self) -> None:
assert self._process is not None
start_time = asyncio.get_running_loop().time()
last_error = ""
while asyncio.get_running_loop().time() - start_time < 10.0:
if self._process.returncode is not None:
stderr_output = ""
if self._process.stderr is not None:
try:
stderr_output = (
await asyncio.wait_for(self._process.stderr.read(), timeout=0.2)
).decode("utf-8", errors="replace")
except Exception:
stderr_output = ""
raise Aria2Error(
f"aria2 RPC process exited early with code {self._process.returncode}: {stderr_output.strip()}"
)
try:
if await self._ping():
return
except Exception as exc: # pragma: no cover - startup race
last_error = str(exc)
await asyncio.sleep(0.2)
raise Aria2Error(
f"Timed out waiting for aria2 RPC to become ready{': ' + last_error if last_error else ''}"
)
async def _ping(self) -> bool:
try:
result = await self._rpc_call("aria2.getVersion", [])
except Exception:
return False
return isinstance(result, dict)
async def _rpc_call(self, method: str, params: list[Any]) -> Any:
if not self._rpc_url:
raise Aria2Error("aria2 RPC endpoint is not initialized")
session = await self._get_rpc_session()
payload = {
"jsonrpc": "2.0",
"id": secrets.token_hex(8),
"method": method,
"params": [f"token:{self._rpc_secret}", *params],
}
async with session.post(self._rpc_url, json=payload) as response:
text = await response.text()
try:
body = json.loads(text)
except json.JSONDecodeError:
body = None
if body is None:
if response.status != 200:
raise Aria2Error(
f"aria2 RPC returned status {response.status} with non-JSON body: {text}"
)
raise Aria2Error(f"Invalid aria2 RPC response: {text}")
if "error" in body:
error = body["error"] or {}
code = error.get("code") if isinstance(error, dict) else None
message = error.get("message") if isinstance(error, dict) else str(error)
logger.error(
"aria2 RPC %s failed with HTTP %s, code=%s, message=%s",
method,
response.status,
code,
message,
)
status_message = (
f"aria2 RPC {method} failed with status {response.status}: {message}"
if response.status != 200
else message
)
raise Aria2Error(status_message or "Unknown aria2 RPC error")
if response.status != 200:
logger.error(
"aria2 RPC %s returned unexpected HTTP status %s without error payload: %s",
method,
response.status,
body,
)
raise Aria2Error(
f"aria2 RPC {method} returned unexpected status {response.status}"
)
return body.get("result")
async def _get_rpc_session(self) -> aiohttp.ClientSession:
if self._rpc_session is None or self._rpc_session.closed:
async with self._rpc_session_lock:
if self._rpc_session is None or self._rpc_session.closed:
timeout = aiohttp.ClientTimeout(total=30)
self._rpc_session = aiohttp.ClientSession(timeout=timeout)
return self._rpc_session
@staticmethod
def _find_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("127.0.0.1", 0))
sock.listen(1)
return int(sock.getsockname()[1])
async def get_aria2_downloader() -> Aria2Downloader:
"""Get the singleton aria2 downloader."""
return await Aria2Downloader.get_instance()