Merge pull request #501 from willmiao/codex/update-downloadmanager-to-handle-multiple-download-urls

feat: retry mirror downloads sequentially
This commit is contained in:
pixelpaws
2025-09-30 17:17:34 +08:00
committed by GitHub
2 changed files with 178 additions and 22 deletions

View File

@@ -3,7 +3,7 @@ import os
import asyncio import asyncio
from collections import OrderedDict from collections import OrderedDict
import uuid import uuid
from typing import Dict from typing import Dict, List
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS
from ..utils.exif_utils import ExifUtils from ..utils.exif_utils import ExifUtils
@@ -294,7 +294,18 @@ class DownloadManager:
file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None) file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None)
if not file_info: if not file_info:
return {'success': False, 'error': 'No primary file found in metadata'} return {'success': False, 'error': 'No primary file found in metadata'}
if not file_info.get('downloadUrl'): mirrors = file_info.get('mirrors') or []
download_urls = []
if mirrors:
for mirror in mirrors:
if mirror.get('deletedAt') is None and mirror.get('url'):
download_urls.append(mirror['url'])
else:
download_url = file_info.get('downloadUrl')
if download_url:
download_urls.append(download_url)
if not download_urls:
return {'success': False, 'error': 'No download URL found for primary file'} return {'success': False, 'error': 'No download URL found for primary file'}
# 3. Prepare download # 3. Prepare download
@@ -314,7 +325,7 @@ class DownloadManager:
# 6. Start download process # 6. Start download process
result = await self._execute_download( result = await self._execute_download(
download_url=file_info.get('downloadUrl', ''), download_urls=download_urls,
save_dir=save_dir, save_dir=save_dir,
metadata=metadata, metadata=metadata,
version_info=version_info, version_info=version_info,
@@ -394,8 +405,8 @@ class DownloadManager:
return formatted_path return formatted_path
async def _execute_download(self, download_url: str, save_dir: str, async def _execute_download(self, download_urls: List[str], save_dir: str,
metadata, version_info: Dict, metadata, version_info: Dict,
relative_path: str, progress_callback=None, relative_path: str, progress_callback=None,
model_type: str = "lora", download_id: str = None) -> Dict: model_type: str = "lora", download_id: str = None) -> Dict:
"""Execute the actual download process including preview images and model files""" """Execute the actual download process including preview images and model files"""
@@ -506,33 +517,44 @@ class DownloadManager:
# Download model file with progress tracking using downloader # Download model file with progress tracking using downloader
downloader = await get_downloader() downloader = await get_downloader()
# Determine if the download URL is from Civitai last_error = None
use_auth = download_url.startswith("https://civitai.com/api/download/") for download_url in download_urls:
success, result = await downloader.download_file( use_auth = download_url.startswith("https://civitai.com/api/download/")
download_url, success, result = await downloader.download_file(
save_path, # Use full path instead of separate dir and filename download_url,
progress_callback=lambda p: self._handle_download_progress(p, progress_callback), save_path, # Use full path instead of separate dir and filename
use_auth=use_auth # Only use authentication for Civitai downloads progress_callback=lambda p: self._handle_download_progress(p, progress_callback),
) use_auth=use_auth # Only use authentication for Civitai downloads
)
if not success: if success:
break
last_error = result
if os.path.exists(save_path):
try:
os.remove(save_path)
except Exception as e:
logger.warning(f"Failed to remove incomplete file {save_path}: {e}")
else:
# Clean up files on failure, but preserve .part file for resume # Clean up files on failure, but preserve .part file for resume
cleanup_files = [metadata_path] cleanup_files = [metadata_path]
if metadata.preview_url and os.path.exists(metadata.preview_url): preview_path_value = getattr(metadata, 'preview_url', None)
cleanup_files.append(metadata.preview_url) if preview_path_value and os.path.exists(preview_path_value):
cleanup_files.append(preview_path_value)
for path in cleanup_files: for path in cleanup_files:
if path and os.path.exists(path): if path and os.path.exists(path):
try: try:
os.remove(path) os.remove(path)
except Exception as e: except Exception as e:
logger.warning(f"Failed to cleanup file {path}: {e}") logger.warning(f"Failed to cleanup file {path}: {e}")
# Log but don't remove .part file to allow resume # Log but don't remove .part file to allow resume
if os.path.exists(part_path): if os.path.exists(part_path):
logger.info(f"Preserving partial download for resume: {part_path}") logger.info(f"Preserving partial download for resume: {part_path}")
return {'success': False, 'error': result} return {'success': False, 'error': last_error or 'Failed to download file'}
# 4. Update file information (size and modified time) # 4. Update file information (size and modified time)
metadata.update_file_info(save_path) metadata.update_file_info(save_path)

View File

@@ -1,3 +1,4 @@
import os
from pathlib import Path from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
@@ -8,6 +9,7 @@ from py.services.download_manager import DownloadManager
from py.services import download_manager from py.services import download_manager
from py.services.service_registry import ServiceRegistry from py.services.service_registry import ServiceRegistry
from py.services.settings_manager import settings from py.services.settings_manager import settings
from py.utils.metadata_manager import MetadataManager
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
@@ -147,7 +149,7 @@ async def test_successful_download_uses_defaults(monkeypatch, scanners, metadata
async def fake_execute_download( async def fake_execute_download(
self, self,
*, *,
download_url, download_urls,
save_dir, save_dir,
metadata, metadata,
version_info, version_info,
@@ -158,7 +160,7 @@ async def test_successful_download_uses_defaults(monkeypatch, scanners, metadata
): ):
captured.update( captured.update(
{ {
"download_url": download_url, "download_urls": download_urls,
"save_dir": Path(save_dir), "save_dir": Path(save_dir),
"relative_path": relative_path, "relative_path": relative_path,
"progress_callback": progress_callback, "progress_callback": progress_callback,
@@ -188,6 +190,63 @@ async def test_successful_download_uses_defaults(monkeypatch, scanners, metadata
expected_dir = Path(settings.get("default_lora_root")) / "MappedModel" / "fantasy" expected_dir = Path(settings.get("default_lora_root")) / "MappedModel" / "fantasy"
assert captured["save_dir"] == expected_dir assert captured["save_dir"] == expected_dir
assert captured["model_type"] == "lora" assert captured["model_type"] == "lora"
assert captured["download_urls"] == [
"https://example.invalid/file.safetensors"
]
async def test_download_uses_active_mirrors(monkeypatch, scanners, metadata_provider, tmp_path):
manager = DownloadManager()
metadata_with_mirrors = {
"id": 42,
"model": {"type": "LoRA", "tags": ["fantasy"]},
"baseModel": "BaseModel",
"creator": {"username": "Author"},
"files": [
{
"primary": True,
"downloadUrl": "https://example.invalid/file.safetensors",
"mirrors": [
{"url": "https://mirror.example/file.safetensors", "deletedAt": None},
{"url": "https://mirror.example/old.safetensors", "deletedAt": "2024-01-01"},
],
"name": "file.safetensors",
}
],
}
metadata_provider.get_model_version = AsyncMock(return_value=metadata_with_mirrors)
captured = {}
async def fake_execute_download(
self,
*,
download_urls,
save_dir,
metadata,
version_info,
relative_path,
progress_callback,
model_type,
download_id,
):
captured["download_urls"] = download_urls
return {"success": True}
monkeypatch.setattr(DownloadManager, "_execute_download", fake_execute_download, raising=False)
result = await manager.download_from_civitai(
model_version_id=99,
save_dir=str(tmp_path),
use_default_paths=True,
progress_callback=None,
source=None,
)
assert result["success"] is True
assert captured["download_urls"] == ["https://mirror.example/file.safetensors"]
async def test_download_aborts_when_version_exists(monkeypatch, scanners, metadata_provider): async def test_download_aborts_when_version_exists(monkeypatch, scanners, metadata_provider):
@@ -259,3 +318,78 @@ def test_embedding_relative_path_replaces_spaces():
relative_path = manager._calculate_relative_path(version_info, "embedding") relative_path = manager._calculate_relative_path(version_info, "embedding")
assert relative_path == "Base_Model/tag_with_space" assert relative_path == "Base_Model/tag_with_space"
async def test_execute_download_retries_urls(monkeypatch, tmp_path):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
initial_path = save_dir / "file.safetensors"
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = str(path)
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
def generate_unique_filename(self, *_args, **_kwargs):
return os.path.basename(self.file_path)
def update_file_info(self, _path):
return None
def to_dict(self):
return {"file_path": self.file_path}
metadata = DummyMetadata(initial_path)
version_info = {"images": []}
download_urls = [
"https://first.example/file.safetensors",
"https://second.example/file.safetensors",
]
class DummyDownloader:
def __init__(self):
self.calls = []
async def download_file(self, url, path, progress_callback=None, use_auth=None):
self.calls.append((url, path, use_auth))
if len(self.calls) == 1:
return False, "first failed"
# Create the target file to simulate a successful download
Path(path).write_text("content")
return True, "second success"
dummy_downloader = DummyDownloader()
monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=dummy_downloader))
class DummyScanner:
def __init__(self):
self.calls = []
async def add_model_to_cache(self, metadata_dict, relative_path):
self.calls.append((metadata_dict, relative_path))
dummy_scanner = DummyScanner()
monkeypatch.setattr(DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner))
monkeypatch.setattr(DownloadManager, "_get_checkpoint_scanner", AsyncMock(return_value=dummy_scanner))
monkeypatch.setattr(ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=dummy_scanner))
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
result = await manager._execute_download(
download_urls=download_urls,
save_dir=str(save_dir),
metadata=metadata,
version_info=version_info,
relative_path="",
progress_callback=None,
model_type="lora",
download_id=None,
)
assert result == {"success": True}
assert [url for url, *_ in dummy_downloader.calls] == download_urls
assert dummy_scanner.calls # ensure cache updated