diff --git a/tests/conftest.py b/tests/conftest.py index 818aeb9b..006a1adc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -42,8 +42,30 @@ sys.modules['nodes'] = nodes_mock def pytest_pyfunc_call(pyfuncitem): - if inspect.iscoroutinefunction(pyfuncitem.function): - asyncio.run(pyfuncitem.obj(**pyfuncitem.funcargs)) + """Allow bare async tests to run without pytest.mark.asyncio.""" + test_function = pyfuncitem.function + if inspect.iscoroutinefunction(test_function): + func = pyfuncitem.obj + signature = inspect.signature(func) + accepted_kwargs: Dict[str, Any] = {} + for name, parameter in signature.parameters.items(): + if parameter.kind is inspect.Parameter.VAR_POSITIONAL: + continue + if parameter.kind is inspect.Parameter.VAR_KEYWORD: + accepted_kwargs = dict(pyfuncitem.funcargs) + break + if name in pyfuncitem.funcargs: + accepted_kwargs[name] = pyfuncitem.funcargs[name] + + original_policy = asyncio.get_event_loop_policy() + policy = pyfuncitem.funcargs.get("event_loop_policy") + if policy is not None and policy is not original_policy: + asyncio.set_event_loop_policy(policy) + try: + asyncio.run(func(**accepted_kwargs)) + finally: + if policy is not None and policy is not original_policy: + asyncio.set_event_loop_policy(original_policy) return True return None @@ -196,3 +218,5 @@ def mock_scanner(mock_cache: MockCache, mock_hash_index: MockHashIndex) -> MockS @pytest.fixture def mock_service(mock_scanner: MockScanner) -> MockModelService: return MockModelService(scanner=mock_scanner) + +