mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-06 08:26:45 -03:00
fix(download): restore aria2 resume lifecycle
This commit is contained in:
@@ -15,6 +15,7 @@ from typing import Any, Dict, Optional, Tuple
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
from .downloader import DownloadProgress, get_downloader
|
from .downloader import DownloadProgress, get_downloader
|
||||||
|
from .aria2_transfer_state import Aria2TransferStateStore
|
||||||
from .settings_manager import get_settings_manager
|
from .settings_manager import get_settings_manager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -64,6 +65,7 @@ class Aria2Downloader:
|
|||||||
self._process_lock = asyncio.Lock()
|
self._process_lock = asyncio.Lock()
|
||||||
self._transfers: Dict[str, Aria2Transfer] = {}
|
self._transfers: Dict[str, Aria2Transfer] = {}
|
||||||
self._poll_interval = 0.5
|
self._poll_interval = 0.5
|
||||||
|
self._state_store = Aria2TransferStateStore()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_running(self) -> bool:
|
def is_running(self) -> bool:
|
||||||
@@ -82,6 +84,48 @@ class Aria2Downloader:
|
|||||||
|
|
||||||
await self._ensure_process()
|
await self._ensure_process()
|
||||||
save_path = os.path.abspath(save_path)
|
save_path = os.path.abspath(save_path)
|
||||||
|
transfer = self._transfers.get(download_id)
|
||||||
|
if transfer is None or os.path.abspath(transfer.save_path) != save_path:
|
||||||
|
gid = await self._schedule_download(
|
||||||
|
url,
|
||||||
|
save_path,
|
||||||
|
download_id=download_id,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
transfer = Aria2Transfer(gid=gid, save_path=save_path)
|
||||||
|
self._transfers[download_id] = transfer
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
status = await self.get_status(download_id)
|
||||||
|
if status is None:
|
||||||
|
return False, "aria2 download not found"
|
||||||
|
|
||||||
|
snapshot = self._build_progress_snapshot(status)
|
||||||
|
if progress_callback is not None:
|
||||||
|
await self._dispatch_progress(progress_callback, snapshot)
|
||||||
|
|
||||||
|
state = status.get("status", "")
|
||||||
|
if state == "complete":
|
||||||
|
completed_path = self._resolve_completed_path(status, save_path)
|
||||||
|
return True, completed_path
|
||||||
|
if state == "error":
|
||||||
|
return False, status.get("errorMessage") or "aria2 download failed"
|
||||||
|
if state == "removed":
|
||||||
|
return False, "Download was cancelled"
|
||||||
|
|
||||||
|
await asyncio.sleep(self._poll_interval)
|
||||||
|
finally:
|
||||||
|
self._transfers.pop(download_id, None)
|
||||||
|
|
||||||
|
async def _schedule_download(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
save_path: str,
|
||||||
|
*,
|
||||||
|
download_id: str,
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
|
) -> str:
|
||||||
save_dir = os.path.dirname(save_path)
|
save_dir = os.path.dirname(save_path)
|
||||||
out_name = os.path.basename(save_path)
|
out_name = os.path.basename(save_path)
|
||||||
|
|
||||||
@@ -128,31 +172,16 @@ class Aria2Downloader:
|
|||||||
raise Aria2Error(f"Failed to schedule aria2 download: {exc}") from exc
|
raise Aria2Error(f"Failed to schedule aria2 download: {exc}") from exc
|
||||||
|
|
||||||
logger.debug("aria2 accepted download %s with gid %s", download_id, gid)
|
logger.debug("aria2 accepted download %s with gid %s", download_id, gid)
|
||||||
|
await self._state_store.upsert(
|
||||||
self._transfers[download_id] = Aria2Transfer(gid=gid, save_path=save_path)
|
download_id,
|
||||||
|
{
|
||||||
try:
|
"gid": gid,
|
||||||
while True:
|
"save_path": save_path,
|
||||||
status = await self.get_status(download_id)
|
"status": "downloading",
|
||||||
if status is None:
|
"url": url,
|
||||||
return False, "aria2 download not found"
|
},
|
||||||
|
)
|
||||||
snapshot = self._build_progress_snapshot(status)
|
return gid
|
||||||
if progress_callback is not None:
|
|
||||||
await self._dispatch_progress(progress_callback, snapshot)
|
|
||||||
|
|
||||||
state = status.get("status", "")
|
|
||||||
if state == "complete":
|
|
||||||
completed_path = self._resolve_completed_path(status, save_path)
|
|
||||||
return True, completed_path
|
|
||||||
if state == "error":
|
|
||||||
return False, status.get("errorMessage") or "aria2 download failed"
|
|
||||||
if state == "removed":
|
|
||||||
return False, "Download was cancelled"
|
|
||||||
|
|
||||||
await asyncio.sleep(self._poll_interval)
|
|
||||||
finally:
|
|
||||||
self._transfers.pop(download_id, None)
|
|
||||||
|
|
||||||
async def get_status(self, download_id: str) -> Optional[Dict[str, Any]]:
|
async def get_status(self, download_id: str) -> Optional[Dict[str, Any]]:
|
||||||
"""Return the raw aria2 status payload for a known download."""
|
"""Return the raw aria2 status payload for a known download."""
|
||||||
@@ -179,6 +208,47 @@ class Aria2Downloader:
|
|||||||
return status
|
return status
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def get_status_by_gid(self, gid: str) -> Optional[Dict[str, Any]]:
|
||||||
|
keys = [
|
||||||
|
"gid",
|
||||||
|
"status",
|
||||||
|
"totalLength",
|
||||||
|
"completedLength",
|
||||||
|
"downloadSpeed",
|
||||||
|
"errorMessage",
|
||||||
|
"files",
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
status = await self._rpc_call("aria2.tellStatus", [gid, keys])
|
||||||
|
except Exception as exc:
|
||||||
|
message = str(exc)
|
||||||
|
if "cannot be found" in message.lower() or "not found" in message.lower():
|
||||||
|
return None
|
||||||
|
raise Aria2Error(f"Failed to query aria2 download status: {exc}") from exc
|
||||||
|
|
||||||
|
if isinstance(status, dict):
|
||||||
|
return status
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def restore_transfer(self, download_id: str, gid: str, save_path: str) -> None:
|
||||||
|
await self._ensure_process()
|
||||||
|
self._transfers[download_id] = Aria2Transfer(
|
||||||
|
gid=gid,
|
||||||
|
save_path=os.path.abspath(save_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def reassign_transfer(
|
||||||
|
self, from_download_id: str, to_download_id: str
|
||||||
|
) -> Optional[Aria2Transfer]:
|
||||||
|
transfer = self._transfers.get(from_download_id)
|
||||||
|
if transfer is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
self._transfers[to_download_id] = transfer
|
||||||
|
if from_download_id != to_download_id:
|
||||||
|
self._transfers.pop(from_download_id, None)
|
||||||
|
return transfer
|
||||||
|
|
||||||
async def has_transfer(self, download_id: str) -> bool:
|
async def has_transfer(self, download_id: str) -> bool:
|
||||||
return download_id in self._transfers
|
return download_id in self._transfers
|
||||||
|
|
||||||
@@ -192,6 +262,7 @@ class Aria2Downloader:
|
|||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
return {"success": False, "error": str(exc)}
|
return {"success": False, "error": str(exc)}
|
||||||
|
|
||||||
|
await self._state_store.upsert(download_id, {"status": "paused"})
|
||||||
return {"success": True, "message": "Download paused successfully"}
|
return {"success": True, "message": "Download paused successfully"}
|
||||||
|
|
||||||
async def resume_download(self, download_id: str) -> Dict[str, Any]:
|
async def resume_download(self, download_id: str) -> Dict[str, Any]:
|
||||||
@@ -204,6 +275,7 @@ class Aria2Downloader:
|
|||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
return {"success": False, "error": str(exc)}
|
return {"success": False, "error": str(exc)}
|
||||||
|
|
||||||
|
await self._state_store.upsert(download_id, {"status": "downloading"})
|
||||||
return {"success": True, "message": "Download resumed successfully"}
|
return {"success": True, "message": "Download resumed successfully"}
|
||||||
|
|
||||||
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
|
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
|
||||||
@@ -216,6 +288,7 @@ class Aria2Downloader:
|
|||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
return {"success": False, "error": str(exc)}
|
return {"success": False, "error": str(exc)}
|
||||||
|
|
||||||
|
await self._state_store.remove(download_id)
|
||||||
return {"success": True, "message": "Download cancelled successfully"}
|
return {"success": True, "message": "Download cancelled successfully"}
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
|
|||||||
108
py/services/aria2_transfer_state.py
Normal file
108
py/services/aria2_transfer_state.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from ..utils.cache_paths import get_cache_base_dir
|
||||||
|
|
||||||
|
|
||||||
|
def get_aria2_state_path() -> str:
|
||||||
|
base_dir = get_cache_base_dir(create=True)
|
||||||
|
state_dir = os.path.join(base_dir, "aria2")
|
||||||
|
os.makedirs(state_dir, exist_ok=True)
|
||||||
|
return os.path.join(state_dir, "downloads.json")
|
||||||
|
|
||||||
|
|
||||||
|
class Aria2TransferStateStore:
|
||||||
|
"""Persist aria2 transfer metadata needed for restart recovery."""
|
||||||
|
|
||||||
|
_locks_by_path: Dict[str, asyncio.Lock] = {}
|
||||||
|
|
||||||
|
def __init__(self, state_path: Optional[str] = None) -> None:
|
||||||
|
self._state_path = os.path.abspath(state_path or get_aria2_state_path())
|
||||||
|
self._lock = self._locks_by_path.setdefault(self._state_path, asyncio.Lock())
|
||||||
|
|
||||||
|
def _read_all_unlocked(self) -> Dict[str, Dict[str, Any]]:
|
||||||
|
try:
|
||||||
|
with open(self._state_path, "r", encoding="utf-8") as handle:
|
||||||
|
data = json.load(handle)
|
||||||
|
except FileNotFoundError:
|
||||||
|
return {}
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
normalized: Dict[str, Dict[str, Any]] = {}
|
||||||
|
for download_id, entry in data.items():
|
||||||
|
if isinstance(download_id, str) and isinstance(entry, dict):
|
||||||
|
normalized[download_id] = entry
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
def _write_all_unlocked(self, data: Dict[str, Dict[str, Any]]) -> None:
|
||||||
|
directory = os.path.dirname(self._state_path)
|
||||||
|
if directory:
|
||||||
|
os.makedirs(directory, exist_ok=True)
|
||||||
|
|
||||||
|
temp_path = f"{self._state_path}.tmp"
|
||||||
|
with open(temp_path, "w", encoding="utf-8") as handle:
|
||||||
|
json.dump(data, handle, ensure_ascii=True, indent=2, sort_keys=True)
|
||||||
|
os.replace(temp_path, self._state_path)
|
||||||
|
|
||||||
|
async def load_all(self) -> Dict[str, Dict[str, Any]]:
|
||||||
|
async with self._lock:
|
||||||
|
return deepcopy(self._read_all_unlocked())
|
||||||
|
|
||||||
|
async def get(self, download_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
async with self._lock:
|
||||||
|
return deepcopy(self._read_all_unlocked().get(download_id))
|
||||||
|
|
||||||
|
async def upsert(self, download_id: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
async with self._lock:
|
||||||
|
data = self._read_all_unlocked()
|
||||||
|
current = data.get(download_id, {})
|
||||||
|
current.update(payload)
|
||||||
|
data[download_id] = current
|
||||||
|
self._write_all_unlocked(data)
|
||||||
|
return deepcopy(current)
|
||||||
|
|
||||||
|
async def remove(self, download_id: str) -> None:
|
||||||
|
async with self._lock:
|
||||||
|
data = self._read_all_unlocked()
|
||||||
|
if download_id in data:
|
||||||
|
del data[download_id]
|
||||||
|
self._write_all_unlocked(data)
|
||||||
|
|
||||||
|
async def find_by_save_path(
|
||||||
|
self, save_path: str, *, exclude_download_id: Optional[str] = None
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
normalized_target = os.path.abspath(save_path)
|
||||||
|
async with self._lock:
|
||||||
|
data = self._read_all_unlocked()
|
||||||
|
for download_id, entry in data.items():
|
||||||
|
if exclude_download_id and download_id == exclude_download_id:
|
||||||
|
continue
|
||||||
|
candidate = entry.get("save_path")
|
||||||
|
if isinstance(candidate, str) and os.path.abspath(candidate) == normalized_target:
|
||||||
|
result = dict(entry)
|
||||||
|
result["download_id"] = download_id
|
||||||
|
return result
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def reassign(self, from_download_id: str, to_download_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
async with self._lock:
|
||||||
|
data = self._read_all_unlocked()
|
||||||
|
existing = data.get(from_download_id)
|
||||||
|
if existing is None:
|
||||||
|
return None
|
||||||
|
updated = dict(existing)
|
||||||
|
updated["download_id"] = to_download_id
|
||||||
|
data[to_download_id] = updated
|
||||||
|
if from_download_id != to_download_id:
|
||||||
|
data.pop(from_download_id, None)
|
||||||
|
self._write_all_unlocked(data)
|
||||||
|
return deepcopy(updated)
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,11 +1,24 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from py.services.aria2_downloader import Aria2Downloader, Aria2Error
|
from py.services.aria2_downloader import Aria2Downloader, Aria2Error
|
||||||
|
from py.services.aria2_transfer_state import Aria2TransferStateStore
|
||||||
|
from py.services import aria2_transfer_state
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def isolate_aria2_state(monkeypatch, tmp_path):
|
||||||
|
state_path = tmp_path / "cache" / "aria2" / "downloads.json"
|
||||||
|
monkeypatch.setattr(
|
||||||
|
aria2_transfer_state,
|
||||||
|
"get_aria2_state_path",
|
||||||
|
lambda: str(state_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -79,6 +92,23 @@ async def test_download_file_polls_until_complete(tmp_path, monkeypatch):
|
|||||||
assert "header" not in rpc_calls[0][1][1]
|
assert "header" not in rpc_calls[0][1][1]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_transfer_state_store_shares_lock_and_preserves_concurrent_updates(tmp_path):
|
||||||
|
state_path = tmp_path / "cache" / "aria2" / "downloads.json"
|
||||||
|
store_a = Aria2TransferStateStore(str(state_path))
|
||||||
|
store_b = Aria2TransferStateStore(str(state_path))
|
||||||
|
|
||||||
|
assert store_a._lock is store_b._lock
|
||||||
|
|
||||||
|
await asyncio.gather(
|
||||||
|
store_a.upsert("download-1", {"status": "downloading", "gid": "gid-1"}),
|
||||||
|
store_b.upsert("download-2", {"status": "paused", "gid": "gid-2"}),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert await store_a.get("download-1") == {"status": "downloading", "gid": "gid-1"}
|
||||||
|
assert await store_b.get("download-2") == {"status": "paused", "gid": "gid-2"}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_download_file_keeps_auth_headers_when_civitai_does_not_redirect(
|
async def test_download_file_keeps_auth_headers_when_civitai_does_not_redirect(
|
||||||
tmp_path, monkeypatch
|
tmp_path, monkeypatch
|
||||||
@@ -161,6 +191,61 @@ async def test_pause_resume_cancel_forward_to_rpc(monkeypatch):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_file_reuses_existing_transfer_without_add_uri(
|
||||||
|
tmp_path, monkeypatch
|
||||||
|
):
|
||||||
|
downloader = Aria2Downloader()
|
||||||
|
downloader._rpc_url = "http://127.0.0.1/jsonrpc"
|
||||||
|
downloader._rpc_secret = "secret"
|
||||||
|
|
||||||
|
save_path = tmp_path / "downloads" / "model.safetensors"
|
||||||
|
downloader._transfers["download-1"] = type(
|
||||||
|
"Transfer", (), {"gid": "gid-1", "save_path": str(save_path)}
|
||||||
|
)()
|
||||||
|
|
||||||
|
rpc_calls = []
|
||||||
|
statuses = iter(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"gid": "gid-1",
|
||||||
|
"status": "active",
|
||||||
|
"completedLength": "5",
|
||||||
|
"totalLength": "10",
|
||||||
|
"downloadSpeed": "25",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"gid": "gid-1",
|
||||||
|
"status": "complete",
|
||||||
|
"completedLength": "10",
|
||||||
|
"totalLength": "10",
|
||||||
|
"downloadSpeed": "0",
|
||||||
|
"files": [{"path": str(save_path)}],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def fake_rpc_call(method, params):
|
||||||
|
rpc_calls.append((method, params))
|
||||||
|
if method == "aria2.tellStatus":
|
||||||
|
return next(statuses)
|
||||||
|
raise AssertionError(f"Unexpected RPC method: {method}")
|
||||||
|
|
||||||
|
monkeypatch.setattr(downloader, "_ensure_process", AsyncMock())
|
||||||
|
monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call)
|
||||||
|
monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock())
|
||||||
|
|
||||||
|
success, result = await downloader.download_file(
|
||||||
|
"https://example.com/model.safetensors",
|
||||||
|
str(save_path),
|
||||||
|
download_id="download-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert success is True
|
||||||
|
assert result == str(save_path)
|
||||||
|
assert [call[0] for call in rpc_calls] == ["aria2.tellStatus", "aria2.tellStatus"]
|
||||||
|
|
||||||
|
|
||||||
def test_build_progress_snapshot_normalizes_numeric_fields():
|
def test_build_progress_snapshot_normalizes_numeric_fields():
|
||||||
downloader = Aria2Downloader()
|
downloader = Aria2Downloader()
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import pytest
|
|||||||
|
|
||||||
from py.services.download_manager import DownloadManager
|
from py.services.download_manager import DownloadManager
|
||||||
from py.services import download_manager
|
from py.services import download_manager
|
||||||
|
from py.services import aria2_transfer_state
|
||||||
from py.services.service_registry import ServiceRegistry
|
from py.services.service_registry import ServiceRegistry
|
||||||
from py.services.settings_manager import SettingsManager, get_settings_manager
|
from py.services.settings_manager import SettingsManager, get_settings_manager
|
||||||
|
|
||||||
@@ -46,6 +47,16 @@ def isolate_settings(monkeypatch, tmp_path):
|
|||||||
monkeypatch.setattr(SettingsManager, "_save_settings", lambda self: None)
|
monkeypatch.setattr(SettingsManager, "_save_settings", lambda self: None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def isolate_aria2_state(monkeypatch, tmp_path):
|
||||||
|
state_path = tmp_path / "cache" / "aria2" / "downloads.json"
|
||||||
|
monkeypatch.setattr(
|
||||||
|
aria2_transfer_state,
|
||||||
|
"get_aria2_state_path",
|
||||||
|
lambda: str(state_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def stub_metadata(monkeypatch):
|
def stub_metadata(monkeypatch):
|
||||||
class _StubMetadata:
|
class _StubMetadata:
|
||||||
@@ -439,6 +450,436 @@ async def test_pause_resume_queued_aria2_task_without_transfer(monkeypatch):
|
|||||||
await task
|
await task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resume_download_restores_persisted_aria2_task(monkeypatch, tmp_path):
|
||||||
|
manager = DownloadManager()
|
||||||
|
save_dir = tmp_path / "downloads"
|
||||||
|
save_dir.mkdir()
|
||||||
|
save_path = save_dir / "file.safetensors"
|
||||||
|
save_path.write_text("partial")
|
||||||
|
(save_dir / "file.safetensors.aria2").write_text("control")
|
||||||
|
|
||||||
|
await manager._aria2_state_store.upsert(
|
||||||
|
"download-1",
|
||||||
|
{
|
||||||
|
"download_id": "download-1",
|
||||||
|
"transfer_backend": "aria2",
|
||||||
|
"status": "paused",
|
||||||
|
"save_dir": str(save_dir),
|
||||||
|
"relative_path": "",
|
||||||
|
"use_default_paths": False,
|
||||||
|
"save_path": str(save_path),
|
||||||
|
"file_path": str(save_path),
|
||||||
|
"model_id": 12,
|
||||||
|
"model_version_id": 34,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
created = {}
|
||||||
|
|
||||||
|
async def fake_download_with_semaphore(
|
||||||
|
self,
|
||||||
|
task_id,
|
||||||
|
model_id,
|
||||||
|
model_version_id,
|
||||||
|
save_dir,
|
||||||
|
relative_path,
|
||||||
|
progress_callback=None,
|
||||||
|
use_default_paths=False,
|
||||||
|
source=None,
|
||||||
|
file_params=None,
|
||||||
|
):
|
||||||
|
created.update(
|
||||||
|
{
|
||||||
|
"task_id": task_id,
|
||||||
|
"model_id": model_id,
|
||||||
|
"model_version_id": model_version_id,
|
||||||
|
"save_dir": save_dir,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return {"success": True}
|
||||||
|
|
||||||
|
class DummyAria2Downloader:
|
||||||
|
def __init__(self):
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
async def get_status_by_gid(self, gid):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def has_transfer(self, download_id):
|
||||||
|
self.calls.append(("has_transfer", download_id))
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def resume_download(self, download_id):
|
||||||
|
self.calls.append(("resume", download_id))
|
||||||
|
return {"success": True, "message": "resumed"}
|
||||||
|
|
||||||
|
async def restore_transfer(self, download_id, gid, save_path):
|
||||||
|
self.calls.append(("restore_transfer", download_id, gid, save_path))
|
||||||
|
|
||||||
|
dummy_aria2 = DummyAria2Downloader()
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_manager, "_download_with_semaphore", None, raising=False
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
DownloadManager,
|
||||||
|
"_download_with_semaphore",
|
||||||
|
fake_download_with_semaphore,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_manager,
|
||||||
|
"get_aria2_downloader",
|
||||||
|
AsyncMock(return_value=dummy_aria2),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await manager.resume_download("download-1")
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
assert result == {"success": True, "message": "Download resumed successfully"}
|
||||||
|
assert created["task_id"] == "download-1"
|
||||||
|
assert created["model_version_id"] == 34
|
||||||
|
assert manager._active_downloads["download-1"]["status"] == "downloading"
|
||||||
|
assert manager._pause_events["download-1"].is_set() is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_active_downloads_restores_persisted_aria2_entries(monkeypatch, tmp_path):
|
||||||
|
manager = DownloadManager()
|
||||||
|
save_dir = tmp_path / "downloads"
|
||||||
|
save_dir.mkdir()
|
||||||
|
save_path = save_dir / "file.safetensors"
|
||||||
|
save_path.write_text("partial")
|
||||||
|
(save_dir / "file.safetensors.aria2").write_text("control")
|
||||||
|
|
||||||
|
await manager._aria2_state_store.upsert(
|
||||||
|
"download-1",
|
||||||
|
{
|
||||||
|
"download_id": "download-1",
|
||||||
|
"transfer_backend": "aria2",
|
||||||
|
"status": "paused",
|
||||||
|
"save_path": str(save_path),
|
||||||
|
"file_path": str(save_path),
|
||||||
|
"model_id": 12,
|
||||||
|
"model_version_id": 34,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
class DummyAria2Downloader:
|
||||||
|
async def get_status_by_gid(self, gid):
|
||||||
|
return None
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_manager,
|
||||||
|
"get_aria2_downloader",
|
||||||
|
AsyncMock(return_value=DummyAria2Downloader()),
|
||||||
|
)
|
||||||
|
|
||||||
|
downloads = await manager.get_active_downloads()
|
||||||
|
|
||||||
|
assert downloads["downloads"] == [
|
||||||
|
{
|
||||||
|
"download_id": "download-1",
|
||||||
|
"model_id": 12,
|
||||||
|
"model_version_id": 34,
|
||||||
|
"progress": 0,
|
||||||
|
"status": "paused",
|
||||||
|
"error": None,
|
||||||
|
"bytes_downloaded": 0,
|
||||||
|
"total_bytes": None,
|
||||||
|
"bytes_per_second": 0.0,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_active_downloads_restores_orphaned_aria2_partial_as_paused(
|
||||||
|
monkeypatch, tmp_path
|
||||||
|
):
|
||||||
|
manager = DownloadManager()
|
||||||
|
save_dir = tmp_path / "downloads"
|
||||||
|
save_dir.mkdir()
|
||||||
|
save_path = save_dir / "file.safetensors"
|
||||||
|
save_path.write_text("partial")
|
||||||
|
(save_dir / "file.safetensors.aria2").write_text("control")
|
||||||
|
|
||||||
|
await manager._aria2_state_store.upsert(
|
||||||
|
"download-1",
|
||||||
|
{
|
||||||
|
"download_id": "download-1",
|
||||||
|
"transfer_backend": "aria2",
|
||||||
|
"status": "downloading",
|
||||||
|
"save_path": str(save_path),
|
||||||
|
"file_path": str(save_path),
|
||||||
|
"model_id": 12,
|
||||||
|
"model_version_id": 34,
|
||||||
|
"gid": "missing-gid",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
class DummyAria2Downloader:
|
||||||
|
async def get_status_by_gid(self, gid):
|
||||||
|
return None
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_manager,
|
||||||
|
"get_aria2_downloader",
|
||||||
|
AsyncMock(return_value=DummyAria2Downloader()),
|
||||||
|
)
|
||||||
|
|
||||||
|
downloads = await manager.get_active_downloads()
|
||||||
|
persisted = await manager._aria2_state_store.get("download-1")
|
||||||
|
|
||||||
|
assert downloads["downloads"] == [
|
||||||
|
{
|
||||||
|
"download_id": "download-1",
|
||||||
|
"model_id": 12,
|
||||||
|
"model_version_id": 34,
|
||||||
|
"progress": 0,
|
||||||
|
"status": "paused",
|
||||||
|
"error": None,
|
||||||
|
"bytes_downloaded": 0,
|
||||||
|
"total_bytes": None,
|
||||||
|
"bytes_per_second": 0.0,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
assert manager._pause_events["download-1"].is_paused() is True
|
||||||
|
assert persisted["status"] == "paused"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_active_downloads_restarts_from_resume_context_for_active_restored_aria2(
|
||||||
|
monkeypatch, tmp_path
|
||||||
|
):
|
||||||
|
manager = DownloadManager()
|
||||||
|
save_dir = tmp_path / "downloads"
|
||||||
|
save_dir.mkdir()
|
||||||
|
save_path = save_dir / "file.safetensors"
|
||||||
|
save_path.write_text("partial")
|
||||||
|
|
||||||
|
await manager._aria2_state_store.upsert(
|
||||||
|
"download-1",
|
||||||
|
{
|
||||||
|
"download_id": "download-1",
|
||||||
|
"transfer_backend": "aria2",
|
||||||
|
"status": "downloading",
|
||||||
|
"save_path": str(save_path),
|
||||||
|
"file_path": str(save_path),
|
||||||
|
"model_id": 12,
|
||||||
|
"model_version_id": 34,
|
||||||
|
"gid": "gid-1",
|
||||||
|
"resume_context": {
|
||||||
|
"version_info": {
|
||||||
|
"id": 34,
|
||||||
|
"modelId": 12,
|
||||||
|
"model": {"id": 12, "type": "LoRA", "tags": ["fantasy"]},
|
||||||
|
"images": [],
|
||||||
|
},
|
||||||
|
"file_info": {
|
||||||
|
"name": "file.safetensors",
|
||||||
|
"type": "Model",
|
||||||
|
"primary": True,
|
||||||
|
"downloadUrl": "https://example.com/file.safetensors",
|
||||||
|
},
|
||||||
|
"model_type": "lora",
|
||||||
|
"relative_path": "",
|
||||||
|
"save_dir": str(save_dir),
|
||||||
|
"download_urls": ["https://example.com/file.safetensors"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
restarted = {}
|
||||||
|
|
||||||
|
class DummyAria2Downloader:
|
||||||
|
async def get_status_by_gid(self, gid):
|
||||||
|
return {"gid": gid, "status": "active"}
|
||||||
|
|
||||||
|
async def restore_transfer(self, download_id, gid, restored_path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_manager,
|
||||||
|
"get_aria2_downloader",
|
||||||
|
AsyncMock(return_value=DummyAria2Downloader()),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def fake_resume_restored_aria2_download(self, download_id, record):
|
||||||
|
restarted.update(
|
||||||
|
{
|
||||||
|
"download_id": download_id,
|
||||||
|
"model_id": record.get("model_id"),
|
||||||
|
"model_version_id": record.get("model_version_id"),
|
||||||
|
"save_dir": record.get("save_dir"),
|
||||||
|
"resume_context": record.get("resume_context"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return {"success": True}
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
DownloadManager,
|
||||||
|
"_resume_restored_aria2_download",
|
||||||
|
fake_resume_restored_aria2_download,
|
||||||
|
)
|
||||||
|
execute_original = AsyncMock(side_effect=AssertionError("should not refetch metadata"))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
DownloadManager,
|
||||||
|
"_execute_original_download",
|
||||||
|
execute_original,
|
||||||
|
)
|
||||||
|
|
||||||
|
downloads = await manager.get_active_downloads()
|
||||||
|
assert downloads["downloads"][0]["status"] == "downloading"
|
||||||
|
restarted_task = manager._download_tasks["download-1"]
|
||||||
|
await restarted_task
|
||||||
|
|
||||||
|
assert restarted["download_id"] == "download-1"
|
||||||
|
assert restarted["model_id"] == 12
|
||||||
|
assert restarted["model_version_id"] == 34
|
||||||
|
assert restarted["save_dir"] is None
|
||||||
|
assert restarted["resume_context"]["model_type"] == "lora"
|
||||||
|
assert execute_original.await_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_active_downloads_restores_persisted_aria2_without_initial_save_path(
|
||||||
|
monkeypatch, tmp_path
|
||||||
|
):
|
||||||
|
manager = DownloadManager()
|
||||||
|
save_dir = tmp_path / "downloads"
|
||||||
|
save_dir.mkdir()
|
||||||
|
save_path = save_dir / "file.safetensors"
|
||||||
|
save_path.write_text("partial")
|
||||||
|
(save_dir / "file.safetensors.aria2").write_text("control")
|
||||||
|
|
||||||
|
await manager._aria2_state_store.upsert(
|
||||||
|
"download-1",
|
||||||
|
{
|
||||||
|
"download_id": "download-1",
|
||||||
|
"transfer_backend": "aria2",
|
||||||
|
"status": "paused",
|
||||||
|
"model_id": 12,
|
||||||
|
"model_version_id": 34,
|
||||||
|
"resume_context": {
|
||||||
|
"version_info": {
|
||||||
|
"id": 34,
|
||||||
|
"modelId": 12,
|
||||||
|
"model": {"id": 12, "type": "LoRA"},
|
||||||
|
"images": [],
|
||||||
|
},
|
||||||
|
"file_info": {
|
||||||
|
"name": "file.safetensors",
|
||||||
|
"type": "Model",
|
||||||
|
"primary": True,
|
||||||
|
"downloadUrl": "https://example.com/file.safetensors",
|
||||||
|
},
|
||||||
|
"model_type": "lora",
|
||||||
|
"relative_path": "",
|
||||||
|
"save_dir": str(save_dir),
|
||||||
|
"download_urls": ["https://example.com/file.safetensors"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
class DummyAria2Downloader:
|
||||||
|
async def get_status_by_gid(self, gid):
|
||||||
|
return None
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_manager,
|
||||||
|
"get_aria2_downloader",
|
||||||
|
AsyncMock(return_value=DummyAria2Downloader()),
|
||||||
|
)
|
||||||
|
|
||||||
|
downloads = await manager.get_active_downloads()
|
||||||
|
persisted = await manager._aria2_state_store.get("download-1")
|
||||||
|
|
||||||
|
assert downloads["downloads"] == [
|
||||||
|
{
|
||||||
|
"download_id": "download-1",
|
||||||
|
"model_id": 12,
|
||||||
|
"model_version_id": 34,
|
||||||
|
"progress": 0,
|
||||||
|
"status": "paused",
|
||||||
|
"error": None,
|
||||||
|
"bytes_downloaded": 0,
|
||||||
|
"total_bytes": None,
|
||||||
|
"bytes_per_second": 0.0,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
assert manager._active_downloads["download-1"]["file_path"] == str(save_path)
|
||||||
|
assert persisted["save_path"] == str(save_path)
|
||||||
|
assert persisted["file_path"] == str(save_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resume_restored_aria2_download_updates_terminal_status_and_cleanup(monkeypatch):
|
||||||
|
manager = DownloadManager()
|
||||||
|
manager._active_downloads["download-1"] = {
|
||||||
|
"transfer_backend": "aria2",
|
||||||
|
"status": "paused",
|
||||||
|
"model_id": 12,
|
||||||
|
"model_version_id": 34,
|
||||||
|
"bytes_per_second": 10.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
persist_state = AsyncMock()
|
||||||
|
cleanup_record = AsyncMock(return_value=None)
|
||||||
|
execute_download = AsyncMock(return_value={"success": True})
|
||||||
|
record_history = AsyncMock(return_value=None)
|
||||||
|
sync_version = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
monkeypatch.setattr(manager, "_persist_aria2_state", persist_state)
|
||||||
|
monkeypatch.setattr(manager, "_cleanup_download_record", cleanup_record)
|
||||||
|
monkeypatch.setattr(manager, "_execute_download", execute_download)
|
||||||
|
monkeypatch.setattr(manager, "_record_downloaded_version_history", record_history)
|
||||||
|
monkeypatch.setattr(manager, "_sync_downloaded_version", sync_version)
|
||||||
|
|
||||||
|
scheduled_tasks = []
|
||||||
|
original_create_task = asyncio.create_task
|
||||||
|
|
||||||
|
def tracking_create_task(coro):
|
||||||
|
task = original_create_task(coro)
|
||||||
|
scheduled_tasks.append(task)
|
||||||
|
return task
|
||||||
|
|
||||||
|
monkeypatch.setattr(download_manager.asyncio, "create_task", tracking_create_task)
|
||||||
|
|
||||||
|
result = await manager._resume_restored_aria2_download(
|
||||||
|
"download-1",
|
||||||
|
{
|
||||||
|
"download_id": "download-1",
|
||||||
|
"save_path": "/tmp/file.safetensors",
|
||||||
|
"file_path": "/tmp/file.safetensors",
|
||||||
|
"model_id": 12,
|
||||||
|
"model_version_id": 34,
|
||||||
|
"resume_context": {
|
||||||
|
"version_info": {
|
||||||
|
"id": 34,
|
||||||
|
"modelId": 12,
|
||||||
|
"model": {"id": 12},
|
||||||
|
"images": [],
|
||||||
|
},
|
||||||
|
"file_info": {
|
||||||
|
"name": "file.safetensors",
|
||||||
|
"downloadUrl": "https://example.com/file.safetensors",
|
||||||
|
},
|
||||||
|
"model_type": "lora",
|
||||||
|
"relative_path": "",
|
||||||
|
"save_dir": "/tmp",
|
||||||
|
"download_urls": ["https://example.com/file.safetensors"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {"success": True}
|
||||||
|
assert manager._active_downloads["download-1"]["status"] == "completed"
|
||||||
|
assert manager._active_downloads["download-1"]["bytes_per_second"] == 0.0
|
||||||
|
assert persist_state.await_count == 2
|
||||||
|
assert len(scheduled_tasks) == 1
|
||||||
|
await asyncio.gather(*scheduled_tasks)
|
||||||
|
cleanup_record.assert_awaited_once_with("download-1")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_download_uses_captured_backend_when_settings_change(
|
async def test_download_uses_captured_backend_when_settings_change(
|
||||||
monkeypatch, scanners, metadata_provider, tmp_path
|
monkeypatch, scanners, metadata_provider, tmp_path
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import pytest
|
|||||||
from py.services.download_manager import DownloadManager
|
from py.services.download_manager import DownloadManager
|
||||||
from py.services.downloader import DownloadStreamControl
|
from py.services.downloader import DownloadStreamControl
|
||||||
from py.services import download_manager
|
from py.services import download_manager
|
||||||
|
from py.services import aria2_transfer_state
|
||||||
from py.services.service_registry import ServiceRegistry
|
from py.services.service_registry import ServiceRegistry
|
||||||
from py.services.settings_manager import SettingsManager, get_settings_manager
|
from py.services.settings_manager import SettingsManager, get_settings_manager
|
||||||
from py.utils.metadata_manager import MetadataManager
|
from py.utils.metadata_manager import MetadataManager
|
||||||
@@ -49,6 +50,16 @@ def isolate_settings(monkeypatch, tmp_path):
|
|||||||
monkeypatch.setattr(SettingsManager, "_save_settings", lambda self: None)
|
monkeypatch.setattr(SettingsManager, "_save_settings", lambda self: None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def isolate_aria2_state(monkeypatch, tmp_path):
|
||||||
|
state_path = tmp_path / "cache" / "aria2" / "downloads.json"
|
||||||
|
monkeypatch.setattr(
|
||||||
|
aria2_transfer_state,
|
||||||
|
"get_aria2_state_path",
|
||||||
|
lambda: str(state_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_execute_download_retries_urls(monkeypatch, tmp_path):
|
async def test_execute_download_retries_urls(monkeypatch, tmp_path):
|
||||||
"""Test that download retries multiple URLs on failure."""
|
"""Test that download retries multiple URLs on failure."""
|
||||||
@@ -800,6 +811,89 @@ async def test_resume_download_returns_error_when_aria2_probe_raises(monkeypatch
|
|||||||
assert manager._active_downloads[download_id]["status"] == "paused"
|
assert manager._active_downloads[download_id]["status"] == "paused"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resume_download_does_not_spawn_restored_worker_when_aria2_resume_fails(
|
||||||
|
monkeypatch, tmp_path
|
||||||
|
):
|
||||||
|
manager = DownloadManager()
|
||||||
|
|
||||||
|
download_id = "dl"
|
||||||
|
save_path = tmp_path / "file.safetensors"
|
||||||
|
pause_control = DownloadStreamControl()
|
||||||
|
pause_control.pause()
|
||||||
|
manager._pause_events[download_id] = pause_control
|
||||||
|
manager._active_downloads[download_id] = {
|
||||||
|
"transfer_backend": "aria2",
|
||||||
|
"status": "paused",
|
||||||
|
"bytes_per_second": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
await manager._aria2_state_store.upsert(
|
||||||
|
download_id,
|
||||||
|
{
|
||||||
|
"download_id": download_id,
|
||||||
|
"transfer_backend": "aria2",
|
||||||
|
"status": "paused",
|
||||||
|
"save_path": str(save_path),
|
||||||
|
"file_path": str(save_path),
|
||||||
|
"model_id": 12,
|
||||||
|
"model_version_id": 34,
|
||||||
|
"resume_context": {
|
||||||
|
"version_info": {"id": 34, "modelId": 12, "model": {"id": 12}},
|
||||||
|
"file_info": {
|
||||||
|
"name": "file.safetensors",
|
||||||
|
"downloadUrl": "https://example.com/file.safetensors",
|
||||||
|
},
|
||||||
|
"model_type": "lora",
|
||||||
|
"relative_path": "",
|
||||||
|
"save_dir": str(tmp_path),
|
||||||
|
"download_urls": ["https://example.com/file.safetensors"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
resume_restored = AsyncMock(return_value={"success": True})
|
||||||
|
monkeypatch.setattr(manager, "_resume_restored_aria2_download", resume_restored)
|
||||||
|
|
||||||
|
class DummyAria2Downloader:
|
||||||
|
async def has_transfer(self, _download_id):
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def resume_download(self, _download_id):
|
||||||
|
return {"success": False, "error": "rpc unavailable"}
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_manager,
|
||||||
|
"get_aria2_downloader",
|
||||||
|
AsyncMock(return_value=DummyAria2Downloader()),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await manager.resume_download(download_id)
|
||||||
|
|
||||||
|
assert result == {"success": False, "error": "rpc unavailable"}
|
||||||
|
assert download_id not in manager._download_tasks
|
||||||
|
assert resume_restored.await_count == 0
|
||||||
|
assert pause_control.is_paused() is True
|
||||||
|
assert manager._active_downloads[download_id]["status"] == "paused"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_background_download_task_cleans_up_finished_restore_task():
|
||||||
|
manager = DownloadManager()
|
||||||
|
download_id = "download-1"
|
||||||
|
manager._pause_events[download_id] = DownloadStreamControl()
|
||||||
|
|
||||||
|
async def finished_restore():
|
||||||
|
return {"success": True}
|
||||||
|
|
||||||
|
task = manager._start_background_download_task(download_id, finished_restore())
|
||||||
|
await task
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
assert download_id not in manager._download_tasks
|
||||||
|
assert download_id not in manager._pause_events
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_cancel_download_still_cancels_local_task_when_aria2_raises(monkeypatch):
|
async def test_cancel_download_still_cancels_local_task_when_aria2_raises(monkeypatch):
|
||||||
manager = DownloadManager()
|
manager = DownloadManager()
|
||||||
@@ -836,6 +930,217 @@ async def test_cancel_download_still_cancels_local_task_when_aria2_raises(monkey
|
|||||||
assert task.cancelled() or task.done()
|
assert task.cancelled() or task.done()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_download_preserves_tracking_when_aria2_returns_error(monkeypatch, tmp_path):
|
||||||
|
manager = DownloadManager()
|
||||||
|
download_id = "download-queued"
|
||||||
|
save_path = tmp_path / "file.safetensors"
|
||||||
|
save_path.write_text("partial")
|
||||||
|
(tmp_path / "file.safetensors.aria2").write_text("control")
|
||||||
|
|
||||||
|
pause_control = DownloadStreamControl()
|
||||||
|
manager._pause_events[download_id] = pause_control
|
||||||
|
manager._download_tasks[download_id] = object()
|
||||||
|
manager._active_downloads[download_id] = {
|
||||||
|
"transfer_backend": "aria2",
|
||||||
|
"status": "downloading",
|
||||||
|
"file_path": str(save_path),
|
||||||
|
}
|
||||||
|
|
||||||
|
await manager._aria2_state_store.upsert(
|
||||||
|
download_id,
|
||||||
|
{
|
||||||
|
"download_id": download_id,
|
||||||
|
"transfer_backend": "aria2",
|
||||||
|
"status": "downloading",
|
||||||
|
"save_path": str(save_path),
|
||||||
|
"file_path": str(save_path),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
cleanup_files = AsyncMock(return_value=None)
|
||||||
|
monkeypatch.setattr(manager, "_cleanup_cancelled_download_files", cleanup_files)
|
||||||
|
|
||||||
|
class DummyAria2Downloader:
|
||||||
|
async def cancel_download(self, _download_id):
|
||||||
|
return {"success": False, "error": "rpc unavailable"}
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_manager,
|
||||||
|
"get_aria2_downloader",
|
||||||
|
AsyncMock(return_value=DummyAria2Downloader()),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await manager.cancel_download(download_id)
|
||||||
|
|
||||||
|
assert result == {"success": False, "error": "rpc unavailable"}
|
||||||
|
assert download_id in manager._download_tasks
|
||||||
|
assert download_id in manager._pause_events
|
||||||
|
assert await manager._aria2_state_store.get(download_id) is not None
|
||||||
|
assert cleanup_files.await_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_download_rejects_completed_history_entry(tmp_path):
|
||||||
|
manager = DownloadManager()
|
||||||
|
download_id = "completed-download"
|
||||||
|
save_path = tmp_path / "file.safetensors"
|
||||||
|
metadata_path = tmp_path / "file.metadata.json"
|
||||||
|
preview_path = tmp_path / "file.jpeg"
|
||||||
|
save_path.write_text("complete")
|
||||||
|
metadata_path.write_text("{}")
|
||||||
|
preview_path.write_text("preview")
|
||||||
|
|
||||||
|
manager._active_downloads[download_id] = {
|
||||||
|
"transfer_backend": "aria2",
|
||||||
|
"status": "completed",
|
||||||
|
"file_path": str(save_path),
|
||||||
|
"preview_path": str(preview_path),
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await manager.cancel_download(download_id)
|
||||||
|
|
||||||
|
assert result == {"success": False, "error": "Download task not found"}
|
||||||
|
assert save_path.exists()
|
||||||
|
assert metadata_path.exists()
|
||||||
|
assert preview_path.exists()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_download_removes_preview_and_aria2_control_files(monkeypatch, tmp_path):
|
||||||
|
manager = DownloadManager()
|
||||||
|
|
||||||
|
started = asyncio.Event()
|
||||||
|
|
||||||
|
async def blocked_task():
|
||||||
|
started.set()
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
|
||||||
|
task = asyncio.create_task(blocked_task())
|
||||||
|
await started.wait()
|
||||||
|
|
||||||
|
save_path = tmp_path / "file.safetensors"
|
||||||
|
save_path.write_text("partial")
|
||||||
|
aria2_path = tmp_path / "file.safetensors.aria2"
|
||||||
|
aria2_path.write_text("control")
|
||||||
|
preview_path = tmp_path / "file.jpeg"
|
||||||
|
preview_path.write_text("preview")
|
||||||
|
|
||||||
|
download_id = "download-queued"
|
||||||
|
manager._download_tasks[download_id] = task
|
||||||
|
manager._active_downloads[download_id] = {
|
||||||
|
"transfer_backend": "aria2",
|
||||||
|
"status": "queued",
|
||||||
|
"file_path": str(save_path),
|
||||||
|
"aria2_control_path": str(aria2_path),
|
||||||
|
"preview_path": str(preview_path),
|
||||||
|
}
|
||||||
|
|
||||||
|
class DummyAria2Downloader:
|
||||||
|
async def cancel_download(self, _download_id):
|
||||||
|
return {"success": True, "message": "cancelled"}
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_manager,
|
||||||
|
"get_aria2_downloader",
|
||||||
|
AsyncMock(return_value=DummyAria2Downloader()),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await manager.cancel_download(download_id)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert not save_path.exists()
|
||||||
|
assert not aria2_path.exists()
|
||||||
|
assert not preview_path.exists()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_download_does_not_delete_untracked_same_basename_preview(
|
||||||
|
monkeypatch, tmp_path
|
||||||
|
):
|
||||||
|
manager = DownloadManager()
|
||||||
|
|
||||||
|
started = asyncio.Event()
|
||||||
|
|
||||||
|
async def blocked_task():
|
||||||
|
started.set()
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
|
||||||
|
task = asyncio.create_task(blocked_task())
|
||||||
|
await started.wait()
|
||||||
|
|
||||||
|
save_path = tmp_path / "file.safetensors"
|
||||||
|
save_path.write_text("partial")
|
||||||
|
aria2_path = tmp_path / "file.safetensors.aria2"
|
||||||
|
aria2_path.write_text("control")
|
||||||
|
manual_preview_path = tmp_path / "file.jpg"
|
||||||
|
manual_preview_path.write_text("manual")
|
||||||
|
|
||||||
|
download_id = "download-queued"
|
||||||
|
manager._download_tasks[download_id] = task
|
||||||
|
manager._active_downloads[download_id] = {
|
||||||
|
"transfer_backend": "aria2",
|
||||||
|
"status": "queued",
|
||||||
|
"file_path": str(save_path),
|
||||||
|
"aria2_control_path": str(aria2_path),
|
||||||
|
}
|
||||||
|
|
||||||
|
class DummyAria2Downloader:
|
||||||
|
async def cancel_download(self, _download_id):
|
||||||
|
return {"success": True, "message": "cancelled"}
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_manager,
|
||||||
|
"get_aria2_downloader",
|
||||||
|
AsyncMock(return_value=DummyAria2Downloader()),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await manager.cancel_download(download_id)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert not save_path.exists()
|
||||||
|
assert not aria2_path.exists()
|
||||||
|
assert manual_preview_path.exists()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cleanup_cancelled_download_files_retries_aria2_control_deletion(
|
||||||
|
monkeypatch, tmp_path
|
||||||
|
):
|
||||||
|
manager = DownloadManager()
|
||||||
|
download_id = "download-1"
|
||||||
|
|
||||||
|
save_path = tmp_path / "file.safetensors"
|
||||||
|
aria2_path = tmp_path / "file.safetensors.aria2"
|
||||||
|
save_path.write_text("partial")
|
||||||
|
aria2_path.write_text("control")
|
||||||
|
|
||||||
|
original_unlink = os.unlink
|
||||||
|
attempts = {"count": 0}
|
||||||
|
|
||||||
|
def flaky_unlink(path):
|
||||||
|
if path == str(aria2_path) and attempts["count"] == 0:
|
||||||
|
attempts["count"] += 1
|
||||||
|
raise PermissionError("still locked")
|
||||||
|
return original_unlink(path)
|
||||||
|
|
||||||
|
monkeypatch.setattr(download_manager.os, "unlink", flaky_unlink)
|
||||||
|
monkeypatch.setattr("py.services.download_manager.asyncio.sleep", AsyncMock())
|
||||||
|
|
||||||
|
await manager._cleanup_cancelled_download_files(
|
||||||
|
download_id,
|
||||||
|
{
|
||||||
|
"file_path": str(save_path),
|
||||||
|
"aria2_control_path": str(aria2_path),
|
||||||
|
"transfer_backend": "aria2",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert attempts["count"] == 1
|
||||||
|
assert not save_path.exists()
|
||||||
|
assert not aria2_path.exists()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_execute_download_waits_for_paused_pre_transfer_gate(monkeypatch, tmp_path):
|
async def test_execute_download_waits_for_paused_pre_transfer_gate(monkeypatch, tmp_path):
|
||||||
manager = DownloadManager()
|
manager = DownloadManager()
|
||||||
@@ -931,6 +1236,311 @@ async def test_execute_download_waits_for_paused_pre_transfer_gate(monkeypatch,
|
|||||||
assert result == {"success": True}
|
assert result == {"success": True}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_download_reuses_existing_aria2_partial_path(monkeypatch, tmp_path):
|
||||||
|
manager = DownloadManager()
|
||||||
|
|
||||||
|
save_dir = tmp_path / "downloads"
|
||||||
|
save_dir.mkdir()
|
||||||
|
target_path = save_dir / "file.safetensors"
|
||||||
|
target_path.write_text("partial")
|
||||||
|
control_path = save_dir / "file.safetensors.aria2"
|
||||||
|
control_path.write_text("control")
|
||||||
|
|
||||||
|
await manager._aria2_state_store.upsert(
|
||||||
|
"download-1",
|
||||||
|
{
|
||||||
|
"download_id": "download-1",
|
||||||
|
"transfer_backend": "aria2",
|
||||||
|
"save_path": str(target_path),
|
||||||
|
"file_path": str(target_path),
|
||||||
|
"status": "paused",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
class DummyMetadata:
|
||||||
|
def __init__(self, path: Path):
|
||||||
|
self.file_path = str(path)
|
||||||
|
self.sha256 = "sha256"
|
||||||
|
self.file_name = path.stem
|
||||||
|
self.preview_url = None
|
||||||
|
|
||||||
|
def generate_unique_filename(self, *_args, **_kwargs):
|
||||||
|
return "renamed.safetensors"
|
||||||
|
|
||||||
|
def update_file_info(self, _path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {"file_path": self.file_path}
|
||||||
|
|
||||||
|
manager._active_downloads["download-1"] = {"transfer_backend": "aria2"}
|
||||||
|
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
|
||||||
|
|
||||||
|
async def fake_download_model_file(
|
||||||
|
self,
|
||||||
|
download_url,
|
||||||
|
save_path,
|
||||||
|
*,
|
||||||
|
backend,
|
||||||
|
progress_callback,
|
||||||
|
use_auth,
|
||||||
|
download_id,
|
||||||
|
pause_control,
|
||||||
|
):
|
||||||
|
Path(save_path).write_text("content")
|
||||||
|
return True, save_path
|
||||||
|
|
||||||
|
monkeypatch.setattr(DownloadManager, "_download_model_file", fake_download_model_file)
|
||||||
|
|
||||||
|
result = await manager._execute_download(
|
||||||
|
download_urls=["https://example.com/file.safetensors"],
|
||||||
|
save_dir=str(save_dir),
|
||||||
|
metadata=DummyMetadata(target_path),
|
||||||
|
version_info={"images": []},
|
||||||
|
relative_path="",
|
||||||
|
progress_callback=None,
|
||||||
|
model_type="lora",
|
||||||
|
download_id="download-1",
|
||||||
|
transfer_backend="aria2",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {"success": True}
|
||||||
|
assert manager._active_downloads["download-1"]["file_path"] == str(target_path)
|
||||||
|
assert not (save_dir / "renamed.safetensors").exists()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_download_rejects_conflicting_aria2_partial_path(tmp_path):
|
||||||
|
manager = DownloadManager()
|
||||||
|
|
||||||
|
save_dir = tmp_path / "downloads"
|
||||||
|
save_dir.mkdir()
|
||||||
|
target_path = save_dir / "file.safetensors"
|
||||||
|
target_path.write_text("partial")
|
||||||
|
(save_dir / "file.safetensors.aria2").write_text("control")
|
||||||
|
|
||||||
|
await manager._aria2_state_store.upsert(
|
||||||
|
"other-download",
|
||||||
|
{
|
||||||
|
"download_id": "other-download",
|
||||||
|
"transfer_backend": "aria2",
|
||||||
|
"save_path": str(target_path),
|
||||||
|
"file_path": str(target_path),
|
||||||
|
"status": "paused",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
class DummyMetadata:
|
||||||
|
def __init__(self, path: Path):
|
||||||
|
self.file_path = str(path)
|
||||||
|
self.sha256 = "sha256"
|
||||||
|
self.file_name = path.stem
|
||||||
|
self.preview_url = None
|
||||||
|
|
||||||
|
def generate_unique_filename(self, *_args, **_kwargs):
|
||||||
|
raise AssertionError("should not rename")
|
||||||
|
|
||||||
|
result = await manager._execute_download(
|
||||||
|
download_urls=["https://example.com/file.safetensors"],
|
||||||
|
save_dir=str(save_dir),
|
||||||
|
metadata=DummyMetadata(target_path),
|
||||||
|
version_info={"images": []},
|
||||||
|
relative_path="",
|
||||||
|
progress_callback=None,
|
||||||
|
model_type="lora",
|
||||||
|
download_id="download-1",
|
||||||
|
transfer_backend="aria2",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "already using" in result["error"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_download_reassigns_same_aria2_partial_to_new_download_id(
|
||||||
|
monkeypatch, tmp_path
|
||||||
|
):
|
||||||
|
manager = DownloadManager()
|
||||||
|
|
||||||
|
save_dir = tmp_path / "downloads"
|
||||||
|
save_dir.mkdir()
|
||||||
|
target_path = save_dir / "file.safetensors"
|
||||||
|
target_path.write_text("partial")
|
||||||
|
(save_dir / "file.safetensors.aria2").write_text("control")
|
||||||
|
|
||||||
|
await manager._aria2_state_store.upsert(
|
||||||
|
"old-download",
|
||||||
|
{
|
||||||
|
"download_id": "old-download",
|
||||||
|
"transfer_backend": "aria2",
|
||||||
|
"save_path": str(target_path),
|
||||||
|
"file_path": str(target_path),
|
||||||
|
"status": "paused",
|
||||||
|
"model_id": 11,
|
||||||
|
"model_version_id": 22,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
class DummyMetadata:
|
||||||
|
def __init__(self, path: Path):
|
||||||
|
self.file_path = str(path)
|
||||||
|
self.sha256 = "sha256"
|
||||||
|
self.file_name = path.stem
|
||||||
|
self.preview_url = None
|
||||||
|
|
||||||
|
def generate_unique_filename(self, *_args, **_kwargs):
|
||||||
|
raise AssertionError("should not rename")
|
||||||
|
|
||||||
|
def update_file_info(self, _path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {"file_path": self.file_path}
|
||||||
|
|
||||||
|
class DummyAria2Downloader:
|
||||||
|
def __init__(self):
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
async def reassign_transfer(self, previous_download_id, new_download_id):
|
||||||
|
self.calls.append(("reassign_transfer", previous_download_id, new_download_id))
|
||||||
|
return None
|
||||||
|
|
||||||
|
dummy_aria2 = DummyAria2Downloader()
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_manager,
|
||||||
|
"get_aria2_downloader",
|
||||||
|
AsyncMock(return_value=dummy_aria2),
|
||||||
|
)
|
||||||
|
|
||||||
|
manager._active_downloads["old-download"] = {
|
||||||
|
"transfer_backend": "aria2",
|
||||||
|
"model_id": 11,
|
||||||
|
"model_version_id": 22,
|
||||||
|
"status": "paused",
|
||||||
|
}
|
||||||
|
manager._active_downloads["new-download"] = {
|
||||||
|
"transfer_backend": "aria2",
|
||||||
|
"model_id": 11,
|
||||||
|
"model_version_id": 22,
|
||||||
|
"status": "queued",
|
||||||
|
}
|
||||||
|
|
||||||
|
resolved, path = await manager._resolve_download_target_path(
|
||||||
|
str(save_dir),
|
||||||
|
DummyMetadata(target_path),
|
||||||
|
transfer_backend="aria2",
|
||||||
|
download_id="new-download",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resolved is True
|
||||||
|
assert path == str(target_path)
|
||||||
|
assert "old-download" not in manager._active_downloads
|
||||||
|
assert manager._active_downloads["new-download"]["file_path"] == str(target_path)
|
||||||
|
assert dummy_aria2.calls == [("reassign_transfer", "old-download", "new-download")]
|
||||||
|
assert await manager._aria2_state_store.get("old-download") is None
|
||||||
|
assert (await manager._aria2_state_store.get("new-download"))["save_path"] == str(
|
||||||
|
target_path
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_same_aria2_download_request_requires_version_id_match():
|
||||||
|
manager = DownloadManager()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
manager._is_same_aria2_download_request(
|
||||||
|
{"model_id": 1, "model_version_id": None},
|
||||||
|
{"model_id": 1, "model_version_id": 2},
|
||||||
|
)
|
||||||
|
is False
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
manager._is_same_aria2_download_request(
|
||||||
|
{"model_id": 1, "model_version_id": 3},
|
||||||
|
{"model_id": 1, "model_version_id": None},
|
||||||
|
)
|
||||||
|
is False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_adopt_existing_aria2_download_cancels_old_running_task(monkeypatch, tmp_path):
|
||||||
|
manager = DownloadManager()
|
||||||
|
save_path = tmp_path / "file.safetensors"
|
||||||
|
|
||||||
|
started = asyncio.Event()
|
||||||
|
cancelled = asyncio.Event()
|
||||||
|
call_order = []
|
||||||
|
|
||||||
|
async def old_download():
|
||||||
|
started.set()
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
call_order.append("old-task-cancelled")
|
||||||
|
cancelled.set()
|
||||||
|
raise
|
||||||
|
|
||||||
|
old_task = asyncio.create_task(old_download())
|
||||||
|
await started.wait()
|
||||||
|
|
||||||
|
manager._download_tasks["old-download"] = old_task
|
||||||
|
old_pause_control = DownloadStreamControl()
|
||||||
|
old_pause_control.pause()
|
||||||
|
manager._pause_events["old-download"] = old_pause_control
|
||||||
|
manager._active_downloads["old-download"] = {
|
||||||
|
"transfer_backend": "aria2",
|
||||||
|
"model_id": 11,
|
||||||
|
"model_version_id": 22,
|
||||||
|
"status": "downloading",
|
||||||
|
}
|
||||||
|
manager._active_downloads["new-download"] = {
|
||||||
|
"transfer_backend": "aria2",
|
||||||
|
"model_id": 11,
|
||||||
|
"model_version_id": 22,
|
||||||
|
"status": "queued",
|
||||||
|
}
|
||||||
|
|
||||||
|
await manager._aria2_state_store.upsert(
|
||||||
|
"old-download",
|
||||||
|
{
|
||||||
|
"download_id": "old-download",
|
||||||
|
"transfer_backend": "aria2",
|
||||||
|
"save_path": str(save_path),
|
||||||
|
"file_path": str(save_path),
|
||||||
|
"status": "downloading",
|
||||||
|
"model_id": 11,
|
||||||
|
"model_version_id": 22,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
class DummyAria2Downloader:
|
||||||
|
async def reassign_transfer(self, previous_download_id, new_download_id):
|
||||||
|
call_order.append("reassign-transfer")
|
||||||
|
return None
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_manager,
|
||||||
|
"get_aria2_downloader",
|
||||||
|
AsyncMock(return_value=DummyAria2Downloader()),
|
||||||
|
)
|
||||||
|
|
||||||
|
await manager._adopt_existing_aria2_download(
|
||||||
|
"old-download",
|
||||||
|
"new-download",
|
||||||
|
{"model_id": 11, "model_version_id": 22},
|
||||||
|
str(save_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert cancelled.is_set() is True
|
||||||
|
assert "old-download" not in manager._download_tasks
|
||||||
|
assert call_order == ["reassign-transfer", "old-task-cancelled"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_pause_download_rejects_unknown_task():
|
async def test_pause_download_rejects_unknown_task():
|
||||||
"""Test that pause_download rejects unknown download tasks."""
|
"""Test that pause_download rejects unknown download tasks."""
|
||||||
|
|||||||
Reference in New Issue
Block a user