From 9121c12a2c1e3b61f8d4f209269b82b9072df009 Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Tue, 30 Sep 2025 17:14:59 +0800 Subject: [PATCH] feat(download): retry mirror urls sequentially --- py/services/download_manager.py | 62 +++++++---- tests/services/test_download_manager.py | 138 +++++++++++++++++++++++- 2 files changed, 178 insertions(+), 22 deletions(-) diff --git a/py/services/download_manager.py b/py/services/download_manager.py index 46816479..b4bddc23 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -3,7 +3,7 @@ import os import asyncio from collections import OrderedDict import uuid -from typing import Dict +from typing import Dict, List from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS from ..utils.exif_utils import ExifUtils @@ -294,7 +294,18 @@ class DownloadManager: file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None) if not file_info: return {'success': False, 'error': 'No primary file found in metadata'} - if not file_info.get('downloadUrl'): + mirrors = file_info.get('mirrors') or [] + download_urls = [] + if mirrors: + for mirror in mirrors: + if mirror.get('deletedAt') is None and mirror.get('url'): + download_urls.append(mirror['url']) + else: + download_url = file_info.get('downloadUrl') + if download_url: + download_urls.append(download_url) + + if not download_urls: return {'success': False, 'error': 'No download URL found for primary file'} # 3. Prepare download @@ -314,7 +325,7 @@ class DownloadManager: # 6. Start download process result = await self._execute_download( - download_url=file_info.get('downloadUrl', ''), + download_urls=download_urls, save_dir=save_dir, metadata=metadata, version_info=version_info, @@ -394,8 +405,8 @@ class DownloadManager: return formatted_path - async def _execute_download(self, download_url: str, save_dir: str, - metadata, version_info: Dict, + async def _execute_download(self, download_urls: List[str], save_dir: str, + metadata, version_info: Dict, relative_path: str, progress_callback=None, model_type: str = "lora", download_id: str = None) -> Dict: """Execute the actual download process including preview images and model files""" @@ -506,33 +517,44 @@ class DownloadManager: # Download model file with progress tracking using downloader downloader = await get_downloader() - # Determine if the download URL is from Civitai - use_auth = download_url.startswith("https://civitai.com/api/download/") - success, result = await downloader.download_file( - download_url, - save_path, # Use full path instead of separate dir and filename - progress_callback=lambda p: self._handle_download_progress(p, progress_callback), - use_auth=use_auth # Only use authentication for Civitai downloads - ) + last_error = None + for download_url in download_urls: + use_auth = download_url.startswith("https://civitai.com/api/download/") + success, result = await downloader.download_file( + download_url, + save_path, # Use full path instead of separate dir and filename + progress_callback=lambda p: self._handle_download_progress(p, progress_callback), + use_auth=use_auth # Only use authentication for Civitai downloads + ) - if not success: + if success: + break + + last_error = result + if os.path.exists(save_path): + try: + os.remove(save_path) + except Exception as e: + logger.warning(f"Failed to remove incomplete file {save_path}: {e}") + else: # Clean up files on failure, but preserve .part file for resume cleanup_files = [metadata_path] - if metadata.preview_url and os.path.exists(metadata.preview_url): - cleanup_files.append(metadata.preview_url) - + preview_path_value = getattr(metadata, 'preview_url', None) + if preview_path_value and os.path.exists(preview_path_value): + cleanup_files.append(preview_path_value) + for path in cleanup_files: if path and os.path.exists(path): try: os.remove(path) except Exception as e: logger.warning(f"Failed to cleanup file {path}: {e}") - + # Log but don't remove .part file to allow resume if os.path.exists(part_path): logger.info(f"Preserving partial download for resume: {part_path}") - - return {'success': False, 'error': result} + + return {'success': False, 'error': last_error or 'Failed to download file'} # 4. Update file information (size and modified time) metadata.update_file_info(save_path) diff --git a/tests/services/test_download_manager.py b/tests/services/test_download_manager.py index c7d3d5c3..c49bc7b4 100644 --- a/tests/services/test_download_manager.py +++ b/tests/services/test_download_manager.py @@ -1,3 +1,4 @@ +import os from pathlib import Path from types import SimpleNamespace from unittest.mock import AsyncMock @@ -8,6 +9,7 @@ from py.services.download_manager import DownloadManager from py.services import download_manager from py.services.service_registry import ServiceRegistry from py.services.settings_manager import settings +from py.utils.metadata_manager import MetadataManager @pytest.fixture(autouse=True) @@ -147,7 +149,7 @@ async def test_successful_download_uses_defaults(monkeypatch, scanners, metadata async def fake_execute_download( self, *, - download_url, + download_urls, save_dir, metadata, version_info, @@ -158,7 +160,7 @@ async def test_successful_download_uses_defaults(monkeypatch, scanners, metadata ): captured.update( { - "download_url": download_url, + "download_urls": download_urls, "save_dir": Path(save_dir), "relative_path": relative_path, "progress_callback": progress_callback, @@ -188,6 +190,63 @@ async def test_successful_download_uses_defaults(monkeypatch, scanners, metadata expected_dir = Path(settings.get("default_lora_root")) / "MappedModel" / "fantasy" assert captured["save_dir"] == expected_dir assert captured["model_type"] == "lora" + assert captured["download_urls"] == [ + "https://example.invalid/file.safetensors" + ] + + +async def test_download_uses_active_mirrors(monkeypatch, scanners, metadata_provider, tmp_path): + manager = DownloadManager() + + metadata_with_mirrors = { + "id": 42, + "model": {"type": "LoRA", "tags": ["fantasy"]}, + "baseModel": "BaseModel", + "creator": {"username": "Author"}, + "files": [ + { + "primary": True, + "downloadUrl": "https://example.invalid/file.safetensors", + "mirrors": [ + {"url": "https://mirror.example/file.safetensors", "deletedAt": None}, + {"url": "https://mirror.example/old.safetensors", "deletedAt": "2024-01-01"}, + ], + "name": "file.safetensors", + } + ], + } + + metadata_provider.get_model_version = AsyncMock(return_value=metadata_with_mirrors) + + captured = {} + + async def fake_execute_download( + self, + *, + download_urls, + save_dir, + metadata, + version_info, + relative_path, + progress_callback, + model_type, + download_id, + ): + captured["download_urls"] = download_urls + return {"success": True} + + monkeypatch.setattr(DownloadManager, "_execute_download", fake_execute_download, raising=False) + + result = await manager.download_from_civitai( + model_version_id=99, + save_dir=str(tmp_path), + use_default_paths=True, + progress_callback=None, + source=None, + ) + + assert result["success"] is True + assert captured["download_urls"] == ["https://mirror.example/file.safetensors"] async def test_download_aborts_when_version_exists(monkeypatch, scanners, metadata_provider): @@ -259,3 +318,78 @@ def test_embedding_relative_path_replaces_spaces(): relative_path = manager._calculate_relative_path(version_info, "embedding") assert relative_path == "Base_Model/tag_with_space" + + +async def test_execute_download_retries_urls(monkeypatch, tmp_path): + manager = DownloadManager() + + save_dir = tmp_path / "downloads" + save_dir.mkdir() + initial_path = save_dir / "file.safetensors" + + class DummyMetadata: + def __init__(self, path: Path): + self.file_path = str(path) + self.sha256 = "sha256" + self.file_name = path.stem + self.preview_url = None + + def generate_unique_filename(self, *_args, **_kwargs): + return os.path.basename(self.file_path) + + def update_file_info(self, _path): + return None + + def to_dict(self): + return {"file_path": self.file_path} + + metadata = DummyMetadata(initial_path) + version_info = {"images": []} + download_urls = [ + "https://first.example/file.safetensors", + "https://second.example/file.safetensors", + ] + + class DummyDownloader: + def __init__(self): + self.calls = [] + + async def download_file(self, url, path, progress_callback=None, use_auth=None): + self.calls.append((url, path, use_auth)) + if len(self.calls) == 1: + return False, "first failed" + # Create the target file to simulate a successful download + Path(path).write_text("content") + return True, "second success" + + dummy_downloader = DummyDownloader() + monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=dummy_downloader)) + + class DummyScanner: + def __init__(self): + self.calls = [] + + async def add_model_to_cache(self, metadata_dict, relative_path): + self.calls.append((metadata_dict, relative_path)) + + dummy_scanner = DummyScanner() + monkeypatch.setattr(DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)) + monkeypatch.setattr(DownloadManager, "_get_checkpoint_scanner", AsyncMock(return_value=dummy_scanner)) + monkeypatch.setattr(ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=dummy_scanner)) + + monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True)) + + result = await manager._execute_download( + download_urls=download_urls, + save_dir=str(save_dir), + metadata=metadata, + version_info=version_info, + relative_path="", + progress_callback=None, + model_type="lora", + download_id=None, + ) + + assert result == {"success": True} + assert [url for url, *_ in dummy_downloader.calls] == download_urls + assert dummy_scanner.calls # ensure cache updated