|
23 | 23 | from a2a.types import ( |
24 | 24 | InternalError, |
25 | 25 | Message, |
| 26 | + MessageSendConfiguration, |
26 | 27 | MessageSendParams, |
| 28 | + PushNotificationConfig, |
27 | 29 | Task, |
28 | 30 | TaskIdParams, |
29 | 31 | TaskNotFoundError, |
@@ -123,12 +125,10 @@ async def on_message_send( |
123 | 125 | task: Task | None = await task_manager.get_task() |
124 | 126 | if task: |
125 | 127 | task = task_manager.update_with_message(params.message, task) |
126 | | - if ( |
127 | | - self._push_notifier |
128 | | - and params.configuration |
129 | | - and params.configuration.pushNotificationConfig |
130 | | - and not params.configuration.blocking |
131 | | - ): |
| 128 | + if self.should_add_push_info(params): |
| 129 | + assert isinstance(self._push_notifier, PushNotifier) # For typechecker |
| 130 | + assert isinstance(params.configuration, MessageSendConfiguration) # For typechecker |
| 131 | + assert isinstance(params.configuration.pushNotificationConfig, PushNotificationConfig) # For typechecker |
132 | 132 | await self._push_notifier.set_info( |
133 | 133 | task.id, params.configuration.pushNotificationConfig |
134 | 134 | ) |
@@ -190,11 +190,10 @@ async def on_message_send_stream( |
190 | 190 | if task: |
191 | 191 | task = task_manager.update_with_message(params.message, task) |
192 | 192 |
|
193 | | - if ( |
194 | | - self._push_notifier |
195 | | - and params.configuration |
196 | | - and params.configuration.pushNotificationConfig |
197 | | - ): |
| 193 | + if self.should_add_push_info(params): |
| 194 | + assert isinstance(self._push_notifier, PushNotifier) # For typechecker |
| 195 | + assert isinstance(params.configuration, MessageSendConfiguration) # For typechecker |
| 196 | + assert isinstance(params.configuration.pushNotificationConfig, PushNotificationConfig) # For typechecker |
198 | 197 | await self._push_notifier.set_info( |
199 | 198 | task.id, params.configuration.pushNotificationConfig |
200 | 199 | ) |
@@ -319,3 +318,13 @@ async def on_resubscribe_to_task( |
319 | 318 | consumer = EventConsumer(queue) |
320 | 319 | async for event in result_aggregator.consume_and_emit(consumer): |
321 | 320 | yield event |
| 321 | + |
| 322 | + def should_add_push_info(self, params: MessageSendParams) -> bool: |
| 323 | + if ( |
| 324 | + self._push_notifier |
| 325 | + and params.configuration |
| 326 | + and params.configuration.pushNotificationConfig |
| 327 | + ): |
| 328 | + return True |
| 329 | + else: |
| 330 | + return False |
0 commit comments