mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-07 00:46:44 -03:00
fix(download): restore aria2 resume lifecycle
This commit is contained in:
@@ -10,6 +10,7 @@ import pytest
|
||||
|
||||
from py.services.download_manager import DownloadManager
|
||||
from py.services import download_manager
|
||||
from py.services import aria2_transfer_state
|
||||
from py.services.service_registry import ServiceRegistry
|
||||
from py.services.settings_manager import SettingsManager, get_settings_manager
|
||||
|
||||
@@ -46,6 +47,16 @@ def isolate_settings(monkeypatch, tmp_path):
|
||||
monkeypatch.setattr(SettingsManager, "_save_settings", lambda self: None)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def isolate_aria2_state(monkeypatch, tmp_path):
|
||||
state_path = tmp_path / "cache" / "aria2" / "downloads.json"
|
||||
monkeypatch.setattr(
|
||||
aria2_transfer_state,
|
||||
"get_aria2_state_path",
|
||||
lambda: str(state_path),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def stub_metadata(monkeypatch):
|
||||
class _StubMetadata:
|
||||
@@ -439,6 +450,436 @@ async def test_pause_resume_queued_aria2_task_without_transfer(monkeypatch):
|
||||
await task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_download_restores_persisted_aria2_task(monkeypatch, tmp_path):
|
||||
manager = DownloadManager()
|
||||
save_dir = tmp_path / "downloads"
|
||||
save_dir.mkdir()
|
||||
save_path = save_dir / "file.safetensors"
|
||||
save_path.write_text("partial")
|
||||
(save_dir / "file.safetensors.aria2").write_text("control")
|
||||
|
||||
await manager._aria2_state_store.upsert(
|
||||
"download-1",
|
||||
{
|
||||
"download_id": "download-1",
|
||||
"transfer_backend": "aria2",
|
||||
"status": "paused",
|
||||
"save_dir": str(save_dir),
|
||||
"relative_path": "",
|
||||
"use_default_paths": False,
|
||||
"save_path": str(save_path),
|
||||
"file_path": str(save_path),
|
||||
"model_id": 12,
|
||||
"model_version_id": 34,
|
||||
},
|
||||
)
|
||||
|
||||
created = {}
|
||||
|
||||
async def fake_download_with_semaphore(
|
||||
self,
|
||||
task_id,
|
||||
model_id,
|
||||
model_version_id,
|
||||
save_dir,
|
||||
relative_path,
|
||||
progress_callback=None,
|
||||
use_default_paths=False,
|
||||
source=None,
|
||||
file_params=None,
|
||||
):
|
||||
created.update(
|
||||
{
|
||||
"task_id": task_id,
|
||||
"model_id": model_id,
|
||||
"model_version_id": model_version_id,
|
||||
"save_dir": save_dir,
|
||||
}
|
||||
)
|
||||
return {"success": True}
|
||||
|
||||
class DummyAria2Downloader:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
async def get_status_by_gid(self, gid):
|
||||
return None
|
||||
|
||||
async def has_transfer(self, download_id):
|
||||
self.calls.append(("has_transfer", download_id))
|
||||
return False
|
||||
|
||||
async def resume_download(self, download_id):
|
||||
self.calls.append(("resume", download_id))
|
||||
return {"success": True, "message": "resumed"}
|
||||
|
||||
async def restore_transfer(self, download_id, gid, save_path):
|
||||
self.calls.append(("restore_transfer", download_id, gid, save_path))
|
||||
|
||||
dummy_aria2 = DummyAria2Downloader()
|
||||
monkeypatch.setattr(
|
||||
download_manager, "_download_with_semaphore", None, raising=False
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
DownloadManager,
|
||||
"_download_with_semaphore",
|
||||
fake_download_with_semaphore,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
download_manager,
|
||||
"get_aria2_downloader",
|
||||
AsyncMock(return_value=dummy_aria2),
|
||||
)
|
||||
|
||||
result = await manager.resume_download("download-1")
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert result == {"success": True, "message": "Download resumed successfully"}
|
||||
assert created["task_id"] == "download-1"
|
||||
assert created["model_version_id"] == 34
|
||||
assert manager._active_downloads["download-1"]["status"] == "downloading"
|
||||
assert manager._pause_events["download-1"].is_set() is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_downloads_restores_persisted_aria2_entries(monkeypatch, tmp_path):
|
||||
manager = DownloadManager()
|
||||
save_dir = tmp_path / "downloads"
|
||||
save_dir.mkdir()
|
||||
save_path = save_dir / "file.safetensors"
|
||||
save_path.write_text("partial")
|
||||
(save_dir / "file.safetensors.aria2").write_text("control")
|
||||
|
||||
await manager._aria2_state_store.upsert(
|
||||
"download-1",
|
||||
{
|
||||
"download_id": "download-1",
|
||||
"transfer_backend": "aria2",
|
||||
"status": "paused",
|
||||
"save_path": str(save_path),
|
||||
"file_path": str(save_path),
|
||||
"model_id": 12,
|
||||
"model_version_id": 34,
|
||||
},
|
||||
)
|
||||
|
||||
class DummyAria2Downloader:
|
||||
async def get_status_by_gid(self, gid):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_manager,
|
||||
"get_aria2_downloader",
|
||||
AsyncMock(return_value=DummyAria2Downloader()),
|
||||
)
|
||||
|
||||
downloads = await manager.get_active_downloads()
|
||||
|
||||
assert downloads["downloads"] == [
|
||||
{
|
||||
"download_id": "download-1",
|
||||
"model_id": 12,
|
||||
"model_version_id": 34,
|
||||
"progress": 0,
|
||||
"status": "paused",
|
||||
"error": None,
|
||||
"bytes_downloaded": 0,
|
||||
"total_bytes": None,
|
||||
"bytes_per_second": 0.0,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_downloads_restores_orphaned_aria2_partial_as_paused(
|
||||
monkeypatch, tmp_path
|
||||
):
|
||||
manager = DownloadManager()
|
||||
save_dir = tmp_path / "downloads"
|
||||
save_dir.mkdir()
|
||||
save_path = save_dir / "file.safetensors"
|
||||
save_path.write_text("partial")
|
||||
(save_dir / "file.safetensors.aria2").write_text("control")
|
||||
|
||||
await manager._aria2_state_store.upsert(
|
||||
"download-1",
|
||||
{
|
||||
"download_id": "download-1",
|
||||
"transfer_backend": "aria2",
|
||||
"status": "downloading",
|
||||
"save_path": str(save_path),
|
||||
"file_path": str(save_path),
|
||||
"model_id": 12,
|
||||
"model_version_id": 34,
|
||||
"gid": "missing-gid",
|
||||
},
|
||||
)
|
||||
|
||||
class DummyAria2Downloader:
|
||||
async def get_status_by_gid(self, gid):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_manager,
|
||||
"get_aria2_downloader",
|
||||
AsyncMock(return_value=DummyAria2Downloader()),
|
||||
)
|
||||
|
||||
downloads = await manager.get_active_downloads()
|
||||
persisted = await manager._aria2_state_store.get("download-1")
|
||||
|
||||
assert downloads["downloads"] == [
|
||||
{
|
||||
"download_id": "download-1",
|
||||
"model_id": 12,
|
||||
"model_version_id": 34,
|
||||
"progress": 0,
|
||||
"status": "paused",
|
||||
"error": None,
|
||||
"bytes_downloaded": 0,
|
||||
"total_bytes": None,
|
||||
"bytes_per_second": 0.0,
|
||||
}
|
||||
]
|
||||
assert manager._pause_events["download-1"].is_paused() is True
|
||||
assert persisted["status"] == "paused"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_downloads_restarts_from_resume_context_for_active_restored_aria2(
|
||||
monkeypatch, tmp_path
|
||||
):
|
||||
manager = DownloadManager()
|
||||
save_dir = tmp_path / "downloads"
|
||||
save_dir.mkdir()
|
||||
save_path = save_dir / "file.safetensors"
|
||||
save_path.write_text("partial")
|
||||
|
||||
await manager._aria2_state_store.upsert(
|
||||
"download-1",
|
||||
{
|
||||
"download_id": "download-1",
|
||||
"transfer_backend": "aria2",
|
||||
"status": "downloading",
|
||||
"save_path": str(save_path),
|
||||
"file_path": str(save_path),
|
||||
"model_id": 12,
|
||||
"model_version_id": 34,
|
||||
"gid": "gid-1",
|
||||
"resume_context": {
|
||||
"version_info": {
|
||||
"id": 34,
|
||||
"modelId": 12,
|
||||
"model": {"id": 12, "type": "LoRA", "tags": ["fantasy"]},
|
||||
"images": [],
|
||||
},
|
||||
"file_info": {
|
||||
"name": "file.safetensors",
|
||||
"type": "Model",
|
||||
"primary": True,
|
||||
"downloadUrl": "https://example.com/file.safetensors",
|
||||
},
|
||||
"model_type": "lora",
|
||||
"relative_path": "",
|
||||
"save_dir": str(save_dir),
|
||||
"download_urls": ["https://example.com/file.safetensors"],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
restarted = {}
|
||||
|
||||
class DummyAria2Downloader:
|
||||
async def get_status_by_gid(self, gid):
|
||||
return {"gid": gid, "status": "active"}
|
||||
|
||||
async def restore_transfer(self, download_id, gid, restored_path):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_manager,
|
||||
"get_aria2_downloader",
|
||||
AsyncMock(return_value=DummyAria2Downloader()),
|
||||
)
|
||||
|
||||
async def fake_resume_restored_aria2_download(self, download_id, record):
|
||||
restarted.update(
|
||||
{
|
||||
"download_id": download_id,
|
||||
"model_id": record.get("model_id"),
|
||||
"model_version_id": record.get("model_version_id"),
|
||||
"save_dir": record.get("save_dir"),
|
||||
"resume_context": record.get("resume_context"),
|
||||
}
|
||||
)
|
||||
return {"success": True}
|
||||
|
||||
monkeypatch.setattr(
|
||||
DownloadManager,
|
||||
"_resume_restored_aria2_download",
|
||||
fake_resume_restored_aria2_download,
|
||||
)
|
||||
execute_original = AsyncMock(side_effect=AssertionError("should not refetch metadata"))
|
||||
monkeypatch.setattr(
|
||||
DownloadManager,
|
||||
"_execute_original_download",
|
||||
execute_original,
|
||||
)
|
||||
|
||||
downloads = await manager.get_active_downloads()
|
||||
assert downloads["downloads"][0]["status"] == "downloading"
|
||||
restarted_task = manager._download_tasks["download-1"]
|
||||
await restarted_task
|
||||
|
||||
assert restarted["download_id"] == "download-1"
|
||||
assert restarted["model_id"] == 12
|
||||
assert restarted["model_version_id"] == 34
|
||||
assert restarted["save_dir"] is None
|
||||
assert restarted["resume_context"]["model_type"] == "lora"
|
||||
assert execute_original.await_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_downloads_restores_persisted_aria2_without_initial_save_path(
|
||||
monkeypatch, tmp_path
|
||||
):
|
||||
manager = DownloadManager()
|
||||
save_dir = tmp_path / "downloads"
|
||||
save_dir.mkdir()
|
||||
save_path = save_dir / "file.safetensors"
|
||||
save_path.write_text("partial")
|
||||
(save_dir / "file.safetensors.aria2").write_text("control")
|
||||
|
||||
await manager._aria2_state_store.upsert(
|
||||
"download-1",
|
||||
{
|
||||
"download_id": "download-1",
|
||||
"transfer_backend": "aria2",
|
||||
"status": "paused",
|
||||
"model_id": 12,
|
||||
"model_version_id": 34,
|
||||
"resume_context": {
|
||||
"version_info": {
|
||||
"id": 34,
|
||||
"modelId": 12,
|
||||
"model": {"id": 12, "type": "LoRA"},
|
||||
"images": [],
|
||||
},
|
||||
"file_info": {
|
||||
"name": "file.safetensors",
|
||||
"type": "Model",
|
||||
"primary": True,
|
||||
"downloadUrl": "https://example.com/file.safetensors",
|
||||
},
|
||||
"model_type": "lora",
|
||||
"relative_path": "",
|
||||
"save_dir": str(save_dir),
|
||||
"download_urls": ["https://example.com/file.safetensors"],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
class DummyAria2Downloader:
|
||||
async def get_status_by_gid(self, gid):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_manager,
|
||||
"get_aria2_downloader",
|
||||
AsyncMock(return_value=DummyAria2Downloader()),
|
||||
)
|
||||
|
||||
downloads = await manager.get_active_downloads()
|
||||
persisted = await manager._aria2_state_store.get("download-1")
|
||||
|
||||
assert downloads["downloads"] == [
|
||||
{
|
||||
"download_id": "download-1",
|
||||
"model_id": 12,
|
||||
"model_version_id": 34,
|
||||
"progress": 0,
|
||||
"status": "paused",
|
||||
"error": None,
|
||||
"bytes_downloaded": 0,
|
||||
"total_bytes": None,
|
||||
"bytes_per_second": 0.0,
|
||||
}
|
||||
]
|
||||
assert manager._active_downloads["download-1"]["file_path"] == str(save_path)
|
||||
assert persisted["save_path"] == str(save_path)
|
||||
assert persisted["file_path"] == str(save_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_restored_aria2_download_updates_terminal_status_and_cleanup(monkeypatch):
|
||||
manager = DownloadManager()
|
||||
manager._active_downloads["download-1"] = {
|
||||
"transfer_backend": "aria2",
|
||||
"status": "paused",
|
||||
"model_id": 12,
|
||||
"model_version_id": 34,
|
||||
"bytes_per_second": 10.0,
|
||||
}
|
||||
|
||||
persist_state = AsyncMock()
|
||||
cleanup_record = AsyncMock(return_value=None)
|
||||
execute_download = AsyncMock(return_value={"success": True})
|
||||
record_history = AsyncMock(return_value=None)
|
||||
sync_version = AsyncMock(return_value=None)
|
||||
|
||||
monkeypatch.setattr(manager, "_persist_aria2_state", persist_state)
|
||||
monkeypatch.setattr(manager, "_cleanup_download_record", cleanup_record)
|
||||
monkeypatch.setattr(manager, "_execute_download", execute_download)
|
||||
monkeypatch.setattr(manager, "_record_downloaded_version_history", record_history)
|
||||
monkeypatch.setattr(manager, "_sync_downloaded_version", sync_version)
|
||||
|
||||
scheduled_tasks = []
|
||||
original_create_task = asyncio.create_task
|
||||
|
||||
def tracking_create_task(coro):
|
||||
task = original_create_task(coro)
|
||||
scheduled_tasks.append(task)
|
||||
return task
|
||||
|
||||
monkeypatch.setattr(download_manager.asyncio, "create_task", tracking_create_task)
|
||||
|
||||
result = await manager._resume_restored_aria2_download(
|
||||
"download-1",
|
||||
{
|
||||
"download_id": "download-1",
|
||||
"save_path": "/tmp/file.safetensors",
|
||||
"file_path": "/tmp/file.safetensors",
|
||||
"model_id": 12,
|
||||
"model_version_id": 34,
|
||||
"resume_context": {
|
||||
"version_info": {
|
||||
"id": 34,
|
||||
"modelId": 12,
|
||||
"model": {"id": 12},
|
||||
"images": [],
|
||||
},
|
||||
"file_info": {
|
||||
"name": "file.safetensors",
|
||||
"downloadUrl": "https://example.com/file.safetensors",
|
||||
},
|
||||
"model_type": "lora",
|
||||
"relative_path": "",
|
||||
"save_dir": "/tmp",
|
||||
"download_urls": ["https://example.com/file.safetensors"],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert result == {"success": True}
|
||||
assert manager._active_downloads["download-1"]["status"] == "completed"
|
||||
assert manager._active_downloads["download-1"]["bytes_per_second"] == 0.0
|
||||
assert persist_state.await_count == 2
|
||||
assert len(scheduled_tasks) == 1
|
||||
await asyncio.gather(*scheduled_tasks)
|
||||
cleanup_record.assert_awaited_once_with("download-1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_uses_captured_backend_when_settings_change(
|
||||
monkeypatch, scanners, metadata_provider, tmp_path
|
||||
|
||||
Reference in New Issue
Block a user