mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-06 16:36:45 -03:00
feat(download): add experimental aria2 backend
This commit is contained in:
@@ -263,6 +263,19 @@
|
|||||||
"red": "civitai.red (uneingeschränkt)"
|
"red": "civitai.red (uneingeschränkt)"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"downloadBackend": {
|
||||||
|
"label": "Download-Backend",
|
||||||
|
"help": "Wähle aus, wie Modelldateien heruntergeladen werden. Python verwendet den eingebauten Downloader. aria2 verwendet den experimentellen externen Downloader-Prozess.",
|
||||||
|
"options": {
|
||||||
|
"python": "Python (integriert)",
|
||||||
|
"aria2": "aria2 (experimentell)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"aria2cPath": {
|
||||||
|
"label": "aria2c-Pfad",
|
||||||
|
"help": "Optionaler Pfad zur ausführbaren aria2c-Datei. Leer lassen, um aria2c aus dem System-PATH zu verwenden.",
|
||||||
|
"placeholder": "Leer lassen, um aria2c aus dem PATH zu verwenden"
|
||||||
|
},
|
||||||
"civitaiHostBanner": {
|
"civitaiHostBanner": {
|
||||||
"title": "Civitai-Host-Einstellung verfügbar",
|
"title": "Civitai-Host-Einstellung verfügbar",
|
||||||
"content": "Civitai verwendet jetzt civitai.com für SFW-Inhalte und civitai.red für uneingeschränkte Inhalte. In den Einstellungen können Sie ändern, welche Seite standardmäßig geöffnet wird.",
|
"content": "Civitai verwendet jetzt civitai.com für SFW-Inhalte und civitai.red für uneingeschränkte Inhalte. In den Einstellungen können Sie ändern, welche Seite standardmäßig geöffnet wird.",
|
||||||
@@ -278,6 +291,7 @@
|
|||||||
},
|
},
|
||||||
"sections": {
|
"sections": {
|
||||||
"contentFiltering": "Inhaltsfilterung",
|
"contentFiltering": "Inhaltsfilterung",
|
||||||
|
"downloads": "Downloads",
|
||||||
"videoSettings": "Video-Einstellungen",
|
"videoSettings": "Video-Einstellungen",
|
||||||
"layoutSettings": "Layout-Einstellungen",
|
"layoutSettings": "Layout-Einstellungen",
|
||||||
"misc": "Verschiedenes",
|
"misc": "Verschiedenes",
|
||||||
|
|||||||
@@ -263,6 +263,19 @@
|
|||||||
"red": "civitai.red (unrestricted)"
|
"red": "civitai.red (unrestricted)"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"downloadBackend": {
|
||||||
|
"label": "Download backend",
|
||||||
|
"help": "Choose how model files are downloaded. Python uses the built-in downloader. aria2 uses the experimental external downloader process.",
|
||||||
|
"options": {
|
||||||
|
"python": "Python (built-in)",
|
||||||
|
"aria2": "aria2 (experimental)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"aria2cPath": {
|
||||||
|
"label": "aria2c path",
|
||||||
|
"help": "Optional path to the aria2c executable. Leave empty to use aria2c from your system PATH.",
|
||||||
|
"placeholder": "Leave empty to use aria2c from PATH"
|
||||||
|
},
|
||||||
"civitaiHostBanner": {
|
"civitaiHostBanner": {
|
||||||
"title": "Civitai host preference available",
|
"title": "Civitai host preference available",
|
||||||
"content": "Civitai now uses civitai.com for SFW content and civitai.red for unrestricted content. You can change which site opens by default in Settings.",
|
"content": "Civitai now uses civitai.com for SFW content and civitai.red for unrestricted content. You can change which site opens by default in Settings.",
|
||||||
@@ -278,6 +291,7 @@
|
|||||||
},
|
},
|
||||||
"sections": {
|
"sections": {
|
||||||
"contentFiltering": "Content Filtering",
|
"contentFiltering": "Content Filtering",
|
||||||
|
"downloads": "Downloads",
|
||||||
"videoSettings": "Video Settings",
|
"videoSettings": "Video Settings",
|
||||||
"layoutSettings": "Layout Settings",
|
"layoutSettings": "Layout Settings",
|
||||||
"misc": "Miscellaneous",
|
"misc": "Miscellaneous",
|
||||||
|
|||||||
@@ -263,6 +263,19 @@
|
|||||||
"red": "civitai.red (sin restricciones)"
|
"red": "civitai.red (sin restricciones)"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"downloadBackend": {
|
||||||
|
"label": "Backend de descarga",
|
||||||
|
"help": "Elige cómo se descargan los archivos del modelo. Python usa el descargador integrado. aria2 usa el proceso externo experimental de descarga.",
|
||||||
|
"options": {
|
||||||
|
"python": "Python (integrado)",
|
||||||
|
"aria2": "aria2 (experimental)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"aria2cPath": {
|
||||||
|
"label": "Ruta de aria2c",
|
||||||
|
"help": "Ruta opcional al ejecutable aria2c. Déjalo vacío para usar aria2c desde el PATH del sistema.",
|
||||||
|
"placeholder": "Déjalo vacío para usar aria2c desde el PATH"
|
||||||
|
},
|
||||||
"civitaiHostBanner": {
|
"civitaiHostBanner": {
|
||||||
"title": "Preferencia de host de Civitai disponible",
|
"title": "Preferencia de host de Civitai disponible",
|
||||||
"content": "Civitai ahora usa civitai.com para contenido SFW y civitai.red para contenido sin restricciones. Puedes cambiar en Ajustes qué sitio se abre por defecto.",
|
"content": "Civitai ahora usa civitai.com para contenido SFW y civitai.red para contenido sin restricciones. Puedes cambiar en Ajustes qué sitio se abre por defecto.",
|
||||||
@@ -278,6 +291,7 @@
|
|||||||
},
|
},
|
||||||
"sections": {
|
"sections": {
|
||||||
"contentFiltering": "Filtrado de contenido",
|
"contentFiltering": "Filtrado de contenido",
|
||||||
|
"downloads": "Descargas",
|
||||||
"videoSettings": "Configuración de video",
|
"videoSettings": "Configuración de video",
|
||||||
"layoutSettings": "Configuración de diseño",
|
"layoutSettings": "Configuración de diseño",
|
||||||
"misc": "Varios",
|
"misc": "Varios",
|
||||||
|
|||||||
@@ -263,6 +263,19 @@
|
|||||||
"red": "civitai.red (sans restriction)"
|
"red": "civitai.red (sans restriction)"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"downloadBackend": {
|
||||||
|
"label": "Moteur de téléchargement",
|
||||||
|
"help": "Choisissez comment les fichiers de modèles sont téléchargés. Python utilise le téléchargeur intégré. aria2 utilise le processus externe expérimental de téléchargement.",
|
||||||
|
"options": {
|
||||||
|
"python": "Python (intégré)",
|
||||||
|
"aria2": "aria2 (expérimental)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"aria2cPath": {
|
||||||
|
"label": "Chemin vers aria2c",
|
||||||
|
"help": "Chemin facultatif vers l’exécutable aria2c. Laissez vide pour utiliser aria2c depuis le PATH système.",
|
||||||
|
"placeholder": "Laisser vide pour utiliser aria2c depuis le PATH"
|
||||||
|
},
|
||||||
"civitaiHostBanner": {
|
"civitaiHostBanner": {
|
||||||
"title": "Préférence d’hôte Civitai disponible",
|
"title": "Préférence d’hôte Civitai disponible",
|
||||||
"content": "Civitai utilise désormais civitai.com pour le contenu SFW et civitai.red pour le contenu sans restriction. Vous pouvez modifier dans les paramètres le site ouvert par défaut.",
|
"content": "Civitai utilise désormais civitai.com pour le contenu SFW et civitai.red pour le contenu sans restriction. Vous pouvez modifier dans les paramètres le site ouvert par défaut.",
|
||||||
@@ -278,6 +291,7 @@
|
|||||||
},
|
},
|
||||||
"sections": {
|
"sections": {
|
||||||
"contentFiltering": "Filtrage du contenu",
|
"contentFiltering": "Filtrage du contenu",
|
||||||
|
"downloads": "Téléchargements",
|
||||||
"videoSettings": "Paramètres vidéo",
|
"videoSettings": "Paramètres vidéo",
|
||||||
"layoutSettings": "Paramètres d'affichage",
|
"layoutSettings": "Paramètres d'affichage",
|
||||||
"misc": "Divers",
|
"misc": "Divers",
|
||||||
|
|||||||
@@ -263,6 +263,19 @@
|
|||||||
"red": "civitai.red (ללא הגבלות)"
|
"red": "civitai.red (ללא הגבלות)"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"downloadBackend": {
|
||||||
|
"label": "מנגנון הורדה",
|
||||||
|
"help": "בחר כיצד יורדים קבצי המודל. Python משתמש במוריד המובנה. aria2 משתמש בתהליך הורדה חיצוני ניסיוני.",
|
||||||
|
"options": {
|
||||||
|
"python": "Python (מובנה)",
|
||||||
|
"aria2": "aria2 (ניסיוני)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"aria2cPath": {
|
||||||
|
"label": "נתיב aria2c",
|
||||||
|
"help": "נתיב אופציונלי לקובץ ההפעלה aria2c. השאר ריק כדי להשתמש ב-aria2c מתוך ה-PATH של המערכת.",
|
||||||
|
"placeholder": "השאר ריק כדי להשתמש ב-aria2c מתוך ה-PATH"
|
||||||
|
},
|
||||||
"civitaiHostBanner": {
|
"civitaiHostBanner": {
|
||||||
"title": "העדפת מארח Civitai זמינה",
|
"title": "העדפת מארח Civitai זמינה",
|
||||||
"content": "Civitai משתמש כעת ב-civitai.com עבור תוכן SFW וב-civitai.red עבור תוכן ללא הגבלות. ניתן לשנות בהגדרות איזה אתר ייפתח כברירת מחדל.",
|
"content": "Civitai משתמש כעת ב-civitai.com עבור תוכן SFW וב-civitai.red עבור תוכן ללא הגבלות. ניתן לשנות בהגדרות איזה אתר ייפתח כברירת מחדל.",
|
||||||
@@ -278,6 +291,7 @@
|
|||||||
},
|
},
|
||||||
"sections": {
|
"sections": {
|
||||||
"contentFiltering": "סינון תוכן",
|
"contentFiltering": "סינון תוכן",
|
||||||
|
"downloads": "הורדות",
|
||||||
"videoSettings": "הגדרות וידאו",
|
"videoSettings": "הגדרות וידאו",
|
||||||
"layoutSettings": "הגדרות פריסה",
|
"layoutSettings": "הגדרות פריסה",
|
||||||
"misc": "שונות",
|
"misc": "שונות",
|
||||||
|
|||||||
@@ -263,6 +263,19 @@
|
|||||||
"red": "civitai.red(制限なし)"
|
"red": "civitai.red(制限なし)"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"downloadBackend": {
|
||||||
|
"label": "ダウンロードバックエンド",
|
||||||
|
"help": "モデルファイルのダウンロード方法を選択します。Python は内蔵ダウンローダーを使用し、aria2 は実験的な外部ダウンローダープロセスを使用します。",
|
||||||
|
"options": {
|
||||||
|
"python": "Python(内蔵)",
|
||||||
|
"aria2": "aria2(実験的)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"aria2cPath": {
|
||||||
|
"label": "aria2c のパス",
|
||||||
|
"help": "aria2c 実行ファイルへの任意のパスです。空欄のままにすると、システム PATH 上の aria2c を使用します。",
|
||||||
|
"placeholder": "空欄のままにすると PATH 上の aria2c を使用します"
|
||||||
|
},
|
||||||
"civitaiHostBanner": {
|
"civitaiHostBanner": {
|
||||||
"title": "Civitai ホスト設定を利用できます",
|
"title": "Civitai ホスト設定を利用できます",
|
||||||
"content": "Civitai は現在、SFW コンテンツには civitai.com、制限なしコンテンツには civitai.red を使用しています。設定で既定で開くサイトを変更できます。",
|
"content": "Civitai は現在、SFW コンテンツには civitai.com、制限なしコンテンツには civitai.red を使用しています。設定で既定で開くサイトを変更できます。",
|
||||||
@@ -278,6 +291,7 @@
|
|||||||
},
|
},
|
||||||
"sections": {
|
"sections": {
|
||||||
"contentFiltering": "コンテンツフィルタリング",
|
"contentFiltering": "コンテンツフィルタリング",
|
||||||
|
"downloads": "ダウンロード",
|
||||||
"videoSettings": "動画設定",
|
"videoSettings": "動画設定",
|
||||||
"layoutSettings": "レイアウト設定",
|
"layoutSettings": "レイアウト設定",
|
||||||
"misc": "その他",
|
"misc": "その他",
|
||||||
|
|||||||
@@ -263,6 +263,19 @@
|
|||||||
"red": "civitai.red(무제한)"
|
"red": "civitai.red(무제한)"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"downloadBackend": {
|
||||||
|
"label": "다운로드 백엔드",
|
||||||
|
"help": "모델 파일을 다운로드하는 방식을 선택합니다. Python은 내장 다운로더를 사용하고, aria2는 실험적인 외부 다운로더 프로세스를 사용합니다.",
|
||||||
|
"options": {
|
||||||
|
"python": "Python(내장)",
|
||||||
|
"aria2": "aria2(실험적)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"aria2cPath": {
|
||||||
|
"label": "aria2c 경로",
|
||||||
|
"help": "aria2c 실행 파일의 선택적 경로입니다. 비워 두면 시스템 PATH의 aria2c를 사용합니다.",
|
||||||
|
"placeholder": "비워 두면 PATH의 aria2c를 사용합니다"
|
||||||
|
},
|
||||||
"civitaiHostBanner": {
|
"civitaiHostBanner": {
|
||||||
"title": "Civitai 호스트 기본 설정 사용 가능",
|
"title": "Civitai 호스트 기본 설정 사용 가능",
|
||||||
"content": "이제 Civitai는 SFW 콘텐츠에 civitai.com을, 무제한 콘텐츠에 civitai.red를 사용합니다. 설정에서 기본으로 열 사이트를 변경할 수 있습니다.",
|
"content": "이제 Civitai는 SFW 콘텐츠에 civitai.com을, 무제한 콘텐츠에 civitai.red를 사용합니다. 설정에서 기본으로 열 사이트를 변경할 수 있습니다.",
|
||||||
@@ -278,6 +291,7 @@
|
|||||||
},
|
},
|
||||||
"sections": {
|
"sections": {
|
||||||
"contentFiltering": "콘텐츠 필터링",
|
"contentFiltering": "콘텐츠 필터링",
|
||||||
|
"downloads": "다운로드",
|
||||||
"videoSettings": "비디오 설정",
|
"videoSettings": "비디오 설정",
|
||||||
"layoutSettings": "레이아웃 설정",
|
"layoutSettings": "레이아웃 설정",
|
||||||
"misc": "기타",
|
"misc": "기타",
|
||||||
|
|||||||
@@ -263,6 +263,19 @@
|
|||||||
"red": "civitai.red (без ограничений)"
|
"red": "civitai.red (без ограничений)"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"downloadBackend": {
|
||||||
|
"label": "Бэкенд загрузки",
|
||||||
|
"help": "Выберите способ загрузки файлов моделей. Python использует встроенный загрузчик. aria2 использует экспериментальный внешний процесс загрузки.",
|
||||||
|
"options": {
|
||||||
|
"python": "Python (встроенный)",
|
||||||
|
"aria2": "aria2 (экспериментальный)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"aria2cPath": {
|
||||||
|
"label": "Путь к aria2c",
|
||||||
|
"help": "Необязательный путь к исполняемому файлу aria2c. Оставьте пустым, чтобы использовать aria2c из системного PATH.",
|
||||||
|
"placeholder": "Оставьте пустым, чтобы использовать aria2c из PATH"
|
||||||
|
},
|
||||||
"civitaiHostBanner": {
|
"civitaiHostBanner": {
|
||||||
"title": "Доступна настройка хоста Civitai",
|
"title": "Доступна настройка хоста Civitai",
|
||||||
"content": "Теперь Civitai использует civitai.com для контента SFW и civitai.red для контента без ограничений. В настройках можно изменить, какой сайт открывать по умолчанию.",
|
"content": "Теперь Civitai использует civitai.com для контента SFW и civitai.red для контента без ограничений. В настройках можно изменить, какой сайт открывать по умолчанию.",
|
||||||
@@ -278,6 +291,7 @@
|
|||||||
},
|
},
|
||||||
"sections": {
|
"sections": {
|
||||||
"contentFiltering": "Фильтрация контента",
|
"contentFiltering": "Фильтрация контента",
|
||||||
|
"downloads": "Загрузки",
|
||||||
"videoSettings": "Настройки видео",
|
"videoSettings": "Настройки видео",
|
||||||
"layoutSettings": "Настройки макета",
|
"layoutSettings": "Настройки макета",
|
||||||
"misc": "Разное",
|
"misc": "Разное",
|
||||||
|
|||||||
@@ -263,6 +263,19 @@
|
|||||||
"red": "civitai.red(无限制)"
|
"red": "civitai.red(无限制)"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"downloadBackend": {
|
||||||
|
"label": "下载后端",
|
||||||
|
"help": "选择模型文件的下载方式。Python 使用内置下载器。aria2 使用实验性的外部下载进程。",
|
||||||
|
"options": {
|
||||||
|
"python": "Python(内置)",
|
||||||
|
"aria2": "aria2(实验性)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"aria2cPath": {
|
||||||
|
"label": "aria2c 路径",
|
||||||
|
"help": "可选的 aria2c 可执行文件路径。留空则使用系统 PATH 中的 aria2c。",
|
||||||
|
"placeholder": "留空则使用 PATH 中的 aria2c"
|
||||||
|
},
|
||||||
"civitaiHostBanner": {
|
"civitaiHostBanner": {
|
||||||
"title": "已提供 Civitai 站点偏好设置",
|
"title": "已提供 Civitai 站点偏好设置",
|
||||||
"content": "Civitai 现在使用 civitai.com 提供 SFW 内容,使用 civitai.red 提供无限制内容。你可以在设置中更改默认打开的站点。",
|
"content": "Civitai 现在使用 civitai.com 提供 SFW 内容,使用 civitai.red 提供无限制内容。你可以在设置中更改默认打开的站点。",
|
||||||
@@ -278,6 +291,7 @@
|
|||||||
},
|
},
|
||||||
"sections": {
|
"sections": {
|
||||||
"contentFiltering": "内容过滤",
|
"contentFiltering": "内容过滤",
|
||||||
|
"downloads": "下载",
|
||||||
"videoSettings": "视频设置",
|
"videoSettings": "视频设置",
|
||||||
"layoutSettings": "布局设置",
|
"layoutSettings": "布局设置",
|
||||||
"misc": "其他",
|
"misc": "其他",
|
||||||
|
|||||||
@@ -263,6 +263,19 @@
|
|||||||
"red": "civitai.red(無限制)"
|
"red": "civitai.red(無限制)"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"downloadBackend": {
|
||||||
|
"label": "下載後端",
|
||||||
|
"help": "選擇模型檔案的下載方式。Python 使用內建下載器。aria2 使用實驗性的外部下載程序。",
|
||||||
|
"options": {
|
||||||
|
"python": "Python(內建)",
|
||||||
|
"aria2": "aria2(實驗性)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"aria2cPath": {
|
||||||
|
"label": "aria2c 路徑",
|
||||||
|
"help": "可選的 aria2c 可執行檔路徑。留空則使用系統 PATH 中的 aria2c。",
|
||||||
|
"placeholder": "留空則使用 PATH 中的 aria2c"
|
||||||
|
},
|
||||||
"civitaiHostBanner": {
|
"civitaiHostBanner": {
|
||||||
"title": "已提供 Civitai 站點偏好設定",
|
"title": "已提供 Civitai 站點偏好設定",
|
||||||
"content": "Civitai 現在使用 civitai.com 提供 SFW 內容,使用 civitai.red 提供無限制內容。你可以在設定中變更預設開啟的站點。",
|
"content": "Civitai 現在使用 civitai.com 提供 SFW 內容,使用 civitai.red 提供無限制內容。你可以在設定中變更預設開啟的站點。",
|
||||||
@@ -278,6 +291,7 @@
|
|||||||
},
|
},
|
||||||
"sections": {
|
"sections": {
|
||||||
"contentFiltering": "內容過濾",
|
"contentFiltering": "內容過濾",
|
||||||
|
"downloads": "下載",
|
||||||
"videoSettings": "影片設定",
|
"videoSettings": "影片設定",
|
||||||
"layoutSettings": "版面設定",
|
"layoutSettings": "版面設定",
|
||||||
"misc": "其他",
|
"misc": "其他",
|
||||||
|
|||||||
497
py/services/aria2_downloader.py
Normal file
497
py/services/aria2_downloader.py
Normal file
@@ -0,0 +1,497 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import secrets
|
||||||
|
import shutil
|
||||||
|
import socket
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from .downloader import DownloadProgress, get_downloader
|
||||||
|
from .settings_manager import get_settings_manager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
CIVITAI_DOWNLOAD_URL_PREFIXES = (
|
||||||
|
"https://civitai.com/api/download/",
|
||||||
|
"https://civitai.red/api/download/",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Aria2Error(RuntimeError):
|
||||||
|
"""Raised when aria2 integration fails."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Aria2Transfer:
|
||||||
|
"""Track an aria2 download registered by the Python coordinator."""
|
||||||
|
|
||||||
|
gid: str
|
||||||
|
save_path: str
|
||||||
|
|
||||||
|
|
||||||
|
class Aria2Downloader:
|
||||||
|
"""Manage an aria2 RPC daemon for experimental model downloads."""
|
||||||
|
|
||||||
|
_instance = None
|
||||||
|
_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_instance(cls) -> "Aria2Downloader":
|
||||||
|
async with cls._lock:
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = cls()
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
if hasattr(self, "_initialized"):
|
||||||
|
return
|
||||||
|
|
||||||
|
self._initialized = True
|
||||||
|
self._process: Optional[asyncio.subprocess.Process] = None
|
||||||
|
self._rpc_port: Optional[int] = None
|
||||||
|
self._rpc_secret = ""
|
||||||
|
self._rpc_url = ""
|
||||||
|
self._rpc_session: Optional[aiohttp.ClientSession] = None
|
||||||
|
self._rpc_session_lock = asyncio.Lock()
|
||||||
|
self._process_lock = asyncio.Lock()
|
||||||
|
self._transfers: Dict[str, Aria2Transfer] = {}
|
||||||
|
self._poll_interval = 0.5
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_running(self) -> bool:
|
||||||
|
return self._process is not None and self._process.returncode is None
|
||||||
|
|
||||||
|
async def download_file(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
save_path: str,
|
||||||
|
*,
|
||||||
|
download_id: str,
|
||||||
|
progress_callback=None,
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
|
) -> Tuple[bool, str]:
|
||||||
|
"""Download a file using aria2 RPC and wait for completion."""
|
||||||
|
|
||||||
|
await self._ensure_process()
|
||||||
|
save_path = os.path.abspath(save_path)
|
||||||
|
save_dir = os.path.dirname(save_path)
|
||||||
|
out_name = os.path.basename(save_path)
|
||||||
|
|
||||||
|
Path(save_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
resolved_url = url
|
||||||
|
request_headers = headers
|
||||||
|
if headers and url.startswith(CIVITAI_DOWNLOAD_URL_PREFIXES):
|
||||||
|
resolved_url = await self._resolve_authenticated_redirect_url(url, headers)
|
||||||
|
if resolved_url != url:
|
||||||
|
request_headers = None
|
||||||
|
logger.debug(
|
||||||
|
"Resolved Civitai download %s to signed URL for aria2",
|
||||||
|
download_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
options: Dict[str, str] = {
|
||||||
|
"dir": save_dir,
|
||||||
|
"out": out_name,
|
||||||
|
"continue": "true",
|
||||||
|
"max-connection-per-server": "4",
|
||||||
|
"split": "4",
|
||||||
|
"min-split-size": "1M",
|
||||||
|
"allow-overwrite": "true",
|
||||||
|
"auto-file-renaming": "false",
|
||||||
|
"file-allocation": "none",
|
||||||
|
}
|
||||||
|
if request_headers:
|
||||||
|
options["header"] = [
|
||||||
|
f"{key}: {value}" for key, value in request_headers.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Submitting aria2 download %s -> %s (auth=%s, civitai_signed=%s)",
|
||||||
|
download_id,
|
||||||
|
save_path,
|
||||||
|
bool(request_headers),
|
||||||
|
resolved_url != url,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
gid = await self._rpc_call("aria2.addUri", [[resolved_url], options])
|
||||||
|
except Exception as exc:
|
||||||
|
raise Aria2Error(f"Failed to schedule aria2 download: {exc}") from exc
|
||||||
|
|
||||||
|
logger.debug("aria2 accepted download %s with gid %s", download_id, gid)
|
||||||
|
|
||||||
|
self._transfers[download_id] = Aria2Transfer(gid=gid, save_path=save_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
status = await self.get_status(download_id)
|
||||||
|
if status is None:
|
||||||
|
return False, "aria2 download not found"
|
||||||
|
|
||||||
|
snapshot = self._build_progress_snapshot(status)
|
||||||
|
if progress_callback is not None:
|
||||||
|
await self._dispatch_progress(progress_callback, snapshot)
|
||||||
|
|
||||||
|
state = status.get("status", "")
|
||||||
|
if state == "complete":
|
||||||
|
completed_path = self._resolve_completed_path(status, save_path)
|
||||||
|
return True, completed_path
|
||||||
|
if state == "error":
|
||||||
|
return False, status.get("errorMessage") or "aria2 download failed"
|
||||||
|
if state == "removed":
|
||||||
|
return False, "Download was cancelled"
|
||||||
|
|
||||||
|
await asyncio.sleep(self._poll_interval)
|
||||||
|
finally:
|
||||||
|
self._transfers.pop(download_id, None)
|
||||||
|
|
||||||
|
async def get_status(self, download_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Return the raw aria2 status payload for a known download."""
|
||||||
|
|
||||||
|
transfer = self._transfers.get(download_id)
|
||||||
|
if transfer is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
keys = [
|
||||||
|
"gid",
|
||||||
|
"status",
|
||||||
|
"totalLength",
|
||||||
|
"completedLength",
|
||||||
|
"downloadSpeed",
|
||||||
|
"errorMessage",
|
||||||
|
"files",
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
status = await self._rpc_call("aria2.tellStatus", [transfer.gid, keys])
|
||||||
|
except Exception as exc:
|
||||||
|
raise Aria2Error(f"Failed to query aria2 download status: {exc}") from exc
|
||||||
|
|
||||||
|
if isinstance(status, dict):
|
||||||
|
return status
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def has_transfer(self, download_id: str) -> bool:
|
||||||
|
return download_id in self._transfers
|
||||||
|
|
||||||
|
async def pause_download(self, download_id: str) -> Dict[str, Any]:
|
||||||
|
transfer = self._transfers.get(download_id)
|
||||||
|
if transfer is None:
|
||||||
|
return {"success": False, "error": "Download task not found"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._rpc_call("aria2.forcePause", [transfer.gid])
|
||||||
|
except Exception as exc:
|
||||||
|
return {"success": False, "error": str(exc)}
|
||||||
|
|
||||||
|
return {"success": True, "message": "Download paused successfully"}
|
||||||
|
|
||||||
|
async def resume_download(self, download_id: str) -> Dict[str, Any]:
|
||||||
|
transfer = self._transfers.get(download_id)
|
||||||
|
if transfer is None:
|
||||||
|
return {"success": False, "error": "Download task not found"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._rpc_call("aria2.unpause", [transfer.gid])
|
||||||
|
except Exception as exc:
|
||||||
|
return {"success": False, "error": str(exc)}
|
||||||
|
|
||||||
|
return {"success": True, "message": "Download resumed successfully"}
|
||||||
|
|
||||||
|
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
|
||||||
|
transfer = self._transfers.get(download_id)
|
||||||
|
if transfer is None:
|
||||||
|
return {"success": False, "error": "Download task not found"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._rpc_call("aria2.forceRemove", [transfer.gid])
|
||||||
|
except Exception as exc:
|
||||||
|
return {"success": False, "error": str(exc)}
|
||||||
|
|
||||||
|
return {"success": True, "message": "Download cancelled successfully"}
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Shut down the RPC process and session."""
|
||||||
|
|
||||||
|
if self._rpc_session is not None:
|
||||||
|
await self._rpc_session.close()
|
||||||
|
self._rpc_session = None
|
||||||
|
|
||||||
|
process = self._process
|
||||||
|
self._process = None
|
||||||
|
self._transfers.clear()
|
||||||
|
|
||||||
|
if process is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if process.returncode is None:
|
||||||
|
process.terminate()
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(process.wait(), timeout=5.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
process.kill()
|
||||||
|
await process.wait()
|
||||||
|
|
||||||
|
async def _dispatch_progress(self, callback, snapshot: DownloadProgress) -> None:
|
||||||
|
try:
|
||||||
|
result = callback(snapshot, snapshot)
|
||||||
|
except TypeError:
|
||||||
|
result = callback(snapshot.percent_complete)
|
||||||
|
|
||||||
|
if asyncio.iscoroutine(result):
|
||||||
|
await result
|
||||||
|
elif hasattr(result, "__await__"):
|
||||||
|
await result
|
||||||
|
|
||||||
|
def _build_progress_snapshot(self, status: Dict[str, Any]) -> DownloadProgress:
|
||||||
|
completed = self._parse_int(status.get("completedLength"))
|
||||||
|
total = self._parse_int(status.get("totalLength"))
|
||||||
|
speed = float(self._parse_int(status.get("downloadSpeed")))
|
||||||
|
percent = 0.0
|
||||||
|
if total > 0:
|
||||||
|
percent = (completed / total) * 100.0
|
||||||
|
|
||||||
|
return DownloadProgress(
|
||||||
|
percent_complete=max(0.0, min(percent, 100.0)),
|
||||||
|
bytes_downloaded=completed,
|
||||||
|
total_bytes=total or None,
|
||||||
|
bytes_per_second=speed,
|
||||||
|
timestamp=datetime.now().timestamp(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _resolve_completed_path(self, status: Dict[str, Any], default_path: str) -> str:
|
||||||
|
files = status.get("files")
|
||||||
|
if isinstance(files, list) and files:
|
||||||
|
first = files[0]
|
||||||
|
if isinstance(first, dict):
|
||||||
|
candidate = first.get("path")
|
||||||
|
if isinstance(candidate, str) and candidate:
|
||||||
|
return candidate
|
||||||
|
return default_path
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_int(value: Any) -> int:
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def _resolve_authenticated_redirect_url(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
headers: Dict[str, str],
|
||||||
|
) -> str:
|
||||||
|
downloader = await get_downloader()
|
||||||
|
session = await downloader.session
|
||||||
|
request_headers = dict(downloader.default_headers)
|
||||||
|
request_headers.update(headers)
|
||||||
|
request_headers["Accept-Encoding"] = "identity"
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.get(
|
||||||
|
url,
|
||||||
|
headers=request_headers,
|
||||||
|
allow_redirects=False,
|
||||||
|
proxy=downloader.proxy_url,
|
||||||
|
) as response:
|
||||||
|
if response.status in {301, 302, 303, 307, 308}:
|
||||||
|
location = response.headers.get("Location")
|
||||||
|
if location:
|
||||||
|
return location
|
||||||
|
raise Aria2Error(
|
||||||
|
"Authenticated Civitai redirect did not include a Location header"
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status == 200:
|
||||||
|
return url
|
||||||
|
|
||||||
|
body = await response.text()
|
||||||
|
raise Aria2Error(
|
||||||
|
f"Failed to resolve authenticated Civitai redirect: status={response.status} body={body[:300]}"
|
||||||
|
)
|
||||||
|
except aiohttp.ClientError as exc:
|
||||||
|
raise Aria2Error(
|
||||||
|
f"Failed to resolve authenticated Civitai redirect: {exc}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
async def _ensure_process(self) -> None:
|
||||||
|
async with self._process_lock:
|
||||||
|
if self.is_running and await self._ping():
|
||||||
|
return
|
||||||
|
|
||||||
|
await self.close()
|
||||||
|
|
||||||
|
executable = self._resolve_executable()
|
||||||
|
self._rpc_port = self._find_free_port()
|
||||||
|
self._rpc_secret = secrets.token_hex(16)
|
||||||
|
self._rpc_url = f"http://127.0.0.1:{self._rpc_port}/jsonrpc"
|
||||||
|
|
||||||
|
command = [
|
||||||
|
executable,
|
||||||
|
"--enable-rpc=true",
|
||||||
|
"--rpc-listen-all=false",
|
||||||
|
f"--rpc-listen-port={self._rpc_port}",
|
||||||
|
f"--rpc-secret={self._rpc_secret}",
|
||||||
|
"--check-certificate=true",
|
||||||
|
"--allow-overwrite=true",
|
||||||
|
"--auto-file-renaming=false",
|
||||||
|
"--file-allocation=none",
|
||||||
|
"--max-concurrent-downloads=5",
|
||||||
|
"--continue=true",
|
||||||
|
"--daemon=false",
|
||||||
|
"--quiet=true",
|
||||||
|
f"--stop-with-process={os.getpid()}",
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.info("Starting aria2 RPC daemon from %s", executable)
|
||||||
|
self._process = await asyncio.create_subprocess_exec(
|
||||||
|
*command,
|
||||||
|
stdout=asyncio.subprocess.DEVNULL,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._wait_until_ready()
|
||||||
|
|
||||||
|
def _resolve_executable(self) -> str:
|
||||||
|
settings = get_settings_manager()
|
||||||
|
configured_path = (settings.get("aria2c_path") or "").strip()
|
||||||
|
candidate = configured_path or "aria2c"
|
||||||
|
|
||||||
|
resolved = shutil.which(candidate)
|
||||||
|
if resolved:
|
||||||
|
return resolved
|
||||||
|
|
||||||
|
if configured_path and os.path.isfile(configured_path) and os.access(
|
||||||
|
configured_path, os.X_OK
|
||||||
|
):
|
||||||
|
return configured_path
|
||||||
|
|
||||||
|
raise Aria2Error(
|
||||||
|
"aria2c executable was not found. Install aria2 or configure aria2c_path."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _wait_until_ready(self) -> None:
|
||||||
|
assert self._process is not None
|
||||||
|
|
||||||
|
start_time = asyncio.get_running_loop().time()
|
||||||
|
last_error = ""
|
||||||
|
while asyncio.get_running_loop().time() - start_time < 10.0:
|
||||||
|
if self._process.returncode is not None:
|
||||||
|
stderr_output = ""
|
||||||
|
if self._process.stderr is not None:
|
||||||
|
try:
|
||||||
|
stderr_output = (
|
||||||
|
await asyncio.wait_for(self._process.stderr.read(), timeout=0.2)
|
||||||
|
).decode("utf-8", errors="replace")
|
||||||
|
except Exception:
|
||||||
|
stderr_output = ""
|
||||||
|
raise Aria2Error(
|
||||||
|
f"aria2 RPC process exited early with code {self._process.returncode}: {stderr_output.strip()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if await self._ping():
|
||||||
|
return
|
||||||
|
except Exception as exc: # pragma: no cover - startup race
|
||||||
|
last_error = str(exc)
|
||||||
|
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
|
||||||
|
raise Aria2Error(
|
||||||
|
f"Timed out waiting for aria2 RPC to become ready{': ' + last_error if last_error else ''}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _ping(self) -> bool:
|
||||||
|
try:
|
||||||
|
result = await self._rpc_call("aria2.getVersion", [])
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return isinstance(result, dict)
|
||||||
|
|
||||||
|
async def _rpc_call(self, method: str, params: list[Any]) -> Any:
|
||||||
|
if not self._rpc_url:
|
||||||
|
raise Aria2Error("aria2 RPC endpoint is not initialized")
|
||||||
|
|
||||||
|
session = await self._get_rpc_session()
|
||||||
|
payload = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": secrets.token_hex(8),
|
||||||
|
"method": method,
|
||||||
|
"params": [f"token:{self._rpc_secret}", *params],
|
||||||
|
}
|
||||||
|
|
||||||
|
async with session.post(self._rpc_url, json=payload) as response:
|
||||||
|
text = await response.text()
|
||||||
|
|
||||||
|
try:
|
||||||
|
body = json.loads(text)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
body = None
|
||||||
|
|
||||||
|
if body is None:
|
||||||
|
if response.status != 200:
|
||||||
|
raise Aria2Error(
|
||||||
|
f"aria2 RPC returned status {response.status} with non-JSON body: {text}"
|
||||||
|
)
|
||||||
|
raise Aria2Error(f"Invalid aria2 RPC response: {text}")
|
||||||
|
|
||||||
|
if "error" in body:
|
||||||
|
error = body["error"] or {}
|
||||||
|
code = error.get("code") if isinstance(error, dict) else None
|
||||||
|
message = error.get("message") if isinstance(error, dict) else str(error)
|
||||||
|
logger.error(
|
||||||
|
"aria2 RPC %s failed with HTTP %s, code=%s, message=%s",
|
||||||
|
method,
|
||||||
|
response.status,
|
||||||
|
code,
|
||||||
|
message,
|
||||||
|
)
|
||||||
|
status_message = (
|
||||||
|
f"aria2 RPC {method} failed with status {response.status}: {message}"
|
||||||
|
if response.status != 200
|
||||||
|
else message
|
||||||
|
)
|
||||||
|
raise Aria2Error(status_message or "Unknown aria2 RPC error")
|
||||||
|
|
||||||
|
if response.status != 200:
|
||||||
|
logger.error(
|
||||||
|
"aria2 RPC %s returned unexpected HTTP status %s without error payload: %s",
|
||||||
|
method,
|
||||||
|
response.status,
|
||||||
|
body,
|
||||||
|
)
|
||||||
|
raise Aria2Error(
|
||||||
|
f"aria2 RPC {method} returned unexpected status {response.status}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return body.get("result")
|
||||||
|
|
||||||
|
async def _get_rpc_session(self) -> aiohttp.ClientSession:
|
||||||
|
if self._rpc_session is None or self._rpc_session.closed:
|
||||||
|
async with self._rpc_session_lock:
|
||||||
|
if self._rpc_session is None or self._rpc_session.closed:
|
||||||
|
timeout = aiohttp.ClientTimeout(total=30)
|
||||||
|
self._rpc_session = aiohttp.ClientSession(timeout=timeout)
|
||||||
|
return self._rpc_session
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _find_free_port() -> int:
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||||
|
sock.bind(("127.0.0.1", 0))
|
||||||
|
sock.listen(1)
|
||||||
|
return int(sock.getsockname()[1])
|
||||||
|
|
||||||
|
|
||||||
|
async def get_aria2_downloader() -> Aria2Downloader:
|
||||||
|
"""Get the singleton aria2 downloader."""
|
||||||
|
|
||||||
|
return await Aria2Downloader.get_instance()
|
||||||
@@ -5,6 +5,7 @@ import asyncio
|
|||||||
import inspect
|
import inspect
|
||||||
import shutil
|
import shutil
|
||||||
import zipfile
|
import zipfile
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, List, Optional, Set, Tuple
|
from typing import Dict, List, Optional, Set, Tuple
|
||||||
@@ -25,6 +26,7 @@ from .service_registry import ServiceRegistry
|
|||||||
from .settings_manager import get_settings_manager
|
from .settings_manager import get_settings_manager
|
||||||
from .metadata_service import get_default_metadata_provider, get_metadata_provider
|
from .metadata_service import get_default_metadata_provider, get_metadata_provider
|
||||||
from .downloader import get_downloader, DownloadProgress, DownloadStreamControl
|
from .downloader import get_downloader, DownloadProgress, DownloadStreamControl
|
||||||
|
from .aria2_downloader import Aria2Error, get_aria2_downloader
|
||||||
|
|
||||||
# Download to temporary file first
|
# Download to temporary file first
|
||||||
import tempfile
|
import tempfile
|
||||||
@@ -60,6 +62,59 @@ class DownloadManager:
|
|||||||
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
|
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
|
||||||
self._download_tasks = {} # download_id -> asyncio.Task
|
self._download_tasks = {} # download_id -> asyncio.Task
|
||||||
self._pause_events: Dict[str, DownloadStreamControl] = {}
|
self._pause_events: Dict[str, DownloadStreamControl] = {}
|
||||||
|
self._archive_executor = ThreadPoolExecutor(
|
||||||
|
max_workers=2, thread_name_prefix="lm-archive"
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_model_download_backend() -> str:
|
||||||
|
backend = (get_settings_manager().get("download_backend") or "python").strip()
|
||||||
|
return backend.lower() or "python"
|
||||||
|
|
||||||
|
async def _download_model_file(
|
||||||
|
self,
|
||||||
|
download_url: str,
|
||||||
|
save_path: str,
|
||||||
|
*,
|
||||||
|
backend: str,
|
||||||
|
progress_callback,
|
||||||
|
use_auth: bool,
|
||||||
|
download_id: Optional[str],
|
||||||
|
pause_control: Optional[DownloadStreamControl],
|
||||||
|
) -> Tuple[bool, str]:
|
||||||
|
if backend == "aria2":
|
||||||
|
if not download_id:
|
||||||
|
return False, "aria2 downloads require a tracked download_id"
|
||||||
|
|
||||||
|
headers: Dict[str, str] = {}
|
||||||
|
if use_auth:
|
||||||
|
api_key = (get_settings_manager().get("civitai_api_key") or "").strip()
|
||||||
|
if api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
aria2_downloader = await get_aria2_downloader()
|
||||||
|
return await aria2_downloader.download_file(
|
||||||
|
download_url,
|
||||||
|
save_path,
|
||||||
|
download_id=download_id,
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
headers=headers or None,
|
||||||
|
)
|
||||||
|
except Aria2Error as exc:
|
||||||
|
logger.error("aria2 download failed for %s: %s", download_url, exc)
|
||||||
|
return False, str(exc)
|
||||||
|
|
||||||
|
download_kwargs = {
|
||||||
|
"progress_callback": progress_callback,
|
||||||
|
"use_auth": use_auth,
|
||||||
|
}
|
||||||
|
|
||||||
|
if pause_control is not None:
|
||||||
|
download_kwargs["pause_event"] = pause_control
|
||||||
|
|
||||||
|
downloader = await get_downloader()
|
||||||
|
return await downloader.download_file(download_url, save_path, **download_kwargs)
|
||||||
|
|
||||||
async def _get_lora_scanner(self):
|
async def _get_lora_scanner(self):
|
||||||
"""Get the lora scanner from registry"""
|
"""Get the lora scanner from registry"""
|
||||||
@@ -126,6 +181,7 @@ class DownloadManager:
|
|||||||
"model_version_id": model_version_id,
|
"model_version_id": model_version_id,
|
||||||
"progress": 0,
|
"progress": 0,
|
||||||
"status": "queued",
|
"status": "queued",
|
||||||
|
"transfer_backend": self._get_model_download_backend(),
|
||||||
"bytes_downloaded": 0,
|
"bytes_downloaded": 0,
|
||||||
"total_bytes": None,
|
"total_bytes": None,
|
||||||
"bytes_per_second": 0.0,
|
"bytes_per_second": 0.0,
|
||||||
@@ -240,6 +296,9 @@ class DownloadManager:
|
|||||||
tracking_callback,
|
tracking_callback,
|
||||||
use_default_paths,
|
use_default_paths,
|
||||||
task_id,
|
task_id,
|
||||||
|
self._active_downloads.get(task_id, {}).get(
|
||||||
|
"transfer_backend", "python"
|
||||||
|
),
|
||||||
source,
|
source,
|
||||||
file_params,
|
file_params,
|
||||||
)
|
)
|
||||||
@@ -294,6 +353,7 @@ class DownloadManager:
|
|||||||
progress_callback,
|
progress_callback,
|
||||||
use_default_paths,
|
use_default_paths,
|
||||||
download_id=None,
|
download_id=None,
|
||||||
|
transfer_backend="python",
|
||||||
source=None,
|
source=None,
|
||||||
file_params=None,
|
file_params=None,
|
||||||
):
|
):
|
||||||
@@ -696,16 +756,27 @@ class DownloadManager:
|
|||||||
logger.info(f"Creating EmbeddingMetadata for {file_name}")
|
logger.info(f"Creating EmbeddingMetadata for {file_name}")
|
||||||
|
|
||||||
# 6. Start download process
|
# 6. Start download process
|
||||||
result = await self._execute_download(
|
execute_kwargs = {
|
||||||
download_urls=download_urls,
|
"download_urls": download_urls,
|
||||||
save_dir=save_dir,
|
"save_dir": save_dir,
|
||||||
metadata=metadata,
|
"metadata": metadata,
|
||||||
version_info=version_info,
|
"version_info": version_info,
|
||||||
relative_path=relative_path,
|
"relative_path": relative_path,
|
||||||
progress_callback=progress_callback,
|
"progress_callback": progress_callback,
|
||||||
model_type=model_type,
|
"model_type": model_type,
|
||||||
download_id=download_id,
|
"download_id": download_id,
|
||||||
)
|
}
|
||||||
|
execute_signature = inspect.signature(self._execute_download)
|
||||||
|
if (
|
||||||
|
"transfer_backend" in execute_signature.parameters
|
||||||
|
or any(
|
||||||
|
parameter.kind == inspect.Parameter.VAR_KEYWORD
|
||||||
|
for parameter in execute_signature.parameters.values()
|
||||||
|
)
|
||||||
|
):
|
||||||
|
execute_kwargs["transfer_backend"] = transfer_backend
|
||||||
|
|
||||||
|
result = await self._execute_download(**execute_kwargs)
|
||||||
|
|
||||||
if result.get("success", False):
|
if result.get("success", False):
|
||||||
resolved_model_id = (
|
resolved_model_id = (
|
||||||
@@ -965,6 +1036,7 @@ class DownloadManager:
|
|||||||
progress_callback=None,
|
progress_callback=None,
|
||||||
model_type: str = "lora",
|
model_type: str = "lora",
|
||||||
download_id: str = None,
|
download_id: str = None,
|
||||||
|
transfer_backend: Optional[str] = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""Execute the actual download process including preview images and model files"""
|
"""Execute the actual download process including preview images and model files"""
|
||||||
metadata_entries: List = []
|
metadata_entries: List = []
|
||||||
@@ -974,6 +1046,7 @@ class DownloadManager:
|
|||||||
preview_targets: List[str] = []
|
preview_targets: List[str] = []
|
||||||
preview_path: str | None = None
|
preview_path: str | None = None
|
||||||
preview_nsfw_level = 0
|
preview_nsfw_level = 0
|
||||||
|
transfer_backend = (transfer_backend or self._get_model_download_backend()).lower()
|
||||||
try:
|
try:
|
||||||
# Extract original filename details
|
# Extract original filename details
|
||||||
original_filename = os.path.basename(metadata.file_path)
|
original_filename = os.path.basename(metadata.file_path)
|
||||||
@@ -1136,32 +1209,37 @@ class DownloadManager:
|
|||||||
if progress_callback:
|
if progress_callback:
|
||||||
await progress_callback(3) # 3% progress after preview download
|
await progress_callback(3) # 3% progress after preview download
|
||||||
|
|
||||||
# Download model file with progress tracking using downloader
|
# Download model file with progress tracking using the configured backend
|
||||||
downloader = await get_downloader()
|
downloader = None
|
||||||
if pause_control is not None:
|
if transfer_backend == "python":
|
||||||
pause_control.update_stall_timeout(downloader.stall_timeout)
|
downloader = await get_downloader()
|
||||||
|
if pause_control is not None:
|
||||||
|
pause_control.update_stall_timeout(downloader.stall_timeout)
|
||||||
|
if pause_control is not None and pause_control.is_paused():
|
||||||
|
if download_id and download_id in self._active_downloads:
|
||||||
|
self._active_downloads[download_id]["status"] = "paused"
|
||||||
|
self._active_downloads[download_id]["bytes_per_second"] = 0.0
|
||||||
|
await pause_control.wait()
|
||||||
|
if download_id and download_id in self._active_downloads:
|
||||||
|
self._active_downloads[download_id]["status"] = "downloading"
|
||||||
last_error = None
|
last_error = None
|
||||||
for download_url in download_urls:
|
for download_url in download_urls:
|
||||||
download_url = normalize_civitai_download_url(download_url)
|
download_url = normalize_civitai_download_url(download_url)
|
||||||
use_auth = download_url.startswith(CIVITAI_DOWNLOAD_URL_PREFIXES)
|
use_auth = download_url.startswith(CIVITAI_DOWNLOAD_URL_PREFIXES)
|
||||||
download_kwargs = {
|
success, result = await self._download_model_file(
|
||||||
"progress_callback": lambda progress, snapshot=None: (
|
download_url,
|
||||||
|
save_path,
|
||||||
|
backend=transfer_backend,
|
||||||
|
progress_callback=lambda progress, snapshot=None: (
|
||||||
self._handle_download_progress(
|
self._handle_download_progress(
|
||||||
progress,
|
progress,
|
||||||
progress_callback,
|
progress_callback,
|
||||||
snapshot,
|
snapshot,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
"use_auth": use_auth, # Only use authentication for Civitai downloads
|
use_auth=use_auth,
|
||||||
}
|
download_id=download_id,
|
||||||
|
pause_control=pause_control,
|
||||||
if pause_control is not None:
|
|
||||||
download_kwargs["pause_event"] = pause_control
|
|
||||||
|
|
||||||
success, result = await downloader.download_file(
|
|
||||||
download_url,
|
|
||||||
save_path, # Use full path instead of separate dir and filename
|
|
||||||
**download_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
@@ -1401,7 +1479,8 @@ class DownloadManager:
|
|||||||
extracted_files.append(dest_path)
|
extracted_files.append(dest_path)
|
||||||
return extracted_files
|
return extracted_files
|
||||||
|
|
||||||
return await asyncio.to_thread(_extract_sync)
|
loop = asyncio.get_running_loop()
|
||||||
|
return await loop.run_in_executor(self._archive_executor, _extract_sync)
|
||||||
|
|
||||||
async def _build_metadata_entries(
|
async def _build_metadata_entries(
|
||||||
self, base_metadata, file_paths: List[str]
|
self, base_metadata, file_paths: List[str]
|
||||||
@@ -1511,8 +1590,28 @@ class DownloadManager:
|
|||||||
return {"success": False, "error": "Download task not found"}
|
return {"success": False, "error": "Download task not found"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get the task and cancel it
|
|
||||||
task = self._download_tasks[download_id]
|
task = self._download_tasks[download_id]
|
||||||
|
backend = (
|
||||||
|
self._active_downloads.get(download_id, {}).get("transfer_backend")
|
||||||
|
or "python"
|
||||||
|
)
|
||||||
|
|
||||||
|
if backend == "aria2":
|
||||||
|
try:
|
||||||
|
aria2_downloader = await get_aria2_downloader()
|
||||||
|
cancel_result = await aria2_downloader.cancel_download(download_id)
|
||||||
|
if (
|
||||||
|
not cancel_result.get("success")
|
||||||
|
and cancel_result.get("error") != "Download task not found"
|
||||||
|
):
|
||||||
|
return cancel_result
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to cancel aria2 transfer for %s, continuing with local task cancellation: %s",
|
||||||
|
download_id,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
pause_control = self._pause_events.get(download_id)
|
pause_control = self._pause_events.get(download_id)
|
||||||
@@ -1613,6 +1712,28 @@ class DownloadManager:
|
|||||||
|
|
||||||
pause_control.pause()
|
pause_control.pause()
|
||||||
|
|
||||||
|
backend = (
|
||||||
|
self._active_downloads.get(download_id, {}).get("transfer_backend")
|
||||||
|
or "python"
|
||||||
|
)
|
||||||
|
if backend == "aria2":
|
||||||
|
try:
|
||||||
|
aria2_downloader = await get_aria2_downloader()
|
||||||
|
if await aria2_downloader.has_transfer(download_id):
|
||||||
|
result = await aria2_downloader.pause_download(download_id)
|
||||||
|
if not result.get("success"):
|
||||||
|
pause_control.resume()
|
||||||
|
return result
|
||||||
|
except Exception as exc:
|
||||||
|
pause_control.resume()
|
||||||
|
return {"success": False, "error": str(exc)}
|
||||||
|
|
||||||
|
download_info = self._active_downloads.get(download_id)
|
||||||
|
if download_info is not None:
|
||||||
|
download_info["status"] = "paused"
|
||||||
|
download_info["bytes_per_second"] = 0.0
|
||||||
|
return {"success": True, "message": "Download paused successfully"}
|
||||||
|
|
||||||
download_info = self._active_downloads.get(download_id)
|
download_info = self._active_downloads.get(download_id)
|
||||||
if download_info is not None:
|
if download_info is not None:
|
||||||
download_info["status"] = "paused"
|
download_info["status"] = "paused"
|
||||||
@@ -1631,6 +1752,28 @@ class DownloadManager:
|
|||||||
return {"success": False, "error": "Download is not paused"}
|
return {"success": False, "error": "Download is not paused"}
|
||||||
|
|
||||||
download_info = self._active_downloads.get(download_id)
|
download_info = self._active_downloads.get(download_id)
|
||||||
|
backend = (
|
||||||
|
self._active_downloads.get(download_id, {}).get("transfer_backend")
|
||||||
|
or "python"
|
||||||
|
)
|
||||||
|
if backend == "aria2":
|
||||||
|
try:
|
||||||
|
aria2_downloader = await get_aria2_downloader()
|
||||||
|
if await aria2_downloader.has_transfer(download_id):
|
||||||
|
result = await aria2_downloader.resume_download(download_id)
|
||||||
|
if not result.get("success"):
|
||||||
|
return result
|
||||||
|
except Exception as exc:
|
||||||
|
return {"success": False, "error": str(exc)}
|
||||||
|
|
||||||
|
pause_control.resume()
|
||||||
|
|
||||||
|
if download_info is not None:
|
||||||
|
if download_info.get("status") == "paused":
|
||||||
|
download_info["status"] = "downloading"
|
||||||
|
download_info.setdefault("bytes_per_second", 0.0)
|
||||||
|
return {"success": True, "message": "Download resumed successfully"}
|
||||||
|
|
||||||
force_reconnect = False
|
force_reconnect = False
|
||||||
if pause_control is not None:
|
if pause_control is not None:
|
||||||
elapsed = pause_control.time_since_last_progress()
|
elapsed = pause_control.time_since_last_progress()
|
||||||
|
|||||||
@@ -55,6 +55,8 @@ DEFAULT_KEYS_CLEANUP_THRESHOLD = 10
|
|||||||
DEFAULT_SETTINGS: Dict[str, Any] = {
|
DEFAULT_SETTINGS: Dict[str, Any] = {
|
||||||
"civitai_api_key": "",
|
"civitai_api_key": "",
|
||||||
"civitai_host": "civitai.com",
|
"civitai_host": "civitai.com",
|
||||||
|
"download_backend": "python",
|
||||||
|
"aria2c_path": "",
|
||||||
"use_portable_settings": False,
|
"use_portable_settings": False,
|
||||||
"hash_chunk_size_mb": DEFAULT_HASH_CHUNK_SIZE_MB,
|
"hash_chunk_size_mb": DEFAULT_HASH_CHUNK_SIZE_MB,
|
||||||
"language": "en",
|
"language": "en",
|
||||||
|
|||||||
@@ -807,6 +807,16 @@ export class SettingsManager {
|
|||||||
civitaiHostSelect.value = state.global.settings.civitai_host || 'civitai.com';
|
civitaiHostSelect.value = state.global.settings.civitai_host || 'civitai.com';
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const downloadBackendSelect = document.getElementById('downloadBackend');
|
||||||
|
if (downloadBackendSelect) {
|
||||||
|
downloadBackendSelect.value = state.global.settings.download_backend || 'python';
|
||||||
|
}
|
||||||
|
|
||||||
|
const aria2cPathInput = document.getElementById('aria2cPath');
|
||||||
|
if (aria2cPathInput) {
|
||||||
|
aria2cPathInput.value = state.global.settings.aria2c_path || '';
|
||||||
|
}
|
||||||
|
|
||||||
const recipesPathInput = document.getElementById('recipesPath');
|
const recipesPathInput = document.getElementById('recipesPath');
|
||||||
if (recipesPathInput) {
|
if (recipesPathInput) {
|
||||||
recipesPathInput.value = state.global.settings.recipes_path || '';
|
recipesPathInput.value = state.global.settings.recipes_path || '';
|
||||||
@@ -950,9 +960,36 @@ export class SettingsManager {
|
|||||||
languageSelect.value = currentLanguage;
|
languageSelect.value = currentLanguage;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
this.loadDownloadBackendSettings();
|
||||||
this.loadProxySettings();
|
this.loadProxySettings();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
loadDownloadBackendSettings() {
|
||||||
|
const downloadBackendSelect = document.getElementById('downloadBackend');
|
||||||
|
const aria2PathSetting = document.getElementById('aria2PathSetting');
|
||||||
|
const updateVisibility = () => {
|
||||||
|
if (!aria2PathSetting || !downloadBackendSelect) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
aria2PathSetting.style.display = downloadBackendSelect.value === 'aria2' ? 'block' : 'none';
|
||||||
|
};
|
||||||
|
|
||||||
|
if (downloadBackendSelect) {
|
||||||
|
downloadBackendSelect.value = state.global.settings.download_backend || 'python';
|
||||||
|
downloadBackendSelect.onchange = () => {
|
||||||
|
updateVisibility();
|
||||||
|
this.saveSelectSetting('downloadBackend', 'download_backend');
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const aria2cPathInput = document.getElementById('aria2cPath');
|
||||||
|
if (aria2cPathInput) {
|
||||||
|
aria2cPathInput.value = state.global.settings.aria2c_path || '';
|
||||||
|
}
|
||||||
|
|
||||||
|
updateVisibility();
|
||||||
|
}
|
||||||
|
|
||||||
setupPriorityTagInputs() {
|
setupPriorityTagInputs() {
|
||||||
['lora', 'checkpoint', 'embedding'].forEach((modelType) => {
|
['lora', 'checkpoint', 'embedding'].forEach((modelType) => {
|
||||||
const textarea = document.getElementById(`${modelType}PriorityTagsInput`);
|
const textarea = document.getElementById(`${modelType}PriorityTagsInput`);
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import { DEFAULT_PATH_TEMPLATES, DEFAULT_PRIORITY_TAG_CONFIG } from '../utils/co
|
|||||||
const DEFAULT_SETTINGS_BASE = Object.freeze({
|
const DEFAULT_SETTINGS_BASE = Object.freeze({
|
||||||
civitai_api_key: '',
|
civitai_api_key: '',
|
||||||
civitai_host: 'civitai.com',
|
civitai_host: 'civitai.com',
|
||||||
|
download_backend: 'python',
|
||||||
|
aria2c_path: '',
|
||||||
use_portable_settings: false,
|
use_portable_settings: false,
|
||||||
language: 'en',
|
language: 'en',
|
||||||
show_only_sfw: false,
|
show_only_sfw: false,
|
||||||
|
|||||||
@@ -129,6 +129,43 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div class="settings-subsection">
|
||||||
|
<div class="settings-subsection-header">
|
||||||
|
<h4>{{ t('settings.sections.downloads') }}</h4>
|
||||||
|
</div>
|
||||||
|
<div class="setting-item">
|
||||||
|
<div class="setting-row">
|
||||||
|
<div class="setting-info">
|
||||||
|
<label for="downloadBackend">{{ t('settings.downloadBackend.label') }}</label>
|
||||||
|
<i class="fas fa-info-circle info-icon" data-tooltip="{{ t('settings.downloadBackend.help') }}"></i>
|
||||||
|
</div>
|
||||||
|
<div class="setting-control select-control">
|
||||||
|
<select id="downloadBackend" onchange="settingsManager.saveSelectSetting('downloadBackend', 'download_backend')">
|
||||||
|
<option value="python">{{ t('settings.downloadBackend.options.python') }}</option>
|
||||||
|
<option value="aria2">{{ t('settings.downloadBackend.options.aria2') }}</option>
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="setting-item" id="aria2PathSetting" style="display: none;">
|
||||||
|
<div class="setting-row">
|
||||||
|
<div class="setting-info">
|
||||||
|
<label for="aria2cPath">{{ t('settings.aria2cPath.label') }}</label>
|
||||||
|
<i class="fas fa-info-circle info-icon" data-tooltip="{{ t('settings.aria2cPath.help') }}"></i>
|
||||||
|
</div>
|
||||||
|
<div class="setting-control">
|
||||||
|
<div class="text-input-wrapper">
|
||||||
|
<input type="text"
|
||||||
|
id="aria2cPath"
|
||||||
|
placeholder="{{ t('settings.aria2cPath.placeholder') }}"
|
||||||
|
onblur="settingsManager.saveInputSetting('aria2cPath', 'aria2c_path')"
|
||||||
|
onkeydown="if(event.key === 'Enter') { this.blur(); }" />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- Backup -->
|
<!-- Backup -->
|
||||||
<div class="settings-subsection">
|
<div class="settings-subsection">
|
||||||
<div class="settings-subsection-header">
|
<div class="settings-subsection-header">
|
||||||
|
|||||||
@@ -305,4 +305,39 @@ describe('SettingsManager library controls', () => {
|
|||||||
'success',
|
'success',
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('loads download backend settings and toggles the aria2 path field', () => {
|
||||||
|
const manager = createManager();
|
||||||
|
document.body.innerHTML = `
|
||||||
|
<select id="downloadBackend">
|
||||||
|
<option value="python">Python</option>
|
||||||
|
<option value="aria2">aria2</option>
|
||||||
|
</select>
|
||||||
|
<div id="aria2PathSetting" style="display: none;"></div>
|
||||||
|
<input id="aria2cPath" />
|
||||||
|
`;
|
||||||
|
|
||||||
|
state.global.settings = {
|
||||||
|
download_backend: 'aria2',
|
||||||
|
aria2c_path: '/usr/bin/aria2c',
|
||||||
|
};
|
||||||
|
|
||||||
|
const saveSpy = vi.spyOn(manager, 'saveSelectSetting').mockResolvedValue();
|
||||||
|
|
||||||
|
manager.loadDownloadBackendSettings();
|
||||||
|
|
||||||
|
const backendSelect = document.getElementById('downloadBackend');
|
||||||
|
const aria2PathSetting = document.getElementById('aria2PathSetting');
|
||||||
|
const aria2cPath = document.getElementById('aria2cPath');
|
||||||
|
|
||||||
|
expect(backendSelect.value).toBe('aria2');
|
||||||
|
expect(aria2cPath.value).toBe('/usr/bin/aria2c');
|
||||||
|
expect(aria2PathSetting.style.display).toBe('block');
|
||||||
|
|
||||||
|
backendSelect.value = 'python';
|
||||||
|
backendSelect.onchange();
|
||||||
|
|
||||||
|
expect(aria2PathSetting.style.display).toBe('none');
|
||||||
|
expect(saveSpy).toHaveBeenCalledWith('downloadBackend', 'download_backend');
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
269
tests/services/test_aria2_downloader.py
Normal file
269
tests/services/test_aria2_downloader.py
Normal file
@@ -0,0 +1,269 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from py.services.aria2_downloader import Aria2Downloader, Aria2Error
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_file_polls_until_complete(tmp_path, monkeypatch):
|
||||||
|
downloader = Aria2Downloader()
|
||||||
|
downloader._rpc_url = "http://127.0.0.1/jsonrpc"
|
||||||
|
downloader._rpc_secret = "secret"
|
||||||
|
|
||||||
|
save_path = tmp_path / "downloads" / "model.safetensors"
|
||||||
|
progress_events = []
|
||||||
|
rpc_calls = []
|
||||||
|
statuses = iter(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"gid": "gid-1",
|
||||||
|
"status": "active",
|
||||||
|
"completedLength": "5",
|
||||||
|
"totalLength": "10",
|
||||||
|
"downloadSpeed": "25",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"gid": "gid-1",
|
||||||
|
"status": "complete",
|
||||||
|
"completedLength": "10",
|
||||||
|
"totalLength": "10",
|
||||||
|
"downloadSpeed": "0",
|
||||||
|
"files": [{"path": str(save_path)}],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def fake_rpc_call(method, params):
|
||||||
|
rpc_calls.append((method, params))
|
||||||
|
if method == "aria2.addUri":
|
||||||
|
return "gid-1"
|
||||||
|
if method == "aria2.tellStatus":
|
||||||
|
return next(statuses)
|
||||||
|
raise AssertionError(f"Unexpected RPC method: {method}")
|
||||||
|
|
||||||
|
monkeypatch.setattr(downloader, "_ensure_process", AsyncMock())
|
||||||
|
monkeypatch.setattr(
|
||||||
|
downloader,
|
||||||
|
"_resolve_authenticated_redirect_url",
|
||||||
|
AsyncMock(
|
||||||
|
return_value="https://signed.example.com/model.safetensors?token=abc"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call)
|
||||||
|
monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock())
|
||||||
|
|
||||||
|
async def progress_callback(progress, snapshot=None):
|
||||||
|
progress_events.append(snapshot.percent_complete if snapshot else progress)
|
||||||
|
|
||||||
|
success, result = await downloader.download_file(
|
||||||
|
"https://civitai.com/api/download/models/123",
|
||||||
|
str(save_path),
|
||||||
|
download_id="download-1",
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
headers={"Authorization": "Bearer token"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert success is True
|
||||||
|
assert result == str(save_path)
|
||||||
|
assert progress_events == [50.0, 100.0]
|
||||||
|
assert downloader._transfers == {}
|
||||||
|
assert rpc_calls[0][0] == "aria2.addUri"
|
||||||
|
assert rpc_calls[0][1][0] == [
|
||||||
|
"https://signed.example.com/model.safetensors?token=abc"
|
||||||
|
]
|
||||||
|
assert rpc_calls[0][1][1]["out"] == "model.safetensors"
|
||||||
|
assert "header" not in rpc_calls[0][1][1]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_file_keeps_auth_headers_when_civitai_does_not_redirect(
|
||||||
|
tmp_path, monkeypatch
|
||||||
|
):
|
||||||
|
downloader = Aria2Downloader()
|
||||||
|
downloader._rpc_url = "http://127.0.0.1/jsonrpc"
|
||||||
|
downloader._rpc_secret = "secret"
|
||||||
|
|
||||||
|
save_path = tmp_path / "downloads" / "model.safetensors"
|
||||||
|
rpc_calls = []
|
||||||
|
statuses = iter(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"gid": "gid-1",
|
||||||
|
"status": "complete",
|
||||||
|
"completedLength": "10",
|
||||||
|
"totalLength": "10",
|
||||||
|
"downloadSpeed": "0",
|
||||||
|
"files": [{"path": str(save_path)}],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def fake_rpc_call(method, params):
|
||||||
|
rpc_calls.append((method, params))
|
||||||
|
if method == "aria2.addUri":
|
||||||
|
return "gid-1"
|
||||||
|
if method == "aria2.tellStatus":
|
||||||
|
return next(statuses)
|
||||||
|
raise AssertionError(f"Unexpected RPC method: {method}")
|
||||||
|
|
||||||
|
monkeypatch.setattr(downloader, "_ensure_process", AsyncMock())
|
||||||
|
monkeypatch.setattr(
|
||||||
|
downloader,
|
||||||
|
"_resolve_authenticated_redirect_url",
|
||||||
|
AsyncMock(return_value="https://civitai.com/api/download/models/123"),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call)
|
||||||
|
monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock())
|
||||||
|
|
||||||
|
success, result = await downloader.download_file(
|
||||||
|
"https://civitai.com/api/download/models/123",
|
||||||
|
str(save_path),
|
||||||
|
download_id="download-1",
|
||||||
|
headers={"Authorization": "Bearer token"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert success is True
|
||||||
|
assert result == str(save_path)
|
||||||
|
assert rpc_calls[0][1][0] == ["https://civitai.com/api/download/models/123"]
|
||||||
|
assert rpc_calls[0][1][1]["header"] == ["Authorization: Bearer token"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pause_resume_cancel_forward_to_rpc(monkeypatch):
|
||||||
|
downloader = Aria2Downloader()
|
||||||
|
downloader._transfers["download-1"] = type(
|
||||||
|
"Transfer", (), {"gid": "gid-1", "save_path": "/tmp/model.safetensors"}
|
||||||
|
)()
|
||||||
|
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
async def fake_rpc_call(method, params):
|
||||||
|
calls.append((method, params))
|
||||||
|
return "gid-1"
|
||||||
|
|
||||||
|
monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call)
|
||||||
|
|
||||||
|
pause_result = await downloader.pause_download("download-1")
|
||||||
|
resume_result = await downloader.resume_download("download-1")
|
||||||
|
cancel_result = await downloader.cancel_download("download-1")
|
||||||
|
|
||||||
|
assert pause_result["success"] is True
|
||||||
|
assert resume_result["success"] is True
|
||||||
|
assert cancel_result["success"] is True
|
||||||
|
assert calls == [
|
||||||
|
("aria2.forcePause", ["gid-1"]),
|
||||||
|
("aria2.unpause", ["gid-1"]),
|
||||||
|
("aria2.forceRemove", ["gid-1"]),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_progress_snapshot_normalizes_numeric_fields():
|
||||||
|
downloader = Aria2Downloader()
|
||||||
|
|
||||||
|
snapshot = downloader._build_progress_snapshot(
|
||||||
|
{
|
||||||
|
"completedLength": "75",
|
||||||
|
"totalLength": "100",
|
||||||
|
"downloadSpeed": "512",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert snapshot.percent_complete == 75.0
|
||||||
|
assert snapshot.bytes_downloaded == 75
|
||||||
|
assert snapshot.total_bytes == 100
|
||||||
|
assert snapshot.bytes_per_second == 512.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_executable_raises_when_binary_missing(monkeypatch):
|
||||||
|
downloader = Aria2Downloader()
|
||||||
|
settings = type("Settings", (), {"get": lambda self, key, default=None: ""})()
|
||||||
|
|
||||||
|
monkeypatch.setattr("py.services.aria2_downloader.get_settings_manager", lambda: settings)
|
||||||
|
monkeypatch.setattr("py.services.aria2_downloader.shutil.which", lambda _: None)
|
||||||
|
|
||||||
|
with pytest.raises(Aria2Error):
|
||||||
|
downloader._resolve_executable()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rpc_call_surfaces_json_error_on_non_200(monkeypatch):
|
||||||
|
downloader = Aria2Downloader()
|
||||||
|
downloader._rpc_url = "http://127.0.0.1:6800/jsonrpc"
|
||||||
|
downloader._rpc_secret = "secret"
|
||||||
|
|
||||||
|
class FakeResponse:
|
||||||
|
status = 400
|
||||||
|
|
||||||
|
async def text(self):
|
||||||
|
return (
|
||||||
|
'{"jsonrpc":"2.0","id":"x","error":{"code":1,"message":"Unauthorized"}}'
|
||||||
|
)
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
class FakeSession:
|
||||||
|
def post(self, _url, json=None):
|
||||||
|
return FakeResponse()
|
||||||
|
|
||||||
|
monkeypatch.setattr(downloader, "_get_rpc_session", AsyncMock(return_value=FakeSession()))
|
||||||
|
|
||||||
|
with pytest.raises(Aria2Error) as exc_info:
|
||||||
|
await downloader._rpc_call("aria2.addUri", [["https://example.com/file"]])
|
||||||
|
|
||||||
|
assert "Unauthorized" in str(exc_info.value)
|
||||||
|
assert "aria2.addUri" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resolve_authenticated_redirect_url_returns_location(monkeypatch):
|
||||||
|
downloader = Aria2Downloader()
|
||||||
|
|
||||||
|
class FakeResponse:
|
||||||
|
status = 307
|
||||||
|
headers = {"Location": "https://signed.example.com/file.safetensors"}
|
||||||
|
|
||||||
|
async def text(self):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
class FakeSession:
|
||||||
|
def get(self, _url, headers=None, allow_redirects=False, proxy=None):
|
||||||
|
return FakeResponse()
|
||||||
|
|
||||||
|
class FakeDownloader:
|
||||||
|
default_headers = {"User-Agent": "ComfyUI-LoRA-Manager/1.0"}
|
||||||
|
proxy_url = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def session(self):
|
||||||
|
async def _session():
|
||||||
|
return FakeSession()
|
||||||
|
|
||||||
|
return _session()
|
||||||
|
|
||||||
|
fake_downloader = FakeDownloader()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"py.services.aria2_downloader.get_downloader",
|
||||||
|
AsyncMock(return_value=fake_downloader),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await downloader._resolve_authenticated_redirect_url(
|
||||||
|
"https://civitai.com/api/download/models/123",
|
||||||
|
{"Authorization": "Bearer token"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == "https://signed.example.com/file.safetensors"
|
||||||
@@ -179,6 +179,7 @@ async def test_successful_download_uses_defaults(
|
|||||||
progress_callback,
|
progress_callback,
|
||||||
model_type,
|
model_type,
|
||||||
download_id,
|
download_id,
|
||||||
|
transfer_backend=None,
|
||||||
):
|
):
|
||||||
captured.update(
|
captured.update(
|
||||||
{
|
{
|
||||||
@@ -268,6 +269,7 @@ async def test_download_uses_active_mirrors(
|
|||||||
progress_callback,
|
progress_callback,
|
||||||
model_type,
|
model_type,
|
||||||
download_id,
|
download_id,
|
||||||
|
transfer_backend=None,
|
||||||
):
|
):
|
||||||
captured["download_urls"] = download_urls
|
captured["download_urls"] = download_urls
|
||||||
return {"success": True}
|
return {"success": True}
|
||||||
@@ -288,6 +290,214 @@ async def test_download_uses_active_mirrors(
|
|||||||
assert captured["download_urls"] == ["https://mirror.example/file.safetensors"]
|
assert captured["download_urls"] == ["https://mirror.example/file.safetensors"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pause_resume_cancel_delegate_to_aria2_backend(monkeypatch):
|
||||||
|
manager = DownloadManager()
|
||||||
|
|
||||||
|
task = asyncio.create_task(asyncio.sleep(60))
|
||||||
|
manager._download_tasks["download-1"] = task
|
||||||
|
manager._pause_events["download-1"] = download_manager.DownloadStreamControl()
|
||||||
|
manager._active_downloads["download-1"] = {
|
||||||
|
"transfer_backend": "aria2",
|
||||||
|
"status": "downloading",
|
||||||
|
}
|
||||||
|
|
||||||
|
class DummyAria2Downloader:
|
||||||
|
def __init__(self):
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
async def pause_download(self, download_id):
|
||||||
|
self.calls.append(("pause", download_id))
|
||||||
|
return {"success": True, "message": "paused"}
|
||||||
|
|
||||||
|
async def resume_download(self, download_id):
|
||||||
|
self.calls.append(("resume", download_id))
|
||||||
|
return {"success": True, "message": "resumed"}
|
||||||
|
|
||||||
|
async def cancel_download(self, download_id):
|
||||||
|
self.calls.append(("cancel", download_id))
|
||||||
|
return {"success": True, "message": "cancelled"}
|
||||||
|
|
||||||
|
async def has_transfer(self, download_id):
|
||||||
|
self.calls.append(("has_transfer", download_id))
|
||||||
|
return True
|
||||||
|
|
||||||
|
dummy_aria2 = DummyAria2Downloader()
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_manager,
|
||||||
|
"get_aria2_downloader",
|
||||||
|
AsyncMock(return_value=dummy_aria2),
|
||||||
|
)
|
||||||
|
|
||||||
|
pause_result = await manager.pause_download("download-1")
|
||||||
|
assert pause_result["success"] is True
|
||||||
|
assert manager._active_downloads["download-1"]["status"] == "paused"
|
||||||
|
|
||||||
|
resume_result = await manager.resume_download("download-1")
|
||||||
|
assert resume_result["success"] is True
|
||||||
|
assert manager._active_downloads["download-1"]["status"] == "downloading"
|
||||||
|
|
||||||
|
cancel_result = await manager.cancel_download("download-1")
|
||||||
|
assert cancel_result["success"] is True
|
||||||
|
assert task.cancelled() or task.done()
|
||||||
|
assert dummy_aria2.calls == [
|
||||||
|
("has_transfer", "download-1"),
|
||||||
|
("pause", "download-1"),
|
||||||
|
("has_transfer", "download-1"),
|
||||||
|
("resume", "download-1"),
|
||||||
|
("cancel", "download-1"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_allows_queued_aria2_task_without_transfer(monkeypatch):
|
||||||
|
manager = DownloadManager()
|
||||||
|
|
||||||
|
started = asyncio.Event()
|
||||||
|
|
||||||
|
async def blocked_task():
|
||||||
|
started.set()
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
|
||||||
|
task = asyncio.create_task(blocked_task())
|
||||||
|
await started.wait()
|
||||||
|
|
||||||
|
manager._download_tasks["download-queued"] = task
|
||||||
|
manager._active_downloads["download-queued"] = {
|
||||||
|
"transfer_backend": "aria2",
|
||||||
|
"status": "queued",
|
||||||
|
}
|
||||||
|
|
||||||
|
class DummyAria2Downloader:
|
||||||
|
async def cancel_download(self, download_id):
|
||||||
|
return {"success": False, "error": "Download task not found"}
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_manager,
|
||||||
|
"get_aria2_downloader",
|
||||||
|
AsyncMock(return_value=DummyAria2Downloader()),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await manager.cancel_download("download-queued")
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert task.cancelled() or task.done()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pause_resume_queued_aria2_task_without_transfer(monkeypatch):
|
||||||
|
manager = DownloadManager()
|
||||||
|
|
||||||
|
task = asyncio.create_task(asyncio.sleep(60))
|
||||||
|
manager._download_tasks["download-queued"] = task
|
||||||
|
manager._pause_events["download-queued"] = download_manager.DownloadStreamControl()
|
||||||
|
manager._active_downloads["download-queued"] = {
|
||||||
|
"transfer_backend": "aria2",
|
||||||
|
"status": "waiting",
|
||||||
|
"bytes_per_second": 12.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
class DummyAria2Downloader:
|
||||||
|
def __init__(self):
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
async def has_transfer(self, download_id):
|
||||||
|
self.calls.append(("has_transfer", download_id))
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def pause_download(self, download_id):
|
||||||
|
self.calls.append(("pause", download_id))
|
||||||
|
return {"success": True, "message": "paused"}
|
||||||
|
|
||||||
|
async def resume_download(self, download_id):
|
||||||
|
self.calls.append(("resume", download_id))
|
||||||
|
return {"success": True, "message": "resumed"}
|
||||||
|
|
||||||
|
dummy_aria2 = DummyAria2Downloader()
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_manager,
|
||||||
|
"get_aria2_downloader",
|
||||||
|
AsyncMock(return_value=dummy_aria2),
|
||||||
|
)
|
||||||
|
|
||||||
|
pause_result = await manager.pause_download("download-queued")
|
||||||
|
assert pause_result == {"success": True, "message": "Download paused successfully"}
|
||||||
|
assert manager._active_downloads["download-queued"]["status"] == "paused"
|
||||||
|
assert manager._pause_events["download-queued"].is_paused() is True
|
||||||
|
|
||||||
|
resume_result = await manager.resume_download("download-queued")
|
||||||
|
assert resume_result == {"success": True, "message": "Download resumed successfully"}
|
||||||
|
assert manager._active_downloads["download-queued"]["status"] == "downloading"
|
||||||
|
assert manager._pause_events["download-queued"].is_set() is True
|
||||||
|
assert dummy_aria2.calls == [
|
||||||
|
("has_transfer", "download-queued"),
|
||||||
|
("has_transfer", "download-queued"),
|
||||||
|
]
|
||||||
|
|
||||||
|
task.cancel()
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_uses_captured_backend_when_settings_change(
|
||||||
|
monkeypatch, scanners, metadata_provider, tmp_path
|
||||||
|
):
|
||||||
|
manager = DownloadManager()
|
||||||
|
settings = get_settings_manager()
|
||||||
|
settings.settings["download_backend"] = "aria2"
|
||||||
|
|
||||||
|
semaphore = asyncio.Semaphore(0)
|
||||||
|
manager._download_semaphore = semaphore
|
||||||
|
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
async def fake_execute_original_download(
|
||||||
|
self,
|
||||||
|
model_id,
|
||||||
|
model_version_id,
|
||||||
|
save_dir,
|
||||||
|
relative_path,
|
||||||
|
progress_callback,
|
||||||
|
use_default_paths,
|
||||||
|
download_id=None,
|
||||||
|
transfer_backend="python",
|
||||||
|
source=None,
|
||||||
|
file_params=None,
|
||||||
|
):
|
||||||
|
captured["transfer_backend"] = transfer_backend
|
||||||
|
return {"success": True}
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
DownloadManager,
|
||||||
|
"_execute_original_download",
|
||||||
|
fake_execute_original_download,
|
||||||
|
)
|
||||||
|
|
||||||
|
download_task = asyncio.create_task(
|
||||||
|
manager.download_from_civitai(
|
||||||
|
model_version_id=99,
|
||||||
|
save_dir=str(tmp_path),
|
||||||
|
use_default_paths=True,
|
||||||
|
progress_callback=None,
|
||||||
|
source=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert len(manager._active_downloads) == 1
|
||||||
|
download_id = next(iter(manager._active_downloads))
|
||||||
|
assert manager._active_downloads[download_id]["transfer_backend"] == "aria2"
|
||||||
|
|
||||||
|
settings.settings["download_backend"] = "python"
|
||||||
|
semaphore.release()
|
||||||
|
|
||||||
|
result = await download_task
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert captured["transfer_backend"] == "aria2"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_download_aborts_when_version_exists(
|
async def test_download_aborts_when_version_exists(
|
||||||
monkeypatch, scanners, metadata_provider
|
monkeypatch, scanners, metadata_provider
|
||||||
|
|||||||
@@ -136,6 +136,190 @@ async def test_execute_download_retries_urls(monkeypatch, tmp_path):
|
|||||||
assert dummy_scanner.calls # ensure cache updated
|
assert dummy_scanner.calls # ensure cache updated
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_download_uses_aria2_backend_for_model_files(monkeypatch, tmp_path):
|
||||||
|
manager = DownloadManager()
|
||||||
|
settings = get_settings_manager()
|
||||||
|
settings.settings["download_backend"] = "aria2"
|
||||||
|
settings.settings["civitai_api_key"] = "secret-key"
|
||||||
|
|
||||||
|
save_dir = tmp_path / "downloads"
|
||||||
|
save_dir.mkdir()
|
||||||
|
target_path = save_dir / "file.safetensors"
|
||||||
|
|
||||||
|
class DummyMetadata:
|
||||||
|
def __init__(self, path: Path):
|
||||||
|
self.file_path = str(path)
|
||||||
|
self.sha256 = "sha256"
|
||||||
|
self.file_name = path.stem
|
||||||
|
self.preview_url = None
|
||||||
|
|
||||||
|
def generate_unique_filename(self, *_args, **_kwargs):
|
||||||
|
return os.path.basename(self.file_path)
|
||||||
|
|
||||||
|
def update_file_info(self, _path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {"file_path": self.file_path}
|
||||||
|
|
||||||
|
class DummyAria2Downloader:
|
||||||
|
def __init__(self):
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
async def download_file(
|
||||||
|
self,
|
||||||
|
url,
|
||||||
|
save_path,
|
||||||
|
*,
|
||||||
|
download_id,
|
||||||
|
progress_callback=None,
|
||||||
|
headers=None,
|
||||||
|
):
|
||||||
|
self.calls.append(
|
||||||
|
{
|
||||||
|
"url": url,
|
||||||
|
"save_path": save_path,
|
||||||
|
"download_id": download_id,
|
||||||
|
"headers": headers,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
Path(save_path).write_text("content")
|
||||||
|
return True, save_path
|
||||||
|
|
||||||
|
dummy_aria2 = DummyAria2Downloader()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_manager,
|
||||||
|
"get_aria2_downloader",
|
||||||
|
AsyncMock(return_value=dummy_aria2),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_manager,
|
||||||
|
"get_downloader",
|
||||||
|
AsyncMock(side_effect=AssertionError("python downloader should not be used")),
|
||||||
|
)
|
||||||
|
|
||||||
|
class DummyScanner:
|
||||||
|
async def add_model_to_cache(self, metadata_dict, relative_path):
|
||||||
|
return {"metadata": metadata_dict, "relative_path": relative_path}
|
||||||
|
|
||||||
|
dummy_scanner = DummyScanner()
|
||||||
|
monkeypatch.setattr(
|
||||||
|
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
DownloadManager,
|
||||||
|
"_get_checkpoint_scanner",
|
||||||
|
AsyncMock(return_value=dummy_scanner),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=dummy_scanner)
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
|
||||||
|
|
||||||
|
result = await manager._execute_download(
|
||||||
|
download_urls=["https://civitai.com/api/download/models/1"],
|
||||||
|
save_dir=str(save_dir),
|
||||||
|
metadata=DummyMetadata(target_path),
|
||||||
|
version_info={"images": []},
|
||||||
|
relative_path="",
|
||||||
|
progress_callback=None,
|
||||||
|
model_type="lora",
|
||||||
|
download_id="download-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {"success": True}
|
||||||
|
assert dummy_aria2.calls == [
|
||||||
|
{
|
||||||
|
"url": "https://civitai.com/api/download/models/1",
|
||||||
|
"save_path": str(target_path),
|
||||||
|
"download_id": "download-1",
|
||||||
|
"headers": {"Authorization": "Bearer secret-key"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_download_allows_anonymous_civitai_with_aria2(
|
||||||
|
monkeypatch, tmp_path
|
||||||
|
):
|
||||||
|
manager = DownloadManager()
|
||||||
|
settings = get_settings_manager()
|
||||||
|
settings.settings["download_backend"] = "aria2"
|
||||||
|
settings.settings["civitai_api_key"] = ""
|
||||||
|
|
||||||
|
save_dir = tmp_path / "downloads"
|
||||||
|
save_dir.mkdir()
|
||||||
|
target_path = save_dir / "file.safetensors"
|
||||||
|
|
||||||
|
class DummyMetadata:
|
||||||
|
def __init__(self, path: Path):
|
||||||
|
self.file_path = str(path)
|
||||||
|
self.sha256 = "sha256"
|
||||||
|
self.file_name = path.stem
|
||||||
|
self.preview_url = None
|
||||||
|
|
||||||
|
def generate_unique_filename(self, *_args, **_kwargs):
|
||||||
|
return os.path.basename(self.file_path)
|
||||||
|
|
||||||
|
def update_file_info(self, _path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {"file_path": self.file_path}
|
||||||
|
|
||||||
|
class DummyAria2Downloader:
|
||||||
|
def __init__(self):
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
async def download_file(
|
||||||
|
self,
|
||||||
|
url,
|
||||||
|
save_path,
|
||||||
|
*,
|
||||||
|
download_id,
|
||||||
|
progress_callback=None,
|
||||||
|
headers=None,
|
||||||
|
):
|
||||||
|
self.calls.append({"url": url, "headers": headers, "download_id": download_id})
|
||||||
|
Path(save_path).write_text("content")
|
||||||
|
return True, save_path
|
||||||
|
|
||||||
|
dummy_aria2 = DummyAria2Downloader()
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_manager,
|
||||||
|
"get_aria2_downloader",
|
||||||
|
AsyncMock(return_value=dummy_aria2),
|
||||||
|
)
|
||||||
|
|
||||||
|
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
|
||||||
|
|
||||||
|
result = await manager._execute_download(
|
||||||
|
download_urls=["https://civitai.com/api/download/models/1"],
|
||||||
|
save_dir=str(save_dir),
|
||||||
|
metadata=DummyMetadata(target_path),
|
||||||
|
version_info={"images": []},
|
||||||
|
relative_path="",
|
||||||
|
progress_callback=None,
|
||||||
|
model_type="lora",
|
||||||
|
download_id="download-2",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {"success": True}
|
||||||
|
assert dummy_aria2.calls == [
|
||||||
|
{
|
||||||
|
"url": "https://civitai.com/api/download/models/1",
|
||||||
|
"headers": None,
|
||||||
|
"download_id": "download-2",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_execute_download_adjusts_checkpoint_sub_type(monkeypatch, tmp_path):
|
async def test_execute_download_adjusts_checkpoint_sub_type(monkeypatch, tmp_path):
|
||||||
"""Test that checkpoint sub_type is adjusted during download."""
|
"""Test that checkpoint sub_type is adjusted during download."""
|
||||||
@@ -276,6 +460,13 @@ async def test_execute_download_extracts_zip_single_model(monkeypatch, tmp_path)
|
|||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader())
|
download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader())
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class ImmediateLoop:
|
||||||
|
async def run_in_executor(self, executor, func, *args):
|
||||||
|
return func(*args)
|
||||||
|
|
||||||
|
monkeypatch.setattr(download_manager.asyncio, "get_running_loop", lambda: ImmediateLoop())
|
||||||
|
|
||||||
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
|
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
|
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
|
||||||
@@ -344,6 +535,13 @@ async def test_execute_download_extracts_zip_multiple_models(monkeypatch, tmp_pa
|
|||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader())
|
download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader())
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class ImmediateLoop:
|
||||||
|
async def run_in_executor(self, executor, func, *args):
|
||||||
|
return func(*args)
|
||||||
|
|
||||||
|
monkeypatch.setattr(download_manager.asyncio, "get_running_loop", lambda: ImmediateLoop())
|
||||||
|
|
||||||
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
|
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
|
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
|
||||||
@@ -418,6 +616,13 @@ async def test_execute_download_extracts_zip_pt_embedding(monkeypatch, tmp_path)
|
|||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader())
|
download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader())
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class ImmediateLoop:
|
||||||
|
async def run_in_executor(self, executor, func, *args):
|
||||||
|
return func(*args)
|
||||||
|
|
||||||
|
monkeypatch.setattr(download_manager.asyncio, "get_running_loop", lambda: ImmediateLoop())
|
||||||
|
|
||||||
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
|
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=dummy_scanner)
|
ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=dummy_scanner)
|
||||||
@@ -446,6 +651,36 @@ async def test_execute_download_extracts_zip_pt_embedding(monkeypatch, tmp_path)
|
|||||||
assert dummy_scanner.add_model_to_cache.await_count == 1
|
assert dummy_scanner.add_model_to_cache.await_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_model_files_from_archive_uses_executor(monkeypatch, tmp_path):
|
||||||
|
manager = DownloadManager()
|
||||||
|
archive_path = tmp_path / "bundle.zip"
|
||||||
|
with zipfile.ZipFile(archive_path, "w") as archive:
|
||||||
|
archive.writestr("inner/model.safetensors", b"model")
|
||||||
|
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
class ImmediateLoop:
|
||||||
|
async def run_in_executor(self, executor, func, *args):
|
||||||
|
captured["executor"] = executor
|
||||||
|
return func(*args)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_manager.asyncio,
|
||||||
|
"get_running_loop",
|
||||||
|
lambda: ImmediateLoop(),
|
||||||
|
)
|
||||||
|
|
||||||
|
extracted = await manager._extract_model_files_from_archive(
|
||||||
|
str(archive_path),
|
||||||
|
{".safetensors"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert captured["executor"] is manager._archive_executor
|
||||||
|
assert len(extracted) == 1
|
||||||
|
assert extracted[0].endswith("model.safetensors")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_pause_download_updates_state():
|
async def test_pause_download_updates_state():
|
||||||
"""Test that pause_download updates download state correctly."""
|
"""Test that pause_download updates download state correctly."""
|
||||||
@@ -469,6 +704,233 @@ async def test_pause_download_updates_state():
|
|||||||
assert manager._active_downloads[download_id]["bytes_per_second"] == 0.0
|
assert manager._active_downloads[download_id]["bytes_per_second"] == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pause_download_reverts_local_pause_when_aria2_pause_fails(monkeypatch):
|
||||||
|
manager = DownloadManager()
|
||||||
|
|
||||||
|
download_id = "dl"
|
||||||
|
manager._download_tasks[download_id] = object()
|
||||||
|
pause_control = DownloadStreamControl()
|
||||||
|
manager._pause_events[download_id] = pause_control
|
||||||
|
manager._active_downloads[download_id] = {
|
||||||
|
"transfer_backend": "aria2",
|
||||||
|
"status": "downloading",
|
||||||
|
"bytes_per_second": 42.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
class DummyAria2Downloader:
|
||||||
|
async def has_transfer(self, _download_id):
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def pause_download(self, _download_id):
|
||||||
|
return {"success": False, "error": "rpc failed"}
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_manager,
|
||||||
|
"get_aria2_downloader",
|
||||||
|
AsyncMock(return_value=DummyAria2Downloader()),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await manager.pause_download(download_id)
|
||||||
|
|
||||||
|
assert result == {"success": False, "error": "rpc failed"}
|
||||||
|
assert pause_control.is_set() is True
|
||||||
|
assert manager._active_downloads[download_id]["status"] == "downloading"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pause_download_reverts_local_pause_when_aria2_probe_raises(monkeypatch):
|
||||||
|
manager = DownloadManager()
|
||||||
|
|
||||||
|
download_id = "dl"
|
||||||
|
manager._download_tasks[download_id] = object()
|
||||||
|
pause_control = DownloadStreamControl()
|
||||||
|
manager._pause_events[download_id] = pause_control
|
||||||
|
manager._active_downloads[download_id] = {
|
||||||
|
"transfer_backend": "aria2",
|
||||||
|
"status": "downloading",
|
||||||
|
"bytes_per_second": 42.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
class DummyAria2Downloader:
|
||||||
|
async def has_transfer(self, _download_id):
|
||||||
|
raise RuntimeError("rpc unavailable")
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_manager,
|
||||||
|
"get_aria2_downloader",
|
||||||
|
AsyncMock(return_value=DummyAria2Downloader()),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await manager.pause_download(download_id)
|
||||||
|
|
||||||
|
assert result == {"success": False, "error": "rpc unavailable"}
|
||||||
|
assert pause_control.is_set() is True
|
||||||
|
assert manager._active_downloads[download_id]["status"] == "downloading"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resume_download_returns_error_when_aria2_probe_raises(monkeypatch):
|
||||||
|
manager = DownloadManager()
|
||||||
|
|
||||||
|
download_id = "dl"
|
||||||
|
pause_control = DownloadStreamControl()
|
||||||
|
pause_control.pause()
|
||||||
|
manager._pause_events[download_id] = pause_control
|
||||||
|
manager._active_downloads[download_id] = {
|
||||||
|
"transfer_backend": "aria2",
|
||||||
|
"status": "paused",
|
||||||
|
"bytes_per_second": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
class DummyAria2Downloader:
|
||||||
|
async def has_transfer(self, _download_id):
|
||||||
|
raise RuntimeError("rpc unavailable")
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_manager,
|
||||||
|
"get_aria2_downloader",
|
||||||
|
AsyncMock(return_value=DummyAria2Downloader()),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await manager.resume_download(download_id)
|
||||||
|
|
||||||
|
assert result == {"success": False, "error": "rpc unavailable"}
|
||||||
|
assert pause_control.is_paused() is True
|
||||||
|
assert manager._active_downloads[download_id]["status"] == "paused"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_download_still_cancels_local_task_when_aria2_raises(monkeypatch):
|
||||||
|
manager = DownloadManager()
|
||||||
|
|
||||||
|
started = asyncio.Event()
|
||||||
|
|
||||||
|
async def blocked_task():
|
||||||
|
started.set()
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
|
||||||
|
task = asyncio.create_task(blocked_task())
|
||||||
|
await started.wait()
|
||||||
|
|
||||||
|
download_id = "download-queued"
|
||||||
|
manager._download_tasks[download_id] = task
|
||||||
|
manager._active_downloads[download_id] = {
|
||||||
|
"transfer_backend": "aria2",
|
||||||
|
"status": "queued",
|
||||||
|
}
|
||||||
|
|
||||||
|
class DummyAria2Downloader:
|
||||||
|
async def cancel_download(self, _download_id):
|
||||||
|
raise RuntimeError("rpc unavailable")
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_manager,
|
||||||
|
"get_aria2_downloader",
|
||||||
|
AsyncMock(return_value=DummyAria2Downloader()),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await manager.cancel_download(download_id)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert task.cancelled() or task.done()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_download_waits_for_paused_pre_transfer_gate(monkeypatch, tmp_path):
|
||||||
|
manager = DownloadManager()
|
||||||
|
|
||||||
|
save_dir = tmp_path / "downloads"
|
||||||
|
save_dir.mkdir()
|
||||||
|
target_path = save_dir / "file.safetensors"
|
||||||
|
|
||||||
|
class DummyMetadata:
|
||||||
|
def __init__(self, path: Path):
|
||||||
|
self.file_path = str(path)
|
||||||
|
self.sha256 = "sha256"
|
||||||
|
self.file_name = path.stem
|
||||||
|
self.preview_url = None
|
||||||
|
|
||||||
|
def generate_unique_filename(self, *_args, **_kwargs):
|
||||||
|
return os.path.basename(self.file_path)
|
||||||
|
|
||||||
|
def update_file_info(self, _path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {"file_path": self.file_path}
|
||||||
|
|
||||||
|
pause_control = DownloadStreamControl()
|
||||||
|
pause_control.pause()
|
||||||
|
manager._pause_events["download-1"] = pause_control
|
||||||
|
manager._active_downloads["download-1"] = {
|
||||||
|
"status": "downloading",
|
||||||
|
"bytes_per_second": 42.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
|
||||||
|
|
||||||
|
started = asyncio.Event()
|
||||||
|
allow_finish = asyncio.Event()
|
||||||
|
captured = {"calls": 0}
|
||||||
|
|
||||||
|
async def fake_download_model_file(
|
||||||
|
self,
|
||||||
|
download_url,
|
||||||
|
save_path,
|
||||||
|
*,
|
||||||
|
backend,
|
||||||
|
progress_callback,
|
||||||
|
use_auth,
|
||||||
|
download_id,
|
||||||
|
pause_control,
|
||||||
|
):
|
||||||
|
captured["calls"] += 1
|
||||||
|
started.set()
|
||||||
|
await allow_finish.wait()
|
||||||
|
Path(save_path).write_text("content")
|
||||||
|
return True, save_path
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
DownloadManager,
|
||||||
|
"_download_model_file",
|
||||||
|
fake_download_model_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
task = asyncio.create_task(
|
||||||
|
manager._execute_download(
|
||||||
|
download_urls=["https://civitai.com/api/download/models/1"],
|
||||||
|
save_dir=str(save_dir),
|
||||||
|
metadata=DummyMetadata(target_path),
|
||||||
|
version_info={"images": []},
|
||||||
|
relative_path="",
|
||||||
|
progress_callback=None,
|
||||||
|
model_type="lora",
|
||||||
|
download_id="download-1",
|
||||||
|
transfer_backend="aria2",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert started.is_set() is False
|
||||||
|
assert captured["calls"] == 0
|
||||||
|
assert manager._active_downloads["download-1"]["status"] == "paused"
|
||||||
|
|
||||||
|
pause_control.resume()
|
||||||
|
await asyncio.wait_for(started.wait(), timeout=1.0)
|
||||||
|
assert captured["calls"] == 1
|
||||||
|
assert manager._active_downloads["download-1"]["status"] == "downloading"
|
||||||
|
|
||||||
|
allow_finish.set()
|
||||||
|
result = await task
|
||||||
|
|
||||||
|
assert result == {"success": True}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_pause_download_rejects_unknown_task():
|
async def test_pause_download_rejects_unknown_task():
|
||||||
"""Test that pause_download rejects unknown download tasks."""
|
"""Test that pause_download rejects unknown download tasks."""
|
||||||
|
|||||||
@@ -147,6 +147,11 @@ def test_environment_variable_overrides_settings(tmp_path, monkeypatch):
|
|||||||
assert mgr.get("civitai_api_key") == "secret"
|
assert mgr.get("civitai_api_key") == "secret"
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_download_backend_is_python(manager):
|
||||||
|
assert manager.get("download_backend") == "python"
|
||||||
|
assert manager.get("aria2c_path") == ""
|
||||||
|
|
||||||
|
|
||||||
def _create_manager_with_settings(
|
def _create_manager_with_settings(
|
||||||
tmp_path, monkeypatch, initial_settings, *, save_spy=None
|
tmp_path, monkeypatch, initial_settings, *, save_spy=None
|
||||||
):
|
):
|
||||||
|
|||||||
Reference in New Issue
Block a user