@@ -7,6 +7,7 @@ use crate::error::{Error, Result};
77use crate :: message:: Message ;
88use numaflow_pb:: clients:: map:: { self , MapRequest , MapResponse , map_client:: MapClient } ;
99use tokio:: sync:: mpsc;
10+ use tokio_stream:: StreamExt ;
1011use tokio_util:: task:: AbortOnDropHandle ;
1112use tonic:: Streaming ;
1213use tonic:: transport:: Channel ;
@@ -17,14 +18,19 @@ use super::{
1718 update_udf_process_time_metric, update_udf_read_metric, update_udf_write_only_metric,
1819} ;
1920
20- type StreamResponseSenderMap =
21- Arc < Mutex < HashMap < String , ( ParentMessageInfo , mpsc:: Sender < Result < Message > > ) > > > ;
21+ type StreamResponseSenderMap = HashMap < String , ( ParentMessageInfo , mpsc:: Sender < Result < Message > > ) > ;
22+
23+ #[ derive( Default ) ]
24+ pub ( in crate :: mapper) struct StreamSenderMapState {
25+ map : StreamResponseSenderMap ,
26+ closed : bool ,
27+ }
2228
2329/// UserDefinedStreamMap is a grpc client that sends stream requests to the map server
2430#[ derive( Clone ) ]
2531pub ( in crate :: mapper) struct UserDefinedStreamMap {
2632 read_tx : mpsc:: Sender < MapRequest > ,
27- senders : StreamResponseSenderMap ,
33+ senders : Arc < Mutex < StreamSenderMapState > > ,
2834 _handle : Arc < AbortOnDropHandle < ( ) > > ,
2935}
3036
@@ -38,7 +44,7 @@ impl UserDefinedStreamMap {
3844 let resp_stream = create_response_stream ( read_tx. clone ( ) , read_rx, & mut client) . await ?;
3945
4046 // map to track the oneshot response sender for each request along with the message info
41- let sender_map = Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ;
47+ let sender_map = Arc :: new ( Mutex :: new ( StreamSenderMapState :: default ( ) ) ) ;
4248
4349 // background task to receive responses from the server and send them to the appropriate
4450 // mpsc sender based on the id
@@ -56,9 +62,15 @@ impl UserDefinedStreamMap {
5662 }
5763
5864 /// Broadcasts a gRPC error to all pending senders and records error metrics.
59- async fn broadcast_error ( sender_map : & StreamResponseSenderMap , error : tonic:: Status ) {
60- let senders =
61- std:: mem:: take ( & mut * sender_map. lock ( ) . expect ( "failed to acquire poisoned lock" ) ) ;
65+ async fn broadcast_error ( sender_map : & Arc < Mutex < StreamSenderMapState > > , error : tonic:: Status ) {
66+ // Force dropping the sender_guard by moving it out of the scope
67+ // Using `drop(sender_guard)` here doesn't satisfy the borrow checker since it assumes it is
68+ // still in use across await calls for some reason.
69+ let senders = {
70+ let mut sender_guard = sender_map. lock ( ) . expect ( "failed to acquire poisoned lock" ) ;
71+ sender_guard. closed = true ;
72+ std:: mem:: take ( & mut sender_guard. map )
73+ } ;
6274
6375 for ( _, ( _, sender) ) in senders {
6476 let _ = sender. send ( Err ( Error :: Grpc ( Box :: new ( error. clone ( ) ) ) ) ) . await ;
@@ -69,41 +81,40 @@ impl UserDefinedStreamMap {
6981 /// receive responses from the server and gets the corresponding oneshot sender from the map
7082 /// and sends the response.
7183 async fn receive_stream_responses (
72- sender_map : StreamResponseSenderMap ,
84+ sender_map : Arc < Mutex < StreamSenderMapState > > ,
7385 mut resp_stream : Streaming < MapResponse > ,
7486 ) {
75- loop {
76- let resp = match resp_stream. message ( ) . await {
77- Ok ( Some ( message) ) => message,
78- Ok ( None ) => break ,
87+ while let Some ( resp) = resp_stream. next ( ) . await {
88+ match resp {
89+ Ok ( message) => {
90+ let ( message_info, response_sender) = sender_map
91+ . lock ( )
92+ . expect ( "failed to acquire poisoned lock" )
93+ . map
94+ . remove ( & message. id )
95+ . expect ( "map entry should always be present" ) ;
96+
97+ // once we get eot, we can drop the sender to let the callee
98+ // know that we are done sending responses
99+ if let Some ( map:: TransmissionStatus { eot : true } ) = message. status {
100+ update_udf_process_time_metric ( is_mono_vertex ( ) ) ;
101+ continue ;
102+ }
103+
104+ Self :: process_stream_response (
105+ & sender_map,
106+ message. id ,
107+ message_info,
108+ response_sender,
109+ message. results ,
110+ )
111+ . await ;
112+ }
79113 Err ( e) => {
80114 error ! ( ?e, "Error reading message from stream map gRPC stream" ) ;
81115 Self :: broadcast_error ( & sender_map, e) . await ;
82- break ;
83116 }
84- } ;
85-
86- let ( message_info, response_sender) = sender_map
87- . lock ( )
88- . expect ( "failed to acquire poisoned lock" )
89- . remove ( & resp. id )
90- . expect ( "map entry should always be present" ) ;
91-
92- // once we get eot, we can drop the sender to let the callee
93- // know that we are done sending responses
94- if let Some ( map:: TransmissionStatus { eot : true } ) = resp. status {
95- update_udf_process_time_metric ( is_mono_vertex ( ) ) ;
96- continue ;
97117 }
98-
99- Self :: process_stream_response (
100- & sender_map,
101- resp. id ,
102- message_info,
103- response_sender,
104- resp. results ,
105- )
106- . await ;
107118 }
108119 }
109120
@@ -129,15 +140,32 @@ impl UserDefinedStreamMap {
129140 return ;
130141 }
131142
132- self . senders
133- . lock ( )
134- . expect ( "failed to acquire poisoned lock" )
135- . insert ( key. clone ( ) , ( msg_info, respond_to) ) ;
143+ // move the senders_guard out of the scope to drop the guard before sending the response
144+ let mapper_closed = {
145+ let mut senders_guard = self
146+ . senders
147+ . lock ( )
148+ . expect ( "failed to acquire poisoned lock" ) ;
149+ if !senders_guard. closed {
150+ // Write the sender back to the map, because we need to send
151+ // more responses for the same request
152+ senders_guard
153+ . map
154+ . insert ( key. clone ( ) , ( msg_info, respond_to. clone ( ) ) ) ;
155+ }
156+ senders_guard. closed
157+ } ;
158+
159+ if mapper_closed {
160+ let _ = respond_to
161+ . send ( Err ( Error :: Mapper ( "mapper closed" . to_string ( ) ) ) )
162+ . await ;
163+ }
136164 }
137165
138166 /// Processes stream responses and sends them to the appropriate mpsc sender
139167 async fn process_stream_response (
140- sender_map : & StreamResponseSenderMap ,
168+ sender_map : & Arc < Mutex < StreamSenderMapState > > ,
141169 msg_id : String ,
142170 mut message_info : ParentMessageInfo ,
143171 response_sender : mpsc:: Sender < Result < Message > > ,
@@ -158,12 +186,24 @@ impl UserDefinedStreamMap {
158186 update_udf_write_only_metric ( is_mono_vertex ( ) ) ;
159187 }
160188
161- // Write the sender back to the map, because we need to send
162- // more responses for the same request
163- sender_map
164- . lock ( )
165- . expect ( "failed to acquire poisoned lock" )
166- . insert ( msg_id, ( message_info, response_sender) ) ;
189+ // move the senders_guard out of the scope to drop the guard before sending the response
190+ let mapper_closed = {
191+ let mut senders_guard = sender_map. lock ( ) . expect ( "failed to acquire poisoned lock" ) ;
192+ if !senders_guard. closed {
193+ // Write the sender back to the map, because we need to send
194+ // more responses for the same request
195+ senders_guard
196+ . map
197+ . insert ( msg_id, ( message_info, response_sender. clone ( ) ) ) ;
198+ }
199+ senders_guard. closed
200+ } ;
201+
202+ if mapper_closed {
203+ let _ = response_sender
204+ . send ( Err ( Error :: Mapper ( "mapper closed" . to_string ( ) ) ) )
205+ . await ;
206+ }
167207 }
168208}
169209
0 commit comments