mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 22:52:12 -03:00
Merge pull request #501 from willmiao/codex/update-downloadmanager-to-handle-multiple-download-urls
feat: retry mirror downloads sequentially
This commit is contained in:
@@ -3,7 +3,7 @@ import os
|
|||||||
import asyncio
|
import asyncio
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict
|
from typing import Dict, List
|
||||||
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
|
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
|
||||||
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS
|
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS
|
||||||
from ..utils.exif_utils import ExifUtils
|
from ..utils.exif_utils import ExifUtils
|
||||||
@@ -294,7 +294,18 @@ class DownloadManager:
|
|||||||
file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None)
|
file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None)
|
||||||
if not file_info:
|
if not file_info:
|
||||||
return {'success': False, 'error': 'No primary file found in metadata'}
|
return {'success': False, 'error': 'No primary file found in metadata'}
|
||||||
if not file_info.get('downloadUrl'):
|
mirrors = file_info.get('mirrors') or []
|
||||||
|
download_urls = []
|
||||||
|
if mirrors:
|
||||||
|
for mirror in mirrors:
|
||||||
|
if mirror.get('deletedAt') is None and mirror.get('url'):
|
||||||
|
download_urls.append(mirror['url'])
|
||||||
|
else:
|
||||||
|
download_url = file_info.get('downloadUrl')
|
||||||
|
if download_url:
|
||||||
|
download_urls.append(download_url)
|
||||||
|
|
||||||
|
if not download_urls:
|
||||||
return {'success': False, 'error': 'No download URL found for primary file'}
|
return {'success': False, 'error': 'No download URL found for primary file'}
|
||||||
|
|
||||||
# 3. Prepare download
|
# 3. Prepare download
|
||||||
@@ -314,7 +325,7 @@ class DownloadManager:
|
|||||||
|
|
||||||
# 6. Start download process
|
# 6. Start download process
|
||||||
result = await self._execute_download(
|
result = await self._execute_download(
|
||||||
download_url=file_info.get('downloadUrl', ''),
|
download_urls=download_urls,
|
||||||
save_dir=save_dir,
|
save_dir=save_dir,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
version_info=version_info,
|
version_info=version_info,
|
||||||
@@ -394,8 +405,8 @@ class DownloadManager:
|
|||||||
|
|
||||||
return formatted_path
|
return formatted_path
|
||||||
|
|
||||||
async def _execute_download(self, download_url: str, save_dir: str,
|
async def _execute_download(self, download_urls: List[str], save_dir: str,
|
||||||
metadata, version_info: Dict,
|
metadata, version_info: Dict,
|
||||||
relative_path: str, progress_callback=None,
|
relative_path: str, progress_callback=None,
|
||||||
model_type: str = "lora", download_id: str = None) -> Dict:
|
model_type: str = "lora", download_id: str = None) -> Dict:
|
||||||
"""Execute the actual download process including preview images and model files"""
|
"""Execute the actual download process including preview images and model files"""
|
||||||
@@ -506,33 +517,44 @@ class DownloadManager:
|
|||||||
|
|
||||||
# Download model file with progress tracking using downloader
|
# Download model file with progress tracking using downloader
|
||||||
downloader = await get_downloader()
|
downloader = await get_downloader()
|
||||||
# Determine if the download URL is from Civitai
|
last_error = None
|
||||||
use_auth = download_url.startswith("https://civitai.com/api/download/")
|
for download_url in download_urls:
|
||||||
success, result = await downloader.download_file(
|
use_auth = download_url.startswith("https://civitai.com/api/download/")
|
||||||
download_url,
|
success, result = await downloader.download_file(
|
||||||
save_path, # Use full path instead of separate dir and filename
|
download_url,
|
||||||
progress_callback=lambda p: self._handle_download_progress(p, progress_callback),
|
save_path, # Use full path instead of separate dir and filename
|
||||||
use_auth=use_auth # Only use authentication for Civitai downloads
|
progress_callback=lambda p: self._handle_download_progress(p, progress_callback),
|
||||||
)
|
use_auth=use_auth # Only use authentication for Civitai downloads
|
||||||
|
)
|
||||||
|
|
||||||
if not success:
|
if success:
|
||||||
|
break
|
||||||
|
|
||||||
|
last_error = result
|
||||||
|
if os.path.exists(save_path):
|
||||||
|
try:
|
||||||
|
os.remove(save_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to remove incomplete file {save_path}: {e}")
|
||||||
|
else:
|
||||||
# Clean up files on failure, but preserve .part file for resume
|
# Clean up files on failure, but preserve .part file for resume
|
||||||
cleanup_files = [metadata_path]
|
cleanup_files = [metadata_path]
|
||||||
if metadata.preview_url and os.path.exists(metadata.preview_url):
|
preview_path_value = getattr(metadata, 'preview_url', None)
|
||||||
cleanup_files.append(metadata.preview_url)
|
if preview_path_value and os.path.exists(preview_path_value):
|
||||||
|
cleanup_files.append(preview_path_value)
|
||||||
|
|
||||||
for path in cleanup_files:
|
for path in cleanup_files:
|
||||||
if path and os.path.exists(path):
|
if path and os.path.exists(path):
|
||||||
try:
|
try:
|
||||||
os.remove(path)
|
os.remove(path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to cleanup file {path}: {e}")
|
logger.warning(f"Failed to cleanup file {path}: {e}")
|
||||||
|
|
||||||
# Log but don't remove .part file to allow resume
|
# Log but don't remove .part file to allow resume
|
||||||
if os.path.exists(part_path):
|
if os.path.exists(part_path):
|
||||||
logger.info(f"Preserving partial download for resume: {part_path}")
|
logger.info(f"Preserving partial download for resume: {part_path}")
|
||||||
|
|
||||||
return {'success': False, 'error': result}
|
return {'success': False, 'error': last_error or 'Failed to download file'}
|
||||||
|
|
||||||
# 4. Update file information (size and modified time)
|
# 4. Update file information (size and modified time)
|
||||||
metadata.update_file_info(save_path)
|
metadata.update_file_info(save_path)
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
@@ -8,6 +9,7 @@ from py.services.download_manager import DownloadManager
|
|||||||
from py.services import download_manager
|
from py.services import download_manager
|
||||||
from py.services.service_registry import ServiceRegistry
|
from py.services.service_registry import ServiceRegistry
|
||||||
from py.services.settings_manager import settings
|
from py.services.settings_manager import settings
|
||||||
|
from py.utils.metadata_manager import MetadataManager
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
@@ -147,7 +149,7 @@ async def test_successful_download_uses_defaults(monkeypatch, scanners, metadata
|
|||||||
async def fake_execute_download(
|
async def fake_execute_download(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
download_url,
|
download_urls,
|
||||||
save_dir,
|
save_dir,
|
||||||
metadata,
|
metadata,
|
||||||
version_info,
|
version_info,
|
||||||
@@ -158,7 +160,7 @@ async def test_successful_download_uses_defaults(monkeypatch, scanners, metadata
|
|||||||
):
|
):
|
||||||
captured.update(
|
captured.update(
|
||||||
{
|
{
|
||||||
"download_url": download_url,
|
"download_urls": download_urls,
|
||||||
"save_dir": Path(save_dir),
|
"save_dir": Path(save_dir),
|
||||||
"relative_path": relative_path,
|
"relative_path": relative_path,
|
||||||
"progress_callback": progress_callback,
|
"progress_callback": progress_callback,
|
||||||
@@ -188,6 +190,63 @@ async def test_successful_download_uses_defaults(monkeypatch, scanners, metadata
|
|||||||
expected_dir = Path(settings.get("default_lora_root")) / "MappedModel" / "fantasy"
|
expected_dir = Path(settings.get("default_lora_root")) / "MappedModel" / "fantasy"
|
||||||
assert captured["save_dir"] == expected_dir
|
assert captured["save_dir"] == expected_dir
|
||||||
assert captured["model_type"] == "lora"
|
assert captured["model_type"] == "lora"
|
||||||
|
assert captured["download_urls"] == [
|
||||||
|
"https://example.invalid/file.safetensors"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_download_uses_active_mirrors(monkeypatch, scanners, metadata_provider, tmp_path):
|
||||||
|
manager = DownloadManager()
|
||||||
|
|
||||||
|
metadata_with_mirrors = {
|
||||||
|
"id": 42,
|
||||||
|
"model": {"type": "LoRA", "tags": ["fantasy"]},
|
||||||
|
"baseModel": "BaseModel",
|
||||||
|
"creator": {"username": "Author"},
|
||||||
|
"files": [
|
||||||
|
{
|
||||||
|
"primary": True,
|
||||||
|
"downloadUrl": "https://example.invalid/file.safetensors",
|
||||||
|
"mirrors": [
|
||||||
|
{"url": "https://mirror.example/file.safetensors", "deletedAt": None},
|
||||||
|
{"url": "https://mirror.example/old.safetensors", "deletedAt": "2024-01-01"},
|
||||||
|
],
|
||||||
|
"name": "file.safetensors",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata_provider.get_model_version = AsyncMock(return_value=metadata_with_mirrors)
|
||||||
|
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
async def fake_execute_download(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
download_urls,
|
||||||
|
save_dir,
|
||||||
|
metadata,
|
||||||
|
version_info,
|
||||||
|
relative_path,
|
||||||
|
progress_callback,
|
||||||
|
model_type,
|
||||||
|
download_id,
|
||||||
|
):
|
||||||
|
captured["download_urls"] = download_urls
|
||||||
|
return {"success": True}
|
||||||
|
|
||||||
|
monkeypatch.setattr(DownloadManager, "_execute_download", fake_execute_download, raising=False)
|
||||||
|
|
||||||
|
result = await manager.download_from_civitai(
|
||||||
|
model_version_id=99,
|
||||||
|
save_dir=str(tmp_path),
|
||||||
|
use_default_paths=True,
|
||||||
|
progress_callback=None,
|
||||||
|
source=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert captured["download_urls"] == ["https://mirror.example/file.safetensors"]
|
||||||
|
|
||||||
|
|
||||||
async def test_download_aborts_when_version_exists(monkeypatch, scanners, metadata_provider):
|
async def test_download_aborts_when_version_exists(monkeypatch, scanners, metadata_provider):
|
||||||
@@ -259,3 +318,78 @@ def test_embedding_relative_path_replaces_spaces():
|
|||||||
relative_path = manager._calculate_relative_path(version_info, "embedding")
|
relative_path = manager._calculate_relative_path(version_info, "embedding")
|
||||||
|
|
||||||
assert relative_path == "Base_Model/tag_with_space"
|
assert relative_path == "Base_Model/tag_with_space"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_execute_download_retries_urls(monkeypatch, tmp_path):
|
||||||
|
manager = DownloadManager()
|
||||||
|
|
||||||
|
save_dir = tmp_path / "downloads"
|
||||||
|
save_dir.mkdir()
|
||||||
|
initial_path = save_dir / "file.safetensors"
|
||||||
|
|
||||||
|
class DummyMetadata:
|
||||||
|
def __init__(self, path: Path):
|
||||||
|
self.file_path = str(path)
|
||||||
|
self.sha256 = "sha256"
|
||||||
|
self.file_name = path.stem
|
||||||
|
self.preview_url = None
|
||||||
|
|
||||||
|
def generate_unique_filename(self, *_args, **_kwargs):
|
||||||
|
return os.path.basename(self.file_path)
|
||||||
|
|
||||||
|
def update_file_info(self, _path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {"file_path": self.file_path}
|
||||||
|
|
||||||
|
metadata = DummyMetadata(initial_path)
|
||||||
|
version_info = {"images": []}
|
||||||
|
download_urls = [
|
||||||
|
"https://first.example/file.safetensors",
|
||||||
|
"https://second.example/file.safetensors",
|
||||||
|
]
|
||||||
|
|
||||||
|
class DummyDownloader:
|
||||||
|
def __init__(self):
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
async def download_file(self, url, path, progress_callback=None, use_auth=None):
|
||||||
|
self.calls.append((url, path, use_auth))
|
||||||
|
if len(self.calls) == 1:
|
||||||
|
return False, "first failed"
|
||||||
|
# Create the target file to simulate a successful download
|
||||||
|
Path(path).write_text("content")
|
||||||
|
return True, "second success"
|
||||||
|
|
||||||
|
dummy_downloader = DummyDownloader()
|
||||||
|
monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=dummy_downloader))
|
||||||
|
|
||||||
|
class DummyScanner:
|
||||||
|
def __init__(self):
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
async def add_model_to_cache(self, metadata_dict, relative_path):
|
||||||
|
self.calls.append((metadata_dict, relative_path))
|
||||||
|
|
||||||
|
dummy_scanner = DummyScanner()
|
||||||
|
monkeypatch.setattr(DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner))
|
||||||
|
monkeypatch.setattr(DownloadManager, "_get_checkpoint_scanner", AsyncMock(return_value=dummy_scanner))
|
||||||
|
monkeypatch.setattr(ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=dummy_scanner))
|
||||||
|
|
||||||
|
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
|
||||||
|
|
||||||
|
result = await manager._execute_download(
|
||||||
|
download_urls=download_urls,
|
||||||
|
save_dir=str(save_dir),
|
||||||
|
metadata=metadata,
|
||||||
|
version_info=version_info,
|
||||||
|
relative_path="",
|
||||||
|
progress_callback=None,
|
||||||
|
model_type="lora",
|
||||||
|
download_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {"success": True}
|
||||||
|
assert [url for url, *_ in dummy_downloader.calls] == download_urls
|
||||||
|
assert dummy_scanner.calls # ensure cache updated
|
||||||
|
|||||||
Reference in New Issue
Block a user