feat(download): add experimental aria2 backend

This commit is contained in:
Will Miao
2026-04-19 21:46:09 +08:00
parent 0ced53c059
commit 1c530ea013
21 changed files with 1867 additions and 28 deletions

View File

@@ -263,6 +263,19 @@
"red": "civitai.red (uneingeschränkt)" "red": "civitai.red (uneingeschränkt)"
} }
}, },
"downloadBackend": {
"label": "Download-Backend",
"help": "Wähle aus, wie Modelldateien heruntergeladen werden. Python verwendet den eingebauten Downloader. aria2 verwendet den experimentellen externen Downloader-Prozess.",
"options": {
"python": "Python (integriert)",
"aria2": "aria2 (experimentell)"
}
},
"aria2cPath": {
"label": "aria2c-Pfad",
"help": "Optionaler Pfad zur ausführbaren aria2c-Datei. Leer lassen, um aria2c aus dem System-PATH zu verwenden.",
"placeholder": "Leer lassen, um aria2c aus dem PATH zu verwenden"
},
"civitaiHostBanner": { "civitaiHostBanner": {
"title": "Civitai-Host-Einstellung verfügbar", "title": "Civitai-Host-Einstellung verfügbar",
"content": "Civitai verwendet jetzt civitai.com für SFW-Inhalte und civitai.red für uneingeschränkte Inhalte. In den Einstellungen können Sie ändern, welche Seite standardmäßig geöffnet wird.", "content": "Civitai verwendet jetzt civitai.com für SFW-Inhalte und civitai.red für uneingeschränkte Inhalte. In den Einstellungen können Sie ändern, welche Seite standardmäßig geöffnet wird.",
@@ -278,6 +291,7 @@
}, },
"sections": { "sections": {
"contentFiltering": "Inhaltsfilterung", "contentFiltering": "Inhaltsfilterung",
"downloads": "Downloads",
"videoSettings": "Video-Einstellungen", "videoSettings": "Video-Einstellungen",
"layoutSettings": "Layout-Einstellungen", "layoutSettings": "Layout-Einstellungen",
"misc": "Verschiedenes", "misc": "Verschiedenes",

View File

@@ -263,6 +263,19 @@
"red": "civitai.red (unrestricted)" "red": "civitai.red (unrestricted)"
} }
}, },
"downloadBackend": {
"label": "Download backend",
"help": "Choose how model files are downloaded. Python uses the built-in downloader. aria2 uses the experimental external downloader process.",
"options": {
"python": "Python (built-in)",
"aria2": "aria2 (experimental)"
}
},
"aria2cPath": {
"label": "aria2c path",
"help": "Optional path to the aria2c executable. Leave empty to use aria2c from your system PATH.",
"placeholder": "Leave empty to use aria2c from PATH"
},
"civitaiHostBanner": { "civitaiHostBanner": {
"title": "Civitai host preference available", "title": "Civitai host preference available",
"content": "Civitai now uses civitai.com for SFW content and civitai.red for unrestricted content. You can change which site opens by default in Settings.", "content": "Civitai now uses civitai.com for SFW content and civitai.red for unrestricted content. You can change which site opens by default in Settings.",
@@ -278,6 +291,7 @@
}, },
"sections": { "sections": {
"contentFiltering": "Content Filtering", "contentFiltering": "Content Filtering",
"downloads": "Downloads",
"videoSettings": "Video Settings", "videoSettings": "Video Settings",
"layoutSettings": "Layout Settings", "layoutSettings": "Layout Settings",
"misc": "Miscellaneous", "misc": "Miscellaneous",

View File

@@ -263,6 +263,19 @@
"red": "civitai.red (sin restricciones)" "red": "civitai.red (sin restricciones)"
} }
}, },
"downloadBackend": {
"label": "Backend de descarga",
"help": "Elige cómo se descargan los archivos del modelo. Python usa el descargador integrado. aria2 usa el proceso externo experimental de descarga.",
"options": {
"python": "Python (integrado)",
"aria2": "aria2 (experimental)"
}
},
"aria2cPath": {
"label": "Ruta de aria2c",
"help": "Ruta opcional al ejecutable aria2c. Déjalo vacío para usar aria2c desde el PATH del sistema.",
"placeholder": "Déjalo vacío para usar aria2c desde el PATH"
},
"civitaiHostBanner": { "civitaiHostBanner": {
"title": "Preferencia de host de Civitai disponible", "title": "Preferencia de host de Civitai disponible",
"content": "Civitai ahora usa civitai.com para contenido SFW y civitai.red para contenido sin restricciones. Puedes cambiar en Ajustes qué sitio se abre por defecto.", "content": "Civitai ahora usa civitai.com para contenido SFW y civitai.red para contenido sin restricciones. Puedes cambiar en Ajustes qué sitio se abre por defecto.",
@@ -278,6 +291,7 @@
}, },
"sections": { "sections": {
"contentFiltering": "Filtrado de contenido", "contentFiltering": "Filtrado de contenido",
"downloads": "Descargas",
"videoSettings": "Configuración de video", "videoSettings": "Configuración de video",
"layoutSettings": "Configuración de diseño", "layoutSettings": "Configuración de diseño",
"misc": "Varios", "misc": "Varios",

View File

@@ -263,6 +263,19 @@
"red": "civitai.red (sans restriction)" "red": "civitai.red (sans restriction)"
} }
}, },
"downloadBackend": {
"label": "Moteur de téléchargement",
"help": "Choisissez comment les fichiers de modèles sont téléchargés. Python utilise le téléchargeur intégré. aria2 utilise le processus externe expérimental de téléchargement.",
"options": {
"python": "Python (intégré)",
"aria2": "aria2 (expérimental)"
}
},
"aria2cPath": {
"label": "Chemin vers aria2c",
"help": "Chemin facultatif vers lexécutable aria2c. Laissez vide pour utiliser aria2c depuis le PATH système.",
"placeholder": "Laisser vide pour utiliser aria2c depuis le PATH"
},
"civitaiHostBanner": { "civitaiHostBanner": {
"title": "Préférence dhôte Civitai disponible", "title": "Préférence dhôte Civitai disponible",
"content": "Civitai utilise désormais civitai.com pour le contenu SFW et civitai.red pour le contenu sans restriction. Vous pouvez modifier dans les paramètres le site ouvert par défaut.", "content": "Civitai utilise désormais civitai.com pour le contenu SFW et civitai.red pour le contenu sans restriction. Vous pouvez modifier dans les paramètres le site ouvert par défaut.",
@@ -278,6 +291,7 @@
}, },
"sections": { "sections": {
"contentFiltering": "Filtrage du contenu", "contentFiltering": "Filtrage du contenu",
"downloads": "Téléchargements",
"videoSettings": "Paramètres vidéo", "videoSettings": "Paramètres vidéo",
"layoutSettings": "Paramètres d'affichage", "layoutSettings": "Paramètres d'affichage",
"misc": "Divers", "misc": "Divers",

View File

@@ -263,6 +263,19 @@
"red": "civitai.red (ללא הגבלות)" "red": "civitai.red (ללא הגבלות)"
} }
}, },
"downloadBackend": {
"label": "מנגנון הורדה",
"help": "בחר כיצד יורדים קבצי המודל. Python משתמש במוריד המובנה. aria2 משתמש בתהליך הורדה חיצוני ניסיוני.",
"options": {
"python": "Python (מובנה)",
"aria2": "aria2 (ניסיוני)"
}
},
"aria2cPath": {
"label": "נתיב aria2c",
"help": "נתיב אופציונלי לקובץ ההפעלה aria2c. השאר ריק כדי להשתמש ב-aria2c מתוך ה-PATH של המערכת.",
"placeholder": "השאר ריק כדי להשתמש ב-aria2c מתוך ה-PATH"
},
"civitaiHostBanner": { "civitaiHostBanner": {
"title": "העדפת מארח Civitai זמינה", "title": "העדפת מארח Civitai זמינה",
"content": "Civitai משתמש כעת ב-civitai.com עבור תוכן SFW וב-civitai.red עבור תוכן ללא הגבלות. ניתן לשנות בהגדרות איזה אתר ייפתח כברירת מחדל.", "content": "Civitai משתמש כעת ב-civitai.com עבור תוכן SFW וב-civitai.red עבור תוכן ללא הגבלות. ניתן לשנות בהגדרות איזה אתר ייפתח כברירת מחדל.",
@@ -278,6 +291,7 @@
}, },
"sections": { "sections": {
"contentFiltering": "סינון תוכן", "contentFiltering": "סינון תוכן",
"downloads": "הורדות",
"videoSettings": "הגדרות וידאו", "videoSettings": "הגדרות וידאו",
"layoutSettings": "הגדרות פריסה", "layoutSettings": "הגדרות פריסה",
"misc": "שונות", "misc": "שונות",

View File

@@ -263,6 +263,19 @@
"red": "civitai.red制限なし" "red": "civitai.red制限なし"
} }
}, },
"downloadBackend": {
"label": "ダウンロードバックエンド",
"help": "モデルファイルのダウンロード方法を選択します。Python は内蔵ダウンローダーを使用し、aria2 は実験的な外部ダウンローダープロセスを使用します。",
"options": {
"python": "Python内蔵",
"aria2": "aria2実験的"
}
},
"aria2cPath": {
"label": "aria2c のパス",
"help": "aria2c 実行ファイルへの任意のパスです。空欄のままにすると、システム PATH 上の aria2c を使用します。",
"placeholder": "空欄のままにすると PATH 上の aria2c を使用します"
},
"civitaiHostBanner": { "civitaiHostBanner": {
"title": "Civitai ホスト設定を利用できます", "title": "Civitai ホスト設定を利用できます",
"content": "Civitai は現在、SFW コンテンツには civitai.com、制限なしコンテンツには civitai.red を使用しています。設定で既定で開くサイトを変更できます。", "content": "Civitai は現在、SFW コンテンツには civitai.com、制限なしコンテンツには civitai.red を使用しています。設定で既定で開くサイトを変更できます。",
@@ -278,6 +291,7 @@
}, },
"sections": { "sections": {
"contentFiltering": "コンテンツフィルタリング", "contentFiltering": "コンテンツフィルタリング",
"downloads": "ダウンロード",
"videoSettings": "動画設定", "videoSettings": "動画設定",
"layoutSettings": "レイアウト設定", "layoutSettings": "レイアウト設定",
"misc": "その他", "misc": "その他",

View File

@@ -263,6 +263,19 @@
"red": "civitai.red(무제한)" "red": "civitai.red(무제한)"
} }
}, },
"downloadBackend": {
"label": "다운로드 백엔드",
"help": "모델 파일을 다운로드하는 방식을 선택합니다. Python은 내장 다운로더를 사용하고, aria2는 실험적인 외부 다운로더 프로세스를 사용합니다.",
"options": {
"python": "Python(내장)",
"aria2": "aria2(실험적)"
}
},
"aria2cPath": {
"label": "aria2c 경로",
"help": "aria2c 실행 파일의 선택적 경로입니다. 비워 두면 시스템 PATH의 aria2c를 사용합니다.",
"placeholder": "비워 두면 PATH의 aria2c를 사용합니다"
},
"civitaiHostBanner": { "civitaiHostBanner": {
"title": "Civitai 호스트 기본 설정 사용 가능", "title": "Civitai 호스트 기본 설정 사용 가능",
"content": "이제 Civitai는 SFW 콘텐츠에 civitai.com을, 무제한 콘텐츠에 civitai.red를 사용합니다. 설정에서 기본으로 열 사이트를 변경할 수 있습니다.", "content": "이제 Civitai는 SFW 콘텐츠에 civitai.com을, 무제한 콘텐츠에 civitai.red를 사용합니다. 설정에서 기본으로 열 사이트를 변경할 수 있습니다.",
@@ -278,6 +291,7 @@
}, },
"sections": { "sections": {
"contentFiltering": "콘텐츠 필터링", "contentFiltering": "콘텐츠 필터링",
"downloads": "다운로드",
"videoSettings": "비디오 설정", "videoSettings": "비디오 설정",
"layoutSettings": "레이아웃 설정", "layoutSettings": "레이아웃 설정",
"misc": "기타", "misc": "기타",

View File

@@ -263,6 +263,19 @@
"red": "civitai.red (без ограничений)" "red": "civitai.red (без ограничений)"
} }
}, },
"downloadBackend": {
"label": "Бэкенд загрузки",
"help": "Выберите способ загрузки файлов моделей. Python использует встроенный загрузчик. aria2 использует экспериментальный внешний процесс загрузки.",
"options": {
"python": "Python (встроенный)",
"aria2": "aria2 (экспериментальный)"
}
},
"aria2cPath": {
"label": "Путь к aria2c",
"help": "Необязательный путь к исполняемому файлу aria2c. Оставьте пустым, чтобы использовать aria2c из системного PATH.",
"placeholder": "Оставьте пустым, чтобы использовать aria2c из PATH"
},
"civitaiHostBanner": { "civitaiHostBanner": {
"title": "Доступна настройка хоста Civitai", "title": "Доступна настройка хоста Civitai",
"content": "Теперь Civitai использует civitai.com для контента SFW и civitai.red для контента без ограничений. В настройках можно изменить, какой сайт открывать по умолчанию.", "content": "Теперь Civitai использует civitai.com для контента SFW и civitai.red для контента без ограничений. В настройках можно изменить, какой сайт открывать по умолчанию.",
@@ -278,6 +291,7 @@
}, },
"sections": { "sections": {
"contentFiltering": "Фильтрация контента", "contentFiltering": "Фильтрация контента",
"downloads": "Загрузки",
"videoSettings": "Настройки видео", "videoSettings": "Настройки видео",
"layoutSettings": "Настройки макета", "layoutSettings": "Настройки макета",
"misc": "Разное", "misc": "Разное",

View File

@@ -263,6 +263,19 @@
"red": "civitai.red无限制" "red": "civitai.red无限制"
} }
}, },
"downloadBackend": {
"label": "下载后端",
"help": "选择模型文件的下载方式。Python 使用内置下载器。aria2 使用实验性的外部下载进程。",
"options": {
"python": "Python内置",
"aria2": "aria2实验性"
}
},
"aria2cPath": {
"label": "aria2c 路径",
"help": "可选的 aria2c 可执行文件路径。留空则使用系统 PATH 中的 aria2c。",
"placeholder": "留空则使用 PATH 中的 aria2c"
},
"civitaiHostBanner": { "civitaiHostBanner": {
"title": "已提供 Civitai 站点偏好设置", "title": "已提供 Civitai 站点偏好设置",
"content": "Civitai 现在使用 civitai.com 提供 SFW 内容,使用 civitai.red 提供无限制内容。你可以在设置中更改默认打开的站点。", "content": "Civitai 现在使用 civitai.com 提供 SFW 内容,使用 civitai.red 提供无限制内容。你可以在设置中更改默认打开的站点。",
@@ -278,6 +291,7 @@
}, },
"sections": { "sections": {
"contentFiltering": "内容过滤", "contentFiltering": "内容过滤",
"downloads": "下载",
"videoSettings": "视频设置", "videoSettings": "视频设置",
"layoutSettings": "布局设置", "layoutSettings": "布局设置",
"misc": "其他", "misc": "其他",

View File

@@ -263,6 +263,19 @@
"red": "civitai.red無限制" "red": "civitai.red無限制"
} }
}, },
"downloadBackend": {
"label": "下載後端",
"help": "選擇模型檔案的下載方式。Python 使用內建下載器。aria2 使用實驗性的外部下載程序。",
"options": {
"python": "Python內建",
"aria2": "aria2實驗性"
}
},
"aria2cPath": {
"label": "aria2c 路徑",
"help": "可選的 aria2c 可執行檔路徑。留空則使用系統 PATH 中的 aria2c。",
"placeholder": "留空則使用 PATH 中的 aria2c"
},
"civitaiHostBanner": { "civitaiHostBanner": {
"title": "已提供 Civitai 站點偏好設定", "title": "已提供 Civitai 站點偏好設定",
"content": "Civitai 現在使用 civitai.com 提供 SFW 內容,使用 civitai.red 提供無限制內容。你可以在設定中變更預設開啟的站點。", "content": "Civitai 現在使用 civitai.com 提供 SFW 內容,使用 civitai.red 提供無限制內容。你可以在設定中變更預設開啟的站點。",
@@ -278,6 +291,7 @@
}, },
"sections": { "sections": {
"contentFiltering": "內容過濾", "contentFiltering": "內容過濾",
"downloads": "下載",
"videoSettings": "影片設定", "videoSettings": "影片設定",
"layoutSettings": "版面設定", "layoutSettings": "版面設定",
"misc": "其他", "misc": "其他",

View File

@@ -0,0 +1,497 @@
from __future__ import annotations
import asyncio
import json
import logging
import os
import secrets
import shutil
import socket
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
import aiohttp
from .downloader import DownloadProgress, get_downloader
from .settings_manager import get_settings_manager
logger = logging.getLogger(__name__)
CIVITAI_DOWNLOAD_URL_PREFIXES = (
"https://civitai.com/api/download/",
"https://civitai.red/api/download/",
)
class Aria2Error(RuntimeError):
"""Raised when aria2 integration fails."""
@dataclass
class Aria2Transfer:
"""Track an aria2 download registered by the Python coordinator."""
gid: str
save_path: str
class Aria2Downloader:
"""Manage an aria2 RPC daemon for experimental model downloads."""
_instance = None
_lock = asyncio.Lock()
@classmethod
async def get_instance(cls) -> "Aria2Downloader":
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self) -> None:
if hasattr(self, "_initialized"):
return
self._initialized = True
self._process: Optional[asyncio.subprocess.Process] = None
self._rpc_port: Optional[int] = None
self._rpc_secret = ""
self._rpc_url = ""
self._rpc_session: Optional[aiohttp.ClientSession] = None
self._rpc_session_lock = asyncio.Lock()
self._process_lock = asyncio.Lock()
self._transfers: Dict[str, Aria2Transfer] = {}
self._poll_interval = 0.5
@property
def is_running(self) -> bool:
return self._process is not None and self._process.returncode is None
async def download_file(
self,
url: str,
save_path: str,
*,
download_id: str,
progress_callback=None,
headers: Optional[Dict[str, str]] = None,
) -> Tuple[bool, str]:
"""Download a file using aria2 RPC and wait for completion."""
await self._ensure_process()
save_path = os.path.abspath(save_path)
save_dir = os.path.dirname(save_path)
out_name = os.path.basename(save_path)
Path(save_dir).mkdir(parents=True, exist_ok=True)
resolved_url = url
request_headers = headers
if headers and url.startswith(CIVITAI_DOWNLOAD_URL_PREFIXES):
resolved_url = await self._resolve_authenticated_redirect_url(url, headers)
if resolved_url != url:
request_headers = None
logger.debug(
"Resolved Civitai download %s to signed URL for aria2",
download_id,
)
options: Dict[str, str] = {
"dir": save_dir,
"out": out_name,
"continue": "true",
"max-connection-per-server": "4",
"split": "4",
"min-split-size": "1M",
"allow-overwrite": "true",
"auto-file-renaming": "false",
"file-allocation": "none",
}
if request_headers:
options["header"] = [
f"{key}: {value}" for key, value in request_headers.items()
]
logger.debug(
"Submitting aria2 download %s -> %s (auth=%s, civitai_signed=%s)",
download_id,
save_path,
bool(request_headers),
resolved_url != url,
)
try:
gid = await self._rpc_call("aria2.addUri", [[resolved_url], options])
except Exception as exc:
raise Aria2Error(f"Failed to schedule aria2 download: {exc}") from exc
logger.debug("aria2 accepted download %s with gid %s", download_id, gid)
self._transfers[download_id] = Aria2Transfer(gid=gid, save_path=save_path)
try:
while True:
status = await self.get_status(download_id)
if status is None:
return False, "aria2 download not found"
snapshot = self._build_progress_snapshot(status)
if progress_callback is not None:
await self._dispatch_progress(progress_callback, snapshot)
state = status.get("status", "")
if state == "complete":
completed_path = self._resolve_completed_path(status, save_path)
return True, completed_path
if state == "error":
return False, status.get("errorMessage") or "aria2 download failed"
if state == "removed":
return False, "Download was cancelled"
await asyncio.sleep(self._poll_interval)
finally:
self._transfers.pop(download_id, None)
async def get_status(self, download_id: str) -> Optional[Dict[str, Any]]:
"""Return the raw aria2 status payload for a known download."""
transfer = self._transfers.get(download_id)
if transfer is None:
return None
keys = [
"gid",
"status",
"totalLength",
"completedLength",
"downloadSpeed",
"errorMessage",
"files",
]
try:
status = await self._rpc_call("aria2.tellStatus", [transfer.gid, keys])
except Exception as exc:
raise Aria2Error(f"Failed to query aria2 download status: {exc}") from exc
if isinstance(status, dict):
return status
return None
async def has_transfer(self, download_id: str) -> bool:
return download_id in self._transfers
async def pause_download(self, download_id: str) -> Dict[str, Any]:
transfer = self._transfers.get(download_id)
if transfer is None:
return {"success": False, "error": "Download task not found"}
try:
await self._rpc_call("aria2.forcePause", [transfer.gid])
except Exception as exc:
return {"success": False, "error": str(exc)}
return {"success": True, "message": "Download paused successfully"}
async def resume_download(self, download_id: str) -> Dict[str, Any]:
transfer = self._transfers.get(download_id)
if transfer is None:
return {"success": False, "error": "Download task not found"}
try:
await self._rpc_call("aria2.unpause", [transfer.gid])
except Exception as exc:
return {"success": False, "error": str(exc)}
return {"success": True, "message": "Download resumed successfully"}
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
transfer = self._transfers.get(download_id)
if transfer is None:
return {"success": False, "error": "Download task not found"}
try:
await self._rpc_call("aria2.forceRemove", [transfer.gid])
except Exception as exc:
return {"success": False, "error": str(exc)}
return {"success": True, "message": "Download cancelled successfully"}
async def close(self) -> None:
"""Shut down the RPC process and session."""
if self._rpc_session is not None:
await self._rpc_session.close()
self._rpc_session = None
process = self._process
self._process = None
self._transfers.clear()
if process is None:
return
if process.returncode is None:
process.terminate()
try:
await asyncio.wait_for(process.wait(), timeout=5.0)
except asyncio.TimeoutError:
process.kill()
await process.wait()
async def _dispatch_progress(self, callback, snapshot: DownloadProgress) -> None:
try:
result = callback(snapshot, snapshot)
except TypeError:
result = callback(snapshot.percent_complete)
if asyncio.iscoroutine(result):
await result
elif hasattr(result, "__await__"):
await result
def _build_progress_snapshot(self, status: Dict[str, Any]) -> DownloadProgress:
completed = self._parse_int(status.get("completedLength"))
total = self._parse_int(status.get("totalLength"))
speed = float(self._parse_int(status.get("downloadSpeed")))
percent = 0.0
if total > 0:
percent = (completed / total) * 100.0
return DownloadProgress(
percent_complete=max(0.0, min(percent, 100.0)),
bytes_downloaded=completed,
total_bytes=total or None,
bytes_per_second=speed,
timestamp=datetime.now().timestamp(),
)
def _resolve_completed_path(self, status: Dict[str, Any], default_path: str) -> str:
files = status.get("files")
if isinstance(files, list) and files:
first = files[0]
if isinstance(first, dict):
candidate = first.get("path")
if isinstance(candidate, str) and candidate:
return candidate
return default_path
@staticmethod
def _parse_int(value: Any) -> int:
try:
return int(value)
except (TypeError, ValueError):
return 0
async def _resolve_authenticated_redirect_url(
self,
url: str,
headers: Dict[str, str],
) -> str:
downloader = await get_downloader()
session = await downloader.session
request_headers = dict(downloader.default_headers)
request_headers.update(headers)
request_headers["Accept-Encoding"] = "identity"
try:
async with session.get(
url,
headers=request_headers,
allow_redirects=False,
proxy=downloader.proxy_url,
) as response:
if response.status in {301, 302, 303, 307, 308}:
location = response.headers.get("Location")
if location:
return location
raise Aria2Error(
"Authenticated Civitai redirect did not include a Location header"
)
if response.status == 200:
return url
body = await response.text()
raise Aria2Error(
f"Failed to resolve authenticated Civitai redirect: status={response.status} body={body[:300]}"
)
except aiohttp.ClientError as exc:
raise Aria2Error(
f"Failed to resolve authenticated Civitai redirect: {exc}"
) from exc
async def _ensure_process(self) -> None:
async with self._process_lock:
if self.is_running and await self._ping():
return
await self.close()
executable = self._resolve_executable()
self._rpc_port = self._find_free_port()
self._rpc_secret = secrets.token_hex(16)
self._rpc_url = f"http://127.0.0.1:{self._rpc_port}/jsonrpc"
command = [
executable,
"--enable-rpc=true",
"--rpc-listen-all=false",
f"--rpc-listen-port={self._rpc_port}",
f"--rpc-secret={self._rpc_secret}",
"--check-certificate=true",
"--allow-overwrite=true",
"--auto-file-renaming=false",
"--file-allocation=none",
"--max-concurrent-downloads=5",
"--continue=true",
"--daemon=false",
"--quiet=true",
f"--stop-with-process={os.getpid()}",
]
logger.info("Starting aria2 RPC daemon from %s", executable)
self._process = await asyncio.create_subprocess_exec(
*command,
stdout=asyncio.subprocess.DEVNULL,
stderr=asyncio.subprocess.PIPE,
)
await self._wait_until_ready()
def _resolve_executable(self) -> str:
settings = get_settings_manager()
configured_path = (settings.get("aria2c_path") or "").strip()
candidate = configured_path or "aria2c"
resolved = shutil.which(candidate)
if resolved:
return resolved
if configured_path and os.path.isfile(configured_path) and os.access(
configured_path, os.X_OK
):
return configured_path
raise Aria2Error(
"aria2c executable was not found. Install aria2 or configure aria2c_path."
)
async def _wait_until_ready(self) -> None:
assert self._process is not None
start_time = asyncio.get_running_loop().time()
last_error = ""
while asyncio.get_running_loop().time() - start_time < 10.0:
if self._process.returncode is not None:
stderr_output = ""
if self._process.stderr is not None:
try:
stderr_output = (
await asyncio.wait_for(self._process.stderr.read(), timeout=0.2)
).decode("utf-8", errors="replace")
except Exception:
stderr_output = ""
raise Aria2Error(
f"aria2 RPC process exited early with code {self._process.returncode}: {stderr_output.strip()}"
)
try:
if await self._ping():
return
except Exception as exc: # pragma: no cover - startup race
last_error = str(exc)
await asyncio.sleep(0.2)
raise Aria2Error(
f"Timed out waiting for aria2 RPC to become ready{': ' + last_error if last_error else ''}"
)
async def _ping(self) -> bool:
try:
result = await self._rpc_call("aria2.getVersion", [])
except Exception:
return False
return isinstance(result, dict)
async def _rpc_call(self, method: str, params: list[Any]) -> Any:
if not self._rpc_url:
raise Aria2Error("aria2 RPC endpoint is not initialized")
session = await self._get_rpc_session()
payload = {
"jsonrpc": "2.0",
"id": secrets.token_hex(8),
"method": method,
"params": [f"token:{self._rpc_secret}", *params],
}
async with session.post(self._rpc_url, json=payload) as response:
text = await response.text()
try:
body = json.loads(text)
except json.JSONDecodeError:
body = None
if body is None:
if response.status != 200:
raise Aria2Error(
f"aria2 RPC returned status {response.status} with non-JSON body: {text}"
)
raise Aria2Error(f"Invalid aria2 RPC response: {text}")
if "error" in body:
error = body["error"] or {}
code = error.get("code") if isinstance(error, dict) else None
message = error.get("message") if isinstance(error, dict) else str(error)
logger.error(
"aria2 RPC %s failed with HTTP %s, code=%s, message=%s",
method,
response.status,
code,
message,
)
status_message = (
f"aria2 RPC {method} failed with status {response.status}: {message}"
if response.status != 200
else message
)
raise Aria2Error(status_message or "Unknown aria2 RPC error")
if response.status != 200:
logger.error(
"aria2 RPC %s returned unexpected HTTP status %s without error payload: %s",
method,
response.status,
body,
)
raise Aria2Error(
f"aria2 RPC {method} returned unexpected status {response.status}"
)
return body.get("result")
async def _get_rpc_session(self) -> aiohttp.ClientSession:
if self._rpc_session is None or self._rpc_session.closed:
async with self._rpc_session_lock:
if self._rpc_session is None or self._rpc_session.closed:
timeout = aiohttp.ClientTimeout(total=30)
self._rpc_session = aiohttp.ClientSession(timeout=timeout)
return self._rpc_session
@staticmethod
def _find_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("127.0.0.1", 0))
sock.listen(1)
return int(sock.getsockname()[1])
async def get_aria2_downloader() -> Aria2Downloader:
"""Get the singleton aria2 downloader."""
return await Aria2Downloader.get_instance()

