Skip to content

Commit 1e4b574

Browse files
authored
perf: Improve performance and code style for proto_utils.py (#452)
- Pre-compile regular expressions - Use `cls` instead of `ClassName` - Change `ToProto.data()` to use `dict_to_struct()` - Reduce duplication by combining `ToProto.update_event()` and `ToProto.stream_response()` - Added missing conversion for type `MutualTlsSecurityScheme`
1 parent e3e5c4b commit 1e4b574

File tree

1 file changed

+80
-94
lines changed

1 file changed

+80
-94
lines changed

src/a2a/utils/proto_utils.py

Lines changed: 80 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,34 @@
1818

1919

2020
# Regexp patterns for matching
21-
_TASK_NAME_MATCH = r'tasks/([\w-]+)'
22-
_TASK_PUSH_CONFIG_NAME_MATCH = (
21+
_TASK_NAME_MATCH = re.compile(r'tasks/([\w-]+)')
22+
_TASK_PUSH_CONFIG_NAME_MATCH = re.compile(
2323
r'tasks/([\w-]+)/pushNotificationConfigs/([\w-]+)'
2424
)
2525

2626

27+
def dict_to_struct(dictionary: dict[str, Any]) -> struct_pb2.Struct:
28+
"""Converts a Python dict to a Struct proto.
29+
30+
Unfortunately, using `json_format.ParseDict` does not work because this
31+
wants the dictionary to be an exact match of the Struct proto with fields
32+
and keys and values, not the traditional Python dict structure.
33+
34+
Args:
35+
dictionary: The Python dict to convert.
36+
37+
Returns:
38+
The Struct proto.
39+
"""
40+
struct = struct_pb2.Struct()
41+
for key, val in dictionary.items():
42+
if isinstance(val, dict):
43+
struct[key] = dict_to_struct(val)
44+
else:
45+
struct[key] = val
46+
return struct
47+
48+
2749
class ToProto:
2850
"""Converts Python types to proto types."""
2951

@@ -33,11 +55,11 @@ def message(cls, message: types.Message | None) -> a2a_pb2.Message | None:
3355
return None
3456
return a2a_pb2.Message(
3557
message_id=message.message_id,
36-
content=[ToProto.part(p) for p in message.parts],
58+
content=[cls.part(p) for p in message.parts],
3759
context_id=message.context_id or '',
3860
task_id=message.task_id or '',
3961
role=cls.role(message.role),
40-
metadata=ToProto.metadata(message.metadata),
62+
metadata=cls.metadata(message.metadata),
4163
)
4264

4365
@classmethod
@@ -53,20 +75,14 @@ def part(cls, part: types.Part) -> a2a_pb2.Part:
5375
if isinstance(part.root, types.TextPart):
5476
return a2a_pb2.Part(text=part.root.text)
5577
if isinstance(part.root, types.FilePart):
56-
return a2a_pb2.Part(file=ToProto.file(part.root.file))
78+
return a2a_pb2.Part(file=cls.file(part.root.file))
5779
if isinstance(part.root, types.DataPart):
58-
return a2a_pb2.Part(data=ToProto.data(part.root.data))
80+
return a2a_pb2.Part(data=cls.data(part.root.data))
5981
raise ValueError(f'Unsupported part type: {part.root}')
6082

6183
@classmethod
6284
def data(cls, data: dict[str, Any]) -> a2a_pb2.DataPart:
63-
json_data = json.dumps(data)
64-
return a2a_pb2.DataPart(
65-
data=json_format.Parse(
66-
json_data,
67-
struct_pb2.Struct(),
68-
)
69-
)
85+
return a2a_pb2.DataPart(data=dict_to_struct(data))
7086

7187
@classmethod
7288
def file(
@@ -87,14 +103,14 @@ def task(cls, task: types.Task) -> a2a_pb2.Task:
87103
return a2a_pb2.Task(
88104
id=task.id,
89105
context_id=task.context_id,
90-
status=ToProto.task_status(task.status),
106+
status=cls.task_status(task.status),
91107
artifacts=(
92-
[ToProto.artifact(a) for a in task.artifacts]
108+
[cls.artifact(a) for a in task.artifacts]
93109
if task.artifacts
94110
else None
95111
),
96112
history=(
97-
[ToProto.message(h) for h in task.history] # type: ignore[misc]
113+
[cls.message(h) for h in task.history] # type: ignore[misc]
98114
if task.history
99115
else None
100116
),
@@ -103,8 +119,8 @@ def task(cls, task: types.Task) -> a2a_pb2.Task:
103119
@classmethod
104120
def task_status(cls, status: types.TaskStatus) -> a2a_pb2.TaskStatus:
105121
return a2a_pb2.TaskStatus(
106-
state=ToProto.task_state(status.state),
107-
update=ToProto.message(status.message),
122+
state=cls.task_state(status.state),
123+
update=cls.message(status.message),
108124
)
109125

110126
@classmethod
@@ -132,9 +148,9 @@ def artifact(cls, artifact: types.Artifact) -> a2a_pb2.Artifact:
132148
return a2a_pb2.Artifact(
133149
artifact_id=artifact.artifact_id,
134150
description=artifact.description,
135-
metadata=ToProto.metadata(artifact.metadata),
151+
metadata=cls.metadata(artifact.metadata),
136152
name=artifact.name,
137-
parts=[ToProto.part(p) for p in artifact.parts],
153+
parts=[cls.part(p) for p in artifact.parts],
138154
)
139155

140156
@classmethod
@@ -151,7 +167,7 @@ def push_notification_config(
151167
cls, config: types.PushNotificationConfig
152168
) -> a2a_pb2.PushNotificationConfig:
153169
auth_info = (
154-
ToProto.authentication_info(config.authentication)
170+
cls.authentication_info(config.authentication)
155171
if config.authentication
156172
else None
157173
)
@@ -169,8 +185,8 @@ def task_artifact_update_event(
169185
return a2a_pb2.TaskArtifactUpdateEvent(
170186
task_id=event.task_id,
171187
context_id=event.context_id,
172-
artifact=ToProto.artifact(event.artifact),
173-
metadata=ToProto.metadata(event.metadata),
188+
artifact=cls.artifact(event.artifact),
189+
metadata=cls.metadata(event.metadata),
174190
append=event.append or False,
175191
last_chunk=event.last_chunk or False,
176192
)
@@ -182,8 +198,8 @@ def task_status_update_event(
182198
return a2a_pb2.TaskStatusUpdateEvent(
183199
task_id=event.task_id,
184200
context_id=event.context_id,
185-
status=ToProto.task_status(event.status),
186-
metadata=ToProto.metadata(event.metadata),
201+
status=cls.task_status(event.status),
202+
metadata=cls.metadata(event.metadata),
187203
final=event.final,
188204
)
189205

@@ -195,7 +211,7 @@ def message_send_configuration(
195211
return a2a_pb2.SendMessageConfiguration()
196212
return a2a_pb2.SendMessageConfiguration(
197213
accepted_output_modes=config.accepted_output_modes,
198-
push_notification=ToProto.push_notification_config(
214+
push_notification=cls.push_notification_config(
199215
config.push_notification_config
200216
)
201217
if config.push_notification_config
@@ -213,19 +229,7 @@ def update_event(
213229
| types.TaskArtifactUpdateEvent,
214230
) -> a2a_pb2.StreamResponse:
215231
"""Converts a task, message, or task update event to a StreamResponse."""
216-
if isinstance(event, types.TaskStatusUpdateEvent):
217-
return a2a_pb2.StreamResponse(
218-
status_update=ToProto.task_status_update_event(event)
219-
)
220-
if isinstance(event, types.TaskArtifactUpdateEvent):
221-
return a2a_pb2.StreamResponse(
222-
artifact_update=ToProto.task_artifact_update_event(event)
223-
)
224-
if isinstance(event, types.Message):
225-
return a2a_pb2.StreamResponse(msg=ToProto.message(event))
226-
if isinstance(event, types.Task):
227-
return a2a_pb2.StreamResponse(task=ToProto.task(event))
228-
raise ValueError(f'Unsupported event type: {type(event)}')
232+
return cls.stream_response(event)
229233

230234
@classmethod
231235
def task_or_message(
@@ -257,9 +261,11 @@ def stream_response(
257261
return a2a_pb2.StreamResponse(
258262
status_update=cls.task_status_update_event(event),
259263
)
260-
return a2a_pb2.StreamResponse(
261-
artifact_update=cls.task_artifact_update_event(event),
262-
)
264+
if isinstance(event, types.TaskArtifactUpdateEvent):
265+
return a2a_pb2.StreamResponse(
266+
artifact_update=cls.task_artifact_update_event(event),
267+
)
268+
raise ValueError(f'Unsupported event type: {type(event)}')
263269

264270
@classmethod
265271
def task_push_notification_config(
@@ -480,11 +486,11 @@ class FromProto:
480486
def message(cls, message: a2a_pb2.Message) -> types.Message:
481487
return types.Message(
482488
message_id=message.message_id,
483-
parts=[FromProto.part(p) for p in message.content],
489+
parts=[cls.part(p) for p in message.content],
484490
context_id=message.context_id or None,
485491
task_id=message.task_id or None,
486-
role=FromProto.role(message.role),
487-
metadata=FromProto.metadata(message.metadata),
492+
role=cls.role(message.role),
493+
metadata=cls.metadata(message.metadata),
488494
)
489495

490496
@classmethod
@@ -498,13 +504,9 @@ def part(cls, part: a2a_pb2.Part) -> types.Part:
498504
if part.HasField('text'):
499505
return types.Part(root=types.TextPart(text=part.text))
500506
if part.HasField('file'):
501-
return types.Part(
502-
root=types.FilePart(file=FromProto.file(part.file))
503-
)
507+
return types.Part(root=types.FilePart(file=cls.file(part.file)))
504508
if part.HasField('data'):
505-
return types.Part(
506-
root=types.DataPart(data=FromProto.data(part.data))
507-
)
509+
return types.Part(root=types.DataPart(data=cls.data(part.data)))
508510
raise ValueError(f'Unsupported part type: {part}')
509511

510512
@classmethod
@@ -543,16 +545,16 @@ def task(cls, task: a2a_pb2.Task) -> types.Task:
543545
return types.Task(
544546
id=task.id,
545547
context_id=task.context_id,
546-
status=FromProto.task_status(task.status),
547-
artifacts=[FromProto.artifact(a) for a in task.artifacts],
548-
history=[FromProto.message(h) for h in task.history],
548+
status=cls.task_status(task.status),
549+
artifacts=[cls.artifact(a) for a in task.artifacts],
550+
history=[cls.message(h) for h in task.history],
549551
)
550552

551553
@classmethod
552554
def task_status(cls, status: a2a_pb2.TaskStatus) -> types.TaskStatus:
553555
return types.TaskStatus(
554-
state=FromProto.task_state(status.state),
555-
message=FromProto.message(status.update),
556+
state=cls.task_state(status.state),
557+
message=cls.message(status.update),
556558
)
557559

558560
@classmethod
@@ -580,9 +582,9 @@ def artifact(cls, artifact: a2a_pb2.Artifact) -> types.Artifact:
580582
return types.Artifact(
581583
artifact_id=artifact.artifact_id,
582584
description=artifact.description,
583-
metadata=FromProto.metadata(artifact.metadata),
585+
metadata=cls.metadata(artifact.metadata),
584586
name=artifact.name,
585-
parts=[FromProto.part(p) for p in artifact.parts],
587+
parts=[cls.part(p) for p in artifact.parts],
586588
)
587589

588590
@classmethod
@@ -592,8 +594,8 @@ def task_artifact_update_event(
592594
return types.TaskArtifactUpdateEvent(
593595
task_id=event.task_id,
594596
context_id=event.context_id,
595-
artifact=FromProto.artifact(event.artifact),
596-
metadata=FromProto.metadata(event.metadata),
597+
artifact=cls.artifact(event.artifact),
598+
metadata=cls.metadata(event.metadata),
597599
append=event.append,
598600
last_chunk=event.last_chunk,
599601
)
@@ -605,8 +607,8 @@ def task_status_update_event(
605607
return types.TaskStatusUpdateEvent(
606608
task_id=event.task_id,
607609
context_id=event.context_id,
608-
status=FromProto.task_status(event.status),
609-
metadata=FromProto.metadata(event.metadata),
610+
status=cls.task_status(event.status),
611+
metadata=cls.metadata(event.metadata),
610612
final=event.final,
611613
)
612614

@@ -618,7 +620,7 @@ def push_notification_config(
618620
id=config.id,
619621
url=config.url,
620622
token=config.token,
621-
authentication=FromProto.authentication_info(config.authentication)
623+
authentication=cls.authentication_info(config.authentication)
622624
if config.HasField('authentication')
623625
else None,
624626
)
@@ -638,7 +640,7 @@ def message_send_configuration(
638640
) -> types.MessageSendConfiguration:
639641
return types.MessageSendConfiguration(
640642
accepted_output_modes=list(config.accepted_output_modes),
641-
push_notification_config=FromProto.push_notification_config(
643+
push_notification_config=cls.push_notification_config(
642644
config.push_notification
643645
)
644646
if config.HasField('push_notification')
@@ -666,18 +668,16 @@ def task_id_params(
666668
| a2a_pb2.GetTaskPushNotificationConfigRequest
667669
),
668670
) -> types.TaskIdParams:
669-
# This is currently incomplete until the core sdk supports multiple
670-
# configs for a single task.
671671
if isinstance(request, a2a_pb2.GetTaskPushNotificationConfigRequest):
672-
m = re.match(_TASK_PUSH_CONFIG_NAME_MATCH, request.name)
672+
m = _TASK_PUSH_CONFIG_NAME_MATCH.match(request.name)
673673
if not m:
674674
raise ServerError(
675675
error=types.InvalidParamsError(
676676
message=f'No task for {request.name}'
677677
)
678678
)
679679
return types.TaskIdParams(id=m.group(1))
680-
m = re.match(_TASK_NAME_MATCH, request.name)
680+
m = _TASK_NAME_MATCH.match(request.name)
681681
if not m:
682682
raise ServerError(
683683
error=types.InvalidParamsError(
@@ -691,7 +691,7 @@ def task_push_notification_config_request(
691691
cls,
692692
request: a2a_pb2.CreateTaskPushNotificationConfigRequest,
693693
) -> types.TaskPushNotificationConfig:
694-
m = re.match(_TASK_NAME_MATCH, request.parent)
694+
m = _TASK_NAME_MATCH.match(request.parent)
695695
if not m:
696696
raise ServerError(
697697
error=types.InvalidParamsError(
@@ -710,7 +710,7 @@ def task_push_notification_config(
710710
cls,
711711
config: a2a_pb2.TaskPushNotificationConfig,
712712
) -> types.TaskPushNotificationConfig:
713-
m = re.match(_TASK_PUSH_CONFIG_NAME_MATCH, config.name)
713+
m = _TASK_PUSH_CONFIG_NAME_MATCH.match(config.name)
714714
if not m:
715715
raise ServerError(
716716
error=types.InvalidParamsError(
@@ -767,7 +767,7 @@ def task_query_params(
767767
cls,
768768
request: a2a_pb2.GetTaskRequest,
769769
) -> types.TaskQueryParams:
770-
m = re.match(_TASK_NAME_MATCH, request.name)
770+
m = _TASK_NAME_MATCH.match(request.name)
771771
if not m:
772772
raise ServerError(
773773
error=types.InvalidParamsError(
@@ -862,6 +862,12 @@ def security_scheme(
862862
flows=cls.oauth2_flows(scheme.oauth2_security_scheme.flows),
863863
)
864864
)
865+
if scheme.HasField('mtls_security_scheme'):
866+
return types.SecurityScheme(
867+
root=types.MutualTLSSecurityScheme(
868+
description=scheme.mtls_security_scheme.description,
869+
)
870+
)
865871
return types.SecurityScheme(
866872
root=types.OpenIdConnectSecurityScheme(
867873
description=scheme.open_id_connect_security_scheme.description,
@@ -920,7 +926,9 @@ def stream_response(
920926
return cls.task(response.task)
921927
if response.HasField('status_update'):
922928
return cls.task_status_update_event(response.status_update)
923-
return cls.task_artifact_update_event(response.artifact_update)
929+
if response.HasField('artifact_update'):
930+
return cls.task_artifact_update_event(response.artifact_update)
931+
raise ValueError('Unsupported StreamResponse type')
924932

925933
@classmethod
926934
def skill(cls, skill: a2a_pb2.AgentSkill) -> types.AgentSkill:
@@ -943,25 +951,3 @@ def role(cls, role: a2a_pb2.Role) -> types.Role:
943951
return types.Role.agent
944952
case _:
945953
return types.Role.agent
946-
947-
948-
def dict_to_struct(dictionary: dict[str, Any]) -> struct_pb2.Struct:
949-
"""Converts a Python dict to a Struct proto.
950-
951-
Unfortunately, using `json_format.ParseDict` does not work because this
952-
wants the dictionary to be an exact match of the Struct proto with fields
953-
and keys and values, not the traditional Python dict structure.
954-
955-
Args:
956-
dictionary: The Python dict to convert.
957-
958-
Returns:
959-
The Struct proto.
960-
"""
961-
struct = struct_pb2.Struct()
962-
for key, val in dictionary.items():
963-
if isinstance(val, dict):
964-
struct[key] = dict_to_struct(val)
965-
else:
966-
struct[key] = val
967-
return struct

0 commit comments

Comments
 (0)