Compare commits

...

3 Commits

Author SHA1 Message Date
Will Miao
761108bfd1 fix(download): restore aria2 resume lifecycle 2026-04-20 09:52:48 +08:00
Will Miao
24dd3a777c fix(settings): align modal form control widths 2026-04-19 21:59:33 +08:00
Will Miao
1c530ea013 feat(download): add experimental aria2 backend 2026-04-19 21:46:09 +08:00
23 changed files with 3982 additions and 128 deletions

View File

@@ -263,6 +263,19 @@
"red": "civitai.red (uneingeschränkt)"
}
},
"downloadBackend": {
"label": "Download-Backend",
"help": "Wähle aus, wie Modelldateien heruntergeladen werden. Python verwendet den eingebauten Downloader. aria2 verwendet den experimentellen externen Downloader-Prozess.",
"options": {
"python": "Python (integriert)",
"aria2": "aria2 (experimentell)"
}
},
"aria2cPath": {
"label": "aria2c-Pfad",
"help": "Optionaler Pfad zur ausführbaren aria2c-Datei. Leer lassen, um aria2c aus dem System-PATH zu verwenden.",
"placeholder": "Leer lassen, um aria2c aus dem PATH zu verwenden"
},
"civitaiHostBanner": {
"title": "Civitai-Host-Einstellung verfügbar",
"content": "Civitai verwendet jetzt civitai.com für SFW-Inhalte und civitai.red für uneingeschränkte Inhalte. In den Einstellungen können Sie ändern, welche Seite standardmäßig geöffnet wird.",
@@ -278,6 +291,7 @@
},
"sections": {
"contentFiltering": "Inhaltsfilterung",
"downloads": "Downloads",
"videoSettings": "Video-Einstellungen",
"layoutSettings": "Layout-Einstellungen",
"misc": "Verschiedenes",

View File

@@ -263,6 +263,19 @@
"red": "civitai.red (unrestricted)"
}
},
"downloadBackend": {
"label": "Download backend",
"help": "Choose how model files are downloaded. Python uses the built-in downloader. aria2 uses the experimental external downloader process.",
"options": {
"python": "Python (built-in)",
"aria2": "aria2 (experimental)"
}
},
"aria2cPath": {
"label": "aria2c path",
"help": "Optional path to the aria2c executable. Leave empty to use aria2c from your system PATH.",
"placeholder": "Leave empty to use aria2c from PATH"
},
"civitaiHostBanner": {
"title": "Civitai host preference available",
"content": "Civitai now uses civitai.com for SFW content and civitai.red for unrestricted content. You can change which site opens by default in Settings.",
@@ -278,6 +291,7 @@
},
"sections": {
"contentFiltering": "Content Filtering",
"downloads": "Downloads",
"videoSettings": "Video Settings",
"layoutSettings": "Layout Settings",
"misc": "Miscellaneous",

View File

@@ -263,6 +263,19 @@
"red": "civitai.red (sin restricciones)"
}
},
"downloadBackend": {
"label": "Backend de descarga",
"help": "Elige cómo se descargan los archivos del modelo. Python usa el descargador integrado. aria2 usa el proceso externo experimental de descarga.",
"options": {
"python": "Python (integrado)",
"aria2": "aria2 (experimental)"
}
},
"aria2cPath": {
"label": "Ruta de aria2c",
"help": "Ruta opcional al ejecutable aria2c. Déjalo vacío para usar aria2c desde el PATH del sistema.",
"placeholder": "Déjalo vacío para usar aria2c desde el PATH"
},
"civitaiHostBanner": {
"title": "Preferencia de host de Civitai disponible",
"content": "Civitai ahora usa civitai.com para contenido SFW y civitai.red para contenido sin restricciones. Puedes cambiar en Ajustes qué sitio se abre por defecto.",
@@ -278,6 +291,7 @@
},
"sections": {
"contentFiltering": "Filtrado de contenido",
"downloads": "Descargas",
"videoSettings": "Configuración de video",
"layoutSettings": "Configuración de diseño",
"misc": "Varios",

View File

@@ -263,6 +263,19 @@
"red": "civitai.red (sans restriction)"
}
},
"downloadBackend": {
"label": "Moteur de téléchargement",
"help": "Choisissez comment les fichiers de modèles sont téléchargés. Python utilise le téléchargeur intégré. aria2 utilise le processus externe expérimental de téléchargement.",
"options": {
"python": "Python (intégré)",
"aria2": "aria2 (expérimental)"
}
},
"aria2cPath": {
"label": "Chemin vers aria2c",
"help": "Chemin facultatif vers lexécutable aria2c. Laissez vide pour utiliser aria2c depuis le PATH système.",
"placeholder": "Laisser vide pour utiliser aria2c depuis le PATH"
},
"civitaiHostBanner": {
"title": "Préférence dhôte Civitai disponible",
"content": "Civitai utilise désormais civitai.com pour le contenu SFW et civitai.red pour le contenu sans restriction. Vous pouvez modifier dans les paramètres le site ouvert par défaut.",
@@ -278,6 +291,7 @@
},
"sections": {
"contentFiltering": "Filtrage du contenu",
"downloads": "Téléchargements",
"videoSettings": "Paramètres vidéo",
"layoutSettings": "Paramètres d'affichage",
"misc": "Divers",

View File

@@ -263,6 +263,19 @@
"red": "civitai.red (ללא הגבלות)"
}
},
"downloadBackend": {
"label": "מנגנון הורדה",
"help": "בחר כיצד יורדים קבצי המודל. Python משתמש במוריד המובנה. aria2 משתמש בתהליך הורדה חיצוני ניסיוני.",
"options": {
"python": "Python (מובנה)",
"aria2": "aria2 (ניסיוני)"
}
},
"aria2cPath": {
"label": "נתיב aria2c",
"help": "נתיב אופציונלי לקובץ ההפעלה aria2c. השאר ריק כדי להשתמש ב-aria2c מתוך ה-PATH של המערכת.",
"placeholder": "השאר ריק כדי להשתמש ב-aria2c מתוך ה-PATH"
},
"civitaiHostBanner": {
"title": "העדפת מארח Civitai זמינה",
"content": "Civitai משתמש כעת ב-civitai.com עבור תוכן SFW וב-civitai.red עבור תוכן ללא הגבלות. ניתן לשנות בהגדרות איזה אתר ייפתח כברירת מחדל.",
@@ -278,6 +291,7 @@
},
"sections": {
"contentFiltering": "סינון תוכן",
"downloads": "הורדות",
"videoSettings": "הגדרות וידאו",
"layoutSettings": "הגדרות פריסה",
"misc": "שונות",

View File

@@ -263,6 +263,19 @@
"red": "civitai.red制限なし"
}
},
"downloadBackend": {
"label": "ダウンロードバックエンド",
"help": "モデルファイルのダウンロード方法を選択します。Python は内蔵ダウンローダーを使用し、aria2 は実験的な外部ダウンローダープロセスを使用します。",
"options": {
"python": "Python内蔵",
"aria2": "aria2実験的"
}
},
"aria2cPath": {
"label": "aria2c のパス",
"help": "aria2c 実行ファイルへの任意のパスです。空欄のままにすると、システム PATH 上の aria2c を使用します。",
"placeholder": "空欄のままにすると PATH 上の aria2c を使用します"
},
"civitaiHostBanner": {
"title": "Civitai ホスト設定を利用できます",
"content": "Civitai は現在、SFW コンテンツには civitai.com、制限なしコンテンツには civitai.red を使用しています。設定で既定で開くサイトを変更できます。",
@@ -278,6 +291,7 @@
},
"sections": {
"contentFiltering": "コンテンツフィルタリング",
"downloads": "ダウンロード",
"videoSettings": "動画設定",
"layoutSettings": "レイアウト設定",
"misc": "その他",

View File

@@ -263,6 +263,19 @@
"red": "civitai.red(무제한)"
}
},
"downloadBackend": {
"label": "다운로드 백엔드",
"help": "모델 파일을 다운로드하는 방식을 선택합니다. Python은 내장 다운로더를 사용하고, aria2는 실험적인 외부 다운로더 프로세스를 사용합니다.",
"options": {
"python": "Python(내장)",
"aria2": "aria2(실험적)"
}
},
"aria2cPath": {
"label": "aria2c 경로",
"help": "aria2c 실행 파일의 선택적 경로입니다. 비워 두면 시스템 PATH의 aria2c를 사용합니다.",
"placeholder": "비워 두면 PATH의 aria2c를 사용합니다"
},
"civitaiHostBanner": {
"title": "Civitai 호스트 기본 설정 사용 가능",
"content": "이제 Civitai는 SFW 콘텐츠에 civitai.com을, 무제한 콘텐츠에 civitai.red를 사용합니다. 설정에서 기본으로 열 사이트를 변경할 수 있습니다.",
@@ -278,6 +291,7 @@
},
"sections": {
"contentFiltering": "콘텐츠 필터링",
"downloads": "다운로드",
"videoSettings": "비디오 설정",
"layoutSettings": "레이아웃 설정",
"misc": "기타",

View File

@@ -263,6 +263,19 @@
"red": "civitai.red (без ограничений)"
}
},
"downloadBackend": {
"label": "Бэкенд загрузки",
"help": "Выберите способ загрузки файлов моделей. Python использует встроенный загрузчик. aria2 использует экспериментальный внешний процесс загрузки.",
"options": {
"python": "Python (встроенный)",
"aria2": "aria2 (экспериментальный)"
}
},
"aria2cPath": {
"label": "Путь к aria2c",
"help": "Необязательный путь к исполняемому файлу aria2c. Оставьте пустым, чтобы использовать aria2c из системного PATH.",
"placeholder": "Оставьте пустым, чтобы использовать aria2c из PATH"
},
"civitaiHostBanner": {
"title": "Доступна настройка хоста Civitai",
"content": "Теперь Civitai использует civitai.com для контента SFW и civitai.red для контента без ограничений. В настройках можно изменить, какой сайт открывать по умолчанию.",
@@ -278,6 +291,7 @@
},
"sections": {
"contentFiltering": "Фильтрация контента",
"downloads": "Загрузки",
"videoSettings": "Настройки видео",
"layoutSettings": "Настройки макета",
"misc": "Разное",

View File

@@ -263,6 +263,19 @@
"red": "civitai.red无限制"
}
},
"downloadBackend": {
"label": "下载后端",
"help": "选择模型文件的下载方式。Python 使用内置下载器。aria2 使用实验性的外部下载进程。",
"options": {
"python": "Python内置",
"aria2": "aria2实验性"
}
},
"aria2cPath": {
"label": "aria2c 路径",
"help": "可选的 aria2c 可执行文件路径。留空则使用系统 PATH 中的 aria2c。",
"placeholder": "留空则使用 PATH 中的 aria2c"
},
"civitaiHostBanner": {
"title": "已提供 Civitai 站点偏好设置",
"content": "Civitai 现在使用 civitai.com 提供 SFW 内容,使用 civitai.red 提供无限制内容。你可以在设置中更改默认打开的站点。",
@@ -278,6 +291,7 @@
},
"sections": {
"contentFiltering": "内容过滤",
"downloads": "下载",
"videoSettings": "视频设置",
"layoutSettings": "布局设置",
"misc": "其他",

View File

@@ -263,6 +263,19 @@
"red": "civitai.red無限制"
}
},
"downloadBackend": {
"label": "下載後端",
"help": "選擇模型檔案的下載方式。Python 使用內建下載器。aria2 使用實驗性的外部下載程序。",
"options": {
"python": "Python內建",
"aria2": "aria2實驗性"
}
},
"aria2cPath": {
"label": "aria2c 路徑",
"help": "可選的 aria2c 可執行檔路徑。留空則使用系統 PATH 中的 aria2c。",
"placeholder": "留空則使用 PATH 中的 aria2c"
},
"civitaiHostBanner": {
"title": "已提供 Civitai 站點偏好設定",
"content": "Civitai 現在使用 civitai.com 提供 SFW 內容,使用 civitai.red 提供無限制內容。你可以在設定中變更預設開啟的站點。",
@@ -278,6 +291,7 @@
},
"sections": {
"contentFiltering": "內容過濾",
"downloads": "下載",
"videoSettings": "影片設定",
"layoutSettings": "版面設定",
"misc": "其他",

View File

@@ -0,0 +1,570 @@
from __future__ import annotations
import asyncio
import json
import logging
import os
import secrets
import shutil
import socket
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
import aiohttp
from .downloader import DownloadProgress, get_downloader
from .aria2_transfer_state import Aria2TransferStateStore
from .settings_manager import get_settings_manager
logger = logging.getLogger(__name__)
CIVITAI_DOWNLOAD_URL_PREFIXES = (
"https://civitai.com/api/download/",
"https://civitai.red/api/download/",
)
class Aria2Error(RuntimeError):
"""Raised when aria2 integration fails."""
@dataclass
class Aria2Transfer:
"""Track an aria2 download registered by the Python coordinator."""
gid: str
save_path: str
class Aria2Downloader:
"""Manage an aria2 RPC daemon for experimental model downloads."""
_instance = None
_lock = asyncio.Lock()
@classmethod
async def get_instance(cls) -> "Aria2Downloader":
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self) -> None:
if hasattr(self, "_initialized"):
return
self._initialized = True
self._process: Optional[asyncio.subprocess.Process] = None
self._rpc_port: Optional[int] = None
self._rpc_secret = ""
self._rpc_url = ""
self._rpc_session: Optional[aiohttp.ClientSession] = None
self._rpc_session_lock = asyncio.Lock()
self._process_lock = asyncio.Lock()
self._transfers: Dict[str, Aria2Transfer] = {}
self._poll_interval = 0.5
self._state_store = Aria2TransferStateStore()
@property
def is_running(self) -> bool:
return self._process is not None and self._process.returncode is None
async def download_file(
self,
url: str,
save_path: str,
*,
download_id: str,
progress_callback=None,
headers: Optional[Dict[str, str]] = None,
) -> Tuple[bool, str]:
"""Download a file using aria2 RPC and wait for completion."""
await self._ensure_process()
save_path = os.path.abspath(save_path)
transfer = self._transfers.get(download_id)
if transfer is None or os.path.abspath(transfer.save_path) != save_path:
gid = await self._schedule_download(
url,
save_path,
download_id=download_id,
headers=headers,
)
transfer = Aria2Transfer(gid=gid, save_path=save_path)
self._transfers[download_id] = transfer
try:
while True:
status = await self.get_status(download_id)
if status is None:
return False, "aria2 download not found"
snapshot = self._build_progress_snapshot(status)
if progress_callback is not None:
await self._dispatch_progress(progress_callback, snapshot)
state = status.get("status", "")
if state == "complete":
completed_path = self._resolve_completed_path(status, save_path)
return True, completed_path
if state == "error":
return False, status.get("errorMessage") or "aria2 download failed"
if state == "removed":
return False, "Download was cancelled"
await asyncio.sleep(self._poll_interval)
finally:
self._transfers.pop(download_id, None)
async def _schedule_download(
self,
url: str,
save_path: str,
*,
download_id: str,
headers: Optional[Dict[str, str]] = None,
) -> str:
save_dir = os.path.dirname(save_path)
out_name = os.path.basename(save_path)
Path(save_dir).mkdir(parents=True, exist_ok=True)
resolved_url = url
request_headers = headers
if headers and url.startswith(CIVITAI_DOWNLOAD_URL_PREFIXES):
resolved_url = await self._resolve_authenticated_redirect_url(url, headers)
if resolved_url != url:
request_headers = None
logger.debug(
"Resolved Civitai download %s to signed URL for aria2",
download_id,
)
options: Dict[str, str] = {
"dir": save_dir,
"out": out_name,
"continue": "true",
"max-connection-per-server": "4",
"split": "4",
"min-split-size": "1M",
"allow-overwrite": "true",
"auto-file-renaming": "false",
"file-allocation": "none",
}
if request_headers:
options["header"] = [
f"{key}: {value}" for key, value in request_headers.items()
]
logger.debug(
"Submitting aria2 download %s -> %s (auth=%s, civitai_signed=%s)",
download_id,
save_path,
bool(request_headers),
resolved_url != url,
)
try:
gid = await self._rpc_call("aria2.addUri", [[resolved_url], options])
except Exception as exc:
raise Aria2Error(f"Failed to schedule aria2 download: {exc}") from exc
logger.debug("aria2 accepted download %s with gid %s", download_id, gid)
await self._state_store.upsert(
download_id,
{
"gid": gid,
"save_path": save_path,
"status": "downloading",
"url": url,
},
)
return gid
async def get_status(self, download_id: str) -> Optional[Dict[str, Any]]:
"""Return the raw aria2 status payload for a known download."""
transfer = self._transfers.get(download_id)
if transfer is None:
return None
keys = [
"gid",
"status",
"totalLength",
"completedLength",
"downloadSpeed",
"errorMessage",
"files",
]
try:
status = await self._rpc_call("aria2.tellStatus", [transfer.gid, keys])
except Exception as exc:
raise Aria2Error(f"Failed to query aria2 download status: {exc}") from exc
if isinstance(status, dict):
return status
return None
async def get_status_by_gid(self, gid: str) -> Optional[Dict[str, Any]]:
keys = [
"gid",
"status",
"totalLength",
"completedLength",
"downloadSpeed",
"errorMessage",
"files",
]
try:
status = await self._rpc_call("aria2.tellStatus", [gid, keys])
except Exception as exc:
message = str(exc)
if "cannot be found" in message.lower() or "not found" in message.lower():
return None
raise Aria2Error(f"Failed to query aria2 download status: {exc}") from exc
if isinstance(status, dict):
return status
return None
async def restore_transfer(self, download_id: str, gid: str, save_path: str) -> None:
await self._ensure_process()
self._transfers[download_id] = Aria2Transfer(
gid=gid,
save_path=os.path.abspath(save_path),
)
async def reassign_transfer(
self, from_download_id: str, to_download_id: str
) -> Optional[Aria2Transfer]:
transfer = self._transfers.get(from_download_id)
if transfer is None:
return None
self._transfers[to_download_id] = transfer
if from_download_id != to_download_id:
self._transfers.pop(from_download_id, None)
return transfer
async def has_transfer(self, download_id: str) -> bool:
return download_id in self._transfers
async def pause_download(self, download_id: str) -> Dict[str, Any]:
transfer = self._transfers.get(download_id)
if transfer is None:
return {"success": False, "error": "Download task not found"}
try:
await self._rpc_call("aria2.forcePause", [transfer.gid])
except Exception as exc:
return {"success": False, "error": str(exc)}
await self._state_store.upsert(download_id, {"status": "paused"})
return {"success": True, "message": "Download paused successfully"}
async def resume_download(self, download_id: str) -> Dict[str, Any]:
transfer = self._transfers.get(download_id)
if transfer is None:
return {"success": False, "error": "Download task not found"}
try:
await self._rpc_call("aria2.unpause", [transfer.gid])
except Exception as exc:
return {"success": False, "error": str(exc)}
await self._state_store.upsert(download_id, {"status": "downloading"})
return {"success": True, "message": "Download resumed successfully"}
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
transfer = self._transfers.get(download_id)
if transfer is None:
return {"success": False, "error": "Download task not found"}
try:
await self._rpc_call("aria2.forceRemove", [transfer.gid])
except Exception as exc:
return {"success": False, "error": str(exc)}
await self._state_store.remove(download_id)
return {"success": True, "message": "Download cancelled successfully"}
async def close(self) -> None:
"""Shut down the RPC process and session."""
if self._rpc_session is not None:
await self._rpc_session.close()
self._rpc_session = None
process = self._process
self._process = None
self._transfers.clear()
if process is None:
return
if process.returncode is None:
process.terminate()
try:
await asyncio.wait_for(process.wait(), timeout=5.0)
except asyncio.TimeoutError:
process.kill()
await process.wait()
async def _dispatch_progress(self, callback, snapshot: DownloadProgress) -> None:
try:
result = callback(snapshot, snapshot)
except TypeError:
result = callback(snapshot.percent_complete)
if asyncio.iscoroutine(result):
await result
elif hasattr(result, "__await__"):
await result
def _build_progress_snapshot(self, status: Dict[str, Any]) -> DownloadProgress:
completed = self._parse_int(status.get("completedLength"))
total = self._parse_int(status.get("totalLength"))
speed = float(self._parse_int(status.get("downloadSpeed")))
percent = 0.0
if total > 0:
percent = (completed / total) * 100.0
return DownloadProgress(
percent_complete=max(0.0, min(percent, 100.0)),
bytes_downloaded=completed,
total_bytes=total or None,
bytes_per_second=speed,
timestamp=datetime.now().timestamp(),
)
def _resolve_completed_path(self, status: Dict[str, Any], default_path: str) -> str:
files = status.get("files")
if isinstance(files, list) and files:
first = files[0]
if isinstance(first, dict):
candidate = first.get("path")
if isinstance(candidate, str) and candidate:
return candidate
return default_path
@staticmethod
def _parse_int(value: Any) -> int:
try:
return int(value)
except (TypeError, ValueError):
return 0
async def _resolve_authenticated_redirect_url(
self,
url: str,
headers: Dict[str, str],
) -> str:
downloader = await get_downloader()
session = await downloader.session
request_headers = dict(downloader.default_headers)
request_headers.update(headers)
request_headers["Accept-Encoding"] = "identity"
try:
async with session.get(
url,
headers=request_headers,
allow_redirects=False,
proxy=downloader.proxy_url,
) as response:
if response.status in {301, 302, 303, 307, 308}:
location = response.headers.get("Location")
if location:
return location
raise Aria2Error(
"Authenticated Civitai redirect did not include a Location header"
)
if response.status == 200:
return url
body = await response.text()
raise Aria2Error(
f"Failed to resolve authenticated Civitai redirect: status={response.status} body={body[:300]}"
)
except aiohttp.ClientError as exc:
raise Aria2Error(
f"Failed to resolve authenticated Civitai redirect: {exc}"
) from exc
async def _ensure_process(self) -> None:
async with self._process_lock:
if self.is_running and await self._ping():
return
await self.close()
executable = self._resolve_executable()
self._rpc_port = self._find_free_port()
self._rpc_secret = secrets.token_hex(16)
self._rpc_url = f"http://127.0.0.1:{self._rpc_port}/jsonrpc"
command = [
executable,
"--enable-rpc=true",
"--rpc-listen-all=false",
f"--rpc-listen-port={self._rpc_port}",
f"--rpc-secret={self._rpc_secret}",
"--check-certificate=true",
"--allow-overwrite=true",
"--auto-file-renaming=false",
"--file-allocation=none",
"--max-concurrent-downloads=5",
"--continue=true",
"--daemon=false",
"--quiet=true",
f"--stop-with-process={os.getpid()}",
]
logger.info("Starting aria2 RPC daemon from %s", executable)
self._process = await asyncio.create_subprocess_exec(
*command,
stdout=asyncio.subprocess.DEVNULL,
stderr=asyncio.subprocess.PIPE,
)
await self._wait_until_ready()
def _resolve_executable(self) -> str:
settings = get_settings_manager()
configured_path = (settings.get("aria2c_path") or "").strip()
candidate = configured_path or "aria2c"
resolved = shutil.which(candidate)
if resolved:
return resolved
if configured_path and os.path.isfile(configured_path) and os.access(
configured_path, os.X_OK
):
return configured_path
raise Aria2Error(
"aria2c executable was not found. Install aria2 or configure aria2c_path."
)
async def _wait_until_ready(self) -> None:
assert self._process is not None
start_time = asyncio.get_running_loop().time()
last_error = ""
while asyncio.get_running_loop().time() - start_time < 10.0:
if self._process.returncode is not None:
stderr_output = ""
if self._process.stderr is not None:
try:
stderr_output = (
await asyncio.wait_for(self._process.stderr.read(), timeout=0.2)
).decode("utf-8", errors="replace")
except Exception:
stderr_output = ""
raise Aria2Error(
f"aria2 RPC process exited early with code {self._process.returncode}: {stderr_output.strip()}"
)
try:
if await self._ping():
return
except Exception as exc: # pragma: no cover - startup race
last_error = str(exc)
await asyncio.sleep(0.2)
raise Aria2Error(
f"Timed out waiting for aria2 RPC to become ready{': ' + last_error if last_error else ''}"
)
async def _ping(self) -> bool:
try:
result = await self._rpc_call("aria2.getVersion", [])
except Exception:
return False
return isinstance(result, dict)
async def _rpc_call(self, method: str, params: list[Any]) -> Any:
if not self._rpc_url:
raise Aria2Error("aria2 RPC endpoint is not initialized")
session = await self._get_rpc_session()
payload = {
"jsonrpc": "2.0",
"id": secrets.token_hex(8),
"method": method,
"params": [f"token:{self._rpc_secret}", *params],
}
async with session.post(self._rpc_url, json=payload) as response:
text = await response.text()
try:
body = json.loads(text)
except json.JSONDecodeError:
body = None
if body is None:
if response.status != 200:
raise Aria2Error(
f"aria2 RPC returned status {response.status} with non-JSON body: {text}"
)
raise Aria2Error(f"Invalid aria2 RPC response: {text}")
if "error" in body:
error = body["error"] or {}
code = error.get("code") if isinstance(error, dict) else None
message = error.get("message") if isinstance(error, dict) else str(error)
logger.error(
"aria2 RPC %s failed with HTTP %s, code=%s, message=%s",
method,
response.status,
code,
message,
)
status_message = (
f"aria2 RPC {method} failed with status {response.status}: {message}"
if response.status != 200
else message
)
raise Aria2Error(status_message or "Unknown aria2 RPC error")
if response.status != 200:
logger.error(
"aria2 RPC %s returned unexpected HTTP status %s without error payload: %s",
method,
response.status,
body,
)
raise Aria2Error(
f"aria2 RPC {method} returned unexpected status {response.status}"
)
return body.get("result")
async def _get_rpc_session(self) -> aiohttp.ClientSession:
if self._rpc_session is None or self._rpc_session.closed:
async with self._rpc_session_lock:
if self._rpc_session is None or self._rpc_session.closed:
timeout = aiohttp.ClientTimeout(total=30)
self._rpc_session = aiohttp.ClientSession(timeout=timeout)
return self._rpc_session
@staticmethod
def _find_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("127.0.0.1", 0))
sock.listen(1)
return int(sock.getsockname()[1])
async def get_aria2_downloader() -> Aria2Downloader:
"""Get the singleton aria2 downloader."""
return await Aria2Downloader.get_instance()

