diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 30d1ee89..4435b40d 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -330,6 +330,12 @@ async def push_notification_callback() -> None: except Exception: logger.exception('Agent execution failed') + # If the consumer fails, we must cancel the producer to prevent it from hanging + # on queue operations (e.g., waiting for the queue to drain). + producer_task.cancel() + # Force the queue to close immediately, discarding any pending events. + # This ensures that any producers waiting on the queue are unblocked. + await queue.close(immediate=True) raise finally: if interrupted_or_non_blocking: @@ -392,6 +398,12 @@ async def on_message_send_stream( bg_task.set_name(f'background_consume:{task_id}') self._track_background_task(bg_task) raise + except Exception: + # If the consumer fails (e.g. database error), we must cleanup. + logger.exception('Agent execution failed during streaming') + producer_task.cancel() + await queue.close(immediate=True) + raise finally: cleanup_task = asyncio.create_task( self._cleanup_producer(producer_task, task_id) @@ -435,7 +447,14 @@ async def _cleanup_producer( task_id: str, ) -> None: """Cleans up the agent execution task and queue manager entry.""" - await producer_task + try: + await producer_task + except asyncio.CancelledError: + logger.debug( + 'Producer task %s was cancelled during cleanup', task_id + ) + except Exception: + logger.exception('Producer task %s failed during cleanup', task_id) await self._queue_manager.close(task_id) async with self._running_agents_lock: self._running_agents.pop(task_id, None) diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index 88dd77ab..3e23beac 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -2644,3 +2644,91 @@ async def test_on_message_send_stream_task_id_provided_but_task_not_found(): f'Task {task_id} was specified but does not exist' in exc_info.value.error.message ) + + +@pytest.mark.asyncio +async def test_on_message_send_stream_consumer_error_cancels_producer_and_closes_queue(): + """Test that if the consumer (result aggregator) raises an exception, the producer is cancelled and queue is closed immediately.""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_queue_manager = AsyncMock(spec=QueueManager) + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_request_context_builder = AsyncMock(spec=RequestContextBuilder) + + task_id = 'error_cleanup_task' + context_id = 'error_cleanup_ctx' + + mock_request_context = MagicMock(spec=RequestContext) + mock_request_context.task_id = task_id + mock_request_context.context_id = context_id + mock_request_context_builder.build.return_value = mock_request_context + + mock_queue = AsyncMock(spec=EventQueue) + mock_queue_manager.create_or_tap.return_value = mock_queue + + request_handler = DefaultRequestHandler( + agent_executor=mock_agent_executor, + task_store=mock_task_store, + queue_manager=mock_queue_manager, + request_context_builder=mock_request_context_builder, + ) + + params = MessageSendParams( + message=Message( + role=Role.user, + message_id='msg_error_cleanup', + parts=[], + # Do NOT provide task_id here to avoid "Task ... was specified but does not exist" error + ) + ) + + # Mock ResultAggregator to raise exception + mock_result_aggregator_instance = MagicMock(spec=ResultAggregator) + + async def raise_error_gen(_consumer): + # Raise an exception to simulate consumer failure + raise ValueError('Consumer failed!') + yield # unreachable + + mock_result_aggregator_instance.consume_and_emit.side_effect = ( + raise_error_gen + ) + + # Capture the producer task to verify cancellation + captured_producer_task = None + original_register = request_handler._register_producer + + async def spy_register_producer(tid, task): + nonlocal captured_producer_task + captured_producer_task = task + # Wrap the cancel method to spy on it + task.cancel = MagicMock(wraps=task.cancel) + await original_register(tid, task) + + with ( + patch( + 'a2a.server.request_handlers.default_request_handler.ResultAggregator', + return_value=mock_result_aggregator_instance, + ), + patch( + 'a2a.server.request_handlers.default_request_handler.TaskManager.get_task', + return_value=None, + ), + patch.object( + request_handler, + '_register_producer', + side_effect=spy_register_producer, + ), + ): + # Act + with pytest.raises(ValueError, match='Consumer failed!'): + async for _ in request_handler.on_message_send_stream( + params, create_server_call_context() + ): + pass + + assert captured_producer_task is not None + # Verify producer was cancelled + captured_producer_task.cancel.assert_called() + + # Verify queue closed immediately + mock_queue.close.assert_awaited_with(immediate=True) diff --git a/tests/server/request_handlers/test_jsonrpc_handler.py b/tests/server/request_handlers/test_jsonrpc_handler.py index d1ead021..d10d544a 100644 --- a/tests/server/request_handlers/test_jsonrpc_handler.py +++ b/tests/server/request_handlers/test_jsonrpc_handler.py @@ -322,7 +322,6 @@ async def streaming_coro(): self.assertIsInstance(response.root, JSONRPCErrorResponse) assert response.root.error == UnsupportedOperationError() # type: ignore - mock_agent_executor.execute.assert_called_once() @patch( 'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build'