mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-06 16:36:45 -03:00
fix(download): restore aria2 resume lifecycle
This commit is contained in:
@@ -15,6 +15,7 @@ from typing import Any, Dict, Optional, Tuple
|
||||
import aiohttp
|
||||
|
||||
from .downloader import DownloadProgress, get_downloader
|
||||
from .aria2_transfer_state import Aria2TransferStateStore
|
||||
from .settings_manager import get_settings_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -64,6 +65,7 @@ class Aria2Downloader:
|
||||
self._process_lock = asyncio.Lock()
|
||||
self._transfers: Dict[str, Aria2Transfer] = {}
|
||||
self._poll_interval = 0.5
|
||||
self._state_store = Aria2TransferStateStore()
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
@@ -82,6 +84,48 @@ class Aria2Downloader:
|
||||
|
||||
await self._ensure_process()
|
||||
save_path = os.path.abspath(save_path)
|
||||
transfer = self._transfers.get(download_id)
|
||||
if transfer is None or os.path.abspath(transfer.save_path) != save_path:
|
||||
gid = await self._schedule_download(
|
||||
url,
|
||||
save_path,
|
||||
download_id=download_id,
|
||||
headers=headers,
|
||||
)
|
||||
transfer = Aria2Transfer(gid=gid, save_path=save_path)
|
||||
self._transfers[download_id] = transfer
|
||||
|
||||
try:
|
||||
while True:
|
||||
status = await self.get_status(download_id)
|
||||
if status is None:
|
||||
return False, "aria2 download not found"
|
||||
|
||||
snapshot = self._build_progress_snapshot(status)
|
||||
if progress_callback is not None:
|
||||
await self._dispatch_progress(progress_callback, snapshot)
|
||||
|
||||
state = status.get("status", "")
|
||||
if state == "complete":
|
||||
completed_path = self._resolve_completed_path(status, save_path)
|
||||
return True, completed_path
|
||||
if state == "error":
|
||||
return False, status.get("errorMessage") or "aria2 download failed"
|
||||
if state == "removed":
|
||||
return False, "Download was cancelled"
|
||||
|
||||
await asyncio.sleep(self._poll_interval)
|
||||
finally:
|
||||
self._transfers.pop(download_id, None)
|
||||
|
||||
async def _schedule_download(
|
||||
self,
|
||||
url: str,
|
||||
save_path: str,
|
||||
*,
|
||||
download_id: str,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
) -> str:
|
||||
save_dir = os.path.dirname(save_path)
|
||||
out_name = os.path.basename(save_path)
|
||||
|
||||
@@ -128,31 +172,16 @@ class Aria2Downloader:
|
||||
raise Aria2Error(f"Failed to schedule aria2 download: {exc}") from exc
|
||||
|
||||
logger.debug("aria2 accepted download %s with gid %s", download_id, gid)
|
||||
|
||||
self._transfers[download_id] = Aria2Transfer(gid=gid, save_path=save_path)
|
||||
|
||||
try:
|
||||
while True:
|
||||
status = await self.get_status(download_id)
|
||||
if status is None:
|
||||
return False, "aria2 download not found"
|
||||
|
||||
snapshot = self._build_progress_snapshot(status)
|
||||
if progress_callback is not None:
|
||||
await self._dispatch_progress(progress_callback, snapshot)
|
||||
|
||||
state = status.get("status", "")
|
||||
if state == "complete":
|
||||
completed_path = self._resolve_completed_path(status, save_path)
|
||||
return True, completed_path
|
||||
if state == "error":
|
||||
return False, status.get("errorMessage") or "aria2 download failed"
|
||||
if state == "removed":
|
||||
return False, "Download was cancelled"
|
||||
|
||||
await asyncio.sleep(self._poll_interval)
|
||||
finally:
|
||||
self._transfers.pop(download_id, None)
|
||||
await self._state_store.upsert(
|
||||
download_id,
|
||||
{
|
||||
"gid": gid,
|
||||
"save_path": save_path,
|
||||
"status": "downloading",
|
||||
"url": url,
|
||||
},
|
||||
)
|
||||
return gid
|
||||
|
||||
async def get_status(self, download_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Return the raw aria2 status payload for a known download."""
|
||||
@@ -179,6 +208,47 @@ class Aria2Downloader:
|
||||
return status
|
||||
return None
|
||||
|
||||
async def get_status_by_gid(self, gid: str) -> Optional[Dict[str, Any]]:
|
||||
keys = [
|
||||
"gid",
|
||||
"status",
|
||||
"totalLength",
|
||||
"completedLength",
|
||||
"downloadSpeed",
|
||||
"errorMessage",
|
||||
"files",
|
||||
]
|
||||
try:
|
||||
status = await self._rpc_call("aria2.tellStatus", [gid, keys])
|
||||
except Exception as exc:
|
||||
message = str(exc)
|
||||
if "cannot be found" in message.lower() or "not found" in message.lower():
|
||||
return None
|
||||
raise Aria2Error(f"Failed to query aria2 download status: {exc}") from exc
|
||||
|
||||
if isinstance(status, dict):
|
||||
return status
|
||||
return None
|
||||
|
||||
async def restore_transfer(self, download_id: str, gid: str, save_path: str) -> None:
|
||||
await self._ensure_process()
|
||||
self._transfers[download_id] = Aria2Transfer(
|
||||
gid=gid,
|
||||
save_path=os.path.abspath(save_path),
|
||||
)
|
||||
|
||||
async def reassign_transfer(
|
||||
self, from_download_id: str, to_download_id: str
|
||||
) -> Optional[Aria2Transfer]:
|
||||
transfer = self._transfers.get(from_download_id)
|
||||
if transfer is None:
|
||||
return None
|
||||
|
||||
self._transfers[to_download_id] = transfer
|
||||
if from_download_id != to_download_id:
|
||||
self._transfers.pop(from_download_id, None)
|
||||
return transfer
|
||||
|
||||
async def has_transfer(self, download_id: str) -> bool:
|
||||
return download_id in self._transfers
|
||||
|
||||
@@ -192,6 +262,7 @@ class Aria2Downloader:
|
||||
except Exception as exc:
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
await self._state_store.upsert(download_id, {"status": "paused"})
|
||||
return {"success": True, "message": "Download paused successfully"}
|
||||
|
||||
async def resume_download(self, download_id: str) -> Dict[str, Any]:
|
||||
@@ -204,6 +275,7 @@ class Aria2Downloader:
|
||||
except Exception as exc:
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
await self._state_store.upsert(download_id, {"status": "downloading"})
|
||||
return {"success": True, "message": "Download resumed successfully"}
|
||||
|
||||
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
|
||||
@@ -216,6 +288,7 @@ class Aria2Downloader:
|
||||
except Exception as exc:
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
await self._state_store.remove(download_id)
|
||||
return {"success": True, "message": "Download cancelled successfully"}
|
||||
|
||||
async def close(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user