fix(tests): update tests to match current download implementation

- Remove calculate_sha256 mocking from download_manager tests since
  SHA256 now comes from API metadata (not recalculated during download)
- Update chunk_size assertion from 4MB to 16MB in downloader config test
This commit is contained in:
Will Miao
2026-03-26 18:00:04 +08:00
parent 95e5bc26d1
commit 3b001a6cd8
2 changed files with 110 additions and 102 deletions

View File

@@ -281,8 +281,6 @@ async def test_execute_download_extracts_zip_single_model(monkeypatch, tmp_path)
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner) DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
) )
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True)) monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
hash_calculator = AsyncMock(return_value="hash-single")
monkeypatch.setattr(download_manager, "calculate_sha256", hash_calculator)
result = await manager._execute_download( result = await manager._execute_download(
download_urls=download_urls, download_urls=download_urls,
@@ -299,10 +297,10 @@ async def test_execute_download_extracts_zip_single_model(monkeypatch, tmp_path)
assert not zip_path.exists() assert not zip_path.exists()
extracted = save_dir / "model.safetensors" extracted = save_dir / "model.safetensors"
assert extracted.exists() assert extracted.exists()
assert hash_calculator.await_args.args[0] == str(extracted)
saved_call = MetadataManager.save_metadata.await_args saved_call = MetadataManager.save_metadata.await_args
assert saved_call.args[0] == str(extracted) assert saved_call.args[0] == str(extracted)
assert saved_call.args[1].sha256 == "hash-single" # SHA256 comes from metadata (API value), not recalculated
assert saved_call.args[1].sha256 == "sha256"
assert dummy_scanner.add_model_to_cache.await_count == 1 assert dummy_scanner.add_model_to_cache.await_count == 1
@@ -351,8 +349,6 @@ async def test_execute_download_extracts_zip_multiple_models(monkeypatch, tmp_pa
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner) DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
) )
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True)) monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
hash_calculator = AsyncMock(side_effect=["hash-one", "hash-two"])
monkeypatch.setattr(download_manager, "calculate_sha256", hash_calculator)
result = await manager._execute_download( result = await manager._execute_download(
download_urls=download_urls, download_urls=download_urls,
@@ -372,15 +368,15 @@ async def test_execute_download_extracts_zip_multiple_models(monkeypatch, tmp_pa
assert extracted_one.exists() assert extracted_one.exists()
assert extracted_two.exists() assert extracted_two.exists()
assert hash_calculator.await_count == 2
assert MetadataManager.save_metadata.await_count == 2 assert MetadataManager.save_metadata.await_count == 2
assert dummy_scanner.add_model_to_cache.await_count == 2 assert dummy_scanner.add_model_to_cache.await_count == 2
metadata_calls = MetadataManager.save_metadata.await_args_list metadata_calls = MetadataManager.save_metadata.await_args_list
assert metadata_calls[0].args[0] == str(extracted_one) assert metadata_calls[0].args[0] == str(extracted_one)
assert metadata_calls[0].args[1].sha256 == "hash-one" # SHA256 comes from metadata (API value), not recalculated
assert metadata_calls[0].args[1].sha256 == "sha256"
assert metadata_calls[1].args[0] == str(extracted_two) assert metadata_calls[1].args[0] == str(extracted_two)
assert metadata_calls[1].args[1].sha256 == "hash-two" assert metadata_calls[1].args[1].sha256 == "sha256"
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -427,8 +423,6 @@ async def test_execute_download_extracts_zip_pt_embedding(monkeypatch, tmp_path)
ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=dummy_scanner) ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=dummy_scanner)
) )
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True)) monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
hash_calculator = AsyncMock(return_value="hash-pt")
monkeypatch.setattr(download_manager, "calculate_sha256", hash_calculator)
result = await manager._execute_download( result = await manager._execute_download(
download_urls=download_urls, download_urls=download_urls,
@@ -445,10 +439,10 @@ async def test_execute_download_extracts_zip_pt_embedding(monkeypatch, tmp_path)
assert not zip_path.exists() assert not zip_path.exists()
extracted = save_dir / "embedding.pt" extracted = save_dir / "embedding.pt"
assert extracted.exists() assert extracted.exists()
assert hash_calculator.await_args.args[0] == str(extracted)
saved_call = MetadataManager.save_metadata.await_args saved_call = MetadataManager.save_metadata.await_args
assert saved_call.args[0] == str(extracted) assert saved_call.args[0] == str(extracted)
assert saved_call.args[1].sha256 == "hash-pt" # SHA256 comes from metadata (API value), not recalculated
assert saved_call.args[1].sha256 == "sha256"
assert dummy_scanner.add_model_to_cache.await_count == 1 assert dummy_scanner.add_model_to_cache.await_count == 1

View File

@@ -9,95 +9,99 @@ from unittest.mock import AsyncMock, patch, MagicMock
import aiohttp import aiohttp
from py.services.downloader import Downloader, DownloadStalledError, DownloadRestartRequested from py.services.downloader import (
Downloader,
DownloadStalledError,
DownloadRestartRequested,
)
class TestDownloadStreamControl: class TestDownloadStreamControl:
"""Test DownloadStreamControl functionality.""" """Test DownloadStreamControl functionality."""
def test_pause_clears_event(self): def test_pause_clears_event(self):
"""Verify pause() clears the event.""" """Verify pause() clears the event."""
from py.services.downloader import DownloadStreamControl from py.services.downloader import DownloadStreamControl
control = DownloadStreamControl() control = DownloadStreamControl()
assert control.is_set() is True # Initially set assert control.is_set() is True # Initially set
control.pause() control.pause()
assert control.is_set() is False assert control.is_set() is False
assert control.is_paused() is True assert control.is_paused() is True
def test_resume_sets_event(self): def test_resume_sets_event(self):
"""Verify resume() sets the event.""" """Verify resume() sets the event."""
from py.services.downloader import DownloadStreamControl from py.services.downloader import DownloadStreamControl
control = DownloadStreamControl() control = DownloadStreamControl()
control.pause() control.pause()
assert control.is_set() is False assert control.is_set() is False
control.resume() control.resume()
assert control.is_set() is True assert control.is_set() is True
assert control.is_paused() is False assert control.is_paused() is False
def test_reconnect_request_tracking(self): def test_reconnect_request_tracking(self):
"""Verify reconnect request tracking works correctly.""" """Verify reconnect request tracking works correctly."""
from py.services.downloader import DownloadStreamControl from py.services.downloader import DownloadStreamControl
control = DownloadStreamControl() control = DownloadStreamControl()
assert control.has_reconnect_request() is False assert control.has_reconnect_request() is False
control.request_reconnect() control.request_reconnect()
assert control.has_reconnect_request() is True assert control.has_reconnect_request() is True
# Consume the request # Consume the request
consumed = control.consume_reconnect_request() consumed = control.consume_reconnect_request()
assert consumed is True assert consumed is True
assert control.has_reconnect_request() is False assert control.has_reconnect_request() is False
def test_mark_progress_clears_reconnect(self): def test_mark_progress_clears_reconnect(self):
"""Verify mark_progress clears reconnect requests.""" """Verify mark_progress clears reconnect requests."""
from py.services.downloader import DownloadStreamControl from py.services.downloader import DownloadStreamControl
control = DownloadStreamControl() control = DownloadStreamControl()
control.request_reconnect() control.request_reconnect()
assert control.has_reconnect_request() is True assert control.has_reconnect_request() is True
control.mark_progress() control.mark_progress()
assert control.has_reconnect_request() is False assert control.has_reconnect_request() is False
assert control.last_progress_timestamp is not None assert control.last_progress_timestamp is not None
def test_time_since_last_progress(self): def test_time_since_last_progress(self):
"""Verify time_since_last_progress calculation.""" """Verify time_since_last_progress calculation."""
from py.services.downloader import DownloadStreamControl from py.services.downloader import DownloadStreamControl
import time import time
control = DownloadStreamControl() control = DownloadStreamControl()
# Initially None # Initially None
assert control.time_since_last_progress() is None assert control.time_since_last_progress() is None
# After marking progress # After marking progress
now = time.time() now = time.time()
control.mark_progress(timestamp=now) control.mark_progress(timestamp=now)
elapsed = control.time_since_last_progress(now=now + 5) elapsed = control.time_since_last_progress(now=now + 5)
assert elapsed == 5.0 assert elapsed == 5.0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_wait_for_resume(self): async def test_wait_for_resume(self):
"""Verify wait() blocks until resumed.""" """Verify wait() blocks until resumed."""
from py.services.downloader import DownloadStreamControl from py.services.downloader import DownloadStreamControl
import asyncio import asyncio
control = DownloadStreamControl() control = DownloadStreamControl()
control.pause() control.pause()
# Start a task that will wait # Start a task that will wait
wait_task = asyncio.create_task(control.wait()) wait_task = asyncio.create_task(control.wait())
# Give it a moment to start waiting # Give it a moment to start waiting
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
assert not wait_task.done() assert not wait_task.done()
# Resume should unblock # Resume should unblock
control.resume() control.resume()
await asyncio.wait_for(wait_task, timeout=0.1) await asyncio.wait_for(wait_task, timeout=0.1)
@@ -105,75 +109,76 @@ class TestDownloadStreamControl:
class TestDownloaderConfiguration: class TestDownloaderConfiguration:
"""Test downloader configuration and initialization.""" """Test downloader configuration and initialization."""
def test_downloader_singleton_pattern(self): def test_downloader_singleton_pattern(self):
"""Verify Downloader follows singleton pattern.""" """Verify Downloader follows singleton pattern."""
# Reset first # Reset first
Downloader._instance = None Downloader._instance = None
# Both should return same instance # Both should return same instance
async def get_instances(): async def get_instances():
instance1 = await Downloader.get_instance() instance1 = await Downloader.get_instance()
instance2 = await Downloader.get_instance() instance2 = await Downloader.get_instance()
return instance1, instance2 return instance1, instance2
import asyncio import asyncio
instance1, instance2 = asyncio.run(get_instances()) instance1, instance2 = asyncio.run(get_instances())
assert instance1 is instance2 assert instance1 is instance2
# Cleanup # Cleanup
Downloader._instance = None Downloader._instance = None
def test_default_configuration_values(self): def test_default_configuration_values(self):
"""Verify default configuration values are set correctly.""" """Verify default configuration values are set correctly."""
Downloader._instance = None Downloader._instance = None
downloader = Downloader() downloader = Downloader()
assert downloader.chunk_size == 4 * 1024 * 1024 # 4MB assert downloader.chunk_size == 16 * 1024 * 1024 # 16MB
assert downloader.max_retries == 5 assert downloader.max_retries == 5
assert downloader.base_delay == 2.0 assert downloader.base_delay == 2.0
assert downloader.session_timeout == 300 assert downloader.session_timeout == 300
# Cleanup # Cleanup
Downloader._instance = None Downloader._instance = None
def test_default_headers_include_user_agent(self): def test_default_headers_include_user_agent(self):
"""Verify default headers include User-Agent.""" """Verify default headers include User-Agent."""
Downloader._instance = None Downloader._instance = None
downloader = Downloader() downloader = Downloader()
assert 'User-Agent' in downloader.default_headers assert "User-Agent" in downloader.default_headers
assert 'ComfyUI-LoRA-Manager' in downloader.default_headers['User-Agent'] assert "ComfyUI-LoRA-Manager" in downloader.default_headers["User-Agent"]
assert downloader.default_headers['Accept-Encoding'] == 'identity' assert downloader.default_headers["Accept-Encoding"] == "identity"
# Cleanup # Cleanup
Downloader._instance = None Downloader._instance = None
def test_stall_timeout_resolution(self): def test_stall_timeout_resolution(self):
"""Verify stall timeout is resolved correctly.""" """Verify stall timeout is resolved correctly."""
Downloader._instance = None Downloader._instance = None
downloader = Downloader() downloader = Downloader()
timeout = downloader._resolve_stall_timeout() timeout = downloader._resolve_stall_timeout()
# Should be at least 30 seconds # Should be at least 30 seconds
assert timeout >= 30.0 assert timeout >= 30.0
# Cleanup # Cleanup
Downloader._instance = None Downloader._instance = None
class TestDownloadProgress: class TestDownloadProgress:
"""Test DownloadProgress dataclass.""" """Test DownloadProgress dataclass."""
def test_download_progress_creation(self): def test_download_progress_creation(self):
"""Verify DownloadProgress can be created with correct values.""" """Verify DownloadProgress can be created with correct values."""
from py.services.downloader import DownloadProgress from py.services.downloader import DownloadProgress
from datetime import datetime from datetime import datetime
progress = DownloadProgress( progress = DownloadProgress(
percent_complete=50.0, percent_complete=50.0,
bytes_downloaded=500, bytes_downloaded=500,
@@ -181,7 +186,7 @@ class TestDownloadProgress:
bytes_per_second=100.5, bytes_per_second=100.5,
timestamp=datetime.now().timestamp(), timestamp=datetime.now().timestamp(),
) )
assert progress.percent_complete == 50.0 assert progress.percent_complete == 50.0
assert progress.bytes_downloaded == 500 assert progress.bytes_downloaded == 500
assert progress.total_bytes == 1000 assert progress.total_bytes == 1000
@@ -191,121 +196,130 @@ class TestDownloadProgress:
class TestDownloaderExceptions: class TestDownloaderExceptions:
"""Test custom exception classes.""" """Test custom exception classes."""
def test_download_stalled_error(self): def test_download_stalled_error(self):
"""Verify DownloadStalledError can be raised and caught.""" """Verify DownloadStalledError can be raised and caught."""
with pytest.raises(DownloadStalledError) as exc_info: with pytest.raises(DownloadStalledError) as exc_info:
raise DownloadStalledError("Download stalled for 120 seconds") raise DownloadStalledError("Download stalled for 120 seconds")
assert "stalled" in str(exc_info.value).lower() assert "stalled" in str(exc_info.value).lower()
def test_download_restart_requested_error(self): def test_download_restart_requested_error(self):
"""Verify DownloadRestartRequested can be raised and caught.""" """Verify DownloadRestartRequested can be raised and caught."""
with pytest.raises(DownloadRestartRequested) as exc_info: with pytest.raises(DownloadRestartRequested) as exc_info:
raise DownloadRestartRequested("Reconnect requested after resume") raise DownloadRestartRequested("Reconnect requested after resume")
assert "reconnect" in str(exc_info.value).lower() or "restart" in str(exc_info.value).lower() assert (
"reconnect" in str(exc_info.value).lower()
or "restart" in str(exc_info.value).lower()
)
class TestDownloaderAuthHeaders: class TestDownloaderAuthHeaders:
"""Test authentication header generation.""" """Test authentication header generation."""
def test_get_auth_headers_without_auth(self): def test_get_auth_headers_without_auth(self):
"""Verify auth headers without authentication.""" """Verify auth headers without authentication."""
Downloader._instance = None Downloader._instance = None
downloader = Downloader() downloader = Downloader()
headers = downloader._get_auth_headers(use_auth=False) headers = downloader._get_auth_headers(use_auth=False)
assert 'User-Agent' in headers assert "User-Agent" in headers
assert 'Authorization' not in headers assert "Authorization" not in headers
Downloader._instance = None Downloader._instance = None
def test_get_auth_headers_with_auth_no_api_key(self, monkeypatch): def test_get_auth_headers_with_auth_no_api_key(self, monkeypatch):
"""Verify auth headers with auth but no API key configured.""" """Verify auth headers with auth but no API key configured."""
Downloader._instance = None Downloader._instance = None
downloader = Downloader() downloader = Downloader()
# Mock settings manager to return no API key # Mock settings manager to return no API key
mock_settings = MagicMock() mock_settings = MagicMock()
mock_settings.get.return_value = None mock_settings.get.return_value = None
with patch('py.services.downloader.get_settings_manager', return_value=mock_settings): with patch(
"py.services.downloader.get_settings_manager", return_value=mock_settings
):
headers = downloader._get_auth_headers(use_auth=True) headers = downloader._get_auth_headers(use_auth=True)
# Should still have User-Agent but no Authorization # Should still have User-Agent but no Authorization
assert 'User-Agent' in headers assert "User-Agent" in headers
assert 'Authorization' not in headers assert "Authorization" not in headers
Downloader._instance = None Downloader._instance = None
def test_get_auth_headers_with_auth_and_api_key(self, monkeypatch): def test_get_auth_headers_with_auth_and_api_key(self, monkeypatch):
"""Verify auth headers with auth and API key configured.""" """Verify auth headers with auth and API key configured."""
Downloader._instance = None Downloader._instance = None
downloader = Downloader() downloader = Downloader()
# Mock settings manager to return API key # Mock settings manager to return API key
mock_settings = MagicMock() mock_settings = MagicMock()
mock_settings.get.return_value = "test-api-key-12345" mock_settings.get.return_value = "test-api-key-12345"
with patch('py.services.downloader.get_settings_manager', return_value=mock_settings): with patch(
"py.services.downloader.get_settings_manager", return_value=mock_settings
):
headers = downloader._get_auth_headers(use_auth=True) headers = downloader._get_auth_headers(use_auth=True)
# Should have both User-Agent and Authorization # Should have both User-Agent and Authorization
assert 'User-Agent' in headers assert "User-Agent" in headers
assert 'Authorization' in headers assert "Authorization" in headers
assert 'test-api-key-12345' in headers['Authorization'] assert "test-api-key-12345" in headers["Authorization"]
assert headers['Content-Type'] == 'application/json' assert headers["Content-Type"] == "application/json"
Downloader._instance = None Downloader._instance = None
class TestDownloaderSessionManagement: class TestDownloaderSessionManagement:
"""Test session management functionality.""" """Test session management functionality."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_should_refresh_session_when_none(self): async def test_should_refresh_session_when_none(self):
"""Verify session refresh is needed when session is None.""" """Verify session refresh is needed when session is None."""
Downloader._instance = None Downloader._instance = None
downloader = Downloader() downloader = Downloader()
# Initially should need refresh # Initially should need refresh
assert downloader._should_refresh_session() is True assert downloader._should_refresh_session() is True
Downloader._instance = None Downloader._instance = None
def test_should_not_refresh_new_session(self): def test_should_not_refresh_new_session(self):
"""Verify new session doesn't need refresh.""" """Verify new session doesn't need refresh."""
Downloader._instance = None Downloader._instance = None
downloader = Downloader() downloader = Downloader()
# Simulate a fresh session # Simulate a fresh session
downloader._session_created_at = MagicMock() downloader._session_created_at = MagicMock()
downloader._session = MagicMock() downloader._session = MagicMock()
# Mock datetime to return current time # Mock datetime to return current time
from datetime import datetime, timedelta from datetime import datetime, timedelta
current_time = datetime.now() current_time = datetime.now()
downloader._session_created_at = current_time downloader._session_created_at = current_time
# Should not need refresh for new session # Should not need refresh for new session
assert downloader._should_refresh_session() is False assert downloader._should_refresh_session() is False
Downloader._instance = None Downloader._instance = None
def test_should_refresh_old_session(self): def test_should_refresh_old_session(self):
"""Verify old session needs refresh.""" """Verify old session needs refresh."""
Downloader._instance = None Downloader._instance = None
downloader = Downloader() downloader = Downloader()
# Simulate an old session (older than timeout) # Simulate an old session (older than timeout)
from datetime import datetime, timedelta from datetime import datetime, timedelta
old_time = datetime.now() - timedelta(seconds=downloader.session_timeout + 1) old_time = datetime.now() - timedelta(seconds=downloader.session_timeout + 1)
downloader._session_created_at = old_time downloader._session_created_at = old_time
downloader._session = MagicMock() downloader._session = MagicMock()
# Should need refresh for old session # Should need refresh for old session
assert downloader._should_refresh_session() is True assert downloader._should_refresh_session() is True
Downloader._instance = None Downloader._instance = None