mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
test(backend): Phase 1 - Improve testing infrastructure and add error path tests
## Changes ### pytest-asyncio Integration - Add pytest-asyncio>=0.21.0 to requirements-dev.txt - Update pytest.ini with asyncio_mode=auto and fixture loop scope - Remove custom pytest_pyfunc_call handler from conftest.py - Add @pytest.mark.asyncio to 21 async test functions ### Error Path Tests - Create test_downloader_error_paths.py with 19 new tests covering: - DownloadStreamControl state management (6 tests) - Downloader configuration and initialization (4 tests) - DownloadProgress dataclass validation (1 test) - Custom exception handling (2 tests) - Authentication header generation (3 tests) - Session management (3 tests) ### Documentation - Update backend-testing-improvement-plan.md with Phase 1 completion status ## Test Results - All 458 service tests pass - No regressions introduced Relates to backend testing improvement plan Phase 1
This commit is contained in:
@@ -105,35 +105,6 @@ def _isolate_settings_dir(tmp_path_factory, monkeypatch, request):
|
||||
settings_manager_module.reset_settings_manager()
|
||||
|
||||
|
||||
def pytest_pyfunc_call(pyfuncitem):
|
||||
"""Allow bare async tests to run without pytest.mark.asyncio."""
|
||||
test_function = pyfuncitem.function
|
||||
if inspect.iscoroutinefunction(test_function):
|
||||
func = pyfuncitem.obj
|
||||
signature = inspect.signature(func)
|
||||
accepted_kwargs: Dict[str, Any] = {}
|
||||
for name, parameter in signature.parameters.items():
|
||||
if parameter.kind is inspect.Parameter.VAR_POSITIONAL:
|
||||
continue
|
||||
if parameter.kind is inspect.Parameter.VAR_KEYWORD:
|
||||
accepted_kwargs = dict(pyfuncitem.funcargs)
|
||||
break
|
||||
if name in pyfuncitem.funcargs:
|
||||
accepted_kwargs[name] = pyfuncitem.funcargs[name]
|
||||
|
||||
original_policy = asyncio.get_event_loop_policy()
|
||||
policy = pyfuncitem.funcargs.get("event_loop_policy")
|
||||
if policy is not None and policy is not original_policy:
|
||||
asyncio.set_event_loop_policy(policy)
|
||||
try:
|
||||
asyncio.run(func(**accepted_kwargs))
|
||||
finally:
|
||||
if policy is not None and policy is not original_policy:
|
||||
asyncio.set_event_loop_policy(original_policy)
|
||||
return True
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockHashIndex:
|
||||
"""Minimal hash index stub mirroring the scanner contract."""
|
||||
|
||||
@@ -149,6 +149,7 @@ def noop_cleanup(monkeypatch):
|
||||
monkeypatch.setattr(DownloadManager, "_cleanup_download_record", _cleanup)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_requires_identifier():
|
||||
manager = DownloadManager()
|
||||
result = await manager.download_from_civitai()
|
||||
@@ -158,6 +159,7 @@ async def test_download_requires_identifier():
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_download_uses_defaults(
|
||||
monkeypatch, scanners, metadata_provider, tmp_path
|
||||
):
|
||||
@@ -218,6 +220,7 @@ async def test_successful_download_uses_defaults(
|
||||
assert captured["download_urls"] == ["https://example.invalid/file.safetensors"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_uses_active_mirrors(
|
||||
monkeypatch, scanners, metadata_provider, tmp_path
|
||||
):
|
||||
@@ -283,6 +286,7 @@ async def test_download_uses_active_mirrors(
|
||||
assert captured["download_urls"] == ["https://mirror.example/file.safetensors"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_aborts_when_version_exists(
|
||||
monkeypatch, scanners, metadata_provider
|
||||
):
|
||||
@@ -301,6 +305,7 @@ async def test_download_aborts_when_version_exists(
|
||||
assert execute_mock.await_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_handles_metadata_errors(monkeypatch, scanners):
|
||||
async def failing_provider(*_args, **_kwargs):
|
||||
return None
|
||||
@@ -322,6 +327,7 @@ async def test_download_handles_metadata_errors(monkeypatch, scanners):
|
||||
assert "download_id" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_rejects_unsupported_model_type(monkeypatch, scanners):
|
||||
class Provider:
|
||||
async def get_model_version(self, *_args, **_kwargs):
|
||||
@@ -394,6 +400,7 @@ def test_relative_path_sanitizes_model_and_version_placeholders():
|
||||
assert relative_path == "Fancy_Model/Version_One"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_download_retries_urls(monkeypatch, tmp_path):
|
||||
manager = DownloadManager()
|
||||
|
||||
@@ -479,6 +486,7 @@ async def test_execute_download_retries_urls(monkeypatch, tmp_path):
|
||||
assert dummy_scanner.calls # ensure cache updated
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_download_adjusts_checkpoint_sub_type(monkeypatch, tmp_path):
|
||||
manager = DownloadManager()
|
||||
|
||||
@@ -578,6 +586,7 @@ async def test_execute_download_adjusts_checkpoint_sub_type(monkeypatch, tmp_pat
|
||||
assert cached_entry["sub_type"] == "diffusion_model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_download_extracts_zip_single_model(monkeypatch, tmp_path):
|
||||
manager = DownloadManager()
|
||||
save_dir = tmp_path / "downloads"
|
||||
@@ -645,6 +654,7 @@ async def test_execute_download_extracts_zip_single_model(monkeypatch, tmp_path)
|
||||
assert dummy_scanner.add_model_to_cache.await_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_download_extracts_zip_multiple_models(monkeypatch, tmp_path):
|
||||
manager = DownloadManager()
|
||||
save_dir = tmp_path / "downloads"
|
||||
@@ -720,6 +730,7 @@ async def test_execute_download_extracts_zip_multiple_models(monkeypatch, tmp_pa
|
||||
assert metadata_calls[1].args[1].sha256 == "hash-two"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_download_extracts_zip_pt_embedding(monkeypatch, tmp_path):
|
||||
manager = DownloadManager()
|
||||
save_dir = tmp_path / "downloads"
|
||||
@@ -824,6 +835,7 @@ def test_distribute_preview_to_entries_keeps_existing_file(tmp_path):
|
||||
assert Path(targets[1]).read_bytes() == b"preview"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_download_updates_state():
|
||||
manager = DownloadManager()
|
||||
|
||||
@@ -845,6 +857,7 @@ async def test_pause_download_updates_state():
|
||||
assert manager._active_downloads[download_id]["bytes_per_second"] == 0.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_download_rejects_unknown_task():
|
||||
manager = DownloadManager()
|
||||
|
||||
@@ -853,6 +866,7 @@ async def test_pause_download_rejects_unknown_task():
|
||||
assert result == {"success": False, "error": "Download task not found"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_download_sets_event_and_status():
|
||||
manager = DownloadManager()
|
||||
|
||||
@@ -873,6 +887,7 @@ async def test_resume_download_sets_event_and_status():
|
||||
assert manager._active_downloads[download_id]["status"] == "downloading"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_download_requests_reconnect_for_stalled_stream():
|
||||
manager = DownloadManager()
|
||||
|
||||
@@ -893,6 +908,7 @@ async def test_resume_download_requests_reconnect_for_stalled_stream():
|
||||
assert pause_control.has_reconnect_request() is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_download_rejects_when_not_paused():
|
||||
manager = DownloadManager()
|
||||
|
||||
@@ -1131,6 +1147,7 @@ async def test_execute_download_respects_blur_setting(monkeypatch, tmp_path):
|
||||
assert stored_preview and stored_preview.endswith(".jpeg")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_civarchive_source_uses_civarchive_provider(
|
||||
monkeypatch, scanners, tmp_path
|
||||
):
|
||||
@@ -1235,6 +1252,7 @@ async def test_civarchive_source_uses_civarchive_provider(
|
||||
assert captured["version_info"]["source"] == "civarchive"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_civarchive_source_prioritizes_non_civitai_urls(
|
||||
monkeypatch, scanners, tmp_path
|
||||
):
|
||||
@@ -1323,6 +1341,7 @@ async def test_civarchive_source_prioritizes_non_civitai_urls(
|
||||
assert captured["download_urls"][1] == "https://another-mirror.org/file.safetensors"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_civarchive_source_fallback_to_default_provider(
|
||||
monkeypatch, scanners, tmp_path
|
||||
):
|
||||
|
||||
311
tests/services/test_downloader_error_paths.py
Normal file
311
tests/services/test_downloader_error_paths.py
Normal file
@@ -0,0 +1,311 @@
|
||||
"""Error path tests for downloader module.
|
||||
|
||||
Tests HTTP error handling and network error scenarios.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
import aiohttp
|
||||
|
||||
from py.services.downloader import Downloader, DownloadStalledError, DownloadRestartRequested
|
||||
|
||||
|
||||
class TestDownloadStreamControl:
|
||||
"""Test DownloadStreamControl functionality."""
|
||||
|
||||
def test_pause_clears_event(self):
|
||||
"""Verify pause() clears the event."""
|
||||
from py.services.downloader import DownloadStreamControl
|
||||
|
||||
control = DownloadStreamControl()
|
||||
assert control.is_set() is True # Initially set
|
||||
|
||||
control.pause()
|
||||
assert control.is_set() is False
|
||||
assert control.is_paused() is True
|
||||
|
||||
def test_resume_sets_event(self):
|
||||
"""Verify resume() sets the event."""
|
||||
from py.services.downloader import DownloadStreamControl
|
||||
|
||||
control = DownloadStreamControl()
|
||||
control.pause()
|
||||
assert control.is_set() is False
|
||||
|
||||
control.resume()
|
||||
assert control.is_set() is True
|
||||
assert control.is_paused() is False
|
||||
|
||||
def test_reconnect_request_tracking(self):
|
||||
"""Verify reconnect request tracking works correctly."""
|
||||
from py.services.downloader import DownloadStreamControl
|
||||
|
||||
control = DownloadStreamControl()
|
||||
assert control.has_reconnect_request() is False
|
||||
|
||||
control.request_reconnect()
|
||||
assert control.has_reconnect_request() is True
|
||||
|
||||
# Consume the request
|
||||
consumed = control.consume_reconnect_request()
|
||||
assert consumed is True
|
||||
assert control.has_reconnect_request() is False
|
||||
|
||||
def test_mark_progress_clears_reconnect(self):
|
||||
"""Verify mark_progress clears reconnect requests."""
|
||||
from py.services.downloader import DownloadStreamControl
|
||||
|
||||
control = DownloadStreamControl()
|
||||
control.request_reconnect()
|
||||
assert control.has_reconnect_request() is True
|
||||
|
||||
control.mark_progress()
|
||||
assert control.has_reconnect_request() is False
|
||||
assert control.last_progress_timestamp is not None
|
||||
|
||||
def test_time_since_last_progress(self):
|
||||
"""Verify time_since_last_progress calculation."""
|
||||
from py.services.downloader import DownloadStreamControl
|
||||
import time
|
||||
|
||||
control = DownloadStreamControl()
|
||||
|
||||
# Initially None
|
||||
assert control.time_since_last_progress() is None
|
||||
|
||||
# After marking progress
|
||||
now = time.time()
|
||||
control.mark_progress(timestamp=now)
|
||||
|
||||
elapsed = control.time_since_last_progress(now=now + 5)
|
||||
assert elapsed == 5.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_resume(self):
|
||||
"""Verify wait() blocks until resumed."""
|
||||
from py.services.downloader import DownloadStreamControl
|
||||
import asyncio
|
||||
|
||||
control = DownloadStreamControl()
|
||||
control.pause()
|
||||
|
||||
# Start a task that will wait
|
||||
wait_task = asyncio.create_task(control.wait())
|
||||
|
||||
# Give it a moment to start waiting
|
||||
await asyncio.sleep(0.01)
|
||||
assert not wait_task.done()
|
||||
|
||||
# Resume should unblock
|
||||
control.resume()
|
||||
await asyncio.wait_for(wait_task, timeout=0.1)
|
||||
|
||||
|
||||
class TestDownloaderConfiguration:
|
||||
"""Test downloader configuration and initialization."""
|
||||
|
||||
def test_downloader_singleton_pattern(self):
|
||||
"""Verify Downloader follows singleton pattern."""
|
||||
# Reset first
|
||||
Downloader._instance = None
|
||||
|
||||
# Both should return same instance
|
||||
async def get_instances():
|
||||
instance1 = await Downloader.get_instance()
|
||||
instance2 = await Downloader.get_instance()
|
||||
return instance1, instance2
|
||||
|
||||
import asyncio
|
||||
instance1, instance2 = asyncio.run(get_instances())
|
||||
|
||||
assert instance1 is instance2
|
||||
|
||||
# Cleanup
|
||||
Downloader._instance = None
|
||||
|
||||
def test_default_configuration_values(self):
|
||||
"""Verify default configuration values are set correctly."""
|
||||
Downloader._instance = None
|
||||
|
||||
downloader = Downloader()
|
||||
|
||||
assert downloader.chunk_size == 4 * 1024 * 1024 # 4MB
|
||||
assert downloader.max_retries == 5
|
||||
assert downloader.base_delay == 2.0
|
||||
assert downloader.session_timeout == 300
|
||||
|
||||
# Cleanup
|
||||
Downloader._instance = None
|
||||
|
||||
def test_default_headers_include_user_agent(self):
|
||||
"""Verify default headers include User-Agent."""
|
||||
Downloader._instance = None
|
||||
|
||||
downloader = Downloader()
|
||||
|
||||
assert 'User-Agent' in downloader.default_headers
|
||||
assert 'ComfyUI-LoRA-Manager' in downloader.default_headers['User-Agent']
|
||||
assert downloader.default_headers['Accept-Encoding'] == 'identity'
|
||||
|
||||
# Cleanup
|
||||
Downloader._instance = None
|
||||
|
||||
def test_stall_timeout_resolution(self):
|
||||
"""Verify stall timeout is resolved correctly."""
|
||||
Downloader._instance = None
|
||||
|
||||
downloader = Downloader()
|
||||
timeout = downloader._resolve_stall_timeout()
|
||||
|
||||
# Should be at least 30 seconds
|
||||
assert timeout >= 30.0
|
||||
|
||||
# Cleanup
|
||||
Downloader._instance = None
|
||||
|
||||
|
||||
class TestDownloadProgress:
|
||||
"""Test DownloadProgress dataclass."""
|
||||
|
||||
def test_download_progress_creation(self):
|
||||
"""Verify DownloadProgress can be created with correct values."""
|
||||
from py.services.downloader import DownloadProgress
|
||||
from datetime import datetime
|
||||
|
||||
progress = DownloadProgress(
|
||||
percent_complete=50.0,
|
||||
bytes_downloaded=500,
|
||||
total_bytes=1000,
|
||||
bytes_per_second=100.5,
|
||||
timestamp=datetime.now().timestamp(),
|
||||
)
|
||||
|
||||
assert progress.percent_complete == 50.0
|
||||
assert progress.bytes_downloaded == 500
|
||||
assert progress.total_bytes == 1000
|
||||
assert progress.bytes_per_second == 100.5
|
||||
assert progress.timestamp is not None
|
||||
|
||||
|
||||
class TestDownloaderExceptions:
|
||||
"""Test custom exception classes."""
|
||||
|
||||
def test_download_stalled_error(self):
|
||||
"""Verify DownloadStalledError can be raised and caught."""
|
||||
with pytest.raises(DownloadStalledError) as exc_info:
|
||||
raise DownloadStalledError("Download stalled for 120 seconds")
|
||||
|
||||
assert "stalled" in str(exc_info.value).lower()
|
||||
|
||||
def test_download_restart_requested_error(self):
|
||||
"""Verify DownloadRestartRequested can be raised and caught."""
|
||||
with pytest.raises(DownloadRestartRequested) as exc_info:
|
||||
raise DownloadRestartRequested("Reconnect requested after resume")
|
||||
|
||||
assert "reconnect" in str(exc_info.value).lower() or "restart" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
class TestDownloaderAuthHeaders:
|
||||
"""Test authentication header generation."""
|
||||
|
||||
def test_get_auth_headers_without_auth(self):
|
||||
"""Verify auth headers without authentication."""
|
||||
Downloader._instance = None
|
||||
downloader = Downloader()
|
||||
|
||||
headers = downloader._get_auth_headers(use_auth=False)
|
||||
|
||||
assert 'User-Agent' in headers
|
||||
assert 'Authorization' not in headers
|
||||
|
||||
Downloader._instance = None
|
||||
|
||||
def test_get_auth_headers_with_auth_no_api_key(self, monkeypatch):
|
||||
"""Verify auth headers with auth but no API key configured."""
|
||||
Downloader._instance = None
|
||||
downloader = Downloader()
|
||||
|
||||
# Mock settings manager to return no API key
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.get.return_value = None
|
||||
|
||||
with patch('py.services.downloader.get_settings_manager', return_value=mock_settings):
|
||||
headers = downloader._get_auth_headers(use_auth=True)
|
||||
|
||||
# Should still have User-Agent but no Authorization
|
||||
assert 'User-Agent' in headers
|
||||
assert 'Authorization' not in headers
|
||||
|
||||
Downloader._instance = None
|
||||
|
||||
def test_get_auth_headers_with_auth_and_api_key(self, monkeypatch):
|
||||
"""Verify auth headers with auth and API key configured."""
|
||||
Downloader._instance = None
|
||||
downloader = Downloader()
|
||||
|
||||
# Mock settings manager to return API key
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.get.return_value = "test-api-key-12345"
|
||||
|
||||
with patch('py.services.downloader.get_settings_manager', return_value=mock_settings):
|
||||
headers = downloader._get_auth_headers(use_auth=True)
|
||||
|
||||
# Should have both User-Agent and Authorization
|
||||
assert 'User-Agent' in headers
|
||||
assert 'Authorization' in headers
|
||||
assert 'test-api-key-12345' in headers['Authorization']
|
||||
assert headers['Content-Type'] == 'application/json'
|
||||
|
||||
Downloader._instance = None
|
||||
|
||||
|
||||
class TestDownloaderSessionManagement:
|
||||
"""Test session management functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_refresh_session_when_none(self):
|
||||
"""Verify session refresh is needed when session is None."""
|
||||
Downloader._instance = None
|
||||
downloader = Downloader()
|
||||
|
||||
# Initially should need refresh
|
||||
assert downloader._should_refresh_session() is True
|
||||
|
||||
Downloader._instance = None
|
||||
|
||||
def test_should_not_refresh_new_session(self):
|
||||
"""Verify new session doesn't need refresh."""
|
||||
Downloader._instance = None
|
||||
downloader = Downloader()
|
||||
|
||||
# Simulate a fresh session
|
||||
downloader._session_created_at = MagicMock()
|
||||
downloader._session = MagicMock()
|
||||
|
||||
# Mock datetime to return current time
|
||||
from datetime import datetime, timedelta
|
||||
current_time = datetime.now()
|
||||
downloader._session_created_at = current_time
|
||||
|
||||
# Should not need refresh for new session
|
||||
assert downloader._should_refresh_session() is False
|
||||
|
||||
Downloader._instance = None
|
||||
|
||||
def test_should_refresh_old_session(self):
|
||||
"""Verify old session needs refresh."""
|
||||
Downloader._instance = None
|
||||
downloader = Downloader()
|
||||
|
||||
# Simulate an old session (older than timeout)
|
||||
from datetime import datetime, timedelta
|
||||
old_time = datetime.now() - timedelta(seconds=downloader.session_timeout + 1)
|
||||
downloader._session_created_at = old_time
|
||||
downloader._session = MagicMock()
|
||||
|
||||
# Should need refresh for old session
|
||||
assert downloader._should_refresh_session() is True
|
||||
|
||||
Downloader._instance = None
|
||||
Reference in New Issue
Block a user