mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-06 16:36:45 -03:00
feat(download): add experimental aria2 backend
This commit is contained in:
@@ -5,6 +5,7 @@ import asyncio
|
||||
import inspect
|
||||
import shutil
|
||||
import zipfile
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from collections import OrderedDict
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
@@ -25,6 +26,7 @@ from .service_registry import ServiceRegistry
|
||||
from .settings_manager import get_settings_manager
|
||||
from .metadata_service import get_default_metadata_provider, get_metadata_provider
|
||||
from .downloader import get_downloader, DownloadProgress, DownloadStreamControl
|
||||
from .aria2_downloader import Aria2Error, get_aria2_downloader
|
||||
|
||||
# Download to temporary file first
|
||||
import tempfile
|
||||
@@ -60,6 +62,59 @@ class DownloadManager:
|
||||
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
|
||||
self._download_tasks = {} # download_id -> asyncio.Task
|
||||
self._pause_events: Dict[str, DownloadStreamControl] = {}
|
||||
self._archive_executor = ThreadPoolExecutor(
|
||||
max_workers=2, thread_name_prefix="lm-archive"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_model_download_backend() -> str:
|
||||
backend = (get_settings_manager().get("download_backend") or "python").strip()
|
||||
return backend.lower() or "python"
|
||||
|
||||
async def _download_model_file(
|
||||
self,
|
||||
download_url: str,
|
||||
save_path: str,
|
||||
*,
|
||||
backend: str,
|
||||
progress_callback,
|
||||
use_auth: bool,
|
||||
download_id: Optional[str],
|
||||
pause_control: Optional[DownloadStreamControl],
|
||||
) -> Tuple[bool, str]:
|
||||
if backend == "aria2":
|
||||
if not download_id:
|
||||
return False, "aria2 downloads require a tracked download_id"
|
||||
|
||||
headers: Dict[str, str] = {}
|
||||
if use_auth:
|
||||
api_key = (get_settings_manager().get("civitai_api_key") or "").strip()
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
try:
|
||||
aria2_downloader = await get_aria2_downloader()
|
||||
return await aria2_downloader.download_file(
|
||||
download_url,
|
||||
save_path,
|
||||
download_id=download_id,
|
||||
progress_callback=progress_callback,
|
||||
headers=headers or None,
|
||||
)
|
||||
except Aria2Error as exc:
|
||||
logger.error("aria2 download failed for %s: %s", download_url, exc)
|
||||
return False, str(exc)
|
||||
|
||||
download_kwargs = {
|
||||
"progress_callback": progress_callback,
|
||||
"use_auth": use_auth,
|
||||
}
|
||||
|
||||
if pause_control is not None:
|
||||
download_kwargs["pause_event"] = pause_control
|
||||
|
||||
downloader = await get_downloader()
|
||||
return await downloader.download_file(download_url, save_path, **download_kwargs)
|
||||
|
||||
async def _get_lora_scanner(self):
|
||||
"""Get the lora scanner from registry"""
|
||||
@@ -126,6 +181,7 @@ class DownloadManager:
|
||||
"model_version_id": model_version_id,
|
||||
"progress": 0,
|
||||
"status": "queued",
|
||||
"transfer_backend": self._get_model_download_backend(),
|
||||
"bytes_downloaded": 0,
|
||||
"total_bytes": None,
|
||||
"bytes_per_second": 0.0,
|
||||
@@ -240,6 +296,9 @@ class DownloadManager:
|
||||
tracking_callback,
|
||||
use_default_paths,
|
||||
task_id,
|
||||
self._active_downloads.get(task_id, {}).get(
|
||||
"transfer_backend", "python"
|
||||
),
|
||||
source,
|
||||
file_params,
|
||||
)
|
||||
@@ -294,6 +353,7 @@ class DownloadManager:
|
||||
progress_callback,
|
||||
use_default_paths,
|
||||
download_id=None,
|
||||
transfer_backend="python",
|
||||
source=None,
|
||||
file_params=None,
|
||||
):
|
||||
@@ -696,16 +756,27 @@ class DownloadManager:
|
||||
logger.info(f"Creating EmbeddingMetadata for {file_name}")
|
||||
|
||||
# 6. Start download process
|
||||
result = await self._execute_download(
|
||||
download_urls=download_urls,
|
||||
save_dir=save_dir,
|
||||
metadata=metadata,
|
||||
version_info=version_info,
|
||||
relative_path=relative_path,
|
||||
progress_callback=progress_callback,
|
||||
model_type=model_type,
|
||||
download_id=download_id,
|
||||
)
|
||||
execute_kwargs = {
|
||||
"download_urls": download_urls,
|
||||
"save_dir": save_dir,
|
||||
"metadata": metadata,
|
||||
"version_info": version_info,
|
||||
"relative_path": relative_path,
|
||||
"progress_callback": progress_callback,
|
||||
"model_type": model_type,
|
||||
"download_id": download_id,
|
||||
}
|
||||
execute_signature = inspect.signature(self._execute_download)
|
||||
if (
|
||||
"transfer_backend" in execute_signature.parameters
|
||||
or any(
|
||||
parameter.kind == inspect.Parameter.VAR_KEYWORD
|
||||
for parameter in execute_signature.parameters.values()
|
||||
)
|
||||
):
|
||||
execute_kwargs["transfer_backend"] = transfer_backend
|
||||
|
||||
result = await self._execute_download(**execute_kwargs)
|
||||
|
||||
if result.get("success", False):
|
||||
resolved_model_id = (
|
||||
@@ -965,6 +1036,7 @@ class DownloadManager:
|
||||
progress_callback=None,
|
||||
model_type: str = "lora",
|
||||
download_id: str = None,
|
||||
transfer_backend: Optional[str] = None,
|
||||
) -> Dict:
|
||||
"""Execute the actual download process including preview images and model files"""
|
||||
metadata_entries: List = []
|
||||
@@ -974,6 +1046,7 @@ class DownloadManager:
|
||||
preview_targets: List[str] = []
|
||||
preview_path: str | None = None
|
||||
preview_nsfw_level = 0
|
||||
transfer_backend = (transfer_backend or self._get_model_download_backend()).lower()
|
||||
try:
|
||||
# Extract original filename details
|
||||
original_filename = os.path.basename(metadata.file_path)
|
||||
@@ -1136,32 +1209,37 @@ class DownloadManager:
|
||||
if progress_callback:
|
||||
await progress_callback(3) # 3% progress after preview download
|
||||
|
||||
# Download model file with progress tracking using downloader
|
||||
downloader = await get_downloader()
|
||||
if pause_control is not None:
|
||||
pause_control.update_stall_timeout(downloader.stall_timeout)
|
||||
# Download model file with progress tracking using the configured backend
|
||||
downloader = None
|
||||
if transfer_backend == "python":
|
||||
downloader = await get_downloader()
|
||||
if pause_control is not None:
|
||||
pause_control.update_stall_timeout(downloader.stall_timeout)
|
||||
if pause_control is not None and pause_control.is_paused():
|
||||
if download_id and download_id in self._active_downloads:
|
||||
self._active_downloads[download_id]["status"] = "paused"
|
||||
self._active_downloads[download_id]["bytes_per_second"] = 0.0
|
||||
await pause_control.wait()
|
||||
if download_id and download_id in self._active_downloads:
|
||||
self._active_downloads[download_id]["status"] = "downloading"
|
||||
last_error = None
|
||||
for download_url in download_urls:
|
||||
download_url = normalize_civitai_download_url(download_url)
|
||||
use_auth = download_url.startswith(CIVITAI_DOWNLOAD_URL_PREFIXES)
|
||||
download_kwargs = {
|
||||
"progress_callback": lambda progress, snapshot=None: (
|
||||
success, result = await self._download_model_file(
|
||||
download_url,
|
||||
save_path,
|
||||
backend=transfer_backend,
|
||||
progress_callback=lambda progress, snapshot=None: (
|
||||
self._handle_download_progress(
|
||||
progress,
|
||||
progress_callback,
|
||||
snapshot,
|
||||
)
|
||||
),
|
||||
"use_auth": use_auth, # Only use authentication for Civitai downloads
|
||||
}
|
||||
|
||||
if pause_control is not None:
|
||||
download_kwargs["pause_event"] = pause_control
|
||||
|
||||
success, result = await downloader.download_file(
|
||||
download_url,
|
||||
save_path, # Use full path instead of separate dir and filename
|
||||
**download_kwargs,
|
||||
use_auth=use_auth,
|
||||
download_id=download_id,
|
||||
pause_control=pause_control,
|
||||
)
|
||||
|
||||
if success:
|
||||
@@ -1401,7 +1479,8 @@ class DownloadManager:
|
||||
extracted_files.append(dest_path)
|
||||
return extracted_files
|
||||
|
||||
return await asyncio.to_thread(_extract_sync)
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(self._archive_executor, _extract_sync)
|
||||
|
||||
async def _build_metadata_entries(
|
||||
self, base_metadata, file_paths: List[str]
|
||||
@@ -1511,8 +1590,28 @@ class DownloadManager:
|
||||
return {"success": False, "error": "Download task not found"}
|
||||
|
||||
try:
|
||||
# Get the task and cancel it
|
||||
task = self._download_tasks[download_id]
|
||||
backend = (
|
||||
self._active_downloads.get(download_id, {}).get("transfer_backend")
|
||||
or "python"
|
||||
)
|
||||
|
||||
if backend == "aria2":
|
||||
try:
|
||||
aria2_downloader = await get_aria2_downloader()
|
||||
cancel_result = await aria2_downloader.cancel_download(download_id)
|
||||
if (
|
||||
not cancel_result.get("success")
|
||||
and cancel_result.get("error") != "Download task not found"
|
||||
):
|
||||
return cancel_result
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to cancel aria2 transfer for %s, continuing with local task cancellation: %s",
|
||||
download_id,
|
||||
exc,
|
||||
)
|
||||
|
||||
task.cancel()
|
||||
|
||||
pause_control = self._pause_events.get(download_id)
|
||||
@@ -1613,6 +1712,28 @@ class DownloadManager:
|
||||
|
||||
pause_control.pause()
|
||||
|
||||
backend = (
|
||||
self._active_downloads.get(download_id, {}).get("transfer_backend")
|
||||
or "python"
|
||||
)
|
||||
if backend == "aria2":
|
||||
try:
|
||||
aria2_downloader = await get_aria2_downloader()
|
||||
if await aria2_downloader.has_transfer(download_id):
|
||||
result = await aria2_downloader.pause_download(download_id)
|
||||
if not result.get("success"):
|
||||
pause_control.resume()
|
||||
return result
|
||||
except Exception as exc:
|
||||
pause_control.resume()
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
download_info = self._active_downloads.get(download_id)
|
||||
if download_info is not None:
|
||||
download_info["status"] = "paused"
|
||||
download_info["bytes_per_second"] = 0.0
|
||||
return {"success": True, "message": "Download paused successfully"}
|
||||
|
||||
download_info = self._active_downloads.get(download_id)
|
||||
if download_info is not None:
|
||||
download_info["status"] = "paused"
|
||||
@@ -1631,6 +1752,28 @@ class DownloadManager:
|
||||
return {"success": False, "error": "Download is not paused"}
|
||||
|
||||
download_info = self._active_downloads.get(download_id)
|
||||
backend = (
|
||||
self._active_downloads.get(download_id, {}).get("transfer_backend")
|
||||
or "python"
|
||||
)
|
||||
if backend == "aria2":
|
||||
try:
|
||||
aria2_downloader = await get_aria2_downloader()
|
||||
if await aria2_downloader.has_transfer(download_id):
|
||||
result = await aria2_downloader.resume_download(download_id)
|
||||
if not result.get("success"):
|
||||
return result
|
||||
except Exception as exc:
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
pause_control.resume()
|
||||
|
||||
if download_info is not None:
|
||||
if download_info.get("status") == "paused":
|
||||
download_info["status"] = "downloading"
|
||||
download_info.setdefault("bytes_per_second", 0.0)
|
||||
return {"success": True, "message": "Download resumed successfully"}
|
||||
|
||||
force_reconnect = False
|
||||
if pause_control is not None:
|
||||
elapsed = pause_control.time_since_last_progress()
|
||||
|
||||
Reference in New Issue
Block a user