1818
1919
2020# Regexp patterns for matching
21- _TASK_NAME_MATCH = r'tasks/([\w-]+)'
22- _TASK_PUSH_CONFIG_NAME_MATCH = (
21+ _TASK_NAME_MATCH = re . compile ( r'tasks/([\w-]+)' )
22+ _TASK_PUSH_CONFIG_NAME_MATCH = re . compile (
2323 r'tasks/([\w-]+)/pushNotificationConfigs/([\w-]+)'
2424)
2525
2626
27+ def dict_to_struct (dictionary : dict [str , Any ]) -> struct_pb2 .Struct :
28+ """Converts a Python dict to a Struct proto.
29+
30+ Unfortunately, using `json_format.ParseDict` does not work because this
31+ wants the dictionary to be an exact match of the Struct proto with fields
32+ and keys and values, not the traditional Python dict structure.
33+
34+ Args:
35+ dictionary: The Python dict to convert.
36+
37+ Returns:
38+ The Struct proto.
39+ """
40+ struct = struct_pb2 .Struct ()
41+ for key , val in dictionary .items ():
42+ if isinstance (val , dict ):
43+ struct [key ] = dict_to_struct (val )
44+ else :
45+ struct [key ] = val
46+ return struct
47+
48+
2749class ToProto :
2850 """Converts Python types to proto types."""
2951
@@ -33,11 +55,11 @@ def message(cls, message: types.Message | None) -> a2a_pb2.Message | None:
3355 return None
3456 return a2a_pb2 .Message (
3557 message_id = message .message_id ,
36- content = [ToProto .part (p ) for p in message .parts ],
58+ content = [cls .part (p ) for p in message .parts ],
3759 context_id = message .context_id or '' ,
3860 task_id = message .task_id or '' ,
3961 role = cls .role (message .role ),
40- metadata = ToProto .metadata (message .metadata ),
62+ metadata = cls .metadata (message .metadata ),
4163 )
4264
4365 @classmethod
@@ -53,20 +75,14 @@ def part(cls, part: types.Part) -> a2a_pb2.Part:
5375 if isinstance (part .root , types .TextPart ):
5476 return a2a_pb2 .Part (text = part .root .text )
5577 if isinstance (part .root , types .FilePart ):
56- return a2a_pb2 .Part (file = ToProto .file (part .root .file ))
78+ return a2a_pb2 .Part (file = cls .file (part .root .file ))
5779 if isinstance (part .root , types .DataPart ):
58- return a2a_pb2 .Part (data = ToProto .data (part .root .data ))
80+ return a2a_pb2 .Part (data = cls .data (part .root .data ))
5981 raise ValueError (f'Unsupported part type: { part .root } ' )
6082
6183 @classmethod
6284 def data (cls , data : dict [str , Any ]) -> a2a_pb2 .DataPart :
63- json_data = json .dumps (data )
64- return a2a_pb2 .DataPart (
65- data = json_format .Parse (
66- json_data ,
67- struct_pb2 .Struct (),
68- )
69- )
85+ return a2a_pb2 .DataPart (data = dict_to_struct (data ))
7086
7187 @classmethod
7288 def file (
@@ -87,14 +103,14 @@ def task(cls, task: types.Task) -> a2a_pb2.Task:
87103 return a2a_pb2 .Task (
88104 id = task .id ,
89105 context_id = task .context_id ,
90- status = ToProto .task_status (task .status ),
106+ status = cls .task_status (task .status ),
91107 artifacts = (
92- [ToProto .artifact (a ) for a in task .artifacts ]
108+ [cls .artifact (a ) for a in task .artifacts ]
93109 if task .artifacts
94110 else None
95111 ),
96112 history = (
97- [ToProto .message (h ) for h in task .history ] # type: ignore[misc]
113+ [cls .message (h ) for h in task .history ] # type: ignore[misc]
98114 if task .history
99115 else None
100116 ),
@@ -103,8 +119,8 @@ def task(cls, task: types.Task) -> a2a_pb2.Task:
103119 @classmethod
104120 def task_status (cls , status : types .TaskStatus ) -> a2a_pb2 .TaskStatus :
105121 return a2a_pb2 .TaskStatus (
106- state = ToProto .task_state (status .state ),
107- update = ToProto .message (status .message ),
122+ state = cls .task_state (status .state ),
123+ update = cls .message (status .message ),
108124 )
109125
110126 @classmethod
@@ -132,9 +148,9 @@ def artifact(cls, artifact: types.Artifact) -> a2a_pb2.Artifact:
132148 return a2a_pb2 .Artifact (
133149 artifact_id = artifact .artifact_id ,
134150 description = artifact .description ,
135- metadata = ToProto .metadata (artifact .metadata ),
151+ metadata = cls .metadata (artifact .metadata ),
136152 name = artifact .name ,
137- parts = [ToProto .part (p ) for p in artifact .parts ],
153+ parts = [cls .part (p ) for p in artifact .parts ],
138154 )
139155
140156 @classmethod
@@ -151,7 +167,7 @@ def push_notification_config(
151167 cls , config : types .PushNotificationConfig
152168 ) -> a2a_pb2 .PushNotificationConfig :
153169 auth_info = (
154- ToProto .authentication_info (config .authentication )
170+ cls .authentication_info (config .authentication )
155171 if config .authentication
156172 else None
157173 )
@@ -169,8 +185,8 @@ def task_artifact_update_event(
169185 return a2a_pb2 .TaskArtifactUpdateEvent (
170186 task_id = event .task_id ,
171187 context_id = event .context_id ,
172- artifact = ToProto .artifact (event .artifact ),
173- metadata = ToProto .metadata (event .metadata ),
188+ artifact = cls .artifact (event .artifact ),
189+ metadata = cls .metadata (event .metadata ),
174190 append = event .append or False ,
175191 last_chunk = event .last_chunk or False ,
176192 )
@@ -182,8 +198,8 @@ def task_status_update_event(
182198 return a2a_pb2 .TaskStatusUpdateEvent (
183199 task_id = event .task_id ,
184200 context_id = event .context_id ,
185- status = ToProto .task_status (event .status ),
186- metadata = ToProto .metadata (event .metadata ),
201+ status = cls .task_status (event .status ),
202+ metadata = cls .metadata (event .metadata ),
187203 final = event .final ,
188204 )
189205
@@ -195,7 +211,7 @@ def message_send_configuration(
195211 return a2a_pb2 .SendMessageConfiguration ()
196212 return a2a_pb2 .SendMessageConfiguration (
197213 accepted_output_modes = config .accepted_output_modes ,
198- push_notification = ToProto .push_notification_config (
214+ push_notification = cls .push_notification_config (
199215 config .push_notification_config
200216 )
201217 if config .push_notification_config
@@ -213,19 +229,7 @@ def update_event(
213229 | types .TaskArtifactUpdateEvent ,
214230 ) -> a2a_pb2 .StreamResponse :
215231 """Converts a task, message, or task update event to a StreamResponse."""
216- if isinstance (event , types .TaskStatusUpdateEvent ):
217- return a2a_pb2 .StreamResponse (
218- status_update = ToProto .task_status_update_event (event )
219- )
220- if isinstance (event , types .TaskArtifactUpdateEvent ):
221- return a2a_pb2 .StreamResponse (
222- artifact_update = ToProto .task_artifact_update_event (event )
223- )
224- if isinstance (event , types .Message ):
225- return a2a_pb2 .StreamResponse (msg = ToProto .message (event ))
226- if isinstance (event , types .Task ):
227- return a2a_pb2 .StreamResponse (task = ToProto .task (event ))
228- raise ValueError (f'Unsupported event type: { type (event )} ' )
232+ return cls .stream_response (event )
229233
230234 @classmethod
231235 def task_or_message (
@@ -257,9 +261,11 @@ def stream_response(
257261 return a2a_pb2 .StreamResponse (
258262 status_update = cls .task_status_update_event (event ),
259263 )
260- return a2a_pb2 .StreamResponse (
261- artifact_update = cls .task_artifact_update_event (event ),
262- )
264+ if isinstance (event , types .TaskArtifactUpdateEvent ):
265+ return a2a_pb2 .StreamResponse (
266+ artifact_update = cls .task_artifact_update_event (event ),
267+ )
268+ raise ValueError (f'Unsupported event type: { type (event )} ' )
263269
264270 @classmethod
265271 def task_push_notification_config (
@@ -480,11 +486,11 @@ class FromProto:
480486 def message (cls , message : a2a_pb2 .Message ) -> types .Message :
481487 return types .Message (
482488 message_id = message .message_id ,
483- parts = [FromProto .part (p ) for p in message .content ],
489+ parts = [cls .part (p ) for p in message .content ],
484490 context_id = message .context_id or None ,
485491 task_id = message .task_id or None ,
486- role = FromProto .role (message .role ),
487- metadata = FromProto .metadata (message .metadata ),
492+ role = cls .role (message .role ),
493+ metadata = cls .metadata (message .metadata ),
488494 )
489495
490496 @classmethod
@@ -498,13 +504,9 @@ def part(cls, part: a2a_pb2.Part) -> types.Part:
498504 if part .HasField ('text' ):
499505 return types .Part (root = types .TextPart (text = part .text ))
500506 if part .HasField ('file' ):
501- return types .Part (
502- root = types .FilePart (file = FromProto .file (part .file ))
503- )
507+ return types .Part (root = types .FilePart (file = cls .file (part .file )))
504508 if part .HasField ('data' ):
505- return types .Part (
506- root = types .DataPart (data = FromProto .data (part .data ))
507- )
509+ return types .Part (root = types .DataPart (data = cls .data (part .data )))
508510 raise ValueError (f'Unsupported part type: { part } ' )
509511
510512 @classmethod
@@ -543,16 +545,16 @@ def task(cls, task: a2a_pb2.Task) -> types.Task:
543545 return types .Task (
544546 id = task .id ,
545547 context_id = task .context_id ,
546- status = FromProto .task_status (task .status ),
547- artifacts = [FromProto .artifact (a ) for a in task .artifacts ],
548- history = [FromProto .message (h ) for h in task .history ],
548+ status = cls .task_status (task .status ),
549+ artifacts = [cls .artifact (a ) for a in task .artifacts ],
550+ history = [cls .message (h ) for h in task .history ],
549551 )
550552
551553 @classmethod
552554 def task_status (cls , status : a2a_pb2 .TaskStatus ) -> types .TaskStatus :
553555 return types .TaskStatus (
554- state = FromProto .task_state (status .state ),
555- message = FromProto .message (status .update ),
556+ state = cls .task_state (status .state ),
557+ message = cls .message (status .update ),
556558 )
557559
558560 @classmethod
@@ -580,9 +582,9 @@ def artifact(cls, artifact: a2a_pb2.Artifact) -> types.Artifact:
580582 return types .Artifact (
581583 artifact_id = artifact .artifact_id ,
582584 description = artifact .description ,
583- metadata = FromProto .metadata (artifact .metadata ),
585+ metadata = cls .metadata (artifact .metadata ),
584586 name = artifact .name ,
585- parts = [FromProto .part (p ) for p in artifact .parts ],
587+ parts = [cls .part (p ) for p in artifact .parts ],
586588 )
587589
588590 @classmethod
@@ -592,8 +594,8 @@ def task_artifact_update_event(
592594 return types .TaskArtifactUpdateEvent (
593595 task_id = event .task_id ,
594596 context_id = event .context_id ,
595- artifact = FromProto .artifact (event .artifact ),
596- metadata = FromProto .metadata (event .metadata ),
597+ artifact = cls .artifact (event .artifact ),
598+ metadata = cls .metadata (event .metadata ),
597599 append = event .append ,
598600 last_chunk = event .last_chunk ,
599601 )
@@ -605,8 +607,8 @@ def task_status_update_event(
605607 return types .TaskStatusUpdateEvent (
606608 task_id = event .task_id ,
607609 context_id = event .context_id ,
608- status = FromProto .task_status (event .status ),
609- metadata = FromProto .metadata (event .metadata ),
610+ status = cls .task_status (event .status ),
611+ metadata = cls .metadata (event .metadata ),
610612 final = event .final ,
611613 )
612614
@@ -618,7 +620,7 @@ def push_notification_config(
618620 id = config .id ,
619621 url = config .url ,
620622 token = config .token ,
621- authentication = FromProto .authentication_info (config .authentication )
623+ authentication = cls .authentication_info (config .authentication )
622624 if config .HasField ('authentication' )
623625 else None ,
624626 )
@@ -638,7 +640,7 @@ def message_send_configuration(
638640 ) -> types .MessageSendConfiguration :
639641 return types .MessageSendConfiguration (
640642 accepted_output_modes = list (config .accepted_output_modes ),
641- push_notification_config = FromProto .push_notification_config (
643+ push_notification_config = cls .push_notification_config (
642644 config .push_notification
643645 )
644646 if config .HasField ('push_notification' )
@@ -666,18 +668,16 @@ def task_id_params(
666668 | a2a_pb2 .GetTaskPushNotificationConfigRequest
667669 ),
668670 ) -> types .TaskIdParams :
669- # This is currently incomplete until the core sdk supports multiple
670- # configs for a single task.
671671 if isinstance (request , a2a_pb2 .GetTaskPushNotificationConfigRequest ):
672- m = re .match (_TASK_PUSH_CONFIG_NAME_MATCH , request .name )
672+ m = _TASK_PUSH_CONFIG_NAME_MATCH .match (request .name )
673673 if not m :
674674 raise ServerError (
675675 error = types .InvalidParamsError (
676676 message = f'No task for { request .name } '
677677 )
678678 )
679679 return types .TaskIdParams (id = m .group (1 ))
680- m = re .match (_TASK_NAME_MATCH , request .name )
680+ m = _TASK_NAME_MATCH .match (request .name )
681681 if not m :
682682 raise ServerError (
683683 error = types .InvalidParamsError (
@@ -691,7 +691,7 @@ def task_push_notification_config_request(
691691 cls ,
692692 request : a2a_pb2 .CreateTaskPushNotificationConfigRequest ,
693693 ) -> types .TaskPushNotificationConfig :
694- m = re .match (_TASK_NAME_MATCH , request .parent )
694+ m = _TASK_NAME_MATCH .match (request .parent )
695695 if not m :
696696 raise ServerError (
697697 error = types .InvalidParamsError (
@@ -710,7 +710,7 @@ def task_push_notification_config(
710710 cls ,
711711 config : a2a_pb2 .TaskPushNotificationConfig ,
712712 ) -> types .TaskPushNotificationConfig :
713- m = re .match (_TASK_PUSH_CONFIG_NAME_MATCH , config .name )
713+ m = _TASK_PUSH_CONFIG_NAME_MATCH .match (config .name )
714714 if not m :
715715 raise ServerError (
716716 error = types .InvalidParamsError (
@@ -767,7 +767,7 @@ def task_query_params(
767767 cls ,
768768 request : a2a_pb2 .GetTaskRequest ,
769769 ) -> types .TaskQueryParams :
770- m = re .match (_TASK_NAME_MATCH , request .name )
770+ m = _TASK_NAME_MATCH .match (request .name )
771771 if not m :
772772 raise ServerError (
773773 error = types .InvalidParamsError (
@@ -862,6 +862,12 @@ def security_scheme(
862862 flows = cls .oauth2_flows (scheme .oauth2_security_scheme .flows ),
863863 )
864864 )
865+ if scheme .HasField ('mtls_security_scheme' ):
866+ return types .SecurityScheme (
867+ root = types .MutualTLSSecurityScheme (
868+ description = scheme .mtls_security_scheme .description ,
869+ )
870+ )
865871 return types .SecurityScheme (
866872 root = types .OpenIdConnectSecurityScheme (
867873 description = scheme .open_id_connect_security_scheme .description ,
@@ -920,7 +926,9 @@ def stream_response(
920926 return cls .task (response .task )
921927 if response .HasField ('status_update' ):
922928 return cls .task_status_update_event (response .status_update )
923- return cls .task_artifact_update_event (response .artifact_update )
929+ if response .HasField ('artifact_update' ):
930+ return cls .task_artifact_update_event (response .artifact_update )
931+ raise ValueError ('Unsupported StreamResponse type' )
924932
925933 @classmethod
926934 def skill (cls , skill : a2a_pb2 .AgentSkill ) -> types .AgentSkill :
@@ -943,25 +951,3 @@ def role(cls, role: a2a_pb2.Role) -> types.Role:
943951 return types .Role .agent
944952 case _:
945953 return types .Role .agent
946-
947-
948- def dict_to_struct (dictionary : dict [str , Any ]) -> struct_pb2 .Struct :
949- """Converts a Python dict to a Struct proto.
950-
951- Unfortunately, using `json_format.ParseDict` does not work because this
952- wants the dictionary to be an exact match of the Struct proto with fields
953- and keys and values, not the traditional Python dict structure.
954-
955- Args:
956- dictionary: The Python dict to convert.
957-
958- Returns:
959- The Struct proto.
960- """
961- struct = struct_pb2 .Struct ()
962- for key , val in dictionary .items ():
963- if isinstance (val , dict ):
964- struct [key ] = dict_to_struct (val )
965- else :
966- struct [key ] = val
967- return struct
0 commit comments