mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-07 08:56:43 -03:00
571 lines
19 KiB
Python
571 lines
19 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import secrets
|
|
import shutil
|
|
import socket
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Optional, Tuple
|
|
|
|
import aiohttp
|
|
|
|
from .downloader import DownloadProgress, get_downloader
|
|
from .aria2_transfer_state import Aria2TransferStateStore
|
|
from .settings_manager import get_settings_manager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
CIVITAI_DOWNLOAD_URL_PREFIXES = (
|
|
"https://civitai.com/api/download/",
|
|
"https://civitai.red/api/download/",
|
|
)
|
|
|
|
|
|
class Aria2Error(RuntimeError):
|
|
"""Raised when aria2 integration fails."""
|
|
|
|
|
|
@dataclass
|
|
class Aria2Transfer:
|
|
"""Track an aria2 download registered by the Python coordinator."""
|
|
|
|
gid: str
|
|
save_path: str
|
|
|
|
|
|
class Aria2Downloader:
|
|
"""Manage an aria2 RPC daemon for experimental model downloads."""
|
|
|
|
_instance = None
|
|
_lock = asyncio.Lock()
|
|
|
|
@classmethod
|
|
async def get_instance(cls) -> "Aria2Downloader":
|
|
async with cls._lock:
|
|
if cls._instance is None:
|
|
cls._instance = cls()
|
|
return cls._instance
|
|
|
|
def __init__(self) -> None:
|
|
if hasattr(self, "_initialized"):
|
|
return
|
|
|
|
self._initialized = True
|
|
self._process: Optional[asyncio.subprocess.Process] = None
|
|
self._rpc_port: Optional[int] = None
|
|
self._rpc_secret = ""
|
|
self._rpc_url = ""
|
|
self._rpc_session: Optional[aiohttp.ClientSession] = None
|
|
self._rpc_session_lock = asyncio.Lock()
|
|
self._process_lock = asyncio.Lock()
|
|
self._transfers: Dict[str, Aria2Transfer] = {}
|
|
self._poll_interval = 0.5
|
|
self._state_store = Aria2TransferStateStore()
|
|
|
|
@property
|
|
def is_running(self) -> bool:
|
|
return self._process is not None and self._process.returncode is None
|
|
|
|
async def download_file(
|
|
self,
|
|
url: str,
|
|
save_path: str,
|
|
*,
|
|
download_id: str,
|
|
progress_callback=None,
|
|
headers: Optional[Dict[str, str]] = None,
|
|
) -> Tuple[bool, str]:
|
|
"""Download a file using aria2 RPC and wait for completion."""
|
|
|
|
await self._ensure_process()
|
|
save_path = os.path.abspath(save_path)
|
|
transfer = self._transfers.get(download_id)
|
|
if transfer is None or os.path.abspath(transfer.save_path) != save_path:
|
|
gid = await self._schedule_download(
|
|
url,
|
|
save_path,
|
|
download_id=download_id,
|
|
headers=headers,
|
|
)
|
|
transfer = Aria2Transfer(gid=gid, save_path=save_path)
|
|
self._transfers[download_id] = transfer
|
|
|
|
try:
|
|
while True:
|
|
status = await self.get_status(download_id)
|
|
if status is None:
|
|
return False, "aria2 download not found"
|
|
|
|
snapshot = self._build_progress_snapshot(status)
|
|
if progress_callback is not None:
|
|
await self._dispatch_progress(progress_callback, snapshot)
|
|
|
|
state = status.get("status", "")
|
|
if state == "complete":
|
|
completed_path = self._resolve_completed_path(status, save_path)
|
|
return True, completed_path
|
|
if state == "error":
|
|
return False, status.get("errorMessage") or "aria2 download failed"
|
|
if state == "removed":
|
|
return False, "Download was cancelled"
|
|
|
|
await asyncio.sleep(self._poll_interval)
|
|
finally:
|
|
self._transfers.pop(download_id, None)
|
|
|
|
async def _schedule_download(
|
|
self,
|
|
url: str,
|
|
save_path: str,
|
|
*,
|
|
download_id: str,
|
|
headers: Optional[Dict[str, str]] = None,
|
|
) -> str:
|
|
save_dir = os.path.dirname(save_path)
|
|
out_name = os.path.basename(save_path)
|
|
|
|
Path(save_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
resolved_url = url
|
|
request_headers = headers
|
|
if headers and url.startswith(CIVITAI_DOWNLOAD_URL_PREFIXES):
|
|
resolved_url = await self._resolve_authenticated_redirect_url(url, headers)
|
|
if resolved_url != url:
|
|
request_headers = None
|
|
logger.debug(
|
|
"Resolved Civitai download %s to signed URL for aria2",
|
|
download_id,
|
|
)
|
|
|
|
options: Dict[str, str] = {
|
|
"dir": save_dir,
|
|
"out": out_name,
|
|
"continue": "true",
|
|
"max-connection-per-server": "4",
|
|
"split": "4",
|
|
"min-split-size": "1M",
|
|
"allow-overwrite": "true",
|
|
"auto-file-renaming": "false",
|
|
"file-allocation": "none",
|
|
}
|
|
if request_headers:
|
|
options["header"] = [
|
|
f"{key}: {value}" for key, value in request_headers.items()
|
|
]
|
|
|
|
logger.debug(
|
|
"Submitting aria2 download %s -> %s (auth=%s, civitai_signed=%s)",
|
|
download_id,
|
|
save_path,
|
|
bool(request_headers),
|
|
resolved_url != url,
|
|
)
|
|
|
|
try:
|
|
gid = await self._rpc_call("aria2.addUri", [[resolved_url], options])
|
|
except Exception as exc:
|
|
raise Aria2Error(f"Failed to schedule aria2 download: {exc}") from exc
|
|
|
|
logger.debug("aria2 accepted download %s with gid %s", download_id, gid)
|
|
await self._state_store.upsert(
|
|
download_id,
|
|
{
|
|
"gid": gid,
|
|
"save_path": save_path,
|
|
"status": "downloading",
|
|
"url": url,
|
|
},
|
|
)
|
|
return gid
|
|
|
|
async def get_status(self, download_id: str) -> Optional[Dict[str, Any]]:
|
|
"""Return the raw aria2 status payload for a known download."""
|
|
|
|
transfer = self._transfers.get(download_id)
|
|
if transfer is None:
|
|
return None
|
|
|
|
keys = [
|
|
"gid",
|
|
"status",
|
|
"totalLength",
|
|
"completedLength",
|
|
"downloadSpeed",
|
|
"errorMessage",
|
|
"files",
|
|
]
|
|
try:
|
|
status = await self._rpc_call("aria2.tellStatus", [transfer.gid, keys])
|
|
except Exception as exc:
|
|
raise Aria2Error(f"Failed to query aria2 download status: {exc}") from exc
|
|
|
|
if isinstance(status, dict):
|
|
return status
|
|
return None
|
|
|
|
async def get_status_by_gid(self, gid: str) -> Optional[Dict[str, Any]]:
|
|
keys = [
|
|
"gid",
|
|
"status",
|
|
"totalLength",
|
|
"completedLength",
|
|
"downloadSpeed",
|
|
"errorMessage",
|
|
"files",
|
|
]
|
|
try:
|
|
status = await self._rpc_call("aria2.tellStatus", [gid, keys])
|
|
except Exception as exc:
|
|
message = str(exc)
|
|
if "cannot be found" in message.lower() or "not found" in message.lower():
|
|
return None
|
|
raise Aria2Error(f"Failed to query aria2 download status: {exc}") from exc
|
|
|
|
if isinstance(status, dict):
|
|
return status
|
|
return None
|
|
|
|
async def restore_transfer(self, download_id: str, gid: str, save_path: str) -> None:
|
|
await self._ensure_process()
|
|
self._transfers[download_id] = Aria2Transfer(
|
|
gid=gid,
|
|
save_path=os.path.abspath(save_path),
|
|
)
|
|
|
|
async def reassign_transfer(
|
|
self, from_download_id: str, to_download_id: str
|
|
) -> Optional[Aria2Transfer]:
|
|
transfer = self._transfers.get(from_download_id)
|
|
if transfer is None:
|
|
return None
|
|
|
|
self._transfers[to_download_id] = transfer
|
|
if from_download_id != to_download_id:
|
|
self._transfers.pop(from_download_id, None)
|
|
return transfer
|
|
|
|
async def has_transfer(self, download_id: str) -> bool:
|
|
return download_id in self._transfers
|
|
|
|
async def pause_download(self, download_id: str) -> Dict[str, Any]:
|
|
transfer = self._transfers.get(download_id)
|
|
if transfer is None:
|
|
return {"success": False, "error": "Download task not found"}
|
|
|
|
try:
|
|
await self._rpc_call("aria2.forcePause", [transfer.gid])
|
|
except Exception as exc:
|
|
return {"success": False, "error": str(exc)}
|
|
|
|
await self._state_store.upsert(download_id, {"status": "paused"})
|
|
return {"success": True, "message": "Download paused successfully"}
|
|
|
|
async def resume_download(self, download_id: str) -> Dict[str, Any]:
|
|
transfer = self._transfers.get(download_id)
|
|
if transfer is None:
|
|
return {"success": False, "error": "Download task not found"}
|
|
|
|
try:
|
|
await self._rpc_call("aria2.unpause", [transfer.gid])
|
|
except Exception as exc:
|
|
return {"success": False, "error": str(exc)}
|
|
|
|
await self._state_store.upsert(download_id, {"status": "downloading"})
|
|
return {"success": True, "message": "Download resumed successfully"}
|
|
|
|
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
|
|
transfer = self._transfers.get(download_id)
|
|
if transfer is None:
|
|
return {"success": False, "error": "Download task not found"}
|
|
|
|
try:
|
|
await self._rpc_call("aria2.forceRemove", [transfer.gid])
|
|
except Exception as exc:
|
|
return {"success": False, "error": str(exc)}
|
|
|
|
await self._state_store.remove(download_id)
|
|
return {"success": True, "message": "Download cancelled successfully"}
|
|
|
|
async def close(self) -> None:
|
|
"""Shut down the RPC process and session."""
|
|
|
|
if self._rpc_session is not None:
|
|
await self._rpc_session.close()
|
|
self._rpc_session = None
|
|
|
|
process = self._process
|
|
self._process = None
|
|
self._transfers.clear()
|
|
|
|
if process is None:
|
|
return
|
|
|
|
if process.returncode is None:
|
|
process.terminate()
|
|
try:
|
|
await asyncio.wait_for(process.wait(), timeout=5.0)
|
|
except asyncio.TimeoutError:
|
|
process.kill()
|
|
await process.wait()
|
|
|
|
async def _dispatch_progress(self, callback, snapshot: DownloadProgress) -> None:
|
|
try:
|
|
result = callback(snapshot, snapshot)
|
|
except TypeError:
|
|
result = callback(snapshot.percent_complete)
|
|
|
|
if asyncio.iscoroutine(result):
|
|
await result
|
|
elif hasattr(result, "__await__"):
|
|
await result
|
|
|
|
def _build_progress_snapshot(self, status: Dict[str, Any]) -> DownloadProgress:
|
|
completed = self._parse_int(status.get("completedLength"))
|
|
total = self._parse_int(status.get("totalLength"))
|
|
speed = float(self._parse_int(status.get("downloadSpeed")))
|
|
percent = 0.0
|
|
if total > 0:
|
|
percent = (completed / total) * 100.0
|
|
|
|
return DownloadProgress(
|
|
percent_complete=max(0.0, min(percent, 100.0)),
|
|
bytes_downloaded=completed,
|
|
total_bytes=total or None,
|
|
bytes_per_second=speed,
|
|
timestamp=datetime.now().timestamp(),
|
|
)
|
|
|
|
def _resolve_completed_path(self, status: Dict[str, Any], default_path: str) -> str:
|
|
files = status.get("files")
|
|
if isinstance(files, list) and files:
|
|
first = files[0]
|
|
if isinstance(first, dict):
|
|
candidate = first.get("path")
|
|
if isinstance(candidate, str) and candidate:
|
|
return candidate
|
|
return default_path
|
|
|
|
@staticmethod
|
|
def _parse_int(value: Any) -> int:
|
|
try:
|
|
return int(value)
|
|
except (TypeError, ValueError):
|
|
return 0
|
|
|
|
async def _resolve_authenticated_redirect_url(
|
|
self,
|
|
url: str,
|
|
headers: Dict[str, str],
|
|
) -> str:
|
|
downloader = await get_downloader()
|
|
session = await downloader.session
|
|
request_headers = dict(downloader.default_headers)
|
|
request_headers.update(headers)
|
|
request_headers["Accept-Encoding"] = "identity"
|
|
|
|
try:
|
|
async with session.get(
|
|
url,
|
|
headers=request_headers,
|
|
allow_redirects=False,
|
|
proxy=downloader.proxy_url,
|
|
) as response:
|
|
if response.status in {301, 302, 303, 307, 308}:
|
|
location = response.headers.get("Location")
|
|
if location:
|
|
return location
|
|
raise Aria2Error(
|
|
"Authenticated Civitai redirect did not include a Location header"
|
|
)
|
|
|
|
if response.status == 200:
|
|
return url
|
|
|
|
body = await response.text()
|
|
raise Aria2Error(
|
|
f"Failed to resolve authenticated Civitai redirect: status={response.status} body={body[:300]}"
|
|
)
|
|
except aiohttp.ClientError as exc:
|
|
raise Aria2Error(
|
|
f"Failed to resolve authenticated Civitai redirect: {exc}"
|
|
) from exc
|
|
|
|
async def _ensure_process(self) -> None:
|
|
async with self._process_lock:
|
|
if self.is_running and await self._ping():
|
|
return
|
|
|
|
await self.close()
|
|
|
|
executable = self._resolve_executable()
|
|
self._rpc_port = self._find_free_port()
|
|
self._rpc_secret = secrets.token_hex(16)
|
|
self._rpc_url = f"http://127.0.0.1:{self._rpc_port}/jsonrpc"
|
|
|
|
command = [
|
|
executable,
|
|
"--enable-rpc=true",
|
|
"--rpc-listen-all=false",
|
|
f"--rpc-listen-port={self._rpc_port}",
|
|
f"--rpc-secret={self._rpc_secret}",
|
|
"--check-certificate=true",
|
|
"--allow-overwrite=true",
|
|
"--auto-file-renaming=false",
|
|
"--file-allocation=none",
|
|
"--max-concurrent-downloads=5",
|
|
"--continue=true",
|
|
"--daemon=false",
|
|
"--quiet=true",
|
|
f"--stop-with-process={os.getpid()}",
|
|
]
|
|
|
|
logger.info("Starting aria2 RPC daemon from %s", executable)
|
|
self._process = await asyncio.create_subprocess_exec(
|
|
*command,
|
|
stdout=asyncio.subprocess.DEVNULL,
|
|
stderr=asyncio.subprocess.PIPE,
|
|
)
|
|
|
|
await self._wait_until_ready()
|
|
|
|
def _resolve_executable(self) -> str:
|
|
settings = get_settings_manager()
|
|
configured_path = (settings.get("aria2c_path") or "").strip()
|
|
candidate = configured_path or "aria2c"
|
|
|
|
resolved = shutil.which(candidate)
|
|
if resolved:
|
|
return resolved
|
|
|
|
if configured_path and os.path.isfile(configured_path) and os.access(
|
|
configured_path, os.X_OK
|
|
):
|
|
return configured_path
|
|
|
|
raise Aria2Error(
|
|
"aria2c executable was not found. Install aria2 or configure aria2c_path."
|
|
)
|
|
|
|
async def _wait_until_ready(self) -> None:
|
|
assert self._process is not None
|
|
|
|
start_time = asyncio.get_running_loop().time()
|
|
last_error = ""
|
|
while asyncio.get_running_loop().time() - start_time < 10.0:
|
|
if self._process.returncode is not None:
|
|
stderr_output = ""
|
|
if self._process.stderr is not None:
|
|
try:
|
|
stderr_output = (
|
|
await asyncio.wait_for(self._process.stderr.read(), timeout=0.2)
|
|
).decode("utf-8", errors="replace")
|
|
except Exception:
|
|
stderr_output = ""
|
|
raise Aria2Error(
|
|
f"aria2 RPC process exited early with code {self._process.returncode}: {stderr_output.strip()}"
|
|
)
|
|
|
|
try:
|
|
if await self._ping():
|
|
return
|
|
except Exception as exc: # pragma: no cover - startup race
|
|
last_error = str(exc)
|
|
|
|
await asyncio.sleep(0.2)
|
|
|
|
raise Aria2Error(
|
|
f"Timed out waiting for aria2 RPC to become ready{': ' + last_error if last_error else ''}"
|
|
)
|
|
|
|
async def _ping(self) -> bool:
|
|
try:
|
|
result = await self._rpc_call("aria2.getVersion", [])
|
|
except Exception:
|
|
return False
|
|
|
|
return isinstance(result, dict)
|
|
|
|
async def _rpc_call(self, method: str, params: list[Any]) -> Any:
|
|
if not self._rpc_url:
|
|
raise Aria2Error("aria2 RPC endpoint is not initialized")
|
|
|
|
session = await self._get_rpc_session()
|
|
payload = {
|
|
"jsonrpc": "2.0",
|
|
"id": secrets.token_hex(8),
|
|
"method": method,
|
|
"params": [f"token:{self._rpc_secret}", *params],
|
|
}
|
|
|
|
async with session.post(self._rpc_url, json=payload) as response:
|
|
text = await response.text()
|
|
|
|
try:
|
|
body = json.loads(text)
|
|
except json.JSONDecodeError:
|
|
body = None
|
|
|
|
if body is None:
|
|
if response.status != 200:
|
|
raise Aria2Error(
|
|
f"aria2 RPC returned status {response.status} with non-JSON body: {text}"
|
|
)
|
|
raise Aria2Error(f"Invalid aria2 RPC response: {text}")
|
|
|
|
if "error" in body:
|
|
error = body["error"] or {}
|
|
code = error.get("code") if isinstance(error, dict) else None
|
|
message = error.get("message") if isinstance(error, dict) else str(error)
|
|
logger.error(
|
|
"aria2 RPC %s failed with HTTP %s, code=%s, message=%s",
|
|
method,
|
|
response.status,
|
|
code,
|
|
message,
|
|
)
|
|
status_message = (
|
|
f"aria2 RPC {method} failed with status {response.status}: {message}"
|
|
if response.status != 200
|
|
else message
|
|
)
|
|
raise Aria2Error(status_message or "Unknown aria2 RPC error")
|
|
|
|
if response.status != 200:
|
|
logger.error(
|
|
"aria2 RPC %s returned unexpected HTTP status %s without error payload: %s",
|
|
method,
|
|
response.status,
|
|
body,
|
|
)
|
|
raise Aria2Error(
|
|
f"aria2 RPC {method} returned unexpected status {response.status}"
|
|
)
|
|
|
|
return body.get("result")
|
|
|
|
async def _get_rpc_session(self) -> aiohttp.ClientSession:
|
|
if self._rpc_session is None or self._rpc_session.closed:
|
|
async with self._rpc_session_lock:
|
|
if self._rpc_session is None or self._rpc_session.closed:
|
|
timeout = aiohttp.ClientTimeout(total=30)
|
|
self._rpc_session = aiohttp.ClientSession(timeout=timeout)
|
|
return self._rpc_session
|
|
|
|
@staticmethod
|
|
def _find_free_port() -> int:
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
|
sock.bind(("127.0.0.1", 0))
|
|
sock.listen(1)
|
|
return int(sock.getsockname()[1])
|
|
|
|
|
|
async def get_aria2_downloader() -> Aria2Downloader:
|
|
"""Get the singleton aria2 downloader."""
|
|
|
|
return await Aria2Downloader.get_instance()
|