mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-06 16:36:45 -03:00
feat(download): add experimental aria2 backend
This commit is contained in:
497
py/services/aria2_downloader.py
Normal file
497
py/services/aria2_downloader.py
Normal file
@@ -0,0 +1,497 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import shutil
|
||||
import socket
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
|
||||
from .downloader import DownloadProgress, get_downloader
|
||||
from .settings_manager import get_settings_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CIVITAI_DOWNLOAD_URL_PREFIXES = (
|
||||
"https://civitai.com/api/download/",
|
||||
"https://civitai.red/api/download/",
|
||||
)
|
||||
|
||||
|
||||
class Aria2Error(RuntimeError):
|
||||
"""Raised when aria2 integration fails."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class Aria2Transfer:
|
||||
"""Track an aria2 download registered by the Python coordinator."""
|
||||
|
||||
gid: str
|
||||
save_path: str
|
||||
|
||||
|
||||
class Aria2Downloader:
|
||||
"""Manage an aria2 RPC daemon for experimental model downloads."""
|
||||
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls) -> "Aria2Downloader":
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
if hasattr(self, "_initialized"):
|
||||
return
|
||||
|
||||
self._initialized = True
|
||||
self._process: Optional[asyncio.subprocess.Process] = None
|
||||
self._rpc_port: Optional[int] = None
|
||||
self._rpc_secret = ""
|
||||
self._rpc_url = ""
|
||||
self._rpc_session: Optional[aiohttp.ClientSession] = None
|
||||
self._rpc_session_lock = asyncio.Lock()
|
||||
self._process_lock = asyncio.Lock()
|
||||
self._transfers: Dict[str, Aria2Transfer] = {}
|
||||
self._poll_interval = 0.5
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return self._process is not None and self._process.returncode is None
|
||||
|
||||
async def download_file(
|
||||
self,
|
||||
url: str,
|
||||
save_path: str,
|
||||
*,
|
||||
download_id: str,
|
||||
progress_callback=None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
"""Download a file using aria2 RPC and wait for completion."""
|
||||
|
||||
await self._ensure_process()
|
||||
save_path = os.path.abspath(save_path)
|
||||
save_dir = os.path.dirname(save_path)
|
||||
out_name = os.path.basename(save_path)
|
||||
|
||||
Path(save_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
resolved_url = url
|
||||
request_headers = headers
|
||||
if headers and url.startswith(CIVITAI_DOWNLOAD_URL_PREFIXES):
|
||||
resolved_url = await self._resolve_authenticated_redirect_url(url, headers)
|
||||
if resolved_url != url:
|
||||
request_headers = None
|
||||
logger.debug(
|
||||
"Resolved Civitai download %s to signed URL for aria2",
|
||||
download_id,
|
||||
)
|
||||
|
||||
options: Dict[str, str] = {
|
||||
"dir": save_dir,
|
||||
"out": out_name,
|
||||
"continue": "true",
|
||||
"max-connection-per-server": "4",
|
||||
"split": "4",
|
||||
"min-split-size": "1M",
|
||||
"allow-overwrite": "true",
|
||||
"auto-file-renaming": "false",
|
||||
"file-allocation": "none",
|
||||
}
|
||||
if request_headers:
|
||||
options["header"] = [
|
||||
f"{key}: {value}" for key, value in request_headers.items()
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
"Submitting aria2 download %s -> %s (auth=%s, civitai_signed=%s)",
|
||||
download_id,
|
||||
save_path,
|
||||
bool(request_headers),
|
||||
resolved_url != url,
|
||||
)
|
||||
|
||||
try:
|
||||
gid = await self._rpc_call("aria2.addUri", [[resolved_url], options])
|
||||
except Exception as exc:
|
||||
raise Aria2Error(f"Failed to schedule aria2 download: {exc}") from exc
|
||||
|
||||
logger.debug("aria2 accepted download %s with gid %s", download_id, gid)
|
||||
|
||||
self._transfers[download_id] = Aria2Transfer(gid=gid, save_path=save_path)
|
||||
|
||||
try:
|
||||
while True:
|
||||
status = await self.get_status(download_id)
|
||||
if status is None:
|
||||
return False, "aria2 download not found"
|
||||
|
||||
snapshot = self._build_progress_snapshot(status)
|
||||
if progress_callback is not None:
|
||||
await self._dispatch_progress(progress_callback, snapshot)
|
||||
|
||||
state = status.get("status", "")
|
||||
if state == "complete":
|
||||
completed_path = self._resolve_completed_path(status, save_path)
|
||||
return True, completed_path
|
||||
if state == "error":
|
||||
return False, status.get("errorMessage") or "aria2 download failed"
|
||||
if state == "removed":
|
||||
return False, "Download was cancelled"
|
||||
|
||||
await asyncio.sleep(self._poll_interval)
|
||||
finally:
|
||||
self._transfers.pop(download_id, None)
|
||||
|
||||
async def get_status(self, download_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Return the raw aria2 status payload for a known download."""
|
||||
|
||||
transfer = self._transfers.get(download_id)
|
||||
if transfer is None:
|
||||
return None
|
||||
|
||||
keys = [
|
||||
"gid",
|
||||
"status",
|
||||
"totalLength",
|
||||
"completedLength",
|
||||
"downloadSpeed",
|
||||
"errorMessage",
|
||||
"files",
|
||||
]
|
||||
try:
|
||||
status = await self._rpc_call("aria2.tellStatus", [transfer.gid, keys])
|
||||
except Exception as exc:
|
||||
raise Aria2Error(f"Failed to query aria2 download status: {exc}") from exc
|
||||
|
||||
if isinstance(status, dict):
|
||||
return status
|
||||
return None
|
||||
|
||||
async def has_transfer(self, download_id: str) -> bool:
|
||||
return download_id in self._transfers
|
||||
|
||||
async def pause_download(self, download_id: str) -> Dict[str, Any]:
|
||||
transfer = self._transfers.get(download_id)
|
||||
if transfer is None:
|
||||
return {"success": False, "error": "Download task not found"}
|
||||
|
||||
try:
|
||||
await self._rpc_call("aria2.forcePause", [transfer.gid])
|
||||
except Exception as exc:
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
return {"success": True, "message": "Download paused successfully"}
|
||||
|
||||
async def resume_download(self, download_id: str) -> Dict[str, Any]:
|
||||
transfer = self._transfers.get(download_id)
|
||||
if transfer is None:
|
||||
return {"success": False, "error": "Download task not found"}
|
||||
|
||||
try:
|
||||
await self._rpc_call("aria2.unpause", [transfer.gid])
|
||||
except Exception as exc:
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
return {"success": True, "message": "Download resumed successfully"}
|
||||
|
||||
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
|
||||
transfer = self._transfers.get(download_id)
|
||||
if transfer is None:
|
||||
return {"success": False, "error": "Download task not found"}
|
||||
|
||||
try:
|
||||
await self._rpc_call("aria2.forceRemove", [transfer.gid])
|
||||
except Exception as exc:
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
return {"success": True, "message": "Download cancelled successfully"}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Shut down the RPC process and session."""
|
||||
|
||||
if self._rpc_session is not None:
|
||||
await self._rpc_session.close()
|
||||
self._rpc_session = None
|
||||
|
||||
process = self._process
|
||||
self._process = None
|
||||
self._transfers.clear()
|
||||
|
||||
if process is None:
|
||||
return
|
||||
|
||||
if process.returncode is None:
|
||||
process.terminate()
|
||||
try:
|
||||
await asyncio.wait_for(process.wait(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
process.kill()
|
||||
await process.wait()
|
||||
|
||||
async def _dispatch_progress(self, callback, snapshot: DownloadProgress) -> None:
|
||||
try:
|
||||
result = callback(snapshot, snapshot)
|
||||
except TypeError:
|
||||
result = callback(snapshot.percent_complete)
|
||||
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
elif hasattr(result, "__await__"):
|
||||
await result
|
||||
|
||||
def _build_progress_snapshot(self, status: Dict[str, Any]) -> DownloadProgress:
|
||||
completed = self._parse_int(status.get("completedLength"))
|
||||
total = self._parse_int(status.get("totalLength"))
|
||||
speed = float(self._parse_int(status.get("downloadSpeed")))
|
||||
percent = 0.0
|
||||
if total > 0:
|
||||
percent = (completed / total) * 100.0
|
||||
|
||||
return DownloadProgress(
|
||||
percent_complete=max(0.0, min(percent, 100.0)),
|
||||
bytes_downloaded=completed,
|
||||
total_bytes=total or None,
|
||||
bytes_per_second=speed,
|
||||
timestamp=datetime.now().timestamp(),
|
||||
)
|
||||
|
||||
def _resolve_completed_path(self, status: Dict[str, Any], default_path: str) -> str:
|
||||
files = status.get("files")
|
||||
if isinstance(files, list) and files:
|
||||
first = files[0]
|
||||
if isinstance(first, dict):
|
||||
candidate = first.get("path")
|
||||
if isinstance(candidate, str) and candidate:
|
||||
return candidate
|
||||
return default_path
|
||||
|
||||
@staticmethod
|
||||
def _parse_int(value: Any) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return 0
|
||||
|
||||
async def _resolve_authenticated_redirect_url(
|
||||
self,
|
||||
url: str,
|
||||
headers: Dict[str, str],
|
||||
) -> str:
|
||||
downloader = await get_downloader()
|
||||
session = await downloader.session
|
||||
request_headers = dict(downloader.default_headers)
|
||||
request_headers.update(headers)
|
||||
request_headers["Accept-Encoding"] = "identity"
|
||||
|
||||
try:
|
||||
async with session.get(
|
||||
url,
|
||||
headers=request_headers,
|
||||
allow_redirects=False,
|
||||
proxy=downloader.proxy_url,
|
||||
) as response:
|
||||
if response.status in {301, 302, 303, 307, 308}:
|
||||
location = response.headers.get("Location")
|
||||
if location:
|
||||
return location
|
||||
raise Aria2Error(
|
||||
"Authenticated Civitai redirect did not include a Location header"
|
||||
)
|
||||
|
||||
if response.status == 200:
|
||||
return url
|
||||
|
||||
body = await response.text()
|
||||
raise Aria2Error(
|
||||
f"Failed to resolve authenticated Civitai redirect: status={response.status} body={body[:300]}"
|
||||
)
|
||||
except aiohttp.ClientError as exc:
|
||||
raise Aria2Error(
|
||||
f"Failed to resolve authenticated Civitai redirect: {exc}"
|
||||
) from exc
|
||||
|
||||
async def _ensure_process(self) -> None:
|
||||
async with self._process_lock:
|
||||
if self.is_running and await self._ping():
|
||||
return
|
||||
|
||||
await self.close()
|
||||
|
||||
executable = self._resolve_executable()
|
||||
self._rpc_port = self._find_free_port()
|
||||
self._rpc_secret = secrets.token_hex(16)
|
||||
self._rpc_url = f"http://127.0.0.1:{self._rpc_port}/jsonrpc"
|
||||
|
||||
command = [
|
||||
executable,
|
||||
"--enable-rpc=true",
|
||||
"--rpc-listen-all=false",
|
||||
f"--rpc-listen-port={self._rpc_port}",
|
||||
f"--rpc-secret={self._rpc_secret}",
|
||||
"--check-certificate=true",
|
||||
"--allow-overwrite=true",
|
||||
"--auto-file-renaming=false",
|
||||
"--file-allocation=none",
|
||||
"--max-concurrent-downloads=5",
|
||||
"--continue=true",
|
||||
"--daemon=false",
|
||||
"--quiet=true",
|
||||
f"--stop-with-process={os.getpid()}",
|
||||
]
|
||||
|
||||
logger.info("Starting aria2 RPC daemon from %s", executable)
|
||||
self._process = await asyncio.create_subprocess_exec(
|
||||
*command,
|
||||
stdout=asyncio.subprocess.DEVNULL,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
await self._wait_until_ready()
|
||||
|
||||
def _resolve_executable(self) -> str:
|
||||
settings = get_settings_manager()
|
||||
configured_path = (settings.get("aria2c_path") or "").strip()
|
||||
candidate = configured_path or "aria2c"
|
||||
|
||||
resolved = shutil.which(candidate)
|
||||
if resolved:
|
||||
return resolved
|
||||
|
||||
if configured_path and os.path.isfile(configured_path) and os.access(
|
||||
configured_path, os.X_OK
|
||||
):
|
||||
return configured_path
|
||||
|
||||
raise Aria2Error(
|
||||
"aria2c executable was not found. Install aria2 or configure aria2c_path."
|
||||
)
|
||||
|
||||
async def _wait_until_ready(self) -> None:
|
||||
assert self._process is not None
|
||||
|
||||
start_time = asyncio.get_running_loop().time()
|
||||
last_error = ""
|
||||
while asyncio.get_running_loop().time() - start_time < 10.0:
|
||||
if self._process.returncode is not None:
|
||||
stderr_output = ""
|
||||
if self._process.stderr is not None:
|
||||
try:
|
||||
stderr_output = (
|
||||
await asyncio.wait_for(self._process.stderr.read(), timeout=0.2)
|
||||
).decode("utf-8", errors="replace")
|
||||
except Exception:
|
||||
stderr_output = ""
|
||||
raise Aria2Error(
|
||||
f"aria2 RPC process exited early with code {self._process.returncode}: {stderr_output.strip()}"
|
||||
)
|
||||
|
||||
try:
|
||||
if await self._ping():
|
||||
return
|
||||
except Exception as exc: # pragma: no cover - startup race
|
||||
last_error = str(exc)
|
||||
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
raise Aria2Error(
|
||||
f"Timed out waiting for aria2 RPC to become ready{': ' + last_error if last_error else ''}"
|
||||
)
|
||||
|
||||
async def _ping(self) -> bool:
|
||||
try:
|
||||
result = await self._rpc_call("aria2.getVersion", [])
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
return isinstance(result, dict)
|
||||
|
||||
async def _rpc_call(self, method: str, params: list[Any]) -> Any:
|
||||
if not self._rpc_url:
|
||||
raise Aria2Error("aria2 RPC endpoint is not initialized")
|
||||
|
||||
session = await self._get_rpc_session()
|
||||
payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": secrets.token_hex(8),
|
||||
"method": method,
|
||||
"params": [f"token:{self._rpc_secret}", *params],
|
||||
}
|
||||
|
||||
async with session.post(self._rpc_url, json=payload) as response:
|
||||
text = await response.text()
|
||||
|
||||
try:
|
||||
body = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
body = None
|
||||
|
||||
if body is None:
|
||||
if response.status != 200:
|
||||
raise Aria2Error(
|
||||
f"aria2 RPC returned status {response.status} with non-JSON body: {text}"
|
||||
)
|
||||
raise Aria2Error(f"Invalid aria2 RPC response: {text}")
|
||||
|
||||
if "error" in body:
|
||||
error = body["error"] or {}
|
||||
code = error.get("code") if isinstance(error, dict) else None
|
||||
message = error.get("message") if isinstance(error, dict) else str(error)
|
||||
logger.error(
|
||||
"aria2 RPC %s failed with HTTP %s, code=%s, message=%s",
|
||||
method,
|
||||
response.status,
|
||||
code,
|
||||
message,
|
||||
)
|
||||
status_message = (
|
||||
f"aria2 RPC {method} failed with status {response.status}: {message}"
|
||||
if response.status != 200
|
||||
else message
|
||||
)
|
||||
raise Aria2Error(status_message or "Unknown aria2 RPC error")
|
||||
|
||||
if response.status != 200:
|
||||
logger.error(
|
||||
"aria2 RPC %s returned unexpected HTTP status %s without error payload: %s",
|
||||
method,
|
||||
response.status,
|
||||
body,
|
||||
)
|
||||
raise Aria2Error(
|
||||
f"aria2 RPC {method} returned unexpected status {response.status}"
|
||||
)
|
||||
|
||||
return body.get("result")
|
||||
|
||||
async def _get_rpc_session(self) -> aiohttp.ClientSession:
|
||||
if self._rpc_session is None or self._rpc_session.closed:
|
||||
async with self._rpc_session_lock:
|
||||
if self._rpc_session is None or self._rpc_session.closed:
|
||||
timeout = aiohttp.ClientTimeout(total=30)
|
||||
self._rpc_session = aiohttp.ClientSession(timeout=timeout)
|
||||
return self._rpc_session
|
||||
|
||||
@staticmethod
|
||||
def _find_free_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.bind(("127.0.0.1", 0))
|
||||
sock.listen(1)
|
||||
return int(sock.getsockname()[1])
|
||||
|
||||
|
||||
async def get_aria2_downloader() -> Aria2Downloader:
|
||||
"""Get the singleton aria2 downloader."""
|
||||
|
||||
return await Aria2Downloader.get_instance()
|
||||
@@ -5,6 +5,7 @@ import asyncio
|
||||
import inspect
|
||||
import shutil
|
||||
import zipfile
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from collections import OrderedDict
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
@@ -25,6 +26,7 @@ from .service_registry import ServiceRegistry
|
||||
from .settings_manager import get_settings_manager
|
||||
from .metadata_service import get_default_metadata_provider, get_metadata_provider
|
||||
from .downloader import get_downloader, DownloadProgress, DownloadStreamControl
|
||||
from .aria2_downloader import Aria2Error, get_aria2_downloader
|
||||
|
||||
# Download to temporary file first
|
||||
import tempfile
|
||||
@@ -60,6 +62,59 @@ class DownloadManager:
|
||||
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
|
||||
self._download_tasks = {} # download_id -> asyncio.Task
|
||||
self._pause_events: Dict[str, DownloadStreamControl] = {}
|
||||
self._archive_executor = ThreadPoolExecutor(
|
||||
max_workers=2, thread_name_prefix="lm-archive"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_model_download_backend() -> str:
|
||||
backend = (get_settings_manager().get("download_backend") or "python").strip()
|
||||
return backend.lower() or "python"
|
||||
|
||||
async def _download_model_file(
|
||||
self,
|
||||
download_url: str,
|
||||
save_path: str,
|
||||
*,
|
||||
backend: str,
|
||||
progress_callback,
|
||||
use_auth: bool,
|
||||
download_id: Optional[str],
|
||||
pause_control: Optional[DownloadStreamControl],
|
||||
) -> Tuple[bool, str]:
|
||||
if backend == "aria2":
|
||||
if not download_id:
|
||||
return False, "aria2 downloads require a tracked download_id"
|
||||
|
||||
headers: Dict[str, str] = {}
|
||||
if use_auth:
|
||||
api_key = (get_settings_manager().get("civitai_api_key") or "").strip()
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
try:
|
||||
aria2_downloader = await get_aria2_downloader()
|
||||
return await aria2_downloader.download_file(
|
||||
download_url,
|
||||
save_path,
|
||||
download_id=download_id,
|
||||
progress_callback=progress_callback,
|
||||
headers=headers or None,
|
||||
)
|
||||
except Aria2Error as exc:
|
||||
logger.error("aria2 download failed for %s: %s", download_url, exc)
|
||||
return False, str(exc)
|
||||
|
||||
download_kwargs = {
|
||||
"progress_callback": progress_callback,
|
||||
"use_auth": use_auth,
|
||||
}
|
||||
|
||||
if pause_control is not None:
|
||||
download_kwargs["pause_event"] = pause_control
|
||||
|
||||
downloader = await get_downloader()
|
||||
return await downloader.download_file(download_url, save_path, **download_kwargs)
|
||||
|
||||
async def _get_lora_scanner(self):
|
||||
"""Get the lora scanner from registry"""
|
||||
@@ -126,6 +181,7 @@ class DownloadManager:
|
||||
"model_version_id": model_version_id,
|
||||
"progress": 0,
|
||||
"status": "queued",
|
||||
"transfer_backend": self._get_model_download_backend(),
|
||||
"bytes_downloaded": 0,
|
||||
"total_bytes": None,
|
||||
"bytes_per_second": 0.0,
|
||||
@@ -240,6 +296,9 @@ class DownloadManager:
|
||||
tracking_callback,
|
||||
use_default_paths,
|
||||
task_id,
|
||||
self._active_downloads.get(task_id, {}).get(
|
||||
"transfer_backend", "python"
|
||||
),
|
||||
source,
|
||||
file_params,
|
||||
)
|
||||
@@ -294,6 +353,7 @@ class DownloadManager:
|
||||
progress_callback,
|
||||
use_default_paths,
|
||||
download_id=None,
|
||||
transfer_backend="python",
|
||||
source=None,
|
||||
file_params=None,
|
||||
):
|
||||
@@ -696,16 +756,27 @@ class DownloadManager:
|
||||
logger.info(f"Creating EmbeddingMetadata for {file_name}")
|
||||
|
||||
# 6. Start download process
|
||||
result = await self._execute_download(
|
||||
download_urls=download_urls,
|
||||
save_dir=save_dir,
|
||||
metadata=metadata,
|
||||
version_info=version_info,
|
||||
relative_path=relative_path,
|
||||
progress_callback=progress_callback,
|
||||
model_type=model_type,
|
||||
download_id=download_id,
|
||||
)
|
||||
execute_kwargs = {
|
||||
"download_urls": download_urls,
|
||||
"save_dir": save_dir,
|
||||
"metadata": metadata,
|
||||
"version_info": version_info,
|
||||
"relative_path": relative_path,
|
||||
"progress_callback": progress_callback,
|
||||
"model_type": model_type,
|
||||
"download_id": download_id,
|
||||
}
|
||||
execute_signature = inspect.signature(self._execute_download)
|
||||
if (
|
||||
"transfer_backend" in execute_signature.parameters
|
||||
or any(
|
||||
parameter.kind == inspect.Parameter.VAR_KEYWORD
|
||||
for parameter in execute_signature.parameters.values()
|
||||
)
|
||||
):
|
||||
execute_kwargs["transfer_backend"] = transfer_backend
|
||||
|
||||
result = await self._execute_download(**execute_kwargs)
|
||||
|
||||
if result.get("success", False):
|
||||
resolved_model_id = (
|
||||
@@ -965,6 +1036,7 @@ class DownloadManager:
|
||||
progress_callback=None,
|
||||
model_type: str = "lora",
|
||||
download_id: str = None,
|
||||
transfer_backend: Optional[str] = None,
|
||||
) -> Dict:
|
||||
"""Execute the actual download process including preview images and model files"""
|
||||
metadata_entries: List = []
|
||||
@@ -974,6 +1046,7 @@ class DownloadManager:
|
||||
preview_targets: List[str] = []
|
||||
preview_path: str | None = None
|
||||
preview_nsfw_level = 0
|
||||
transfer_backend = (transfer_backend or self._get_model_download_backend()).lower()
|
||||
try:
|
||||
# Extract original filename details
|
||||
original_filename = os.path.basename(metadata.file_path)
|
||||
@@ -1136,32 +1209,37 @@ class DownloadManager:
|
||||
if progress_callback:
|
||||
await progress_callback(3) # 3% progress after preview download
|
||||
|
||||
# Download model file with progress tracking using downloader
|
||||
downloader = await get_downloader()
|
||||
if pause_control is not None:
|
||||
pause_control.update_stall_timeout(downloader.stall_timeout)
|
||||
# Download model file with progress tracking using the configured backend
|
||||
downloader = None
|
||||
if transfer_backend == "python":
|
||||
downloader = await get_downloader()
|
||||
if pause_control is not None:
|
||||
pause_control.update_stall_timeout(downloader.stall_timeout)
|
||||
if pause_control is not None and pause_control.is_paused():
|
||||
if download_id and download_id in self._active_downloads:
|
||||
self._active_downloads[download_id]["status"] = "paused"
|
||||
self._active_downloads[download_id]["bytes_per_second"] = 0.0
|
||||
await pause_control.wait()
|
||||
if download_id and download_id in self._active_downloads:
|
||||
self._active_downloads[download_id]["status"] = "downloading"
|
||||
last_error = None
|
||||
for download_url in download_urls:
|
||||
download_url = normalize_civitai_download_url(download_url)
|
||||
use_auth = download_url.startswith(CIVITAI_DOWNLOAD_URL_PREFIXES)
|
||||
download_kwargs = {
|
||||
"progress_callback": lambda progress, snapshot=None: (
|
||||
success, result = await self._download_model_file(
|
||||
download_url,
|
||||
save_path,
|
||||
backend=transfer_backend,
|
||||
progress_callback=lambda progress, snapshot=None: (
|
||||
self._handle_download_progress(
|
||||
progress,
|
||||
progress_callback,
|
||||
snapshot,
|
||||
)
|
||||
),
|
||||
"use_auth": use_auth, # Only use authentication for Civitai downloads
|
||||
}
|
||||
|
||||
if pause_control is not None:
|
||||
download_kwargs["pause_event"] = pause_control
|
||||
|
||||
success, result = await downloader.download_file(
|
||||
download_url,
|
||||
save_path, # Use full path instead of separate dir and filename
|
||||
**download_kwargs,
|
||||
use_auth=use_auth,
|
||||
download_id=download_id,
|
||||
pause_control=pause_control,
|
||||
)
|
||||
|
||||
if success:
|
||||
@@ -1401,7 +1479,8 @@ class DownloadManager:
|
||||
extracted_files.append(dest_path)
|
||||
return extracted_files
|
||||
|
||||
return await asyncio.to_thread(_extract_sync)
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(self._archive_executor, _extract_sync)
|
||||
|
||||
async def _build_metadata_entries(
|
||||
self, base_metadata, file_paths: List[str]
|
||||
@@ -1511,8 +1590,28 @@ class DownloadManager:
|
||||
return {"success": False, "error": "Download task not found"}
|
||||
|
||||
try:
|
||||
# Get the task and cancel it
|
||||
task = self._download_tasks[download_id]
|
||||
backend = (
|
||||
self._active_downloads.get(download_id, {}).get("transfer_backend")
|
||||
or "python"
|
||||
)
|
||||
|
||||
if backend == "aria2":
|
||||
try:
|
||||
aria2_downloader = await get_aria2_downloader()
|
||||
cancel_result = await aria2_downloader.cancel_download(download_id)
|
||||
if (
|
||||
not cancel_result.get("success")
|
||||
and cancel_result.get("error") != "Download task not found"
|
||||
):
|
||||
return cancel_result
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to cancel aria2 transfer for %s, continuing with local task cancellation: %s",
|
||||
download_id,
|
||||
exc,
|
||||
)
|
||||
|
||||
task.cancel()
|
||||
|
||||
pause_control = self._pause_events.get(download_id)
|
||||
@@ -1613,6 +1712,28 @@ class DownloadManager:
|
||||
|
||||
pause_control.pause()
|
||||
|
||||
backend = (
|
||||
self._active_downloads.get(download_id, {}).get("transfer_backend")
|
||||
or "python"
|
||||
)
|
||||
if backend == "aria2":
|
||||
try:
|
||||
aria2_downloader = await get_aria2_downloader()
|
||||
if await aria2_downloader.has_transfer(download_id):
|
||||
result = await aria2_downloader.pause_download(download_id)
|
||||
if not result.get("success"):
|
||||
pause_control.resume()
|
||||
return result
|
||||
except Exception as exc:
|
||||
pause_control.resume()
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
download_info = self._active_downloads.get(download_id)
|
||||
if download_info is not None:
|
||||
download_info["status"] = "paused"
|
||||
download_info["bytes_per_second"] = 0.0
|
||||
return {"success": True, "message": "Download paused successfully"}
|
||||
|
||||
download_info = self._active_downloads.get(download_id)
|
||||
if download_info is not None:
|
||||
download_info["status"] = "paused"
|
||||
@@ -1631,6 +1752,28 @@ class DownloadManager:
|
||||
return {"success": False, "error": "Download is not paused"}
|
||||
|
||||
download_info = self._active_downloads.get(download_id)
|
||||
backend = (
|
||||
self._active_downloads.get(download_id, {}).get("transfer_backend")
|
||||
or "python"
|
||||
)
|
||||
if backend == "aria2":
|
||||
try:
|
||||
aria2_downloader = await get_aria2_downloader()
|
||||
if await aria2_downloader.has_transfer(download_id):
|
||||
result = await aria2_downloader.resume_download(download_id)
|
||||
if not result.get("success"):
|
||||
return result
|
||||
except Exception as exc:
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
pause_control.resume()
|
||||
|
||||
if download_info is not None:
|
||||
if download_info.get("status") == "paused":
|
||||
download_info["status"] = "downloading"
|
||||
download_info.setdefault("bytes_per_second", 0.0)
|
||||
return {"success": True, "message": "Download resumed successfully"}
|
||||
|
||||
force_reconnect = False
|
||||
if pause_control is not None:
|
||||
elapsed = pause_control.time_since_last_progress()
|
||||
|
||||
@@ -55,6 +55,8 @@ DEFAULT_KEYS_CLEANUP_THRESHOLD = 10
|
||||
DEFAULT_SETTINGS: Dict[str, Any] = {
|
||||
"civitai_api_key": "",
|
||||
"civitai_host": "civitai.com",
|
||||
"download_backend": "python",
|
||||
"aria2c_path": "",
|
||||
"use_portable_settings": False,
|
||||
"hash_chunk_size_mb": DEFAULT_HASH_CHUNK_SIZE_MB,
|
||||
"language": "en",
|
||||
|
||||
Reference in New Issue
Block a user