diff --git a/locales/de.json b/locales/de.json index 8584a12d..bd795d45 100644 --- a/locales/de.json +++ b/locales/de.json @@ -263,6 +263,19 @@ "red": "civitai.red (uneingeschränkt)" } }, + "downloadBackend": { + "label": "Download-Backend", + "help": "Wähle aus, wie Modelldateien heruntergeladen werden. Python verwendet den eingebauten Downloader. aria2 verwendet den experimentellen externen Downloader-Prozess.", + "options": { + "python": "Python (integriert)", + "aria2": "aria2 (experimentell)" + } + }, + "aria2cPath": { + "label": "aria2c-Pfad", + "help": "Optionaler Pfad zur ausführbaren aria2c-Datei. Leer lassen, um aria2c aus dem System-PATH zu verwenden.", + "placeholder": "Leer lassen, um aria2c aus dem PATH zu verwenden" + }, "civitaiHostBanner": { "title": "Civitai-Host-Einstellung verfügbar", "content": "Civitai verwendet jetzt civitai.com für SFW-Inhalte und civitai.red für uneingeschränkte Inhalte. In den Einstellungen können Sie ändern, welche Seite standardmäßig geöffnet wird.", @@ -278,6 +291,7 @@ }, "sections": { "contentFiltering": "Inhaltsfilterung", + "downloads": "Downloads", "videoSettings": "Video-Einstellungen", "layoutSettings": "Layout-Einstellungen", "misc": "Verschiedenes", diff --git a/locales/en.json b/locales/en.json index e4272a48..a7e41e0f 100644 --- a/locales/en.json +++ b/locales/en.json @@ -263,6 +263,19 @@ "red": "civitai.red (unrestricted)" } }, + "downloadBackend": { + "label": "Download backend", + "help": "Choose how model files are downloaded. Python uses the built-in downloader. aria2 uses the experimental external downloader process.", + "options": { + "python": "Python (built-in)", + "aria2": "aria2 (experimental)" + } + }, + "aria2cPath": { + "label": "aria2c path", + "help": "Optional path to the aria2c executable. Leave empty to use aria2c from your system PATH.", + "placeholder": "Leave empty to use aria2c from PATH" + }, "civitaiHostBanner": { "title": "Civitai host preference available", "content": "Civitai now uses civitai.com for SFW content and civitai.red for unrestricted content. You can change which site opens by default in Settings.", @@ -278,6 +291,7 @@ }, "sections": { "contentFiltering": "Content Filtering", + "downloads": "Downloads", "videoSettings": "Video Settings", "layoutSettings": "Layout Settings", "misc": "Miscellaneous", diff --git a/locales/es.json b/locales/es.json index 45ba1491..8154b625 100644 --- a/locales/es.json +++ b/locales/es.json @@ -263,6 +263,19 @@ "red": "civitai.red (sin restricciones)" } }, + "downloadBackend": { + "label": "Backend de descarga", + "help": "Elige cómo se descargan los archivos del modelo. Python usa el descargador integrado. aria2 usa el proceso externo experimental de descarga.", + "options": { + "python": "Python (integrado)", + "aria2": "aria2 (experimental)" + } + }, + "aria2cPath": { + "label": "Ruta de aria2c", + "help": "Ruta opcional al ejecutable aria2c. Déjalo vacío para usar aria2c desde el PATH del sistema.", + "placeholder": "Déjalo vacío para usar aria2c desde el PATH" + }, "civitaiHostBanner": { "title": "Preferencia de host de Civitai disponible", "content": "Civitai ahora usa civitai.com para contenido SFW y civitai.red para contenido sin restricciones. Puedes cambiar en Ajustes qué sitio se abre por defecto.", @@ -278,6 +291,7 @@ }, "sections": { "contentFiltering": "Filtrado de contenido", + "downloads": "Descargas", "videoSettings": "Configuración de video", "layoutSettings": "Configuración de diseño", "misc": "Varios", diff --git a/locales/fr.json b/locales/fr.json index 15ac157c..fd156e34 100644 --- a/locales/fr.json +++ b/locales/fr.json @@ -263,6 +263,19 @@ "red": "civitai.red (sans restriction)" } }, + "downloadBackend": { + "label": "Moteur de téléchargement", + "help": "Choisissez comment les fichiers de modèles sont téléchargés. Python utilise le téléchargeur intégré. aria2 utilise le processus externe expérimental de téléchargement.", + "options": { + "python": "Python (intégré)", + "aria2": "aria2 (expérimental)" + } + }, + "aria2cPath": { + "label": "Chemin vers aria2c", + "help": "Chemin facultatif vers l’exécutable aria2c. Laissez vide pour utiliser aria2c depuis le PATH système.", + "placeholder": "Laisser vide pour utiliser aria2c depuis le PATH" + }, "civitaiHostBanner": { "title": "Préférence d’hôte Civitai disponible", "content": "Civitai utilise désormais civitai.com pour le contenu SFW et civitai.red pour le contenu sans restriction. Vous pouvez modifier dans les paramètres le site ouvert par défaut.", @@ -278,6 +291,7 @@ }, "sections": { "contentFiltering": "Filtrage du contenu", + "downloads": "Téléchargements", "videoSettings": "Paramètres vidéo", "layoutSettings": "Paramètres d'affichage", "misc": "Divers", diff --git a/locales/he.json b/locales/he.json index 8dfdab61..9cd794d5 100644 --- a/locales/he.json +++ b/locales/he.json @@ -263,6 +263,19 @@ "red": "civitai.red (ללא הגבלות)" } }, + "downloadBackend": { + "label": "מנגנון הורדה", + "help": "בחר כיצד יורדים קבצי המודל. Python משתמש במוריד המובנה. aria2 משתמש בתהליך הורדה חיצוני ניסיוני.", + "options": { + "python": "Python (מובנה)", + "aria2": "aria2 (ניסיוני)" + } + }, + "aria2cPath": { + "label": "נתיב aria2c", + "help": "נתיב אופציונלי לקובץ ההפעלה aria2c. השאר ריק כדי להשתמש ב-aria2c מתוך ה-PATH של המערכת.", + "placeholder": "השאר ריק כדי להשתמש ב-aria2c מתוך ה-PATH" + }, "civitaiHostBanner": { "title": "העדפת מארח Civitai זמינה", "content": "Civitai משתמש כעת ב-civitai.com עבור תוכן SFW וב-civitai.red עבור תוכן ללא הגבלות. ניתן לשנות בהגדרות איזה אתר ייפתח כברירת מחדל.", @@ -278,6 +291,7 @@ }, "sections": { "contentFiltering": "סינון תוכן", + "downloads": "הורדות", "videoSettings": "הגדרות וידאו", "layoutSettings": "הגדרות פריסה", "misc": "שונות", diff --git a/locales/ja.json b/locales/ja.json index e1d4448d..f5d286c2 100644 --- a/locales/ja.json +++ b/locales/ja.json @@ -263,6 +263,19 @@ "red": "civitai.red(制限なし)" } }, + "downloadBackend": { + "label": "ダウンロードバックエンド", + "help": "モデルファイルのダウンロード方法を選択します。Python は内蔵ダウンローダーを使用し、aria2 は実験的な外部ダウンローダープロセスを使用します。", + "options": { + "python": "Python(内蔵)", + "aria2": "aria2(実験的)" + } + }, + "aria2cPath": { + "label": "aria2c のパス", + "help": "aria2c 実行ファイルへの任意のパスです。空欄のままにすると、システム PATH 上の aria2c を使用します。", + "placeholder": "空欄のままにすると PATH 上の aria2c を使用します" + }, "civitaiHostBanner": { "title": "Civitai ホスト設定を利用できます", "content": "Civitai は現在、SFW コンテンツには civitai.com、制限なしコンテンツには civitai.red を使用しています。設定で既定で開くサイトを変更できます。", @@ -278,6 +291,7 @@ }, "sections": { "contentFiltering": "コンテンツフィルタリング", + "downloads": "ダウンロード", "videoSettings": "動画設定", "layoutSettings": "レイアウト設定", "misc": "その他", diff --git a/locales/ko.json b/locales/ko.json index 5a3f2ec0..2e1f40ca 100644 --- a/locales/ko.json +++ b/locales/ko.json @@ -263,6 +263,19 @@ "red": "civitai.red(무제한)" } }, + "downloadBackend": { + "label": "다운로드 백엔드", + "help": "모델 파일을 다운로드하는 방식을 선택합니다. Python은 내장 다운로더를 사용하고, aria2는 실험적인 외부 다운로더 프로세스를 사용합니다.", + "options": { + "python": "Python(내장)", + "aria2": "aria2(실험적)" + } + }, + "aria2cPath": { + "label": "aria2c 경로", + "help": "aria2c 실행 파일의 선택적 경로입니다. 비워 두면 시스템 PATH의 aria2c를 사용합니다.", + "placeholder": "비워 두면 PATH의 aria2c를 사용합니다" + }, "civitaiHostBanner": { "title": "Civitai 호스트 기본 설정 사용 가능", "content": "이제 Civitai는 SFW 콘텐츠에 civitai.com을, 무제한 콘텐츠에 civitai.red를 사용합니다. 설정에서 기본으로 열 사이트를 변경할 수 있습니다.", @@ -278,6 +291,7 @@ }, "sections": { "contentFiltering": "콘텐츠 필터링", + "downloads": "다운로드", "videoSettings": "비디오 설정", "layoutSettings": "레이아웃 설정", "misc": "기타", diff --git a/locales/ru.json b/locales/ru.json index cf0612fe..58eccf53 100644 --- a/locales/ru.json +++ b/locales/ru.json @@ -263,6 +263,19 @@ "red": "civitai.red (без ограничений)" } }, + "downloadBackend": { + "label": "Бэкенд загрузки", + "help": "Выберите способ загрузки файлов моделей. Python использует встроенный загрузчик. aria2 использует экспериментальный внешний процесс загрузки.", + "options": { + "python": "Python (встроенный)", + "aria2": "aria2 (экспериментальный)" + } + }, + "aria2cPath": { + "label": "Путь к aria2c", + "help": "Необязательный путь к исполняемому файлу aria2c. Оставьте пустым, чтобы использовать aria2c из системного PATH.", + "placeholder": "Оставьте пустым, чтобы использовать aria2c из PATH" + }, "civitaiHostBanner": { "title": "Доступна настройка хоста Civitai", "content": "Теперь Civitai использует civitai.com для контента SFW и civitai.red для контента без ограничений. В настройках можно изменить, какой сайт открывать по умолчанию.", @@ -278,6 +291,7 @@ }, "sections": { "contentFiltering": "Фильтрация контента", + "downloads": "Загрузки", "videoSettings": "Настройки видео", "layoutSettings": "Настройки макета", "misc": "Разное", diff --git a/locales/zh-CN.json b/locales/zh-CN.json index 898f550a..17644c63 100644 --- a/locales/zh-CN.json +++ b/locales/zh-CN.json @@ -263,6 +263,19 @@ "red": "civitai.red(无限制)" } }, + "downloadBackend": { + "label": "下载后端", + "help": "选择模型文件的下载方式。Python 使用内置下载器。aria2 使用实验性的外部下载进程。", + "options": { + "python": "Python(内置)", + "aria2": "aria2(实验性)" + } + }, + "aria2cPath": { + "label": "aria2c 路径", + "help": "可选的 aria2c 可执行文件路径。留空则使用系统 PATH 中的 aria2c。", + "placeholder": "留空则使用 PATH 中的 aria2c" + }, "civitaiHostBanner": { "title": "已提供 Civitai 站点偏好设置", "content": "Civitai 现在使用 civitai.com 提供 SFW 内容,使用 civitai.red 提供无限制内容。你可以在设置中更改默认打开的站点。", @@ -278,6 +291,7 @@ }, "sections": { "contentFiltering": "内容过滤", + "downloads": "下载", "videoSettings": "视频设置", "layoutSettings": "布局设置", "misc": "其他", diff --git a/locales/zh-TW.json b/locales/zh-TW.json index ef4077ba..c2139ef0 100644 --- a/locales/zh-TW.json +++ b/locales/zh-TW.json @@ -263,6 +263,19 @@ "red": "civitai.red(無限制)" } }, + "downloadBackend": { + "label": "下載後端", + "help": "選擇模型檔案的下載方式。Python 使用內建下載器。aria2 使用實驗性的外部下載程序。", + "options": { + "python": "Python(內建)", + "aria2": "aria2(實驗性)" + } + }, + "aria2cPath": { + "label": "aria2c 路徑", + "help": "可選的 aria2c 可執行檔路徑。留空則使用系統 PATH 中的 aria2c。", + "placeholder": "留空則使用 PATH 中的 aria2c" + }, "civitaiHostBanner": { "title": "已提供 Civitai 站點偏好設定", "content": "Civitai 現在使用 civitai.com 提供 SFW 內容,使用 civitai.red 提供無限制內容。你可以在設定中變更預設開啟的站點。", @@ -278,6 +291,7 @@ }, "sections": { "contentFiltering": "內容過濾", + "downloads": "下載", "videoSettings": "影片設定", "layoutSettings": "版面設定", "misc": "其他", diff --git a/py/services/aria2_downloader.py b/py/services/aria2_downloader.py new file mode 100644 index 00000000..d1b1b018 --- /dev/null +++ b/py/services/aria2_downloader.py @@ -0,0 +1,497 @@ +from __future__ import annotations + +import asyncio +import json +import logging +import os +import secrets +import shutil +import socket +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + +import aiohttp + +from .downloader import DownloadProgress, get_downloader +from .settings_manager import get_settings_manager + +logger = logging.getLogger(__name__) + +CIVITAI_DOWNLOAD_URL_PREFIXES = ( + "https://civitai.com/api/download/", + "https://civitai.red/api/download/", +) + + +class Aria2Error(RuntimeError): + """Raised when aria2 integration fails.""" + + +@dataclass +class Aria2Transfer: + """Track an aria2 download registered by the Python coordinator.""" + + gid: str + save_path: str + + +class Aria2Downloader: + """Manage an aria2 RPC daemon for experimental model downloads.""" + + _instance = None + _lock = asyncio.Lock() + + @classmethod + async def get_instance(cls) -> "Aria2Downloader": + async with cls._lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def __init__(self) -> None: + if hasattr(self, "_initialized"): + return + + self._initialized = True + self._process: Optional[asyncio.subprocess.Process] = None + self._rpc_port: Optional[int] = None + self._rpc_secret = "" + self._rpc_url = "" + self._rpc_session: Optional[aiohttp.ClientSession] = None + self._rpc_session_lock = asyncio.Lock() + self._process_lock = asyncio.Lock() + self._transfers: Dict[str, Aria2Transfer] = {} + self._poll_interval = 0.5 + + @property + def is_running(self) -> bool: + return self._process is not None and self._process.returncode is None + + async def download_file( + self, + url: str, + save_path: str, + *, + download_id: str, + progress_callback=None, + headers: Optional[Dict[str, str]] = None, + ) -> Tuple[bool, str]: + """Download a file using aria2 RPC and wait for completion.""" + + await self._ensure_process() + save_path = os.path.abspath(save_path) + save_dir = os.path.dirname(save_path) + out_name = os.path.basename(save_path) + + Path(save_dir).mkdir(parents=True, exist_ok=True) + + resolved_url = url + request_headers = headers + if headers and url.startswith(CIVITAI_DOWNLOAD_URL_PREFIXES): + resolved_url = await self._resolve_authenticated_redirect_url(url, headers) + if resolved_url != url: + request_headers = None + logger.debug( + "Resolved Civitai download %s to signed URL for aria2", + download_id, + ) + + options: Dict[str, str] = { + "dir": save_dir, + "out": out_name, + "continue": "true", + "max-connection-per-server": "4", + "split": "4", + "min-split-size": "1M", + "allow-overwrite": "true", + "auto-file-renaming": "false", + "file-allocation": "none", + } + if request_headers: + options["header"] = [ + f"{key}: {value}" for key, value in request_headers.items() + ] + + logger.debug( + "Submitting aria2 download %s -> %s (auth=%s, civitai_signed=%s)", + download_id, + save_path, + bool(request_headers), + resolved_url != url, + ) + + try: + gid = await self._rpc_call("aria2.addUri", [[resolved_url], options]) + except Exception as exc: + raise Aria2Error(f"Failed to schedule aria2 download: {exc}") from exc + + logger.debug("aria2 accepted download %s with gid %s", download_id, gid) + + self._transfers[download_id] = Aria2Transfer(gid=gid, save_path=save_path) + + try: + while True: + status = await self.get_status(download_id) + if status is None: + return False, "aria2 download not found" + + snapshot = self._build_progress_snapshot(status) + if progress_callback is not None: + await self._dispatch_progress(progress_callback, snapshot) + + state = status.get("status", "") + if state == "complete": + completed_path = self._resolve_completed_path(status, save_path) + return True, completed_path + if state == "error": + return False, status.get("errorMessage") or "aria2 download failed" + if state == "removed": + return False, "Download was cancelled" + + await asyncio.sleep(self._poll_interval) + finally: + self._transfers.pop(download_id, None) + + async def get_status(self, download_id: str) -> Optional[Dict[str, Any]]: + """Return the raw aria2 status payload for a known download.""" + + transfer = self._transfers.get(download_id) + if transfer is None: + return None + + keys = [ + "gid", + "status", + "totalLength", + "completedLength", + "downloadSpeed", + "errorMessage", + "files", + ] + try: + status = await self._rpc_call("aria2.tellStatus", [transfer.gid, keys]) + except Exception as exc: + raise Aria2Error(f"Failed to query aria2 download status: {exc}") from exc + + if isinstance(status, dict): + return status + return None + + async def has_transfer(self, download_id: str) -> bool: + return download_id in self._transfers + + async def pause_download(self, download_id: str) -> Dict[str, Any]: + transfer = self._transfers.get(download_id) + if transfer is None: + return {"success": False, "error": "Download task not found"} + + try: + await self._rpc_call("aria2.forcePause", [transfer.gid]) + except Exception as exc: + return {"success": False, "error": str(exc)} + + return {"success": True, "message": "Download paused successfully"} + + async def resume_download(self, download_id: str) -> Dict[str, Any]: + transfer = self._transfers.get(download_id) + if transfer is None: + return {"success": False, "error": "Download task not found"} + + try: + await self._rpc_call("aria2.unpause", [transfer.gid]) + except Exception as exc: + return {"success": False, "error": str(exc)} + + return {"success": True, "message": "Download resumed successfully"} + + async def cancel_download(self, download_id: str) -> Dict[str, Any]: + transfer = self._transfers.get(download_id) + if transfer is None: + return {"success": False, "error": "Download task not found"} + + try: + await self._rpc_call("aria2.forceRemove", [transfer.gid]) + except Exception as exc: + return {"success": False, "error": str(exc)} + + return {"success": True, "message": "Download cancelled successfully"} + + async def close(self) -> None: + """Shut down the RPC process and session.""" + + if self._rpc_session is not None: + await self._rpc_session.close() + self._rpc_session = None + + process = self._process + self._process = None + self._transfers.clear() + + if process is None: + return + + if process.returncode is None: + process.terminate() + try: + await asyncio.wait_for(process.wait(), timeout=5.0) + except asyncio.TimeoutError: + process.kill() + await process.wait() + + async def _dispatch_progress(self, callback, snapshot: DownloadProgress) -> None: + try: + result = callback(snapshot, snapshot) + except TypeError: + result = callback(snapshot.percent_complete) + + if asyncio.iscoroutine(result): + await result + elif hasattr(result, "__await__"): + await result + + def _build_progress_snapshot(self, status: Dict[str, Any]) -> DownloadProgress: + completed = self._parse_int(status.get("completedLength")) + total = self._parse_int(status.get("totalLength")) + speed = float(self._parse_int(status.get("downloadSpeed"))) + percent = 0.0 + if total > 0: + percent = (completed / total) * 100.0 + + return DownloadProgress( + percent_complete=max(0.0, min(percent, 100.0)), + bytes_downloaded=completed, + total_bytes=total or None, + bytes_per_second=speed, + timestamp=datetime.now().timestamp(), + ) + + def _resolve_completed_path(self, status: Dict[str, Any], default_path: str) -> str: + files = status.get("files") + if isinstance(files, list) and files: + first = files[0] + if isinstance(first, dict): + candidate = first.get("path") + if isinstance(candidate, str) and candidate: + return candidate + return default_path + + @staticmethod + def _parse_int(value: Any) -> int: + try: + return int(value) + except (TypeError, ValueError): + return 0 + + async def _resolve_authenticated_redirect_url( + self, + url: str, + headers: Dict[str, str], + ) -> str: + downloader = await get_downloader() + session = await downloader.session + request_headers = dict(downloader.default_headers) + request_headers.update(headers) + request_headers["Accept-Encoding"] = "identity" + + try: + async with session.get( + url, + headers=request_headers, + allow_redirects=False, + proxy=downloader.proxy_url, + ) as response: + if response.status in {301, 302, 303, 307, 308}: + location = response.headers.get("Location") + if location: + return location + raise Aria2Error( + "Authenticated Civitai redirect did not include a Location header" + ) + + if response.status == 200: + return url + + body = await response.text() + raise Aria2Error( + f"Failed to resolve authenticated Civitai redirect: status={response.status} body={body[:300]}" + ) + except aiohttp.ClientError as exc: + raise Aria2Error( + f"Failed to resolve authenticated Civitai redirect: {exc}" + ) from exc + + async def _ensure_process(self) -> None: + async with self._process_lock: + if self.is_running and await self._ping(): + return + + await self.close() + + executable = self._resolve_executable() + self._rpc_port = self._find_free_port() + self._rpc_secret = secrets.token_hex(16) + self._rpc_url = f"http://127.0.0.1:{self._rpc_port}/jsonrpc" + + command = [ + executable, + "--enable-rpc=true", + "--rpc-listen-all=false", + f"--rpc-listen-port={self._rpc_port}", + f"--rpc-secret={self._rpc_secret}", + "--check-certificate=true", + "--allow-overwrite=true", + "--auto-file-renaming=false", + "--file-allocation=none", + "--max-concurrent-downloads=5", + "--continue=true", + "--daemon=false", + "--quiet=true", + f"--stop-with-process={os.getpid()}", + ] + + logger.info("Starting aria2 RPC daemon from %s", executable) + self._process = await asyncio.create_subprocess_exec( + *command, + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.PIPE, + ) + + await self._wait_until_ready() + + def _resolve_executable(self) -> str: + settings = get_settings_manager() + configured_path = (settings.get("aria2c_path") or "").strip() + candidate = configured_path or "aria2c" + + resolved = shutil.which(candidate) + if resolved: + return resolved + + if configured_path and os.path.isfile(configured_path) and os.access( + configured_path, os.X_OK + ): + return configured_path + + raise Aria2Error( + "aria2c executable was not found. Install aria2 or configure aria2c_path." + ) + + async def _wait_until_ready(self) -> None: + assert self._process is not None + + start_time = asyncio.get_running_loop().time() + last_error = "" + while asyncio.get_running_loop().time() - start_time < 10.0: + if self._process.returncode is not None: + stderr_output = "" + if self._process.stderr is not None: + try: + stderr_output = ( + await asyncio.wait_for(self._process.stderr.read(), timeout=0.2) + ).decode("utf-8", errors="replace") + except Exception: + stderr_output = "" + raise Aria2Error( + f"aria2 RPC process exited early with code {self._process.returncode}: {stderr_output.strip()}" + ) + + try: + if await self._ping(): + return + except Exception as exc: # pragma: no cover - startup race + last_error = str(exc) + + await asyncio.sleep(0.2) + + raise Aria2Error( + f"Timed out waiting for aria2 RPC to become ready{': ' + last_error if last_error else ''}" + ) + + async def _ping(self) -> bool: + try: + result = await self._rpc_call("aria2.getVersion", []) + except Exception: + return False + + return isinstance(result, dict) + + async def _rpc_call(self, method: str, params: list[Any]) -> Any: + if not self._rpc_url: + raise Aria2Error("aria2 RPC endpoint is not initialized") + + session = await self._get_rpc_session() + payload = { + "jsonrpc": "2.0", + "id": secrets.token_hex(8), + "method": method, + "params": [f"token:{self._rpc_secret}", *params], + } + + async with session.post(self._rpc_url, json=payload) as response: + text = await response.text() + + try: + body = json.loads(text) + except json.JSONDecodeError: + body = None + + if body is None: + if response.status != 200: + raise Aria2Error( + f"aria2 RPC returned status {response.status} with non-JSON body: {text}" + ) + raise Aria2Error(f"Invalid aria2 RPC response: {text}") + + if "error" in body: + error = body["error"] or {} + code = error.get("code") if isinstance(error, dict) else None + message = error.get("message") if isinstance(error, dict) else str(error) + logger.error( + "aria2 RPC %s failed with HTTP %s, code=%s, message=%s", + method, + response.status, + code, + message, + ) + status_message = ( + f"aria2 RPC {method} failed with status {response.status}: {message}" + if response.status != 200 + else message + ) + raise Aria2Error(status_message or "Unknown aria2 RPC error") + + if response.status != 200: + logger.error( + "aria2 RPC %s returned unexpected HTTP status %s without error payload: %s", + method, + response.status, + body, + ) + raise Aria2Error( + f"aria2 RPC {method} returned unexpected status {response.status}" + ) + + return body.get("result") + + async def _get_rpc_session(self) -> aiohttp.ClientSession: + if self._rpc_session is None or self._rpc_session.closed: + async with self._rpc_session_lock: + if self._rpc_session is None or self._rpc_session.closed: + timeout = aiohttp.ClientTimeout(total=30) + self._rpc_session = aiohttp.ClientSession(timeout=timeout) + return self._rpc_session + + @staticmethod + def _find_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + sock.listen(1) + return int(sock.getsockname()[1]) + + +async def get_aria2_downloader() -> Aria2Downloader: + """Get the singleton aria2 downloader.""" + + return await Aria2Downloader.get_instance() diff --git a/py/services/download_manager.py b/py/services/download_manager.py index 2fe07cff..ce727abd 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -5,6 +5,7 @@ import asyncio import inspect import shutil import zipfile +from concurrent.futures import ThreadPoolExecutor from collections import OrderedDict import uuid from typing import Dict, List, Optional, Set, Tuple @@ -25,6 +26,7 @@ from .service_registry import ServiceRegistry from .settings_manager import get_settings_manager from .metadata_service import get_default_metadata_provider, get_metadata_provider from .downloader import get_downloader, DownloadProgress, DownloadStreamControl +from .aria2_downloader import Aria2Error, get_aria2_downloader # Download to temporary file first import tempfile @@ -60,6 +62,59 @@ class DownloadManager: self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads self._download_tasks = {} # download_id -> asyncio.Task self._pause_events: Dict[str, DownloadStreamControl] = {} + self._archive_executor = ThreadPoolExecutor( + max_workers=2, thread_name_prefix="lm-archive" + ) + + @staticmethod + def _get_model_download_backend() -> str: + backend = (get_settings_manager().get("download_backend") or "python").strip() + return backend.lower() or "python" + + async def _download_model_file( + self, + download_url: str, + save_path: str, + *, + backend: str, + progress_callback, + use_auth: bool, + download_id: Optional[str], + pause_control: Optional[DownloadStreamControl], + ) -> Tuple[bool, str]: + if backend == "aria2": + if not download_id: + return False, "aria2 downloads require a tracked download_id" + + headers: Dict[str, str] = {} + if use_auth: + api_key = (get_settings_manager().get("civitai_api_key") or "").strip() + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + try: + aria2_downloader = await get_aria2_downloader() + return await aria2_downloader.download_file( + download_url, + save_path, + download_id=download_id, + progress_callback=progress_callback, + headers=headers or None, + ) + except Aria2Error as exc: + logger.error("aria2 download failed for %s: %s", download_url, exc) + return False, str(exc) + + download_kwargs = { + "progress_callback": progress_callback, + "use_auth": use_auth, + } + + if pause_control is not None: + download_kwargs["pause_event"] = pause_control + + downloader = await get_downloader() + return await downloader.download_file(download_url, save_path, **download_kwargs) async def _get_lora_scanner(self): """Get the lora scanner from registry""" @@ -126,6 +181,7 @@ class DownloadManager: "model_version_id": model_version_id, "progress": 0, "status": "queued", + "transfer_backend": self._get_model_download_backend(), "bytes_downloaded": 0, "total_bytes": None, "bytes_per_second": 0.0, @@ -240,6 +296,9 @@ class DownloadManager: tracking_callback, use_default_paths, task_id, + self._active_downloads.get(task_id, {}).get( + "transfer_backend", "python" + ), source, file_params, ) @@ -294,6 +353,7 @@ class DownloadManager: progress_callback, use_default_paths, download_id=None, + transfer_backend="python", source=None, file_params=None, ): @@ -696,16 +756,27 @@ class DownloadManager: logger.info(f"Creating EmbeddingMetadata for {file_name}") # 6. Start download process - result = await self._execute_download( - download_urls=download_urls, - save_dir=save_dir, - metadata=metadata, - version_info=version_info, - relative_path=relative_path, - progress_callback=progress_callback, - model_type=model_type, - download_id=download_id, - ) + execute_kwargs = { + "download_urls": download_urls, + "save_dir": save_dir, + "metadata": metadata, + "version_info": version_info, + "relative_path": relative_path, + "progress_callback": progress_callback, + "model_type": model_type, + "download_id": download_id, + } + execute_signature = inspect.signature(self._execute_download) + if ( + "transfer_backend" in execute_signature.parameters + or any( + parameter.kind == inspect.Parameter.VAR_KEYWORD + for parameter in execute_signature.parameters.values() + ) + ): + execute_kwargs["transfer_backend"] = transfer_backend + + result = await self._execute_download(**execute_kwargs) if result.get("success", False): resolved_model_id = ( @@ -965,6 +1036,7 @@ class DownloadManager: progress_callback=None, model_type: str = "lora", download_id: str = None, + transfer_backend: Optional[str] = None, ) -> Dict: """Execute the actual download process including preview images and model files""" metadata_entries: List = [] @@ -974,6 +1046,7 @@ class DownloadManager: preview_targets: List[str] = [] preview_path: str | None = None preview_nsfw_level = 0 + transfer_backend = (transfer_backend or self._get_model_download_backend()).lower() try: # Extract original filename details original_filename = os.path.basename(metadata.file_path) @@ -1136,32 +1209,37 @@ class DownloadManager: if progress_callback: await progress_callback(3) # 3% progress after preview download - # Download model file with progress tracking using downloader - downloader = await get_downloader() - if pause_control is not None: - pause_control.update_stall_timeout(downloader.stall_timeout) + # Download model file with progress tracking using the configured backend + downloader = None + if transfer_backend == "python": + downloader = await get_downloader() + if pause_control is not None: + pause_control.update_stall_timeout(downloader.stall_timeout) + if pause_control is not None and pause_control.is_paused(): + if download_id and download_id in self._active_downloads: + self._active_downloads[download_id]["status"] = "paused" + self._active_downloads[download_id]["bytes_per_second"] = 0.0 + await pause_control.wait() + if download_id and download_id in self._active_downloads: + self._active_downloads[download_id]["status"] = "downloading" last_error = None for download_url in download_urls: download_url = normalize_civitai_download_url(download_url) use_auth = download_url.startswith(CIVITAI_DOWNLOAD_URL_PREFIXES) - download_kwargs = { - "progress_callback": lambda progress, snapshot=None: ( + success, result = await self._download_model_file( + download_url, + save_path, + backend=transfer_backend, + progress_callback=lambda progress, snapshot=None: ( self._handle_download_progress( progress, progress_callback, snapshot, ) ), - "use_auth": use_auth, # Only use authentication for Civitai downloads - } - - if pause_control is not None: - download_kwargs["pause_event"] = pause_control - - success, result = await downloader.download_file( - download_url, - save_path, # Use full path instead of separate dir and filename - **download_kwargs, + use_auth=use_auth, + download_id=download_id, + pause_control=pause_control, ) if success: @@ -1401,7 +1479,8 @@ class DownloadManager: extracted_files.append(dest_path) return extracted_files - return await asyncio.to_thread(_extract_sync) + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self._archive_executor, _extract_sync) async def _build_metadata_entries( self, base_metadata, file_paths: List[str] @@ -1511,8 +1590,28 @@ class DownloadManager: return {"success": False, "error": "Download task not found"} try: - # Get the task and cancel it task = self._download_tasks[download_id] + backend = ( + self._active_downloads.get(download_id, {}).get("transfer_backend") + or "python" + ) + + if backend == "aria2": + try: + aria2_downloader = await get_aria2_downloader() + cancel_result = await aria2_downloader.cancel_download(download_id) + if ( + not cancel_result.get("success") + and cancel_result.get("error") != "Download task not found" + ): + return cancel_result + except Exception as exc: + logger.warning( + "Failed to cancel aria2 transfer for %s, continuing with local task cancellation: %s", + download_id, + exc, + ) + task.cancel() pause_control = self._pause_events.get(download_id) @@ -1613,6 +1712,28 @@ class DownloadManager: pause_control.pause() + backend = ( + self._active_downloads.get(download_id, {}).get("transfer_backend") + or "python" + ) + if backend == "aria2": + try: + aria2_downloader = await get_aria2_downloader() + if await aria2_downloader.has_transfer(download_id): + result = await aria2_downloader.pause_download(download_id) + if not result.get("success"): + pause_control.resume() + return result + except Exception as exc: + pause_control.resume() + return {"success": False, "error": str(exc)} + + download_info = self._active_downloads.get(download_id) + if download_info is not None: + download_info["status"] = "paused" + download_info["bytes_per_second"] = 0.0 + return {"success": True, "message": "Download paused successfully"} + download_info = self._active_downloads.get(download_id) if download_info is not None: download_info["status"] = "paused" @@ -1631,6 +1752,28 @@ class DownloadManager: return {"success": False, "error": "Download is not paused"} download_info = self._active_downloads.get(download_id) + backend = ( + self._active_downloads.get(download_id, {}).get("transfer_backend") + or "python" + ) + if backend == "aria2": + try: + aria2_downloader = await get_aria2_downloader() + if await aria2_downloader.has_transfer(download_id): + result = await aria2_downloader.resume_download(download_id) + if not result.get("success"): + return result + except Exception as exc: + return {"success": False, "error": str(exc)} + + pause_control.resume() + + if download_info is not None: + if download_info.get("status") == "paused": + download_info["status"] = "downloading" + download_info.setdefault("bytes_per_second", 0.0) + return {"success": True, "message": "Download resumed successfully"} + force_reconnect = False if pause_control is not None: elapsed = pause_control.time_since_last_progress() diff --git a/py/services/settings_manager.py b/py/services/settings_manager.py index 8b6cfebc..08279982 100644 --- a/py/services/settings_manager.py +++ b/py/services/settings_manager.py @@ -55,6 +55,8 @@ DEFAULT_KEYS_CLEANUP_THRESHOLD = 10 DEFAULT_SETTINGS: Dict[str, Any] = { "civitai_api_key": "", "civitai_host": "civitai.com", + "download_backend": "python", + "aria2c_path": "", "use_portable_settings": False, "hash_chunk_size_mb": DEFAULT_HASH_CHUNK_SIZE_MB, "language": "en", diff --git a/static/js/managers/SettingsManager.js b/static/js/managers/SettingsManager.js index 1dcc5d2f..ab2906b2 100644 --- a/static/js/managers/SettingsManager.js +++ b/static/js/managers/SettingsManager.js @@ -807,6 +807,16 @@ export class SettingsManager { civitaiHostSelect.value = state.global.settings.civitai_host || 'civitai.com'; } + const downloadBackendSelect = document.getElementById('downloadBackend'); + if (downloadBackendSelect) { + downloadBackendSelect.value = state.global.settings.download_backend || 'python'; + } + + const aria2cPathInput = document.getElementById('aria2cPath'); + if (aria2cPathInput) { + aria2cPathInput.value = state.global.settings.aria2c_path || ''; + } + const recipesPathInput = document.getElementById('recipesPath'); if (recipesPathInput) { recipesPathInput.value = state.global.settings.recipes_path || ''; @@ -950,9 +960,36 @@ export class SettingsManager { languageSelect.value = currentLanguage; } + this.loadDownloadBackendSettings(); this.loadProxySettings(); } + loadDownloadBackendSettings() { + const downloadBackendSelect = document.getElementById('downloadBackend'); + const aria2PathSetting = document.getElementById('aria2PathSetting'); + const updateVisibility = () => { + if (!aria2PathSetting || !downloadBackendSelect) { + return; + } + aria2PathSetting.style.display = downloadBackendSelect.value === 'aria2' ? 'block' : 'none'; + }; + + if (downloadBackendSelect) { + downloadBackendSelect.value = state.global.settings.download_backend || 'python'; + downloadBackendSelect.onchange = () => { + updateVisibility(); + this.saveSelectSetting('downloadBackend', 'download_backend'); + }; + } + + const aria2cPathInput = document.getElementById('aria2cPath'); + if (aria2cPathInput) { + aria2cPathInput.value = state.global.settings.aria2c_path || ''; + } + + updateVisibility(); + } + setupPriorityTagInputs() { ['lora', 'checkpoint', 'embedding'].forEach((modelType) => { const textarea = document.getElementById(`${modelType}PriorityTagsInput`); diff --git a/static/js/state/index.js b/static/js/state/index.js index 4a55ab26..5206c3d4 100644 --- a/static/js/state/index.js +++ b/static/js/state/index.js @@ -6,6 +6,8 @@ import { DEFAULT_PATH_TEMPLATES, DEFAULT_PRIORITY_TAG_CONFIG } from '../utils/co const DEFAULT_SETTINGS_BASE = Object.freeze({ civitai_api_key: '', civitai_host: 'civitai.com', + download_backend: 'python', + aria2c_path: '', use_portable_settings: false, language: 'en', show_only_sfw: false, diff --git a/templates/components/modals/settings_modal.html b/templates/components/modals/settings_modal.html index c85a7cfc..607241e0 100644 --- a/templates/components/modals/settings_modal.html +++ b/templates/components/modals/settings_modal.html @@ -129,6 +129,43 @@ +
+
+

{{ t('settings.sections.downloads') }}

+
+
+
+
+ + +
+
+ +
+
+
+ +
+
diff --git a/tests/frontend/managers/settingsManager.library.test.js b/tests/frontend/managers/settingsManager.library.test.js index 18cd2c24..2b8eb9a0 100644 --- a/tests/frontend/managers/settingsManager.library.test.js +++ b/tests/frontend/managers/settingsManager.library.test.js @@ -305,4 +305,39 @@ describe('SettingsManager library controls', () => { 'success', ); }); + + it('loads download backend settings and toggles the aria2 path field', () => { + const manager = createManager(); + document.body.innerHTML = ` + + + + `; + + state.global.settings = { + download_backend: 'aria2', + aria2c_path: '/usr/bin/aria2c', + }; + + const saveSpy = vi.spyOn(manager, 'saveSelectSetting').mockResolvedValue(); + + manager.loadDownloadBackendSettings(); + + const backendSelect = document.getElementById('downloadBackend'); + const aria2PathSetting = document.getElementById('aria2PathSetting'); + const aria2cPath = document.getElementById('aria2cPath'); + + expect(backendSelect.value).toBe('aria2'); + expect(aria2cPath.value).toBe('/usr/bin/aria2c'); + expect(aria2PathSetting.style.display).toBe('block'); + + backendSelect.value = 'python'; + backendSelect.onchange(); + + expect(aria2PathSetting.style.display).toBe('none'); + expect(saveSpy).toHaveBeenCalledWith('downloadBackend', 'download_backend'); + }); }); diff --git a/tests/services/test_aria2_downloader.py b/tests/services/test_aria2_downloader.py new file mode 100644 index 00000000..606e1d56 --- /dev/null +++ b/tests/services/test_aria2_downloader.py @@ -0,0 +1,269 @@ +from __future__ import annotations + +from pathlib import Path +from unittest.mock import AsyncMock + +import pytest + +from py.services.aria2_downloader import Aria2Downloader, Aria2Error + + +@pytest.mark.asyncio +async def test_download_file_polls_until_complete(tmp_path, monkeypatch): + downloader = Aria2Downloader() + downloader._rpc_url = "http://127.0.0.1/jsonrpc" + downloader._rpc_secret = "secret" + + save_path = tmp_path / "downloads" / "model.safetensors" + progress_events = [] + rpc_calls = [] + statuses = iter( + [ + { + "gid": "gid-1", + "status": "active", + "completedLength": "5", + "totalLength": "10", + "downloadSpeed": "25", + }, + { + "gid": "gid-1", + "status": "complete", + "completedLength": "10", + "totalLength": "10", + "downloadSpeed": "0", + "files": [{"path": str(save_path)}], + }, + ] + ) + + async def fake_rpc_call(method, params): + rpc_calls.append((method, params)) + if method == "aria2.addUri": + return "gid-1" + if method == "aria2.tellStatus": + return next(statuses) + raise AssertionError(f"Unexpected RPC method: {method}") + + monkeypatch.setattr(downloader, "_ensure_process", AsyncMock()) + monkeypatch.setattr( + downloader, + "_resolve_authenticated_redirect_url", + AsyncMock( + return_value="https://signed.example.com/model.safetensors?token=abc" + ), + ) + monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call) + monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock()) + + async def progress_callback(progress, snapshot=None): + progress_events.append(snapshot.percent_complete if snapshot else progress) + + success, result = await downloader.download_file( + "https://civitai.com/api/download/models/123", + str(save_path), + download_id="download-1", + progress_callback=progress_callback, + headers={"Authorization": "Bearer token"}, + ) + + assert success is True + assert result == str(save_path) + assert progress_events == [50.0, 100.0] + assert downloader._transfers == {} + assert rpc_calls[0][0] == "aria2.addUri" + assert rpc_calls[0][1][0] == [ + "https://signed.example.com/model.safetensors?token=abc" + ] + assert rpc_calls[0][1][1]["out"] == "model.safetensors" + assert "header" not in rpc_calls[0][1][1] + + +@pytest.mark.asyncio +async def test_download_file_keeps_auth_headers_when_civitai_does_not_redirect( + tmp_path, monkeypatch +): + downloader = Aria2Downloader() + downloader._rpc_url = "http://127.0.0.1/jsonrpc" + downloader._rpc_secret = "secret" + + save_path = tmp_path / "downloads" / "model.safetensors" + rpc_calls = [] + statuses = iter( + [ + { + "gid": "gid-1", + "status": "complete", + "completedLength": "10", + "totalLength": "10", + "downloadSpeed": "0", + "files": [{"path": str(save_path)}], + }, + ] + ) + + async def fake_rpc_call(method, params): + rpc_calls.append((method, params)) + if method == "aria2.addUri": + return "gid-1" + if method == "aria2.tellStatus": + return next(statuses) + raise AssertionError(f"Unexpected RPC method: {method}") + + monkeypatch.setattr(downloader, "_ensure_process", AsyncMock()) + monkeypatch.setattr( + downloader, + "_resolve_authenticated_redirect_url", + AsyncMock(return_value="https://civitai.com/api/download/models/123"), + ) + monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call) + monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock()) + + success, result = await downloader.download_file( + "https://civitai.com/api/download/models/123", + str(save_path), + download_id="download-1", + headers={"Authorization": "Bearer token"}, + ) + + assert success is True + assert result == str(save_path) + assert rpc_calls[0][1][0] == ["https://civitai.com/api/download/models/123"] + assert rpc_calls[0][1][1]["header"] == ["Authorization: Bearer token"] + + +@pytest.mark.asyncio +async def test_pause_resume_cancel_forward_to_rpc(monkeypatch): + downloader = Aria2Downloader() + downloader._transfers["download-1"] = type( + "Transfer", (), {"gid": "gid-1", "save_path": "/tmp/model.safetensors"} + )() + + calls = [] + + async def fake_rpc_call(method, params): + calls.append((method, params)) + return "gid-1" + + monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call) + + pause_result = await downloader.pause_download("download-1") + resume_result = await downloader.resume_download("download-1") + cancel_result = await downloader.cancel_download("download-1") + + assert pause_result["success"] is True + assert resume_result["success"] is True + assert cancel_result["success"] is True + assert calls == [ + ("aria2.forcePause", ["gid-1"]), + ("aria2.unpause", ["gid-1"]), + ("aria2.forceRemove", ["gid-1"]), + ] + + +def test_build_progress_snapshot_normalizes_numeric_fields(): + downloader = Aria2Downloader() + + snapshot = downloader._build_progress_snapshot( + { + "completedLength": "75", + "totalLength": "100", + "downloadSpeed": "512", + } + ) + + assert snapshot.percent_complete == 75.0 + assert snapshot.bytes_downloaded == 75 + assert snapshot.total_bytes == 100 + assert snapshot.bytes_per_second == 512.0 + + +def test_resolve_executable_raises_when_binary_missing(monkeypatch): + downloader = Aria2Downloader() + settings = type("Settings", (), {"get": lambda self, key, default=None: ""})() + + monkeypatch.setattr("py.services.aria2_downloader.get_settings_manager", lambda: settings) + monkeypatch.setattr("py.services.aria2_downloader.shutil.which", lambda _: None) + + with pytest.raises(Aria2Error): + downloader._resolve_executable() + + +@pytest.mark.asyncio +async def test_rpc_call_surfaces_json_error_on_non_200(monkeypatch): + downloader = Aria2Downloader() + downloader._rpc_url = "http://127.0.0.1:6800/jsonrpc" + downloader._rpc_secret = "secret" + + class FakeResponse: + status = 400 + + async def text(self): + return ( + '{"jsonrpc":"2.0","id":"x","error":{"code":1,"message":"Unauthorized"}}' + ) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + class FakeSession: + def post(self, _url, json=None): + return FakeResponse() + + monkeypatch.setattr(downloader, "_get_rpc_session", AsyncMock(return_value=FakeSession())) + + with pytest.raises(Aria2Error) as exc_info: + await downloader._rpc_call("aria2.addUri", [["https://example.com/file"]]) + + assert "Unauthorized" in str(exc_info.value) + assert "aria2.addUri" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_resolve_authenticated_redirect_url_returns_location(monkeypatch): + downloader = Aria2Downloader() + + class FakeResponse: + status = 307 + headers = {"Location": "https://signed.example.com/file.safetensors"} + + async def text(self): + return "" + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + class FakeSession: + def get(self, _url, headers=None, allow_redirects=False, proxy=None): + return FakeResponse() + + class FakeDownloader: + default_headers = {"User-Agent": "ComfyUI-LoRA-Manager/1.0"} + proxy_url = None + + @property + def session(self): + async def _session(): + return FakeSession() + + return _session() + + fake_downloader = FakeDownloader() + + monkeypatch.setattr( + "py.services.aria2_downloader.get_downloader", + AsyncMock(return_value=fake_downloader), + ) + + result = await downloader._resolve_authenticated_redirect_url( + "https://civitai.com/api/download/models/123", + {"Authorization": "Bearer token"}, + ) + + assert result == "https://signed.example.com/file.safetensors" diff --git a/tests/services/test_download_manager_basic.py b/tests/services/test_download_manager_basic.py index ac801212..10fae8d7 100644 --- a/tests/services/test_download_manager_basic.py +++ b/tests/services/test_download_manager_basic.py @@ -179,6 +179,7 @@ async def test_successful_download_uses_defaults( progress_callback, model_type, download_id, + transfer_backend=None, ): captured.update( { @@ -268,6 +269,7 @@ async def test_download_uses_active_mirrors( progress_callback, model_type, download_id, + transfer_backend=None, ): captured["download_urls"] = download_urls return {"success": True} @@ -288,6 +290,214 @@ async def test_download_uses_active_mirrors( assert captured["download_urls"] == ["https://mirror.example/file.safetensors"] +@pytest.mark.asyncio +async def test_pause_resume_cancel_delegate_to_aria2_backend(monkeypatch): + manager = DownloadManager() + + task = asyncio.create_task(asyncio.sleep(60)) + manager._download_tasks["download-1"] = task + manager._pause_events["download-1"] = download_manager.DownloadStreamControl() + manager._active_downloads["download-1"] = { + "transfer_backend": "aria2", + "status": "downloading", + } + + class DummyAria2Downloader: + def __init__(self): + self.calls = [] + + async def pause_download(self, download_id): + self.calls.append(("pause", download_id)) + return {"success": True, "message": "paused"} + + async def resume_download(self, download_id): + self.calls.append(("resume", download_id)) + return {"success": True, "message": "resumed"} + + async def cancel_download(self, download_id): + self.calls.append(("cancel", download_id)) + return {"success": True, "message": "cancelled"} + + async def has_transfer(self, download_id): + self.calls.append(("has_transfer", download_id)) + return True + + dummy_aria2 = DummyAria2Downloader() + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=dummy_aria2), + ) + + pause_result = await manager.pause_download("download-1") + assert pause_result["success"] is True + assert manager._active_downloads["download-1"]["status"] == "paused" + + resume_result = await manager.resume_download("download-1") + assert resume_result["success"] is True + assert manager._active_downloads["download-1"]["status"] == "downloading" + + cancel_result = await manager.cancel_download("download-1") + assert cancel_result["success"] is True + assert task.cancelled() or task.done() + assert dummy_aria2.calls == [ + ("has_transfer", "download-1"), + ("pause", "download-1"), + ("has_transfer", "download-1"), + ("resume", "download-1"), + ("cancel", "download-1"), + ] + + +@pytest.mark.asyncio +async def test_cancel_allows_queued_aria2_task_without_transfer(monkeypatch): + manager = DownloadManager() + + started = asyncio.Event() + + async def blocked_task(): + started.set() + await asyncio.sleep(60) + + task = asyncio.create_task(blocked_task()) + await started.wait() + + manager._download_tasks["download-queued"] = task + manager._active_downloads["download-queued"] = { + "transfer_backend": "aria2", + "status": "queued", + } + + class DummyAria2Downloader: + async def cancel_download(self, download_id): + return {"success": False, "error": "Download task not found"} + + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=DummyAria2Downloader()), + ) + + result = await manager.cancel_download("download-queued") + + assert result["success"] is True + assert task.cancelled() or task.done() + + +@pytest.mark.asyncio +async def test_pause_resume_queued_aria2_task_without_transfer(monkeypatch): + manager = DownloadManager() + + task = asyncio.create_task(asyncio.sleep(60)) + manager._download_tasks["download-queued"] = task + manager._pause_events["download-queued"] = download_manager.DownloadStreamControl() + manager._active_downloads["download-queued"] = { + "transfer_backend": "aria2", + "status": "waiting", + "bytes_per_second": 12.0, + } + + class DummyAria2Downloader: + def __init__(self): + self.calls = [] + + async def has_transfer(self, download_id): + self.calls.append(("has_transfer", download_id)) + return False + + async def pause_download(self, download_id): + self.calls.append(("pause", download_id)) + return {"success": True, "message": "paused"} + + async def resume_download(self, download_id): + self.calls.append(("resume", download_id)) + return {"success": True, "message": "resumed"} + + dummy_aria2 = DummyAria2Downloader() + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=dummy_aria2), + ) + + pause_result = await manager.pause_download("download-queued") + assert pause_result == {"success": True, "message": "Download paused successfully"} + assert manager._active_downloads["download-queued"]["status"] == "paused" + assert manager._pause_events["download-queued"].is_paused() is True + + resume_result = await manager.resume_download("download-queued") + assert resume_result == {"success": True, "message": "Download resumed successfully"} + assert manager._active_downloads["download-queued"]["status"] == "downloading" + assert manager._pause_events["download-queued"].is_set() is True + assert dummy_aria2.calls == [ + ("has_transfer", "download-queued"), + ("has_transfer", "download-queued"), + ] + + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + +@pytest.mark.asyncio +async def test_download_uses_captured_backend_when_settings_change( + monkeypatch, scanners, metadata_provider, tmp_path +): + manager = DownloadManager() + settings = get_settings_manager() + settings.settings["download_backend"] = "aria2" + + semaphore = asyncio.Semaphore(0) + manager._download_semaphore = semaphore + + captured = {} + + async def fake_execute_original_download( + self, + model_id, + model_version_id, + save_dir, + relative_path, + progress_callback, + use_default_paths, + download_id=None, + transfer_backend="python", + source=None, + file_params=None, + ): + captured["transfer_backend"] = transfer_backend + return {"success": True} + + monkeypatch.setattr( + DownloadManager, + "_execute_original_download", + fake_execute_original_download, + ) + + download_task = asyncio.create_task( + manager.download_from_civitai( + model_version_id=99, + save_dir=str(tmp_path), + use_default_paths=True, + progress_callback=None, + source=None, + ) + ) + + await asyncio.sleep(0) + assert len(manager._active_downloads) == 1 + download_id = next(iter(manager._active_downloads)) + assert manager._active_downloads[download_id]["transfer_backend"] == "aria2" + + settings.settings["download_backend"] = "python" + semaphore.release() + + result = await download_task + + assert result["success"] is True + assert captured["transfer_backend"] == "aria2" + + @pytest.mark.asyncio async def test_download_aborts_when_version_exists( monkeypatch, scanners, metadata_provider diff --git a/tests/services/test_download_manager_error.py b/tests/services/test_download_manager_error.py index fd154f3b..8f75e6ea 100644 --- a/tests/services/test_download_manager_error.py +++ b/tests/services/test_download_manager_error.py @@ -136,6 +136,190 @@ async def test_execute_download_retries_urls(monkeypatch, tmp_path): assert dummy_scanner.calls # ensure cache updated +@pytest.mark.asyncio +async def test_execute_download_uses_aria2_backend_for_model_files(monkeypatch, tmp_path): + manager = DownloadManager() + settings = get_settings_manager() + settings.settings["download_backend"] = "aria2" + settings.settings["civitai_api_key"] = "secret-key" + + save_dir = tmp_path / "downloads" + save_dir.mkdir() + target_path = save_dir / "file.safetensors" + + class DummyMetadata: + def __init__(self, path: Path): + self.file_path = str(path) + self.sha256 = "sha256" + self.file_name = path.stem + self.preview_url = None + + def generate_unique_filename(self, *_args, **_kwargs): + return os.path.basename(self.file_path) + + def update_file_info(self, _path): + return None + + def to_dict(self): + return {"file_path": self.file_path} + + class DummyAria2Downloader: + def __init__(self): + self.calls = [] + + async def download_file( + self, + url, + save_path, + *, + download_id, + progress_callback=None, + headers=None, + ): + self.calls.append( + { + "url": url, + "save_path": save_path, + "download_id": download_id, + "headers": headers, + } + ) + Path(save_path).write_text("content") + return True, save_path + + dummy_aria2 = DummyAria2Downloader() + + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=dummy_aria2), + ) + monkeypatch.setattr( + download_manager, + "get_downloader", + AsyncMock(side_effect=AssertionError("python downloader should not be used")), + ) + + class DummyScanner: + async def add_model_to_cache(self, metadata_dict, relative_path): + return {"metadata": metadata_dict, "relative_path": relative_path} + + dummy_scanner = DummyScanner() + monkeypatch.setattr( + DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner) + ) + monkeypatch.setattr( + DownloadManager, + "_get_checkpoint_scanner", + AsyncMock(return_value=dummy_scanner), + ) + monkeypatch.setattr( + ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=dummy_scanner) + ) + monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True)) + + result = await manager._execute_download( + download_urls=["https://civitai.com/api/download/models/1"], + save_dir=str(save_dir), + metadata=DummyMetadata(target_path), + version_info={"images": []}, + relative_path="", + progress_callback=None, + model_type="lora", + download_id="download-1", + ) + + assert result == {"success": True} + assert dummy_aria2.calls == [ + { + "url": "https://civitai.com/api/download/models/1", + "save_path": str(target_path), + "download_id": "download-1", + "headers": {"Authorization": "Bearer secret-key"}, + } + ] + + +@pytest.mark.asyncio +async def test_execute_download_allows_anonymous_civitai_with_aria2( + monkeypatch, tmp_path +): + manager = DownloadManager() + settings = get_settings_manager() + settings.settings["download_backend"] = "aria2" + settings.settings["civitai_api_key"] = "" + + save_dir = tmp_path / "downloads" + save_dir.mkdir() + target_path = save_dir / "file.safetensors" + + class DummyMetadata: + def __init__(self, path: Path): + self.file_path = str(path) + self.sha256 = "sha256" + self.file_name = path.stem + self.preview_url = None + + def generate_unique_filename(self, *_args, **_kwargs): + return os.path.basename(self.file_path) + + def update_file_info(self, _path): + return None + + def to_dict(self): + return {"file_path": self.file_path} + + class DummyAria2Downloader: + def __init__(self): + self.calls = [] + + async def download_file( + self, + url, + save_path, + *, + download_id, + progress_callback=None, + headers=None, + ): + self.calls.append({"url": url, "headers": headers, "download_id": download_id}) + Path(save_path).write_text("content") + return True, save_path + + dummy_aria2 = DummyAria2Downloader() + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=dummy_aria2), + ) + + dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None)) + monkeypatch.setattr( + DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner) + ) + monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True)) + + result = await manager._execute_download( + download_urls=["https://civitai.com/api/download/models/1"], + save_dir=str(save_dir), + metadata=DummyMetadata(target_path), + version_info={"images": []}, + relative_path="", + progress_callback=None, + model_type="lora", + download_id="download-2", + ) + + assert result == {"success": True} + assert dummy_aria2.calls == [ + { + "url": "https://civitai.com/api/download/models/1", + "headers": None, + "download_id": "download-2", + } + ] + + @pytest.mark.asyncio async def test_execute_download_adjusts_checkpoint_sub_type(monkeypatch, tmp_path): """Test that checkpoint sub_type is adjusted during download.""" @@ -276,6 +460,13 @@ async def test_execute_download_extracts_zip_single_model(monkeypatch, tmp_path) monkeypatch.setattr( download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader()) ) + + class ImmediateLoop: + async def run_in_executor(self, executor, func, *args): + return func(*args) + + monkeypatch.setattr(download_manager.asyncio, "get_running_loop", lambda: ImmediateLoop()) + dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None)) monkeypatch.setattr( DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner) @@ -344,6 +535,13 @@ async def test_execute_download_extracts_zip_multiple_models(monkeypatch, tmp_pa monkeypatch.setattr( download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader()) ) + + class ImmediateLoop: + async def run_in_executor(self, executor, func, *args): + return func(*args) + + monkeypatch.setattr(download_manager.asyncio, "get_running_loop", lambda: ImmediateLoop()) + dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None)) monkeypatch.setattr( DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner) @@ -418,6 +616,13 @@ async def test_execute_download_extracts_zip_pt_embedding(monkeypatch, tmp_path) monkeypatch.setattr( download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader()) ) + + class ImmediateLoop: + async def run_in_executor(self, executor, func, *args): + return func(*args) + + monkeypatch.setattr(download_manager.asyncio, "get_running_loop", lambda: ImmediateLoop()) + dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None)) monkeypatch.setattr( ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=dummy_scanner) @@ -446,6 +651,36 @@ async def test_execute_download_extracts_zip_pt_embedding(monkeypatch, tmp_path) assert dummy_scanner.add_model_to_cache.await_count == 1 +@pytest.mark.asyncio +async def test_extract_model_files_from_archive_uses_executor(monkeypatch, tmp_path): + manager = DownloadManager() + archive_path = tmp_path / "bundle.zip" + with zipfile.ZipFile(archive_path, "w") as archive: + archive.writestr("inner/model.safetensors", b"model") + + captured = {} + + class ImmediateLoop: + async def run_in_executor(self, executor, func, *args): + captured["executor"] = executor + return func(*args) + + monkeypatch.setattr( + download_manager.asyncio, + "get_running_loop", + lambda: ImmediateLoop(), + ) + + extracted = await manager._extract_model_files_from_archive( + str(archive_path), + {".safetensors"}, + ) + + assert captured["executor"] is manager._archive_executor + assert len(extracted) == 1 + assert extracted[0].endswith("model.safetensors") + + @pytest.mark.asyncio async def test_pause_download_updates_state(): """Test that pause_download updates download state correctly.""" @@ -469,6 +704,233 @@ async def test_pause_download_updates_state(): assert manager._active_downloads[download_id]["bytes_per_second"] == 0.0 +@pytest.mark.asyncio +async def test_pause_download_reverts_local_pause_when_aria2_pause_fails(monkeypatch): + manager = DownloadManager() + + download_id = "dl" + manager._download_tasks[download_id] = object() + pause_control = DownloadStreamControl() + manager._pause_events[download_id] = pause_control + manager._active_downloads[download_id] = { + "transfer_backend": "aria2", + "status": "downloading", + "bytes_per_second": 42.0, + } + + class DummyAria2Downloader: + async def has_transfer(self, _download_id): + return True + + async def pause_download(self, _download_id): + return {"success": False, "error": "rpc failed"} + + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=DummyAria2Downloader()), + ) + + result = await manager.pause_download(download_id) + + assert result == {"success": False, "error": "rpc failed"} + assert pause_control.is_set() is True + assert manager._active_downloads[download_id]["status"] == "downloading" + + +@pytest.mark.asyncio +async def test_pause_download_reverts_local_pause_when_aria2_probe_raises(monkeypatch): + manager = DownloadManager() + + download_id = "dl" + manager._download_tasks[download_id] = object() + pause_control = DownloadStreamControl() + manager._pause_events[download_id] = pause_control + manager._active_downloads[download_id] = { + "transfer_backend": "aria2", + "status": "downloading", + "bytes_per_second": 42.0, + } + + class DummyAria2Downloader: + async def has_transfer(self, _download_id): + raise RuntimeError("rpc unavailable") + + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=DummyAria2Downloader()), + ) + + result = await manager.pause_download(download_id) + + assert result == {"success": False, "error": "rpc unavailable"} + assert pause_control.is_set() is True + assert manager._active_downloads[download_id]["status"] == "downloading" + + +@pytest.mark.asyncio +async def test_resume_download_returns_error_when_aria2_probe_raises(monkeypatch): + manager = DownloadManager() + + download_id = "dl" + pause_control = DownloadStreamControl() + pause_control.pause() + manager._pause_events[download_id] = pause_control + manager._active_downloads[download_id] = { + "transfer_backend": "aria2", + "status": "paused", + "bytes_per_second": 0.0, + } + + class DummyAria2Downloader: + async def has_transfer(self, _download_id): + raise RuntimeError("rpc unavailable") + + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=DummyAria2Downloader()), + ) + + result = await manager.resume_download(download_id) + + assert result == {"success": False, "error": "rpc unavailable"} + assert pause_control.is_paused() is True + assert manager._active_downloads[download_id]["status"] == "paused" + + +@pytest.mark.asyncio +async def test_cancel_download_still_cancels_local_task_when_aria2_raises(monkeypatch): + manager = DownloadManager() + + started = asyncio.Event() + + async def blocked_task(): + started.set() + await asyncio.sleep(60) + + task = asyncio.create_task(blocked_task()) + await started.wait() + + download_id = "download-queued" + manager._download_tasks[download_id] = task + manager._active_downloads[download_id] = { + "transfer_backend": "aria2", + "status": "queued", + } + + class DummyAria2Downloader: + async def cancel_download(self, _download_id): + raise RuntimeError("rpc unavailable") + + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=DummyAria2Downloader()), + ) + + result = await manager.cancel_download(download_id) + + assert result["success"] is True + assert task.cancelled() or task.done() + + +@pytest.mark.asyncio +async def test_execute_download_waits_for_paused_pre_transfer_gate(monkeypatch, tmp_path): + manager = DownloadManager() + + save_dir = tmp_path / "downloads" + save_dir.mkdir() + target_path = save_dir / "file.safetensors" + + class DummyMetadata: + def __init__(self, path: Path): + self.file_path = str(path) + self.sha256 = "sha256" + self.file_name = path.stem + self.preview_url = None + + def generate_unique_filename(self, *_args, **_kwargs): + return os.path.basename(self.file_path) + + def update_file_info(self, _path): + return None + + def to_dict(self): + return {"file_path": self.file_path} + + pause_control = DownloadStreamControl() + pause_control.pause() + manager._pause_events["download-1"] = pause_control + manager._active_downloads["download-1"] = { + "status": "downloading", + "bytes_per_second": 42.0, + } + + dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None)) + monkeypatch.setattr( + DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner) + ) + monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True)) + + started = asyncio.Event() + allow_finish = asyncio.Event() + captured = {"calls": 0} + + async def fake_download_model_file( + self, + download_url, + save_path, + *, + backend, + progress_callback, + use_auth, + download_id, + pause_control, + ): + captured["calls"] += 1 + started.set() + await allow_finish.wait() + Path(save_path).write_text("content") + return True, save_path + + monkeypatch.setattr( + DownloadManager, + "_download_model_file", + fake_download_model_file, + ) + + task = asyncio.create_task( + manager._execute_download( + download_urls=["https://civitai.com/api/download/models/1"], + save_dir=str(save_dir), + metadata=DummyMetadata(target_path), + version_info={"images": []}, + relative_path="", + progress_callback=None, + model_type="lora", + download_id="download-1", + transfer_backend="aria2", + ) + ) + + await asyncio.sleep(0) + assert started.is_set() is False + assert captured["calls"] == 0 + assert manager._active_downloads["download-1"]["status"] == "paused" + + pause_control.resume() + await asyncio.wait_for(started.wait(), timeout=1.0) + assert captured["calls"] == 1 + assert manager._active_downloads["download-1"]["status"] == "downloading" + + allow_finish.set() + result = await task + + assert result == {"success": True} + + @pytest.mark.asyncio async def test_pause_download_rejects_unknown_task(): """Test that pause_download rejects unknown download tasks.""" diff --git a/tests/services/test_settings_manager.py b/tests/services/test_settings_manager.py index abb1fb97..11c40089 100644 --- a/tests/services/test_settings_manager.py +++ b/tests/services/test_settings_manager.py @@ -147,6 +147,11 @@ def test_environment_variable_overrides_settings(tmp_path, monkeypatch): assert mgr.get("civitai_api_key") == "secret" +def test_default_download_backend_is_python(manager): + assert manager.get("download_backend") == "python" + assert manager.get("aria2c_path") == "" + + def _create_manager_with_settings( tmp_path, monkeypatch, initial_settings, *, save_spy=None ):