View File

@@ -5,6 +5,7 @@ import asyncio
import inspect import inspect
import shutil import shutil
import zipfile import zipfile
from concurrent.futures import ThreadPoolExecutor
from collections import OrderedDict from collections import OrderedDict
import uuid import uuid
from typing import Dict, List, Optional, Set, Tuple from typing import Dict, List, Optional, Set, Tuple
@@ -25,6 +26,7 @@ from .service_registry import ServiceRegistry
from .settings_manager import get_settings_manager from .settings_manager import get_settings_manager
from .metadata_service import get_default_metadata_provider, get_metadata_provider from .metadata_service import get_default_metadata_provider, get_metadata_provider
from .downloader import get_downloader, DownloadProgress, DownloadStreamControl from .downloader import get_downloader, DownloadProgress, DownloadStreamControl
from .aria2_downloader import Aria2Error, get_aria2_downloader
# Download to temporary file first # Download to temporary file first
import tempfile import tempfile
@@ -60,6 +62,59 @@ class DownloadManager:
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
self._download_tasks = {} # download_id -> asyncio.Task self._download_tasks = {} # download_id -> asyncio.Task
self._pause_events: Dict[str, DownloadStreamControl] = {} self._pause_events: Dict[str, DownloadStreamControl] = {}
self._archive_executor = ThreadPoolExecutor(
max_workers=2, thread_name_prefix="lm-archive"
)
@staticmethod
def _get_model_download_backend() -> str:
backend = (get_settings_manager().get("download_backend") or "python").strip()
return backend.lower() or "python"
async def _download_model_file(
self,
download_url: str,
save_path: str,
*,
backend: str,
progress_callback,
use_auth: bool,
download_id: Optional[str],
pause_control: Optional[DownloadStreamControl],
) -> Tuple[bool, str]:
if backend == "aria2":
if not download_id:
return False, "aria2 downloads require a tracked download_id"
headers: Dict[str, str] = {}
if use_auth:
api_key = (get_settings_manager().get("civitai_api_key") or "").strip()
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
try:
aria2_downloader = await get_aria2_downloader()
return await aria2_downloader.download_file(
download_url,
save_path,
download_id=download_id,
progress_callback=progress_callback,
headers=headers or None,
)
except Aria2Error as exc:
logger.error("aria2 download failed for %s: %s", download_url, exc)
return False, str(exc)
download_kwargs = {
"progress_callback": progress_callback,
"use_auth": use_auth,
}
if pause_control is not None:
download_kwargs["pause_event"] = pause_control
downloader = await get_downloader()
return await downloader.download_file(download_url, save_path, **download_kwargs)
async def _get_lora_scanner(self): async def _get_lora_scanner(self):
"""Get the lora scanner from registry""" """Get the lora scanner from registry"""
@@ -126,6 +181,7 @@ class DownloadManager:
"model_version_id": model_version_id, "model_version_id": model_version_id,
"progress": 0, "progress": 0,
"status": "queued", "status": "queued",
"transfer_backend": self._get_model_download_backend(),
"bytes_downloaded": 0, "bytes_downloaded": 0,
"total_bytes": None, "total_bytes": None,
"bytes_per_second": 0.0, "bytes_per_second": 0.0,
@@ -240,6 +296,9 @@ class DownloadManager:
tracking_callback, tracking_callback,
use_default_paths, use_default_paths,
task_id, task_id,
self._active_downloads.get(task_id, {}).get(
"transfer_backend", "python"
),
source, source,
file_params, file_params,
) )
@@ -294,6 +353,7 @@ class DownloadManager:
progress_callback, progress_callback,
use_default_paths, use_default_paths,
download_id=None, download_id=None,
transfer_backend="python",
source=None, source=None,
file_params=None, file_params=None,
): ):
@@ -696,16 +756,27 @@ class DownloadManager:
logger.info(f"Creating EmbeddingMetadata for {file_name}") logger.info(f"Creating EmbeddingMetadata for {file_name}")
# 6. Start download process # 6. Start download process
result = await self._execute_download( execute_kwargs = {
download_urls=download_urls, "download_urls": download_urls,
save_dir=save_dir, "save_dir": save_dir,
metadata=metadata, "metadata": metadata,
version_info=version_info, "version_info": version_info,
relative_path=relative_path, "relative_path": relative_path,
progress_callback=progress_callback, "progress_callback": progress_callback,
model_type=model_type, "model_type": model_type,
download_id=download_id, "download_id": download_id,
) }
execute_signature = inspect.signature(self._execute_download)
if (
"transfer_backend" in execute_signature.parameters
or any(
parameter.kind == inspect.Parameter.VAR_KEYWORD
for parameter in execute_signature.parameters.values()
)
):
execute_kwargs["transfer_backend"] = transfer_backend
result = await self._execute_download(**execute_kwargs)
if result.get("success", False): if result.get("success", False):
resolved_model_id = ( resolved_model_id = (
@@ -965,6 +1036,7 @@ class DownloadManager:
progress_callback=None, progress_callback=None,
model_type: str = "lora", model_type: str = "lora",
download_id: str = None, download_id: str = None,
transfer_backend: Optional[str] = None,
) -> Dict: ) -> Dict:
"""Execute the actual download process including preview images and model files""" """Execute the actual download process including preview images and model files"""
metadata_entries: List = [] metadata_entries: List = []
@@ -974,6 +1046,7 @@ class DownloadManager:
preview_targets: List[str] = [] preview_targets: List[str] = []
preview_path: str | None = None preview_path: str | None = None
preview_nsfw_level = 0 preview_nsfw_level = 0
transfer_backend = (transfer_backend or self._get_model_download_backend()).lower()
try: try:
# Extract original filename details # Extract original filename details
original_filename = os.path.basename(metadata.file_path) original_filename = os.path.basename(metadata.file_path)
@@ -1136,32 +1209,37 @@ class DownloadManager:
if progress_callback: if progress_callback:
await progress_callback(3) # 3% progress after preview download await progress_callback(3) # 3% progress after preview download
# Download model file with progress tracking using downloader # Download model file with progress tracking using the configured backend
downloader = await get_downloader() downloader = None
if pause_control is not None: if transfer_backend == "python":
pause_control.update_stall_timeout(downloader.stall_timeout) downloader = await get_downloader()
if pause_control is not None:
pause_control.update_stall_timeout(downloader.stall_timeout)
if pause_control is not None and pause_control.is_paused():
if download_id and download_id in self._active_downloads:
self._active_downloads[download_id]["status"] = "paused"
self._active_downloads[download_id]["bytes_per_second"] = 0.0
await pause_control.wait()
if download_id and download_id in self._active_downloads:
self._active_downloads[download_id]["status"] = "downloading"
last_error = None last_error = None
for download_url in download_urls: for download_url in download_urls:
download_url = normalize_civitai_download_url(download_url) download_url = normalize_civitai_download_url(download_url)
use_auth = download_url.startswith(CIVITAI_DOWNLOAD_URL_PREFIXES) use_auth = download_url.startswith(CIVITAI_DOWNLOAD_URL_PREFIXES)
download_kwargs = { success, result = await self._download_model_file(
"progress_callback": lambda progress, snapshot=None: ( download_url,
save_path,
backend=transfer_backend,
progress_callback=lambda progress, snapshot=None: (
self._handle_download_progress( self._handle_download_progress(
progress, progress,
progress_callback, progress_callback,
snapshot, snapshot,
) )
), ),
"use_auth": use_auth, # Only use authentication for Civitai downloads use_auth=use_auth,
} download_id=download_id,
pause_control=pause_control,
if pause_control is not None:
download_kwargs["pause_event"] = pause_control
success, result = await downloader.download_file(
download_url,
save_path, # Use full path instead of separate dir and filename
**download_kwargs,
) )
if success: if success:
@@ -1401,7 +1479,8 @@ class DownloadManager:
extracted_files.append(dest_path) extracted_files.append(dest_path)
return extracted_files return extracted_files
return await asyncio.to_thread(_extract_sync) loop = asyncio.get_running_loop()
return await loop.run_in_executor(self._archive_executor, _extract_sync)
async def _build_metadata_entries( async def _build_metadata_entries(
self, base_metadata, file_paths: List[str] self, base_metadata, file_paths: List[str]
@@ -1511,8 +1590,28 @@ class DownloadManager:
return {"success": False, "error": "Download task not found"} return {"success": False, "error": "Download task not found"}
try: try:
# Get the task and cancel it
task = self._download_tasks[download_id] task = self._download_tasks[download_id]
backend = (
self._active_downloads.get(download_id, {}).get("transfer_backend")
or "python"
)
if backend == "aria2":
try:
aria2_downloader = await get_aria2_downloader()
cancel_result = await aria2_downloader.cancel_download(download_id)
if (
not cancel_result.get("success")
and cancel_result.get("error") != "Download task not found"
):
return cancel_result
except Exception as exc:
logger.warning(
"Failed to cancel aria2 transfer for %s, continuing with local task cancellation: %s",
download_id,
exc,
)
task.cancel() task.cancel()
pause_control = self._pause_events.get(download_id) pause_control = self._pause_events.get(download_id)
@@ -1613,6 +1712,28 @@ class DownloadManager:
pause_control.pause() pause_control.pause()
backend = (
self._active_downloads.get(download_id, {}).get("transfer_backend")
or "python"
)
if backend == "aria2":
try:
aria2_downloader = await get_aria2_downloader()
if await aria2_downloader.has_transfer(download_id):
result = await aria2_downloader.pause_download(download_id)
if not result.get("success"):
pause_control.resume()
return result
except Exception as exc:
pause_control.resume()
return {"success": False, "error": str(exc)}
download_info = self._active_downloads.get(download_id)
if download_info is not None:
download_info["status"] = "paused"
download_info["bytes_per_second"] = 0.0
return {"success": True, "message": "Download paused successfully"}
download_info = self._active_downloads.get(download_id) download_info = self._active_downloads.get(download_id)
if download_info is not None: if download_info is not None:
download_info["status"] = "paused" download_info["status"] = "paused"
@@ -1631,6 +1752,28 @@ class DownloadManager:
return {"success": False, "error": "Download is not paused"} return {"success": False, "error": "Download is not paused"}
download_info = self._active_downloads.get(download_id) download_info = self._active_downloads.get(download_id)
backend = (
self._active_downloads.get(download_id, {}).get("transfer_backend")
or "python"
)
if backend == "aria2":
try:
aria2_downloader = await get_aria2_downloader()
if await aria2_downloader.has_transfer(download_id):
result = await aria2_downloader.resume_download(download_id)
if not result.get("success"):
return result
except Exception as exc:
return {"success": False, "error": str(exc)}
pause_control.resume()
if download_info is not None:
if download_info.get("status") == "paused":
download_info["status"] = "downloading"
download_info.setdefault("bytes_per_second", 0.0)
return {"success": True, "message": "Download resumed successfully"}
force_reconnect = False force_reconnect = False
if pause_control is not None: if pause_control is not None:
elapsed = pause_control.time_since_last_progress() elapsed = pause_control.time_since_last_progress()

View File

@@ -55,6 +55,8 @@ DEFAULT_KEYS_CLEANUP_THRESHOLD = 10
DEFAULT_SETTINGS: Dict[str, Any] = { DEFAULT_SETTINGS: Dict[str, Any] = {
"civitai_api_key": "", "civitai_api_key": "",
"civitai_host": "civitai.com", "civitai_host": "civitai.com",
"download_backend": "python",
"aria2c_path": "",
"use_portable_settings": False, "use_portable_settings": False,
"hash_chunk_size_mb": DEFAULT_HASH_CHUNK_SIZE_MB, "hash_chunk_size_mb": DEFAULT_HASH_CHUNK_SIZE_MB,
"language": "en", "language": "en",

View File

@@ -807,6 +807,16 @@ export class SettingsManager {
civitaiHostSelect.value = state.global.settings.civitai_host || 'civitai.com'; civitaiHostSelect.value = state.global.settings.civitai_host || 'civitai.com';
} }
const downloadBackendSelect = document.getElementById('downloadBackend');
if (downloadBackendSelect) {
downloadBackendSelect.value = state.global.settings.download_backend || 'python';
}
const aria2cPathInput = document.getElementById('aria2cPath');
if (aria2cPathInput) {
aria2cPathInput.value = state.global.settings.aria2c_path || '';
}
const recipesPathInput = document.getElementById('recipesPath'); const recipesPathInput = document.getElementById('recipesPath');
if (recipesPathInput) { if (recipesPathInput) {
recipesPathInput.value = state.global.settings.recipes_path || ''; recipesPathInput.value = state.global.settings.recipes_path || '';
@@ -950,9 +960,36 @@ export class SettingsManager {
languageSelect.value = currentLanguage; languageSelect.value = currentLanguage;
} }
this.loadDownloadBackendSettings();
this.loadProxySettings(); this.loadProxySettings();
} }
loadDownloadBackendSettings() {
const downloadBackendSelect = document.getElementById('downloadBackend');
const aria2PathSetting = document.getElementById('aria2PathSetting');
const updateVisibility = () => {
if (!aria2PathSetting || !downloadBackendSelect) {
return;
}
aria2PathSetting.style.display = downloadBackendSelect.value === 'aria2' ? 'block' : 'none';
};
if (downloadBackendSelect) {
downloadBackendSelect.value = state.global.settings.download_backend || 'python';
downloadBackendSelect.onchange = () => {
updateVisibility();
this.saveSelectSetting('downloadBackend', 'download_backend');
};
}
const aria2cPathInput = document.getElementById('aria2cPath');
if (aria2cPathInput) {
aria2cPathInput.value = state.global.settings.aria2c_path || '';
}
updateVisibility();
}
setupPriorityTagInputs() { setupPriorityTagInputs() {
['lora', 'checkpoint', 'embedding'].forEach((modelType) => { ['lora', 'checkpoint', 'embedding'].forEach((modelType) => {
const textarea = document.getElementById(`${modelType}PriorityTagsInput`); const textarea = document.getElementById(`${modelType}PriorityTagsInput`);

View File

@@ -6,6 +6,8 @@ import { DEFAULT_PATH_TEMPLATES, DEFAULT_PRIORITY_TAG_CONFIG } from '../utils/co
const DEFAULT_SETTINGS_BASE = Object.freeze({ const DEFAULT_SETTINGS_BASE = Object.freeze({
civitai_api_key: '', civitai_api_key: '',
civitai_host: 'civitai.com', civitai_host: 'civitai.com',
download_backend: 'python',
aria2c_path: '',
use_portable_settings: false, use_portable_settings: false,
language: 'en', language: 'en',
show_only_sfw: false, show_only_sfw: false,

View File

@@ -129,6 +129,43 @@
</div> </div>
</div> </div>
<div class="settings-subsection">
<div class="settings-subsection-header">
<h4>{{ t('settings.sections.downloads') }}</h4>
</div>
<div class="setting-item">
<div class="setting-row">
<div class="setting-info">
<label for="downloadBackend">{{ t('settings.downloadBackend.label') }}</label>
<i class="fas fa-info-circle info-icon" data-tooltip="{{ t('settings.downloadBackend.help') }}"></i>
</div>
<div class="setting-control select-control">
<select id="downloadBackend" onchange="settingsManager.saveSelectSetting('downloadBackend', 'download_backend')">
<option value="python">{{ t('settings.downloadBackend.options.python') }}</option>
<option value="aria2">{{ t('settings.downloadBackend.options.aria2') }}</option>
</select>
</div>
</div>
</div>
<div class="setting-item" id="aria2PathSetting" style="display: none;">
<div class="setting-row">
<div class="setting-info">
<label for="aria2cPath">{{ t('settings.aria2cPath.label') }}</label>
<i class="fas fa-info-circle info-icon" data-tooltip="{{ t('settings.aria2cPath.help') }}"></i>
</div>
<div class="setting-control">
<div class="text-input-wrapper">
<input type="text"
id="aria2cPath"
placeholder="{{ t('settings.aria2cPath.placeholder') }}"
onblur="settingsManager.saveInputSetting('aria2cPath', 'aria2c_path')"
onkeydown="if(event.key === 'Enter') { this.blur(); }" />
</div>
</div>
</div>
</div>
</div>
<!-- Backup --> <!-- Backup -->
<div class="settings-subsection"> <div class="settings-subsection">
<div class="settings-subsection-header"> <div class="settings-subsection-header">

View File

@@ -305,4 +305,39 @@ describe('SettingsManager library controls', () => {
'success', 'success',
); );
}); });
it('loads download backend settings and toggles the aria2 path field', () => {
const manager = createManager();
document.body.innerHTML = `
<select id="downloadBackend">
<option value="python">Python</option>
<option value="aria2">aria2</option>
</select>
<div id="aria2PathSetting" style="display: none;"></div>
<input id="aria2cPath" />
`;
state.global.settings = {
download_backend: 'aria2',
aria2c_path: '/usr/bin/aria2c',
};
const saveSpy = vi.spyOn(manager, 'saveSelectSetting').mockResolvedValue();
manager.loadDownloadBackendSettings();
const backendSelect = document.getElementById('downloadBackend');
const aria2PathSetting = document.getElementById('aria2PathSetting');
const aria2cPath = document.getElementById('aria2cPath');
expect(backendSelect.value).toBe('aria2');
expect(aria2cPath.value).toBe('/usr/bin/aria2c');
expect(aria2PathSetting.style.display).toBe('block');
backendSelect.value = 'python';
backendSelect.onchange();
expect(aria2PathSetting.style.display).toBe('none');
expect(saveSpy).toHaveBeenCalledWith('downloadBackend', 'download_backend');
});
}); });

View File

@@ -0,0 +1,269 @@
from __future__ import annotations
from pathlib import Path
from unittest.mock import AsyncMock
import pytest
from py.services.aria2_downloader import Aria2Downloader, Aria2Error
@pytest.mark.asyncio
async def test_download_file_polls_until_complete(tmp_path, monkeypatch):
downloader = Aria2Downloader()
downloader._rpc_url = "http://127.0.0.1/jsonrpc"
downloader._rpc_secret = "secret"
save_path = tmp_path / "downloads" / "model.safetensors"
progress_events = []
rpc_calls = []
statuses = iter(
[
{
"gid": "gid-1",
"status": "active",
"completedLength": "5",
"totalLength": "10",
"downloadSpeed": "25",
},
{
"gid": "gid-1",
"status": "complete",
"completedLength": "10",
"totalLength": "10",
"downloadSpeed": "0",
"files": [{"path": str(save_path)}],
},
]
)
async def fake_rpc_call(method, params):
rpc_calls.append((method, params))
if method == "aria2.addUri":
return "gid-1"
if method == "aria2.tellStatus":
return next(statuses)
raise AssertionError(f"Unexpected RPC method: {method}")
monkeypatch.setattr(downloader, "_ensure_process", AsyncMock())
monkeypatch.setattr(
downloader,
"_resolve_authenticated_redirect_url",
AsyncMock(
return_value="https://signed.example.com/model.safetensors?token=abc"
),
)
monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call)
monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock())
async def progress_callback(progress, snapshot=None):
progress_events.append(snapshot.percent_complete if snapshot else progress)
success, result = await downloader.download_file(
"https://civitai.com/api/download/models/123",
str(save_path),
download_id="download-1",
progress_callback=progress_callback,
headers={"Authorization": "Bearer token"},
)
assert success is True
assert result == str(save_path)
assert progress_events == [50.0, 100.0]
assert downloader._transfers == {}
assert rpc_calls[0][0] == "aria2.addUri"
assert rpc_calls[0][1][0] == [
"https://signed.example.com/model.safetensors?token=abc"
]
assert rpc_calls[0][1][1]["out"] == "model.safetensors"
assert "header" not in rpc_calls[0][1][1]
@pytest.mark.asyncio
async def test_download_file_keeps_auth_headers_when_civitai_does_not_redirect(
tmp_path, monkeypatch
):
downloader = Aria2Downloader()
downloader._rpc_url = "http://127.0.0.1/jsonrpc"
downloader._rpc_secret = "secret"
save_path = tmp_path / "downloads" / "model.safetensors"
rpc_calls = []
statuses = iter(
[
{
"gid": "gid-1",
"status": "complete",
"completedLength": "10",
"totalLength": "10",
"downloadSpeed": "0",
"files": [{"path": str(save_path)}],
},
]
)
async def fake_rpc_call(method, params):
rpc_calls.append((method, params))
if method == "aria2.addUri":
return "gid-1"
if method == "aria2.tellStatus":
return next(statuses)
raise AssertionError(f"Unexpected RPC method: {method}")
monkeypatch.setattr(downloader, "_ensure_process", AsyncMock())
monkeypatch.setattr(
downloader,
"_resolve_authenticated_redirect_url",
AsyncMock(return_value="https://civitai.com/api/download/models/123"),
)
monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call)
monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock())
success, result = await downloader.download_file(
"https://civitai.com/api/download/models/123",
str(save_path),
download_id="download-1",
headers={"Authorization": "Bearer token"},
)
assert success is True
assert result == str(save_path)
assert rpc_calls[0][1][0] == ["https://civitai.com/api/download/models/123"]
assert rpc_calls[0][1][1]["header"] == ["Authorization: Bearer token"]
@pytest.mark.asyncio
async def test_pause_resume_cancel_forward_to_rpc(monkeypatch):
downloader = Aria2Downloader()
downloader._transfers["download-1"] = type(
"Transfer", (), {"gid": "gid-1", "save_path": "/tmp/model.safetensors"}
)()
calls = []
async def fake_rpc_call(method, params):
calls.append((method, params))
return "gid-1"
monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call)
pause_result = await downloader.pause_download("download-1")
resume_result = await downloader.resume_download("download-1")
cancel_result = await downloader.cancel_download("download-1")
assert pause_result["success"] is True
assert resume_result["success"] is True
assert cancel_result["success"] is True
assert calls == [
("aria2.forcePause", ["gid-1"]),
("aria2.unpause", ["gid-1"]),
("aria2.forceRemove", ["gid-1"]),
]
def test_build_progress_snapshot_normalizes_numeric_fields():
downloader = Aria2Downloader()
snapshot = downloader._build_progress_snapshot(
{
"completedLength": "75",
"totalLength": "100",
"downloadSpeed": "512",
}
)
assert snapshot.percent_complete == 75.0
assert snapshot.bytes_downloaded == 75
assert snapshot.total_bytes == 100
assert snapshot.bytes_per_second == 512.0
def test_resolve_executable_raises_when_binary_missing(monkeypatch):
downloader = Aria2Downloader()
settings = type("Settings", (), {"get": lambda self, key, default=None: ""})()
monkeypatch.setattr("py.services.aria2_downloader.get_settings_manager", lambda: settings)
monkeypatch.setattr("py.services.aria2_downloader.shutil.which", lambda _: None)
with pytest.raises(Aria2Error):
downloader._resolve_executable()
@pytest.mark.asyncio
async def test_rpc_call_surfaces_json_error_on_non_200(monkeypatch):
downloader = Aria2Downloader()
downloader._rpc_url = "http://127.0.0.1:6800/jsonrpc"
downloader._rpc_secret = "secret"
class FakeResponse:
status = 400
async def text(self):
return (
'{"jsonrpc":"2.0","id":"x","error":{"code":1,"message":"Unauthorized"}}'
)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
class FakeSession:
def post(self, _url, json=None):
return FakeResponse()
monkeypatch.setattr(downloader, "_get_rpc_session", AsyncMock(return_value=FakeSession()))
with pytest.raises(Aria2Error) as exc_info:
await downloader._rpc_call("aria2.addUri", [["https://example.com/file"]])
assert "Unauthorized" in str(exc_info.value)
assert "aria2.addUri" in str(exc_info.value)
@pytest.mark.asyncio
async def test_resolve_authenticated_redirect_url_returns_location(monkeypatch):
downloader = Aria2Downloader()
class FakeResponse:
status = 307
headers = {"Location": "https://signed.example.com/file.safetensors"}
async def text(self):
return ""
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
class FakeSession:
def get(self, _url, headers=None, allow_redirects=False, proxy=None):
return FakeResponse()
class FakeDownloader:
default_headers = {"User-Agent": "ComfyUI-LoRA-Manager/1.0"}
proxy_url = None
@property
def session(self):
async def _session():
return FakeSession()
return _session()
fake_downloader = FakeDownloader()
monkeypatch.setattr(
"py.services.aria2_downloader.get_downloader",
AsyncMock(return_value=fake_downloader),
)
result = await downloader._resolve_authenticated_redirect_url(
"https://civitai.com/api/download/models/123",
{"Authorization": "Bearer token"},
)
assert result == "https://signed.example.com/file.safetensors"

