mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
fix(downloader): enforce file size integrity checks
This commit is contained in:
@@ -354,12 +354,60 @@ class Downloader:
|
||||
last_progress_report_time = now
|
||||
|
||||
# Download completed successfully
|
||||
# Verify file size if total_size was provided
|
||||
final_size = os.path.getsize(part_path)
|
||||
if total_size > 0 and final_size != total_size:
|
||||
logger.warning(f"File size mismatch. Expected: {total_size}, Got: {final_size}")
|
||||
# Don't treat this as fatal error, continue anyway
|
||||
|
||||
# Verify file size integrity before finalizing
|
||||
final_size = os.path.getsize(part_path) if os.path.exists(part_path) else 0
|
||||
expected_size = total_size if total_size > 0 else None
|
||||
|
||||
integrity_error: Optional[str] = None
|
||||
if final_size <= 0:
|
||||
integrity_error = "Downloaded file is empty"
|
||||
elif expected_size is not None and final_size != expected_size:
|
||||
integrity_error = (
|
||||
f"File size mismatch. Expected: {expected_size}, Got: {final_size}"
|
||||
)
|
||||
|
||||
if integrity_error is not None:
|
||||
logger.error(
|
||||
"Download integrity check failed for %s: %s",
|
||||
save_path,
|
||||
integrity_error,
|
||||
)
|
||||
|
||||
# Remove the corrupted payload so future attempts start fresh
|
||||
if os.path.exists(part_path):
|
||||
try:
|
||||
os.remove(part_path)
|
||||
except OSError as remove_error:
|
||||
logger.warning(
|
||||
"Failed to delete corrupted download %s: %s",
|
||||
part_path,
|
||||
remove_error,
|
||||
)
|
||||
if part_path != save_path and os.path.exists(save_path):
|
||||
try:
|
||||
os.remove(save_path)
|
||||
except OSError as remove_error:
|
||||
logger.warning(
|
||||
"Failed to delete target file %s after integrity error: %s",
|
||||
save_path,
|
||||
remove_error,
|
||||
)
|
||||
|
||||
retry_count += 1
|
||||
if retry_count <= self.max_retries:
|
||||
delay = self.base_delay * (2 ** (retry_count - 1))
|
||||
logger.info(
|
||||
"Retrying download in %s seconds due to integrity check failure",
|
||||
delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
resume_offset = 0
|
||||
total_size = 0
|
||||
await self._create_session()
|
||||
continue
|
||||
|
||||
return False, integrity_error
|
||||
|
||||
# Atomically rename .part to final file (only if using resume)
|
||||
if allow_resume and part_path != save_path:
|
||||
max_rename_attempts = 5
|
||||
@@ -382,7 +430,9 @@ class Downloader:
|
||||
else:
|
||||
logger.error(f"Failed to rename file after {max_rename_attempts} attempts: {e}")
|
||||
return False, f"Failed to finalize download: {str(e)}"
|
||||
|
||||
|
||||
final_size = os.path.getsize(save_path)
|
||||
|
||||
# Ensure 100% progress is reported
|
||||
if progress_callback:
|
||||
final_snapshot = DownloadProgress(
|
||||
|
||||
125
tests/services/test_downloader.py
Normal file
125
tests/services/test_downloader.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.downloader import Downloader
|
||||
|
||||
|
||||
class FakeStream:
|
||||
def __init__(self, chunks):
|
||||
self._chunks = list(chunks)
|
||||
|
||||
async def iter_chunked(self, _chunk_size):
|
||||
for chunk in self._chunks:
|
||||
await asyncio.sleep(0)
|
||||
yield chunk
|
||||
|
||||
|
||||
class FakeResponse:
|
||||
def __init__(self, status, headers, chunks):
|
||||
self.status = status
|
||||
self.headers = headers
|
||||
self.content = FakeStream(chunks)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
class FakeSession:
|
||||
def __init__(self, responses):
|
||||
self._responses = list(responses)
|
||||
self._get_calls = 0
|
||||
|
||||
def get(self, url, headers=None, allow_redirects=True, proxy=None): # noqa: D401 - signature mirrors aiohttp
|
||||
del url, headers, allow_redirects, proxy
|
||||
response_factory = self._responses[self._get_calls]
|
||||
self._get_calls += 1
|
||||
return response_factory()
|
||||
|
||||
async def close(self):
|
||||
return None
|
||||
|
||||
|
||||
def _build_downloader(responses, *, max_retries=0):
|
||||
downloader = Downloader()
|
||||
downloader.max_retries = max_retries
|
||||
downloader.base_delay = 0
|
||||
fake_session = FakeSession(responses)
|
||||
downloader._session = fake_session
|
||||
downloader._session_created_at = datetime.now()
|
||||
downloader._proxy_url = None
|
||||
return downloader
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_file_fails_when_size_mismatch(tmp_path):
|
||||
target_path = tmp_path / "model" / "file.bin"
|
||||
target_path.parent.mkdir()
|
||||
|
||||
responses = [
|
||||
lambda: FakeResponse(
|
||||
status=200,
|
||||
headers={"content-length": "10"},
|
||||
chunks=[b"abc"],
|
||||
)
|
||||
]
|
||||
|
||||
downloader = _build_downloader(responses)
|
||||
|
||||
success, message = await downloader.download_file("https://example.com/file", str(target_path))
|
||||
|
||||
assert success is False
|
||||
assert "mismatch" in message.lower()
|
||||
assert not target_path.exists()
|
||||
assert not Path(str(target_path) + ".part").exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_file_fails_when_zero_bytes(tmp_path):
|
||||
target_path = tmp_path / "model" / "file.bin"
|
||||
target_path.parent.mkdir()
|
||||
|
||||
responses = [
|
||||
lambda: FakeResponse(
|
||||
status=200,
|
||||
headers={"content-length": "0"},
|
||||
chunks=[],
|
||||
)
|
||||
]
|
||||
|
||||
downloader = _build_downloader(responses)
|
||||
|
||||
success, message = await downloader.download_file("https://example.com/file", str(target_path))
|
||||
|
||||
assert success is False
|
||||
assert "empty" in message.lower()
|
||||
assert not target_path.exists()
|
||||
assert not Path(str(target_path) + ".part").exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_file_succeeds_when_sizes_match(tmp_path):
|
||||
target_path = tmp_path / "model" / "file.bin"
|
||||
target_path.parent.mkdir()
|
||||
|
||||
payload = b"abcdef"
|
||||
responses = [
|
||||
lambda: FakeResponse(
|
||||
status=200,
|
||||
headers={"content-length": str(len(payload))},
|
||||
chunks=[payload],
|
||||
)
|
||||
]
|
||||
|
||||
downloader = _build_downloader(responses)
|
||||
|
||||
success, result_path = await downloader.download_file("https://example.com/file", str(target_path))
|
||||
|
||||
assert success is True
|
||||
assert Path(result_path).read_bytes() == payload
|
||||
assert not Path(str(target_path) + ".part").exists()
|
||||
Reference in New Issue
Block a user