Skip to content

Commit a149a09

Browse files
feat: use StreamResponse as push notifications payload (#724)
## Description As per the 1.0 spec update (see [4.3.3. Push Notification Payload](https://a2a-protocol.org/latest/specification/#433-push-notification-payload)) use `StreamResponse` as push notifications payload. Fixes #678 --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent e67934b commit a149a09

File tree

10 files changed

+163
-63
lines changed

10 files changed

+163
-63
lines changed

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from a2a.server.request_handlers.request_handler import RequestHandler
2222
from a2a.server.tasks import (
2323
PushNotificationConfigStore,
24+
PushNotificationEvent,
2425
PushNotificationSender,
2526
ResultAggregator,
2627
TaskManager,
@@ -319,13 +320,15 @@ def _validate_task_id_match(self, task_id: str, event_task_id: str) -> None:
319320
)
320321

321322
async def _send_push_notification_if_needed(
322-
self, task_id: str, result_aggregator: ResultAggregator
323+
self, task_id: str, event: Event
323324
) -> None:
324-
"""Sends push notification if configured and task is available."""
325-
if self._push_sender and task_id:
326-
latest_task = await result_aggregator.current_result
327-
if isinstance(latest_task, Task):
328-
await self._push_sender.send_notification(latest_task)
325+
"""Sends push notification if configured."""
326+
if (
327+
self._push_sender
328+
and task_id
329+
and isinstance(event, PushNotificationEvent)
330+
):
331+
await self._push_sender.send_notification(task_id, event)
329332

330333
async def on_message_send(
331334
self,
@@ -357,10 +360,8 @@ async def on_message_send(
357360
interrupted_or_non_blocking = False
358361
try:
359362
# Create async callback for push notifications
360-
async def push_notification_callback() -> None:
361-
await self._send_push_notification_if_needed(
362-
task_id, result_aggregator
363-
)
363+
async def push_notification_callback(event: Event) -> None:
364+
await self._send_push_notification_if_needed(task_id, event)
364365

365366
(
366367
result,
@@ -393,8 +394,6 @@ async def push_notification_callback() -> None:
393394
if params.configuration:
394395
result = apply_history_length(result, params.configuration)
395396

396-
await self._send_push_notification_if_needed(task_id, result_aggregator)
397-
398397
return result
399398

400399
async def on_message_send_stream(
@@ -422,9 +421,7 @@ async def on_message_send_stream(
422421
if isinstance(event, Task):
423422
self._validate_task_id_match(task_id, event.id)
424423

425-
await self._send_push_notification_if_needed(
426-
task_id, result_aggregator
427-
)
424+
await self._send_push_notification_if_needed(task_id, event)
428425
yield event
429426
except (asyncio.CancelledError, GeneratorExit):
430427
# Client disconnected: continue consuming and persisting events in the background

src/a2a/server/tasks/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
from a2a.server.tasks.push_notification_config_store import (
1313
PushNotificationConfigStore,
1414
)
15-
from a2a.server.tasks.push_notification_sender import PushNotificationSender
15+
from a2a.server.tasks.push_notification_sender import (
16+
PushNotificationEvent,
17+
PushNotificationSender,
18+
)
1619
from a2a.server.tasks.result_aggregator import ResultAggregator
1720
from a2a.server.tasks.task_manager import TaskManager
1821
from a2a.server.tasks.task_store import TaskStore
@@ -72,6 +75,7 @@ def __init__(self, *args, **kwargs):
7275
'InMemoryPushNotificationConfigStore',
7376
'InMemoryTaskStore',
7477
'PushNotificationConfigStore',
78+
'PushNotificationEvent',
7579
'PushNotificationSender',
7680
'ResultAggregator',
7781
'TaskManager',

src/a2a/server/tasks/base_push_notification_sender.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@
88
from a2a.server.tasks.push_notification_config_store import (
99
PushNotificationConfigStore,
1010
)
11-
from a2a.server.tasks.push_notification_sender import PushNotificationSender
12-
from a2a.types.a2a_pb2 import PushNotificationConfig, StreamResponse, Task
11+
from a2a.server.tasks.push_notification_sender import (
12+
PushNotificationEvent,
13+
PushNotificationSender,
14+
)
15+
from a2a.types.a2a_pb2 import PushNotificationConfig
16+
from a2a.utils.proto_utils import to_stream_response
1317

1418

1519
logger = logging.getLogger(__name__)
@@ -32,44 +36,50 @@ def __init__(
3236
self._client = httpx_client
3337
self._config_store = config_store
3438

35-
async def send_notification(self, task: Task) -> None:
36-
"""Sends a push notification for a task if configuration exists."""
37-
push_configs = await self._config_store.get_info(task.id)
39+
async def send_notification(
40+
self, task_id: str, event: PushNotificationEvent
41+
) -> None:
42+
"""Sends a push notification for an event if configuration exists."""
43+
push_configs = await self._config_store.get_info(task_id)
3844
if not push_configs:
3945
return
4046

4147
awaitables = [
42-
self._dispatch_notification(task, push_info)
48+
self._dispatch_notification(event, push_info, task_id)
4349
for push_info in push_configs
4450
]
4551
results = await asyncio.gather(*awaitables)
4652

4753
if not all(results):
4854
logger.warning(
49-
'Some push notifications failed to send for task_id=%s', task.id
55+
'Some push notifications failed to send for task_id=%s', task_id
5056
)
5157

5258
async def _dispatch_notification(
53-
self, task: Task, push_info: PushNotificationConfig
59+
self,
60+
event: PushNotificationEvent,
61+
push_info: PushNotificationConfig,
62+
task_id: str,
5463
) -> bool:
5564
url = push_info.url
5665
try:
5766
headers = None
5867
if push_info.token:
5968
headers = {'X-A2A-Notification-Token': push_info.token}
69+
6070
response = await self._client.post(
6171
url,
62-
json=MessageToDict(StreamResponse(task=task)),
72+
json=MessageToDict(to_stream_response(event)),
6373
headers=headers,
6474
)
6575
response.raise_for_status()
6676
logger.info(
67-
'Push-notification sent for task_id=%s to URL: %s', task.id, url
77+
'Push-notification sent for task_id=%s to URL: %s', task_id, url
6878
)
6979
except Exception:
7080
logger.exception(
7181
'Error sending push-notification for task_id=%s to URL: %s.',
72-
task.id,
82+
task_id,
7383
url,
7484
)
7585
return False
Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
11
from abc import ABC, abstractmethod
22

3-
from a2a.types.a2a_pb2 import Task
3+
from a2a.types.a2a_pb2 import (
4+
Task,
5+
TaskArtifactUpdateEvent,
6+
TaskStatusUpdateEvent,
7+
)
8+
9+
10+
PushNotificationEvent = Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
411

512

613
class PushNotificationSender(ABC):
714
"""Interface for sending push notifications for tasks."""
815

916
@abstractmethod
10-
async def send_notification(self, task: Task) -> None:
17+
async def send_notification(
18+
self, task_id: str, event: PushNotificationEvent
19+
) -> None:
1120
"""Sends a push notification containing the latest task state."""

src/a2a/server/tasks/result_aggregator.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ async def consume_and_break_on_interrupt(
9898
self,
9999
consumer: EventConsumer,
100100
blocking: bool = True,
101-
event_callback: Callable[[], Awaitable[None]] | None = None,
101+
event_callback: Callable[[Event], Awaitable[None]] | None = None,
102102
) -> tuple[Task | Message | None, bool]:
103103
"""Processes the event stream until completion or an interruptible state is encountered.
104104
@@ -131,6 +131,9 @@ async def consume_and_break_on_interrupt(
131131
return event, False
132132
await self.task_manager.process(event)
133133

134+
if event_callback:
135+
await event_callback(event)
136+
134137
should_interrupt = False
135138
is_auth_required = (
136139
isinstance(event, Task | TaskStatusUpdateEvent)
@@ -169,7 +172,7 @@ async def consume_and_break_on_interrupt(
169172
async def _continue_consuming(
170173
self,
171174
event_stream: AsyncIterator[Event],
172-
event_callback: Callable[[], Awaitable[None]] | None = None,
175+
event_callback: Callable[[Event], Awaitable[None]] | None = None,
173176
) -> None:
174177
"""Continues processing an event stream in a background task.
175178
@@ -183,4 +186,4 @@ async def _continue_consuming(
183186
async for event in event_stream:
184187
await self.task_manager.process(event)
185188
if event_callback:
186-
await event_callback()
189+
await event_callback(event)

tests/e2e/push_notifications/notifications_app.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
class Notification(BaseModel):
1313
"""Encapsulates default push notification data."""
1414

15-
task: dict[str, Any]
15+
event: dict[str, Any]
1616
token: str
1717

1818

@@ -36,20 +36,33 @@ async def add_notification(request: Request):
3636
try:
3737
json_data = await request.json()
3838
stream_response = ParseDict(json_data, StreamResponse())
39-
if not stream_response.HasField('task'):
39+
40+
payload_name = stream_response.WhichOneof('payload')
41+
task_id = None
42+
if payload_name:
43+
event_payload = getattr(stream_response, payload_name)
44+
# The 'Task' message uses 'id', while event messages use 'task_id'.
45+
task_id = getattr(
46+
event_payload, 'task_id', getattr(event_payload, 'id', None)
47+
)
48+
49+
if not task_id:
4050
raise HTTPException(
41-
status_code=400, detail='Missing task in StreamResponse'
51+
status_code=400,
52+
detail='Missing "task_id" in push notification.',
4253
)
43-
task = stream_response.task
54+
4455
except Exception as e:
4556
raise HTTPException(status_code=400, detail=str(e))
4657

4758
async with store_lock:
48-
if task.id not in store:
49-
store[task.id] = []
50-
store[task.id].append(
59+
if task_id not in store:
60+
store[task_id] = []
61+
store[task_id].append(
5162
Notification(
52-
task=MessageToDict(task, preserving_proto_field_name=True),
63+
event=MessageToDict(
64+
stream_response, preserving_proto_field_name=True
65+
),
5366
token=token,
5467
)
5568
)

tests/e2e/push_notifications/test_default_push_notification_support.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,22 @@ async def test_notification_triggering_with_in_message_config_e2e(
139139
notifications = await wait_for_n_notifications(
140140
http_client,
141141
f'{notifications_server}/{task.id}/notifications',
142-
n=1,
142+
n=2,
143143
)
144144
assert notifications[0].token == token
145-
# Notification.task is a dict from proto serialization
146-
assert notifications[0].task['id'] == task.id
147-
assert notifications[0].task['status']['state'] == 'TASK_STATE_COMPLETED'
145+
146+
# Verify exactly two consecutive events: SUBMITTED -> COMPLETED
147+
assert len(notifications) == 2
148+
149+
# 1. First event: SUBMITTED (Task)
150+
event0 = notifications[0].event
151+
state0 = event0['task'].get('status', {}).get('state')
152+
assert state0 == 'TASK_STATE_SUBMITTED'
153+
154+
# 2. Second event: COMPLETED (TaskStatusUpdateEvent)
155+
event1 = notifications[1].event
156+
state1 = event1['status_update'].get('status', {}).get('state')
157+
assert state1 == 'TASK_STATE_COMPLETED'
148158

149159

150160
@pytest.mark.asyncio
@@ -220,9 +230,9 @@ async def test_notification_triggering_after_config_change_e2e(
220230
f'{notifications_server}/{task.id}/notifications',
221231
n=1,
222232
)
223-
# Notification.task is a dict from proto serialization
224-
assert notifications[0].task['id'] == task.id
225-
assert notifications[0].task['status']['state'] == 'TASK_STATE_COMPLETED'
233+
event = notifications[0].event
234+
state = event['status_update'].get('status', {}).get('state', '')
235+
assert state == 'TASK_STATE_COMPLETED'
226236
assert notifications[0].token == token
227237

228238

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import contextlib
33
import logging
4+
import uuid
45
import time
56
import uuid
67
from typing import cast
@@ -669,6 +670,8 @@ async def mock_consume_and_break_on_interrupt(
669670
nonlocal event_callback_passed, event_callback_received
670671
event_callback_passed = event_callback is not None
671672
event_callback_received = event_callback
673+
if event_callback_received:
674+
await event_callback_received(final_task)
672675
return initial_task, True # interrupted = True for non-blocking
673676

674677
mock_result_aggregator_instance.consume_and_break_on_interrupt = (
@@ -706,7 +709,7 @@ async def mock_consume_and_break_on_interrupt(
706709
)
707710

708711
# Verify that the push notification was sent with the final task
709-
mock_push_sender.send_notification.assert_called_with(final_task)
712+
mock_push_sender.send_notification.assert_called_with(task_id, final_task)
710713

711714
# Verify that the push notification config was stored
712715
mock_push_notification_store.set_info.assert_awaited_once_with(
@@ -1418,8 +1421,12 @@ def sync_get_event_stream_gen_for_prop_test(*args, **kwargs):
14181421

14191422
# 2. send_notification called for each task event yielded by aggregator
14201423
assert mock_push_sender.send_notification.await_count == 2
1421-
mock_push_sender.send_notification.assert_any_await(event1_task_update)
1422-
mock_push_sender.send_notification.assert_any_await(event2_final_task)
1424+
mock_push_sender.send_notification.assert_any_await(
1425+
task_id, event1_task_update
1426+
)
1427+
mock_push_sender.send_notification.assert_any_await(
1428+
task_id, event2_final_task
1429+
)
14231430

14241431
mock_agent_executor.execute.assert_awaited_once()
14251432

tests/server/tasks/test_inmemory_push_notifications.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ async def test_send_notification_success(self) -> None:
156156
mock_response.status_code = 200
157157
self.mock_httpx_client.post.return_value = mock_response
158158

159-
await self.notifier.send_notification(task_data) # Pass only task_data
159+
await self.notifier.send_notification(task_id, task_data)
160160

161161
self.mock_httpx_client.post.assert_awaited_once()
162162
called_args, called_kwargs = self.mock_httpx_client.post.call_args
@@ -183,7 +183,7 @@ async def test_send_notification_with_token_success(self) -> None:
183183
mock_response.status_code = 200
184184
self.mock_httpx_client.post.return_value = mock_response
185185

186-
await self.notifier.send_notification(task_data) # Pass only task_data
186+
await self.notifier.send_notification(task_id, task_data)
187187

188188
self.mock_httpx_client.post.assert_awaited_once()
189189
called_args, called_kwargs = self.mock_httpx_client.post.call_args
@@ -205,7 +205,7 @@ async def test_send_notification_no_config(self) -> None:
205205
task_id = 'task_send_no_config'
206206
task_data = create_sample_task(task_id=task_id)
207207

208-
await self.notifier.send_notification(task_data) # Pass only task_data
208+
await self.notifier.send_notification(task_id, task_data)
209209

210210
self.mock_httpx_client.post.assert_not_called()
211211

@@ -229,7 +229,7 @@ async def test_send_notification_http_status_error(
229229
self.mock_httpx_client.post.side_effect = http_error
230230

231231
# The method should catch the error and log it, not re-raise
232-
await self.notifier.send_notification(task_data) # Pass only task_data
232+
await self.notifier.send_notification(task_id, task_data)
233233

234234
self.mock_httpx_client.post.assert_awaited_once()
235235
mock_logger.exception.assert_called_once()
@@ -251,7 +251,7 @@ async def test_send_notification_request_error(
251251
request_error = httpx.RequestError('Network issue', request=MagicMock())
252252
self.mock_httpx_client.post.side_effect = request_error
253253

254-
await self.notifier.send_notification(task_data) # Pass only task_data
254+
await self.notifier.send_notification(task_id, task_data)
255255

256256
self.mock_httpx_client.post.assert_awaited_once()
257257
mock_logger.exception.assert_called_once()
@@ -281,7 +281,7 @@ async def test_send_notification_with_auth(
281281
mock_response.status_code = 200
282282
self.mock_httpx_client.post.return_value = mock_response
283283

284-
await self.notifier.send_notification(task_data) # Pass only task_data
284+
await self.notifier.send_notification(task_id, task_data)
285285

286286
self.mock_httpx_client.post.assert_awaited_once()
287287
called_args, called_kwargs = self.mock_httpx_client.post.call_args

0 commit comments

Comments
 (0)