View File

@@ -179,6 +179,7 @@ async def test_successful_download_uses_defaults(
progress_callback, progress_callback,
model_type, model_type,
download_id, download_id,
transfer_backend=None,
): ):
captured.update( captured.update(
{ {
@@ -268,6 +269,7 @@ async def test_download_uses_active_mirrors(
progress_callback, progress_callback,
model_type, model_type,
download_id, download_id,
transfer_backend=None,
): ):
captured["download_urls"] = download_urls captured["download_urls"] = download_urls
return {"success": True} return {"success": True}
@@ -288,6 +290,214 @@ async def test_download_uses_active_mirrors(
assert captured["download_urls"] == ["https://mirror.example/file.safetensors"] assert captured["download_urls"] == ["https://mirror.example/file.safetensors"]
@pytest.mark.asyncio
async def test_pause_resume_cancel_delegate_to_aria2_backend(monkeypatch):
manager = DownloadManager()
task = asyncio.create_task(asyncio.sleep(60))
manager._download_tasks["download-1"] = task
manager._pause_events["download-1"] = download_manager.DownloadStreamControl()
manager._active_downloads["download-1"] = {
"transfer_backend": "aria2",
"status": "downloading",
}
class DummyAria2Downloader:
def __init__(self):
self.calls = []
async def pause_download(self, download_id):
self.calls.append(("pause", download_id))
return {"success": True, "message": "paused"}
async def resume_download(self, download_id):
self.calls.append(("resume", download_id))
return {"success": True, "message": "resumed"}
async def cancel_download(self, download_id):
self.calls.append(("cancel", download_id))
return {"success": True, "message": "cancelled"}
async def has_transfer(self, download_id):
self.calls.append(("has_transfer", download_id))
return True
dummy_aria2 = DummyAria2Downloader()
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=dummy_aria2),
)
pause_result = await manager.pause_download("download-1")
assert pause_result["success"] is True
assert manager._active_downloads["download-1"]["status"] == "paused"
resume_result = await manager.resume_download("download-1")
assert resume_result["success"] is True
assert manager._active_downloads["download-1"]["status"] == "downloading"
cancel_result = await manager.cancel_download("download-1")
assert cancel_result["success"] is True
assert task.cancelled() or task.done()
assert dummy_aria2.calls == [
("has_transfer", "download-1"),
("pause", "download-1"),
("has_transfer", "download-1"),
("resume", "download-1"),
("cancel", "download-1"),
]
@pytest.mark.asyncio
async def test_cancel_allows_queued_aria2_task_without_transfer(monkeypatch):
manager = DownloadManager()
started = asyncio.Event()
async def blocked_task():
started.set()
await asyncio.sleep(60)
task = asyncio.create_task(blocked_task())
await started.wait()
manager._download_tasks["download-queued"] = task
manager._active_downloads["download-queued"] = {
"transfer_backend": "aria2",
"status": "queued",
}
class DummyAria2Downloader:
async def cancel_download(self, download_id):
return {"success": False, "error": "Download task not found"}
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
result = await manager.cancel_download("download-queued")
assert result["success"] is True
assert task.cancelled() or task.done()
@pytest.mark.asyncio
async def test_pause_resume_queued_aria2_task_without_transfer(monkeypatch):
manager = DownloadManager()
task = asyncio.create_task(asyncio.sleep(60))
manager._download_tasks["download-queued"] = task
manager._pause_events["download-queued"] = download_manager.DownloadStreamControl()
manager._active_downloads["download-queued"] = {
"transfer_backend": "aria2",
"status": "waiting",
"bytes_per_second": 12.0,
}
class DummyAria2Downloader:
def __init__(self):
self.calls = []
async def has_transfer(self, download_id):
self.calls.append(("has_transfer", download_id))
return False
async def pause_download(self, download_id):
self.calls.append(("pause", download_id))
return {"success": True, "message": "paused"}
async def resume_download(self, download_id):
self.calls.append(("resume", download_id))
return {"success": True, "message": "resumed"}
dummy_aria2 = DummyAria2Downloader()
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=dummy_aria2),
)
pause_result = await manager.pause_download("download-queued")
assert pause_result == {"success": True, "message": "Download paused successfully"}
assert manager._active_downloads["download-queued"]["status"] == "paused"
assert manager._pause_events["download-queued"].is_paused() is True
resume_result = await manager.resume_download("download-queued")
assert resume_result == {"success": True, "message": "Download resumed successfully"}
assert manager._active_downloads["download-queued"]["status"] == "downloading"
assert manager._pause_events["download-queued"].is_set() is True
assert dummy_aria2.calls == [
("has_transfer", "download-queued"),
("has_transfer", "download-queued"),
]
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task
@pytest.mark.asyncio
async def test_download_uses_captured_backend_when_settings_change(
monkeypatch, scanners, metadata_provider, tmp_path
):
manager = DownloadManager()
settings = get_settings_manager()
settings.settings["download_backend"] = "aria2"
semaphore = asyncio.Semaphore(0)
manager._download_semaphore = semaphore
captured = {}
async def fake_execute_original_download(
self,
model_id,
model_version_id,
save_dir,
relative_path,
progress_callback,
use_default_paths,
download_id=None,
transfer_backend="python",
source=None,
file_params=None,
):
captured["transfer_backend"] = transfer_backend
return {"success": True}
monkeypatch.setattr(
DownloadManager,
"_execute_original_download",
fake_execute_original_download,
)
download_task = asyncio.create_task(
manager.download_from_civitai(
model_version_id=99,
save_dir=str(tmp_path),
use_default_paths=True,
progress_callback=None,
source=None,
)
)
await asyncio.sleep(0)
assert len(manager._active_downloads) == 1
download_id = next(iter(manager._active_downloads))
assert manager._active_downloads[download_id]["transfer_backend"] == "aria2"
settings.settings["download_backend"] = "python"
semaphore.release()
result = await download_task
assert result["success"] is True
assert captured["transfer_backend"] == "aria2"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_download_aborts_when_version_exists( async def test_download_aborts_when_version_exists(
monkeypatch, scanners, metadata_provider monkeypatch, scanners, metadata_provider

View File

@@ -136,6 +136,190 @@ async def test_execute_download_retries_urls(monkeypatch, tmp_path):
assert dummy_scanner.calls # ensure cache updated assert dummy_scanner.calls # ensure cache updated
@pytest.mark.asyncio
async def test_execute_download_uses_aria2_backend_for_model_files(monkeypatch, tmp_path):
manager = DownloadManager()
settings = get_settings_manager()
settings.settings["download_backend"] = "aria2"
settings.settings["civitai_api_key"] = "secret-key"
save_dir = tmp_path / "downloads"
save_dir.mkdir()
target_path = save_dir / "file.safetensors"
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = str(path)
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
def generate_unique_filename(self, *_args, **_kwargs):
return os.path.basename(self.file_path)
def update_file_info(self, _path):
return None
def to_dict(self):
return {"file_path": self.file_path}
class DummyAria2Downloader:
def __init__(self):
self.calls = []
async def download_file(
self,
url,
save_path,
*,
download_id,
progress_callback=None,
headers=None,
):
self.calls.append(
{
"url": url,
"save_path": save_path,
"download_id": download_id,
"headers": headers,
}
)
Path(save_path).write_text("content")
return True, save_path
dummy_aria2 = DummyAria2Downloader()
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=dummy_aria2),
)
monkeypatch.setattr(
download_manager,
"get_downloader",
AsyncMock(side_effect=AssertionError("python downloader should not be used")),
)
class DummyScanner:
async def add_model_to_cache(self, metadata_dict, relative_path):
return {"metadata": metadata_dict, "relative_path": relative_path}
dummy_scanner = DummyScanner()
monkeypatch.setattr(
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
)
monkeypatch.setattr(
DownloadManager,
"_get_checkpoint_scanner",
AsyncMock(return_value=dummy_scanner),
)
monkeypatch.setattr(
ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=dummy_scanner)
)
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
result = await manager._execute_download(
download_urls=["https://civitai.com/api/download/models/1"],
save_dir=str(save_dir),
metadata=DummyMetadata(target_path),
version_info={"images": []},
relative_path="",
progress_callback=None,
model_type="lora",
download_id="download-1",
)
assert result == {"success": True}
assert dummy_aria2.calls == [
{
"url": "https://civitai.com/api/download/models/1",
"save_path": str(target_path),
"download_id": "download-1",
"headers": {"Authorization": "Bearer secret-key"},
}
]
@pytest.mark.asyncio
async def test_execute_download_allows_anonymous_civitai_with_aria2(
monkeypatch, tmp_path
):
manager = DownloadManager()
settings = get_settings_manager()
settings.settings["download_backend"] = "aria2"
settings.settings["civitai_api_key"] = ""
save_dir = tmp_path / "downloads"
save_dir.mkdir()
target_path = save_dir / "file.safetensors"
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = str(path)
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
def generate_unique_filename(self, *_args, **_kwargs):
return os.path.basename(self.file_path)
def update_file_info(self, _path):
return None
def to_dict(self):
return {"file_path": self.file_path}
class DummyAria2Downloader:
def __init__(self):
self.calls = []
async def download_file(
self,
url,
save_path,
*,
download_id,
progress_callback=None,
headers=None,
):
self.calls.append({"url": url, "headers": headers, "download_id": download_id})
Path(save_path).write_text("content")
return True, save_path
dummy_aria2 = DummyAria2Downloader()
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=dummy_aria2),
)
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
monkeypatch.setattr(
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
)
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
result = await manager._execute_download(
download_urls=["https://civitai.com/api/download/models/1"],
save_dir=str(save_dir),
metadata=DummyMetadata(target_path),
version_info={"images": []},
relative_path="",
progress_callback=None,
model_type="lora",
download_id="download-2",
)
assert result == {"success": True}
assert dummy_aria2.calls == [
{
"url": "https://civitai.com/api/download/models/1",
"headers": None,
"download_id": "download-2",
}
]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_execute_download_adjusts_checkpoint_sub_type(monkeypatch, tmp_path): async def test_execute_download_adjusts_checkpoint_sub_type(monkeypatch, tmp_path):
"""Test that checkpoint sub_type is adjusted during download.""" """Test that checkpoint sub_type is adjusted during download."""
@@ -276,6 +460,13 @@ async def test_execute_download_extracts_zip_single_model(monkeypatch, tmp_path)
monkeypatch.setattr( monkeypatch.setattr(
download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader()) download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader())
) )
class ImmediateLoop:
async def run_in_executor(self, executor, func, *args):
return func(*args)
monkeypatch.setattr(download_manager.asyncio, "get_running_loop", lambda: ImmediateLoop())
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None)) dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
monkeypatch.setattr( monkeypatch.setattr(
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner) DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
@@ -344,6 +535,13 @@ async def test_execute_download_extracts_zip_multiple_models(monkeypatch, tmp_pa
monkeypatch.setattr( monkeypatch.setattr(
download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader()) download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader())
) )
class ImmediateLoop:
async def run_in_executor(self, executor, func, *args):
return func(*args)
monkeypatch.setattr(download_manager.asyncio, "get_running_loop", lambda: ImmediateLoop())
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None)) dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
monkeypatch.setattr( monkeypatch.setattr(
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner) DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
@@ -418,6 +616,13 @@ async def test_execute_download_extracts_zip_pt_embedding(monkeypatch, tmp_path)
monkeypatch.setattr( monkeypatch.setattr(
download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader()) download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader())
) )
class ImmediateLoop:
async def run_in_executor(self, executor, func, *args):
return func(*args)
monkeypatch.setattr(download_manager.asyncio, "get_running_loop", lambda: ImmediateLoop())
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None)) dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
monkeypatch.setattr( monkeypatch.setattr(
ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=dummy_scanner) ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=dummy_scanner)
@@ -446,6 +651,36 @@ async def test_execute_download_extracts_zip_pt_embedding(monkeypatch, tmp_path)
assert dummy_scanner.add_model_to_cache.await_count == 1 assert dummy_scanner.add_model_to_cache.await_count == 1
@pytest.mark.asyncio
async def test_extract_model_files_from_archive_uses_executor(monkeypatch, tmp_path):
manager = DownloadManager()
archive_path = tmp_path / "bundle.zip"
with zipfile.ZipFile(archive_path, "w") as archive:
archive.writestr("inner/model.safetensors", b"model")
captured = {}
class ImmediateLoop:
async def run_in_executor(self, executor, func, *args):
captured["executor"] = executor
return func(*args)
monkeypatch.setattr(
download_manager.asyncio,
"get_running_loop",
lambda: ImmediateLoop(),
)
extracted = await manager._extract_model_files_from_archive(
str(archive_path),
{".safetensors"},
)
assert captured["executor"] is manager._archive_executor
assert len(extracted) == 1
assert extracted[0].endswith("model.safetensors")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_pause_download_updates_state(): async def test_pause_download_updates_state():
"""Test that pause_download updates download state correctly.""" """Test that pause_download updates download state correctly."""
@@ -469,6 +704,233 @@ async def test_pause_download_updates_state():
assert manager._active_downloads[download_id]["bytes_per_second"] == 0.0 assert manager._active_downloads[download_id]["bytes_per_second"] == 0.0
@pytest.mark.asyncio
async def test_pause_download_reverts_local_pause_when_aria2_pause_fails(monkeypatch):
manager = DownloadManager()
download_id = "dl"
manager._download_tasks[download_id] = object()
pause_control = DownloadStreamControl()
manager._pause_events[download_id] = pause_control
manager._active_downloads[download_id] = {
"transfer_backend": "aria2",
"status": "downloading",
"bytes_per_second": 42.0,
}
class DummyAria2Downloader:
async def has_transfer(self, _download_id):
return True
async def pause_download(self, _download_id):
return {"success": False, "error": "rpc failed"}
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
result = await manager.pause_download(download_id)
assert result == {"success": False, "error": "rpc failed"}
assert pause_control.is_set() is True
assert manager._active_downloads[download_id]["status"] == "downloading"
@pytest.mark.asyncio
async def test_pause_download_reverts_local_pause_when_aria2_probe_raises(monkeypatch):
manager = DownloadManager()
download_id = "dl"
manager._download_tasks[download_id] = object()
pause_control = DownloadStreamControl()
manager._pause_events[download_id] = pause_control
manager._active_downloads[download_id] = {
"transfer_backend": "aria2",
"status": "downloading",
"bytes_per_second": 42.0,
}
class DummyAria2Downloader:
async def has_transfer(self, _download_id):
raise RuntimeError("rpc unavailable")
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
result = await manager.pause_download(download_id)
assert result == {"success": False, "error": "rpc unavailable"}
assert pause_control.is_set() is True
assert manager._active_downloads[download_id]["status"] == "downloading"
@pytest.mark.asyncio
async def test_resume_download_returns_error_when_aria2_probe_raises(monkeypatch):
manager = DownloadManager()
download_id = "dl"
pause_control = DownloadStreamControl()
pause_control.pause()
manager._pause_events[download_id] = pause_control
manager._active_downloads[download_id] = {
"transfer_backend": "aria2",
"status": "paused",
"bytes_per_second": 0.0,
}
class DummyAria2Downloader:
async def has_transfer(self, _download_id):
raise RuntimeError("rpc unavailable")
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
result = await manager.resume_download(download_id)
assert result == {"success": False, "error": "rpc unavailable"}
assert pause_control.is_paused() is True
assert manager._active_downloads[download_id]["status"] == "paused"
@pytest.mark.asyncio
async def test_cancel_download_still_cancels_local_task_when_aria2_raises(monkeypatch):
manager = DownloadManager()
started = asyncio.Event()
async def blocked_task():
started.set()
await asyncio.sleep(60)
task = asyncio.create_task(blocked_task())
await started.wait()
download_id = "download-queued"
manager._download_tasks[download_id] = task
manager._active_downloads[download_id] = {
"transfer_backend": "aria2",
"status": "queued",
}
class DummyAria2Downloader:
async def cancel_download(self, _download_id):
raise RuntimeError("rpc unavailable")
monkeypatch.setattr(
download_manager,
"get_aria2_downloader",
AsyncMock(return_value=DummyAria2Downloader()),
)
result = await manager.cancel_download(download_id)
assert result["success"] is True
assert task.cancelled() or task.done()
@pytest.mark.asyncio
async def test_execute_download_waits_for_paused_pre_transfer_gate(monkeypatch, tmp_path):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
target_path = save_dir / "file.safetensors"
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = str(path)
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
def generate_unique_filename(self, *_args, **_kwargs):
return os.path.basename(self.file_path)
def update_file_info(self, _path):
return None
def to_dict(self):
return {"file_path": self.file_path}
pause_control = DownloadStreamControl()
pause_control.pause()
manager._pause_events["download-1"] = pause_control
manager._active_downloads["download-1"] = {
"status": "downloading",
"bytes_per_second": 42.0,
}
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
monkeypatch.setattr(
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
)
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
started = asyncio.Event()
allow_finish = asyncio.Event()
captured = {"calls": 0}
async def fake_download_model_file(
self,
download_url,
save_path,
*,
backend,
progress_callback,
use_auth,
download_id,
pause_control,
):
captured["calls"] += 1
started.set()
await allow_finish.wait()
Path(save_path).write_text("content")
return True, save_path
monkeypatch.setattr(
DownloadManager,
"_download_model_file",
fake_download_model_file,
)
task = asyncio.create_task(
manager._execute_download(
download_urls=["https://civitai.com/api/download/models/1"],
save_dir=str(save_dir),
metadata=DummyMetadata(target_path),
version_info={"images": []},
relative_path="",
progress_callback=None,
model_type="lora",
download_id="download-1",
transfer_backend="aria2",
)
)
await asyncio.sleep(0)
assert started.is_set() is False
assert captured["calls"] == 0
assert manager._active_downloads["download-1"]["status"] == "paused"
pause_control.resume()
await asyncio.wait_for(started.wait(), timeout=1.0)
assert captured["calls"] == 1
assert manager._active_downloads["download-1"]["status"] == "downloading"
allow_finish.set()
result = await task
assert result == {"success": True}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_pause_download_rejects_unknown_task(): async def test_pause_download_rejects_unknown_task():
"""Test that pause_download rejects unknown download tasks.""" """Test that pause_download rejects unknown download tasks."""

View File

@@ -147,6 +147,11 @@ def test_environment_variable_overrides_settings(tmp_path, monkeypatch):
assert mgr.get("civitai_api_key") == "secret" assert mgr.get("civitai_api_key") == "secret"
def test_default_download_backend_is_python(manager):
assert manager.get("download_backend") == "python"
assert manager.get("aria2c_path") == ""
def _create_manager_with_settings( def _create_manager_with_settings(
tmp_path, monkeypatch, initial_settings, *, save_spy=None tmp_path, monkeypatch, initial_settings, *, save_spy=None
): ):