From 761108bfd198854ba615b39e5ed0547ef736fc37 Mon Sep 17 00:00:00 2001 From: Will Miao Date: Mon, 20 Apr 2026 09:52:48 +0800 Subject: [PATCH] fix(download): restore aria2 resume lifecycle --- py/services/aria2_downloader.py | 123 ++- py/services/aria2_transfer_state.py | 108 +++ py/services/download_manager.py | 876 ++++++++++++++++-- tests/services/test_aria2_downloader.py | 85 ++ tests/services/test_download_manager_basic.py | 441 +++++++++ tests/services/test_download_manager_error.py | 610 ++++++++++++ 6 files changed, 2123 insertions(+), 120 deletions(-) create mode 100644 py/services/aria2_transfer_state.py diff --git a/py/services/aria2_downloader.py b/py/services/aria2_downloader.py index d1b1b018..f50b6a1c 100644 --- a/py/services/aria2_downloader.py +++ b/py/services/aria2_downloader.py @@ -15,6 +15,7 @@ from typing import Any, Dict, Optional, Tuple import aiohttp from .downloader import DownloadProgress, get_downloader +from .aria2_transfer_state import Aria2TransferStateStore from .settings_manager import get_settings_manager logger = logging.getLogger(__name__) @@ -64,6 +65,7 @@ class Aria2Downloader: self._process_lock = asyncio.Lock() self._transfers: Dict[str, Aria2Transfer] = {} self._poll_interval = 0.5 + self._state_store = Aria2TransferStateStore() @property def is_running(self) -> bool: @@ -82,6 +84,48 @@ class Aria2Downloader: await self._ensure_process() save_path = os.path.abspath(save_path) + transfer = self._transfers.get(download_id) + if transfer is None or os.path.abspath(transfer.save_path) != save_path: + gid = await self._schedule_download( + url, + save_path, + download_id=download_id, + headers=headers, + ) + transfer = Aria2Transfer(gid=gid, save_path=save_path) + self._transfers[download_id] = transfer + + try: + while True: + status = await self.get_status(download_id) + if status is None: + return False, "aria2 download not found" + + snapshot = self._build_progress_snapshot(status) + if progress_callback is not None: + await self._dispatch_progress(progress_callback, snapshot) + + state = status.get("status", "") + if state == "complete": + completed_path = self._resolve_completed_path(status, save_path) + return True, completed_path + if state == "error": + return False, status.get("errorMessage") or "aria2 download failed" + if state == "removed": + return False, "Download was cancelled" + + await asyncio.sleep(self._poll_interval) + finally: + self._transfers.pop(download_id, None) + + async def _schedule_download( + self, + url: str, + save_path: str, + *, + download_id: str, + headers: Optional[Dict[str, str]] = None, + ) -> str: save_dir = os.path.dirname(save_path) out_name = os.path.basename(save_path) @@ -128,31 +172,16 @@ class Aria2Downloader: raise Aria2Error(f"Failed to schedule aria2 download: {exc}") from exc logger.debug("aria2 accepted download %s with gid %s", download_id, gid) - - self._transfers[download_id] = Aria2Transfer(gid=gid, save_path=save_path) - - try: - while True: - status = await self.get_status(download_id) - if status is None: - return False, "aria2 download not found" - - snapshot = self._build_progress_snapshot(status) - if progress_callback is not None: - await self._dispatch_progress(progress_callback, snapshot) - - state = status.get("status", "") - if state == "complete": - completed_path = self._resolve_completed_path(status, save_path) - return True, completed_path - if state == "error": - return False, status.get("errorMessage") or "aria2 download failed" - if state == "removed": - return False, "Download was cancelled" - - await asyncio.sleep(self._poll_interval) - finally: - self._transfers.pop(download_id, None) + await self._state_store.upsert( + download_id, + { + "gid": gid, + "save_path": save_path, + "status": "downloading", + "url": url, + }, + ) + return gid async def get_status(self, download_id: str) -> Optional[Dict[str, Any]]: """Return the raw aria2 status payload for a known download.""" @@ -179,6 +208,47 @@ class Aria2Downloader: return status return None + async def get_status_by_gid(self, gid: str) -> Optional[Dict[str, Any]]: + keys = [ + "gid", + "status", + "totalLength", + "completedLength", + "downloadSpeed", + "errorMessage", + "files", + ] + try: + status = await self._rpc_call("aria2.tellStatus", [gid, keys]) + except Exception as exc: + message = str(exc) + if "cannot be found" in message.lower() or "not found" in message.lower(): + return None + raise Aria2Error(f"Failed to query aria2 download status: {exc}") from exc + + if isinstance(status, dict): + return status + return None + + async def restore_transfer(self, download_id: str, gid: str, save_path: str) -> None: + await self._ensure_process() + self._transfers[download_id] = Aria2Transfer( + gid=gid, + save_path=os.path.abspath(save_path), + ) + + async def reassign_transfer( + self, from_download_id: str, to_download_id: str + ) -> Optional[Aria2Transfer]: + transfer = self._transfers.get(from_download_id) + if transfer is None: + return None + + self._transfers[to_download_id] = transfer + if from_download_id != to_download_id: + self._transfers.pop(from_download_id, None) + return transfer + async def has_transfer(self, download_id: str) -> bool: return download_id in self._transfers @@ -192,6 +262,7 @@ class Aria2Downloader: except Exception as exc: return {"success": False, "error": str(exc)} + await self._state_store.upsert(download_id, {"status": "paused"}) return {"success": True, "message": "Download paused successfully"} async def resume_download(self, download_id: str) -> Dict[str, Any]: @@ -204,6 +275,7 @@ class Aria2Downloader: except Exception as exc: return {"success": False, "error": str(exc)} + await self._state_store.upsert(download_id, {"status": "downloading"}) return {"success": True, "message": "Download resumed successfully"} async def cancel_download(self, download_id: str) -> Dict[str, Any]: @@ -216,6 +288,7 @@ class Aria2Downloader: except Exception as exc: return {"success": False, "error": str(exc)} + await self._state_store.remove(download_id) return {"success": True, "message": "Download cancelled successfully"} async def close(self) -> None: diff --git a/py/services/aria2_transfer_state.py b/py/services/aria2_transfer_state.py new file mode 100644 index 00000000..1754c95d --- /dev/null +++ b/py/services/aria2_transfer_state.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import asyncio +import json +import os +from copy import deepcopy +from typing import Any, Dict, Optional + +from ..utils.cache_paths import get_cache_base_dir + + +def get_aria2_state_path() -> str: + base_dir = get_cache_base_dir(create=True) + state_dir = os.path.join(base_dir, "aria2") + os.makedirs(state_dir, exist_ok=True) + return os.path.join(state_dir, "downloads.json") + + +class Aria2TransferStateStore: + """Persist aria2 transfer metadata needed for restart recovery.""" + + _locks_by_path: Dict[str, asyncio.Lock] = {} + + def __init__(self, state_path: Optional[str] = None) -> None: + self._state_path = os.path.abspath(state_path or get_aria2_state_path()) + self._lock = self._locks_by_path.setdefault(self._state_path, asyncio.Lock()) + + def _read_all_unlocked(self) -> Dict[str, Dict[str, Any]]: + try: + with open(self._state_path, "r", encoding="utf-8") as handle: + data = json.load(handle) + except FileNotFoundError: + return {} + except json.JSONDecodeError: + return {} + + if not isinstance(data, dict): + return {} + + normalized: Dict[str, Dict[str, Any]] = {} + for download_id, entry in data.items(): + if isinstance(download_id, str) and isinstance(entry, dict): + normalized[download_id] = entry + return normalized + + def _write_all_unlocked(self, data: Dict[str, Dict[str, Any]]) -> None: + directory = os.path.dirname(self._state_path) + if directory: + os.makedirs(directory, exist_ok=True) + + temp_path = f"{self._state_path}.tmp" + with open(temp_path, "w", encoding="utf-8") as handle: + json.dump(data, handle, ensure_ascii=True, indent=2, sort_keys=True) + os.replace(temp_path, self._state_path) + + async def load_all(self) -> Dict[str, Dict[str, Any]]: + async with self._lock: + return deepcopy(self._read_all_unlocked()) + + async def get(self, download_id: str) -> Optional[Dict[str, Any]]: + async with self._lock: + return deepcopy(self._read_all_unlocked().get(download_id)) + + async def upsert(self, download_id: str, payload: Dict[str, Any]) -> Dict[str, Any]: + async with self._lock: + data = self._read_all_unlocked() + current = data.get(download_id, {}) + current.update(payload) + data[download_id] = current + self._write_all_unlocked(data) + return deepcopy(current) + + async def remove(self, download_id: str) -> None: + async with self._lock: + data = self._read_all_unlocked() + if download_id in data: + del data[download_id] + self._write_all_unlocked(data) + + async def find_by_save_path( + self, save_path: str, *, exclude_download_id: Optional[str] = None + ) -> Optional[Dict[str, Any]]: + normalized_target = os.path.abspath(save_path) + async with self._lock: + data = self._read_all_unlocked() + for download_id, entry in data.items(): + if exclude_download_id and download_id == exclude_download_id: + continue + candidate = entry.get("save_path") + if isinstance(candidate, str) and os.path.abspath(candidate) == normalized_target: + result = dict(entry) + result["download_id"] = download_id + return result + return None + + async def reassign(self, from_download_id: str, to_download_id: str) -> Optional[Dict[str, Any]]: + async with self._lock: + data = self._read_all_unlocked() + existing = data.get(from_download_id) + if existing is None: + return None + updated = dict(existing) + updated["download_id"] = to_download_id + data[to_download_id] = updated + if from_download_id != to_download_id: + data.pop(from_download_id, None) + self._write_all_unlocked(data) + return deepcopy(updated) diff --git a/py/services/download_manager.py b/py/services/download_manager.py index ce727abd..5b297356 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -27,6 +27,7 @@ from .settings_manager import get_settings_manager from .metadata_service import get_default_metadata_provider, get_metadata_provider from .downloader import get_downloader, DownloadProgress, DownloadStreamControl from .aria2_downloader import Aria2Error, get_aria2_downloader +from .aria2_transfer_state import Aria2TransferStateStore # Download to temporary file first import tempfile @@ -65,6 +66,9 @@ class DownloadManager: self._archive_executor = ThreadPoolExecutor( max_workers=2, thread_name_prefix="lm-archive" ) + self._aria2_state_store = Aria2TransferStateStore() + self._restored_persisted_downloads = False + self._restore_lock = asyncio.Lock() @staticmethod def _get_model_download_backend() -> str: @@ -179,6 +183,11 @@ class DownloadManager: self._active_downloads[task_id] = { "model_id": model_id, "model_version_id": model_version_id, + "save_dir": save_dir, + "relative_path": relative_path, + "use_default_paths": bool(use_default_paths), + "source": source, + "file_params": copy.deepcopy(file_params) if file_params is not None else None, "progress": 0, "status": "queued", "transfer_backend": self._get_model_download_backend(), @@ -191,6 +200,9 @@ class DownloadManager: pause_control = DownloadStreamControl() self._pause_events[task_id] = pause_control + if self._active_downloads[task_id]["transfer_backend"] == "aria2": + await self._persist_aria2_state(task_id) + # Create tracking task download_task = asyncio.create_task( self._download_with_semaphore( @@ -242,6 +254,8 @@ class DownloadManager: # Update status to waiting if task_id in self._active_downloads: self._active_downloads[task_id]["status"] = "waiting" + if self._active_downloads[task_id].get("transfer_backend") == "aria2": + await self._persist_aria2_state(task_id) # Wrap progress callback to track progress in active_downloads original_callback = progress_callback @@ -276,11 +290,15 @@ class DownloadManager: if task_id in self._active_downloads: self._active_downloads[task_id]["status"] = "paused" self._active_downloads[task_id]["bytes_per_second"] = 0.0 + if self._active_downloads[task_id].get("transfer_backend") == "aria2": + await self._persist_aria2_state(task_id) await pause_control.wait() # Update status to downloading if task_id in self._active_downloads: self._active_downloads[task_id]["status"] = "downloading" + if self._active_downloads[task_id].get("transfer_backend") == "aria2": + await self._persist_aria2_state(task_id) # Use original download implementation try: @@ -315,6 +333,8 @@ class DownloadManager: "error", "Unknown error" ) self._active_downloads[task_id]["bytes_per_second"] = 0.0 + if self._active_downloads[task_id].get("transfer_backend") == "aria2": + await self._persist_aria2_state(task_id) return result except asyncio.CancelledError: @@ -322,6 +342,8 @@ class DownloadManager: if task_id in self._active_downloads: self._active_downloads[task_id]["status"] = "cancelled" self._active_downloads[task_id]["bytes_per_second"] = 0.0 + if self._active_downloads[task_id].get("transfer_backend") == "aria2": + await self._persist_aria2_state(task_id) logger.info(f"Download cancelled for task {task_id}") raise except Exception as e: @@ -333,17 +355,639 @@ class DownloadManager: self._active_downloads[task_id]["status"] = "failed" self._active_downloads[task_id]["error"] = str(e) self._active_downloads[task_id]["bytes_per_second"] = 0.0 + if self._active_downloads[task_id].get("transfer_backend") == "aria2": + await self._persist_aria2_state(task_id) return {"success": False, "error": str(e)} finally: # Schedule cleanup of download record after delay asyncio.create_task(self._cleanup_download_record(task_id)) + def _start_background_download_task(self, download_id: str, coroutine) -> asyncio.Task: + task = asyncio.create_task(coroutine) + self._download_tasks[download_id] = task + + def _cleanup_done_task(done_task: asyncio.Task) -> None: + current_task = self._download_tasks.get(download_id) + if current_task is done_task: + self._download_tasks.pop(download_id, None) + self._pause_events.pop(download_id, None) + + task.add_done_callback(_cleanup_done_task) + return task + async def _cleanup_download_record(self, task_id: str): """Keep completed downloads in history for a short time""" await asyncio.sleep(600) # Keep for 10 minutes if task_id in self._active_downloads: del self._active_downloads[task_id] + async def _delete_file_with_retries( + self, + path: Optional[str], + *, + retries: int = 5, + delay: float = 0.1, + ) -> bool: + if not path: + return False + + for attempt in range(retries): + if not os.path.exists(path): + return True + try: + os.unlink(path) + return True + except FileNotFoundError: + return True + except Exception: + if attempt == retries - 1: + return False + await asyncio.sleep(delay) + return False + + async def _cleanup_cancelled_download_files( + self, + download_id: str, + download_info: Optional[Dict], + ) -> None: + target_files = set() + persisted = await self._aria2_state_store.get(download_id) + + primary_path = None + if isinstance(download_info, dict): + primary_path = download_info.get("file_path") + if not primary_path and isinstance(persisted, dict): + primary_path = persisted.get("save_path") or persisted.get("file_path") + if primary_path: + target_files.add(primary_path) + + if isinstance(download_info, dict): + for extra_path in download_info.get("extracted_paths", []): + if extra_path: + target_files.add(extra_path) + + for file_path in target_files: + deleted = await self._delete_file_with_retries(file_path) + if deleted: + logger.debug(f"Deleted cancelled download: {file_path}") + elif os.path.exists(file_path): + logger.error(f"Error deleting file: {file_path}") + + part_path = None + if isinstance(download_info, dict): + part_path = download_info.get("part_path") + if part_path: + deleted = await self._delete_file_with_retries(part_path) + if deleted: + logger.debug(f"Deleted partial download: {part_path}") + elif os.path.exists(part_path): + logger.error(f"Error deleting part file: {part_path}") + + aria2_control_path = None + if isinstance(download_info, dict): + aria2_control_path = download_info.get("aria2_control_path") + if not aria2_control_path and primary_path: + aria2_control_path = f"{primary_path}.aria2" + if aria2_control_path: + deleted = await self._delete_file_with_retries(aria2_control_path) + if deleted: + logger.debug(f"Deleted aria2 control file: {aria2_control_path}") + elif os.path.exists(aria2_control_path): + logger.warning( + "Failed to delete aria2 control file after retries: %s", + aria2_control_path, + ) + + for file_path in target_files: + metadata_path = os.path.splitext(file_path)[0] + ".metadata.json" + deleted = await self._delete_file_with_retries(metadata_path) + if not deleted and os.path.exists(metadata_path): + logger.error(f"Error deleting metadata file: {metadata_path}") + + preview_candidates = set() + if isinstance(download_info, dict): + preview_path_value = download_info.get("preview_path") + if preview_path_value: + preview_candidates.add(preview_path_value) + + for preview_path in preview_candidates: + deleted = await self._delete_file_with_retries(preview_path) + if deleted and not os.path.exists(preview_path): + logger.debug(f"Deleted preview file: {preview_path}") + elif os.path.exists(preview_path): + logger.error(f"Error deleting preview file: {preview_path}") + + async def _persist_aria2_state( + self, + download_id: str, + *, + extra: Optional[Dict] = None, + ) -> None: + info = self._active_downloads.get(download_id) + if not info: + return + + payload = { + "download_id": download_id, + "model_id": info.get("model_id"), + "model_version_id": info.get("model_version_id"), + "save_dir": info.get("save_dir"), + "relative_path": info.get("relative_path", ""), + "use_default_paths": bool(info.get("use_default_paths", False)), + "source": info.get("source"), + "file_params": copy.deepcopy(info.get("file_params")), + "transfer_backend": info.get("transfer_backend", "aria2"), + "status": info.get("status", "queued"), + "progress": info.get("progress", 0), + "bytes_downloaded": info.get("bytes_downloaded", 0), + "total_bytes": info.get("total_bytes"), + "bytes_per_second": info.get("bytes_per_second", 0.0), + "file_path": info.get("file_path"), + } + if extra: + payload.update(extra) + + await self._aria2_state_store.upsert(download_id, payload) + + def _build_restored_download_info(self, record: Dict, save_path: str) -> Dict: + return { + "model_id": record.get("model_id"), + "model_version_id": record.get("model_version_id"), + "save_dir": record.get("save_dir"), + "relative_path": record.get("relative_path", ""), + "use_default_paths": bool(record.get("use_default_paths", False)), + "source": record.get("source"), + "file_params": copy.deepcopy(record.get("file_params")), + "progress": record.get("progress", 0), + "status": record.get("status", "paused"), + "transfer_backend": "aria2", + "bytes_downloaded": record.get("bytes_downloaded", 0), + "total_bytes": record.get("total_bytes"), + "bytes_per_second": record.get("bytes_per_second", 0.0), + "last_progress_timestamp": None, + "file_path": save_path, + "aria2_control_path": f"{save_path}.aria2", + } + + def _is_same_aria2_download_request( + self, + current_info: Optional[Dict], + persisted_record: Dict, + ) -> bool: + if not isinstance(current_info, dict): + return False + + current_version_id = current_info.get("model_version_id") + persisted_version_id = persisted_record.get("model_version_id") + if current_version_id is None or persisted_version_id is None: + return False + + return current_version_id == persisted_version_id + + def _build_download_urls_from_file_info(self, file_info: Dict, source: str = None) -> List[str]: + mirrors = file_info.get("mirrors") or [] + download_urls: List[str] = [] + if mirrors: + for mirror in mirrors: + if mirror.get("deletedAt") is None and mirror.get("url"): + download_urls.append(normalize_civitai_download_url(mirror["url"])) + + if source == "civarchive" and len(download_urls) > 1: + civitai_urls = [ + u for u in download_urls if u.startswith(CIVITAI_DOWNLOAD_URL_PREFIXES) + ] + non_civitai_urls = [ + u for u in download_urls if not u.startswith(CIVITAI_DOWNLOAD_URL_PREFIXES) + ] + download_urls = non_civitai_urls + civitai_urls + else: + download_url = file_info.get("downloadUrl") + if download_url: + download_urls.append(normalize_civitai_download_url(download_url)) + + return download_urls + + def _build_metadata_for_resume( + self, + *, + model_type: str, + version_info: Dict, + file_info: Dict, + save_path: str, + ): + if model_type == "checkpoint": + return CheckpointMetadata.from_civitai_info(version_info, file_info, save_path) + if model_type == "embedding": + return EmbeddingMetadata.from_civitai_info(version_info, file_info, save_path) + return LoraMetadata.from_civitai_info(version_info, file_info, save_path) + + def _resolve_save_path_from_persisted_record(self, record: Dict) -> Optional[str]: + save_path = record.get("save_path") or record.get("file_path") + if isinstance(save_path, str) and save_path: + return os.path.abspath(save_path) + + resume_context = record.get("resume_context") + if not isinstance(resume_context, dict): + return None + + save_dir = resume_context.get("save_dir") + file_info = resume_context.get("file_info") + if not isinstance(save_dir, str) or not save_dir: + return None + if not isinstance(file_info, dict): + return None + + file_name = file_info.get("name") + if not isinstance(file_name, str) or not file_name: + return None + + return os.path.abspath(os.path.join(save_dir, file_name)) + + async def _resume_restored_aria2_download(self, download_id: str, record: Dict) -> Dict: + try: + if download_id in self._active_downloads: + self._active_downloads[download_id]["status"] = "downloading" + self._active_downloads[download_id]["bytes_per_second"] = 0.0 + if self._active_downloads[download_id].get("transfer_backend") == "aria2": + await self._persist_aria2_state(download_id) + + resume_context = record.get("resume_context") + if not isinstance(resume_context, dict): + result = {"success": False, "error": "Missing aria2 resume context"} + else: + version_info = copy.deepcopy(resume_context.get("version_info") or {}) + file_info = copy.deepcopy(resume_context.get("file_info") or {}) + model_type = (resume_context.get("model_type") or "").lower() + relative_path = resume_context.get("relative_path", "") + save_dir = resume_context.get("save_dir") + source = record.get("source") + + if not version_info or not file_info or not model_type or not save_dir: + result = {"success": False, "error": "Incomplete aria2 resume context"} + else: + save_path = ( + record.get("save_path") + or record.get("file_path") + or os.path.join(save_dir, file_info.get("name", "")) + ) + metadata = self._build_metadata_for_resume( + model_type=model_type, + version_info=version_info, + file_info=file_info, + save_path=save_path, + ) + download_urls = resume_context.get("download_urls") + if not isinstance(download_urls, list) or not download_urls: + download_urls = self._build_download_urls_from_file_info( + file_info, source=source + ) + if not download_urls: + result = {"success": False, "error": "No mirror URL found"} + else: + result = await self._execute_download( + download_urls=download_urls, + save_dir=save_dir, + metadata=metadata, + version_info=version_info, + relative_path=relative_path, + progress_callback=None, + model_type=model_type, + download_id=download_id, + transfer_backend="aria2", + ) + + if result.get("success", False): + resolved_model_id = ( + record.get("model_id") + or version_info.get("modelId") + or (version_info.get("model") or {}).get("id") + ) + await self._record_downloaded_version_history( + model_type, + resolved_model_id, + version_info, + record.get("model_version_id"), + record.get("save_path") or record.get("file_path"), + ) + await self._sync_downloaded_version( + model_type, + resolved_model_id, + version_info, + record.get("model_version_id"), + ) + + if download_id in self._active_downloads: + self._active_downloads[download_id]["status"] = ( + result.get("status", "completed") + if result["success"] + else "failed" + ) + if not result["success"]: + self._active_downloads[download_id]["error"] = result.get( + "error", "Unknown error" + ) + self._active_downloads[download_id]["bytes_per_second"] = 0.0 + if self._active_downloads[download_id].get("transfer_backend") == "aria2": + await self._persist_aria2_state(download_id) + + return result + except asyncio.CancelledError: + if download_id in self._active_downloads: + self._active_downloads[download_id]["status"] = "cancelled" + self._active_downloads[download_id]["bytes_per_second"] = 0.0 + if self._active_downloads[download_id].get("transfer_backend") == "aria2": + await self._persist_aria2_state(download_id) + logger.info(f"Download cancelled for task {download_id}") + raise + except Exception as exc: + logger.error( + f"Download error for task {download_id}: {str(exc)}", exc_info=True + ) + if download_id in self._active_downloads: + self._active_downloads[download_id]["status"] = "failed" + self._active_downloads[download_id]["error"] = str(exc) + self._active_downloads[download_id]["bytes_per_second"] = 0.0 + if self._active_downloads[download_id].get("transfer_backend") == "aria2": + await self._persist_aria2_state(download_id) + return {"success": False, "error": str(exc)} + finally: + asyncio.create_task(self._cleanup_download_record(download_id)) + + async def _adopt_existing_aria2_download( + self, + previous_download_id: str, + new_download_id: str, + persisted_record: Dict, + save_path: str, + ) -> None: + aria2_downloader = await get_aria2_downloader() + await aria2_downloader.reassign_transfer(previous_download_id, new_download_id) + + old_task = self._download_tasks.get(previous_download_id) + if old_task is not None and not old_task.done(): + old_task.cancel() + old_pause_control = self._pause_events.get(previous_download_id) + if old_pause_control is not None: + old_pause_control.resume() + try: + await asyncio.wait_for(asyncio.shield(old_task), timeout=2.0) + except (asyncio.CancelledError, asyncio.TimeoutError): + pass + + if previous_download_id != new_download_id: + self._active_downloads.pop(previous_download_id, None) + self._pause_events.pop(previous_download_id, None) + self._download_tasks.pop(previous_download_id, None) + + reassigned = await self._aria2_state_store.reassign( + previous_download_id, new_download_id + ) + merged_record = dict(persisted_record) + if reassigned: + merged_record.update(reassigned) + + current_info = self._active_downloads.get(new_download_id) + if current_info is not None: + current_info.update( + { + "model_id": merged_record.get("model_id", current_info.get("model_id")), + "model_version_id": merged_record.get( + "model_version_id", current_info.get("model_version_id") + ), + "save_dir": merged_record.get("save_dir", current_info.get("save_dir")), + "relative_path": merged_record.get( + "relative_path", current_info.get("relative_path", "") + ), + "source": merged_record.get("source", current_info.get("source")), + "file_params": copy.deepcopy( + merged_record.get("file_params", current_info.get("file_params")) + ), + "file_path": save_path, + "aria2_control_path": f"{save_path}.aria2", + } + ) + else: + self._active_downloads[new_download_id] = self._build_restored_download_info( + merged_record, save_path + ) + + async def _restore_persisted_downloads(self) -> None: + if self._restored_persisted_downloads: + return + + async with self._restore_lock: + if self._restored_persisted_downloads: + return + + persisted = await self._aria2_state_store.load_all() + if not persisted: + self._restored_persisted_downloads = True + return + + aria2_downloader = await get_aria2_downloader() + for download_id, record in persisted.items(): + if record.get("transfer_backend") != "aria2": + continue + + save_path = self._resolve_save_path_from_persisted_record(record) + if save_path is None: + continue + + if ( + record.get("save_path") != save_path + or record.get("file_path") != save_path + ): + await self._aria2_state_store.upsert( + download_id, + { + "save_path": save_path, + "file_path": save_path, + }, + ) + control_path = f"{save_path}.aria2" + gid = record.get("gid") + status_payload = None + if isinstance(gid, str) and gid: + try: + status_payload = await aria2_downloader.get_status_by_gid(gid) + except Exception: + status_payload = None + + if status_payload is not None: + remote_status = status_payload.get("status", "") + if remote_status in {"active", "waiting", "paused"}: + await aria2_downloader.restore_transfer(download_id, gid, save_path) + restored = self._active_downloads.setdefault( + download_id, + self._build_restored_download_info(record, save_path), + ) + restored["status"] = ( + "paused" if remote_status == "paused" else "downloading" + ) + pause_control = self._pause_events.get(download_id) + if pause_control is None: + pause_control = DownloadStreamControl() + self._pause_events[download_id] = pause_control + if remote_status == "paused": + pause_control.pause() + else: + pause_control.resume() + await self._aria2_state_store.upsert( + download_id, + { + "gid": gid, + "save_path": save_path, + "file_path": save_path, + "status": restored["status"], + }, + ) + if ( + remote_status in {"active", "waiting"} + and download_id not in self._download_tasks + ): + resume_context = record.get("resume_context") + if isinstance(resume_context, dict): + self._start_background_download_task( + download_id, + self._resume_restored_aria2_download( + download_id, + dict(record), + ) + ) + else: + self._start_background_download_task( + download_id, + self._download_with_semaphore( + download_id, + restored.get("model_id"), + restored.get("model_version_id"), + restored.get("save_dir"), + restored.get("relative_path", ""), + None, + bool(restored.get("use_default_paths", False)), + restored.get("source"), + restored.get("file_params"), + ) + ) + continue + + if remote_status == "complete" and not os.path.exists(control_path): + await self._aria2_state_store.remove(download_id) + continue + + if os.path.exists(save_path) and os.path.exists(control_path): + restored = self._active_downloads.setdefault( + download_id, + self._build_restored_download_info(record, save_path), + ) + pause_control = self._pause_events.get(download_id) + if pause_control is None: + pause_control = DownloadStreamControl() + self._pause_events[download_id] = pause_control + + # No live aria2 gid was found, so restore this partial as resumable-but-paused. + pause_control.pause() + restored["status"] = "paused" + await self._aria2_state_store.upsert( + download_id, + { + "save_path": save_path, + "file_path": save_path, + "status": "paused", + }, + ) + continue + + await self._aria2_state_store.remove(download_id) + + self._restored_persisted_downloads = True + + async def _resolve_download_target_path( + self, + save_dir: str, + metadata, + *, + transfer_backend: str, + download_id: Optional[str], + ) -> Tuple[bool, str]: + original_filename = os.path.basename(metadata.file_path) + base_name, extension = os.path.splitext(original_filename) + original_path = os.path.join(save_dir, original_filename) + + if transfer_backend == "aria2": + control_path = f"{original_path}.aria2" + if os.path.exists(original_path) and os.path.exists(control_path): + persisted_record = None + if download_id: + persisted_record = await self._aria2_state_store.get(download_id) + if persisted_record: + persisted_path = ( + persisted_record.get("save_path") + or persisted_record.get("file_path") + ) + if isinstance(persisted_path, str) and os.path.abspath( + persisted_path + ) == os.path.abspath(original_path): + logger.info( + "Reusing aria2 partial target %s for %s", + original_path, + download_id, + ) + return True, original_path + + conflict_record = await self._aria2_state_store.find_by_save_path( + original_path, exclude_download_id=download_id + ) + if conflict_record is not None: + current_info = self._active_downloads.get(download_id) if download_id else None + if download_id and self._is_same_aria2_download_request( + current_info, conflict_record + ): + logger.info( + "Reassigning aria2 partial target %s from %s to %s", + original_path, + conflict_record.get("download_id"), + download_id, + ) + await self._adopt_existing_aria2_download( + conflict_record["download_id"], + download_id, + conflict_record, + original_path, + ) + return True, original_path + + return ( + False, + f"Another aria2 download is already using '{original_filename}' for resume", + ) + + if download_id: + logger.info( + "Reusing aria2 partial target %s for %s", + original_path, + download_id, + ) + return True, original_path + + def hash_provider(): + return metadata.sha256 + + unique_filename = metadata.generate_unique_filename( + save_dir, base_name, extension, hash_provider=hash_provider + ) + + if unique_filename != original_filename: + logger.info( + f"Filename conflict detected. Changing '{original_filename}' to '{unique_filename}'" + ) + save_path = os.path.join(save_dir, unique_filename) + metadata.file_path = save_path.replace(os.sep, "/") + metadata.file_name = os.path.splitext(unique_filename)[0] + return True, save_path + + return True, metadata.file_path + async def _execute_original_download( self, model_id, @@ -756,6 +1400,23 @@ class DownloadManager: logger.info(f"Creating EmbeddingMetadata for {file_name}") # 6. Start download process + if transfer_backend == "aria2" and download_id: + await self._persist_aria2_state( + download_id, + extra={ + "save_dir": save_dir, + "relative_path": relative_path, + "resume_context": { + "version_info": copy.deepcopy(version_info), + "file_info": copy.deepcopy(file_info), + "model_type": model_type, + "relative_path": relative_path, + "save_dir": save_dir, + "download_urls": copy.deepcopy(download_urls), + }, + }, + ) + execute_kwargs = { "download_urls": download_urls, "save_dir": save_dir, @@ -1048,30 +1709,14 @@ class DownloadManager: preview_nsfw_level = 0 transfer_backend = (transfer_backend or self._get_model_download_backend()).lower() try: - # Extract original filename details - original_filename = os.path.basename(metadata.file_path) - base_name, extension = os.path.splitext(original_filename) - - # Check for filename conflicts and generate unique filename if needed - # Use the hash from metadata for conflict resolution - def hash_provider(): - return metadata.sha256 - - unique_filename = metadata.generate_unique_filename( - save_dir, base_name, extension, hash_provider=hash_provider + resolved, save_path = await self._resolve_download_target_path( + save_dir, + metadata, + transfer_backend=transfer_backend, + download_id=download_id, ) - - # Update paths if filename changed - if unique_filename != original_filename: - logger.info( - f"Filename conflict detected. Changing '{original_filename}' to '{unique_filename}'" - ) - save_path = os.path.join(save_dir, unique_filename) - # Update metadata with new file path and name - metadata.file_path = save_path.replace(os.sep, "/") - metadata.file_name = os.path.splitext(unique_filename)[0] - else: - save_path = metadata.file_path + if not resolved: + return {"success": False, "error": save_path} part_path = save_path + ".part" metadata_path = os.path.splitext(save_path)[0] + ".metadata.json" @@ -1081,7 +1726,12 @@ class DownloadManager: # Store file paths in active_downloads for potential cleanup if download_id and download_id in self._active_downloads: self._active_downloads[download_id]["file_path"] = save_path - self._active_downloads[download_id]["part_path"] = part_path + if transfer_backend == "python": + self._active_downloads[download_id]["part_path"] = part_path + if transfer_backend == "aria2": + self._active_downloads[download_id]["aria2_control_path"] = ( + f"{save_path}.aria2" + ) # Download preview image if available images = version_info.get("images", []) @@ -1205,6 +1855,8 @@ class DownloadManager: preview_nsfw_level = nsfw_level metadata.preview_url = preview_path.replace(os.sep, "/") metadata.preview_nsfw_level = nsfw_level + if download_id and download_id in self._active_downloads: + self._active_downloads[download_id]["preview_path"] = preview_path if progress_callback: await progress_callback(3) # 3% progress after preview download @@ -1226,6 +1878,18 @@ class DownloadManager: for download_url in download_urls: download_url = normalize_civitai_download_url(download_url) use_auth = download_url.startswith(CIVITAI_DOWNLOAD_URL_PREFIXES) + if transfer_backend == "aria2" and download_id: + await self._persist_aria2_state( + download_id, + extra={ + "status": self._active_downloads.get(download_id, {}).get( + "status", "downloading" + ), + "save_path": save_path, + "file_path": save_path, + "url": download_url, + }, + ) success, result = await self._download_model_file( download_url, save_path, @@ -1267,9 +1931,20 @@ class DownloadManager: except Exception as e: logger.warning(f"Failed to cleanup file {path}: {e}") - # Log but don't remove .part file to allow resume - if os.path.exists(part_path): + # Keep resumable partial state for the matching backend. + if transfer_backend == "python" and os.path.exists(part_path): logger.info(f"Preserving partial download for resume: {part_path}") + elif transfer_backend == "aria2" and os.path.exists(f"{save_path}.aria2"): + logger.info("Preserving aria2 partial download for resume: %s", save_path) + if download_id: + await self._persist_aria2_state( + download_id, + extra={ + "status": "failed", + "save_path": save_path, + "file_path": save_path, + }, + ) return { "success": False, @@ -1384,6 +2059,9 @@ class DownloadManager: if scanner is not None: await scanner.add_model_to_cache(metadata_dict, relative_path) + if transfer_backend == "aria2" and download_id: + await self._aria2_state_store.remove(download_id) + # Report 100% completion if progress_callback: await progress_callback(100) @@ -1586,11 +2264,22 @@ class DownloadManager: Returns: Dict: Status of the cancellation operation """ - if download_id not in self._download_tasks: + await self._restore_persisted_downloads() + + if download_id not in self._download_tasks and download_id not in self._active_downloads: return {"success": False, "error": "Download task not found"} + download_info = self._active_downloads.get(download_id) + task = self._download_tasks.get(download_id) + active_statuses = {"queued", "waiting", "downloading", "paused", "cancelling"} + if task is None and ( + not isinstance(download_info, dict) + or download_info.get("status") not in active_statuses + ): + return {"success": False, "error": "Download task not found"} + + should_cleanup_local_tracking = False try: - task = self._download_tasks[download_id] backend = ( self._active_downloads.get(download_id, {}).get("transfer_backend") or "python" @@ -1605,14 +2294,19 @@ class DownloadManager: and cancel_result.get("error") != "Download task not found" ): return cancel_result + should_cleanup_local_tracking = True except Exception as exc: logger.warning( "Failed to cancel aria2 transfer for %s, continuing with local task cancellation: %s", download_id, exc, ) + should_cleanup_local_tracking = True + else: + should_cleanup_local_tracking = True - task.cancel() + if task is not None: + task.cancel() pause_control = self._pause_events.get(download_id) if pause_control is not None: @@ -1624,83 +2318,31 @@ class DownloadManager: self._active_downloads[download_id]["bytes_per_second"] = 0.0 # Wait briefly for the task to acknowledge cancellation - try: - await asyncio.wait_for(asyncio.shield(task), timeout=2.0) - except (asyncio.CancelledError, asyncio.TimeoutError): - pass + if task is not None: + try: + await asyncio.wait_for(asyncio.shield(task), timeout=2.0) + except (asyncio.CancelledError, asyncio.TimeoutError): + pass # Clean up ALL files including .part when user cancels download_info = self._active_downloads.get(download_id) - if download_info: - target_files = set() - primary_path = download_info.get("file_path") - if primary_path: - target_files.add(primary_path) - - for extra_path in download_info.get("extracted_paths", []): - if extra_path: - target_files.add(extra_path) - - for file_path in target_files: - if os.path.exists(file_path): - try: - os.unlink(file_path) - logger.debug(f"Deleted cancelled download: {file_path}") - except Exception as e: - logger.error(f"Error deleting file: {e}") - - # Delete the .part file (only on user cancellation) - if "part_path" in download_info: - part_path = download_info["part_path"] - if os.path.exists(part_path): - try: - os.unlink(part_path) - logger.debug(f"Deleted partial download: {part_path}") - except Exception as e: - logger.error(f"Error deleting part file: {e}") - - # Delete metadata files for each resolved path - for file_path in target_files: - metadata_path = os.path.splitext(file_path)[0] + ".metadata.json" - if os.path.exists(metadata_path): - try: - os.unlink(metadata_path) - except Exception as e: - logger.error(f"Error deleting metadata file: {e}") - - preview_path_value = download_info.get("preview_path") - if preview_path_value and os.path.exists(preview_path_value): - try: - os.unlink(preview_path_value) - logger.debug(f"Deleted preview file: {preview_path_value}") - except Exception as e: - logger.error( - f"Error deleting preview file: {preview_path_value}" - ) - - # Delete preview file if exists (.webp or .mp4) for legacy paths - for file_path in target_files: - for preview_ext in [".webp", ".mp4"]: - preview_path = os.path.splitext(file_path)[0] + preview_ext - if os.path.exists(preview_path): - try: - os.unlink(preview_path) - logger.debug(f"Deleted preview file: {preview_path}") - except Exception as e: - logger.error( - f"Error deleting preview file: {preview_path}" - ) + await self._cleanup_cancelled_download_files(download_id, download_info) return {"success": True, "message": "Download cancelled successfully"} except Exception as e: logger.error(f"Error cancelling download: {e}", exc_info=True) return {"success": False, "error": str(e)} finally: - self._pause_events.pop(download_id, None) + if should_cleanup_local_tracking: + self._pause_events.pop(download_id, None) + self._download_tasks.pop(download_id, None) + await self._aria2_state_store.remove(download_id) async def pause_download(self, download_id: str) -> Dict: """Pause an active download without losing progress.""" - if download_id not in self._download_tasks: + await self._restore_persisted_downloads() + + if download_id not in self._download_tasks and download_id not in self._active_downloads: return {"success": False, "error": "Download task not found"} pause_control = self._pause_events.get(download_id) @@ -1732,6 +2374,7 @@ class DownloadManager: if download_info is not None: download_info["status"] = "paused" download_info["bytes_per_second"] = 0.0 + await self._persist_aria2_state(download_id) return {"success": True, "message": "Download paused successfully"} download_info = self._active_downloads.get(download_id) @@ -1744,9 +2387,22 @@ class DownloadManager: async def resume_download(self, download_id: str) -> Dict: """Resume a previously paused download.""" + await self._restore_persisted_downloads() + pause_control = self._pause_events.get(download_id) if pause_control is None: - return {"success": False, "error": "Download task not found"} + persisted = await self._aria2_state_store.get(download_id) + if not persisted or persisted.get("transfer_backend") != "aria2": + return {"success": False, "error": "Download task not found"} + + save_path = persisted.get("save_path") or persisted.get("file_path") + pause_control = DownloadStreamControl() + pause_control.pause() + self._pause_events[download_id] = pause_control + self._active_downloads[download_id] = self._build_restored_download_info( + persisted, + os.path.abspath(save_path), + ) if pause_control.is_set(): return {"success": False, "error": "Download is not paused"} @@ -1758,11 +2414,39 @@ class DownloadManager: ) if backend == "aria2": try: + persisted = None + if download_id not in self._download_tasks: + persisted = await self._aria2_state_store.get(download_id) aria2_downloader = await get_aria2_downloader() if await aria2_downloader.has_transfer(download_id): result = await aria2_downloader.resume_download(download_id) if not result.get("success"): return result + if download_id not in self._download_tasks and persisted: + resume_context = persisted.get("resume_context") + if isinstance(resume_context, dict): + self._start_background_download_task( + download_id, + self._resume_restored_aria2_download( + download_id, + dict(persisted), + ), + ) + else: + self._start_background_download_task( + download_id, + self._download_with_semaphore( + download_id, + persisted.get("model_id"), + persisted.get("model_version_id"), + persisted.get("save_dir"), + persisted.get("relative_path", ""), + None, + bool(persisted.get("use_default_paths", False)), + persisted.get("source"), + persisted.get("file_params"), + ), + ) except Exception as exc: return {"success": False, "error": str(exc)} @@ -1772,6 +2456,7 @@ class DownloadManager: if download_info.get("status") == "paused": download_info["status"] = "downloading" download_info.setdefault("bytes_per_second", 0.0) + await self._persist_aria2_state(download_id) return {"success": True, "message": "Download resumed successfully"} force_reconnect = False @@ -1849,6 +2534,7 @@ class DownloadManager: Returns: Dict: List of active downloads and their status """ + await self._restore_persisted_downloads() return { "downloads": [ { diff --git a/tests/services/test_aria2_downloader.py b/tests/services/test_aria2_downloader.py index 606e1d56..268aa91e 100644 --- a/tests/services/test_aria2_downloader.py +++ b/tests/services/test_aria2_downloader.py @@ -1,11 +1,24 @@ from __future__ import annotations +import asyncio from pathlib import Path from unittest.mock import AsyncMock import pytest from py.services.aria2_downloader import Aria2Downloader, Aria2Error +from py.services.aria2_transfer_state import Aria2TransferStateStore +from py.services import aria2_transfer_state + + +@pytest.fixture(autouse=True) +def isolate_aria2_state(monkeypatch, tmp_path): + state_path = tmp_path / "cache" / "aria2" / "downloads.json" + monkeypatch.setattr( + aria2_transfer_state, + "get_aria2_state_path", + lambda: str(state_path), + ) @pytest.mark.asyncio @@ -79,6 +92,23 @@ async def test_download_file_polls_until_complete(tmp_path, monkeypatch): assert "header" not in rpc_calls[0][1][1] +@pytest.mark.asyncio +async def test_transfer_state_store_shares_lock_and_preserves_concurrent_updates(tmp_path): + state_path = tmp_path / "cache" / "aria2" / "downloads.json" + store_a = Aria2TransferStateStore(str(state_path)) + store_b = Aria2TransferStateStore(str(state_path)) + + assert store_a._lock is store_b._lock + + await asyncio.gather( + store_a.upsert("download-1", {"status": "downloading", "gid": "gid-1"}), + store_b.upsert("download-2", {"status": "paused", "gid": "gid-2"}), + ) + + assert await store_a.get("download-1") == {"status": "downloading", "gid": "gid-1"} + assert await store_b.get("download-2") == {"status": "paused", "gid": "gid-2"} + + @pytest.mark.asyncio async def test_download_file_keeps_auth_headers_when_civitai_does_not_redirect( tmp_path, monkeypatch @@ -161,6 +191,61 @@ async def test_pause_resume_cancel_forward_to_rpc(monkeypatch): ] +@pytest.mark.asyncio +async def test_download_file_reuses_existing_transfer_without_add_uri( + tmp_path, monkeypatch +): + downloader = Aria2Downloader() + downloader._rpc_url = "http://127.0.0.1/jsonrpc" + downloader._rpc_secret = "secret" + + save_path = tmp_path / "downloads" / "model.safetensors" + downloader._transfers["download-1"] = type( + "Transfer", (), {"gid": "gid-1", "save_path": str(save_path)} + )() + + rpc_calls = [] + statuses = iter( + [ + { + "gid": "gid-1", + "status": "active", + "completedLength": "5", + "totalLength": "10", + "downloadSpeed": "25", + }, + { + "gid": "gid-1", + "status": "complete", + "completedLength": "10", + "totalLength": "10", + "downloadSpeed": "0", + "files": [{"path": str(save_path)}], + }, + ] + ) + + async def fake_rpc_call(method, params): + rpc_calls.append((method, params)) + if method == "aria2.tellStatus": + return next(statuses) + raise AssertionError(f"Unexpected RPC method: {method}") + + monkeypatch.setattr(downloader, "_ensure_process", AsyncMock()) + monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call) + monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock()) + + success, result = await downloader.download_file( + "https://example.com/model.safetensors", + str(save_path), + download_id="download-1", + ) + + assert success is True + assert result == str(save_path) + assert [call[0] for call in rpc_calls] == ["aria2.tellStatus", "aria2.tellStatus"] + + def test_build_progress_snapshot_normalizes_numeric_fields(): downloader = Aria2Downloader() diff --git a/tests/services/test_download_manager_basic.py b/tests/services/test_download_manager_basic.py index 10fae8d7..3117d612 100644 --- a/tests/services/test_download_manager_basic.py +++ b/tests/services/test_download_manager_basic.py @@ -10,6 +10,7 @@ import pytest from py.services.download_manager import DownloadManager from py.services import download_manager +from py.services import aria2_transfer_state from py.services.service_registry import ServiceRegistry from py.services.settings_manager import SettingsManager, get_settings_manager @@ -46,6 +47,16 @@ def isolate_settings(monkeypatch, tmp_path): monkeypatch.setattr(SettingsManager, "_save_settings", lambda self: None) +@pytest.fixture(autouse=True) +def isolate_aria2_state(monkeypatch, tmp_path): + state_path = tmp_path / "cache" / "aria2" / "downloads.json" + monkeypatch.setattr( + aria2_transfer_state, + "get_aria2_state_path", + lambda: str(state_path), + ) + + @pytest.fixture(autouse=True) def stub_metadata(monkeypatch): class _StubMetadata: @@ -439,6 +450,436 @@ async def test_pause_resume_queued_aria2_task_without_transfer(monkeypatch): await task +@pytest.mark.asyncio +async def test_resume_download_restores_persisted_aria2_task(monkeypatch, tmp_path): + manager = DownloadManager() + save_dir = tmp_path / "downloads" + save_dir.mkdir() + save_path = save_dir / "file.safetensors" + save_path.write_text("partial") + (save_dir / "file.safetensors.aria2").write_text("control") + + await manager._aria2_state_store.upsert( + "download-1", + { + "download_id": "download-1", + "transfer_backend": "aria2", + "status": "paused", + "save_dir": str(save_dir), + "relative_path": "", + "use_default_paths": False, + "save_path": str(save_path), + "file_path": str(save_path), + "model_id": 12, + "model_version_id": 34, + }, + ) + + created = {} + + async def fake_download_with_semaphore( + self, + task_id, + model_id, + model_version_id, + save_dir, + relative_path, + progress_callback=None, + use_default_paths=False, + source=None, + file_params=None, + ): + created.update( + { + "task_id": task_id, + "model_id": model_id, + "model_version_id": model_version_id, + "save_dir": save_dir, + } + ) + return {"success": True} + + class DummyAria2Downloader: + def __init__(self): + self.calls = [] + + async def get_status_by_gid(self, gid): + return None + + async def has_transfer(self, download_id): + self.calls.append(("has_transfer", download_id)) + return False + + async def resume_download(self, download_id): + self.calls.append(("resume", download_id)) + return {"success": True, "message": "resumed"} + + async def restore_transfer(self, download_id, gid, save_path): + self.calls.append(("restore_transfer", download_id, gid, save_path)) + + dummy_aria2 = DummyAria2Downloader() + monkeypatch.setattr( + download_manager, "_download_with_semaphore", None, raising=False + ) + monkeypatch.setattr( + DownloadManager, + "_download_with_semaphore", + fake_download_with_semaphore, + ) + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=dummy_aria2), + ) + + result = await manager.resume_download("download-1") + await asyncio.sleep(0) + + assert result == {"success": True, "message": "Download resumed successfully"} + assert created["task_id"] == "download-1" + assert created["model_version_id"] == 34 + assert manager._active_downloads["download-1"]["status"] == "downloading" + assert manager._pause_events["download-1"].is_set() is True + + +@pytest.mark.asyncio +async def test_get_active_downloads_restores_persisted_aria2_entries(monkeypatch, tmp_path): + manager = DownloadManager() + save_dir = tmp_path / "downloads" + save_dir.mkdir() + save_path = save_dir / "file.safetensors" + save_path.write_text("partial") + (save_dir / "file.safetensors.aria2").write_text("control") + + await manager._aria2_state_store.upsert( + "download-1", + { + "download_id": "download-1", + "transfer_backend": "aria2", + "status": "paused", + "save_path": str(save_path), + "file_path": str(save_path), + "model_id": 12, + "model_version_id": 34, + }, + ) + + class DummyAria2Downloader: + async def get_status_by_gid(self, gid): + return None + + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=DummyAria2Downloader()), + ) + + downloads = await manager.get_active_downloads() + + assert downloads["downloads"] == [ + { + "download_id": "download-1", + "model_id": 12, + "model_version_id": 34, + "progress": 0, + "status": "paused", + "error": None, + "bytes_downloaded": 0, + "total_bytes": None, + "bytes_per_second": 0.0, + } + ] + + +@pytest.mark.asyncio +async def test_get_active_downloads_restores_orphaned_aria2_partial_as_paused( + monkeypatch, tmp_path +): + manager = DownloadManager() + save_dir = tmp_path / "downloads" + save_dir.mkdir() + save_path = save_dir / "file.safetensors" + save_path.write_text("partial") + (save_dir / "file.safetensors.aria2").write_text("control") + + await manager._aria2_state_store.upsert( + "download-1", + { + "download_id": "download-1", + "transfer_backend": "aria2", + "status": "downloading", + "save_path": str(save_path), + "file_path": str(save_path), + "model_id": 12, + "model_version_id": 34, + "gid": "missing-gid", + }, + ) + + class DummyAria2Downloader: + async def get_status_by_gid(self, gid): + return None + + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=DummyAria2Downloader()), + ) + + downloads = await manager.get_active_downloads() + persisted = await manager._aria2_state_store.get("download-1") + + assert downloads["downloads"] == [ + { + "download_id": "download-1", + "model_id": 12, + "model_version_id": 34, + "progress": 0, + "status": "paused", + "error": None, + "bytes_downloaded": 0, + "total_bytes": None, + "bytes_per_second": 0.0, + } + ] + assert manager._pause_events["download-1"].is_paused() is True + assert persisted["status"] == "paused" + + +@pytest.mark.asyncio +async def test_get_active_downloads_restarts_from_resume_context_for_active_restored_aria2( + monkeypatch, tmp_path +): + manager = DownloadManager() + save_dir = tmp_path / "downloads" + save_dir.mkdir() + save_path = save_dir / "file.safetensors" + save_path.write_text("partial") + + await manager._aria2_state_store.upsert( + "download-1", + { + "download_id": "download-1", + "transfer_backend": "aria2", + "status": "downloading", + "save_path": str(save_path), + "file_path": str(save_path), + "model_id": 12, + "model_version_id": 34, + "gid": "gid-1", + "resume_context": { + "version_info": { + "id": 34, + "modelId": 12, + "model": {"id": 12, "type": "LoRA", "tags": ["fantasy"]}, + "images": [], + }, + "file_info": { + "name": "file.safetensors", + "type": "Model", + "primary": True, + "downloadUrl": "https://example.com/file.safetensors", + }, + "model_type": "lora", + "relative_path": "", + "save_dir": str(save_dir), + "download_urls": ["https://example.com/file.safetensors"], + }, + }, + ) + + restarted = {} + + class DummyAria2Downloader: + async def get_status_by_gid(self, gid): + return {"gid": gid, "status": "active"} + + async def restore_transfer(self, download_id, gid, restored_path): + return None + + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=DummyAria2Downloader()), + ) + + async def fake_resume_restored_aria2_download(self, download_id, record): + restarted.update( + { + "download_id": download_id, + "model_id": record.get("model_id"), + "model_version_id": record.get("model_version_id"), + "save_dir": record.get("save_dir"), + "resume_context": record.get("resume_context"), + } + ) + return {"success": True} + + monkeypatch.setattr( + DownloadManager, + "_resume_restored_aria2_download", + fake_resume_restored_aria2_download, + ) + execute_original = AsyncMock(side_effect=AssertionError("should not refetch metadata")) + monkeypatch.setattr( + DownloadManager, + "_execute_original_download", + execute_original, + ) + + downloads = await manager.get_active_downloads() + assert downloads["downloads"][0]["status"] == "downloading" + restarted_task = manager._download_tasks["download-1"] + await restarted_task + + assert restarted["download_id"] == "download-1" + assert restarted["model_id"] == 12 + assert restarted["model_version_id"] == 34 + assert restarted["save_dir"] is None + assert restarted["resume_context"]["model_type"] == "lora" + assert execute_original.await_count == 0 + + +@pytest.mark.asyncio +async def test_get_active_downloads_restores_persisted_aria2_without_initial_save_path( + monkeypatch, tmp_path +): + manager = DownloadManager() + save_dir = tmp_path / "downloads" + save_dir.mkdir() + save_path = save_dir / "file.safetensors" + save_path.write_text("partial") + (save_dir / "file.safetensors.aria2").write_text("control") + + await manager._aria2_state_store.upsert( + "download-1", + { + "download_id": "download-1", + "transfer_backend": "aria2", + "status": "paused", + "model_id": 12, + "model_version_id": 34, + "resume_context": { + "version_info": { + "id": 34, + "modelId": 12, + "model": {"id": 12, "type": "LoRA"}, + "images": [], + }, + "file_info": { + "name": "file.safetensors", + "type": "Model", + "primary": True, + "downloadUrl": "https://example.com/file.safetensors", + }, + "model_type": "lora", + "relative_path": "", + "save_dir": str(save_dir), + "download_urls": ["https://example.com/file.safetensors"], + }, + }, + ) + + class DummyAria2Downloader: + async def get_status_by_gid(self, gid): + return None + + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=DummyAria2Downloader()), + ) + + downloads = await manager.get_active_downloads() + persisted = await manager._aria2_state_store.get("download-1") + + assert downloads["downloads"] == [ + { + "download_id": "download-1", + "model_id": 12, + "model_version_id": 34, + "progress": 0, + "status": "paused", + "error": None, + "bytes_downloaded": 0, + "total_bytes": None, + "bytes_per_second": 0.0, + } + ] + assert manager._active_downloads["download-1"]["file_path"] == str(save_path) + assert persisted["save_path"] == str(save_path) + assert persisted["file_path"] == str(save_path) + + +@pytest.mark.asyncio +async def test_resume_restored_aria2_download_updates_terminal_status_and_cleanup(monkeypatch): + manager = DownloadManager() + manager._active_downloads["download-1"] = { + "transfer_backend": "aria2", + "status": "paused", + "model_id": 12, + "model_version_id": 34, + "bytes_per_second": 10.0, + } + + persist_state = AsyncMock() + cleanup_record = AsyncMock(return_value=None) + execute_download = AsyncMock(return_value={"success": True}) + record_history = AsyncMock(return_value=None) + sync_version = AsyncMock(return_value=None) + + monkeypatch.setattr(manager, "_persist_aria2_state", persist_state) + monkeypatch.setattr(manager, "_cleanup_download_record", cleanup_record) + monkeypatch.setattr(manager, "_execute_download", execute_download) + monkeypatch.setattr(manager, "_record_downloaded_version_history", record_history) + monkeypatch.setattr(manager, "_sync_downloaded_version", sync_version) + + scheduled_tasks = [] + original_create_task = asyncio.create_task + + def tracking_create_task(coro): + task = original_create_task(coro) + scheduled_tasks.append(task) + return task + + monkeypatch.setattr(download_manager.asyncio, "create_task", tracking_create_task) + + result = await manager._resume_restored_aria2_download( + "download-1", + { + "download_id": "download-1", + "save_path": "/tmp/file.safetensors", + "file_path": "/tmp/file.safetensors", + "model_id": 12, + "model_version_id": 34, + "resume_context": { + "version_info": { + "id": 34, + "modelId": 12, + "model": {"id": 12}, + "images": [], + }, + "file_info": { + "name": "file.safetensors", + "downloadUrl": "https://example.com/file.safetensors", + }, + "model_type": "lora", + "relative_path": "", + "save_dir": "/tmp", + "download_urls": ["https://example.com/file.safetensors"], + }, + }, + ) + + assert result == {"success": True} + assert manager._active_downloads["download-1"]["status"] == "completed" + assert manager._active_downloads["download-1"]["bytes_per_second"] == 0.0 + assert persist_state.await_count == 2 + assert len(scheduled_tasks) == 1 + await asyncio.gather(*scheduled_tasks) + cleanup_record.assert_awaited_once_with("download-1") + + @pytest.mark.asyncio async def test_download_uses_captured_backend_when_settings_change( monkeypatch, scanners, metadata_provider, tmp_path diff --git a/tests/services/test_download_manager_error.py b/tests/services/test_download_manager_error.py index 8f75e6ea..c462327d 100644 --- a/tests/services/test_download_manager_error.py +++ b/tests/services/test_download_manager_error.py @@ -14,6 +14,7 @@ import pytest from py.services.download_manager import DownloadManager from py.services.downloader import DownloadStreamControl from py.services import download_manager +from py.services import aria2_transfer_state from py.services.service_registry import ServiceRegistry from py.services.settings_manager import SettingsManager, get_settings_manager from py.utils.metadata_manager import MetadataManager @@ -49,6 +50,16 @@ def isolate_settings(monkeypatch, tmp_path): monkeypatch.setattr(SettingsManager, "_save_settings", lambda self: None) +@pytest.fixture(autouse=True) +def isolate_aria2_state(monkeypatch, tmp_path): + state_path = tmp_path / "cache" / "aria2" / "downloads.json" + monkeypatch.setattr( + aria2_transfer_state, + "get_aria2_state_path", + lambda: str(state_path), + ) + + @pytest.mark.asyncio async def test_execute_download_retries_urls(monkeypatch, tmp_path): """Test that download retries multiple URLs on failure.""" @@ -800,6 +811,89 @@ async def test_resume_download_returns_error_when_aria2_probe_raises(monkeypatch assert manager._active_downloads[download_id]["status"] == "paused" +@pytest.mark.asyncio +async def test_resume_download_does_not_spawn_restored_worker_when_aria2_resume_fails( + monkeypatch, tmp_path +): + manager = DownloadManager() + + download_id = "dl" + save_path = tmp_path / "file.safetensors" + pause_control = DownloadStreamControl() + pause_control.pause() + manager._pause_events[download_id] = pause_control + manager._active_downloads[download_id] = { + "transfer_backend": "aria2", + "status": "paused", + "bytes_per_second": 0.0, + } + + await manager._aria2_state_store.upsert( + download_id, + { + "download_id": download_id, + "transfer_backend": "aria2", + "status": "paused", + "save_path": str(save_path), + "file_path": str(save_path), + "model_id": 12, + "model_version_id": 34, + "resume_context": { + "version_info": {"id": 34, "modelId": 12, "model": {"id": 12}}, + "file_info": { + "name": "file.safetensors", + "downloadUrl": "https://example.com/file.safetensors", + }, + "model_type": "lora", + "relative_path": "", + "save_dir": str(tmp_path), + "download_urls": ["https://example.com/file.safetensors"], + }, + }, + ) + + resume_restored = AsyncMock(return_value={"success": True}) + monkeypatch.setattr(manager, "_resume_restored_aria2_download", resume_restored) + + class DummyAria2Downloader: + async def has_transfer(self, _download_id): + return True + + async def resume_download(self, _download_id): + return {"success": False, "error": "rpc unavailable"} + + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=DummyAria2Downloader()), + ) + + result = await manager.resume_download(download_id) + + assert result == {"success": False, "error": "rpc unavailable"} + assert download_id not in manager._download_tasks + assert resume_restored.await_count == 0 + assert pause_control.is_paused() is True + assert manager._active_downloads[download_id]["status"] == "paused" + + +@pytest.mark.asyncio +async def test_start_background_download_task_cleans_up_finished_restore_task(): + manager = DownloadManager() + download_id = "download-1" + manager._pause_events[download_id] = DownloadStreamControl() + + async def finished_restore(): + return {"success": True} + + task = manager._start_background_download_task(download_id, finished_restore()) + await task + await asyncio.sleep(0) + + assert download_id not in manager._download_tasks + assert download_id not in manager._pause_events + + @pytest.mark.asyncio async def test_cancel_download_still_cancels_local_task_when_aria2_raises(monkeypatch): manager = DownloadManager() @@ -836,6 +930,217 @@ async def test_cancel_download_still_cancels_local_task_when_aria2_raises(monkey assert task.cancelled() or task.done() +@pytest.mark.asyncio +async def test_cancel_download_preserves_tracking_when_aria2_returns_error(monkeypatch, tmp_path): + manager = DownloadManager() + download_id = "download-queued" + save_path = tmp_path / "file.safetensors" + save_path.write_text("partial") + (tmp_path / "file.safetensors.aria2").write_text("control") + + pause_control = DownloadStreamControl() + manager._pause_events[download_id] = pause_control + manager._download_tasks[download_id] = object() + manager._active_downloads[download_id] = { + "transfer_backend": "aria2", + "status": "downloading", + "file_path": str(save_path), + } + + await manager._aria2_state_store.upsert( + download_id, + { + "download_id": download_id, + "transfer_backend": "aria2", + "status": "downloading", + "save_path": str(save_path), + "file_path": str(save_path), + }, + ) + + cleanup_files = AsyncMock(return_value=None) + monkeypatch.setattr(manager, "_cleanup_cancelled_download_files", cleanup_files) + + class DummyAria2Downloader: + async def cancel_download(self, _download_id): + return {"success": False, "error": "rpc unavailable"} + + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=DummyAria2Downloader()), + ) + + result = await manager.cancel_download(download_id) + + assert result == {"success": False, "error": "rpc unavailable"} + assert download_id in manager._download_tasks + assert download_id in manager._pause_events + assert await manager._aria2_state_store.get(download_id) is not None + assert cleanup_files.await_count == 0 + + +@pytest.mark.asyncio +async def test_cancel_download_rejects_completed_history_entry(tmp_path): + manager = DownloadManager() + download_id = "completed-download" + save_path = tmp_path / "file.safetensors" + metadata_path = tmp_path / "file.metadata.json" + preview_path = tmp_path / "file.jpeg" + save_path.write_text("complete") + metadata_path.write_text("{}") + preview_path.write_text("preview") + + manager._active_downloads[download_id] = { + "transfer_backend": "aria2", + "status": "completed", + "file_path": str(save_path), + "preview_path": str(preview_path), + } + + result = await manager.cancel_download(download_id) + + assert result == {"success": False, "error": "Download task not found"} + assert save_path.exists() + assert metadata_path.exists() + assert preview_path.exists() + + +@pytest.mark.asyncio +async def test_cancel_download_removes_preview_and_aria2_control_files(monkeypatch, tmp_path): + manager = DownloadManager() + + started = asyncio.Event() + + async def blocked_task(): + started.set() + await asyncio.sleep(60) + + task = asyncio.create_task(blocked_task()) + await started.wait() + + save_path = tmp_path / "file.safetensors" + save_path.write_text("partial") + aria2_path = tmp_path / "file.safetensors.aria2" + aria2_path.write_text("control") + preview_path = tmp_path / "file.jpeg" + preview_path.write_text("preview") + + download_id = "download-queued" + manager._download_tasks[download_id] = task + manager._active_downloads[download_id] = { + "transfer_backend": "aria2", + "status": "queued", + "file_path": str(save_path), + "aria2_control_path": str(aria2_path), + "preview_path": str(preview_path), + } + + class DummyAria2Downloader: + async def cancel_download(self, _download_id): + return {"success": True, "message": "cancelled"} + + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=DummyAria2Downloader()), + ) + + result = await manager.cancel_download(download_id) + + assert result["success"] is True + assert not save_path.exists() + assert not aria2_path.exists() + assert not preview_path.exists() + + +@pytest.mark.asyncio +async def test_cancel_download_does_not_delete_untracked_same_basename_preview( + monkeypatch, tmp_path +): + manager = DownloadManager() + + started = asyncio.Event() + + async def blocked_task(): + started.set() + await asyncio.sleep(60) + + task = asyncio.create_task(blocked_task()) + await started.wait() + + save_path = tmp_path / "file.safetensors" + save_path.write_text("partial") + aria2_path = tmp_path / "file.safetensors.aria2" + aria2_path.write_text("control") + manual_preview_path = tmp_path / "file.jpg" + manual_preview_path.write_text("manual") + + download_id = "download-queued" + manager._download_tasks[download_id] = task + manager._active_downloads[download_id] = { + "transfer_backend": "aria2", + "status": "queued", + "file_path": str(save_path), + "aria2_control_path": str(aria2_path), + } + + class DummyAria2Downloader: + async def cancel_download(self, _download_id): + return {"success": True, "message": "cancelled"} + + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=DummyAria2Downloader()), + ) + + result = await manager.cancel_download(download_id) + + assert result["success"] is True + assert not save_path.exists() + assert not aria2_path.exists() + assert manual_preview_path.exists() + + +@pytest.mark.asyncio +async def test_cleanup_cancelled_download_files_retries_aria2_control_deletion( + monkeypatch, tmp_path +): + manager = DownloadManager() + download_id = "download-1" + + save_path = tmp_path / "file.safetensors" + aria2_path = tmp_path / "file.safetensors.aria2" + save_path.write_text("partial") + aria2_path.write_text("control") + + original_unlink = os.unlink + attempts = {"count": 0} + + def flaky_unlink(path): + if path == str(aria2_path) and attempts["count"] == 0: + attempts["count"] += 1 + raise PermissionError("still locked") + return original_unlink(path) + + monkeypatch.setattr(download_manager.os, "unlink", flaky_unlink) + monkeypatch.setattr("py.services.download_manager.asyncio.sleep", AsyncMock()) + + await manager._cleanup_cancelled_download_files( + download_id, + { + "file_path": str(save_path), + "aria2_control_path": str(aria2_path), + "transfer_backend": "aria2", + }, + ) + + assert attempts["count"] == 1 + assert not save_path.exists() + assert not aria2_path.exists() + + @pytest.mark.asyncio async def test_execute_download_waits_for_paused_pre_transfer_gate(monkeypatch, tmp_path): manager = DownloadManager() @@ -931,6 +1236,311 @@ async def test_execute_download_waits_for_paused_pre_transfer_gate(monkeypatch, assert result == {"success": True} +@pytest.mark.asyncio +async def test_execute_download_reuses_existing_aria2_partial_path(monkeypatch, tmp_path): + manager = DownloadManager() + + save_dir = tmp_path / "downloads" + save_dir.mkdir() + target_path = save_dir / "file.safetensors" + target_path.write_text("partial") + control_path = save_dir / "file.safetensors.aria2" + control_path.write_text("control") + + await manager._aria2_state_store.upsert( + "download-1", + { + "download_id": "download-1", + "transfer_backend": "aria2", + "save_path": str(target_path), + "file_path": str(target_path), + "status": "paused", + }, + ) + + class DummyMetadata: + def __init__(self, path: Path): + self.file_path = str(path) + self.sha256 = "sha256" + self.file_name = path.stem + self.preview_url = None + + def generate_unique_filename(self, *_args, **_kwargs): + return "renamed.safetensors" + + def update_file_info(self, _path): + return None + + def to_dict(self): + return {"file_path": self.file_path} + + manager._active_downloads["download-1"] = {"transfer_backend": "aria2"} + dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None)) + monkeypatch.setattr( + DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner) + ) + monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True)) + + async def fake_download_model_file( + self, + download_url, + save_path, + *, + backend, + progress_callback, + use_auth, + download_id, + pause_control, + ): + Path(save_path).write_text("content") + return True, save_path + + monkeypatch.setattr(DownloadManager, "_download_model_file", fake_download_model_file) + + result = await manager._execute_download( + download_urls=["https://example.com/file.safetensors"], + save_dir=str(save_dir), + metadata=DummyMetadata(target_path), + version_info={"images": []}, + relative_path="", + progress_callback=None, + model_type="lora", + download_id="download-1", + transfer_backend="aria2", + ) + + assert result == {"success": True} + assert manager._active_downloads["download-1"]["file_path"] == str(target_path) + assert not (save_dir / "renamed.safetensors").exists() + + +@pytest.mark.asyncio +async def test_execute_download_rejects_conflicting_aria2_partial_path(tmp_path): + manager = DownloadManager() + + save_dir = tmp_path / "downloads" + save_dir.mkdir() + target_path = save_dir / "file.safetensors" + target_path.write_text("partial") + (save_dir / "file.safetensors.aria2").write_text("control") + + await manager._aria2_state_store.upsert( + "other-download", + { + "download_id": "other-download", + "transfer_backend": "aria2", + "save_path": str(target_path), + "file_path": str(target_path), + "status": "paused", + }, + ) + + class DummyMetadata: + def __init__(self, path: Path): + self.file_path = str(path) + self.sha256 = "sha256" + self.file_name = path.stem + self.preview_url = None + + def generate_unique_filename(self, *_args, **_kwargs): + raise AssertionError("should not rename") + + result = await manager._execute_download( + download_urls=["https://example.com/file.safetensors"], + save_dir=str(save_dir), + metadata=DummyMetadata(target_path), + version_info={"images": []}, + relative_path="", + progress_callback=None, + model_type="lora", + download_id="download-1", + transfer_backend="aria2", + ) + + assert result["success"] is False + assert "already using" in result["error"] + + +@pytest.mark.asyncio +async def test_execute_download_reassigns_same_aria2_partial_to_new_download_id( + monkeypatch, tmp_path +): + manager = DownloadManager() + + save_dir = tmp_path / "downloads" + save_dir.mkdir() + target_path = save_dir / "file.safetensors" + target_path.write_text("partial") + (save_dir / "file.safetensors.aria2").write_text("control") + + await manager._aria2_state_store.upsert( + "old-download", + { + "download_id": "old-download", + "transfer_backend": "aria2", + "save_path": str(target_path), + "file_path": str(target_path), + "status": "paused", + "model_id": 11, + "model_version_id": 22, + }, + ) + + class DummyMetadata: + def __init__(self, path: Path): + self.file_path = str(path) + self.sha256 = "sha256" + self.file_name = path.stem + self.preview_url = None + + def generate_unique_filename(self, *_args, **_kwargs): + raise AssertionError("should not rename") + + def update_file_info(self, _path): + return None + + def to_dict(self): + return {"file_path": self.file_path} + + class DummyAria2Downloader: + def __init__(self): + self.calls = [] + + async def reassign_transfer(self, previous_download_id, new_download_id): + self.calls.append(("reassign_transfer", previous_download_id, new_download_id)) + return None + + dummy_aria2 = DummyAria2Downloader() + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=dummy_aria2), + ) + + manager._active_downloads["old-download"] = { + "transfer_backend": "aria2", + "model_id": 11, + "model_version_id": 22, + "status": "paused", + } + manager._active_downloads["new-download"] = { + "transfer_backend": "aria2", + "model_id": 11, + "model_version_id": 22, + "status": "queued", + } + + resolved, path = await manager._resolve_download_target_path( + str(save_dir), + DummyMetadata(target_path), + transfer_backend="aria2", + download_id="new-download", + ) + + assert resolved is True + assert path == str(target_path) + assert "old-download" not in manager._active_downloads + assert manager._active_downloads["new-download"]["file_path"] == str(target_path) + assert dummy_aria2.calls == [("reassign_transfer", "old-download", "new-download")] + assert await manager._aria2_state_store.get("old-download") is None + assert (await manager._aria2_state_store.get("new-download"))["save_path"] == str( + target_path + ) + + +def test_is_same_aria2_download_request_requires_version_id_match(): + manager = DownloadManager() + + assert ( + manager._is_same_aria2_download_request( + {"model_id": 1, "model_version_id": None}, + {"model_id": 1, "model_version_id": 2}, + ) + is False + ) + assert ( + manager._is_same_aria2_download_request( + {"model_id": 1, "model_version_id": 3}, + {"model_id": 1, "model_version_id": None}, + ) + is False + ) + + +@pytest.mark.asyncio +async def test_adopt_existing_aria2_download_cancels_old_running_task(monkeypatch, tmp_path): + manager = DownloadManager() + save_path = tmp_path / "file.safetensors" + + started = asyncio.Event() + cancelled = asyncio.Event() + call_order = [] + + async def old_download(): + started.set() + try: + await asyncio.sleep(60) + except asyncio.CancelledError: + call_order.append("old-task-cancelled") + cancelled.set() + raise + + old_task = asyncio.create_task(old_download()) + await started.wait() + + manager._download_tasks["old-download"] = old_task + old_pause_control = DownloadStreamControl() + old_pause_control.pause() + manager._pause_events["old-download"] = old_pause_control + manager._active_downloads["old-download"] = { + "transfer_backend": "aria2", + "model_id": 11, + "model_version_id": 22, + "status": "downloading", + } + manager._active_downloads["new-download"] = { + "transfer_backend": "aria2", + "model_id": 11, + "model_version_id": 22, + "status": "queued", + } + + await manager._aria2_state_store.upsert( + "old-download", + { + "download_id": "old-download", + "transfer_backend": "aria2", + "save_path": str(save_path), + "file_path": str(save_path), + "status": "downloading", + "model_id": 11, + "model_version_id": 22, + }, + ) + + class DummyAria2Downloader: + async def reassign_transfer(self, previous_download_id, new_download_id): + call_order.append("reassign-transfer") + return None + + monkeypatch.setattr( + download_manager, + "get_aria2_downloader", + AsyncMock(return_value=DummyAria2Downloader()), + ) + + await manager._adopt_existing_aria2_download( + "old-download", + "new-download", + {"model_id": 11, "model_version_id": 22}, + str(save_path), + ) + + assert cancelled.is_set() is True + assert "old-download" not in manager._download_tasks + assert call_order == ["reassign-transfer", "old-task-cancelled"] + + @pytest.mark.asyncio async def test_pause_download_rejects_unknown_task(): """Test that pause_download rejects unknown download tasks."""