mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-06 16:36:45 -03:00
fix(download): restore aria2 resume lifecycle
This commit is contained in:
@@ -15,6 +15,7 @@ from typing import Any, Dict, Optional, Tuple
|
||||
import aiohttp
|
||||
|
||||
from .downloader import DownloadProgress, get_downloader
|
||||
from .aria2_transfer_state import Aria2TransferStateStore
|
||||
from .settings_manager import get_settings_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -64,6 +65,7 @@ class Aria2Downloader:
|
||||
self._process_lock = asyncio.Lock()
|
||||
self._transfers: Dict[str, Aria2Transfer] = {}
|
||||
self._poll_interval = 0.5
|
||||
self._state_store = Aria2TransferStateStore()
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
@@ -82,6 +84,48 @@ class Aria2Downloader:
|
||||
|
||||
await self._ensure_process()
|
||||
save_path = os.path.abspath(save_path)
|
||||
transfer = self._transfers.get(download_id)
|
||||
if transfer is None or os.path.abspath(transfer.save_path) != save_path:
|
||||
gid = await self._schedule_download(
|
||||
url,
|
||||
save_path,
|
||||
download_id=download_id,
|
||||
headers=headers,
|
||||
)
|
||||
transfer = Aria2Transfer(gid=gid, save_path=save_path)
|
||||
self._transfers[download_id] = transfer
|
||||
|
||||
try:
|
||||
while True:
|
||||
status = await self.get_status(download_id)
|
||||
if status is None:
|
||||
return False, "aria2 download not found"
|
||||
|
||||
snapshot = self._build_progress_snapshot(status)
|
||||
if progress_callback is not None:
|
||||
await self._dispatch_progress(progress_callback, snapshot)
|
||||
|
||||
state = status.get("status", "")
|
||||
if state == "complete":
|
||||
completed_path = self._resolve_completed_path(status, save_path)
|
||||
return True, completed_path
|
||||
if state == "error":
|
||||
return False, status.get("errorMessage") or "aria2 download failed"
|
||||
if state == "removed":
|
||||
return False, "Download was cancelled"
|
||||
|
||||
await asyncio.sleep(self._poll_interval)
|
||||
finally:
|
||||
self._transfers.pop(download_id, None)
|
||||
|
||||
async def _schedule_download(
|
||||
self,
|
||||
url: str,
|
||||
save_path: str,
|
||||
*,
|
||||
download_id: str,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
) -> str:
|
||||
save_dir = os.path.dirname(save_path)
|
||||
out_name = os.path.basename(save_path)
|
||||
|
||||
@@ -128,31 +172,16 @@ class Aria2Downloader:
|
||||
raise Aria2Error(f"Failed to schedule aria2 download: {exc}") from exc
|
||||
|
||||
logger.debug("aria2 accepted download %s with gid %s", download_id, gid)
|
||||
|
||||
self._transfers[download_id] = Aria2Transfer(gid=gid, save_path=save_path)
|
||||
|
||||
try:
|
||||
while True:
|
||||
status = await self.get_status(download_id)
|
||||
if status is None:
|
||||
return False, "aria2 download not found"
|
||||
|
||||
snapshot = self._build_progress_snapshot(status)
|
||||
if progress_callback is not None:
|
||||
await self._dispatch_progress(progress_callback, snapshot)
|
||||
|
||||
state = status.get("status", "")
|
||||
if state == "complete":
|
||||
completed_path = self._resolve_completed_path(status, save_path)
|
||||
return True, completed_path
|
||||
if state == "error":
|
||||
return False, status.get("errorMessage") or "aria2 download failed"
|
||||
if state == "removed":
|
||||
return False, "Download was cancelled"
|
||||
|
||||
await asyncio.sleep(self._poll_interval)
|
||||
finally:
|
||||
self._transfers.pop(download_id, None)
|
||||
await self._state_store.upsert(
|
||||
download_id,
|
||||
{
|
||||
"gid": gid,
|
||||
"save_path": save_path,
|
||||
"status": "downloading",
|
||||
"url": url,
|
||||
},
|
||||
)
|
||||
return gid
|
||||
|
||||
async def get_status(self, download_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Return the raw aria2 status payload for a known download."""
|
||||
@@ -179,6 +208,47 @@ class Aria2Downloader:
|
||||
return status
|
||||
return None
|
||||
|
||||
async def get_status_by_gid(self, gid: str) -> Optional[Dict[str, Any]]:
|
||||
keys = [
|
||||
"gid",
|
||||
"status",
|
||||
"totalLength",
|
||||
"completedLength",
|
||||
"downloadSpeed",
|
||||
"errorMessage",
|
||||
"files",
|
||||
]
|
||||
try:
|
||||
status = await self._rpc_call("aria2.tellStatus", [gid, keys])
|
||||
except Exception as exc:
|
||||
message = str(exc)
|
||||
if "cannot be found" in message.lower() or "not found" in message.lower():
|
||||
return None
|
||||
raise Aria2Error(f"Failed to query aria2 download status: {exc}") from exc
|
||||
|
||||
if isinstance(status, dict):
|
||||
return status
|
||||
return None
|
||||
|
||||
async def restore_transfer(self, download_id: str, gid: str, save_path: str) -> None:
|
||||
await self._ensure_process()
|
||||
self._transfers[download_id] = Aria2Transfer(
|
||||
gid=gid,
|
||||
save_path=os.path.abspath(save_path),
|
||||
)
|
||||
|
||||
async def reassign_transfer(
|
||||
self, from_download_id: str, to_download_id: str
|
||||
) -> Optional[Aria2Transfer]:
|
||||
transfer = self._transfers.get(from_download_id)
|
||||
if transfer is None:
|
||||
return None
|
||||
|
||||
self._transfers[to_download_id] = transfer
|
||||
if from_download_id != to_download_id:
|
||||
self._transfers.pop(from_download_id, None)
|
||||
return transfer
|
||||
|
||||
async def has_transfer(self, download_id: str) -> bool:
|
||||
return download_id in self._transfers
|
||||
|
||||
@@ -192,6 +262,7 @@ class Aria2Downloader:
|
||||
except Exception as exc:
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
await self._state_store.upsert(download_id, {"status": "paused"})
|
||||
return {"success": True, "message": "Download paused successfully"}
|
||||
|
||||
async def resume_download(self, download_id: str) -> Dict[str, Any]:
|
||||
@@ -204,6 +275,7 @@ class Aria2Downloader:
|
||||
except Exception as exc:
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
await self._state_store.upsert(download_id, {"status": "downloading"})
|
||||
return {"success": True, "message": "Download resumed successfully"}
|
||||
|
||||
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
|
||||
@@ -216,6 +288,7 @@ class Aria2Downloader:
|
||||
except Exception as exc:
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
await self._state_store.remove(download_id)
|
||||
return {"success": True, "message": "Download cancelled successfully"}
|
||||
|
||||
async def close(self) -> None:
|
||||
|
||||
108
py/services/aria2_transfer_state.py
Normal file
108
py/services/aria2_transfer_state.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from ..utils.cache_paths import get_cache_base_dir
|
||||
|
||||
|
||||
def get_aria2_state_path() -> str:
|
||||
base_dir = get_cache_base_dir(create=True)
|
||||
state_dir = os.path.join(base_dir, "aria2")
|
||||
os.makedirs(state_dir, exist_ok=True)
|
||||
return os.path.join(state_dir, "downloads.json")
|
||||
|
||||
|
||||
class Aria2TransferStateStore:
|
||||
"""Persist aria2 transfer metadata needed for restart recovery."""
|
||||
|
||||
_locks_by_path: Dict[str, asyncio.Lock] = {}
|
||||
|
||||
def __init__(self, state_path: Optional[str] = None) -> None:
|
||||
self._state_path = os.path.abspath(state_path or get_aria2_state_path())
|
||||
self._lock = self._locks_by_path.setdefault(self._state_path, asyncio.Lock())
|
||||
|
||||
def _read_all_unlocked(self) -> Dict[str, Dict[str, Any]]:
|
||||
try:
|
||||
with open(self._state_path, "r", encoding="utf-8") as handle:
|
||||
data = json.load(handle)
|
||||
except FileNotFoundError:
|
||||
return {}
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return {}
|
||||
|
||||
normalized: Dict[str, Dict[str, Any]] = {}
|
||||
for download_id, entry in data.items():
|
||||
if isinstance(download_id, str) and isinstance(entry, dict):
|
||||
normalized[download_id] = entry
|
||||
return normalized
|
||||
|
||||
def _write_all_unlocked(self, data: Dict[str, Dict[str, Any]]) -> None:
|
||||
directory = os.path.dirname(self._state_path)
|
||||
if directory:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
temp_path = f"{self._state_path}.tmp"
|
||||
with open(temp_path, "w", encoding="utf-8") as handle:
|
||||
json.dump(data, handle, ensure_ascii=True, indent=2, sort_keys=True)
|
||||
os.replace(temp_path, self._state_path)
|
||||
|
||||
async def load_all(self) -> Dict[str, Dict[str, Any]]:
|
||||
async with self._lock:
|
||||
return deepcopy(self._read_all_unlocked())
|
||||
|
||||
async def get(self, download_id: str) -> Optional[Dict[str, Any]]:
|
||||
async with self._lock:
|
||||
return deepcopy(self._read_all_unlocked().get(download_id))
|
||||
|
||||
async def upsert(self, download_id: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async with self._lock:
|
||||
data = self._read_all_unlocked()
|
||||
current = data.get(download_id, {})
|
||||
current.update(payload)
|
||||
data[download_id] = current
|
||||
self._write_all_unlocked(data)
|
||||
return deepcopy(current)
|
||||
|
||||
async def remove(self, download_id: str) -> None:
|
||||
async with self._lock:
|
||||
data = self._read_all_unlocked()
|
||||
if download_id in data:
|
||||
del data[download_id]
|
||||
self._write_all_unlocked(data)
|
||||
|
||||
async def find_by_save_path(
|
||||
self, save_path: str, *, exclude_download_id: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
normalized_target = os.path.abspath(save_path)
|
||||
async with self._lock:
|
||||
data = self._read_all_unlocked()
|
||||
for download_id, entry in data.items():
|
||||
if exclude_download_id and download_id == exclude_download_id:
|
||||
continue
|
||||
candidate = entry.get("save_path")
|
||||
if isinstance(candidate, str) and os.path.abspath(candidate) == normalized_target:
|
||||
result = dict(entry)
|
||||
result["download_id"] = download_id
|
||||
return result
|
||||
return None
|
||||
|
||||
async def reassign(self, from_download_id: str, to_download_id: str) -> Optional[Dict[str, Any]]:
|
||||
async with self._lock:
|
||||
data = self._read_all_unlocked()
|
||||
existing = data.get(from_download_id)
|
||||
if existing is None:
|
||||
return None
|
||||
updated = dict(existing)
|
||||
updated["download_id"] = to_download_id
|
||||
data[to_download_id] = updated
|
||||
if from_download_id != to_download_id:
|
||||
data.pop(from_download_id, None)
|
||||
self._write_all_unlocked(data)
|
||||
return deepcopy(updated)
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user