1313# limitations under the License.
1414
1515import base64
16+ import uuid
1617from dataclasses import dataclass , field
1718from io import BytesIO
1819from typing import Any , Dict , Generic , Literal , Optional , Sequence , TypeVar , get_args
@@ -55,19 +56,25 @@ def __init__(
5556 payload : HRIPayload ,
5657 metadata : Optional [Dict [str , Any ]] = None ,
5758 message_author : Literal ["ai" , "human" ] = "ai" ,
59+ communication_id : Optional [str ] = None ,
60+ seq_no : int = 0 ,
61+ seq_end : bool = False ,
5862 ** kwargs ,
5963 ):
6064 super ().__init__ (payload , metadata )
6165 self .message_author = message_author
6266 self .text = payload .text
6367 self .images = payload .images
6468 self .audios = payload .audios
69+ self .communication_id = communication_id
70+ self .seq_no = seq_no
71+ self .seq_end = seq_end
6572
6673 def __bool__ (self ) -> bool :
6774 return bool (self .text or self .images or self .audios )
6875
6976 def __repr__ (self ):
70- return f"HRIMessage(type={ self .message_author } , text={ self .text } , images={ self .images } , audios={ self .audios } )"
77+ return f"HRIMessage(type={ self .message_author } , text={ self .text } , images={ self .images } , audios={ self .audios } , communication_id= { self . communication_id } , seq_no= { self . seq_no } , seq_end= { self . seq_end } )"
7178
7279 def _image_to_base64 (self , image : ImageType ) -> str :
7380 buffered = BytesIO ()
@@ -115,6 +122,7 @@ def to_langchain(self) -> LangchainBaseMessage:
115122 def from_langchain (
116123 cls ,
117124 message : LangchainBaseMessage | RAIMultimodalMessage ,
125+ communication_id : Optional [str ] = None ,
118126 ) -> "HRIMessage" :
119127 if isinstance (message , RAIMultimodalMessage ):
120128 text = message .text
@@ -137,8 +145,14 @@ def from_langchain(
137145 ),
138146 ),
139147 message_author = message .type , # type: ignore
148+ communication_id = communication_id ,
140149 )
141150
151+ @classmethod
152+ def generate_communication_id (cls ) -> str :
153+ """Generate a unique communication ID."""
154+ return str (uuid .uuid1 ())
155+
142156
143157T = TypeVar ("T" , bound = HRIMessage )
144158
@@ -167,12 +181,21 @@ def __init__(
167181 def _build_message (
168182 self ,
169183 message : LangchainBaseMessage | RAIMultimodalMessage ,
184+ communication_id : Optional [str ] = None ,
185+ seq_no : int = 0 ,
186+ seq_end : bool = False ,
170187 ) -> T :
171- return self .T_class .from_langchain (message )
188+ return self .T_class .from_langchain (message , communication_id , seq_no , seq_end )
172189
173- def send_all_targets (self , message : LangchainBaseMessage | RAIMultimodalMessage ):
190+ def send_all_targets (
191+ self ,
192+ message : LangchainBaseMessage | RAIMultimodalMessage ,
193+ communication_id : Optional [str ] = None ,
194+ seq_no : int = 0 ,
195+ seq_end : bool = False ,
196+ ):
174197 for target in self .configured_targets :
175- to_send = self ._build_message (message )
198+ to_send = self ._build_message (message , communication_id , seq_no , seq_end )
176199 self .send_message (to_send , target )
177200
178201 def receive_all_sources (self , timeout_sec : float = 1.0 ) -> dict [str , T ]:
0 commit comments