feat(download): add experimental aria2 backend

This commit is contained in:
Will Miao
2026-04-19 21:46:09 +08:00
parent 0ced53c059
commit 1c530ea013
21 changed files with 1867 additions and 28 deletions

View File

@@ -5,6 +5,7 @@ import asyncio
import inspect
import shutil
import zipfile
from concurrent.futures import ThreadPoolExecutor
from collections import OrderedDict
import uuid
from typing import Dict, List, Optional, Set, Tuple
@@ -25,6 +26,7 @@ from .service_registry import ServiceRegistry
from .settings_manager import get_settings_manager
from .metadata_service import get_default_metadata_provider, get_metadata_provider
from .downloader import get_downloader, DownloadProgress, DownloadStreamControl
from .aria2_downloader import Aria2Error, get_aria2_downloader
# Download to temporary file first
import tempfile
@@ -60,6 +62,59 @@ class DownloadManager:
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
self._download_tasks = {} # download_id -> asyncio.Task
self._pause_events: Dict[str, DownloadStreamControl] = {}
self._archive_executor = ThreadPoolExecutor(
max_workers=2, thread_name_prefix="lm-archive"
)
@staticmethod
def _get_model_download_backend() -> str:
backend = (get_settings_manager().get("download_backend") or "python").strip()
return backend.lower() or "python"
async def _download_model_file(
self,
download_url: str,
save_path: str,
*,
backend: str,
progress_callback,
use_auth: bool,
download_id: Optional[str],
pause_control: Optional[DownloadStreamControl],
) -> Tuple[bool, str]:
if backend == "aria2":
if not download_id:
return False, "aria2 downloads require a tracked download_id"
headers: Dict[str, str] = {}
if use_auth:
api_key = (get_settings_manager().get("civitai_api_key") or "").strip()
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
try:
aria2_downloader = await get_aria2_downloader()
return await aria2_downloader.download_file(
download_url,
save_path,
download_id=download_id,
progress_callback=progress_callback,
headers=headers or None,
)
except Aria2Error as exc:
logger.error("aria2 download failed for %s: %s", download_url, exc)
return False, str(exc)
download_kwargs = {
"progress_callback": progress_callback,
"use_auth": use_auth,
}
if pause_control is not None:
download_kwargs["pause_event"] = pause_control
downloader = await get_downloader()
return await downloader.download_file(download_url, save_path, **download_kwargs)
async def _get_lora_scanner(self):
"""Get the lora scanner from registry"""
@@ -126,6 +181,7 @@ class DownloadManager:
"model_version_id": model_version_id,
"progress": 0,
"status": "queued",
"transfer_backend": self._get_model_download_backend(),
"bytes_downloaded": 0,
"total_bytes": None,
"bytes_per_second": 0.0,
@@ -240,6 +296,9 @@ class DownloadManager:
tracking_callback,
use_default_paths,
task_id,
self._active_downloads.get(task_id, {}).get(
"transfer_backend", "python"
),
source,
file_params,
)
@@ -294,6 +353,7 @@ class DownloadManager:
progress_callback,
use_default_paths,
download_id=None,
transfer_backend="python",
source=None,
file_params=None,
):
@@ -696,16 +756,27 @@ class DownloadManager:
logger.info(f"Creating EmbeddingMetadata for {file_name}")
# 6. Start download process
result = await self._execute_download(
download_urls=download_urls,
save_dir=save_dir,
metadata=metadata,
version_info=version_info,
relative_path=relative_path,
progress_callback=progress_callback,
model_type=model_type,
download_id=download_id,
)
execute_kwargs = {
"download_urls": download_urls,
"save_dir": save_dir,
"metadata": metadata,
"version_info": version_info,
"relative_path": relative_path,
"progress_callback": progress_callback,
"model_type": model_type,
"download_id": download_id,
}
execute_signature = inspect.signature(self._execute_download)
if (
"transfer_backend" in execute_signature.parameters
or any(
parameter.kind == inspect.Parameter.VAR_KEYWORD
for parameter in execute_signature.parameters.values()
)
):
execute_kwargs["transfer_backend"] = transfer_backend
result = await self._execute_download(**execute_kwargs)
if result.get("success", False):
resolved_model_id = (
@@ -965,6 +1036,7 @@ class DownloadManager:
progress_callback=None,
model_type: str = "lora",
download_id: str = None,
transfer_backend: Optional[str] = None,
) -> Dict:
"""Execute the actual download process including preview images and model files"""
metadata_entries: List = []
@@ -974,6 +1046,7 @@ class DownloadManager:
preview_targets: List[str] = []
preview_path: str | None = None
preview_nsfw_level = 0
transfer_backend = (transfer_backend or self._get_model_download_backend()).lower()
try:
# Extract original filename details
original_filename = os.path.basename(metadata.file_path)
@@ -1136,32 +1209,37 @@ class DownloadManager:
if progress_callback:
await progress_callback(3) # 3% progress after preview download
# Download model file with progress tracking using downloader
downloader = await get_downloader()
if pause_control is not None:
pause_control.update_stall_timeout(downloader.stall_timeout)
# Download model file with progress tracking using the configured backend
downloader = None
if transfer_backend == "python":
downloader = await get_downloader()
if pause_control is not None:
pause_control.update_stall_timeout(downloader.stall_timeout)
if pause_control is not None and pause_control.is_paused():
if download_id and download_id in self._active_downloads:
self._active_downloads[download_id]["status"] = "paused"
self._active_downloads[download_id]["bytes_per_second"] = 0.0
await pause_control.wait()
if download_id and download_id in self._active_downloads:
self._active_downloads[download_id]["status"] = "downloading"
last_error = None
for download_url in download_urls:
download_url = normalize_civitai_download_url(download_url)
use_auth = download_url.startswith(CIVITAI_DOWNLOAD_URL_PREFIXES)
download_kwargs = {
"progress_callback": lambda progress, snapshot=None: (
success, result = await self._download_model_file(
download_url,
save_path,
backend=transfer_backend,
progress_callback=lambda progress, snapshot=None: (
self._handle_download_progress(
progress,
progress_callback,
snapshot,
)
),
"use_auth": use_auth, # Only use authentication for Civitai downloads
}
if pause_control is not None:
download_kwargs["pause_event"] = pause_control
success, result = await downloader.download_file(
download_url,
save_path, # Use full path instead of separate dir and filename
**download_kwargs,
use_auth=use_auth,
download_id=download_id,
pause_control=pause_control,
)
if success:
@@ -1401,7 +1479,8 @@ class DownloadManager:
extracted_files.append(dest_path)
return extracted_files
return await asyncio.to_thread(_extract_sync)
loop = asyncio.get_running_loop()
return await loop.run_in_executor(self._archive_executor, _extract_sync)
async def _build_metadata_entries(
self, base_metadata, file_paths: List[str]
@@ -1511,8 +1590,28 @@ class DownloadManager:
return {"success": False, "error": "Download task not found"}
try:
# Get the task and cancel it
task = self._download_tasks[download_id]
backend = (
self._active_downloads.get(download_id, {}).get("transfer_backend")
or "python"
)
if backend == "aria2":
try:
aria2_downloader = await get_aria2_downloader()
cancel_result = await aria2_downloader.cancel_download(download_id)
if (
not cancel_result.get("success")
and cancel_result.get("error") != "Download task not found"
):
return cancel_result
except Exception as exc:
logger.warning(
"Failed to cancel aria2 transfer for %s, continuing with local task cancellation: %s",
download_id,
exc,
)
task.cancel()
pause_control = self._pause_events.get(download_id)
@@ -1613,6 +1712,28 @@ class DownloadManager:
pause_control.pause()
backend = (
self._active_downloads.get(download_id, {}).get("transfer_backend")
or "python"
)
if backend == "aria2":
try:
aria2_downloader = await get_aria2_downloader()
if await aria2_downloader.has_transfer(download_id):
result = await aria2_downloader.pause_download(download_id)
if not result.get("success"):
pause_control.resume()
return result
except Exception as exc:
pause_control.resume()
return {"success": False, "error": str(exc)}
download_info = self._active_downloads.get(download_id)
if download_info is not None:
download_info["status"] = "paused"
download_info["bytes_per_second"] = 0.0
return {"success": True, "message": "Download paused successfully"}
download_info = self._active_downloads.get(download_id)
if download_info is not None:
download_info["status"] = "paused"
@@ -1631,6 +1752,28 @@ class DownloadManager:
return {"success": False, "error": "Download is not paused"}
download_info = self._active_downloads.get(download_id)
backend = (
self._active_downloads.get(download_id, {}).get("transfer_backend")
or "python"
)
if backend == "aria2":
try:
aria2_downloader = await get_aria2_downloader()
if await aria2_downloader.has_transfer(download_id):
result = await aria2_downloader.resume_download(download_id)
if not result.get("success"):
return result
except Exception as exc:
return {"success": False, "error": str(exc)}
pause_control.resume()
if download_info is not None:
if download_info.get("status") == "paused":
download_info["status"] = "downloading"
download_info.setdefault("bytes_per_second", 0.0)
return {"success": True, "message": "Download resumed successfully"}
force_reconnect = False
if pause_control is not None:
elapsed = pause_control.time_since_last_progress()