Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/a2a/server/request_handlers/default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,12 @@ async def push_notification_callback() -> None:

except Exception:
logger.exception('Agent execution failed')
# If the consumer fails, we must cancel the producer to prevent it from hanging
# on queue operations (e.g., waiting for the queue to drain).
producer_task.cancel()
# Force the queue to close immediately, discarding any pending events.
# This ensures that any producers waiting on the queue are unblocked.
await queue.close(immediate=True)
raise
finally:
if interrupted_or_non_blocking:
Expand Down Expand Up @@ -392,6 +398,12 @@ async def on_message_send_stream(
bg_task.set_name(f'background_consume:{task_id}')
self._track_background_task(bg_task)
raise
except Exception:
# If the consumer fails (e.g. database error), we must cleanup.
logger.exception('Agent execution failed during streaming')
producer_task.cancel()
await queue.close(immediate=True)
raise
finally:
cleanup_task = asyncio.create_task(
self._cleanup_producer(producer_task, task_id)
Expand Down Expand Up @@ -435,7 +447,14 @@ async def _cleanup_producer(
task_id: str,
) -> None:
"""Cleans up the agent execution task and queue manager entry."""
await producer_task
try:
await producer_task
except asyncio.CancelledError:
logger.debug(
'Producer task %s was cancelled during cleanup', task_id
)
except Exception:
logger.exception('Producer task %s failed during cleanup', task_id)
await self._queue_manager.close(task_id)
async with self._running_agents_lock:
self._running_agents.pop(task_id, None)
Expand Down
88 changes: 88 additions & 0 deletions tests/server/request_handlers/test_default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2644,3 +2644,91 @@ async def test_on_message_send_stream_task_id_provided_but_task_not_found():
f'Task {task_id} was specified but does not exist'
in exc_info.value.error.message
)


@pytest.mark.asyncio
async def test_on_message_send_stream_consumer_error_cancels_producer_and_closes_queue():
"""Test that if the consumer (result aggregator) raises an exception, the producer is cancelled and queue is closed immediately."""
mock_task_store = AsyncMock(spec=TaskStore)
mock_queue_manager = AsyncMock(spec=QueueManager)
mock_agent_executor = AsyncMock(spec=AgentExecutor)
mock_request_context_builder = AsyncMock(spec=RequestContextBuilder)

task_id = 'error_cleanup_task'
context_id = 'error_cleanup_ctx'

mock_request_context = MagicMock(spec=RequestContext)
mock_request_context.task_id = task_id
mock_request_context.context_id = context_id
mock_request_context_builder.build.return_value = mock_request_context

mock_queue = AsyncMock(spec=EventQueue)
mock_queue_manager.create_or_tap.return_value = mock_queue

request_handler = DefaultRequestHandler(
agent_executor=mock_agent_executor,
task_store=mock_task_store,
queue_manager=mock_queue_manager,
request_context_builder=mock_request_context_builder,
)

params = MessageSendParams(
message=Message(
role=Role.user,
message_id='msg_error_cleanup',
parts=[],
# Do NOT provide task_id here to avoid "Task ... was specified but does not exist" error
)
)

# Mock ResultAggregator to raise exception
mock_result_aggregator_instance = MagicMock(spec=ResultAggregator)

async def raise_error_gen(_consumer):
# Raise an exception to simulate consumer failure
raise ValueError('Consumer failed!')
yield # unreachable

mock_result_aggregator_instance.consume_and_emit.side_effect = (
raise_error_gen
)

# Capture the producer task to verify cancellation
captured_producer_task = None
original_register = request_handler._register_producer

async def spy_register_producer(tid, task):
nonlocal captured_producer_task
captured_producer_task = task
# Wrap the cancel method to spy on it
task.cancel = MagicMock(wraps=task.cancel)
await original_register(tid, task)

with (
patch(
'a2a.server.request_handlers.default_request_handler.ResultAggregator',
return_value=mock_result_aggregator_instance,
),
patch(
'a2a.server.request_handlers.default_request_handler.TaskManager.get_task',
return_value=None,
),
patch.object(
request_handler,
'_register_producer',
side_effect=spy_register_producer,
),
):
# Act
with pytest.raises(ValueError, match='Consumer failed!'):
async for _ in request_handler.on_message_send_stream(
params, create_server_call_context()
):
pass

assert captured_producer_task is not None
# Verify producer was cancelled
captured_producer_task.cancel.assert_called()

# Verify queue closed immediately
mock_queue.close.assert_awaited_with(immediate=True)
1 change: 0 additions & 1 deletion tests/server/request_handlers/test_jsonrpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,6 @@ async def streaming_coro():

self.assertIsInstance(response.root, JSONRPCErrorResponse)
assert response.root.error == UnsupportedOperationError() # type: ignore
mock_agent_executor.execute.assert_called_once()

@patch(
'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build'
Expand Down
Loading