diff --git a/locales/de.json b/locales/de.json index 8584a12d..bd795d45 100644 --- a/locales/de.json +++ b/locales/de.json @@ -263,6 +263,19 @@ "red": "civitai.red (uneingeschränkt)" } }, + "downloadBackend": { + "label": "Download-Backend", + "help": "Wähle aus, wie Modelldateien heruntergeladen werden. Python verwendet den eingebauten Downloader. aria2 verwendet den experimentellen externen Downloader-Prozess.", + "options": { + "python": "Python (integriert)", + "aria2": "aria2 (experimentell)" + } + }, + "aria2cPath": { + "label": "aria2c-Pfad", + "help": "Optionaler Pfad zur ausführbaren aria2c-Datei. Leer lassen, um aria2c aus dem System-PATH zu verwenden.", + "placeholder": "Leer lassen, um aria2c aus dem PATH zu verwenden" + }, "civitaiHostBanner": { "title": "Civitai-Host-Einstellung verfügbar", "content": "Civitai verwendet jetzt civitai.com für SFW-Inhalte und civitai.red für uneingeschränkte Inhalte. In den Einstellungen können Sie ändern, welche Seite standardmäßig geöffnet wird.", @@ -278,6 +291,7 @@ }, "sections": { "contentFiltering": "Inhaltsfilterung", + "downloads": "Downloads", "videoSettings": "Video-Einstellungen", "layoutSettings": "Layout-Einstellungen", "misc": "Verschiedenes", diff --git a/locales/en.json b/locales/en.json index e4272a48..a7e41e0f 100644 --- a/locales/en.json +++ b/locales/en.json @@ -263,6 +263,19 @@ "red": "civitai.red (unrestricted)" } }, + "downloadBackend": { + "label": "Download backend", + "help": "Choose how model files are downloaded. Python uses the built-in downloader. aria2 uses the experimental external downloader process.", + "options": { + "python": "Python (built-in)", + "aria2": "aria2 (experimental)" + } + }, + "aria2cPath": { + "label": "aria2c path", + "help": "Optional path to the aria2c executable. Leave empty to use aria2c from your system PATH.", + "placeholder": "Leave empty to use aria2c from PATH" + }, "civitaiHostBanner": { "title": "Civitai host preference available", "content": "Civitai now uses civitai.com for SFW content and civitai.red for unrestricted content. You can change which site opens by default in Settings.", @@ -278,6 +291,7 @@ }, "sections": { "contentFiltering": "Content Filtering", + "downloads": "Downloads", "videoSettings": "Video Settings", "layoutSettings": "Layout Settings", "misc": "Miscellaneous", diff --git a/locales/es.json b/locales/es.json index 45ba1491..8154b625 100644 --- a/locales/es.json +++ b/locales/es.json @@ -263,6 +263,19 @@ "red": "civitai.red (sin restricciones)" } }, + "downloadBackend": { + "label": "Backend de descarga", + "help": "Elige cómo se descargan los archivos del modelo. Python usa el descargador integrado. aria2 usa el proceso externo experimental de descarga.", + "options": { + "python": "Python (integrado)", + "aria2": "aria2 (experimental)" + } + }, + "aria2cPath": { + "label": "Ruta de aria2c", + "help": "Ruta opcional al ejecutable aria2c. Déjalo vacío para usar aria2c desde el PATH del sistema.", + "placeholder": "Déjalo vacío para usar aria2c desde el PATH" + }, "civitaiHostBanner": { "title": "Preferencia de host de Civitai disponible", "content": "Civitai ahora usa civitai.com para contenido SFW y civitai.red para contenido sin restricciones. Puedes cambiar en Ajustes qué sitio se abre por defecto.", @@ -278,6 +291,7 @@ }, "sections": { "contentFiltering": "Filtrado de contenido", + "downloads": "Descargas", "videoSettings": "Configuración de video", "layoutSettings": "Configuración de diseño", "misc": "Varios", diff --git a/locales/fr.json b/locales/fr.json index 15ac157c..fd156e34 100644 --- a/locales/fr.json +++ b/locales/fr.json @@ -263,6 +263,19 @@ "red": "civitai.red (sans restriction)" } }, + "downloadBackend": { + "label": "Moteur de téléchargement", + "help": "Choisissez comment les fichiers de modèles sont téléchargés. Python utilise le téléchargeur intégré. aria2 utilise le processus externe expérimental de téléchargement.", + "options": { + "python": "Python (intégré)", + "aria2": "aria2 (expérimental)" + } + }, + "aria2cPath": { + "label": "Chemin vers aria2c", + "help": "Chemin facultatif vers l’exécutable aria2c. Laissez vide pour utiliser aria2c depuis le PATH système.", + "placeholder": "Laisser vide pour utiliser aria2c depuis le PATH" + }, "civitaiHostBanner": { "title": "Préférence d’hôte Civitai disponible", "content": "Civitai utilise désormais civitai.com pour le contenu SFW et civitai.red pour le contenu sans restriction. Vous pouvez modifier dans les paramètres le site ouvert par défaut.", @@ -278,6 +291,7 @@ }, "sections": { "contentFiltering": "Filtrage du contenu", + "downloads": "Téléchargements", "videoSettings": "Paramètres vidéo", "layoutSettings": "Paramètres d'affichage", "misc": "Divers", diff --git a/locales/he.json b/locales/he.json index 8dfdab61..9cd794d5 100644 --- a/locales/he.json +++ b/locales/he.json @@ -263,6 +263,19 @@ "red": "civitai.red (ללא הגבלות)" } }, + "downloadBackend": { + "label": "מנגנון הורדה", + "help": "בחר כיצד יורדים קבצי המודל. Python משתמש במוריד המובנה. aria2 משתמש בתהליך הורדה חיצוני ניסיוני.", + "options": { + "python": "Python (מובנה)", + "aria2": "aria2 (ניסיוני)" + } + }, + "aria2cPath": { + "label": "נתיב aria2c", + "help": "נתיב אופציונלי לקובץ ההפעלה aria2c. השאר ריק כדי להשתמש ב-aria2c מתוך ה-PATH של המערכת.", + "placeholder": "השאר ריק כדי להשתמש ב-aria2c מתוך ה-PATH" + }, "civitaiHostBanner": { "title": "העדפת מארח Civitai זמינה", "content": "Civitai משתמש כעת ב-civitai.com עבור תוכן SFW וב-civitai.red עבור תוכן ללא הגבלות. ניתן לשנות בהגדרות איזה אתר ייפתח כברירת מחדל.", @@ -278,6 +291,7 @@ }, "sections": { "contentFiltering": "סינון תוכן", + "downloads": "הורדות", "videoSettings": "הגדרות וידאו", "layoutSettings": "הגדרות פריסה", "misc": "שונות", diff --git a/locales/ja.json b/locales/ja.json index e1d4448d..f5d286c2 100644 --- a/locales/ja.json +++ b/locales/ja.json @@ -263,6 +263,19 @@ "red": "civitai.red(制限なし)" } }, + "downloadBackend": { + "label": "ダウンロードバックエンド", + "help": "モデルファイルのダウンロード方法を選択します。Python は内蔵ダウンローダーを使用し、aria2 は実験的な外部ダウンローダープロセスを使用します。", + "options": { + "python": "Python(内蔵)", + "aria2": "aria2(実験的)" + } + }, + "aria2cPath": { + "label": "aria2c のパス", + "help": "aria2c 実行ファイルへの任意のパスです。空欄のままにすると、システム PATH 上の aria2c を使用します。", + "placeholder": "空欄のままにすると PATH 上の aria2c を使用します" + }, "civitaiHostBanner": { "title": "Civitai ホスト設定を利用できます", "content": "Civitai は現在、SFW コンテンツには civitai.com、制限なしコンテンツには civitai.red を使用しています。設定で既定で開くサイトを変更できます。", @@ -278,6 +291,7 @@ }, "sections": { "contentFiltering": "コンテンツフィルタリング", + "downloads": "ダウンロード", "videoSettings": "動画設定", "layoutSettings": "レイアウト設定", "misc": "その他", diff --git a/locales/ko.json b/locales/ko.json index 5a3f2ec0..2e1f40ca 100644 --- a/locales/ko.json +++ b/locales/ko.json @@ -263,6 +263,19 @@ "red": "civitai.red(무제한)" } }, + "downloadBackend": { + "label": "다운로드 백엔드", + "help": "모델 파일을 다운로드하는 방식을 선택합니다. Python은 내장 다운로더를 사용하고, aria2는 실험적인 외부 다운로더 프로세스를 사용합니다.", + "options": { + "python": "Python(내장)", + "aria2": "aria2(실험적)" + } + }, + "aria2cPath": { + "label": "aria2c 경로", + "help": "aria2c 실행 파일의 선택적 경로입니다. 비워 두면 시스템 PATH의 aria2c를 사용합니다.", + "placeholder": "비워 두면 PATH의 aria2c를 사용합니다" + }, "civitaiHostBanner": { "title": "Civitai 호스트 기본 설정 사용 가능", "content": "이제 Civitai는 SFW 콘텐츠에 civitai.com을, 무제한 콘텐츠에 civitai.red를 사용합니다. 설정에서 기본으로 열 사이트를 변경할 수 있습니다.", @@ -278,6 +291,7 @@ }, "sections": { "contentFiltering": "콘텐츠 필터링", + "downloads": "다운로드", "videoSettings": "비디오 설정", "layoutSettings": "레이아웃 설정", "misc": "기타", diff --git a/locales/ru.json b/locales/ru.json index cf0612fe..58eccf53 100644 --- a/locales/ru.json +++ b/locales/ru.json @@ -263,6 +263,19 @@ "red": "civitai.red (без ограничений)" } }, + "downloadBackend": { + "label": "Бэкенд загрузки", + "help": "Выберите способ загрузки файлов моделей. Python использует встроенный загрузчик. aria2 использует экспериментальный внешний процесс загрузки.", + "options": { + "python": "Python (встроенный)", + "aria2": "aria2 (экспериментальный)" + } + }, + "aria2cPath": { + "label": "Путь к aria2c", + "help": "Необязательный путь к исполняемому файлу aria2c. Оставьте пустым, чтобы использовать aria2c из системного PATH.", + "placeholder": "Оставьте пустым, чтобы использовать aria2c из PATH" + }, "civitaiHostBanner": { "title": "Доступна настройка хоста Civitai", "content": "Теперь Civitai использует civitai.com для контента SFW и civitai.red для контента без ограничений. В настройках можно изменить, какой сайт открывать по умолчанию.", @@ -278,6 +291,7 @@ }, "sections": { "contentFiltering": "Фильтрация контента", + "downloads": "Загрузки", "videoSettings": "Настройки видео", "layoutSettings": "Настройки макета", "misc": "Разное", diff --git a/locales/zh-CN.json b/locales/zh-CN.json index 898f550a..17644c63 100644 --- a/locales/zh-CN.json +++ b/locales/zh-CN.json @@ -263,6 +263,19 @@ "red": "civitai.red(无限制)" } }, + "downloadBackend": { + "label": "下载后端", + "help": "选择模型文件的下载方式。Python 使用内置下载器。aria2 使用实验性的外部下载进程。", + "options": { + "python": "Python(内置)", + "aria2": "aria2(实验性)" + } + }, + "aria2cPath": { + "label": "aria2c 路径", + "help": "可选的 aria2c 可执行文件路径。留空则使用系统 PATH 中的 aria2c。", + "placeholder": "留空则使用 PATH 中的 aria2c" + }, "civitaiHostBanner": { "title": "已提供 Civitai 站点偏好设置", "content": "Civitai 现在使用 civitai.com 提供 SFW 内容,使用 civitai.red 提供无限制内容。你可以在设置中更改默认打开的站点。", @@ -278,6 +291,7 @@ }, "sections": { "contentFiltering": "内容过滤", + "downloads": "下载", "videoSettings": "视频设置", "layoutSettings": "布局设置", "misc": "其他", diff --git a/locales/zh-TW.json b/locales/zh-TW.json index ef4077ba..c2139ef0 100644 --- a/locales/zh-TW.json +++ b/locales/zh-TW.json @@ -263,6 +263,19 @@ "red": "civitai.red(無限制)" } }, + "downloadBackend": { + "label": "下載後端", + "help": "選擇模型檔案的下載方式。Python 使用內建下載器。aria2 使用實驗性的外部下載程序。", + "options": { + "python": "Python(內建)", + "aria2": "aria2(實驗性)" + } + }, + "aria2cPath": { + "label": "aria2c 路徑", + "help": "可選的 aria2c 可執行檔路徑。留空則使用系統 PATH 中的 aria2c。", + "placeholder": "留空則使用 PATH 中的 aria2c" + }, "civitaiHostBanner": { "title": "已提供 Civitai 站點偏好設定", "content": "Civitai 現在使用 civitai.com 提供 SFW 內容,使用 civitai.red 提供無限制內容。你可以在設定中變更預設開啟的站點。", @@ -278,6 +291,7 @@ }, "sections": { "contentFiltering": "內容過濾", + "downloads": "下載", "videoSettings": "影片設定", "layoutSettings": "版面設定", "misc": "其他", diff --git a/py/services/aria2_downloader.py b/py/services/aria2_downloader.py new file mode 100644 index 00000000..d1b1b018 --- /dev/null +++ b/py/services/aria2_downloader.py @@ -0,0 +1,497 @@ +from __future__ import annotations + +import asyncio +import json +import logging +import os +import secrets +import shutil +import socket +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + +import aiohttp + +from .downloader import DownloadProgress, get_downloader +from .settings_manager import get_settings_manager + +logger = logging.getLogger(__name__) + +CIVITAI_DOWNLOAD_URL_PREFIXES = ( + "https://civitai.com/api/download/", + "https://civitai.red/api/download/", +) + + +class Aria2Error(RuntimeError): + """Raised when aria2 integration fails.""" + + +@dataclass +class Aria2Transfer: + """Track an aria2 download registered by the Python coordinator.""" + + gid: str + save_path: str + + +class Aria2Downloader: + """Manage an aria2 RPC daemon for experimental model downloads.""" + + _instance = None + _lock = asyncio.Lock() + + @classmethod + async def get_instance(cls) -> "Aria2Downloader": + async with cls._lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def __init__(self) -> None: + if hasattr(self, "_initialized"): + return + + self._initialized = True + self._process: Optional[asyncio.subprocess.Process] = None + self._rpc_port: Optional[int] = None + self._rpc_secret = "" + self._rpc_url = "" + self._rpc_session: Optional[aiohttp.ClientSession] = None + self._rpc_session_lock = asyncio.Lock() + self._process_lock = asyncio.Lock() + self._transfers: Dict[str, Aria2Transfer] = {} + self._poll_interval = 0.5 + + @property + def is_running(self) -> bool: + return self._process is not None and self._process.returncode is None + + async def download_file( + self, + url: str, + save_path: str, + *, + download_id: str, + progress_callback=None, + headers: Optional[Dict[str, str]] = None, + ) -> Tuple[bool, str]: + """Download a file using aria2 RPC and wait for completion.""" + + await self._ensure_process() + save_path = os.path.abspath(save_path) + save_dir = os.path.dirname(save_path) + out_name = os.path.basename(save_path) + + Path(save_dir).mkdir(parents=True, exist_ok=True) + + resolved_url = url + request_headers = headers + if headers and url.startswith(CIVITAI_DOWNLOAD_URL_PREFIXES): + resolved_url = await self._resolve_authenticated_redirect_url(url, headers) + if resolved_url != url: + request_headers = None + logger.debug( + "Resolved Civitai download %s to signed URL for aria2", + download_id, + ) + + options: Dict[str, str] = { + "dir": save_dir, + "out": out_name, + "continue": "true", + "max-connection-per-server": "4", + "split": "4", + "min-split-size": "1M", + "allow-overwrite": "true", + "auto-file-renaming": "false", + "file-allocation": "none", + } + if request_headers: + options["header"] = [ + f"{key}: {value}" for key, value in request_headers.items() + ] + + logger.debug( + "Submitting aria2 download %s -> %s (auth=%s, civitai_signed=%s)", + download_id, + save_path, + bool(request_headers), + resolved_url != url, + ) + + try: + gid = await self._rpc_call("aria2.addUri", [[resolved_url], options]) + except Exception as exc: + raise Aria2Error(f"Failed to schedule aria2 download: {exc}") from exc + + logger.debug("aria2 accepted download %s with gid %s", download_id, gid) + + self._transfers[download_id] = Aria2Transfer(gid=gid, save_path=save_path) + + try: + while True: + status = await self.get_status(download_id) + if status is None: + return False, "aria2 download not found" + + snapshot = self._build_progress_snapshot(status) + if progress_callback is not None: + await self._dispatch_progress(progress_callback, snapshot) + + state = status.get("status", "") + if state == "complete": + completed_path = self._resolve_completed_path(status, save_path) + return True, completed_path + if state == "error": + return False, status.get("errorMessage") or "aria2 download failed" + if state == "removed": + return False, "Download was cancelled" + + await asyncio.sleep(self._poll_interval) + finally: + self._transfers.pop(download_id, None) + + async def get_status(self, download_id: str) -> Optional[Dict[str, Any]]: + """Return the raw aria2 status payload for a known download.""" + + transfer = self._transfers.get(download_id) + if transfer is None: + return None + + keys = [ + "gid", + "status", + "totalLength", + "completedLength", + "downloadSpeed", + "errorMessage", + "files", + ] + try: + status = await self._rpc_call("aria2.tellStatus", [transfer.gid, keys]) + except Exception as exc: + raise Aria2Error(f"Failed to query aria2 download status: {exc}") from exc + + if isinstance(status, dict): + return status + return None + + async def has_transfer(self, download_id: str) -> bool: + return download_id in self._transfers + + async def pause_download(self, download_id: str) -> Dict[str, Any]: + transfer = self._transfers.get(download_id) + if transfer is None: + return {"success": False, "error": "Download task not found"} + + try: + await self._rpc_call("aria2.forcePause", [transfer.gid]) + except Exception as exc: + return {"success": False, "error": str(exc)} + + return {"success": True, "message": "Download paused successfully"} + + async def resume_download(self, download_id: str) -> Dict[str, Any]: + transfer = self._transfers.get(download_id) + if transfer is None: + return {"success": False, "error": "Download task not found"} + + try: + await self._rpc_call("aria2.unpause", [transfer.gid]) + except Exception as exc: + return {"success": False, "error": str(exc)} + + return {"success": True, "message": "Download resumed successfully"} + + async def cancel_download(self, download_id: str) -> Dict[str, Any]: + transfer = self._transfers.get(download_id) + if transfer is None: + return {"success": False, "error": "Download task not found"} + + try: + await self._rpc_call("aria2.forceRemove", [transfer.gid]) + except Exception as exc: + return {"success": False, "error": str(exc)} + + return {"success": True, "message": "Download cancelled successfully"} + + async def close(self) -> None: + """Shut down the RPC process and session.""" + + if self._rpc_session is not None: + await self._rpc_session.close() + self._rpc_session = None + + process = self._process + self._process = None + self._transfers.clear() + + if process is None: + return + + if process.returncode is None: + process.terminate() + try: + await asyncio.wait_for(process.wait(), timeout=5.0) + except asyncio.TimeoutError: + process.kill() + await process.wait() + + async def _dispatch_progress(self, callback, snapshot: DownloadProgress) -> None: + try: + result = callback(snapshot, snapshot) + except TypeError: + result = callback(snapshot.percent_complete) + + if asyncio.iscoroutine(result): + await result + elif hasattr(result, "__await__"): + await result + + def _build_progress_snapshot(self, status: Dict[str, Any]) -> DownloadProgress: + completed = self._parse_int(status.get("completedLength")) + total = self._parse_int(status.get("totalLength")) + speed = float(self._parse_int(status.get("downloadSpeed"))) + percent = 0.0 + if total > 0: + percent = (completed / total) * 100.0 + + return DownloadProgress( + percent_complete=max(0.0, min(percent, 100.0)), + bytes_downloaded=completed, + total_bytes=total or None, + bytes_per_second=speed, + timestamp=datetime.now().timestamp(), + ) + + def _resolve_completed_path(self, status: Dict[str, Any], default_path: str) -> str: + files = status.get("files") + if isinstance(files, list) and files: + first = files[0] + if isinstance(first, dict): + candidate = first.get("path") + if isinstance(candidate, str) and candidate: + return candidate + return default_path + + @staticmethod + def _parse_int(value: Any) -> int: + try: + return int(value) + except (TypeError, ValueError): + return 0 + + async def _resolve_authenticated_redirect_url( + self, + url: str, + headers: Dict[str, str], + ) -> str: + downloader = await get_downloader() + session = await downloader.session + request_headers = dict(downloader.default_headers) + request_headers.update(headers) + request_headers["Accept-Encoding"] = "identity" + + try: + async with session.get( + url, + headers=request_headers, + allow_redirects=False, + proxy=downloader.proxy_url, + ) as response: + if response.status in {301, 302, 303, 307, 308}: + location = response.headers.get("Location") + if location: + return location + raise Aria2Error( + "Authenticated Civitai redirect did not include a Location header" + ) + + if response.status == 200: + return url + + body = await response.text() + raise Aria2Error( + f"Failed to resolve authenticated Civitai redirect: status={response.status} body={body[:300]}" + ) + except aiohttp.ClientError as exc: + raise Aria2Error( + f"Failed to resolve authenticated Civitai redirect: {exc}" + ) from exc + + async def _ensure_process(self) -> None: + async with self._process_lock: + if self.is_running and await self._ping(): + return + + await self.close() + + executable = self._resolve_executable() + self._rpc_port = self._find_free_port() + self._rpc_secret = secrets.token_hex(16) + self._rpc_url = f"http://127.0.0.1:{self._rpc_port}/jsonrpc" + + command = [ + executable, + "--enable-rpc=true", + "--rpc-listen-all=false", + f"--rpc-listen-port={self._rpc_port}", + f"--rpc-secret={self._rpc_secret}", + "--check-certificate=true", + "--allow-overwrite=true", + "--auto-file-renaming=false", + "--file-allocation=none", + "--max-concurrent-downloads=5", + "--continue=true", + "--daemon=false", + "--quiet=true", + f"--stop-with-process={os.getpid()}", + ] + + logger.info("Starting aria2 RPC daemon from %s", executable) + self._process = await asyncio.create_subprocess_exec( + *command, + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.PIPE, + ) + + await self._wait_until_ready() + + def _resolve_executable(self) -> str: + settings = get_settings_manager() + configured_path = (settings.get("aria2c_path") or "").strip() + candidate = configured_path or "aria2c" + + resolved = shutil.which(candidate) + if resolved: + return resolved + + if configured_path and os.path.isfile(configured_path) and os.access( + configured_path, os.X_OK + ): + return configured_path + + raise Aria2Error( + "aria2c executable was not found. Install aria2 or configure aria2c_path." + ) + + async def _wait_until_ready(self) -> None: + assert self._process is not None + + start_time = asyncio.get_running_loop().time() + last_error = "" + while asyncio.get_running_loop().time() - start_time < 10.0: + if self._process.returncode is not None: + stderr_output = "" + if self._process.stderr is not None: + try: + stderr_output = ( + await asyncio.wait_for(self._process.stderr.read(), timeout=0.2) + ).decode("utf-8", errors="replace") + except Exception: + stderr_output = "" + raise Aria2Error( + f"aria2 RPC process exited early with code {self._process.returncode}: {stderr_output.strip()}" + ) + + try: + if await self._ping(): + return + except Exception as exc: # pragma: no cover - startup race + last_error = str(exc) + + await asyncio.sleep(0.2) + + raise Aria2Error( + f"Timed out waiting for aria2 RPC to become ready{': ' + last_error if last_error else ''}" + ) + + async def _ping(self) -> bool: + try: + result = await self._rpc_call("aria2.getVersion", []) + except Exception: + return False + + return isinstance(result, dict) + + async def _rpc_call(self, method: str, params: list[Any]) -> Any: + if not self._rpc_url: + raise Aria2Error("aria2 RPC endpoint is not initialized") + + session = await self._get_rpc_session() + payload = { + "jsonrpc": "2.0", + "id": secrets.token_hex(8), + "method": method, + "params": [f"token:{self._rpc_secret}", *params], + } + + async with session.post(self._rpc_url, json=payload) as response: + text = await response.text() + + try: + body = json.loads(text) + except json.JSONDecodeError: + body = None + + if body is None: + if response.status != 200: + raise Aria2Error( + f"aria2 RPC returned status {response.status} with non-JSON body: {text}" + ) + raise Aria2Error(f"Invalid aria2 RPC response: {text}") + + if "error" in body: + error = body["error"] or {} + code = error.get("code") if isinstance(error, dict) else None + message = error.get("message") if isinstance(error, dict) else str(error) + logger.error( + "aria2 RPC %s failed with HTTP %s, code=%s, message=%s", + method, + response.status, + code, + message, + ) + status_message = ( + f"aria2 RPC {method} failed with status {response.status}: {message}" + if response.status != 200 + else message + ) + raise Aria2Error(status_message or "Unknown aria2 RPC error") + + if response.status != 200: + logger.error( + "aria2 RPC %s returned unexpected HTTP status %s without error payload: %s", + method, + response.status, + body, + ) + raise Aria2Error( + f"aria2 RPC {method} returned unexpected status {response.status}" + ) + + return body.get("result") + + async def _get_rpc_session(self) -> aiohttp.ClientSession: + if self._rpc_session is None or self._rpc_session.closed: + async with self._rpc_session_lock: + if self._rpc_session is None or self._rpc_session.closed: + timeout = aiohttp.ClientTimeout(total=30) + self._rpc_session = aiohttp.ClientSession(timeout=timeout) + return self._rpc_session + + @staticmethod + def _find_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + sock.listen(1) + return int(sock.getsockname()[1]) + + +async def get_aria2_downloader() -> Aria2Downloader: + """Get the singleton aria2 downloader.""" + + return await Aria2Downloader.get_instance() diff --git a/py/services/download_manager.py b/py/services/download_manager.py index 2fe07cff..ce727abd 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -5,6 +5,7 @@ import asyncio import inspect import shutil import zipfile +from concurrent.futures import ThreadPoolExecutor from collections import OrderedDict import uuid from typing import Dict, List, Optional, Set, Tuple @@ -25,6 +26,7 @@ from .service_registry import ServiceRegistry from .settings_manager import get_settings_manager from .metadata_service import get_default_metadata_provider, get_metadata_provider from .downloader import get_downloader, DownloadProgress, DownloadStreamControl +from .aria2_downloader import Aria2Error, get_aria2_downloader # Download to temporary file first import tempfile @@ -60,6 +62,59 @@ class DownloadManager: self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads self._download_tasks = {} # download_id -> asyncio.Task self._pause_events: Dict[str, DownloadStreamControl] = {} + self._archive_executor = ThreadPoolExecutor( + max_workers=2, thread_name_prefix="lm-archive" + ) + + @staticmethod + def _get_model_download_backend() -> str: + backend = (get_settings_manager().get("download_backend") or "python").strip() + return backend.lower() or "python" + + async def _download_model_file( + self, + download_url: str, + save_path: str, + *, + backend: str, + progress_callback, + use_auth: bool, + download_id: Optional[str], + pause_control: Optional[DownloadStreamControl], + ) -> Tuple[bool, str]: + if backend == "aria2": + if not download_id: + return False, "aria2 downloads require a tracked download_id" + + headers: Dict[str, str] = {} + if use_auth: + api_key = (get_settings_manager().get("civitai_api_key") or "").strip() + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + try: + aria2_downloader = await get_aria2_downloader() + return await aria2_downloader.download_file( + download_url, + save_path, + download_id=download_id, + progress_callback=progress_callback, + headers=headers or None, + ) + except Aria2Error as exc: + logger.error("aria2 download failed for %s: %s", download_url, exc) + return False, str(exc) + + download_kwargs = { + "progress_callback": progress_callback, + "use_auth": use_auth, + } + + if pause_control is not None: + download_kwargs["pause_event"] = pause_control + + downloader = await get_downloader() + return await downloader.download_file(download_url, save_path, **download_kwargs) async def _get_lora_scanner(self): """Get the lora scanner from registry""" @@ -126,6 +181,7 @@ class DownloadManager: "model_version_id": model_version_id, "progress": 0, "status": "queued", + "transfer_backend": self._get_model_download_backend(), "bytes_downloaded": 0, "total_bytes": None, "bytes_per_second": 0.0, @@ -240,6 +296,9 @@ class DownloadManager: tracking_callback, use_default_paths, task_id, + self._active_downloads.get(task_id, {}).get( + "transfer_backend", "python" + ), source, file_params, ) @@ -294,6 +353,7 @@ class DownloadManager: progress_callback, use_default_paths, download_id=None, + transfer_backend="python", source=None, file_params=None, ): @@ -696,16 +756,27 @@ class DownloadManager: logger.info(f"Creating EmbeddingMetadata for {file_name}") # 6. Start download process - result = await self._execute_download( - download_urls=download_urls, - save_dir=save_dir, - metadata=metadata, - version_info=version_info, - relative_path=relative_path, - progress_callback=progress_callback, - model_type=model_type, - download_id=download_id, - ) + execute_kwargs = { + "download_urls": download_urls, + "save_dir": save_dir, + "metadata": metadata, + "version_info": version_info, + "relative_path": relative_path, + "progress_callback": progress_callback, + "model_type": model_type, + "download_id": download_id, + } + execute_signature = inspect.signature(self._execute_download) + if ( + "transfer_backend" in execute_signature.parameters + or any( + parameter.kind == inspect.Parameter.VAR_KEYWORD + for parameter in execute_signature.parameters.values() + ) + ): + execute_kwargs["transfer_backend"] = transfer_backend + + result = await self._execute_download(**execute_kwargs) if result.get("success", False): resolved_model_id = ( @@ -965,6 +1036,7 @@ class DownloadManager: progress_callback=None, model_type: str = "lora", download_id: str = None, + transfer_backend: Optional[str] = None, ) -> Dict: """Execute the actual download process including preview images and model files""" metadata_entries: List = [] @@ -974,6 +1046,7 @@ class DownloadManager: preview_targets: List[str] = [] preview_path: str | None = None preview_nsfw_level = 0 + transfer_backend = (transfer_backend or self._get_model_download_backend()).lower() try: # Extract original filename details original_filename = os.path.basename(metadata.file_path) @@ -1136,32 +1209,37 @@ class DownloadManager: if progress_callback: await progress_callback(3) # 3% progress after preview download - # Download model file with progress tracking using downloader - downloader = await get_downloader() - if pause_control is not None: - pause_control.update_stall_timeout(downloader.stall_timeout) + # Download model file with progress tracking using the configured backend + downloader = None + if transfer_backend == "python": + downloader = await get_downloader() + if pause_control is not None: + pause_control.update_stall_timeout(downloader.stall_timeout) + if pause_control is not None and pause_control.is_paused(): + if download_id and download_id in self._active_downloads: + self._active_downloads[download_id]["status"] = "paused" + self._active_downloads[download_id]["bytes_per_second"] = 0.0 + await pause_control.wait() + if download_id and download_id in self._active_downloads: + self._active_downloads[download_id]["status"] = "downloading" last_error = None for download_url in download_urls: download_url = normalize_civitai_download_url(download_url) use_auth = download_url.startswith(CIVITAI_DOWNLOAD_URL_PREFIXES) - download_kwargs = { - "progress_callback": lambda progress, snapshot=None: ( + success, result = await self._download_model_file( + download_url, + save_path, + backend=transfer_backend, + progress_callback=lambda progress, snapshot=None: ( self._handle_download_progress( progress, progress_callback, snapshot, ) ), - "use_auth": use_auth, # Only use authentication for Civitai downloads - } - - if pause_control is not None: - download_kwargs["pause_event"] = pause_control - - success, result = await downloader.download_file( - download_url, - save_path, # Use full path instead of separate dir and filename - **download_kwargs, + use_auth=use_auth, + download_id=download_id, + pause_control=pause_control, ) if success: @@ -1401,7 +1479,8 @@ class DownloadManager: extracted_files.append(dest_path) return extracted_files - return await asyncio.to_thread(_extract_sync) + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self._archive_executor, _extract_sync) async def _build_metadata_entries( self, base_metadata, file_paths: List[str] @@ -1511,8 +1590,28 @@ class DownloadManager: return {"success": False, "error": "Download task not found"} try: - # Get the task and cancel it task = self._download_tasks[download_id] + backend = ( + self._active_downloads.get(download_id, {}).get("transfer_backend") + or "python" + ) + + if backend == "aria2": + try: + aria2_downloader = await get_aria2_downloader() + cancel_result = await aria2_downloader.cancel_download(download_id) + if ( + not cancel_result.get("success") + and cancel_result.get("error") != "Download task not found" + ): + return cancel_result + except Exception as exc: + logger.warning( + "Failed to cancel aria2 transfer for %s, continuing with local task cancellation: %s", + download_id, + exc, + ) + task.cancel() pause_control = self._pause_events.get(download_id) @@ -1613,6 +1712,28 @@ class DownloadManager: pause_control.pause() + backend = ( + self._active_downloads.get(download_id, {}).get("transfer_backend") + or "python" + ) + if backend == "aria2": + try: + aria2_downloader = await get_aria2_downloader() + if await aria2_downloader.has_transfer(download_id): + result = await aria2_downloader.pause_download(download_id) + if not result.get("success"): + pause_control.resume() + return result + except Exception as exc: + pause_control.resume() + return {"success": False, "error": str(exc)} + + download_info = self._active_downloads.get(download_id) + if download_info is not None: + download_info["status"] = "paused" + download_info["bytes_per_second"] = 0.0 + return {"success": True, "message": "Download paused successfully"} + download_info = self._active_downloads.get(download_id) if download_info is not None: download_info["status"] = "paused" @@ -1631,6 +1752,28 @@ class DownloadManager: return {"success": False, "error": "Download is not paused"} download_info = self._active_downloads.get(download_id) + backend = ( + self._active_downloads.get(download_id, {}).get("transfer_backend") + or "python" + ) + if backend == "aria2": + try: + aria2_downloader = await get_aria2_downloader() + if await aria2_downloader.has_transfer(download_id): + result = await aria2_downloader.resume_download(download_id) + if not result.get("success"): + return result + except Exception as exc: + return {"success": False, "error": str(exc)} + + pause_control.resume() + + if download_info is not None: + if download_info.get("status") == "paused": + download_info["status"] = "downloading" + download_info.setdefault("bytes_per_second", 0.0) + return {"success": True, "message": "Download resumed successfully"} + force_reconnect = False if pause_control is not None: elapsed = pause_control.time_since_last_progress() diff --git a/py/services/settings_manager.py b/py/services/settings_manager.py index 8b6cfebc..08279982 100644 --- a/py/services/settings_manager.py +++ b/py/services/settings_manager.py @@ -55,6 +55,8 @@ DEFAULT_KEYS_CLEANUP_THRESHOLD = 10 DEFAULT_SETTINGS: Dict[str, Any] = { "civitai_api_key": "", "civitai_host": "civitai.com", + "download_backend": "python", + "aria2c_path": "", "use_portable_settings": False, "hash_chunk_size_mb": DEFAULT_HASH_CHUNK_SIZE_MB, "language": "en", diff --git a/static/js/managers/SettingsManager.js b/static/js/managers/SettingsManager.js index 1dcc5d2f..ab2906b2 100644 --- a/static/js/managers/SettingsManager.js +++ b/static/js/managers/SettingsManager.js @@ -807,6 +807,16 @@ export class SettingsManager { civitaiHostSelect.value = state.global.settings.civitai_host || 'civitai.com'; } + const downloadBackendSelect = document.getElementById('downloadBackend'); + if (downloadBackendSelect) { + downloadBackendSelect.value = state.global.settings.download_backend || 'python'; + } + + const aria2cPathInput = document.getElementById('aria2cPath'); + if (aria2cPathInput) { + aria2cPathInput.value = state.global.settings.aria2c_path || ''; + } + const recipesPathInput = document.getElementById('recipesPath'); if (recipesPathInput) { recipesPathInput.value = state.global.settings.recipes_path || ''; @@ -950,9 +960,36 @@ export class SettingsManager { languageSelect.value = currentLanguage; } + this.loadDownloadBackendSettings(); this.loadProxySettings(); } + loadDownloadBackendSettings() { + const downloadBackendSelect = document.getElementById('downloadBackend'); + const aria2PathSetting = document.getElementById('aria2PathSetting'); + const updateVisibility = () => { + if (!aria2PathSetting || !downloadBackendSelect) { + return; + } + aria2PathSetting.style.display = downloadBackendSelect.value === 'aria2' ? 'block' : 'none'; + }; + + if (downloadBackendSelect) { + downloadBackendSelect.value = state.global.settings.download_backend || 'python'; + downloadBackendSelect.onchange = () => { + updateVisibility(); + this.saveSelectSetting('downloadBackend', 'download_backend'); + }; + } + + const aria2cPathInput = document.getElementById('aria2cPath'); + if (aria2cPathInput) { + aria2cPathInput.value = state.global.settings.aria2c_path || ''; + } + + updateVisibility(); + } + setupPriorityTagInputs() { ['lora', 'checkpoint', 'embedding'].forEach((modelType) => { const textarea = document.getElementById(`${modelType}PriorityTagsInput`); diff --git a/static/js/state/index.js b/static/js/state/index.js index 4a55ab26..5206c3d4 100644 --- a/static/js/state/index.js +++ b/static/js/state/index.js @@ -6,6 +6,8 @@ import { DEFAULT_PATH_TEMPLATES, DEFAULT_PRIORITY_TAG_CONFIG } from '../utils/co const DEFAULT_SETTINGS_BASE = Object.freeze({ civitai_api_key: '', civitai_host: 'civitai.com', + download_backend: 'python', + aria2c_path: '', use_portable_settings: false, language: 'en', show_only_sfw: false, diff --git a/templates/components/modals/settings_modal.html b/templates/components/modals/settings_modal.html index c85a7cfc..607241e0 100644 --- a/templates/components/modals/settings_modal.html +++ b/templates/components/modals/settings_modal.html @@ -129,6 +129,43 @@ +