Skip to content

Commit ebb20f3

Browse files
authored
Merge branch 'main' into kthota/typechanges
2 parents cd7cafa + d1869bb commit ebb20f3

File tree

21 files changed

+768
-43
lines changed

21 files changed

+768
-43
lines changed

.github/actions/spelling/excludes.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,7 @@
8181
\.xz$
8282
\.zip$
8383
^\.github/actions/spelling/
84-
^\Q.github/workflows/spelling.yaml\E$
85-
^\Q.github/workflows/linter.yaml\E$
84+
^\.github/workflows/
8685
\.gitignore\E$
8786
\.vscode/
8887
noxfile.py

.github/linters/.mypy.ini

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
[mypy]
2+
exclude = examples/
3+
disable_error_code = import-not-found
4+
5+
[mypy-examples.*]
6+
follow_imports = skip

.github/workflows/linter.yaml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,9 @@ name: Lint Code Base
1313
#############################
1414
# Start the job on all push #
1515
#############################
16-
# on:
17-
# pull_request:
18-
# branches: [main]
19-
on: workflow_dispatch
16+
on:
17+
pull_request:
18+
branches: [main]
2019

2120
###############
2221
# Set the Job #
@@ -64,3 +63,5 @@ jobs:
6463
VALIDATE_TYPESCRIPT_STANDARD: false
6564
VALIDATE_GIT_COMMITLINT: false
6665
MARKDOWN_CONFIG_FILE: .markdownlint.json
66+
PYTHON_MYPY_CONFIG_FILE: .mypy.ini
67+
FILTER_REGEX_EXCLUDE: "^examples/.*"

.github/workflows/python-publish.yml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424

2525
- name: Build
2626
run: uv build
27-
27+
2828
- name: Upload distributions
2929
uses: actions/upload-artifact@v4
3030
with:
@@ -49,6 +49,3 @@ jobs:
4949
uses: pypa/gh-action-pypi-publish@release/v1
5050
with:
5151
packages-dir: dist/
52-
53-
54-

examples/google_adk/birthday_planner/adk_agent_executor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# mypy: ignore-errors
12
import asyncio
23
import logging
34

examples/google_adk/calendar_agent/adk_agent_executor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# mypy: ignore-errors
12
import asyncio
23
import logging
34

@@ -53,7 +54,7 @@ def __init__(self, runner: Runner, card: AgentCard):
5354

5455
def _run_agent(
5556
self, session_id, new_message: types.Content
56-
) -> AsyncGenerator[Event, None]:
57+
) -> AsyncGenerator[Event]:
5758
return self.runner.run_async(
5859
session_id=session_id, user_id='self', new_message=new_message
5960
)

examples/langgraph/__main__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
import sys
33

44
import click
5+
import httpx
56

67
from agent import CurrencyAgent
78
from agent_executor import CurrencyAgentExecutor
89
from dotenv import load_dotenv
910