View File

@@ -0,0 +1,108 @@
from __future__ import annotations
import asyncio
import json
import os
from copy import deepcopy
from typing import Any, Dict, Optional
from ..utils.cache_paths import get_cache_base_dir
def get_aria2_state_path() -> str:
base_dir = get_cache_base_dir(create=True)
state_dir = os.path.join(base_dir, "aria2")
os.makedirs(state_dir, exist_ok=True)
return os.path.join(state_dir, "downloads.json")
class Aria2TransferStateStore:
"""Persist aria2 transfer metadata needed for restart recovery."""
_locks_by_path: Dict[str, asyncio.Lock] = {}
def __init__(self, state_path: Optional[str] = None) -> None:
self._state_path = os.path.abspath(state_path or get_aria2_state_path())
self._lock = self._locks_by_path.setdefault(self._state_path, asyncio.Lock())
def _read_all_unlocked(self) -> Dict[str, Dict[str, Any]]:
try:
with open(self._state_path, "r", encoding="utf-8") as handle:
data = json.load(handle)
except FileNotFoundError:
return {}
except json.JSONDecodeError:
return {}
if not isinstance(data, dict):
return {}
normalized: Dict[str, Dict[str, Any]] = {}
for download_id, entry in data.items():
if isinstance(download_id, str) and isinstance(entry, dict):
normalized[download_id] = entry
return normalized
def _write_all_unlocked(self, data: Dict[str, Dict[str, Any]]) -> None:
directory = os.path.dirname(self._state_path)
if directory:
os.makedirs(directory, exist_ok=True)
temp_path = f"{self._state_path}.tmp"
with open(temp_path, "w", encoding="utf-8") as handle:
json.dump(data, handle, ensure_ascii=True, indent=2, sort_keys=True)
os.replace(temp_path, self._state_path)
async def load_all(self) -> Dict[str, Dict[str, Any]]:
async with self._lock:
return deepcopy(self._read_all_unlocked())
async def get(self, download_id: str) -> Optional[Dict[str, Any]]:
async with self._lock:
return deepcopy(self._read_all_unlocked().get(download_id))
async def upsert(self, download_id: str, payload: Dict[str, Any]) -> Dict[str, Any]:
async with self._lock:
data = self._read_all_unlocked()
current = data.get(download_id, {})
current.update(payload)
data[download_id] = current
self._write_all_unlocked(data)
return deepcopy(current)
async def remove(self, download_id: str) -> None:
async with self._lock:
data = self._read_all_unlocked()
if download_id in data:
del data[download_id]
self._write_all_unlocked(data)
async def find_by_save_path(
self, save_path: str, *, exclude_download_id: Optional[str] = None
) -> Optional[Dict[str, Any]]:
normalized_target = os.path.abspath(save_path)
async with self._lock:
data = self._read_all_unlocked()
for download_id, entry in data.items():
if exclude_download_id and download_id == exclude_download_id:
continue
candidate = entry.get("save_path")
if isinstance(candidate, str) and os.path.abspath(candidate) == normalized_target:
result = dict(entry)
result["download_id"] = download_id
return result
return None
async def reassign(self, from_download_id: str, to_download_id: str) -> Optional[Dict[str, Any]]:
async with self._lock:
data = self._read_all_unlocked()
existing = data.get(from_download_id)
if existing is None:
return None
updated = dict(existing)
updated["download_id"] = to_download_id
data[to_download_id] = updated
if from_download_id != to_download_id:
data.pop(from_download_id, None)
self._write_all_unlocked(data)
return deepcopy(updated)

