diff --git a/locales/de.json b/locales/de.json index 8584a12d..bd795d45 100644 --- a/locales/de.json +++ b/locales/de.json @@ -263,6 +263,19 @@ "red": "civitai.red (uneingeschränkt)" } }, + "downloadBackend": { + "label": "Download-Backend", + "help": "Wähle aus, wie Modelldateien heruntergeladen werden. Python verwendet den eingebauten Downloader. aria2 verwendet den experimentellen externen Downloader-Prozess.", + "options": { + "python": "Python (integriert)", + "aria2": "aria2 (experimentell)" + } + }, + "aria2cPath": { + "label": "aria2c-Pfad", + "help": "Optionaler Pfad zur ausführbaren aria2c-Datei. Leer lassen, um aria2c aus dem System-PATH zu verwenden.", + "placeholder": "Leer lassen, um aria2c aus dem PATH zu verwenden" + }, "civitaiHostBanner": { "title": "Civitai-Host-Einstellung verfügbar", "content": "Civitai verwendet jetzt civitai.com für SFW-Inhalte und civitai.red für uneingeschränkte Inhalte. In den Einstellungen können Sie ändern, welche Seite standardmäßig geöffnet wird.", @@ -278,6 +291,7 @@ }, "sections": { "contentFiltering": "Inhaltsfilterung", + "downloads": "Downloads", "videoSettings": "Video-Einstellungen", "layoutSettings": "Layout-Einstellungen", "misc": "Verschiedenes", diff --git a/locales/en.json b/locales/en.json index e4272a48..a7e41e0f 100644 --- a/locales/en.json +++ b/locales/en.json @@ -263,6 +263,19 @@ "red": "civitai.red (unrestricted)" } }, + "downloadBackend": { + "label": "Download backend", + "help": "Choose how model files are downloaded. Python uses the built-in downloader. aria2 uses the experimental external downloader process.", + "options": { + "python": "Python (built-in)", + "aria2": "aria2 (experimental)" + } + }, + "aria2cPath": { + "label": "aria2c path", + "help": "Optional path to the aria2c executable. Leave empty to use aria2c from your system PATH.", + "placeholder": "Leave empty to use aria2c from PATH" + }, "civitaiHostBanner": { "title": "Civitai host preference available", "content": "Civitai now uses civitai.com for SFW content and civitai.red for unrestricted content. You can change which site opens by default in Settings.", @@ -278,6 +291,7 @@ }, "sections": { "contentFiltering": "Content Filtering", + "downloads": "Downloads", "videoSettings": "Video Settings", "layoutSettings": "Layout Settings", "misc": "Miscellaneous", diff --git a/locales/es.json b/locales/es.json index 45ba1491..8154b625 100644 --- a/locales/es.json +++ b/locales/es.json @@ -263,6 +263,19 @@ "red": "civitai.red (sin restricciones)" } }, + "downloadBackend": { + "label": "Backend de descarga", + "help": "Elige cómo se descargan los archivos del modelo. Python usa el descargador integrado. aria2 usa el proceso externo experimental de descarga.", + "options": { + "python": "Python (integrado)", + "aria2": "aria2 (experimental)" + } + }, + "aria2cPath": { + "label": "Ruta de aria2c", + "help": "Ruta opcional al ejecutable aria2c. Déjalo vacío para usar aria2c desde el PATH del sistema.", + "placeholder": "Déjalo vacío para usar aria2c desde el PATH" + }, "civitaiHostBanner": { "title": "Preferencia de host de Civitai disponible", "content": "Civitai ahora usa civitai.com para contenido SFW y civitai.red para contenido sin restricciones. Puedes cambiar en Ajustes qué sitio se abre por defecto.", @@ -278,6 +291,7 @@ }, "sections": { "contentFiltering": "Filtrado de contenido", + "downloads": "Descargas", "videoSettings": "Configuración de video", "layoutSettings": "Configuración de diseño", "misc": "Varios", diff --git a/locales/fr.json b/locales/fr.json index 15ac157c..fd156e34 100644 --- a/locales/fr.json +++ b/locales/fr.json @@ -263,6 +263,19 @@ "red": "civitai.red (sans restriction)" } }, + "downloadBackend": { + "label": "Moteur de téléchargement", + "help": "Choisissez comment les fichiers de modèles sont téléchargés. Python utilise le téléchargeur intégré. aria2 utilise le processus externe expérimental de téléchargement.", + "options": { + "python": "Python (intégré)", + "aria2": "aria2 (expérimental)" + } + }, + "aria2cPath": { + "label": "Chemin vers aria2c", + "help": "Chemin facultatif vers l’exécutable aria2c. Laissez vide pour utiliser aria2c depuis le PATH système.", + "placeholder": "Laisser vide pour utiliser aria2c depuis le PATH" + }, "civitaiHostBanner": { "title": "Préférence d’hôte Civitai disponible", "content": "Civitai utilise désormais civitai.com pour le contenu SFW et civitai.red pour le contenu sans restriction. Vous pouvez modifier dans les paramètres le site ouvert par défaut.", @@ -278,6 +291,7 @@ }, "sections": { "contentFiltering": "Filtrage du contenu", + "downloads": "Téléchargements", "videoSettings": "Paramètres vidéo", "layoutSettings": "Paramètres d'affichage", "misc": "Divers", diff --git a/locales/he.json b/locales/he.json index 8dfdab61..9cd794d5 100644 --- a/locales/he.json +++ b/locales/he.json @@ -263,6 +263,19 @@ "red": "civitai.red (ללא הגבלות)" } }, + "downloadBackend": { + "label": "מנגנון הורדה", + "help": "בחר כיצד יורדים קבצי המודל. Python משתמש במוריד המובנה. aria2 משתמש בתהליך הורדה חיצוני ניסיוני.", + "options": { + "python": "Python (מובנה)", + "aria2": "aria2 (ניסיוני)" + } + }, + "aria2cPath": { + "label": "נתיב aria2c", + "help": "נתיב אופציונלי לקובץ ההפעלה aria2c. השאר ריק כדי להשתמש ב-aria2c מתוך ה-PATH של המערכת.", + "placeholder": "השאר ריק כדי להשתמש ב-aria2c מתוך ה-PATH" + }, "civitaiHostBanner": { "title": "העדפת מארח Civitai זמינה", "content": "Civitai משתמש כעת ב-civitai.com עבור תוכן SFW וב-civitai.red עבור תוכן ללא הגבלות. ניתן לשנות בהגדרות איזה אתר ייפתח כברירת מחדל.", @@ -278,6 +291,7 @@ }, "sections": { "contentFiltering": "סינון תוכן", + "downloads": "הורדות", "videoSettings": "הגדרות וידאו", "layoutSettings": "הגדרות פריסה", "misc": "שונות", diff --git a/locales/ja.json b/locales/ja.json index e1d4448d..f5d286c2 100644 --- a/locales/ja.json +++ b/locales/ja.json @@ -263,6 +263,19 @@ "red": "civitai.red(制限なし)" } }, + "downloadBackend": { + "label": "ダウンロードバックエンド", + "help": "モデルファイルのダウンロード方法を選択します。Python は内蔵ダウンローダーを使用し、aria2 は実験的な外部ダウンローダープロセスを使用します。", + "options": { + "python": "Python(内蔵)", + "aria2": "aria2(実験的)" + } + }, + "aria2cPath": { + "label": "aria2c のパス", + "help": "aria2c 実行ファイルへの任意のパスです。空欄のままにすると、システム PATH 上の aria2c を使用します。", + "placeholder": "空欄のままにすると PATH 上の aria2c を使用します" + }, "civitaiHostBanner": { "title": "Civitai ホスト設定を利用できます", "content": "Civitai は現在、SFW コンテンツには civitai.com、制限なしコンテンツには civitai.red を使用しています。設定で既定で開くサイトを変更できます。", @@ -278,6 +291,7 @@ }, "sections": { "contentFiltering": "コンテンツフィルタリング", + "downloads": "ダウンロード", "videoSettings": "動画設定", "layoutSettings": "レイアウト設定", "misc": "その他", diff --git a/locales/ko.json b/locales/ko.json index 5a3f2ec0..2e1f40ca 100644 --- a/locales/ko.json +++ b/locales/ko.json @@ -263,6 +263,19 @@ "red": "civitai.red(무제한)" } }, + "downloadBackend": { + "label": "다운로드 백엔드", + "help": "모델 파일을 다운로드하는 방식을 선택합니다. Python은 내장 다운로더를 사용하고, aria2는 실험적인 외부 다운로더 프로세스를 사용합니다.", + "options": { + "python": "Python(내장)", + "aria2": "aria2(실험적)" + } + }, + "aria2cPath": { + "label": "aria2c 경로", + "help": "aria2c 실행 파일의 선택적 경로입니다. 비워 두면 시스템 PATH의 aria2c를 사용합니다.", + "placeholder": "비워 두면 PATH의 aria2c를 사용합니다" + }, "civitaiHostBanner": { "title": "Civitai 호스트 기본 설정 사용 가능", "content": "이제 Civitai는 SFW 콘텐츠에 civitai.com을, 무제한 콘텐츠에 civitai.red를 사용합니다. 설정에서 기본으로 열 사이트를 변경할 수 있습니다.", @@ -278,6 +291,7 @@ }, "sections": { "contentFiltering": "콘텐츠 필터링", + "downloads": "다운로드", "videoSettings": "비디오 설정", "layoutSettings": "레이아웃 설정", "misc": "기타", diff --git a/locales/ru.json b/locales/ru.json index cf0612fe..58eccf53 100644 --- a/locales/ru.json +++ b/locales/ru.json @@ -263,6 +263,19 @@ "red": "civitai.red (без ограничений)" } }, + "downloadBackend": { + "label": "Бэкенд загрузки", + "help": "Выберите способ загрузки файлов моделей. Python использует встроенный загрузчик. aria2 использует экспериментальный внешний процесс загрузки.", + "options": { + "python": "Python (встроенный)", + "aria2": "aria2 (экспериментальный)" + } + }, + "aria2cPath": { + "label": "Путь к aria2c", + "help": "Необязательный путь к исполняемому файлу aria2c. Оставьте пустым, чтобы использовать aria2c из системного PATH.", + "placeholder": "Оставьте пустым, чтобы использовать aria2c из PATH" + }, "civitaiHostBanner": { "title": "Доступна настройка хоста Civitai", "content": "Теперь Civitai использует civitai.com для контента SFW и civitai.red для контента без ограничений. В настройках можно изменить, какой сайт открывать по умолчанию.", @@ -278,6 +291,7 @@ }, "sections": { "contentFiltering": "Фильтрация контента", + "downloads": "Загрузки", "videoSettings": "Настройки видео", "layoutSettings": "Настройки макета", "misc": "Разное", diff --git a/locales/zh-CN.json b/locales/zh-CN.json index 898f550a..17644c63 100644 --- a/locales/zh-CN.json +++ b/locales/zh-CN.json @@ -263,6 +263,19 @@ "red": "civitai.red(无限制)" } }, + "downloadBackend": { + "label": "下载后端", + "help": "选择模型文件的下载方式。Python 使用内置下载器。aria2 使用实验性的外部下载进程。", + "options": { + "python": "Python(内置)", + "aria2": "aria2(实验性)" + } + }, + "aria2cPath": { + "label": "aria2c 路径", + "help": "可选的 aria2c 可执行文件路径。留空则使用系统 PATH 中的 aria2c。", + "placeholder": "留空则使用 PATH 中的 aria2c" + }, "civitaiHostBanner": { "title": "已提供 Civitai 站点偏好设置", "content": "Civitai 现在使用 civitai.com 提供 SFW 内容,使用 civitai.red 提供无限制内容。你可以在设置中更改默认打开的站点。", @@ -278,6 +291,7 @@ }, "sections": { "contentFiltering": "内容过滤", + "downloads": "下载", "videoSettings": "视频设置", "layoutSettings": "布局设置", "misc": "其他", diff --git a/locales/zh-TW.json b/locales/zh-TW.json index ef4077ba..c2139ef0 100644 --- a/locales/zh-TW.json +++ b/locales/zh-TW.json @@ -263,6 +263,19 @@ "red": "civitai.red(無限制)" } }, + "downloadBackend": { + "label": "下載後端", + "help": "選擇模型檔案的下載方式。Python 使用內建下載器。aria2 使用實驗性的外部下載程序。", + "options": { + "python": "Python(內建)", + "aria2": "aria2(實驗性)" + } + }, + "aria2cPath": { + "label": "aria2c 路徑", + "help": "可選的 aria2c 可執行檔路徑。留空則使用系統 PATH 中的 aria2c。", + "placeholder": "留空則使用 PATH 中的 aria2c" + }, "civitaiHostBanner": { "title": "已提供 Civitai 站點偏好設定", "content": "Civitai 現在使用 civitai.com 提供 SFW 內容,使用 civitai.red 提供無限制內容。你可以在設定中變更預設開啟的站點。", @@ -278,6 +291,7 @@ }, "sections": { "contentFiltering": "內容過濾", + "downloads": "下載", "videoSettings": "影片設定", "layoutSettings": "版面設定", "misc": "其他", diff --git a/py/config.py b/py/config.py index 228491d7..997e76c2 100644 --- a/py/config.py +++ b/py/config.py @@ -26,20 +26,44 @@ logger = logging.getLogger(__name__) def _resolve_valid_default_root( - current: str, primary_paths: List[str], name: str + current: str, primary_paths: List[str], allowed_paths: List[str], name: str ) -> str: - """Return a valid default root from the current primary path set.""" + """Return a valid default root from the current primary/extra path set.""" valid_paths = [path for path in primary_paths if isinstance(path, str) and path.strip()] - if not valid_paths: - return "" + fallback_paths: List[str] = [] + seen: Set[str] = set() + for path in allowed_paths: + if not isinstance(path, str): + continue + stripped = path.strip() + if not stripped or stripped in seen: + continue + seen.add(stripped) + fallback_paths.append(stripped) - if current in valid_paths: + allowed = set(fallback_paths) + + if current and current in allowed: return current + if not valid_paths: + if not fallback_paths: + return "" + if current: + logger.info( + "Repaired stale %s from '%s' to '%s' because it is not present in primary or extra roots", + name, + current, + fallback_paths[0], + ) + else: + logger.info("Auto-setting %s to '%s'", name, fallback_paths[0]) + return fallback_paths[0] + if current: logger.info( - "Repaired stale %s from '%s' to '%s'", + "Repaired stale %s from '%s' to '%s' because it is not present in primary or extra roots", name, current, valid_paths[0], @@ -226,39 +250,76 @@ class Config: default_lora_root = _resolve_valid_default_root( comfy_library.get("default_lora_root", ""), list(self.loras_roots or []), + list(self.loras_roots or []) + + list(comfy_library.get("extra_folder_paths", {}).get("loras", []) or []), "default_lora_root", ) default_checkpoint_root = _resolve_valid_default_root( comfy_library.get("default_checkpoint_root", ""), list(self.checkpoints_roots or []), + list(self.checkpoints_roots or []) + + list(comfy_library.get("extra_folder_paths", {}).get("checkpoints", []) or []), "default_checkpoint_root", ) default_embedding_root = _resolve_valid_default_root( comfy_library.get("default_embedding_root", ""), list(self.embeddings_roots or []), + list(self.embeddings_roots or []) + + list(comfy_library.get("extra_folder_paths", {}).get("embeddings", []) or []), "default_embedding_root", ) metadata = dict(comfy_library.get("metadata", {})) metadata.setdefault("display_name", "ComfyUI") metadata["source"] = "comfyui" + extra_folder_paths = {} + if isinstance(comfy_library, Mapping): + existing_extra_paths = comfy_library.get("extra_folder_paths", {}) + if isinstance(existing_extra_paths, Mapping): + extra_folder_paths = { + key: list(value) if isinstance(value, list) else [] + for key, value in existing_extra_paths.items() + } + + active_library_name = settings_service.get_active_library_name() + should_activate = ( + active_library_name == "comfyui" + or self._should_activate_comfy_library(libraries, libraries_changed) + ) settings_service.upsert_library( "comfyui", folder_paths=target_folder_paths, + extra_folder_paths=extra_folder_paths, default_lora_root=default_lora_root, default_checkpoint_root=default_checkpoint_root, default_embedding_root=default_embedding_root, metadata=metadata, - activate=True, + activate=should_activate, ) - logger.info("Updated 'comfyui' library with current folder paths") + if should_activate: + logger.info("Updated 'comfyui' library with current folder paths") + else: + logger.info( + "Updated 'comfyui' library with current folder paths without activating it" + ) except Exception as e: logger.warning(f"Failed to save folder paths: {e}") + def _should_activate_comfy_library( + self, libraries: Mapping[str, Any], libraries_changed: bool + ) -> bool: + """Return whether startup sync should make the ComfyUI library active.""" + + if libraries_changed: + return True + if not libraries: + return True + return "comfyui" in libraries and len(libraries) == 1 + def _is_link(self, path: str) -> bool: try: if os.path.islink(path): diff --git a/py/services/aria2_downloader.py b/py/services/aria2_downloader.py new file mode 100644 index 00000000..f50b6a1c --- /dev/null +++ b/py/services/aria2_downloader.py @@ -0,0 +1,570 @@ +from __future__ import annotations + +import asyncio +import json +import logging +import os +import secrets +import shutil +import socket +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + +import aiohttp + +from .downloader import DownloadProgress, get_downloader +from .aria2_transfer_state import Aria2TransferStateStore +from .settings_manager import get_settings_manager + +logger = logging.getLogger(__name__) + +CIVITAI_DOWNLOAD_URL_PREFIXES = ( + "https://civitai.com/api/download/", + "https://civitai.red/api/download/", +) + + +class Aria2Error(RuntimeError): + """Raised when aria2 integration fails.""" + + +@dataclass +class Aria2Transfer: + """Track an aria2 download registered by the Python coordinator.""" + + gid: str + save_path: str + + +class Aria2Downloader: + """Manage an aria2 RPC daemon for experimental model downloads.""" + + _instance = None + _lock = asyncio.Lock() + + @classmethod + async def get_instance(cls) -> "Aria2Downloader": + async with cls._lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def __init__(self) -> None: + if hasattr(self, "_initialized"): + return + + self._initialized = True + self._process: Optional[asyncio.subprocess.Process] = None + self._rpc_port: Optional[int] = None + self._rpc_secret = "" + self._rpc_url = "" + self._rpc_session: Optional[aiohttp.ClientSession] = None + self._rpc_session_lock = asyncio.Lock() + self._process_lock = asyncio.Lock() + self._transfers: Dict[str, Aria2Transfer] = {} + self._poll_interval = 0.5 + self._state_store = Aria2TransferStateStore() + + @property + def is_running(self) -> bool: + return self._process is not None and self._process.returncode is None + + async def download_file( + self, + url: str, + save_path: str, + *, + download_id: str, + progress_callback=None, + headers: Optional[Dict[str, str]] = None, + ) -> Tuple[bool, str]: + """Download a file using aria2 RPC and wait for completion.""" + + await self._ensure_process() + save_path = os.path.abspath(save_path) + transfer = self._transfers.get(download_id) + if transfer is None or os.path.abspath(transfer.save_path) != save_path: + gid = await self._schedule_download( + url, + save_path, + download_id=download_id, + headers=headers, + ) + transfer = Aria2Transfer(gid=gid, save_path=save_path) + self._transfers[download_id] = transfer + + try: + while True: + status = await self.get_status(download_id) + if status is None: + return False, "aria2 download not found" + + snapshot = self._build_progress_snapshot(status) + if progress_callback is not None: + await self._dispatch_progress(progress_callback, snapshot) + + state = status.get("status", "") + if state == "complete": + completed_path = self._resolve_completed_path(status, save_path) + return True, completed_path + if state == "error": + return False, status.get("errorMessage") or "aria2 download failed" + if state == "removed": + return False, "Download was cancelled" + + await asyncio.sleep(self._poll_interval) + finally: + self._transfers.pop(download_id, None) + + async def _schedule_download( + self, + url: str, + save_path: str, + *, + download_id: str, + headers: Optional[Dict[str, str]] = None, + ) -> str: + save_dir = os.path.dirname(save_path) + out_name = os.path.basename(save_path) + + Path(save_dir).mkdir(parents=True, exist_ok=True) + + resolved_url = url + request_headers = headers + if headers and url.startswith(CIVITAI_DOWNLOAD_URL_PREFIXES): + resolved_url = await self._resolve_authenticated_redirect_url(url, headers) + if resolved_url != url: + request_headers = None + logger.debug( + "Resolved Civitai download %s to signed URL for aria2", + download_id, + ) + + options: Dict[str, str] = { + "dir": save_dir, + "out": out_name, + "continue": "true", + "max-connection-per-server": "4", + "split": "4", + "min-split-size": "1M", + "allow-overwrite": "true", + "auto-file-renaming": "false", + "file-allocation": "none", + } + if request_headers: + options["header"] = [ + f"{key}: {value}" for key, value in request_headers.items() + ] + + logger.debug( + "Submitting aria2 download %s -> %s (auth=%s, civitai_signed=%s)", + download_id, + save_path, + bool(request_headers), + resolved_url != url, + ) + + try: + gid = await self._rpc_call("aria2.addUri", [[resolved_url], options]) + except Exception as exc: + raise Aria2Error(f"Failed to schedule aria2 download: {exc}") from exc + + logger.debug("aria2 accepted download %s with gid %s", download_id, gid) + await self._state_store.upsert( + download_id, + { + "gid": gid, + "save_path": save_path, + "status": "downloading", + "url": url, + }, + ) + return gid + + async def get_status(self, download_id: str) -> Optional[Dict[str, Any]]: + """Return the raw aria2 status payload for a known download.""" + + transfer = self._transfers.get(download_id) + if transfer is None: + return None + + keys = [ + "gid", + "status", + "totalLength", + "completedLength", + "downloadSpeed", + "errorMessage", + "files", + ] + try: + status = await self._rpc_call("aria2.tellStatus", [transfer.gid, keys]) + except Exception as exc: + raise Aria2Error(f"Failed to query aria2 download status: {exc}") from exc + + if isinstance(status, dict): + return status + return None + + async def get_status_by_gid(self, gid: str) -> Optional[Dict[str, Any]]: + keys = [ + "gid", + "status", + "totalLength", + "completedLength", + "downloadSpeed", + "errorMessage", + "files", + ] + try: + status = await self._rpc_call("aria2.tellStatus", [gid, keys]) + except Exception as exc: + message = str(exc) + if "cannot be found" in message.lower() or "not found" in message.lower(): + return None + raise Aria2Error(f"Failed to query aria2 download status: {exc}") from exc + + if isinstance(status, dict): + return status + return None + + async def restore_transfer(self, download_id: str, gid: str, save_path: str) -> None: + await self._ensure_process() + self._transfers[download_id] = Aria2Transfer( + gid=gid, + save_path=os.path.abspath(save_path), + ) + + async def reassign_transfer( + self, from_download_id: str, to_download_id: str + ) -> Optional[Aria2Transfer]: + transfer = self._transfers.get(from_download_id) + if transfer is None: + return None + + self._transfers[to_download_id] = transfer + if from_download_id != to_download_id: + self._transfers.pop(from_download_id, None) + return transfer + + async def has_transfer(self, download_id: str) -> bool: + return download_id in self._transfers + + async def pause_download(self, download_id: str) -> Dict[str, Any]: + transfer = self._transfers.get(download_id) + if transfer is None: + return {"success": False, "error": "Download task not found"} + + try: + await self._rpc_call("aria2.forcePause", [transfer.gid]) + except Exception as exc: + return {"success": False, "error": str(exc)} + + await self._state_store.upsert(download_id, {"status": "paused"}) + return {"success": True, "message": "Download paused successfully"} + + async def resume_download(self, download_id: str) -> Dict[str, Any]: + transfer = self._transfers.get(download_id) + if transfer is None: + return {"success": False, "error": "Download task not found"} + + try: + await self._rpc_call("aria2.unpause", [transfer.gid]) + except Exception as exc: + return {"success": False, "error": str(exc)} + + await self._state_store.upsert(download_id, {"status": "downloading"}) + return {"success": True, "message": "Download resumed successfully"} + + async def cancel_download(self, download_id: str) -> Dict[str, Any]: + transfer = self._transfers.get(download_id) + if transfer is None: + return {"success": False, "error": "Download task not found"} + + try: + await self._rpc_call("aria2.forceRemove", [transfer.gid]) + except Exception as exc: + return {"success": False, "error": str(exc)} + + await self._state_store.remove(download_id) + return {"success": True, "message": "Download cancelled successfully"} + + async def close(self) -> None: + """Shut down the RPC process and session.""" + + if self._rpc_session is not None: + await self._rpc_session.close() + self._rpc_session = None + + process = self._process + self._process = None + self._transfers.clear() + + if process is None: + return + + if process.returncode is None: + process.terminate() + try: + await asyncio.wait_for(process.wait(), timeout=5.0) + except asyncio.TimeoutError: + process.kill() + await process.wait() + + async def _dispatch_progress(self, callback, snapshot: DownloadProgress) -> None: + try: + result = callback(snapshot, snapshot) + except TypeError: + result = callback(snapshot.percent_complete) + + if asyncio.iscoroutine(result): + await result + elif hasattr(result, "__await__"): + await result + + def _build_progress_snapshot(self, status: Dict[str, Any]) -> DownloadProgress: + completed = self._parse_int(status.get("completedLength")) + total = self._parse_int(status.get("totalLength")) + speed = float(self._parse_int(status.get("downloadSpeed"))) + percent = 0.0 + if total > 0: + percent = (completed / total) * 100.0 + + return DownloadProgress( + percent_complete=max(0.0, min(percent, 100.0)), + bytes_downloaded=completed, + total_bytes=total or None, + bytes_per_second=speed, + timestamp=datetime.now().timestamp(), + ) + + def _resolve_completed_path(self, status: Dict[str, Any], default_path: str) -> str: + files = status.get("files") + if isinstance(files, list) and files: + first = files[0] + if isinstance(first, dict): + candidate = first.get("path") + if isinstance(candidate, str) and candidate: + return candidate + return default_path + + @staticmethod + def _parse_int(value: Any) -> int: + try: + return int(value) + except (TypeError, ValueError): + return 0 + + async def _resolve_authenticated_redirect_url( + self, + url: str, + headers: Dict[str, str], + ) -> str: + downloader = await get_downloader() + session = await downloader.session + request_headers = dict(downloader.default_headers) + request_headers.update(headers) + request_headers["Accept-Encoding"] = "identity" + + try: + async with session.get( + url, + headers=request_headers, + allow_redirects=False, + proxy=downloader.proxy_url, + ) as response: + if response.status in {301, 302, 303, 307, 308}: + location = response.headers.get("Location") + if location: + return location + raise Aria2Error( + "Authenticated Civitai redirect did not include a Location header" + ) + + if response.status == 200: + return url + + body = await response.text() + raise Aria2Error( + f"Failed to resolve authenticated Civitai redirect: status={response.status} body={body[:300]}" + ) + except aiohttp.ClientError as exc: + raise Aria2Error( + f"Failed to resolve authenticated Civitai redirect: {exc}" + ) from exc + + async def _ensure_process(self) -> None: + async with self._process_lock: + if self.is_running and await self._ping(): + return + + await self.close() + + executable = self._resolve_executable() + self._rpc_port = self._find_free_port() + self._rpc_secret = secrets.token_hex(16) + self._rpc_url = f"http://127.0.0.1:{self._rpc_port}/jsonrpc" + + command = [ + executable, + "--enable-rpc=true", + "--rpc-listen-all=false", + f"--rpc-listen-port={self._rpc_port}", + f"--rpc-secret={self._rpc_secret}", + "--check-certificate=true", + "--allow-overwrite=true", + "--auto-file-renaming=false", + "--file-allocation=none", + "--max-concurrent-downloads=5", + "--continue=true", + "--daemon=false", + "--quiet=true", + f"--stop-with-process={os.getpid()}", + ] + + logger.info("Starting aria2 RPC daemon from %s", executable) + self._process = await asyncio.create_subprocess_exec( + *command, + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.PIPE, + ) + + await self._wait_until_ready() + + def _resolve_executable(self) -> str: + settings = get_settings_manager() + configured_path = (settings.get("aria2c_path") or "").strip() + candidate = configured_path or "aria2c" + + resolved = shutil.which(candidate) + if resolved: + return resolved + + if configured_path and os.path.isfile(configured_path) and os.access( + configured_path, os.X_OK + ): + return configured_path + + raise Aria2Error( + "aria2c executable was not found. Install aria2 or configure aria2c_path." + ) + + async def _wait_until_ready(self) -> None: + assert self._process is not None + + start_time = asyncio.get_running_loop().time() + last_error = "" + while asyncio.get_running_loop().time() - start_time < 10.0: + if self._process.returncode is not None: + stderr_output = "" + if self._process.stderr is not None: + try: + stderr_output = ( + await asyncio.wait_for(self._process.stderr.read(), timeout=0.2) + ).decode("utf-8", errors="replace") + except Exception: + stderr_output = "" + raise Aria2Error( + f"aria2 RPC process exited early with code {self._process.returncode}: {stderr_output.strip()}" + ) + + try: + if await self._ping(): + return + except Exception as exc: # pragma: no cover - startup race + last_error = str(exc) + + await asyncio.sleep(0.2) + + raise Aria2Error( + f"Timed out waiting for aria2 RPC to become ready{': ' + last_error if last_error else ''}" + ) + + async def _ping(self) -> bool: + try: + result = await self._rpc_call("aria2.getVersion", []) + except Exception: + return False + + return isinstance(result, dict) + + async def _rpc_call(self, method: str, params: list[Any]) -> Any: + if not self._rpc_url: + raise Aria2Error("aria2 RPC endpoint is not initialized") + + session = await self._get_rpc_session() + payload = { + "jsonrpc": "2.0", + "id": secrets.token_hex(8), + "method": method, + "params": [f"token:{self._rpc_secret}", *params], + } + + async with session.post(self._rpc_url, json=payload) as response: + text = await response.text() + + try: + body = json.loads(text) + except json.JSONDecodeError: + body = None + + if body is None: + if response.status != 200: + raise Aria2Error( + f"aria2 RPC returned status {response.status} with non-JSON body: {text}" + ) + raise Aria2Error(f"Invalid aria2 RPC response: {text}") + + if "error" in body: + error = body["error"] or {} + code = error.get("code") if isinstance(error, dict) else None + message = error.get("message") if isinstance(error, dict) else str(error) + logger.error( + "aria2 RPC %s failed with HTTP %s, code=%s, message=%s", + method, + response.status, + code, + message, + ) + status_message = ( + f"aria2 RPC {method} failed with status {response.status}: {message}" + if response.status != 200 + else message + ) + raise Aria2Error(status_message or "Unknown aria2 RPC error") + + if response.status != 200: + logger.error( + "aria2 RPC %s returned unexpected HTTP status %s without error payload: %s", + method, + response.status, + body, + ) + raise Aria2Error( + f"aria2 RPC {method} returned unexpected status {response.status}" + ) + + return body.get("result") + + async def _get_rpc_session(self) -> aiohttp.ClientSession: + if self._rpc_session is None or self._rpc_session.closed: + async with self._rpc_session_lock: + if self._rpc_session is None or self._rpc_session.closed: + timeout = aiohttp.ClientTimeout(total=30) + self._rpc_session = aiohttp.ClientSession(timeout=timeout) + return self._rpc_session + + @staticmethod + def _find_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + sock.listen(1) + return int(sock.getsockname()[1]) + + +async def get_aria2_downloader() -> Aria2Downloader: + """Get the singleton aria2 downloader.""" + + return await Aria2Downloader.get_instance() diff --git a/py/services/aria2_transfer_state.py b/py/services/aria2_transfer_state.py new file mode 100644 index 00000000..1754c95d --- /dev/null +++ b/py/services/aria2_transfer_state.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import asyncio +import json +import os +from copy import deepcopy +from typing import Any, Dict, Optional + +from ..utils.cache_paths import get_cache_base_dir + + +def get_aria2_state_path() -> str: + base_dir = get_cache_base_dir(create=True) + state_dir = os.path.join(base_dir, "aria2") + os.makedirs(state_dir, exist_ok=True) + return os.path.join(state_dir, "downloads.json") + + +class Aria2TransferStateStore: + """Persist aria2 transfer metadata needed for restart recovery.""" + + _locks_by_path: Dict[str, asyncio.Lock] = {} + + def __init__(self, state_path: Optional[str] = None) -> None: + self._state_path = os.path.abspath(state_path or get_aria2_state_path()) + self._lock = self._locks_by_path.setdefault(self._state_path, asyncio.Lock()) + + def _read_all_unlocked(self) -> Dict[str, Dict[str, Any]]: + try: + with open(self._state_path, "r", encoding="utf-8") as handle: + data = json.load(handle) + except FileNotFoundError: + return {} + except json.JSONDecodeError: + return {} + + if not isinstance(data, dict): + return {} + + normalized: Dict[str, Dict[str, Any]] = {} + for download_id, entry in data.items(): + if isinstance(download_id, str) and isinstance(entry, dict): + normalized[download_id] = entry + return normalized + + def _write_all_unlocked(self, data: Dict[str, Dict[str, Any]]) -> None: + directory = os.path.dirname(self._state_path) + if directory: + os.makedirs(directory, exist_ok=True) + + temp_path = f"{self._state_path}.tmp" + with open(temp_path, "w", encoding="utf-8") as handle: + json.dump(data, handle, ensure_ascii=True, indent=2, sort_keys=True) + os.replace(temp_path, self._state_path) + + async def load_all(self) -> Dict[str, Dict[str, Any]]: + async with self._lock: + return deepcopy(self._read_all_unlocked()) + + async def get(self, download_id: str) -> Optional[Dict[str, Any]]: + async with self._lock: + return deepcopy(self._read_all_unlocked().get(download_id)) + + async def upsert(self, download_id: str, payload: Dict[str, Any]) -> Dict[str, Any]: + async with self._lock: + data = self._read_all_unlocked() + current = data.get(download_id, {}) + current.update(payload) + data[download_id] = current + self._write_all_unlocked(data) + return deepcopy(current) + + async def remove(self, download_id: str) -> None: + async with self._lock: + data = self._read_all_unlocked() + if download_id in data: + del data[download_id] + self._write_all_unlocked(data) + + async def find_by_save_path( + self, save_path: str, *, exclude_download_id: Optional[str] = None + ) -> Optional[Dict[str, Any]]: + normalized_target = os.path.abspath(save_path) + async with self._lock: + data = self._read_all_unlocked() + for download_id, entry in data.items(): + if exclude_download_id and download_id == exclude_download_id: + continue + candidate = entry.get("save_path") + if isinstance(candidate, str) and os.path.abspath(candidate) == normalized_target: + result = dict(entry) + result["download_id"] = download_id + return result + return None + + async def reassign(self, from_download_id: str, to_download_id: str) -> Optional[Dict[str, Any]]: + async with self._lock: + data = self._read_all_unlocked() + existing = data.get(from_download_id) + if existing is None: + return None + updated = dict(existing) + updated["download_id"] = to_download_id + data[to_download_id] = updated + if from_download_id != to_download_id: + data.pop(from_download_id, None) + self._write_all_unlocked(data) + return deepcopy(updated) diff --git a/py/services/connectivity_guard.py b/py/services/connectivity_guard.py index 05de8004..1f60d5df 100644 --- a/py/services/connectivity_guard.py +++ b/py/services/connectivity_guard.py @@ -6,6 +6,7 @@ import asyncio import errno import logging import socket +from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any @@ -49,68 +50,118 @@ class ConnectivityGuard: if hasattr(self, "_initialized"): return self._initialized = True - self.online = True - self.failure_count = 0 - self.cooldown_until: datetime | None = None + self._default_destination = "__global__" + self._destination_states: dict[str, _DestinationState] = { + self._default_destination: _DestinationState() + } self.base_backoff_seconds = 30 self.max_backoff_seconds = 300 self.failure_threshold = 3 + @property + def online(self) -> bool: + return self._state_for_destination(None).online + + @online.setter + def online(self, value: bool) -> None: + self._state_for_destination(None).online = value + + @property + def failure_count(self) -> int: + return self._state_for_destination(None).failure_count + + @failure_count.setter + def failure_count(self, value: int) -> None: + self._state_for_destination(None).failure_count = value + + @property + def cooldown_until(self) -> datetime | None: + return self._state_for_destination(None).cooldown_until + + @cooldown_until.setter + def cooldown_until(self, value: datetime | None) -> None: + self._state_for_destination(None).cooldown_until = value + def _now(self) -> datetime: return datetime.now() - def in_cooldown(self) -> bool: - if self.cooldown_until is None: + def _normalize_destination(self, destination: str | None) -> str: + if destination is None or not destination.strip(): + return self._default_destination + return destination.lower().strip() + + def _state_for_destination(self, destination: str | None) -> "_DestinationState": + destination_key = self._normalize_destination(destination) + if destination_key not in self._destination_states: + self._destination_states[destination_key] = _DestinationState() + return self._destination_states[destination_key] + + def in_cooldown(self, destination: str | None = None) -> bool: + state = self._state_for_destination(destination) + if state.cooldown_until is None: return False - return self._now() < self.cooldown_until + return self._now() < state.cooldown_until - def cooldown_remaining_seconds(self) -> float: - if self.cooldown_until is None: + def cooldown_remaining_seconds(self, destination: str | None = None) -> float: + state = self._state_for_destination(destination) + if state.cooldown_until is None: return 0.0 - return max(0.0, (self.cooldown_until - self._now()).total_seconds()) + return max(0.0, (state.cooldown_until - self._now()).total_seconds()) - def should_block_request(self) -> bool: - return self.in_cooldown() + def should_block_request(self, destination: str | None = None) -> bool: + return self.in_cooldown(destination) - def register_success(self) -> None: - was_offline = (not self.online) or self.cooldown_until is not None - self.online = True - self.failure_count = 0 - self.cooldown_until = None + def register_success(self, destination: str | None = None) -> None: + destination_key = self._normalize_destination(destination) + state = self._state_for_destination(destination_key) + was_offline = (not state.online) or state.cooldown_until is not None + state.online = True + state.failure_count = 0 + state.cooldown_until = None if was_offline: - logger.info("Connectivity restored; requests resumed.") + logger.info( + "Connectivity restored for destination '%s'; requests resumed.", + destination_key, + ) - def register_network_failure(self, exc: Exception) -> None: - self.online = False - self.failure_count += 1 + def register_network_failure( + self, exc: Exception, destination: str | None = None + ) -> None: + destination_key = self._normalize_destination(destination) + state = self._state_for_destination(destination_key) + state.online = False + state.failure_count += 1 - if self.failure_count < self.failure_threshold: + if state.failure_count < self.failure_threshold: logger.debug( - "Network failure tracked (%d/%d): %s", - self.failure_count, + "Network failure tracked for destination '%s' (%d/%d): %s", + destination_key, + state.failure_count, self.failure_threshold, exc, ) return - retry_step = self.failure_count - self.failure_threshold + retry_step = state.failure_count - self.failure_threshold backoff = min( self.max_backoff_seconds, self.base_backoff_seconds * (2**retry_step), ) - should_log_warning = not self.in_cooldown() - self.cooldown_until = self._now() + timedelta(seconds=backoff) + should_log_warning = not self.in_cooldown(destination_key) + state.cooldown_until = self._now() + timedelta(seconds=backoff) if should_log_warning: logger.warning( - "Connectivity offline; enter cooldown for %ss after %d network failures.", + "Connectivity offline for destination '%s'; enter cooldown for %ss after %d network failures.", + destination_key, int(backoff), - self.failure_count, + state.failure_count, ) else: logger.debug( - "Cooldown still active; failure_count=%d, backoff=%ss.", - self.failure_count, + "Cooldown still active for destination '%s'; failure_count=%d, backoff=%ss.", + destination_key, + state.failure_count, int(backoff), ) @@ -145,3 +196,9 @@ class ConnectivityGuard: return False + +@dataclass +class _DestinationState: + online: bool = True + failure_count: int = 0 + cooldown_until: datetime | None = None diff --git a/py/services/download_manager.py b/py/services/download_manager.py index 2fe07cff..5b297356 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -5,6 +5,7 @@ import asyncio import inspect import shutil import zipfile +from concurrent.futures import ThreadPoolExecutor from collections import OrderedDict import uuid from typing import Dict, List, Optional, Set, Tuple @@ -25,6 +26,8 @@ from .service_registry import ServiceRegistry from .settings_manager import get_settings_manager from .metadata_service import get_default_metadata_provider, get_metadata_provider from .downloader import get_downloader, DownloadProgress, DownloadStreamControl +from .aria2_downloader import Aria2Error, get_aria2_downloader +from .aria2_transfer_state import Aria2TransferStateStore # Download to temporary file first import tempfile @@ -60,6 +63,62 @@ class DownloadManager: self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads self._download_tasks = {} # download_id -> asyncio.Task self._pause_events: Dict[str, DownloadStreamControl] = {} + self._archive_executor = ThreadPoolExecutor( + max_workers=2, thread_name_prefix="lm-archive" + ) + self._aria2_state_store = Aria2TransferStateStore() + self._restored_persisted_downloads = False + self._restore_lock = asyncio.Lock() + + @staticmethod + def _get_model_download_backend() -> str: + backend = (get_settings_manager().get("download_backend") or "python").strip() + return backend.lower() or "python" + + async def _download_model_file( + self, + download_url: str, + save_path: str, + *, + backend: str, + progress_callback, + use_auth: bool, + download_id: Optional[str], + pause_control: Optional[DownloadStreamControl], + ) -> Tuple[bool, str]: + if backend == "aria2": + if not download_id: + return False, "aria2 downloads require a tracked download_id" + + headers: Dict[str, str] = {} + if use_auth: + api_key = (get_settings_manager().get("civitai_api_key") or "").strip() + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + try: + aria2_downloader = await get_aria2_downloader() + return await aria2_downloader.download_file( + download_url, + save_path, + download_id=download_id, + progress_callback=progress_callback, + headers=headers or None, + ) + except Aria2Error as exc: + logger.error("aria2 download failed for %s: %s", download_url, exc) + return False, str(exc) + + download_kwargs = { + "progress_callback": progress_callback, + "use_auth": use_auth, + } + + if pause_control is not None: + download_kwargs["pause_event"] = pause_control + + downloader = await get_downloader() + return await downloader.download_file(download_url, save_path, **download_kwargs) async def _get_lora_scanner(self): """Get the lora scanner from registry""" @@ -124,8 +183,14 @@ class DownloadManager: self._active_downloads[task_id] = { "model_id": model_id, "model_version_id": model_version_id, + "save_dir": save_dir, + "relative_path": relative_path, + "use_default_paths": bool(use_default_paths), + "source": source, + "file_params": copy.deepcopy(file_params) if file_params is not None else None, "progress": 0, "status": "queued", + "transfer_backend": self._get_model_download_backend(), "bytes_downloaded": 0, "total_bytes": None, "bytes_per_second": 0.0, @@ -135,6 +200,9 @@ class DownloadManager: pause_control = DownloadStreamControl() self._pause_events[task_id] = pause_control + if self._active_downloads[task_id]["transfer_backend"] == "aria2": + await self._persist_aria2_state(task_id) + # Create tracking task download_task = asyncio.create_task( self._download_with_semaphore( @@ -186,6 +254,8 @@ class DownloadManager: # Update status to waiting if task_id in self._active_downloads: self._active_downloads[task_id]["status"] = "waiting" + if self._active_downloads[task_id].get("transfer_backend") == "aria2": + await self._persist_aria2_state(task_id) # Wrap progress callback to track progress in active_downloads original_callback = progress_callback @@ -220,11 +290,15 @@ class DownloadManager: if task_id in self._active_downloads: self._active_downloads[task_id]["status"] = "paused" self._active_downloads[task_id]["bytes_per_second"] = 0.0 + if self._active_downloads[task_id].get("transfer_backend") == "aria2": + await self._persist_aria2_state(task_id) await pause_control.wait() # Update status to downloading if task_id in self._active_downloads: self._active_downloads[task_id]["status"] = "downloading" + if self._active_downloads[task_id].get("transfer_backend") == "aria2": + await self._persist_aria2_state(task_id) # Use original download implementation try: @@ -240,6 +314,9 @@ class DownloadManager: tracking_callback, use_default_paths, task_id, + self._active_downloads.get(task_id, {}).get( + "transfer_backend", "python" + ), source, file_params, ) @@ -256,6 +333,8 @@ class DownloadManager: "error", "Unknown error" ) self._active_downloads[task_id]["bytes_per_second"] = 0.0 + if self._active_downloads[task_id].get("transfer_backend") == "aria2": + await self._persist_aria2_state(task_id) return result except asyncio.CancelledError: @@ -263,6 +342,8 @@ class DownloadManager: if task_id in self._active_downloads: self._active_downloads[task_id]["status"] = "cancelled" self._active_downloads[task_id]["bytes_per_second"] = 0.0 + if self._active_downloads[task_id].get("transfer_backend") == "aria2": + await self._persist_aria2_state(task_id) logger.info(f"Download cancelled for task {task_id}") raise except Exception as e: @@ -274,17 +355,639 @@ class DownloadManager: self._active_downloads[task_id]["status"] = "failed" self._active_downloads[task_id]["error"] = str(e) self._active_downloads[task_id]["bytes_per_second"] = 0.0 + if self._active_downloads[task_id].get("transfer_backend") == "aria2": + await self._persist_aria2_state(task_id) return {"success": False, "error": str(e)} finally: # Schedule cleanup of download record after delay asyncio.create_task(self._cleanup_download_record(task_id)) + def _start_background_download_task(self, download_id: str, coroutine) -> asyncio.Task: + task = asyncio.create_task(coroutine) + self._download_tasks[download_id] = task + + def _cleanup_done_task(done_task: asyncio.Task) -> None: + current_task = self._download_tasks.get(download_id) + if current_task is done_task: + self._download_tasks.pop(download_id, None) + self._pause_events.pop(download_id, None) + + task.add_done_callback(_cleanup_done_task) + return task + async def _cleanup_download_record(self, task_id: str): """Keep completed downloads in history for a short time""" await asyncio.sleep(600) # Keep for 10 minutes if task_id in self._active_downloads: del self._active_downloads[task_id] + async def _delete_file_with_retries( + self, + path: Optional[str], + *, + retries: int = 5, + delay: float = 0.1, + ) -> bool: + if not path: + return False + + for attempt in range(retries): + if not os.path.exists(path): + return True + try: + os.unlink(path) + return True + except FileNotFoundError: + return True + except Exception: + if attempt == retries - 1: + return False + await asyncio.sleep(delay) + return False + + async def _cleanup_cancelled_download_files( + self, + download_id: str, + download_info: Optional[Dict], + ) -> None: + target_files = set() + persisted = await self._aria2_state_store.get(download_id) + + primary_path = None + if isinstance(download_info, dict): + primary_path = download_info.get("file_path") + if not primary_path and isinstance(persisted, dict): + primary_path = persisted.get("save_path") or persisted.get("file_path") + if primary_path: + target_files.add(primary_path) + + if isinstance(download_info, dict): + for extra_path in download_info.get("extracted_paths", []): + if extra_path: + target_files.add(extra_path) + + for file_path in target_files: + deleted = await self._delete_file_with_retries(file_path) + if deleted: + logger.debug(f"Deleted cancelled download: {file_path}") + elif os.path.exists(file_path): + logger.error(f"Error deleting file: {file_path}") + + part_path = None + if isinstance(download_info, dict): + part_path = download_info.get("part_path") + if part_path: + deleted = await self._delete_file_with_retries(part_path) + if deleted: + logger.debug(f"Deleted partial download: {part_path}") + elif os.path.exists(part_path): + logger.error(f"Error deleting part file: {part_path}") + + aria2_control_path = None + if isinstance(download_info, dict): + aria2_control_path = download_info.get("aria2_control_path") + if not aria2_control_path and primary_path: + aria2_control_path = f"{primary_path}.aria2" + if aria2_control_path: + deleted = await self._delete_file_with_retries(aria2_control_path) + if deleted: + logger.debug(f"Deleted aria2 control file: {aria2_control_path}") + elif os.path.exists(aria2_control_path): + logger.warning( + "Failed to delete aria2 control file after retries: %s", + aria2_control_path, + ) + + for file_path in target_files: + metadata_path = os.path.splitext(file_path)[0] + ".metadata.json" + deleted = await self._delete_file_with_retries(metadata_path) + if not deleted and os.path.exists(metadata_path): + logger.error(f"Error deleting metadata file: {metadata_path}") + + preview_candidates = set() + if isinstance(download_info, dict): + preview_path_value = download_info.get("preview_path") + if preview_path_value: + preview_candidates.add(preview_path_value) + + for preview_path in preview_candidates: + deleted = await self._delete_file_with_retries(preview_path) + if deleted and not os.path.exists(preview_path): + logger.debug(f"Deleted preview file: {preview_path}") + elif os.path.exists(preview_path): + logger.error(f"Error deleting preview file: {preview_path}") + + async def _persist_aria2_state( + self, + download_id: str, + *, + extra: Optional[Dict] = None, + ) -> None: + info = self._active_downloads.get(download_id) + if not info: + return + + payload = { + "download_id": download_id, + "model_id": info.get("model_id"), + "model_version_id": info.get("model_version_id"), + "save_dir": info.get("save_dir"), + "relative_path": info.get("relative_path", ""), + "use_default_paths": bool(info.get("use_default_paths", False)), + "source": info.get("source"), + "file_params": copy.deepcopy(info.get("file_params")), + "transfer_backend": info.get("transfer_backend", "aria2"), + "status": info.get("status", "queued"), + "progress": info.get("progress", 0), + "bytes_downloaded": info.get("bytes_downloaded", 0), + "total_bytes": info.get("total_bytes"), + "bytes_per_second": info.get("bytes_per_second", 0.0), + "file_path": info.get("file_path"), + } + if extra: + payload.update(extra) + + await self._aria2_state_store.upsert(download_id, payload) + + def _build_restored_download_info(self, record: Dict, save_path: str) -> Dict: + return { + "model_id": record.get("model_id"), + "model_version_id": record.get("model_version_id"), + "save_dir": record.get("save_dir"), + "relative_path": record.get("relative_path", ""), + "use_default_paths": bool(record.get("use_default_paths", False)), + "source": record.get("source"), + "file_params": copy.deepcopy(record.get("file_params")), + "progress": record.get("progress", 0), + "status": record.get("status", "paused"), + "transfer_backend": "aria2", + "bytes_downloaded": record.get("bytes_downloaded", 0), + "total_bytes": record.get("total_bytes"), + "bytes_per_second": record.get("bytes_per_second", 0.0), + "last_progress_timestamp": None, + "file_path": save_path, + "aria2_control_path": f"{save_path}.aria2", + } + + def _is_same_aria2_download_request( + self, + current_info: Optional[Dict], + persisted_record: Dict, + ) -> bool: + if not isinstance(current_info, dict): + return False + + current_version_id = current_info.get("model_version_id") + persisted_version_id = persisted_record.get("model_version_id") + if current_version_id is None or persisted_version_id is None: + return False + + return current_version_id == persisted_version_id + + def _build_download_urls_from_file_info(self, file_info: Dict, source: str = None) -> List[str]: + mirrors = file_info.get("mirrors") or [] + download_urls: List[str] = [] + if mirrors: + for mirror in mirrors: + if mirror.get("deletedAt") is None and mirror.get("url"): + download_urls.append(normalize_civitai_download_url(mirror["url"])) + + if source == "civarchive" and len(download_urls) > 1: + civitai_urls = [ + u for u in download_urls if u.startswith(CIVITAI_DOWNLOAD_URL_PREFIXES) + ] + non_civitai_urls = [ + u for u in download_urls if not u.startswith(CIVITAI_DOWNLOAD_URL_PREFIXES) + ] + download_urls = non_civitai_urls + civitai_urls + else: + download_url = file_info.get("downloadUrl") + if download_url: + download_urls.append(normalize_civitai_download_url(download_url)) + + return download_urls + + def _build_metadata_for_resume( + self, + *, + model_type: str, + version_info: Dict, + file_info: Dict, + save_path: str, + ): + if model_type == "checkpoint": + return CheckpointMetadata.from_civitai_info(version_info, file_info, save_path) + if model_type == "embedding": + return EmbeddingMetadata.from_civitai_info(version_info, file_info, save_path) + return LoraMetadata.from_civitai_info(version_info, file_info, save_path) + + def _resolve_save_path_from_persisted_record(self, record: Dict) -> Optional[str]: + save_path = record.get("save_path") or record.get("file_path") + if isinstance(save_path, str) and save_path: + return os.path.abspath(save_path) + + resume_context = record.get("resume_context") + if not isinstance(resume_context, dict): + return None + + save_dir = resume_context.get("save_dir") + file_info = resume_context.get("file_info") + if not isinstance(save_dir, str) or not save_dir: + return None + if not isinstance(file_info, dict): + return None + + file_name = file_info.get("name") + if not isinstance(file_name, str) or not file_name: + return None + + return os.path.abspath(os.path.join(save_dir, file_name)) + + async def _resume_restored_aria2_download(self, download_id: str, record: Dict) -> Dict: + try: + if download_id in self._active_downloads: + self._active_downloads[download_id]["status"] = "downloading" + self._active_downloads[download_id]["bytes_per_second"] = 0.0 + if self._active_downloads[download_id].get("transfer_backend") == "aria2": + await self._persist_aria2_state(download_id) + + resume_context = record.get("resume_context") + if not isinstance(resume_context, dict): + result = {"success": False, "error": "Missing aria2 resume context"} + else: + version_info = copy.deepcopy(resume_context.get("version_info") or {}) + file_info = copy.deepcopy(resume_context.get("file_info") or {}) + model_type = (resume_context.get("model_type") or "").lower() + relative_path = resume_context.get("relative_path", "") + save_dir = resume_context.get("save_dir") + source = record.get("source") + + if not version_info or not file_info or not model_type or not save_dir: + result = {"success": False, "error": "Incomplete aria2 resume context"} + else: + save_path = ( + record.get("save_path") + or record.get("file_path") + or os.path.join(save_dir, file_info.get("name", "")) + ) + metadata = self._build_metadata_for_resume( + model_type=model_type, + version_info=version_info, + file_info=file_info, + save_path=save_path, + ) + download_urls = resume_context.get("download_urls") + if not isinstance(download_urls, list) or not download_urls: + download_urls = self._build_download_urls_from_file_info( + file_info, source=source + ) + if not download_urls: + result = {"success": False, "error": "No mirror URL found"} + else: + result = await self._execute_download( + download_urls=download_urls, + save_dir=save_dir, + metadata=metadata, + version_info=version_info, + relative_path=relative_path, + progress_callback=None, + model_type=model_type, + download_id=download_id, + transfer_backend="aria2", + ) + + if result.get("success", False): + resolved_model_id = ( + record.get("model_id") + or version_info.get("modelId") + or (version_info.get("model") or {}).get("id") + ) + await self._record_downloaded_version_history( + model_type, + resolved_model_id, + version_info, + record.get("model_version_id"), + record.get("save_path") or record.get("file_path"), + ) + await self._sync_downloaded_version( + model_type, + resolved_model_id, + version_info, + record.get("model_version_id"), + ) + + if download_id in self._active_downloads: + self._active_downloads[download_id]["status"] = ( + result.get("status", "completed") + if result["success"] + else "failed" + ) + if not result["success"]: + self._active_downloads[download_id]["error"] = result.get( + "error", "Unknown error" + ) + self._active_downloads[download_id]["bytes_per_second"] = 0.0 + if self._active_downloads[download_id].get("transfer_backend") == "aria2": + await self._persist_aria2_state(download_id) + + return result + except asyncio.CancelledError: + if download_id in self._active_downloads: + self._active_downloads[download_id]["status"] = "cancelled" + self._active_downloads[download_id]["bytes_per_second"] = 0.0 + if self._active_downloads[download_id].get("transfer_backend") == "aria2": + await self._persist_aria2_state(download_id) + logger.info(f"Download cancelled for task {download_id}") + raise + except Exception as exc: + logger.error( + f"Download error for task {download_id}: {str(exc)}", exc_info=True + ) + if download_id in self._active_downloads: + self._active_downloads[download_id]["status"] = "failed" + self._active_downloads[download_id]["error"] = str(exc) + self._active_downloads[download_id]["bytes_per_second"] = 0.0 + if self._active_downloads[download_id].get("transfer_backend") == "aria2": + await self._persist_aria2_state(download_id) + return {"success": False, "error": str(exc)} + finally: + asyncio.create_task(self._cleanup_download_record(download_id)) + + async def _adopt_existing_aria2_download( + self, + previous_download_id: str, + new_download_id: str, + persisted_record: Dict, + save_path: str, + ) -> None: + aria2_downloader = await get_aria2_downloader() + await aria2_downloader.reassign_transfer(previous_download_id, new_download_id) + + old_task = self._download_tasks.get(previous_download_id) + if old_task is not None and not old_task.done(): + old_task.cancel() + old_pause_control = self._pause_events.get(previous_download_id) + if old_pause_control is not None: + old_pause_control.resume() + try: + await asyncio.wait_for(asyncio.shield(old_task), timeout=2.0) + except (asyncio.CancelledError, asyncio.TimeoutError): + pass + + if previous_download_id != new_download_id: + self._active_downloads.pop(previous_download_id, None) + self._pause_events.pop(previous_download_id, None) + self._download_tasks.pop(previous_download_id, None) + + reassigned = await self._aria2_state_store.reassign( + previous_download_id, new_download_id + ) + merged_record = dict(persisted_record) + if reassigned: + merged_record.update(reassigned) + + current_info = self._active_downloads.get(new_download_id) + if current_info is not None: + current_info.update( + { + "model_id": merged_record.get("model_id", current_info.get("model_id")), + "model_version_id": merged_record.get( + "model_version_id", current_info.get("model_version_id") + ), + "save_dir": merged_record.get("save_dir", current_info.get("save_dir")), + "relative_path": merged_record.get( + "relative_path", current_info.get("relative_path", "") + ), + "source": merged_record.get("source", current_info.get("source")), + "file_params": copy.deepcopy( + merged_record.get("file_params", current_info.get("file_params")) + ), + "file_path": save_path, + "aria2_control_path": f"{save_path}.aria2", + } + ) + else: + self._active_downloads[new_download_id] = self._build_restored_download_info( + merged_record, save_path + ) + + async def _restore_persisted_downloads(self) -> None: + if self._restored_persisted_downloads: + return + + async with self._restore_lock: + if self._restored_persisted_downloads: + return + + persisted = await self._aria2_state_store.load_all() + if not persisted: + self._restored_persisted_downloads = True + return + + aria2_downloader = await get_aria2_downloader() + for download_id, record in persisted.items(): + if record.get("transfer_backend") != "aria2": + continue + + save_path = self._resolve_save_path_from_persisted_record(record) + if save_path is None: + continue + + if ( + record.get("save_path") != save_path + or record.get("file_path") != save_path + ): + await self._aria2_state_store.upsert( + download_id, + { + "save_path": save_path, + "file_path": save_path, + }, + ) + control_path = f"{save_path}.aria2" + gid = record.get("gid") + status_payload = None + if isinstance(gid, str) and gid: + try: + status_payload = await aria2_downloader.get_status_by_gid(gid) + except Exception: + status_payload = None + + if status_payload is not None: + remote_status = status_payload.get("status", "") + if remote_status in {"active", "waiting", "paused"}: + await aria2_downloader.restore_transfer(download_id, gid, save_path) + restored = self._active_downloads.setdefault( + download_id, + self._build_restored_download_info(record, save_path), + ) + restored["status"] = ( + "paused" if remote_status == "paused" else "downloading" + ) + pause_control = self._pause_events.get(download_id) + if pause_control is None: + pause_control = DownloadStreamControl() + self._pause_events[download_id] = pause_control + if remote_status == "paused": + pause_control.pause() + else: + pause_control.resume() + await self._aria2_state_store.upsert( + download_id, + { + "gid": gid, + "save_path": save_path, + "file_path": save_path, + "status": restored["status"], + }, + ) + if ( + remote_status in {"active", "waiting"} + and download_id not in self._download_tasks + ): + resume_context = record.get("resume_context") + if isinstance(resume_context, dict): + self._start_background_download_task( + download_id, + self._resume_restored_aria2_download( + download_id, + dict(record), + ) + ) + else: + self._start_background_download_task( + download_id, + self._download_with_semaphore( + download_id, + restored.get("model_id"), + restored.get("model_version_id"), + restored.get("save_dir"), + restored.get("relative_path", ""), + None, + bool(restored.get("use_default_paths", False)), + restored.get("source"), + restored.get("file_params"), + ) + ) + continue + + if remote_status == "complete" and not os.path.exists(control_path): + await self._aria2_state_store.remove(download_id) + continue + + if os.path.exists(save_path) and os.path.exists(control_path): + restored = self._active_downloads.setdefault( + download_id, + self._build_restored_download_info(record, save_path), + ) + pause_control = self._pause_events.get(download_id) + if pause_control is None: + pause_control = DownloadStreamControl() + self._pause_events[download_id] = pause_control + + # No live aria2 gid was found, so restore this partial as resumable-but-paused. + pause_control.pause() + restored["status"] = "paused" + await self._aria2_state_store.upsert( + download_id, + { + "save_path": save_path, + "file_path": save_path, + "status": "paused", + }, + ) + continue + + await self._aria2_state_store.remove(download_id) + + self._restored_persisted_downloads = True + + async def _resolve_download_target_path( + self, + save_dir: str, + metadata, + *, + transfer_backend: str, + download_id: Optional[str], + ) -> Tuple[bool, str]: + original_filename = os.path.basename(metadata.file_path) + base_name, extension = os.path.splitext(original_filename) + original_path = os.path.join(save_dir, original_filename) + + if transfer_backend == "aria2": + control_path = f"{original_path}.aria2" + if os.path.exists(original_path) and os.path.exists(control_path): + persisted_record = None + if download_id: + persisted_record = await self._aria2_state_store.get(download_id) + if persisted_record: + persisted_path = ( + persisted_record.get("save_path") + or persisted_record.get("file_path") + ) + if isinstance(persisted_path, str) and os.path.abspath( + persisted_path + ) == os.path.abspath(original_path): + logger.info( + "Reusing aria2 partial target %s for %s", + original_path, + download_id, + ) + return True, original_path + + conflict_record = await self._aria2_state_store.find_by_save_path( + original_path, exclude_download_id=download_id + ) + if conflict_record is not None: + current_info = self._active_downloads.get(download_id) if download_id else None + if download_id and self._is_same_aria2_download_request( + current_info, conflict_record + ): + logger.info( + "Reassigning aria2 partial target %s from %s to %s", + original_path, + conflict_record.get("download_id"), + download_id, + ) + await self._adopt_existing_aria2_download( + conflict_record["download_id"], + download_id, + conflict_record, + original_path, + ) + return True, original_path + + return ( + False, + f"Another aria2 download is already using '{original_filename}' for resume", + ) + + if download_id: + logger.info( + "Reusing aria2 partial target %s for %s", + original_path, + download_id, + ) + return True, original_path + + def hash_provider(): + return metadata.sha256 + + unique_filename = metadata.generate_unique_filename( + save_dir, base_name, extension, hash_provider=hash_provider + ) + + if unique_filename != original_filename: + logger.info( + f"Filename conflict detected. Changing '{original_filename}' to '{unique_filename}'" + ) + save_path = os.path.join(save_dir, unique_filename) + metadata.file_path = save_path.replace(os.sep, "/") + metadata.file_name = os.path.splitext(unique_filename)[0] + return True, save_path + + return True, metadata.file_path + async def _execute_original_download( self, model_id, @@ -294,6 +997,7 @@ class DownloadManager: progress_callback, use_default_paths, download_id=None, + transfer_backend="python", source=None, file_params=None, ): @@ -696,16 +1400,44 @@ class DownloadManager: logger.info(f"Creating EmbeddingMetadata for {file_name}") # 6. Start download process - result = await self._execute_download( - download_urls=download_urls, - save_dir=save_dir, - metadata=metadata, - version_info=version_info, - relative_path=relative_path, - progress_callback=progress_callback, - model_type=model_type, - download_id=download_id, - ) + if transfer_backend == "aria2" and download_id: + await self._persist_aria2_state( + download_id, + extra={ + "save_dir": save_dir, + "relative_path": relative_path, + "resume_context": { + "version_info": copy.deepcopy(version_info), + "file_info": copy.deepcopy(file_info), + "model_type": model_type, + "relative_path": relative_path, + "save_dir": save_dir, + "download_urls": copy.deepcopy(download_urls), + }, + }, + ) + + execute_kwargs = { + "download_urls": download_urls, + "save_dir": save_dir, + "metadata": metadata, + "version_info": version_info, + "relative_path": relative_path, + "progress_callback": progress_callback, + "model_type": model_type, + "download_id": download_id, + } + execute_signature = inspect.signature(self._execute_download) + if ( + "transfer_backend" in execute_signature.parameters + or any( + parameter.kind == inspect.Parameter.VAR_KEYWORD + for parameter in execute_signature.parameters.values() + ) + ): + execute_kwargs["transfer_backend"] = transfer_backend + + result = await self._execute_download(**execute_kwargs) if result.get("success", False): resolved_model_id = ( @@ -965,6 +1697,7 @@ class DownloadManager: progress_callback=None, model_type: str = "lora", download_id: str = None, + transfer_backend: Optional[str] = None, ) -> Dict: """Execute the actual download process including preview images and model files""" metadata_entries: List = [] @@ -974,31 +1707,16 @@ class DownloadManager: preview_targets: List[str] = [] preview_path: str | None = None preview_nsfw_level = 0 + transfer_backend = (transfer_backend or self._get_model_download_backend()).lower() try: - # Extract original filename details - original_filename = os.path.basename(metadata.file_path) - base_name, extension = os.path.splitext(original_filename) - - # Check for filename conflicts and generate unique filename if needed - # Use the hash from metadata for conflict resolution - def hash_provider(): - return metadata.sha256 - - unique_filename = metadata.generate_unique_filename( - save_dir, base_name, extension, hash_provider=hash_provider + resolved, save_path = await self._resolve_download_target_path( + save_dir, + metadata, + transfer_backend=transfer_backend, + download_id=download_id, ) - - # Update paths if filename changed - if unique_filename != original_filename: - logger.info( - f"Filename conflict detected. Changing '{original_filename}' to '{unique_filename}'" - ) - save_path = os.path.join(save_dir, unique_filename) - # Update metadata with new file path and name - metadata.file_path = save_path.replace(os.sep, "/") - metadata.file_name = os.path.splitext(unique_filename)[0] - else: - save_path = metadata.file_path + if not resolved: + return {"success": False, "error": save_path} part_path = save_path + ".part" metadata_path = os.path.splitext(save_path)[0] + ".metadata.json" @@ -1008,7 +1726,12 @@ class DownloadManager: # Store file paths in active_downloads for potential cleanup if download_id and download_id in self._active_downloads: self._active_downloads[download_id]["file_path"] = save_path - self._active_downloads[download_id]["part_path"] = part_path + if transfer_backend == "python": + self._active_downloads[download_id]["part_path"] = part_path + if transfer_backend == "aria2": + self._active_downloads[download_id]["aria2_control_path"] = ( + f"{save_path}.aria2" + ) # Download preview image if available images = version_info.get("images", []) @@ -1132,36 +1855,55 @@ class DownloadManager: preview_nsfw_level = nsfw_level metadata.preview_url = preview_path.replace(os.sep, "/") metadata.preview_nsfw_level = nsfw_level + if download_id and download_id in self._active_downloads: + self._active_downloads[download_id]["preview_path"] = preview_path if progress_callback: await progress_callback(3) # 3% progress after preview download - # Download model file with progress tracking using downloader - downloader = await get_downloader() - if pause_control is not None: - pause_control.update_stall_timeout(downloader.stall_timeout) + # Download model file with progress tracking using the configured backend + downloader = None + if transfer_backend == "python": + downloader = await get_downloader() + if pause_control is not None: + pause_control.update_stall_timeout(downloader.stall_timeout) + if pause_control is not None and pause_control.is_paused(): + if download_id and download_id in self._active_downloads: + self._active_downloads[download_id]["status"] = "paused" + self._active_downloads[download_id]["bytes_per_second"] = 0.0 + await pause_control.wait() + if download_id and download_id in self._active_downloads: + self._active_downloads[download_id]["status"] = "downloading" last_error = None for download_url in download_urls: download_url = normalize_civitai_download_url(download_url) use_auth = download_url.startswith(CIVITAI_DOWNLOAD_URL_PREFIXES) - download_kwargs = { - "progress_callback": lambda progress, snapshot=None: ( + if transfer_backend == "aria2" and download_id: + await self._persist_aria2_state( + download_id, + extra={ + "status": self._active_downloads.get(download_id, {}).get( + "status", "downloading" + ), + "save_path": save_path, + "file_path": save_path, + "url": download_url, + }, + ) + success, result = await self._download_model_file( + download_url, + save_path, + backend=transfer_backend, + progress_callback=lambda progress, snapshot=None: ( self._handle_download_progress( progress, progress_callback, snapshot, ) ), - "use_auth": use_auth, # Only use authentication for Civitai downloads - } - - if pause_control is not None: - download_kwargs["pause_event"] = pause_control - - success, result = await downloader.download_file( - download_url, - save_path, # Use full path instead of separate dir and filename - **download_kwargs, + use_auth=use_auth, + download_id=download_id, + pause_control=pause_control, ) if success: @@ -1189,9 +1931,20 @@ class DownloadManager: except Exception as e: logger.warning(f"Failed to cleanup file {path}: {e}") - # Log but don't remove .part file to allow resume - if os.path.exists(part_path): + # Keep resumable partial state for the matching backend. + if transfer_backend == "python" and os.path.exists(part_path): logger.info(f"Preserving partial download for resume: {part_path}") + elif transfer_backend == "aria2" and os.path.exists(f"{save_path}.aria2"): + logger.info("Preserving aria2 partial download for resume: %s", save_path) + if download_id: + await self._persist_aria2_state( + download_id, + extra={ + "status": "failed", + "save_path": save_path, + "file_path": save_path, + }, + ) return { "success": False, @@ -1306,6 +2059,9 @@ class DownloadManager: if scanner is not None: await scanner.add_model_to_cache(metadata_dict, relative_path) + if transfer_backend == "aria2" and download_id: + await self._aria2_state_store.remove(download_id) + # Report 100% completion if progress_callback: await progress_callback(100) @@ -1401,7 +2157,8 @@ class DownloadManager: extracted_files.append(dest_path) return extracted_files - return await asyncio.to_thread(_extract_sync) + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self._archive_executor, _extract_sync) async def _build_metadata_entries( self, base_metadata, file_paths: List[str] @@ -1507,13 +2264,49 @@ class DownloadManager: Returns: Dict: Status of the cancellation operation """ - if download_id not in self._download_tasks: + await self._restore_persisted_downloads() + + if download_id not in self._download_tasks and download_id not in self._active_downloads: return {"success": False, "error": "Download task not found"} + download_info = self._active_downloads.get(download_id) + task = self._download_tasks.get(download_id) + active_statuses = {"queued", "waiting", "downloading", "paused", "cancelling"} + if task is None and ( + not isinstance(download_info, dict) + or download_info.get("status") not in active_statuses + ): + return {"success": False, "error": "Download task not found"} + + should_cleanup_local_tracking = False try: - # Get the task and cancel it - task = self._download_tasks[download_id] - task.cancel() + backend = ( + self._active_downloads.get(download_id, {}).get("transfer_backend") + or "python" + ) + + if backend == "aria2": + try: + aria2_downloader = await get_aria2_downloader() + cancel_result = await aria2_downloader.cancel_download(download_id) + if ( + not cancel_result.get("success") + and cancel_result.get("error") != "Download task not found" + ): + return cancel_result + should_cleanup_local_tracking = True + except Exception as exc: + logger.warning( + "Failed to cancel aria2 transfer for %s, continuing with local task cancellation: %s", + download_id, + exc, + ) + should_cleanup_local_tracking = True + else: + should_cleanup_local_tracking = True + + if task is not None: + task.cancel() pause_control = self._pause_events.get(download_id) if pause_control is not None: @@ -1525,83 +2318,31 @@ class DownloadManager: self._active_downloads[download_id]["bytes_per_second"] = 0.0 # Wait briefly for the task to acknowledge cancellation - try: - await asyncio.wait_for(asyncio.shield(task), timeout=2.0) - except (asyncio.CancelledError, asyncio.TimeoutError): - pass + if task is not None: + try: + await asyncio.wait_for(asyncio.shield(task), timeout=2.0) + except (asyncio.CancelledError, asyncio.TimeoutError): + pass # Clean up ALL files including .part when user cancels download_info = self._active_downloads.get(download_id) - if download_info: - target_files = set() - primary_path = download_info.get("file_path") - if primary_path: - target_files.add(primary_path) - - for extra_path in download_info.get("extracted_paths", []): - if extra_path: - target_files.add(extra_path) - - for file_path in target_files: - if os.path.exists(file_path): - try: - os.unlink(file_path) - logger.debug(f"Deleted cancelled download: {file_path}") - except Exception as e: - logger.error(f"Error deleting file: {e}") - - # Delete the .part file (only on user cancellation) - if "part_path" in download_info: - part_path = download_info["part_path"] - if os.path.exists(part_path): - try: - os.unlink(part_path) - logger.debug(f"Deleted partial download: {part_path}") - except Exception as e: - logger.error(f"Error deleting part file: {e}") - - # Delete metadata files for each resolved path - for file_path in target_files: - metadata_path = os.path.splitext(file_path)[0] + ".metadata.json" - if os.path.exists(metadata_path): - try: - os.unlink(metadata_path) - except Exception as e: - logger.error(f"Error deleting metadata file: {e}") - - preview_path_value = download_info.get("preview_path") - if preview_path_value and os.path.exists(preview_path_value): - try: - os.unlink(preview_path_value) - logger.debug(f"Deleted preview file: {preview_path_value}") - except Exception as e: - logger.error( - f"Error deleting preview file: {preview_path_value}" - ) - - # Delete preview file if exists (.webp or .mp4) for legacy paths - for file_path in target_files: - for preview_ext in [".webp", ".mp4"]: - preview_path = os.path.splitext(file_path)[0] + preview_ext - if os.path.exists(preview_path): - try: - os.unlink(preview_path) - logger.debug(f"Deleted preview file: {preview_path}") - except Exception as e: - logger.error( - f"Error deleting preview file: {preview_path}" - ) + await self._cleanup_cancelled_download_files(download_id, download_info) return {"success": True, "message": "Download cancelled successfully"} except Exception as e: logger.error(f"Error cancelling download: {e}", exc_info=True) return {"success": False, "error": str(e)} finally: - self._pause_events.pop(download_id, None) + if should_cleanup_local_tracking: + self._pause_events.pop(download_id, None) + self._download_tasks.pop(download_id, None) + await self._aria2_state_store.remove(download_id) async def pause_download(self, download_id: str) -> Dict: """Pause an active download without losing progress.""" - if download_id not in self._download_tasks: + await self._restore_persisted_downloads() + + if download_id not in self._download_tasks and download_id not in self._active_downloads: return {"success": False, "error": "Download task not found"} pause_control = self._pause_events.get(download_id) @@ -1613,6 +2354,29 @@ class DownloadManager: pause_control.pause() + backend = ( + self._active_downloads.get(download_id, {}).get("transfer_backend") + or "python" + ) + if backend == "aria2": + try: + aria2_downloader = await get_aria2_downloader() + if await aria2_downloader.has_transfer(download_id): + result = await aria2_downloader.pause_download(download_id) + if not result.get("success"): + pause_control.resume() + return result + except Exception as exc: + pause_control.resume() + return {"success": False, "error": str(exc)} + + download_info = self._active_downloads.get(download_id) + if download_info is not None: + download_info["status"] = "paused" + download_info["bytes_per_second"] = 0.0 + await self._persist_aria2_state(download_id) + return {"success": True, "message": "Download paused successfully"} + download_info = self._active_downloads.get(download_id) if download_info is not None: download_info["status"] = "paused" @@ -1623,14 +2387,78 @@ class DownloadManager: async def resume_download(self, download_id: str) -> Dict: """Resume a previously paused download.""" + await self._restore_persisted_downloads() + pause_control = self._pause_events.get(download_id) if pause_control is None: - return {"success": False, "error": "Download task not found"} + persisted = await self._aria2_state_store.get(download_id) + if not persisted or persisted.get("transfer_backend") != "aria2": + return {"success": False, "error": "Download task not found"} + + save_path = persisted.get("save_path") or persisted.get("file_path") + pause_control = DownloadStreamControl() + pause_control.pause() + self._pause_events[download_id] = pause_control + self._active_downloads[download_id] = self._build_restored_download_info( + persisted, + os.path.abspath(save_path), + ) if pause_control.is_set(): return {"success": False, "error": "Download is not paused"} download_info = self._active_downloads.get(download_id) + backend = ( + self._active_downloads.get(download_id, {}).get("transfer_backend") + or "python" + ) + if backend == "aria2": + try: + persisted = None + if download_id not in self._download_tasks: + persisted = await self._aria2_state_store.get(download_id) + aria2_downloader = await get_aria2_downloader() + if await aria2_downloader.has_transfer(download_id): + result = await aria2_downloader.resume_download(download_id) + if not result.get("success"): + return result + if download_id not in self._download_tasks and persisted: + resume_context = persisted.get("resume_context") + if isinstance(resume_context, dict): + self._start_background_download_task( + download_id, + self._resume_restored_aria2_download( + download_id, + dict(persisted), + ), + ) + else: + self._start_background_download_task( + download_id, + self._download_with_semaphore( + download_id, + persisted.get("model_id"), + persisted.get("model_version_id"), + persisted.get("save_dir"), + persisted.get("relative_path", ""), + None, + bool(persisted.get("use_default_paths", False)), + persisted.get("source"), + persisted.get("file_params"), + ), + ) + except Exception as exc: + return {"success": False, "error": str(exc)} + + pause_control.resume() + + if download_info is not None: + if download_info.get("status") == "paused": + download_info["status"] = "downloading" + download_info.setdefault("bytes_per_second", 0.0) + await self._persist_aria2_state(download_id) + return {"success": True, "message": "Download resumed successfully"} + force_reconnect = False if pause_control is not None: elapsed = pause_control.time_since_last_progress() @@ -1706,6 +2534,7 @@ class DownloadManager: Returns: Dict: List of active downloads and their status """ + await self._restore_persisted_downloads() return { "downloads": [ { diff --git a/py/services/downloader.py b/py/services/downloader.py index 71360538..cfac9bf9 100644 --- a/py/services/downloader.py +++ b/py/services/downloader.py @@ -18,6 +18,7 @@ from collections import deque from dataclasses import dataclass from datetime import datetime, timedelta from email.utils import parsedate_to_datetime +from urllib.parse import urlparse from typing import Optional, Dict, Tuple, Callable, Union, Awaitable from ..services.settings_manager import get_settings_manager from .connectivity_guard import ( @@ -828,7 +829,7 @@ class Downloader: ) as response: if response.status == 200: content = await response.read() - guard.register_success() + guard.register_success(destination) if return_headers: return True, content, dict(response.headers) else: @@ -874,7 +875,8 @@ class Downloader: Tuple[bool, Union[Dict, str]]: (success, headers dict or error message) """ guard = await ConnectivityGuard.get_instance() - if guard.should_block_request(): + destination = self._guard_destination(url) + if guard.should_block_request(destination): return False, OFFLINE_COOLDOWN_ERROR try: @@ -898,15 +900,15 @@ class Downloader: url, headers=headers, proxy=self.proxy_url ) as response: if response.status == 200: - guard.register_success() + guard.register_success(destination) return True, dict(response.headers) else: return False, f"Head request failed with status {response.status}" except Exception as e: if guard.is_network_unreachable_error(e): - guard.register_network_failure(e) - if guard.should_block_request(): + guard.register_network_failure(e, destination) + if guard.should_block_request(destination): return False, OFFLINE_COOLDOWN_ERROR logger.debug("Network unavailable during header probe: %s", e) return False, str(e) @@ -935,7 +937,8 @@ class Downloader: Tuple[bool, Union[Dict, str]]: (success, response data or error message) """ guard = await ConnectivityGuard.get_instance() - if guard.should_block_request(): + destination = self._guard_destination(url) + if guard.should_block_request(destination): return False, OFFLINE_COOLDOWN_ERROR try: @@ -961,7 +964,7 @@ class Downloader: method, url, headers=headers, **kwargs ) as response: if response.status == 200: - guard.register_success() + guard.register_success(destination) # Try to parse as JSON, fall back to text try: data = await response.json() @@ -993,8 +996,8 @@ class Downloader: except Exception as e: if guard.is_network_unreachable_error(e): - guard.register_network_failure(e) - if guard.should_block_request(): + guard.register_network_failure(e, destination) + if guard.should_block_request(destination): return False, OFFLINE_COOLDOWN_ERROR logger.debug("Network unavailable for %s %s: %s", method, url, e) return False, str(e) @@ -1048,6 +1051,14 @@ class Downloader: delta = retry_datetime - datetime.now(tz=retry_datetime.tzinfo) return max(0.0, delta.total_seconds()) + @staticmethod + def _guard_destination(url: str) -> str: + """Build per-destination connectivity guard scope from request URL.""" + parsed_url = urlparse(url) + if parsed_url.hostname: + return parsed_url.hostname.lower() + return "unknown" + # Global instance accessor async def get_downloader() -> Downloader: diff --git a/py/services/settings_manager.py b/py/services/settings_manager.py index 8b6cfebc..2c1fc727 100644 --- a/py/services/settings_manager.py +++ b/py/services/settings_manager.py @@ -55,6 +55,8 @@ DEFAULT_KEYS_CLEANUP_THRESHOLD = 10 DEFAULT_SETTINGS: Dict[str, Any] = { "civitai_api_key": "", "civitai_host": "civitai.com", + "download_backend": "python", + "aria2c_path": "", "use_portable_settings": False, "hash_chunk_size_mb": DEFAULT_HASH_CHUNK_SIZE_MB, "language": "en", @@ -761,34 +763,29 @@ class SettingsManager: if self._preserve_disk_template: return - folder_paths = self.settings.get("folder_paths", {}) updated = False def _check_and_auto_set(key: str, setting_key: str) -> bool: """Repair default roots when empty or no longer present.""" current = self.settings.get(setting_key, "") - candidates = folder_paths.get(key, []) - if not isinstance(candidates, list) or not candidates: + primary_candidates = self._get_valid_root_candidates(key) + if not primary_candidates: return False - # Filter valid path strings - valid_paths = [p for p in candidates if isinstance(p, str) and p.strip()] - if not valid_paths: + allowed_roots = self._get_allowed_roots(key) + if current and current in allowed_roots: return False - if current in valid_paths: - return False - - self.settings[setting_key] = valid_paths[0] + self.settings[setting_key] = primary_candidates[0] if current: logger.info( - "Repaired stale %s from '%s' to '%s'", + "Repaired stale %s from '%s' to '%s' because it is not present in primary or extra roots", setting_key, current, - valid_paths[0], + primary_candidates[0], ) else: - logger.info("Auto-set %s to '%s'", setting_key, valid_paths[0]) + logger.info("Auto-set %s to '%s'", setting_key, primary_candidates[0]) return True # Process all model types @@ -811,6 +808,33 @@ class SettingsManager: else: self._save_settings() + def _get_valid_root_candidates(self, key: str) -> List[str]: + """Return stable root candidates, preferring primary roots over extra roots.""" + + candidates: List[str] = [] + seen: set[str] = set() + for mapping_key in ("folder_paths", "extra_folder_paths"): + raw_paths = self.settings.get(mapping_key, {}) + if not isinstance(raw_paths, Mapping): + continue + values = raw_paths.get(key, []) + if not isinstance(values, list): + continue + for value in values: + if not isinstance(value, str): + continue + normalized = value.strip() + if not normalized or normalized in seen: + continue + seen.add(normalized) + candidates.append(normalized) + return candidates + + def _get_allowed_roots(self, key: str) -> set[str]: + """Return all valid roots for a model type, including extra roots.""" + + return set(self._get_valid_root_candidates(key)) + def _check_environment_variables(self) -> None: """Check for environment variables and update settings if needed""" env_api_key = os.environ.get("CIVITAI_API_KEY") diff --git a/static/css/components/modal/settings-modal.css b/static/css/components/modal/settings-modal.css index 7a1f4902..a0f960b0 100644 --- a/static/css/components/modal/settings-modal.css +++ b/static/css/components/modal/settings-modal.css @@ -346,11 +346,13 @@ .api-key-input input { width: 100%; padding: 6px 40px 6px 10px; /* Add left padding */ - height: 20px; + height: 32px; + box-sizing: border-box; border-radius: var(--border-radius-xs); border: 1px solid var(--border-color); background-color: var(--lora-surface); color: var(--text-color); + font-size: 0.95em; } .api-key-input .toggle-visibility { @@ -379,7 +381,8 @@ .text-input-wrapper input { width: 100%; padding: 6px 10px; - height: 20px; + height: 32px; + box-sizing: border-box; border-radius: var(--border-radius-xs); border: 1px solid var(--border-color); background-color: var(--lora-surface); @@ -760,10 +763,12 @@ } .setting-control { - width: 60%; /* Decreased slightly from 65% */ + flex: 0 0 60%; + max-width: 60%; margin-bottom: 0; display: flex; justify-content: flex-end; /* Right-align all controls */ + min-width: 0; } /* Select Control Styles */ @@ -773,6 +778,13 @@ justify-content: flex-end; } +.setting-control select, +.setting-control input[type="text"], +.setting-control input[type="password"], +.setting-control input[type="number"] { + font-size: 0.95em; +} + .select-control select { width: 100%; max-width: 100%; /* Increased from 200px */ @@ -781,8 +793,8 @@ border: 1px solid var(--border-color); background-color: var(--lora-surface); color: var(--text-color); - font-size: 0.95em; height: 32px; + box-sizing: border-box; } /* Fix dark theme select dropdown text color */ @@ -888,8 +900,8 @@ input:checked + .toggle-slider:before { border: 1px solid var(--border-color); background-color: var(--lora-surface); color: var(--text-color); - font-size: 0.95em; height: 32px; + box-sizing: border-box; } /* Add warning text style for settings */ diff --git a/static/js/managers/SettingsManager.js b/static/js/managers/SettingsManager.js index 1dcc5d2f..ab2906b2 100644 --- a/static/js/managers/SettingsManager.js +++ b/static/js/managers/SettingsManager.js @@ -807,6 +807,16 @@ export class SettingsManager { civitaiHostSelect.value = state.global.settings.civitai_host || 'civitai.com'; } + const downloadBackendSelect = document.getElementById('downloadBackend'); + if (downloadBackendSelect) { + downloadBackendSelect.value = state.global.settings.download_backend || 'python'; + } + + const aria2cPathInput = document.getElementById('aria2cPath'); + if (aria2cPathInput) { + aria2cPathInput.value = state.global.settings.aria2c_path || ''; + } + const recipesPathInput = document.getElementById('recipesPath'); if (recipesPathInput) { recipesPathInput.value = state.global.settings.recipes_path || ''; @@ -950,9 +960,36 @@ export class SettingsManager { languageSelect.value = currentLanguage; } + this.loadDownloadBackendSettings(); this.loadProxySettings(); } + loadDownloadBackendSettings() { + const downloadBackendSelect = document.getElementById('downloadBackend'); + const aria2PathSetting = document.getElementById('aria2PathSetting'); + const updateVisibility = () => { + if (!aria2PathSetting || !downloadBackendSelect) { + return; + } + aria2PathSetting.style.display = downloadBackendSelect.value === 'aria2' ? 'block' : 'none'; + }; + + if (downloadBackendSelect) { + downloadBackendSelect.value = state.global.settings.download_backend || 'python'; + downloadBackendSelect.onchange = () => { + updateVisibility(); + this.saveSelectSetting('downloadBackend', 'download_backend'); + }; + } + + const aria2cPathInput = document.getElementById('aria2cPath'); + if (aria2cPathInput) { + aria2cPathInput.value = state.global.settings.aria2c_path || ''; + } + + updateVisibility(); + } + setupPriorityTagInputs() { ['lora', 'checkpoint', 'embedding'].forEach((modelType) => { const textarea = document.getElementById(`${modelType}PriorityTagsInput`); diff --git a/static/js/state/index.js b/static/js/state/index.js index 4a55ab26..5206c3d4 100644 --- a/static/js/state/index.js +++ b/static/js/state/index.js @@ -6,6 +6,8 @@ import { DEFAULT_PATH_TEMPLATES, DEFAULT_PRIORITY_TAG_CONFIG } from '../utils/co const DEFAULT_SETTINGS_BASE = Object.freeze({ civitai_api_key: '', civitai_host: 'civitai.com', + download_backend: 'python', + aria2c_path: '', use_portable_settings: false, language: 'en', show_only_sfw: false, diff --git a/templates/components/modals/settings_modal.html b/templates/components/modals/settings_modal.html index c85a7cfc..607241e0 100644 --- a/templates/components/modals/settings_modal.html +++ b/templates/components/modals/settings_modal.html @@ -129,6 +129,43 @@ +
+
+

{{ t('settings.sections.downloads') }}

+
+
+
+
+ + +
+
+ +
+
+
+ +
+
diff --git a/tests/config/test_config_save_paths.py b/tests/config/test_config_save_paths.py index 8fca10ed..0d2d13b8 100644 --- a/tests/config/test_config_save_paths.py +++ b/tests/config/test_config_save_paths.py @@ -46,6 +46,7 @@ def test_save_paths_renames_default_library(monkeypatch: pytest.MonkeyPatch, tmp self.delete_calls = [] self.upsert_calls = [] self._renamed = False + self.active_library = "default" def get_libraries(self): if self._renamed: @@ -62,6 +63,11 @@ def test_save_paths_renames_default_library(monkeypatch: pytest.MonkeyPatch, tmp def rename_library(self, old_name: str, new_name: str): self.rename_calls.append((old_name, new_name)) self._renamed = True + if self.active_library == old_name: + self.active_library = new_name + + def get_active_library_name(self): + return self.active_library def delete_library(self, name: str): # pragma: no cover - defensive guard self.delete_calls.append(name) @@ -104,6 +110,7 @@ def test_save_paths_logs_warning_when_upsert_fails( class RaisingSettingsService: def __init__(self): self.upsert_attempts = [] + self.active_library = "comfyui" def get_libraries(self): return { @@ -116,6 +123,9 @@ def test_save_paths_logs_warning_when_upsert_fails( def rename_library(self, *_): raise AssertionError("rename_library should not be invoked") + def get_active_library_name(self): + return self.active_library + def upsert_library(self, name: str, **payload): self.upsert_attempts.append((name, payload)) raise RuntimeError("boom") @@ -135,6 +145,8 @@ def test_save_paths_repairs_empty_default_roots(monkeypatch: pytest.MonkeyPatch, folder_paths = _setup_config_environment(monkeypatch, tmp_path) class FakeSettingsService: + active_library = "comfyui" + def get_libraries(self): return { "comfyui": { @@ -148,6 +160,9 @@ def test_save_paths_repairs_empty_default_roots(monkeypatch: pytest.MonkeyPatch, def rename_library(self, *_): raise AssertionError("rename_library should not be invoked") + def get_active_library_name(self): + return self.active_library + def upsert_library(self, name: str, **payload): self.name = name self.payload = payload @@ -167,6 +182,8 @@ def test_save_paths_repairs_stale_default_roots(monkeypatch: pytest.MonkeyPatch, folder_paths = _setup_config_environment(monkeypatch, tmp_path) class FakeSettingsService: + active_library = "comfyui" + def get_libraries(self): return { "comfyui": { @@ -180,6 +197,9 @@ def test_save_paths_repairs_stale_default_roots(monkeypatch: pytest.MonkeyPatch, def rename_library(self, *_): raise AssertionError("rename_library should not be invoked") + def get_active_library_name(self): + return self.active_library + def upsert_library(self, name: str, **payload): self.name = name self.payload = payload @@ -199,6 +219,8 @@ def test_save_paths_keeps_valid_default_roots(monkeypatch: pytest.MonkeyPatch, t folder_paths = _setup_config_environment(monkeypatch, tmp_path) class FakeSettingsService: + active_library = "comfyui" + def get_libraries(self): return { "comfyui": { @@ -212,6 +234,9 @@ def test_save_paths_keeps_valid_default_roots(monkeypatch: pytest.MonkeyPatch, t def rename_library(self, *_): raise AssertionError("rename_library should not be invoked") + def get_active_library_name(self): + return self.active_library + def upsert_library(self, name: str, **payload): self.name = name self.payload = payload @@ -258,6 +283,7 @@ def test_save_paths_removes_template_default_library(monkeypatch, tmp_path): self.rename_calls = [] self.delete_calls = [] self.upsert_calls = [] + self.active_library = "default" def get_libraries(self): return self.libraries @@ -265,6 +291,8 @@ def test_save_paths_removes_template_default_library(monkeypatch, tmp_path): def rename_library(self, old_name: str, new_name: str): self.rename_calls.append((old_name, new_name)) self.libraries[new_name] = self.libraries.pop(old_name) + if self.active_library == old_name: + self.active_library = new_name def delete_library(self, name: str): self.delete_calls.append(name) @@ -273,6 +301,11 @@ def test_save_paths_removes_template_default_library(monkeypatch, tmp_path): def upsert_library(self, name: str, **payload): self.upsert_calls.append((name, payload)) self.libraries[name] = {**payload} + if payload.get("activate"): + self.active_library = name + + def get_active_library_name(self): + return self.active_library fake_settings = FakeSettingsService() monkeypatch.setattr(settings_manager_module, "settings", fake_settings) @@ -313,6 +346,156 @@ def test_save_paths_removes_template_default_library(monkeypatch, tmp_path): assert payload["activate"] is True +def test_save_paths_keeps_default_roots_in_extra_paths(monkeypatch: pytest.MonkeyPatch, tmp_path): + folder_paths = _setup_config_environment(monkeypatch, tmp_path) + extra_lora_dir = tmp_path / "extra_loras" + extra_checkpoint_dir = tmp_path / "extra_checkpoints" + extra_embedding_dir = tmp_path / "extra_embeddings" + + for directory in (extra_lora_dir, extra_checkpoint_dir, extra_embedding_dir): + directory.mkdir() + + class FakeSettingsService: + active_library = "comfyui" + + def get_libraries(self): + return { + "comfyui": { + "folder_paths": {key: list(value) for key, value in folder_paths.items()}, + "extra_folder_paths": { + "loras": [str(extra_lora_dir)], + "checkpoints": [str(extra_checkpoint_dir)], + "embeddings": [str(extra_embedding_dir)], + }, + "default_lora_root": str(extra_lora_dir), + "default_checkpoint_root": str(extra_checkpoint_dir), + "default_embedding_root": str(extra_embedding_dir), + } + } + + def rename_library(self, *_): + raise AssertionError("rename_library should not be invoked") + + def get_active_library_name(self): + return self.active_library + + def upsert_library(self, name: str, **payload): + self.name = name + self.payload = payload + + fake_settings = FakeSettingsService() + monkeypatch.setattr(settings_manager_module, "settings", fake_settings) + + config_module.Config() + + assert fake_settings.name == "comfyui" + assert fake_settings.payload["extra_folder_paths"]["loras"] == [str(extra_lora_dir).replace("\\", "/")] + assert fake_settings.payload["extra_folder_paths"]["checkpoints"] == [ + str(extra_checkpoint_dir).replace("\\", "/") + ] + assert fake_settings.payload["extra_folder_paths"]["embeddings"] == [ + str(extra_embedding_dir).replace("\\", "/") + ] + assert fake_settings.payload["default_lora_root"] == str(extra_lora_dir).replace("\\", "/") + assert fake_settings.payload["default_checkpoint_root"] == str(extra_checkpoint_dir).replace("\\", "/") + assert fake_settings.payload["default_embedding_root"] == str(extra_embedding_dir).replace("\\", "/") + assert fake_settings.payload["activate"] is True + + +def test_save_paths_repairs_empty_default_roots_to_extra_paths_when_primary_missing( + monkeypatch: pytest.MonkeyPatch, tmp_path +): + _setup_config_environment(monkeypatch, tmp_path) + extra_lora_dir = tmp_path / "extra_loras" + extra_lora_dir.mkdir() + + monkeypatch.setattr( + config_module.folder_paths, + "get_folder_paths", + lambda kind: [] if kind == "loras" else [], + ) + + class FakeSettingsService: + active_library = "comfyui" + + def get_libraries(self): + return { + "comfyui": { + "folder_paths": { + "loras": [], + "checkpoints": [], + "unet": [], + "embeddings": [], + }, + "extra_folder_paths": { + "loras": [str(extra_lora_dir)], + }, + "default_lora_root": "", + } + } + + def rename_library(self, *_): + raise AssertionError("rename_library should not be invoked") + + def get_active_library_name(self): + return self.active_library + + def upsert_library(self, name: str, **payload): + self.name = name + self.payload = payload + + fake_settings = FakeSettingsService() + monkeypatch.setattr(settings_manager_module, "settings", fake_settings) + + config_module.Config() + + assert fake_settings.name == "comfyui" + assert fake_settings.payload["default_lora_root"] == str(extra_lora_dir).replace("\\", "/") + + +def test_save_paths_does_not_activate_comfyui_library_when_another_library_is_active( + monkeypatch: pytest.MonkeyPatch, tmp_path +): + folder_paths = _setup_config_environment(monkeypatch, tmp_path) + + class FakeSettingsService: + def __init__(self): + self.active_library = "studio" + self.upsert_calls = [] + + def get_libraries(self): + return { + "studio": { + "folder_paths": {"loras": ["/studio/loras"]}, + }, + "comfyui": { + "folder_paths": {key: list(value) for key, value in folder_paths.items()}, + "default_lora_root": folder_paths["loras"][0], + "default_checkpoint_root": folder_paths["checkpoints"][0], + "default_embedding_root": folder_paths["embeddings"][0], + }, + } + + def rename_library(self, *_): + raise AssertionError("rename_library should not be invoked") + + def get_active_library_name(self): + return self.active_library + + def upsert_library(self, name: str, **payload): + self.upsert_calls.append((name, payload)) + + fake_settings = FakeSettingsService() + monkeypatch.setattr(settings_manager_module, "settings", fake_settings) + + config_module.Config() + + assert len(fake_settings.upsert_calls) == 1 + name, payload = fake_settings.upsert_calls[0] + assert name == "comfyui" + assert payload["activate"] is False + + def test_apply_library_settings_merges_extra_paths(monkeypatch, tmp_path): """Test that apply_library_settings correctly merges folder_paths with extra_folder_paths.""" loras_dir = tmp_path / "loras" diff --git a/tests/frontend/managers/settingsManager.library.test.js b/tests/frontend/managers/settingsManager.library.test.js index 18cd2c24..2b8eb9a0 100644 --- a/tests/frontend/managers/settingsManager.library.test.js +++ b/tests/frontend/managers/settingsManager.library.test.js @@ -305,4 +305,39 @@ describe('SettingsManager library controls', () => { 'success', ); }); + + it('loads download backend settings and toggles the aria2 path field', () => { + const manager = createManager(); + document.body.innerHTML = ` + + + + `; + + state.global.settings = { + download_backend: 'aria2', + aria2c_path: '/usr/bin/aria2c', + }; + + const saveSpy = vi.spyOn(manager, 'saveSelectSetting').mockResolvedValue(); + + manager.loadDownloadBackendSettings(); + + const backendSelect = document.getElementById('downloadBackend'); + const aria2PathSetting = document.getElementById('aria2PathSetting'); + const aria2cPath = document.getElementById('aria2cPath'); + + expect(backendSelect.value).toBe('aria2'); + expect(aria2cPath.value).toBe('/usr/bin/aria2c'); + expect(aria2PathSetting.style.display).toBe('block'); + + backendSelect.value = 'python'; + backendSelect.onchange(); + + expect(aria2PathSetting.style.display).toBe('none'); + expect(saveSpy).toHaveBeenCalledWith('downloadBackend', 'download_backend'); + }); }); diff --git a/tests/services/test_aria2_downloader.py b/tests/services/test_aria2_downloader.py new file mode 100644 index 00000000..268aa91e --- /dev/null +++ b/tests/services/test_aria2_downloader.py @@ -0,0 +1,354 @@ +from __future__ import annotations + +import asyncio +from pathlib import Path +from unittest.mock import AsyncMock + +import pytest + +from py.services.aria2_downloader import Aria2Downloader, Aria2Error +from py.services.aria2_transfer_state import Aria2TransferStateStore +from py.services import aria2_transfer_state + + +@pytest.fixture(autouse=True) +def isolate_aria2_state(monkeypatch, tmp_path): + state_path = tmp_path / "cache" / "aria2" / "downloads.json" + monkeypatch.setattr( + aria2_transfer_state, + "get_aria2_state_path", + lambda: str(state_path), + ) + + +@pytest.mark.asyncio +async def test_download_file_polls_until_complete(tmp_path, monkeypatch): + downloader = Aria2Downloader() + downloader._rpc_url = "http://127.0.0.1/jsonrpc" + downloader._rpc_secret = "secret" + + save_path = tmp_path / "downloads" / "model.safetensors" + progress_events = [] + rpc_calls = [] + statuses = iter( + [ + { + "gid": "gid-1", + "status": "active", + "completedLength": "5", + "totalLength": "10", + "downloadSpeed": "25", + }, + { + "gid": "gid-1", + "status": "complete", + "completedLength": "10", + "totalLength": "10", + "downloadSpeed": "0", + "files": [{"path": str(save_path)}], + }, + ] + ) + + async def fake_rpc_call(method, params): + rpc_calls.append((method, params)) + if method == "aria2.addUri": + return "gid-1" + if method == "aria2.tellStatus": + return next(statuses) + raise AssertionError(f"Unexpected RPC method: {method}") + + monkeypatch.setattr(downloader, "_ensure_process", AsyncMock()) + monkeypatch.setattr( + downloader, + "_resolve_authenticated_redirect_url", + AsyncMock( + return_value="https://signed.example.com/model.safetensors?token=abc" + ), + ) + monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call) + monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock()) + + async def progress_callback(progress, snapshot=None): + progress_events.append(snapshot.percent_complete if snapshot else progress) + + success, result = await downloader.download_file( + "https://civitai.com/api/download/models/123", + str(save_path), + download_id="download-1", + progress_callback=progress_callback, + headers={"Authorization": "Bearer token"}, + ) + + assert success is True + assert result == str(save_path) + assert progress_events == [50.0, 100.0] + assert downloader._transfers == {} + assert rpc_calls[0][0] == "aria2.addUri" + assert rpc_calls[0][1][0] == [ + "https://signed.example.com/model.safetensors?token=abc" + ] + assert rpc_calls[0][1][1]["out"] == "model.safetensors" + assert "header" not in rpc_calls[0][1][1] + + +@pytest.mark.asyncio +async def test_transfer_state_store_shares_lock_and_preserves_concurrent_updates(tmp_path): + state_path = tmp_path / "cache" / "aria2" / "downloads.json" + store_a = Aria2TransferStateStore(str(state_path)) + store_b = Aria2TransferStateStore(str(state_path)) + + assert store_a._lock is store_b._lock + + await asyncio.gather( + store_a.upsert("download-1", {"status": "downloading", "gid": "gid-1"}), + store_b.upsert("download-2", {"status": "paused", "gid": "gid-2"}), + ) + + assert await store_a.get("download-1") == {"status": "downloading", "gid": "gid-1"} + assert await store_b.get("download-2") == {"status": "paused", "gid": "gid-2"} + + +@pytest.mark.asyncio +async def test_download_file_keeps_auth_headers_when_civitai_does_not_redirect( + tmp_path, monkeypatch +): + downloader = Aria2Downloader() + downloader._rpc_url = "http://127.0.0.1/jsonrpc" + downloader._rpc_secret = "secret" + + save_path = tmp_path / "downloads" / "model.safetensors" + rpc_calls = [] + statuses = iter( + [ + { + "gid": "gid-1", + "status": "complete", + "completedLength": "10", + "totalLength": "10", + "downloadSpeed": "0", + "files": [{"path": str(save_path)}], + }, + ] + ) + + async def fake_rpc_call(method, params): + rpc_calls.append((method, params)) + if method == "aria2.addUri": + return "gid-1" + if method == "aria2.tellStatus": + return next(statuses) + raise AssertionError(f"Unexpected RPC method: {method}") + + monkeypatch.setattr(downloader, "_ensure_process", AsyncMock()) + monkeypatch.setattr( + downloader, + "_resolve_authenticated_redirect_url", + AsyncMock(return_value="https://civitai.com/api/download/models/123"), + ) + monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call) + monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock()) + + success, result = await downloader.download_file( + "https://civitai.com/api/download/models/123", + str(save_path), + download_id="download-1", + headers={"Authorization": "Bearer token"}, + ) + + assert success is True + assert result == str(save_path) + assert rpc_calls[0][1][0] == ["https://civitai.com/api/download/models/123"] + assert rpc_calls[0][1][1]["header"] == ["Authorization: Bearer token"] + + +@pytest.mark.asyncio +async def test_pause_resume_cancel_forward_to_rpc(monkeypatch): + downloader = Aria2Downloader() + downloader._transfers["download-1"] = type( + "Transfer", (), {"gid": "gid-1", "save_path": "/tmp/model.safetensors"} + )() + + calls = [] + + async def fake_rpc_call(method, params): + calls.append((method, params)) + return "gid-1" + + monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call) + + pause_result = await downloader.pause_download("download-1") + resume_result = await downloader.resume_download("download-1") + cancel_result = await downloader.cancel_download("download-1") + + assert pause_result["success"] is True + assert resume_result["success"] is True + assert cancel_result["success"] is True + assert calls == [ + ("aria2.forcePause", ["gid-1"]), + ("aria2.unpause", ["gid-1"]), + ("aria2.forceRemove", ["gid-1"]), + ] + + +@pytest.mark.asyncio +async def test_download_file_reuses_existing_transfer_without_add_uri( + tmp_path, monkeypatch +): + downloader = Aria2Downloader() + downloader._rpc_url = "http://127.0.0.1/jsonrpc" + downloader._rpc_secret = "secret" + + save_path = tmp_path / "downloads" / "model.safetensors" + downloader._transfers["download-1"] = type( + "Transfer", (), {"gid": "gid-1", "save_path": str(save_path)} + )() + + rpc_calls = [] + statuses = iter( + [ + { + "gid": "gid-1", + "status": "active", + "completedLength": "5", + "totalLength": "10", + "downloadSpeed": "25", + }, + { + "gid": "gid-1", + "status": "complete", + "completedLength": "10", + "totalLength": "10", + "downloadSpeed": "0", + "files": [{"path": str(save_path)}], + }, + ] + ) + + async def fake_rpc_call(method, params): + rpc_calls.append((method, params)) + if method == "aria2.tellStatus": + return next(statuses) + raise AssertionError(f"Unexpected RPC method: {method}") + + monkeypatch.setattr(downloader, "_ensure_process", AsyncMock()) + monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call) + monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock()) + + success, result = await downloader.download_file( + "https://example.com/model.safetensors", + str(save_path), + download_id="download-1", + ) + + assert success is True + assert result == str(save_path) + assert [call[0] for call in rpc_calls] == ["aria2.tellStatus", "aria2.tellStatus"] + + +def test_build_progress_snapshot_normalizes_numeric_fields(): + downloader = Aria2Downloader() + + snapshot = downloader._build_progress_snapshot( + { + "completedLength": "75", + "totalLength": "100", + "downloadSpeed": "512", + } + ) + + assert snapshot.percent_complete == 75.0 + assert snapshot.bytes_downloaded == 75 + assert snapshot.total_bytes == 100 + assert snapshot.bytes_per_second == 512.0 + + +def test_resolve_executable_raises_when_binary_missing(monkeypatch): + downloader = Aria2Downloader() + settings = type("Settings", (), {"get": lambda self, key, default=None: ""})() + + monkeypatch.setattr("py.services.aria2_downloader.get_settings_manager", lambda: settings) + monkeypatch.setattr("py.services.aria2_downloader.shutil.which", lambda _: None) + + with pytest.raises(Aria2Error): + downloader._resolve_executable() + + +@pytest.mark.asyncio +async def test_rpc_call_surfaces_json_error_on_non_200(monkeypatch): + downloader = Aria2Downloader() + downloader._rpc_url = "http://127.0.0.1:6800/jsonrpc" + downloader._rpc_secret = "secret" + + class FakeResponse: + status = 400 + + async def text(self): + return ( + '{"jsonrpc":"2.0","id":"x","error":{"code":1,"message":"Unauthorized"}}' + ) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + class FakeSession: + def post(self, _url, json=None): + return FakeResponse() + + monkeypatch.setattr(downloader, "_get_rpc_session", AsyncMock(return_value=FakeSession())) + + with pytest.raises(Aria2Error) as exc_info: + await downloader._rpc_call("aria2.addUri", [["https://example.com/file"]]) + + assert "Unauthorized" in str(exc_info.value) + assert "aria2.addUri" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_resolve_authenticated_redirect_url_returns_location(monkeypatch): + downloader = Aria2Downloader() + + class FakeResponse: + status = 307 + headers = {"Location": "https://signed.example.com/file.safetensors"} + + async def text(self): + return "" + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + class FakeSession: + def get(self, _url, headers=None, allow_redirects=False, proxy=None): + return FakeResponse() + + class FakeDownloader: + default_headers = {"User-Agent": "ComfyUI-LoRA-Manager/1.0"} + proxy_url = None + + @property + def session(self): + async def _session(): + return FakeSession() + + return _session() + + fake_downloader = FakeDownloader() + + monkeypatch.setattr( + "py.services.aria2_downloader.get_downloader", + AsyncMock(return_value=fake_downloader), + ) + + result = await downloader._resolve_authenticated_redirect_url( + "https://civitai.com/api/download/models/123", + {"Authorization": "Bearer token"}, + ) + + assert result == "https://signed.example.com/file.safetensors" diff --git a/tests/services/test_connectivity_guard.py b/tests/services/test_connectivity_guard.py index 321837eb..31beb525 100644 --- a/tests/services/test_connectivity_guard.py +++ b/tests/services/test_connectivity_guard.py @@ -39,6 +39,26 @@ async def test_connectivity_guard_enters_cooldown_after_threshold(): assert guard.cooldown_remaining_seconds() > 0 +async def test_connectivity_guard_scopes_cooldown_to_destination(): + guard = await ConnectivityGuard.get_instance() + + destination_a = "civitai.com" + destination_b = "api.github.com" + + guard.register_network_failure( + OSError(errno.ENETUNREACH, "unreachable"), + destination_a, + ) + guard.register_network_failure(asyncio.TimeoutError("timeout"), destination_a) + guard.register_network_failure(ConnectionRefusedError("refused"), destination_a) + + assert guard.should_block_request(destination_a) is True + assert guard.should_block_request(destination_b) is False + + guard.register_success(destination_a) + assert guard.should_block_request(destination_a) is False + + async def test_connectivity_guard_recovers_after_success(): guard = await ConnectivityGuard.get_instance() guard.online = False @@ -55,21 +75,51 @@ async def test_connectivity_guard_recovers_after_success(): async def test_downloader_short_circuits_all_request_helpers_during_cooldown(): guard = await ConnectivityGuard.get_instance() - guard.cooldown_until = datetime.now() + timedelta(seconds=30) - guard.online = False - guard.failure_count = 3 + destination = "example.invalid" + guard.register_network_failure( + OSError(errno.ENETUNREACH, "unreachable"), + destination, + ) + guard.register_network_failure(asyncio.TimeoutError("timeout"), destination) + guard.register_network_failure( + ConnectionRefusedError("refused"), + destination, + ) downloader = Downloader() - ok, payload = await downloader.make_request("GET", "https://example.invalid") + ok, payload = await downloader.make_request("GET", f"https://{destination}") assert ok is False assert payload == OFFLINE_COOLDOWN_ERROR - ok, payload, headers = await downloader.download_to_memory("https://example.invalid") + ok, payload, headers = await downloader.download_to_memory(f"https://{destination}") assert ok is False assert payload == OFFLINE_FRIENDLY_MESSAGE assert headers is None - ok, payload = await downloader.get_response_headers("https://example.invalid") + ok, payload = await downloader.get_response_headers(f"https://{destination}") assert ok is False assert payload == OFFLINE_COOLDOWN_ERROR + + +async def test_downloader_only_short_circuits_requests_for_same_destination(): + guard = await ConnectivityGuard.get_instance() + guard.register_network_failure( + OSError(errno.ENETUNREACH, "unreachable"), + "example.invalid", + ) + guard.register_network_failure(asyncio.TimeoutError("timeout"), "example.invalid") + guard.register_network_failure( + ConnectionRefusedError("refused"), + "example.invalid", + ) + + downloader = Downloader() + ok, payload = await downloader.make_request("GET", "https://example.invalid") + assert ok is False + assert payload == OFFLINE_COOLDOWN_ERROR + + assert ( + guard.should_block_request(downloader._guard_destination("https://example.com")) + is False + ) diff --git a/tests/services/test_download_manager_basic.py b/tests/services/test_download_manager_basic.py index ac801212..3117d612 100644 --- a/tests/services/test_download_manager_basic.py +++ b/tests/services/test_download_manager_basic.py @@ -10,6 +10,7 @@ import pytest from py.services.download_manager import DownloadManager from py.services import download_manager +from py.services import aria2_transfer_state from py.services.service_registry import ServiceRegistry from py.services.settings_manager import SettingsManager, get_settings_manager @@ -46,6 +47,16 @@ def isolate_settings(monkeypatch, tmp_path): monkeypatch.setattr(SettingsManager, "_save_settings", lambda self: None) +@pytest.fixture(autouse=True) +def isolate_aria2_state(monkeypatch, tmp_path): + state_path = tmp_path / "cache" / "aria2" / "downloads.json" + monkeypatch.setattr( + aria2_transfer_state, + "get_aria2_state_path", + lambda: str(state_path), + ) + + @pytest.fixture(autouse=True) def stub_metadata(monkeypatch): class _StubMetadata: @@ -179,6 +190,7 @@ async def test_successful_download_uses_defaults( progress_callback, model_type, download_id, + transfer_backend=None, ): captured.update( { @@ -268,6 +280,7 @@ async def test_download_uses_active_mirrors( progress_callback, model_type, download_id, + transfer_backend=None, ): captured["download_urls"] = download_urls return {"success": True} @@ -288,6 +301,644 @@ async def test_download_uses_active_mirrors( assert captured["download_urls"] == ["https://mirror.example/file.safetensors"] +@pytest.mark.asyncio +async def test_pause_resume_cancel_delegate_to_aria2_backend(monkeypatch): + manager = DownloadManager() + + task = asyncio.create_task(asyncio.sleep(60)) + manager._download_tasks["download-1"] = task + manager._pause_events["download-1"] = download_manager.DownloadStreamControl() + manager._active_downloads["download-1"] = { + "transfer_backend": "aria2", + "status": "downloading", + } + + class DummyAria2Downloader: + def __init__(self): + self.calls = [] + + async def pause_download(self, download_id): + self.calls.append(("pause", download_id)) + return {"success": True, "message": "paused"} + + async def resume_download(self, download_id): + self.calls.append(("resume", download_id)) + return {"success": True, "message": "resumed"} + + async def cancel_download(self, download_id): + self.calls.append(("cancel", download_id)) + return {"success": True, "message": "cancelled"} + + async def has_transfer(self, download_id): + self.calls.append(("has_transfer", download_id)) + return True + + dummy_aria2 = DummyAria2Downloader() + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=dummy_aria2), + ) + + pause_result = await manager.pause_download("download-1") + assert pause_result["success"] is True + assert manager._active_downloads["download-1"]["status"] == "paused" + + resume_result = await manager.resume_download("download-1") + assert resume_result["success"] is True + assert manager._active_downloads["download-1"]["status"] == "downloading" + + cancel_result = await manager.cancel_download("download-1") + assert cancel_result["success"] is True + assert task.cancelled() or task.done() + assert dummy_aria2.calls == [ + ("has_transfer", "download-1"), + ("pause", "download-1"), + ("has_transfer", "download-1"), + ("resume", "download-1"), + ("cancel", "download-1"), + ] + + +@pytest.mark.asyncio +async def test_cancel_allows_queued_aria2_task_without_transfer(monkeypatch): + manager = DownloadManager() + + started = asyncio.Event() + + async def blocked_task(): + started.set() + await asyncio.sleep(60) + + task = asyncio.create_task(blocked_task()) + await started.wait() + + manager._download_tasks["download-queued"] = task + manager._active_downloads["download-queued"] = { + "transfer_backend": "aria2", + "status": "queued", + } + + class DummyAria2Downloader: + async def cancel_download(self, download_id): + return {"success": False, "error": "Download task not found"} + + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=DummyAria2Downloader()), + ) + + result = await manager.cancel_download("download-queued") + + assert result["success"] is True + assert task.cancelled() or task.done() + + +@pytest.mark.asyncio +async def test_pause_resume_queued_aria2_task_without_transfer(monkeypatch): + manager = DownloadManager() + + task = asyncio.create_task(asyncio.sleep(60)) + manager._download_tasks["download-queued"] = task + manager._pause_events["download-queued"] = download_manager.DownloadStreamControl() + manager._active_downloads["download-queued"] = { + "transfer_backend": "aria2", + "status": "waiting", + "bytes_per_second": 12.0, + } + + class DummyAria2Downloader: + def __init__(self): + self.calls = [] + + async def has_transfer(self, download_id): + self.calls.append(("has_transfer", download_id)) + return False + + async def pause_download(self, download_id): + self.calls.append(("pause", download_id)) + return {"success": True, "message": "paused"} + + async def resume_download(self, download_id): + self.calls.append(("resume", download_id)) + return {"success": True, "message": "resumed"} + + dummy_aria2 = DummyAria2Downloader() + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=dummy_aria2), + ) + + pause_result = await manager.pause_download("download-queued") + assert pause_result == {"success": True, "message": "Download paused successfully"} + assert manager._active_downloads["download-queued"]["status"] == "paused" + assert manager._pause_events["download-queued"].is_paused() is True + + resume_result = await manager.resume_download("download-queued") + assert resume_result == {"success": True, "message": "Download resumed successfully"} + assert manager._active_downloads["download-queued"]["status"] == "downloading" + assert manager._pause_events["download-queued"].is_set() is True + assert dummy_aria2.calls == [ + ("has_transfer", "download-queued"), + ("has_transfer", "download-queued"), + ] + + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + +@pytest.mark.asyncio +async def test_resume_download_restores_persisted_aria2_task(monkeypatch, tmp_path): + manager = DownloadManager() + save_dir = tmp_path / "downloads" + save_dir.mkdir() + save_path = save_dir / "file.safetensors" + save_path.write_text("partial") + (save_dir / "file.safetensors.aria2").write_text("control") + + await manager._aria2_state_store.upsert( + "download-1", + { + "download_id": "download-1", + "transfer_backend": "aria2", + "status": "paused", + "save_dir": str(save_dir), + "relative_path": "", + "use_default_paths": False, + "save_path": str(save_path), + "file_path": str(save_path), + "model_id": 12, + "model_version_id": 34, + }, + ) + + created = {} + + async def fake_download_with_semaphore( + self, + task_id, + model_id, + model_version_id, + save_dir, + relative_path, + progress_callback=None, + use_default_paths=False, + source=None, + file_params=None, + ): + created.update( + { + "task_id": task_id, + "model_id": model_id, + "model_version_id": model_version_id, + "save_dir": save_dir, + } + ) + return {"success": True} + + class DummyAria2Downloader: + def __init__(self): + self.calls = [] + + async def get_status_by_gid(self, gid): + return None + + async def has_transfer(self, download_id): + self.calls.append(("has_transfer", download_id)) + return False + + async def resume_download(self, download_id): + self.calls.append(("resume", download_id)) + return {"success": True, "message": "resumed"} + + async def restore_transfer(self, download_id, gid, save_path): + self.calls.append(("restore_transfer", download_id, gid, save_path)) + + dummy_aria2 = DummyAria2Downloader() + monkeypatch.setattr( + download_manager, "_download_with_semaphore", None, raising=False + ) + monkeypatch.setattr( + DownloadManager, + "_download_with_semaphore", + fake_download_with_semaphore, + ) + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=dummy_aria2), + ) + + result = await manager.resume_download("download-1") + await asyncio.sleep(0) + + assert result == {"success": True, "message": "Download resumed successfully"} + assert created["task_id"] == "download-1" + assert created["model_version_id"] == 34 + assert manager._active_downloads["download-1"]["status"] == "downloading" + assert manager._pause_events["download-1"].is_set() is True + + +@pytest.mark.asyncio +async def test_get_active_downloads_restores_persisted_aria2_entries(monkeypatch, tmp_path): + manager = DownloadManager() + save_dir = tmp_path / "downloads" + save_dir.mkdir() + save_path = save_dir / "file.safetensors" + save_path.write_text("partial") + (save_dir / "file.safetensors.aria2").write_text("control") + + await manager._aria2_state_store.upsert( + "download-1", + { + "download_id": "download-1", + "transfer_backend": "aria2", + "status": "paused", + "save_path": str(save_path), + "file_path": str(save_path), + "model_id": 12, + "model_version_id": 34, + }, + ) + + class DummyAria2Downloader: + async def get_status_by_gid(self, gid): + return None + + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=DummyAria2Downloader()), + ) + + downloads = await manager.get_active_downloads() + + assert downloads["downloads"] == [ + { + "download_id": "download-1", + "model_id": 12, + "model_version_id": 34, + "progress": 0, + "status": "paused", + "error": None, + "bytes_downloaded": 0, + "total_bytes": None, + "bytes_per_second": 0.0, + } + ] + + +@pytest.mark.asyncio +async def test_get_active_downloads_restores_orphaned_aria2_partial_as_paused( + monkeypatch, tmp_path +): + manager = DownloadManager() + save_dir = tmp_path / "downloads" + save_dir.mkdir() + save_path = save_dir / "file.safetensors" + save_path.write_text("partial") + (save_dir / "file.safetensors.aria2").write_text("control") + + await manager._aria2_state_store.upsert( + "download-1", + { + "download_id": "download-1", + "transfer_backend": "aria2", + "status": "downloading", + "save_path": str(save_path), + "file_path": str(save_path), + "model_id": 12, + "model_version_id": 34, + "gid": "missing-gid", + }, + ) + + class DummyAria2Downloader: + async def get_status_by_gid(self, gid): + return None + + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=DummyAria2Downloader()), + ) + + downloads = await manager.get_active_downloads() + persisted = await manager._aria2_state_store.get("download-1") + + assert downloads["downloads"] == [ + { + "download_id": "download-1", + "model_id": 12, + "model_version_id": 34, + "progress": 0, + "status": "paused", + "error": None, + "bytes_downloaded": 0, + "total_bytes": None, + "bytes_per_second": 0.0, + } + ] + assert manager._pause_events["download-1"].is_paused() is True + assert persisted["status"] == "paused" + + +@pytest.mark.asyncio +async def test_get_active_downloads_restarts_from_resume_context_for_active_restored_aria2( + monkeypatch, tmp_path +): + manager = DownloadManager() + save_dir = tmp_path / "downloads" + save_dir.mkdir() + save_path = save_dir / "file.safetensors" + save_path.write_text("partial") + + await manager._aria2_state_store.upsert( + "download-1", + { + "download_id": "download-1", + "transfer_backend": "aria2", + "status": "downloading", + "save_path": str(save_path), + "file_path": str(save_path), + "model_id": 12, + "model_version_id": 34, + "gid": "gid-1", + "resume_context": { + "version_info": { + "id": 34, + "modelId": 12, + "model": {"id": 12, "type": "LoRA", "tags": ["fantasy"]}, + "images": [], + }, + "file_info": { + "name": "file.safetensors", + "type": "Model", + "primary": True, + "downloadUrl": "https://example.com/file.safetensors", + }, + "model_type": "lora", + "relative_path": "", + "save_dir": str(save_dir), + "download_urls": ["https://example.com/file.safetensors"], + }, + }, + ) + + restarted = {} + + class DummyAria2Downloader: + async def get_status_by_gid(self, gid): + return {"gid": gid, "status": "active"} + + async def restore_transfer(self, download_id, gid, restored_path): + return None + + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=DummyAria2Downloader()), + ) + + async def fake_resume_restored_aria2_download(self, download_id, record): + restarted.update( + { + "download_id": download_id, + "model_id": record.get("model_id"), + "model_version_id": record.get("model_version_id"), + "save_dir": record.get("save_dir"), + "resume_context": record.get("resume_context"), + } + ) + return {"success": True} + + monkeypatch.setattr( + DownloadManager, + "_resume_restored_aria2_download", + fake_resume_restored_aria2_download, + ) + execute_original = AsyncMock(side_effect=AssertionError("should not refetch metadata")) + monkeypatch.setattr( + DownloadManager, + "_execute_original_download", + execute_original, + ) + + downloads = await manager.get_active_downloads() + assert downloads["downloads"][0]["status"] == "downloading" + restarted_task = manager._download_tasks["download-1"] + await restarted_task + + assert restarted["download_id"] == "download-1" + assert restarted["model_id"] == 12 + assert restarted["model_version_id"] == 34 + assert restarted["save_dir"] is None + assert restarted["resume_context"]["model_type"] == "lora" + assert execute_original.await_count == 0 + + +@pytest.mark.asyncio +async def test_get_active_downloads_restores_persisted_aria2_without_initial_save_path( + monkeypatch, tmp_path +): + manager = DownloadManager() + save_dir = tmp_path / "downloads" + save_dir.mkdir() + save_path = save_dir / "file.safetensors" + save_path.write_text("partial") + (save_dir / "file.safetensors.aria2").write_text("control") + + await manager._aria2_state_store.upsert( + "download-1", + { + "download_id": "download-1", + "transfer_backend": "aria2", + "status": "paused", + "model_id": 12, + "model_version_id": 34, + "resume_context": { + "version_info": { + "id": 34, + "modelId": 12, + "model": {"id": 12, "type": "LoRA"}, + "images": [], + }, + "file_info": { + "name": "file.safetensors", + "type": "Model", + "primary": True, + "downloadUrl": "https://example.com/file.safetensors", + }, + "model_type": "lora", + "relative_path": "", + "save_dir": str(save_dir), + "download_urls": ["https://example.com/file.safetensors"], + }, + }, + ) + + class DummyAria2Downloader: + async def get_status_by_gid(self, gid): + return None + + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=DummyAria2Downloader()), + ) + + downloads = await manager.get_active_downloads() + persisted = await manager._aria2_state_store.get("download-1") + + assert downloads["downloads"] == [ + { + "download_id": "download-1", + "model_id": 12, + "model_version_id": 34, + "progress": 0, + "status": "paused", + "error": None, + "bytes_downloaded": 0, + "total_bytes": None, + "bytes_per_second": 0.0, + } + ] + assert manager._active_downloads["download-1"]["file_path"] == str(save_path) + assert persisted["save_path"] == str(save_path) + assert persisted["file_path"] == str(save_path) + + +@pytest.mark.asyncio +async def test_resume_restored_aria2_download_updates_terminal_status_and_cleanup(monkeypatch): + manager = DownloadManager() + manager._active_downloads["download-1"] = { + "transfer_backend": "aria2", + "status": "paused", + "model_id": 12, + "model_version_id": 34, + "bytes_per_second": 10.0, + } + + persist_state = AsyncMock() + cleanup_record = AsyncMock(return_value=None) + execute_download = AsyncMock(return_value={"success": True}) + record_history = AsyncMock(return_value=None) + sync_version = AsyncMock(return_value=None) + + monkeypatch.setattr(manager, "_persist_aria2_state", persist_state) + monkeypatch.setattr(manager, "_cleanup_download_record", cleanup_record) + monkeypatch.setattr(manager, "_execute_download", execute_download) + monkeypatch.setattr(manager, "_record_downloaded_version_history", record_history) + monkeypatch.setattr(manager, "_sync_downloaded_version", sync_version) + + scheduled_tasks = [] + original_create_task = asyncio.create_task + + def tracking_create_task(coro): + task = original_create_task(coro) + scheduled_tasks.append(task) + return task + + monkeypatch.setattr(download_manager.asyncio, "create_task", tracking_create_task) + + result = await manager._resume_restored_aria2_download( + "download-1", + { + "download_id": "download-1", + "save_path": "/tmp/file.safetensors", + "file_path": "/tmp/file.safetensors", + "model_id": 12, + "model_version_id": 34, + "resume_context": { + "version_info": { + "id": 34, + "modelId": 12, + "model": {"id": 12}, + "images": [], + }, + "file_info": { + "name": "file.safetensors", + "downloadUrl": "https://example.com/file.safetensors", + }, + "model_type": "lora", + "relative_path": "", + "save_dir": "/tmp", + "download_urls": ["https://example.com/file.safetensors"], + }, + }, + ) + + assert result == {"success": True} + assert manager._active_downloads["download-1"]["status"] == "completed" + assert manager._active_downloads["download-1"]["bytes_per_second"] == 0.0 + assert persist_state.await_count == 2 + assert len(scheduled_tasks) == 1 + await asyncio.gather(*scheduled_tasks) + cleanup_record.assert_awaited_once_with("download-1") + + +@pytest.mark.asyncio +async def test_download_uses_captured_backend_when_settings_change( + monkeypatch, scanners, metadata_provider, tmp_path +): + manager = DownloadManager() + settings = get_settings_manager() + settings.settings["download_backend"] = "aria2" + + semaphore = asyncio.Semaphore(0) + manager._download_semaphore = semaphore + + captured = {} + + async def fake_execute_original_download( + self, + model_id, + model_version_id, + save_dir, + relative_path, + progress_callback, + use_default_paths, + download_id=None, + transfer_backend="python", + source=None, + file_params=None, + ): + captured["transfer_backend"] = transfer_backend + return {"success": True} + + monkeypatch.setattr( + DownloadManager, + "_execute_original_download", + fake_execute_original_download, + ) + + download_task = asyncio.create_task( + manager.download_from_civitai( + model_version_id=99, + save_dir=str(tmp_path), + use_default_paths=True, + progress_callback=None, + source=None, + ) + ) + + await asyncio.sleep(0) + assert len(manager._active_downloads) == 1 + download_id = next(iter(manager._active_downloads)) + assert manager._active_downloads[download_id]["transfer_backend"] == "aria2" + + settings.settings["download_backend"] = "python" + semaphore.release() + + result = await download_task + + assert result["success"] is True + assert captured["transfer_backend"] == "aria2" + + @pytest.mark.asyncio async def test_download_aborts_when_version_exists( monkeypatch, scanners, metadata_provider diff --git a/tests/services/test_download_manager_error.py b/tests/services/test_download_manager_error.py index fd154f3b..c462327d 100644 --- a/tests/services/test_download_manager_error.py +++ b/tests/services/test_download_manager_error.py @@ -14,6 +14,7 @@ import pytest from py.services.download_manager import DownloadManager from py.services.downloader import DownloadStreamControl from py.services import download_manager +from py.services import aria2_transfer_state from py.services.service_registry import ServiceRegistry from py.services.settings_manager import SettingsManager, get_settings_manager from py.utils.metadata_manager import MetadataManager @@ -49,6 +50,16 @@ def isolate_settings(monkeypatch, tmp_path): monkeypatch.setattr(SettingsManager, "_save_settings", lambda self: None) +@pytest.fixture(autouse=True) +def isolate_aria2_state(monkeypatch, tmp_path): + state_path = tmp_path / "cache" / "aria2" / "downloads.json" + monkeypatch.setattr( + aria2_transfer_state, + "get_aria2_state_path", + lambda: str(state_path), + ) + + @pytest.mark.asyncio async def test_execute_download_retries_urls(monkeypatch, tmp_path): """Test that download retries multiple URLs on failure.""" @@ -136,6 +147,190 @@ async def test_execute_download_retries_urls(monkeypatch, tmp_path): assert dummy_scanner.calls # ensure cache updated +@pytest.mark.asyncio +async def test_execute_download_uses_aria2_backend_for_model_files(monkeypatch, tmp_path): + manager = DownloadManager() + settings = get_settings_manager() + settings.settings["download_backend"] = "aria2" + settings.settings["civitai_api_key"] = "secret-key" + + save_dir = tmp_path / "downloads" + save_dir.mkdir() + target_path = save_dir / "file.safetensors" + + class DummyMetadata: + def __init__(self, path: Path): + self.file_path = str(path) + self.sha256 = "sha256" + self.file_name = path.stem + self.preview_url = None + + def generate_unique_filename(self, *_args, **_kwargs): + return os.path.basename(self.file_path) + + def update_file_info(self, _path): + return None + + def to_dict(self): + return {"file_path": self.file_path} + + class DummyAria2Downloader: + def __init__(self): + self.calls = [] + + async def download_file( + self, + url, + save_path, + *, + download_id, + progress_callback=None, + headers=None, + ): + self.calls.append( + { + "url": url, + "save_path": save_path, + "download_id": download_id, + "headers": headers, + } + ) + Path(save_path).write_text("content") + return True, save_path + + dummy_aria2 = DummyAria2Downloader() + + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=dummy_aria2), + ) + monkeypatch.setattr( + download_manager, + "get_downloader", + AsyncMock(side_effect=AssertionError("python downloader should not be used")), + ) + + class DummyScanner: + async def add_model_to_cache(self, metadata_dict, relative_path): + return {"metadata": metadata_dict, "relative_path": relative_path} + + dummy_scanner = DummyScanner() + monkeypatch.setattr( + DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner) + ) + monkeypatch.setattr( + DownloadManager, + "_get_checkpoint_scanner", + AsyncMock(return_value=dummy_scanner), + ) + monkeypatch.setattr( + ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=dummy_scanner) + ) + monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True)) + + result = await manager._execute_download( + download_urls=["https://civitai.com/api/download/models/1"], + save_dir=str(save_dir), + metadata=DummyMetadata(target_path), + version_info={"images": []}, + relative_path="", + progress_callback=None, + model_type="lora", + download_id="download-1", + ) + + assert result == {"success": True} + assert dummy_aria2.calls == [ + { + "url": "https://civitai.com/api/download/models/1", + "save_path": str(target_path), + "download_id": "download-1", + "headers": {"Authorization": "Bearer secret-key"}, + } + ] + + +@pytest.mark.asyncio +async def test_execute_download_allows_anonymous_civitai_with_aria2( + monkeypatch, tmp_path +): + manager = DownloadManager() + settings = get_settings_manager() + settings.settings["download_backend"] = "aria2" + settings.settings["civitai_api_key"] = "" + + save_dir = tmp_path / "downloads" + save_dir.mkdir() + target_path = save_dir / "file.safetensors" + + class DummyMetadata: + def __init__(self, path: Path): + self.file_path = str(path) + self.sha256 = "sha256" + self.file_name = path.stem + self.preview_url = None + + def generate_unique_filename(self, *_args, **_kwargs): + return os.path.basename(self.file_path) + + def update_file_info(self, _path): + return None + + def to_dict(self): + return {"file_path": self.file_path} + + class DummyAria2Downloader: + def __init__(self): + self.calls = [] + + async def download_file( + self, + url, + save_path, + *, + download_id, + progress_callback=None, + headers=None, + ): + self.calls.append({"url": url, "headers": headers, "download_id": download_id}) + Path(save_path).write_text("content") + return True, save_path + + dummy_aria2 = DummyAria2Downloader() + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=dummy_aria2), + ) + + dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None)) + monkeypatch.setattr( + DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner) + ) + monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True)) + + result = await manager._execute_download( + download_urls=["https://civitai.com/api/download/models/1"], + save_dir=str(save_dir), + metadata=DummyMetadata(target_path), + version_info={"images": []}, + relative_path="", + progress_callback=None, + model_type="lora", + download_id="download-2", + ) + + assert result == {"success": True} + assert dummy_aria2.calls == [ + { + "url": "https://civitai.com/api/download/models/1", + "headers": None, + "download_id": "download-2", + } + ] + + @pytest.mark.asyncio async def test_execute_download_adjusts_checkpoint_sub_type(monkeypatch, tmp_path): """Test that checkpoint sub_type is adjusted during download.""" @@ -276,6 +471,13 @@ async def test_execute_download_extracts_zip_single_model(monkeypatch, tmp_path) monkeypatch.setattr( download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader()) ) + + class ImmediateLoop: + async def run_in_executor(self, executor, func, *args): + return func(*args) + + monkeypatch.setattr(download_manager.asyncio, "get_running_loop", lambda: ImmediateLoop()) + dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None)) monkeypatch.setattr( DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner) @@ -344,6 +546,13 @@ async def test_execute_download_extracts_zip_multiple_models(monkeypatch, tmp_pa monkeypatch.setattr( download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader()) ) + + class ImmediateLoop: + async def run_in_executor(self, executor, func, *args): + return func(*args) + + monkeypatch.setattr(download_manager.asyncio, "get_running_loop", lambda: ImmediateLoop()) + dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None)) monkeypatch.setattr( DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner) @@ -418,6 +627,13 @@ async def test_execute_download_extracts_zip_pt_embedding(monkeypatch, tmp_path) monkeypatch.setattr( download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader()) ) + + class ImmediateLoop: + async def run_in_executor(self, executor, func, *args): + return func(*args) + + monkeypatch.setattr(download_manager.asyncio, "get_running_loop", lambda: ImmediateLoop()) + dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None)) monkeypatch.setattr( ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=dummy_scanner) @@ -446,6 +662,36 @@ async def test_execute_download_extracts_zip_pt_embedding(monkeypatch, tmp_path) assert dummy_scanner.add_model_to_cache.await_count == 1 +@pytest.mark.asyncio +async def test_extract_model_files_from_archive_uses_executor(monkeypatch, tmp_path): + manager = DownloadManager() + archive_path = tmp_path / "bundle.zip" + with zipfile.ZipFile(archive_path, "w") as archive: + archive.writestr("inner/model.safetensors", b"model") + + captured = {} + + class ImmediateLoop: + async def run_in_executor(self, executor, func, *args): + captured["executor"] = executor + return func(*args) + + monkeypatch.setattr( + download_manager.asyncio, + "get_running_loop", + lambda: ImmediateLoop(), + ) + + extracted = await manager._extract_model_files_from_archive( + str(archive_path), + {".safetensors"}, + ) + + assert captured["executor"] is manager._archive_executor + assert len(extracted) == 1 + assert extracted[0].endswith("model.safetensors") + + @pytest.mark.asyncio async def test_pause_download_updates_state(): """Test that pause_download updates download state correctly.""" @@ -469,6 +715,832 @@ async def test_pause_download_updates_state(): assert manager._active_downloads[download_id]["bytes_per_second"] == 0.0 +@pytest.mark.asyncio +async def test_pause_download_reverts_local_pause_when_aria2_pause_fails(monkeypatch): + manager = DownloadManager() + + download_id = "dl" + manager._download_tasks[download_id] = object() + pause_control = DownloadStreamControl() + manager._pause_events[download_id] = pause_control + manager._active_downloads[download_id] = { + "transfer_backend": "aria2", + "status": "downloading", + "bytes_per_second": 42.0, + } + + class DummyAria2Downloader: + async def has_transfer(self, _download_id): + return True + + async def pause_download(self, _download_id): + return {"success": False, "error": "rpc failed"} + + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=DummyAria2Downloader()), + ) + + result = await manager.pause_download(download_id) + + assert result == {"success": False, "error": "rpc failed"} + assert pause_control.is_set() is True + assert manager._active_downloads[download_id]["status"] == "downloading" + + +@pytest.mark.asyncio +async def test_pause_download_reverts_local_pause_when_aria2_probe_raises(monkeypatch): + manager = DownloadManager() + + download_id = "dl" + manager._download_tasks[download_id] = object() + pause_control = DownloadStreamControl() + manager._pause_events[download_id] = pause_control + manager._active_downloads[download_id] = { + "transfer_backend": "aria2", + "status": "downloading", + "bytes_per_second": 42.0, + } + + class DummyAria2Downloader: + async def has_transfer(self, _download_id): + raise RuntimeError("rpc unavailable") + + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=DummyAria2Downloader()), + ) + + result = await manager.pause_download(download_id) + + assert result == {"success": False, "error": "rpc unavailable"} + assert pause_control.is_set() is True + assert manager._active_downloads[download_id]["status"] == "downloading" + + +@pytest.mark.asyncio +async def test_resume_download_returns_error_when_aria2_probe_raises(monkeypatch): + manager = DownloadManager() + + download_id = "dl" + pause_control = DownloadStreamControl() + pause_control.pause() + manager._pause_events[download_id] = pause_control + manager._active_downloads[download_id] = { + "transfer_backend": "aria2", + "status": "paused", + "bytes_per_second": 0.0, + } + + class DummyAria2Downloader: + async def has_transfer(self, _download_id): + raise RuntimeError("rpc unavailable") + + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=DummyAria2Downloader()), + ) + + result = await manager.resume_download(download_id) + + assert result == {"success": False, "error": "rpc unavailable"} + assert pause_control.is_paused() is True + assert manager._active_downloads[download_id]["status"] == "paused" + + +@pytest.mark.asyncio +async def test_resume_download_does_not_spawn_restored_worker_when_aria2_resume_fails( + monkeypatch, tmp_path +): + manager = DownloadManager() + + download_id = "dl" + save_path = tmp_path / "file.safetensors" + pause_control = DownloadStreamControl() + pause_control.pause() + manager._pause_events[download_id] = pause_control + manager._active_downloads[download_id] = { + "transfer_backend": "aria2", + "status": "paused", + "bytes_per_second": 0.0, + } + + await manager._aria2_state_store.upsert( + download_id, + { + "download_id": download_id, + "transfer_backend": "aria2", + "status": "paused", + "save_path": str(save_path), + "file_path": str(save_path), + "model_id": 12, + "model_version_id": 34, + "resume_context": { + "version_info": {"id": 34, "modelId": 12, "model": {"id": 12}}, + "file_info": { + "name": "file.safetensors", + "downloadUrl": "https://example.com/file.safetensors", + }, + "model_type": "lora", + "relative_path": "", + "save_dir": str(tmp_path), + "download_urls": ["https://example.com/file.safetensors"], + }, + }, + ) + + resume_restored = AsyncMock(return_value={"success": True}) + monkeypatch.setattr(manager, "_resume_restored_aria2_download", resume_restored) + + class DummyAria2Downloader: + async def has_transfer(self, _download_id): + return True + + async def resume_download(self, _download_id): + return {"success": False, "error": "rpc unavailable"} + + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=DummyAria2Downloader()), + ) + + result = await manager.resume_download(download_id) + + assert result == {"success": False, "error": "rpc unavailable"} + assert download_id not in manager._download_tasks + assert resume_restored.await_count == 0 + assert pause_control.is_paused() is True + assert manager._active_downloads[download_id]["status"] == "paused" + + +@pytest.mark.asyncio +async def test_start_background_download_task_cleans_up_finished_restore_task(): + manager = DownloadManager() + download_id = "download-1" + manager._pause_events[download_id] = DownloadStreamControl() + + async def finished_restore(): + return {"success": True} + + task = manager._start_background_download_task(download_id, finished_restore()) + await task + await asyncio.sleep(0) + + assert download_id not in manager._download_tasks + assert download_id not in manager._pause_events + + +@pytest.mark.asyncio +async def test_cancel_download_still_cancels_local_task_when_aria2_raises(monkeypatch): + manager = DownloadManager() + + started = asyncio.Event() + + async def blocked_task(): + started.set() + await asyncio.sleep(60) + + task = asyncio.create_task(blocked_task()) + await started.wait() + + download_id = "download-queued" + manager._download_tasks[download_id] = task + manager._active_downloads[download_id] = { + "transfer_backend": "aria2", + "status": "queued", + } + + class DummyAria2Downloader: + async def cancel_download(self, _download_id): + raise RuntimeError("rpc unavailable") + + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=DummyAria2Downloader()), + ) + + result = await manager.cancel_download(download_id) + + assert result["success"] is True + assert task.cancelled() or task.done() + + +@pytest.mark.asyncio +async def test_cancel_download_preserves_tracking_when_aria2_returns_error(monkeypatch, tmp_path): + manager = DownloadManager() + download_id = "download-queued" + save_path = tmp_path / "file.safetensors" + save_path.write_text("partial") + (tmp_path / "file.safetensors.aria2").write_text("control") + + pause_control = DownloadStreamControl() + manager._pause_events[download_id] = pause_control + manager._download_tasks[download_id] = object() + manager._active_downloads[download_id] = { + "transfer_backend": "aria2", + "status": "downloading", + "file_path": str(save_path), + } + + await manager._aria2_state_store.upsert( + download_id, + { + "download_id": download_id, + "transfer_backend": "aria2", + "status": "downloading", + "save_path": str(save_path), + "file_path": str(save_path), + }, + ) + + cleanup_files = AsyncMock(return_value=None) + monkeypatch.setattr(manager, "_cleanup_cancelled_download_files", cleanup_files) + + class DummyAria2Downloader: + async def cancel_download(self, _download_id): + return {"success": False, "error": "rpc unavailable"} + + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=DummyAria2Downloader()), + ) + + result = await manager.cancel_download(download_id) + + assert result == {"success": False, "error": "rpc unavailable"} + assert download_id in manager._download_tasks + assert download_id in manager._pause_events + assert await manager._aria2_state_store.get(download_id) is not None + assert cleanup_files.await_count == 0 + + +@pytest.mark.asyncio +async def test_cancel_download_rejects_completed_history_entry(tmp_path): + manager = DownloadManager() + download_id = "completed-download" + save_path = tmp_path / "file.safetensors" + metadata_path = tmp_path / "file.metadata.json" + preview_path = tmp_path / "file.jpeg" + save_path.write_text("complete") + metadata_path.write_text("{}") + preview_path.write_text("preview") + + manager._active_downloads[download_id] = { + "transfer_backend": "aria2", + "status": "completed", + "file_path": str(save_path), + "preview_path": str(preview_path), + } + + result = await manager.cancel_download(download_id) + + assert result == {"success": False, "error": "Download task not found"} + assert save_path.exists() + assert metadata_path.exists() + assert preview_path.exists() + + +@pytest.mark.asyncio +async def test_cancel_download_removes_preview_and_aria2_control_files(monkeypatch, tmp_path): + manager = DownloadManager() + + started = asyncio.Event() + + async def blocked_task(): + started.set() + await asyncio.sleep(60) + + task = asyncio.create_task(blocked_task()) + await started.wait() + + save_path = tmp_path / "file.safetensors" + save_path.write_text("partial") + aria2_path = tmp_path / "file.safetensors.aria2" + aria2_path.write_text("control") + preview_path = tmp_path / "file.jpeg" + preview_path.write_text("preview") + + download_id = "download-queued" + manager._download_tasks[download_id] = task + manager._active_downloads[download_id] = { + "transfer_backend": "aria2", + "status": "queued", + "file_path": str(save_path), + "aria2_control_path": str(aria2_path), + "preview_path": str(preview_path), + } + + class DummyAria2Downloader: + async def cancel_download(self, _download_id): + return {"success": True, "message": "cancelled"} + + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=DummyAria2Downloader()), + ) + + result = await manager.cancel_download(download_id) + + assert result["success"] is True + assert not save_path.exists() + assert not aria2_path.exists() + assert not preview_path.exists() + + +@pytest.mark.asyncio +async def test_cancel_download_does_not_delete_untracked_same_basename_preview( + monkeypatch, tmp_path +): + manager = DownloadManager() + + started = asyncio.Event() + + async def blocked_task(): + started.set() + await asyncio.sleep(60) + + task = asyncio.create_task(blocked_task()) + await started.wait() + + save_path = tmp_path / "file.safetensors" + save_path.write_text("partial") + aria2_path = tmp_path / "file.safetensors.aria2" + aria2_path.write_text("control") + manual_preview_path = tmp_path / "file.jpg" + manual_preview_path.write_text("manual") + + download_id = "download-queued" + manager._download_tasks[download_id] = task + manager._active_downloads[download_id] = { + "transfer_backend": "aria2", + "status": "queued", + "file_path": str(save_path), + "aria2_control_path": str(aria2_path), + } + + class DummyAria2Downloader: + async def cancel_download(self, _download_id): + return {"success": True, "message": "cancelled"} + + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=DummyAria2Downloader()), + ) + + result = await manager.cancel_download(download_id) + + assert result["success"] is True + assert not save_path.exists() + assert not aria2_path.exists() + assert manual_preview_path.exists() + + +@pytest.mark.asyncio +async def test_cleanup_cancelled_download_files_retries_aria2_control_deletion( + monkeypatch, tmp_path +): + manager = DownloadManager() + download_id = "download-1" + + save_path = tmp_path / "file.safetensors" + aria2_path = tmp_path / "file.safetensors.aria2" + save_path.write_text("partial") + aria2_path.write_text("control") + + original_unlink = os.unlink + attempts = {"count": 0} + + def flaky_unlink(path): + if path == str(aria2_path) and attempts["count"] == 0: + attempts["count"] += 1 + raise PermissionError("still locked") + return original_unlink(path) + + monkeypatch.setattr(download_manager.os, "unlink", flaky_unlink) + monkeypatch.setattr("py.services.download_manager.asyncio.sleep", AsyncMock()) + + await manager._cleanup_cancelled_download_files( + download_id, + { + "file_path": str(save_path), + "aria2_control_path": str(aria2_path), + "transfer_backend": "aria2", + }, + ) + + assert attempts["count"] == 1 + assert not save_path.exists() + assert not aria2_path.exists() + + +@pytest.mark.asyncio +async def test_execute_download_waits_for_paused_pre_transfer_gate(monkeypatch, tmp_path): + manager = DownloadManager() + + save_dir = tmp_path / "downloads" + save_dir.mkdir() + target_path = save_dir / "file.safetensors" + + class DummyMetadata: + def __init__(self, path: Path): + self.file_path = str(path) + self.sha256 = "sha256" + self.file_name = path.stem + self.preview_url = None + + def generate_unique_filename(self, *_args, **_kwargs): + return os.path.basename(self.file_path) + + def update_file_info(self, _path): + return None + + def to_dict(self): + return {"file_path": self.file_path} + + pause_control = DownloadStreamControl() + pause_control.pause() + manager._pause_events["download-1"] = pause_control + manager._active_downloads["download-1"] = { + "status": "downloading", + "bytes_per_second": 42.0, + } + + dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None)) + monkeypatch.setattr( + DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner) + ) + monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True)) + + started = asyncio.Event() + allow_finish = asyncio.Event() + captured = {"calls": 0} + + async def fake_download_model_file( + self, + download_url, + save_path, + *, + backend, + progress_callback, + use_auth, + download_id, + pause_control, + ): + captured["calls"] += 1 + started.set() + await allow_finish.wait() + Path(save_path).write_text("content") + return True, save_path + + monkeypatch.setattr( + DownloadManager, + "_download_model_file", + fake_download_model_file, + ) + + task = asyncio.create_task( + manager._execute_download( + download_urls=["https://civitai.com/api/download/models/1"], + save_dir=str(save_dir), + metadata=DummyMetadata(target_path), + version_info={"images": []}, + relative_path="", + progress_callback=None, + model_type="lora", + download_id="download-1", + transfer_backend="aria2", + ) + ) + + await asyncio.sleep(0) + assert started.is_set() is False + assert captured["calls"] == 0 + assert manager._active_downloads["download-1"]["status"] == "paused" + + pause_control.resume() + await asyncio.wait_for(started.wait(), timeout=1.0) + assert captured["calls"] == 1 + assert manager._active_downloads["download-1"]["status"] == "downloading" + + allow_finish.set() + result = await task + + assert result == {"success": True} + + +@pytest.mark.asyncio +async def test_execute_download_reuses_existing_aria2_partial_path(monkeypatch, tmp_path): + manager = DownloadManager() + + save_dir = tmp_path / "downloads" + save_dir.mkdir() + target_path = save_dir / "file.safetensors" + target_path.write_text("partial") + control_path = save_dir / "file.safetensors.aria2" + control_path.write_text("control") + + await manager._aria2_state_store.upsert( + "download-1", + { + "download_id": "download-1", + "transfer_backend": "aria2", + "save_path": str(target_path), + "file_path": str(target_path), + "status": "paused", + }, + ) + + class DummyMetadata: + def __init__(self, path: Path): + self.file_path = str(path) + self.sha256 = "sha256" + self.file_name = path.stem + self.preview_url = None + + def generate_unique_filename(self, *_args, **_kwargs): + return "renamed.safetensors" + + def update_file_info(self, _path): + return None + + def to_dict(self): + return {"file_path": self.file_path} + + manager._active_downloads["download-1"] = {"transfer_backend": "aria2"} + dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None)) + monkeypatch.setattr( + DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner) + ) + monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True)) + + async def fake_download_model_file( + self, + download_url, + save_path, + *, + backend, + progress_callback, + use_auth, + download_id, + pause_control, + ): + Path(save_path).write_text("content") + return True, save_path + + monkeypatch.setattr(DownloadManager, "_download_model_file", fake_download_model_file) + + result = await manager._execute_download( + download_urls=["https://example.com/file.safetensors"], + save_dir=str(save_dir), + metadata=DummyMetadata(target_path), + version_info={"images": []}, + relative_path="", + progress_callback=None, + model_type="lora", + download_id="download-1", + transfer_backend="aria2", + ) + + assert result == {"success": True} + assert manager._active_downloads["download-1"]["file_path"] == str(target_path) + assert not (save_dir / "renamed.safetensors").exists() + + +@pytest.mark.asyncio +async def test_execute_download_rejects_conflicting_aria2_partial_path(tmp_path): + manager = DownloadManager() + + save_dir = tmp_path / "downloads" + save_dir.mkdir() + target_path = save_dir / "file.safetensors" + target_path.write_text("partial") + (save_dir / "file.safetensors.aria2").write_text("control") + + await manager._aria2_state_store.upsert( + "other-download", + { + "download_id": "other-download", + "transfer_backend": "aria2", + "save_path": str(target_path), + "file_path": str(target_path), + "status": "paused", + }, + ) + + class DummyMetadata: + def __init__(self, path: Path): + self.file_path = str(path) + self.sha256 = "sha256" + self.file_name = path.stem + self.preview_url = None + + def generate_unique_filename(self, *_args, **_kwargs): + raise AssertionError("should not rename") + + result = await manager._execute_download( + download_urls=["https://example.com/file.safetensors"], + save_dir=str(save_dir), + metadata=DummyMetadata(target_path), + version_info={"images": []}, + relative_path="", + progress_callback=None, + model_type="lora", + download_id="download-1", + transfer_backend="aria2", + ) + + assert result["success"] is False + assert "already using" in result["error"] + + +@pytest.mark.asyncio +async def test_execute_download_reassigns_same_aria2_partial_to_new_download_id( + monkeypatch, tmp_path +): + manager = DownloadManager() + + save_dir = tmp_path / "downloads" + save_dir.mkdir() + target_path = save_dir / "file.safetensors" + target_path.write_text("partial") + (save_dir / "file.safetensors.aria2").write_text("control") + + await manager._aria2_state_store.upsert( + "old-download", + { + "download_id": "old-download", + "transfer_backend": "aria2", + "save_path": str(target_path), + "file_path": str(target_path), + "status": "paused", + "model_id": 11, + "model_version_id": 22, + }, + ) + + class DummyMetadata: + def __init__(self, path: Path): + self.file_path = str(path) + self.sha256 = "sha256" + self.file_name = path.stem + self.preview_url = None + + def generate_unique_filename(self, *_args, **_kwargs): + raise AssertionError("should not rename") + + def update_file_info(self, _path): + return None + + def to_dict(self): + return {"file_path": self.file_path} + + class DummyAria2Downloader: + def __init__(self): + self.calls = [] + + async def reassign_transfer(self, previous_download_id, new_download_id): + self.calls.append(("reassign_transfer", previous_download_id, new_download_id)) + return None + + dummy_aria2 = DummyAria2Downloader() + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=dummy_aria2), + ) + + manager._active_downloads["old-download"] = { + "transfer_backend": "aria2", + "model_id": 11, + "model_version_id": 22, + "status": "paused", + } + manager._active_downloads["new-download"] = { + "transfer_backend": "aria2", + "model_id": 11, + "model_version_id": 22, + "status": "queued", + } + + resolved, path = await manager._resolve_download_target_path( + str(save_dir), + DummyMetadata(target_path), + transfer_backend="aria2", + download_id="new-download", + ) + + assert resolved is True + assert path == str(target_path) + assert "old-download" not in manager._active_downloads + assert manager._active_downloads["new-download"]["file_path"] == str(target_path) + assert dummy_aria2.calls == [("reassign_transfer", "old-download", "new-download")] + assert await manager._aria2_state_store.get("old-download") is None + assert (await manager._aria2_state_store.get("new-download"))["save_path"] == str( + target_path + ) + + +def test_is_same_aria2_download_request_requires_version_id_match(): + manager = DownloadManager() + + assert ( + manager._is_same_aria2_download_request( + {"model_id": 1, "model_version_id": None}, + {"model_id": 1, "model_version_id": 2}, + ) + is False + ) + assert ( + manager._is_same_aria2_download_request( + {"model_id": 1, "model_version_id": 3}, + {"model_id": 1, "model_version_id": None}, + ) + is False + ) + + +@pytest.mark.asyncio +async def test_adopt_existing_aria2_download_cancels_old_running_task(monkeypatch, tmp_path): + manager = DownloadManager() + save_path = tmp_path / "file.safetensors" + + started = asyncio.Event() + cancelled = asyncio.Event() + call_order = [] + + async def old_download(): + started.set() + try: + await asyncio.sleep(60) + except asyncio.CancelledError: + call_order.append("old-task-cancelled") + cancelled.set() + raise + + old_task = asyncio.create_task(old_download()) + await started.wait() + + manager._download_tasks["old-download"] = old_task + old_pause_control = DownloadStreamControl() + old_pause_control.pause() + manager._pause_events["old-download"] = old_pause_control + manager._active_downloads["old-download"] = { + "transfer_backend": "aria2", + "model_id": 11, + "model_version_id": 22, + "status": "downloading", + } + manager._active_downloads["new-download"] = { + "transfer_backend": "aria2", + "model_id": 11, + "model_version_id": 22, + "status": "queued", + } + + await manager._aria2_state_store.upsert( + "old-download", + { + "download_id": "old-download", + "transfer_backend": "aria2", + "save_path": str(save_path), + "file_path": str(save_path), + "status": "downloading", + "model_id": 11, + "model_version_id": 22, + }, + ) + + class DummyAria2Downloader: + async def reassign_transfer(self, previous_download_id, new_download_id): + call_order.append("reassign-transfer") + return None + + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=DummyAria2Downloader()), + ) + + await manager._adopt_existing_aria2_download( + "old-download", + "new-download", + {"model_id": 11, "model_version_id": 22}, + str(save_path), + ) + + assert cancelled.is_set() is True + assert "old-download" not in manager._download_tasks + assert call_order == ["reassign-transfer", "old-task-cancelled"] + + @pytest.mark.asyncio async def test_pause_download_rejects_unknown_task(): """Test that pause_download rejects unknown download tasks.""" diff --git a/tests/services/test_settings_manager.py b/tests/services/test_settings_manager.py index abb1fb97..7b14db91 100644 --- a/tests/services/test_settings_manager.py +++ b/tests/services/test_settings_manager.py @@ -147,6 +147,11 @@ def test_environment_variable_overrides_settings(tmp_path, monkeypatch): assert mgr.get("civitai_api_key") == "secret" +def test_default_download_backend_is_python(manager): + assert manager.get("download_backend") == "python" + assert manager.get("aria2c_path") == "" + + def _create_manager_with_settings( tmp_path, monkeypatch, initial_settings, *, save_spy=None ): @@ -327,6 +332,43 @@ def test_auto_set_default_roots_keeps_valid_values(manager): assert manager.get("default_embedding_root") == "/embeddings" +def test_auto_set_default_roots_keeps_valid_extra_values(manager): + manager.settings["default_lora_root"] = "/extra-loras" + manager.settings["default_checkpoint_root"] = "/extra-checkpoints" + manager.settings["default_embedding_root"] = "/extra-embeddings" + manager.settings["default_unet_root"] = "/extra-unet" + + manager.settings["folder_paths"] = { + "loras": ["/loras"], + "checkpoints": ["/checkpoints"], + "unet": ["/unet"], + "embeddings": ["/embeddings"], + } + manager.settings["extra_folder_paths"] = { + "loras": ["/extra-loras"], + "checkpoints": ["/extra-checkpoints"], + "unet": ["/extra-unet"], + "embeddings": ["/extra-embeddings"], + } + + manager._auto_set_default_roots() + + assert manager.get("default_lora_root") == "/extra-loras" + assert manager.get("default_checkpoint_root") == "/extra-checkpoints" + assert manager.get("default_unet_root") == "/extra-unet" + assert manager.get("default_embedding_root") == "/extra-embeddings" + + +def test_auto_set_default_roots_falls_back_to_extra_when_primary_missing(manager): + manager.settings["default_lora_root"] = "" + manager.settings["folder_paths"] = {"loras": []} + manager.settings["extra_folder_paths"] = {"loras": ["/extra-loras"]} + + manager._auto_set_default_roots() + + assert manager.get("default_lora_root") == "/extra-loras" + + def test_delete_setting(manager): manager.set("example", 1) manager.delete("example")