1011
from a2a.server.apps import A2AStarletteApplication
1112
from a2a.server.request_handlers import DefaultRequestHandler
12-
from a2a.server.tasks import InMemoryTaskStore
13+
from a2a.server.tasks import InMemoryPushNotifier, InMemoryTaskStore
1314
from a2a.types import (
1415
AgentAuthentication,
1516
AgentCapabilities,
@@ -29,9 +30,11 @@ def main(host: str, port: int):
2930
print('GOOGLE_API_KEY environment variable not set.')
3031
sys.exit(1)
3132

33+
client = httpx.AsyncClient()
3234
request_handler = DefaultRequestHandler(
3335
agent_executor=CurrencyAgentExecutor(),
3436
task_store=InMemoryTaskStore(),
37+
push_notifier=InMemoryPushNotifier(client),
3538
)
3639

3740
server = A2AStarletteApplication(

noxfile.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,15 @@ def format(session):
114114
'pyupgrade',
115115
'autoflake',
116116
'ruff',
117+
'no_implicit_optional',
117118
)
118119

119120
if lint_paths_py:
121+
session.run(
122+
'no_implicit_optional',
123+
'--use-union-or',
124+
*lint_paths_py,
125+
)
120126
if not format_all:
121127
session.run(
122128
'pyupgrade',

src/a2a/server/agent_execution/context.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from a2a.types import (
44
InvalidParamsError,
55
Message,
6-
MessageSendParams,
76
MessageSendConfiguration,
7+
MessageSendParams,
88
Task,
99
)
1010
from a2a.utils import get_message_text
@@ -82,6 +82,8 @@ def context_id(self) -> str | None:
8282

8383
@property
8484
def configuration(self) -> MessageSendConfiguration | None:
85+
if not self._params:
86+
return None
8587
return self._params.configuration
8688

8789
def _check_or_generate_task_id(self) -> None:

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 86 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,18 @@
1414
TaskQueueExists,
1515
)
1616
from a2a.server.request_handlers.request_handler import RequestHandler
17-
from a2a.server.tasks import ResultAggregator, TaskManager, TaskStore
17+
from a2a.server.tasks import (
18+
PushNotifier,
19+
ResultAggregator,
20+
TaskManager,
21+
TaskStore,
22+
)
1823
from a2a.types import (
1924
InternalError,
2025
Message,
26+
MessageSendConfiguration,
2127
MessageSendParams,
28+
PushNotificationConfig,
2229
Task,
2330
TaskIdParams,
2431
TaskNotFoundError,
@@ -44,10 +51,12 @@ def __init__(
4451
agent_executor: AgentExecutor,
4552
task_store: TaskStore,
4653
queue_manager: QueueManager | None = None,
54+
push_notifier: PushNotifier | None = None,
4755
) -> None:
4856
self.agent_executor = agent_executor
4957
self.task_store = task_store
5058
self._queue_manager = queue_manager or InMemoryQueueManager()
59+
self._push_notifier = push_notifier
5160
# TODO: Likely want an interface for managing this, like AgentExecutionManager.
5261
self._running_agents = {}
5362
self._running_agents_lock = asyncio.Lock()
@@ -118,6 +127,18 @@ async def on_message_send(
118127
task: Task | None = await task_manager.get_task()
119128
if task:
120129
task = task_manager.update_with_message(params.message, task)
130+
if self.should_add_push_info(params):
131+
assert isinstance(self._push_notifier, PushNotifier)
132+
assert isinstance(
133+
params.configuration, MessageSendConfiguration
134+
)
135+
assert isinstance(
136+
params.configuration.pushNotificationConfig,
137+
PushNotificationConfig,
138+
)
139+
await self._push_notifier.set_info(
140+
task.id, params.configuration.pushNotificationConfig
141+
)
121142
request_context = RequestContext(
122143
params,
123144
task.id if task else None,
@@ -176,6 +197,20 @@ async def on_message_send_stream(
176197
if task:
177198
task = task_manager.update_with_message(params.message, task)
178199

200+
if self.should_add_push_info(params):
201+
assert isinstance(self._push_notifier, PushNotifier)
202+
assert isinstance(
203+
params.configuration, MessageSendConfiguration
204+
)
205+
assert isinstance(
206+
params.configuration.pushNotificationConfig,
207+
PushNotificationConfig,
208+
)
209+
await self._push_notifier.set_info(
210+
task.id, params.configuration.pushNotificationConfig
211+
)
212+
else:
213+
queue = EventQueue()
179214
result_aggregator = ResultAggregator(task_manager)
180215
request_context = RequestContext(
181216
params,
@@ -202,12 +237,26 @@ async def on_message_send_stream(
202237
f'Agent generated task_id={event.id} does not match the RequestContext task_id={task_id}.'
203238
)
204239
try:
205-
await self._queue_manager.add(event.id, queue)
206-
task_id = event.id
240+
created_task: Task = event
241+
await self._queue_manager.add(created_task.id, queue)
242+
task_id = created_task.id
207243
except TaskQueueExists:
208244
logging.info(
209245
'Multiple Task objects created in event stream.'
210246
)
247+
if (
248+
self._push_notifier
249+
and params.configuration
250+
and params.configuration.pushNotificationConfig
251+
):
252+
await self._push_notifier.set_info(
253+
created_task.id,
254+
params.configuration.pushNotificationConfig,
255+
)
256+
if self._push_notifier and task_id:
257+
latest_task = await result_aggregator.current_result
258+
if isinstance(latest_task, Task):
259+
await self._push_notifier.send_notification(latest_task)
211260
yield event
212261
finally:
213262
await self._cleanup_producer(producer_task, task_id)
@@ -226,13 +275,38 @@ async def on_set_task_push_notification_config(
226275
self, params: TaskPushNotificationConfig
227276
) -> TaskPushNotificationConfig:
228277
"""Default handler for 'tasks/pushNotificationConfig/set'."""
229-
raise ServerError(error=UnsupportedOperationError())
278+
if not self._push_notifier:
279+
raise ServerError(error=UnsupportedOperationError())
280+
281+
task: Task | None = await self.task_store.get(params.taskId)
282+
if not task:
283+
raise ServerError(error=TaskNotFoundError())
284+
285+
await self._push_notifier.set_info(
286+
params.taskId,
287+
params.pushNotificationConfig,
288+
)
289+
290+
return params
230291

231292
async def on_get_task_push_notification_config(
232293
self, params: TaskIdParams
233294
) -> TaskPushNotificationConfig:
234295
"""Default handler for 'tasks/pushNotificationConfig/get'."""
235-
raise ServerError(error=UnsupportedOperationError())
296+
if not self._push_notifier:
297+
raise ServerError(error=UnsupportedOperationError())
298+
299+
task: Task | None = await self.task_store.get(params.id)
300+
if not task:
301+
raise ServerError(error=TaskNotFoundError())
302+
303+
push_notification_config = await self._push_notifier.get_info(params.id)
304+
if not push_notification_config:
305+
raise ServerError(error=InternalError())
306+
307+
return TaskPushNotificationConfig(
308+
taskId=params.id, pushNotificationConfig=push_notification_config
309+
)
236310

237311
async def on_resubscribe_to_task(
238312
self, params: TaskIdParams
@@ -258,3 +332,10 @@ async def on_resubscribe_to_task(
258332
consumer = EventConsumer(queue)
259333
async for event in result_aggregator.consume_and_emit(consumer):
260334
yield event
335+
336+
def should_add_push_info(self, params: MessageSendParams) -> bool:
337+
return bool(
338+
self._push_notifier
339+
and params.configuration
340+
and params.configuration.pushNotificationConfig
341+
)

0 commit comments

Comments
 (0)