File diff suppressed because it is too large Load Diff

View File

@@ -55,6 +55,8 @@ DEFAULT_KEYS_CLEANUP_THRESHOLD = 10
DEFAULT_SETTINGS: Dict[str, Any] = {
"civitai_api_key": "",
"civitai_host": "civitai.com",
"download_backend": "python",
"aria2c_path": "",
"use_portable_settings": False,
"hash_chunk_size_mb": DEFAULT_HASH_CHUNK_SIZE_MB,
"language": "en",

View File

@@ -346,11 +346,13 @@
.api-key-input input {
width: 100%;
padding: 6px 40px 6px 10px; /* Add left padding */
height: 20px;
height: 32px;
box-sizing: border-box;
border-radius: var(--border-radius-xs);
border: 1px solid var(--border-color);
background-color: var(--lora-surface);
color: var(--text-color);
font-size: 0.95em;
}
.api-key-input .toggle-visibility {
@@ -379,7 +381,8 @@
.text-input-wrapper input {
width: 100%;
padding: 6px 10px;
height: 20px;
height: 32px;
box-sizing: border-box;
border-radius: var(--border-radius-xs);
border: 1px solid var(--border-color);
background-color: var(--lora-surface);
@@ -760,10 +763,12 @@
}
.setting-control {
width: 60%; /* Decreased slightly from 65% */
flex: 0 0 60%;
max-width: 60%;
margin-bottom: 0;
display: flex;
justify-content: flex-end; /* Right-align all controls */
min-width: 0;
}
/* Select Control Styles */
@@ -773,6 +778,13 @@
justify-content: flex-end;
}
.setting-control select,
.setting-control input[type="text"],
.setting-control input[type="password"],
.setting-control input[type="number"] {
font-size: 0.95em;
}
.select-control select {
width: 100%;
max-width: 100%; /* Increased from 200px */
@@ -781,8 +793,8 @@
border: 1px solid var(--border-color);
background-color: var(--lora-surface);
color: var(--text-color);
font-size: 0.95em;
height: 32px;
box-sizing: border-box;
}
/* Fix dark theme select dropdown text color */
@@ -888,8 +900,8 @@ input:checked + .toggle-slider:before {
border: 1px solid var(--border-color);
background-color: var(--lora-surface);
color: var(--text-color);
font-size: 0.95em;
height: 32px;
box-sizing: border-box;
}
/* Add warning text style for settings */

View File

@@ -807,6 +807,16 @@ export class SettingsManager {
civitaiHostSelect.value = state.global.settings.civitai_host || 'civitai.com';
}
const downloadBackendSelect = document.getElementById('downloadBackend');
if (downloadBackendSelect) {
downloadBackendSelect.value = state.global.settings.download_backend || 'python';
}
const aria2cPathInput = document.getElementById('aria2cPath');
if (aria2cPathInput) {
aria2cPathInput.value = state.global.settings.aria2c_path || '';
}
const recipesPathInput = document.getElementById('recipesPath');
if (recipesPathInput) {
recipesPathInput.value = state.global.settings.recipes_path || '';
@@ -950,9 +960,36 @@ export class SettingsManager {
languageSelect.value = currentLanguage;
}
this.loadDownloadBackendSettings();
this.loadProxySettings();
}
loadDownloadBackendSettings() {
const downloadBackendSelect = document.getElementById('downloadBackend');
const aria2PathSetting = document.getElementById('aria2PathSetting');
const updateVisibility = () => {
if (!aria2PathSetting || !downloadBackendSelect) {
return;
}
aria2PathSetting.style.display = downloadBackendSelect.value === 'aria2' ? 'block' : 'none';
};
if (downloadBackendSelect) {
downloadBackendSelect.value = state.global.settings.download_backend || 'python';
downloadBackendSelect.onchange = () => {
updateVisibility();
this.saveSelectSetting('downloadBackend', 'download_backend');
};
}
const aria2cPathInput = document.getElementById('aria2cPath');
if (aria2cPathInput) {
aria2cPathInput.value = state.global.settings.aria2c_path || '';
}
updateVisibility();
}
setupPriorityTagInputs() {
['lora', 'checkpoint', 'embedding'].forEach((modelType) => {
const textarea = document.getElementById(`${modelType}PriorityTagsInput`);

View File

@@ -6,6 +6,8 @@ import { DEFAULT_PATH_TEMPLATES, DEFAULT_PRIORITY_TAG_CONFIG } from '../utils/co
const DEFAULT_SETTINGS_BASE = Object.freeze({
civitai_api_key: '',
civitai_host: 'civitai.com',
download_backend: 'python',
aria2c_path: '',
use_portable_settings: false,
language: 'en',
show_only_sfw: false,

View File

@@ -129,6 +129,43 @@
</div>
</div>
<div class="settings-subsection">
<div class="settings-subsection-header">
<h4>{{ t('settings.sections.downloads') }}</h4>
</div>
<div class="setting-item">
<div class="setting-row">
<div class="setting-info">
<label for="downloadBackend">{{ t('settings.downloadBackend.label') }}</label>
<i class="fas fa-info-circle info-icon" data-tooltip="{{ t('settings.downloadBackend.help') }}"></i>
</div>
<div class="setting-control select-control">
<select id="downloadBackend" onchange="settingsManager.saveSelectSetting('downloadBackend', 'download_backend')">
<option value="python">{{ t('settings.downloadBackend.options.python') }}</option>
<option value="aria2">{{ t('settings.downloadBackend.options.aria2') }}</option>
</select>
</div>
</div>
</div>
<div class="setting-item" id="aria2PathSetting" style="display: none;">
<div class="setting-row">
<div class="setting-info">
<label for="aria2cPath">{{ t('settings.aria2cPath.label') }}</label>
<i class="fas fa-info-circle info-icon" data-tooltip="{{ t('settings.aria2cPath.help') }}"></i>
</div>
<div class="setting-control">
<div class="text-input-wrapper">
<input type="text"
id="aria2cPath"
placeholder="{{ t('settings.aria2cPath.placeholder') }}"
onblur="settingsManager.saveInputSetting('aria2cPath', 'aria2c_path')"
onkeydown="if(event.key === 'Enter') { this.blur(); }" />
</div>
</div>
</div>
</div>
</div>
<!-- Backup -->
<div class="settings-subsection">
<div class="settings-subsection-header">

View File

@@ -305,4 +305,39 @@ describe('SettingsManager library controls', () => {
'success',
);
});
it('loads download backend settings and toggles the aria2 path field', () => {
const manager = createManager();
document.body.innerHTML = `
<select id="downloadBackend">
<option value="python">Python</option>
<option value="aria2">aria2</option>
</select>
<div id="aria2PathSetting" style="display: none;"></div>
<input id="aria2cPath" />
`;
state.global.settings = {
download_backend: 'aria2',
aria2c_path: '/usr/bin/aria2c',
};
const saveSpy = vi.spyOn(manager, 'saveSelectSetting').mockResolvedValue();
manager.loadDownloadBackendSettings();
const backendSelect = document.getElementById('downloadBackend');
const aria2PathSetting = document.getElementById('aria2PathSetting');
const aria2cPath = document.getElementById('aria2cPath');
expect(backendSelect.value).toBe('aria2');
expect(aria2cPath.value).toBe('/usr/bin/aria2c');
expect(aria2PathSetting.style.display).toBe('block');
backendSelect.value = 'python';
backendSelect.onchange();
expect(aria2PathSetting.style.display).toBe('none');
expect(saveSpy).toHaveBeenCalledWith('downloadBackend', 'download_backend');
});
});

View File

@@ -0,0 +1,354 @@
from __future__ import annotations
import asyncio
from pathlib import Path
from unittest.mock import AsyncMock
import pytest
from py.services.aria2_downloader import Aria2Downloader, Aria2Error
from py.services.aria2_transfer_state import Aria2TransferStateStore
from py.services import aria2_transfer_state
@pytest.fixture(autouse=True)
def isolate_aria2_state(monkeypatch, tmp_path):
state_path = tmp_path / "cache" / "aria2" / "downloads.json"
monkeypatch.setattr(
aria2_transfer_state,
"get_aria2_state_path",
lambda: str(state_path),
)
@pytest.mark.asyncio
async def test_download_file_polls_until_complete(tmp_path, monkeypatch):
downloader = Aria2Downloader()
downloader._rpc_url = "http://127.0.0.1/jsonrpc"
downloader._rpc_secret = "secret"
save_path = tmp_path / "downloads" / "model.safetensors"
progress_events = []
rpc_calls = []
statuses = iter(
[
{
"gid": "gid-1",
"status": "active",
"completedLength": "5",
"totalLength": "10",
"downloadSpeed": "25",
},
{
"gid": "gid-1",
"status": "complete",
"completedLength": "10",
"totalLength": "10",
"downloadSpeed": "0",
"files": [{"path": str(save_path)}],
},
]
)
async def fake_rpc_call(method, params):
rpc_calls.append((method, params))
if method == "aria2.addUri":
return "gid-1"
if method == "aria2.tellStatus":
return next(statuses)
raise AssertionError(f"Unexpected RPC method: {method}")
monkeypatch.setattr(downloader, "_ensure_process", AsyncMock())
monkeypatch.setattr(
downloader,
"_resolve_authenticated_redirect_url",
AsyncMock(
return_value="https://signed.example.com/model.safetensors?token=abc"
),
)
monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call)
monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock())
async def progress_callback(progress, snapshot=None):
progress_events.append(snapshot.percent_complete if snapshot else progress)
success, result = await downloader.download_file(
"https://civitai.com/api/download/models/123",
str(save_path),
download_id="download-1",
progress_callback=progress_callback,
headers={"Authorization": "Bearer token"},
)
assert success is True
assert result == str(save_path)
assert progress_events == [50.0, 100.0]
assert downloader._transfers == {}
assert rpc_calls[0][0] == "aria2.addUri"
assert rpc_calls[0][1][0] == [
"https://signed.example.com/model.safetensors?token=abc"
]
assert rpc_calls[0][1][1]["out"] == "model.safetensors"
assert "header" not in rpc_calls[0][1][1]
@pytest.mark.asyncio
async def test_transfer_state_store_shares_lock_and_preserves_concurrent_updates(tmp_path):
state_path = tmp_path / "cache" / "aria2" / "downloads.json"
store_a = Aria2TransferStateStore(str(state_path))
store_b = Aria2TransferStateStore(str(state_path))
assert store_a._lock is store_b._lock
await asyncio.gather(
store_a.upsert("download-1", {"status": "downloading", "gid": "gid-1"}),
store_b.upsert("download-2", {"status": "paused", "gid": "gid-2"}),
)
assert await store_a.get("download-1") == {"status": "downloading", "gid": "gid-1"}
assert await store_b.get("download-2") == {"status": "paused", "gid": "gid-2"}
@pytest.mark.asyncio
async def test_download_file_keeps_auth_headers_when_civitai_does_not_redirect(
tmp_path, monkeypatch
):
downloader = Aria2Downloader()
downloader._rpc_url = "http://127.0.0.1/jsonrpc"
downloader._rpc_secret = "secret"
save_path = tmp_path / "downloads" / "model.safetensors"
rpc_calls = []
statuses = iter(
[
{
"gid": "gid-1",
"status": "complete",
"completedLength": "10",
"totalLength": "10",
"downloadSpeed": "0",
"files": [{"path": str(save_path)}],
},
]
)
async def fake_rpc_call(method, params):
rpc_calls.append((method, params))
if method == "aria2.addUri":
return "gid-1"
if method == "aria2.tellStatus":
return next(statuses)
raise AssertionError(f"Unexpected RPC method: {method}")
monkeypatch.setattr(downloader, "_ensure_process", AsyncMock())
monkeypatch.setattr(
downloader,
"_resolve_authenticated_redirect_url",
AsyncMock(return_value="https://civitai.com/api/download/models/123"),
)
monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call)
monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock())
success, result = await downloader.download_file(
"https://civitai.com/api/download/models/123",
str(save_path),
download_id="download-1",
headers={"Authorization": "Bearer token"},
)
assert success is True
assert result == str(save_path)
assert rpc_calls[0][1][0] == ["https://civitai.com/api/download/models/123"]
assert rpc_calls[0][1][1]["header"] == ["Authorization: Bearer token"]
@pytest.mark.asyncio
async def test_pause_resume_cancel_forward_to_rpc(monkeypatch):
downloader = Aria2Downloader()
downloader._transfers["download-1"] = type(
"Transfer", (), {"gid": "gid-1", "save_path": "/tmp/model.safetensors"}
)()
calls = []
async def fake_rpc_call(method, params):
calls.append((method, params))
return "gid-1"
monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call)
pause_result = await downloader.pause_download("download-1")
resume_result = await downloader.resume_download("download-1")
cancel_result = await downloader.cancel_download("download-1")
assert pause_result["success"] is True
assert resume_result["success"] is True
assert cancel_result["success"] is True
assert calls == [
("aria2.forcePause", ["gid-1"]),
("aria2.unpause", ["gid-1"]),
("aria2.forceRemove", ["gid-1"]),
]
@pytest.mark.asyncio
async def test_download_file_reuses_existing_transfer_without_add_uri(
tmp_path, monkeypatch
):
downloader = Aria2Downloader()
downloader._rpc_url = "http://127.0.0.1/jsonrpc"
downloader._rpc_secret = "secret"
save_path = tmp_path / "downloads" / "model.safetensors"
downloader._transfers["download-1"] = type(
"Transfer", (), {"gid": "gid-1", "save_path": str(save_path)}
)()
rpc_calls = []
statuses = iter(
[
{
"gid": "gid-1",
"status": "active",
"completedLength": "5",
"totalLength": "10",
"downloadSpeed": "25",
},
{
"gid": "gid-1",
"status": "complete",
"completedLength": "10",
"totalLength": "10",
"downloadSpeed": "0",
"files": [{"path": str(save_path)}],
},
]
)
async def fake_rpc_call(method, params):
rpc_calls.append((method, params))
if method == "aria2.tellStatus":
return next(statuses)
raise AssertionError(f"Unexpected RPC method: {method}")
monkeypatch.setattr(downloader, "_ensure_process", AsyncMock())
monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call)
monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock())
success, result = await downloader.download_file(
"https://example.com/model.safetensors",
str(save_path),
download_id="download-1",
)
assert success is True
assert result == str(save_path)
assert [call[0] for call in rpc_calls] == ["aria2.tellStatus", "aria2.tellStatus"]
def test_build_progress_snapshot_normalizes_numeric_fields():
downloader = Aria2Downloader()
snapshot = downloader._build_progress_snapshot(
{
"completedLength": "75",
"totalLength": "100",
"downloadSpeed": "512",
}
)
assert snapshot.percent_complete == 75.0
assert snapshot.bytes_downloaded == 75
assert snapshot.total_bytes == 100
assert snapshot.bytes_per_second == 512.0
def test_resolve_executable_raises_when_binary_missing(monkeypatch):
downloader = Aria2Downloader()
settings = type("Settings", (), {"get": lambda self, key, default=None: ""})()
monkeypatch.setattr("py.services.aria2_downloader.get_settings_manager", lambda: settings)
monkeypatch.setattr("py.services.aria2_downloader.shutil.which", lambda _: None)
with pytest.raises(Aria2Error):
downloader._resolve_executable()
@pytest.mark.asyncio
async def test_rpc_call_surfaces_json_error_on_non_200(monkeypatch):
downloader = Aria2Downloader()
downloader._rpc_url = "http://127.0.0.1:6800/jsonrpc"
downloader._rpc_secret = "secret"
class FakeResponse:
status = 400
async def text(self):
return (
'{"jsonrpc":"2.0","id":"x","error":{"code":1,"message":"Unauthorized"}}'
)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
class FakeSession:
def post(self, _url, json=None):
return FakeResponse()
monkeypatch.setattr(downloader, "_get_rpc_session", AsyncMock(return_value=FakeSession()))
with pytest.raises(Aria2Error) as exc_info:
await downloader._rpc_call("aria2.addUri", [["https://example.com/file"]])
assert "Unauthorized" in str(exc_info.value)
assert "aria2.addUri" in str(exc_info.value)
@pytest.mark.asyncio
async def test_resolve_authenticated_redirect_url_returns_location(monkeypatch):
downloader = Aria2Downloader()
class FakeResponse:
status = 307
headers = {"Location": "https://signed.example.com/file.safetensors"}
async def text(self):
return ""
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
class FakeSession:
def get(self, _url, headers=None, allow_redirects=False, proxy=None):
return FakeResponse()
class FakeDownloader:
default_headers = {"User-Agent": "ComfyUI-LoRA-Manager/1.0"}
proxy_url = None
@property
def session(self):
async def _session():
return FakeSession()
return _session()
fake_downloader = FakeDownloader()
monkeypatch.setattr(
"py.services.aria2_downloader.get_downloader",
AsyncMock(return_value=fake_downloader),
)
result = await downloader._resolve_authenticated_redirect_url(
"https://civitai.com/api/download/models/123",
{"Authorization": "Bearer token"},
)
assert result == "https://signed.example.com/file.safetensors"

View File

@@ -10,6 +10,7 @@ import pytest
from py.services.download_manager import DownloadManager
from py.services import download_manager
from py.services import aria2_transfer_state
from py.services.service_registry import ServiceRegistry
from py.services.settings_manager import SettingsManager, get_settings_manager
@@ -46,6 +47,16 @@ def isolate_settings(monkeypatch, tmp_path):
monkeypatch.setattr(SettingsManager, "_save_settings", lambda self: None)
@pytest.fixture(autouse=True)
def isolate_aria2_state(monkeypatch, tmp_path):
state_path = tmp_path / "cache" / "aria2" / "downloads.json"
monkeypatch.setattr(
aria2_transfer_state,
"get_aria2_state_path",
lambda: str(state_path),
)
@pytest.fixture(autouse=True)
def stub_metadata(monkeypatch):
class _StubMetadata:
@@ -179,6 +190,7 @@ async def test_successful_download_uses_defaults(
progress_callback,
model_type,
download_id,
transfer_backend=None,
):
captured.update(
{
@@ -268,6 +280,7 @@ async def test_download_uses_active_mirrors(
progress_callback,
model_type,
download_id,
transfer_backend=None,
):
captured["download_urls"] = download_urls
return {"success": True}
@@ -288,6 +301,644 @@ async def test_download_uses_active_mirrors(
assert captured["download_urls"] == ["https://mirror.example/file.safetensors"]
@pytest.mark.asyncio
async def test_pause_resume_cancel_delegate_to_aria2_backend(monkeypatch):
manager = DownloadManager()
task = asyncio.create_task(asyncio.sleep(60))
manager._download_tasks["download-1"] = task
manager._pause_events["download-1"] = download_manager.DownloadStreamControl()
manager._active_downloads["download-1"] = {
"transfer_backend": "aria2",
"status": "downloading",
}
class DummyAria2Downloader:
def __init__(self):
self.calls = []
async def pause_download(self, download_id):
self.calls.append(("pause", download_id))
return {"success": True, "message": "paused"}
async def resume_download(self, download_id):
self.calls.append(("resume", download_id))
return {"success": True, "message": "resumed"}
async def cancel_download(self, download_id):
self.calls.append(("cancel", download_id))
return {"success": True, "message": "cancelled"}
async def has_transfer(self, download_id):
self.calls.append(("has_transfer", download_id))
return True
dummy_aria2 = DummyAria2Downloader()
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=dummy_aria2),
)
pause_result = await manager.pause_download("download-1")
assert pause_result["success"] is True
assert manager._active_downloads["download-1"]["status"] == "paused"
resume_result = await manager.resume_download("download-1")
assert resume_result["success"] is True
assert manager._active_downloads["download-1"]["status"] == "downloading"
cancel_result = await manager.cancel_download("download-1")
assert cancel_result["success"] is True
assert task.cancelled() or task.done()
assert dummy_aria2.calls == [
("has_transfer", "download-1"),
("pause", "download-1"),
("has_transfer", "download-1"),
("resume", "download-1"),
("cancel", "download-1"),
]
@pytest.mark.asyncio
async def test_cancel_allows_queued_aria2_task_without_transfer(monkeypatch):
manager = DownloadManager()
started = asyncio.Event()
async def blocked_task():
started.set()
await asyncio.sleep(60)
task = asyncio.create_task(blocked_task())
await started.wait()
manager._download_tasks["download-queued"] = task
manager._active_downloads["download-queued"] = {
"transfer_backend": "aria2",
"status": "queued",
}
class DummyAria2Downloader:
async def cancel_download(self, download_id):
return {"success": False, "error": "Download task not found"}
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
result = await manager.cancel_download("download-queued")
assert result["success"] is True
assert task.cancelled() or task.done()
@pytest.mark.asyncio
async def test_pause_resume_queued_aria2_task_without_transfer(monkeypatch):
manager = DownloadManager()
task = asyncio.create_task(asyncio.sleep(60))
manager._download_tasks["download-queued"] = task
manager._pause_events["download-queued"] = download_manager.DownloadStreamControl()
manager._active_downloads["download-queued"] = {
"transfer_backend": "aria2",
"status": "waiting",
"bytes_per_second": 12.0,
}
class DummyAria2Downloader:
def __init__(self):
self.calls = []
async def has_transfer(self, download_id):
self.calls.append(("has_transfer", download_id))
return False
async def pause_download(self, download_id):
self.calls.append(("pause", download_id))
return {"success": True, "message": "paused"}
async def resume_download(self, download_id):
self.calls.append(("resume", download_id))
return {"success": True, "message": "resumed"}
dummy_aria2 = DummyAria2Downloader()
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=dummy_aria2),
)
pause_result = await manager.pause_download("download-queued")
assert pause_result == {"success": True, "message": "Download paused successfully"}
assert manager._active_downloads["download-queued"]["status"] == "paused"
assert manager._pause_events["download-queued"].is_paused() is True
resume_result = await manager.resume_download("download-queued")
assert resume_result == {"success": True, "message": "Download resumed successfully"}
assert manager._active_downloads["download-queued"]["status"] == "downloading"
assert manager._pause_events["download-queued"].is_set() is True
assert dummy_aria2.calls == [
("has_transfer", "download-queued"),
("has_transfer", "download-queued"),
]
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task
@pytest.mark.asyncio
async def test_resume_download_restores_persisted_aria2_task(monkeypatch, tmp_path):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
save_path = save_dir / "file.safetensors"
save_path.write_text("partial")
(save_dir / "file.safetensors.aria2").write_text("control")
await manager._aria2_state_store.upsert(
"download-1",
{
"download_id": "download-1",
"transfer_backend": "aria2",
"status": "paused",
"save_dir": str(save_dir),
"relative_path": "",
"use_default_paths": False,
"save_path": str(save_path),
"file_path": str(save_path),
"model_id": 12,
"model_version_id": 34,
},
)
created = {}
async def fake_download_with_semaphore(
self,
task_id,
model_id,
model_version_id,
save_dir,
relative_path,
progress_callback=None,
use_default_paths=False,
source=None,
file_params=None,
):
created.update(
{
"task_id": task_id,
"model_id": model_id,
"model_version_id": model_version_id,
"save_dir": save_dir,
}
)
return {"success": True}
class DummyAria2Downloader:
def __init__(self):
self.calls = []
async def get_status_by_gid(self, gid):
return None
async def has_transfer(self, download_id):
self.calls.append(("has_transfer", download_id))
return False
async def resume_download(self, download_id):
self.calls.append(("resume", download_id))
return {"success": True, "message": "resumed"}
async def restore_transfer(self, download_id, gid, save_path):
self.calls.append(("restore_transfer", download_id, gid, save_path))
dummy_aria2 = DummyAria2Downloader()
monkeypatch.setattr(
download_manager, "_download_with_semaphore", None, raising=False
)
monkeypatch.setattr(
DownloadManager,
"_download_with_semaphore",
fake_download_with_semaphore,
)
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=dummy_aria2),
)
result = await manager.resume_download("download-1")
await asyncio.sleep(0)
assert result == {"success": True, "message": "Download resumed successfully"}
assert created["task_id"] == "download-1"
assert created["model_version_id"] == 34
assert manager._active_downloads["download-1"]["status"] == "downloading"
assert manager._pause_events["download-1"].is_set() is True
@pytest.mark.asyncio
async def test_get_active_downloads_restores_persisted_aria2_entries(monkeypatch, tmp_path):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
save_path = save_dir / "file.safetensors"
save_path.write_text("partial")
(save_dir / "file.safetensors.aria2").write_text("control")
await manager._aria2_state_store.upsert(
"download-1",
{
"download_id": "download-1",
"transfer_backend": "aria2",
"status": "paused",
"save_path": str(save_path),
"file_path": str(save_path),
"model_id": 12,
"model_version_id": 34,
},
)
class DummyAria2Downloader:
async def get_status_by_gid(self, gid):
return None
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
downloads = await manager.get_active_downloads()
assert downloads["downloads"] == [
{
"download_id": "download-1",
"model_id": 12,
"model_version_id": 34,
"progress": 0,
"status": "paused",
"error": None,
"bytes_downloaded": 0,
"total_bytes": None,
"bytes_per_second": 0.0,
}
]
@pytest.mark.asyncio
async def test_get_active_downloads_restores_orphaned_aria2_partial_as_paused(
monkeypatch, tmp_path
):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
save_path = save_dir / "file.safetensors"
save_path.write_text("partial")
(save_dir / "file.safetensors.aria2").write_text("control")
await manager._aria2_state_store.upsert(
"download-1",
{
"download_id": "download-1",
"transfer_backend": "aria2",
"status": "downloading",
"save_path": str(save_path),
"file_path": str(save_path),
"model_id": 12,
"model_version_id": 34,
"gid": "missing-gid",
},
)
class DummyAria2Downloader:
async def get_status_by_gid(self, gid):
return None
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
downloads = await manager.get_active_downloads()
persisted = await manager._aria2_state_store.get("download-1")
assert downloads["downloads"] == [
{
"download_id": "download-1",
"model_id": 12,
"model_version_id": 34,
"progress": 0,
"status": "paused",
"error": None,
"bytes_downloaded": 0,
"total_bytes": None,
"bytes_per_second": 0.0,
}
]
assert manager._pause_events["download-1"].is_paused() is True
assert persisted["status"] == "paused"
@pytest.mark.asyncio
async def test_get_active_downloads_restarts_from_resume_context_for_active_restored_aria2(
monkeypatch, tmp_path
):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
save_path = save_dir / "file.safetensors"
save_path.write_text("partial")
await manager._aria2_state_store.upsert(
"download-1",
{
"download_id": "download-1",
"transfer_backend": "aria2",
"status": "downloading",
"save_path": str(save_path),
"file_path": str(save_path),
"model_id": 12,
"model_version_id": 34,
"gid": "gid-1",
"resume_context": {
"version_info": {
"id": 34,
"modelId": 12,
"model": {"id": 12, "type": "LoRA", "tags": ["fantasy"]},
"images": [],
},
"file_info": {
"name": "file.safetensors",
"type": "Model",
"primary": True,
"downloadUrl": "https://example.com/file.safetensors",
},
"model_type": "lora",
"relative_path": "",
"save_dir": str(save_dir),
"download_urls": ["https://example.com/file.safetensors"],
},
},
)
restarted = {}
class DummyAria2Downloader:
async def get_status_by_gid(self, gid):
return {"gid": gid, "status": "active"}
async def restore_transfer(self, download_id, gid, restored_path):
return None
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
async def fake_resume_restored_aria2_download(self, download_id, record):
restarted.update(
{
"download_id": download_id,
"model_id": record.get("model_id"),
"model_version_id": record.get("model_version_id"),
"save_dir": record.get("save_dir"),
"resume_context": record.get("resume_context"),
}
)
return {"success": True}
monkeypatch.setattr(
DownloadManager,
"_resume_restored_aria2_download",
fake_resume_restored_aria2_download,
)
execute_original = AsyncMock(side_effect=AssertionError("should not refetch metadata"))
monkeypatch.setattr(
DownloadManager,
"_execute_original_download",
execute_original,
)
downloads = await manager.get_active_downloads()
assert downloads["downloads"][0]["status"] == "downloading"
restarted_task = manager._download_tasks["download-1"]
await restarted_task
assert restarted["download_id"] == "download-1"
assert restarted["model_id"] == 12
assert restarted["model_version_id"] == 34
assert restarted["save_dir"] is None
assert restarted["resume_context"]["model_type"] == "lora"
assert execute_original.await_count == 0
@pytest.mark.asyncio
async def test_get_active_downloads_restores_persisted_aria2_without_initial_save_path(
monkeypatch, tmp_path
):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
save_path = save_dir / "file.safetensors"
save_path.write_text("partial")
(save_dir / "file.safetensors.aria2").write_text("control")
await manager._aria2_state_store.upsert(
"download-1",
{
"download_id": "download-1",
"transfer_backend": "aria2",
"status": "paused",
"model_id": 12,
"model_version_id": 34,
"resume_context": {
"version_info": {
"id": 34,
"modelId": 12,
"model": {"id": 12, "type": "LoRA"},
"images": [],
},
"file_info": {
"name": "file.safetensors",
"type": "Model",
"primary": True,
"downloadUrl": "https://example.com/file.safetensors",
},
"model_type": "lora",
"relative_path": "",
"save_dir": str(save_dir),
"download_urls": ["https://example.com/file.safetensors"],
},
},
)
class DummyAria2Downloader:
async def get_status_by_gid(self, gid):
return None
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
downloads = await manager.get_active_downloads()
persisted = await manager._aria2_state_store.get("download-1")
assert downloads["downloads"] == [
{
"download_id": "download-1",
"model_id": 12,
"model_version_id": 34,
"progress": 0,
"status": "paused",
"error": None,
"bytes_downloaded": 0,
"total_bytes": None,
"bytes_per_second": 0.0,
}
]
assert manager._active_downloads["download-1"]["file_path"] == str(save_path)
assert persisted["save_path"] == str(save_path)
assert persisted["file_path"] == str(save_path)
@pytest.mark.asyncio
async def test_resume_restored_aria2_download_updates_terminal_status_and_cleanup(monkeypatch):
manager = DownloadManager()
manager._active_downloads["download-1"] = {
"transfer_backend": "aria2",
"status": "paused",
"model_id": 12,
"model_version_id": 34,
"bytes_per_second": 10.0,
}
persist_state = AsyncMock()
cleanup_record = AsyncMock(return_value=None)
execute_download = AsyncMock(return_value={"success": True})
record_history = AsyncMock(return_value=None)
sync_version = AsyncMock(return_value=None)
monkeypatch.setattr(manager, "_persist_aria2_state", persist_state)
monkeypatch.setattr(manager, "_cleanup_download_record", cleanup_record)
monkeypatch.setattr(manager, "_execute_download", execute_download)
monkeypatch.setattr(manager, "_record_downloaded_version_history", record_history)
monkeypatch.setattr(manager, "_sync_downloaded_version", sync_version)
scheduled_tasks = []
original_create_task = asyncio.create_task
def tracking_create_task(coro):
task = original_create_task(coro)
scheduled_tasks.append(task)
return task
monkeypatch.setattr(download_manager.asyncio, "create_task", tracking_create_task)
result = await manager._resume_restored_aria2_download(
"download-1",
{
"download_id": "download-1",
"save_path": "/tmp/file.safetensors",
"file_path": "/tmp/file.safetensors",
"model_id": 12,
"model_version_id": 34,
"resume_context": {
"version_info": {
"id": 34,
"modelId": 12,
"model": {"id": 12},
"images": [],
},
"file_info": {
"name": "file.safetensors",
"downloadUrl": "https://example.com/file.safetensors",
},
"model_type": "lora",
"relative_path": "",
"save_dir": "/tmp",
"download_urls": ["https://example.com/file.safetensors"],
},
},
)
assert result == {"success": True}
assert manager._active_downloads["download-1"]["status"] == "completed"
assert manager._active_downloads["download-1"]["bytes_per_second"] == 0.0
assert persist_state.await_count == 2
assert len(scheduled_tasks) == 1
await asyncio.gather(*scheduled_tasks)
cleanup_record.assert_awaited_once_with("download-1")
@pytest.mark.asyncio
async def test_download_uses_captured_backend_when_settings_change(
monkeypatch, scanners, metadata_provider, tmp_path
):
manager = DownloadManager()
settings = get_settings_manager()
settings.settings["download_backend"] = "aria2"
semaphore = asyncio.Semaphore(0)
manager._download_semaphore = semaphore
captured = {}
async def fake_execute_original_download(
self,
model_id,
model_version_id,
save_dir,
relative_path,
progress_callback,
use_default_paths,
download_id=None,
transfer_backend="python",
source=None,
file_params=None,
):
captured["transfer_backend"] = transfer_backend
return {"success": True}
monkeypatch.setattr(
DownloadManager,
"_execute_original_download",
fake_execute_original_download,
)
download_task = asyncio.create_task(
manager.download_from_civitai(
model_version_id=99,
save_dir=str(tmp_path),
use_default_paths=True,
progress_callback=None,
source=None,
)
)
await asyncio.sleep(0)
assert len(manager._active_downloads) == 1
download_id = next(iter(manager._active_downloads))
assert manager._active_downloads[download_id]["transfer_backend"] == "aria2"
settings.settings["download_backend"] = "python"
semaphore.release()
result = await download_task
assert result["success"] is True
assert captured["transfer_backend"] == "aria2"
@pytest.mark.asyncio
async def test_download_aborts_when_version_exists(
monkeypatch, scanners, metadata_provider

File diff suppressed because it is too large Load Diff

View File

@@ -147,6 +147,11 @@ def test_environment_variable_overrides_settings(tmp_path, monkeypatch):
assert mgr.get("civitai_api_key") == "secret"
def test_default_download_backend_is_python(manager):
assert manager.get("download_backend") == "python"
assert manager.get("aria2c_path") == ""
def _create_manager_with_settings(
tmp_path, monkeypatch, initial_settings, *, save_spy=None
):