Skip to content

Commit fbdea1c

Browse files
Add similar fix for batch and stream
Signed-off-by: Vaibhav Tiwari <vaibhav.tiwari33@gmail.com>
1 parent 43c662a commit fbdea1c

4 files changed

Lines changed: 155 additions & 95 deletions

File tree

rust/.cargo/config.toml

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
[build]
2-
rustflags = ["-Dwarnings"]
3-
4-
[alias]
5-
clippy-ci = [
6-
"clippy",
7-
"--workspace",
8-
"--exclude", "numaflow-pb",
9-
"--exclude", "numaflow-models",
10-
"--no-deps",
11-
"--all-targets",
12-
"--all-features"
13-
]
1+
#[build]
2+
#rustflags = ["-Dwarnings"]
3+
#
4+
#[alias]
5+
#clippy-ci = [
6+
# "clippy",
7+
# "--workspace",
8+
# "--exclude", "numaflow-pb",
9+
# "--exclude", "numaflow-models",
10+
# "--no-deps",
11+
# "--all-targets",
12+
# "--all-features"
13+
#]

rust/numaflow-core/src/mapper/map/batch.rs

Lines changed: 50 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,28 @@ use crate::error::{Error, Result};
1212
use crate::message::Message;
1313
use numaflow_pb::clients::map::{self, MapRequest, MapResponse, map_client::MapClient};
1414
use tokio::sync::{mpsc, oneshot};
15+
use tokio_stream::StreamExt;
1516
use tokio_util::task::AbortOnDropHandle;
1617
use tonic::Streaming;
1718
use tonic::transport::Channel;
1819
use tracing::error;
1920

2021
/// Type aliases
2122
type ResponseSenderMap =
22-
Arc<Mutex<HashMap<String, (ParentMessageInfo, oneshot::Sender<Result<Vec<Message>>>)>>>;
23+
HashMap<String, (ParentMessageInfo, oneshot::Sender<Result<Vec<Message>>>)>;
24+
25+
#[derive(Default)]
26+
pub(in crate::mapper) struct BatchSenderMapState {
27+
map: ResponseSenderMap,
28+
closed: bool,
29+
}
2330

2431
/// UserDefinedBatchMap is a grpc client that sends batch requests to the map server
2532
/// and forwards the responses.
2633
#[derive(Clone)]
2734
pub(in crate::mapper) struct UserDefinedBatchMap {
2835
read_tx: mpsc::Sender<MapRequest>,
29-
senders: ResponseSenderMap,
36+
senders: Arc<Mutex<BatchSenderMapState>>,
3037
_handle: Arc<AbortOnDropHandle<()>>,
3138
}
3239

@@ -40,7 +47,7 @@ impl UserDefinedBatchMap {
4047
let resp_stream = create_response_stream(read_tx.clone(), read_rx, &mut client).await?;
4148

4249
// map to track the oneshot response sender for each request along with the message info
43-
let sender_map = Arc::new(Mutex::new(HashMap::new()));
50+
let sender_map = Arc::new(Mutex::new(BatchSenderMapState::default()));
4451

4552
// background task to receive responses from the server and send them to the appropriate
4653
// oneshot response sender based on the id
@@ -58,9 +65,13 @@ impl UserDefinedBatchMap {
5865
}
5966

6067
/// Broadcasts a batch map gRPC error to all pending senders and records error metrics.
61-
fn broadcast_error(sender_map: &ResponseSenderMap, error: tonic::Status) {
62-
let senders =
63-
std::mem::take(&mut *sender_map.lock().expect("failed to acquire poisoned lock"));
68+
fn broadcast_error(sender_map: &Arc<Mutex<BatchSenderMapState>>, error: tonic::Status) {
69+
let mut sender_guard = sender_map.lock().expect("failed to acquire poisoned lock");
70+
sender_guard.closed = true;
71+
let senders = std::mem::take(&mut sender_guard.map);
72+
73+
// avoid holding the lock while sending errors
74+
drop(sender_guard);
6475

6576
for (_, (_, sender)) in senders {
6677
let _ = sender.send(Err(Error::Grpc(Box::new(error.clone()))));
@@ -71,45 +82,45 @@ impl UserDefinedBatchMap {
7182
/// receive responses from the server and gets the corresponding oneshot response sender from the map
7283
/// and sends the response.
7384
async fn receive_batch_responses(
74-
sender_map: ResponseSenderMap,
85+
sender_map: Arc<Mutex<BatchSenderMapState>>,
7586
mut resp_stream: Streaming<MapResponse>,
7687
) {
77-
loop {
78-
let resp = match resp_stream.message().await {
79-
Ok(Some(message)) => message,
80-
Ok(None) => break,
88+
while let Some(resp) = resp_stream.next().await {
89+
match resp {
90+
Ok(message) => {
91+
if let Some(map::TransmissionStatus { eot: true }) = message.status {
92+
if !sender_map
93+
.lock()
94+
.expect("failed to acquire poisoned lock")
95+
.map
96+
.is_empty()
97+
{
98+
error!("received EOT but not all responses have been received");
99+
critical_error!(VERTEX_TYPE_MAP_UDF, "eot_received_from_map");
100+
}
101+
update_udf_process_time_metric(is_mono_vertex());
102+
continue;
103+
}
104+
105+
Self::process_response(&sender_map, message).await
106+
}
81107
Err(e) => {
82108
error!(?e, "Error reading message from batch map gRPC stream");
83109
Self::broadcast_error(&sender_map, e);
84-
break;
85-
}
86-
};
87-
88-
if let Some(map::TransmissionStatus { eot: true }) = resp.status {
89-
if !sender_map
90-
.lock()
91-
.expect("failed to acquire poisoned lock")
92-
.is_empty()
93-
{
94-
error!("received EOT but not all responses have been received");
95-
critical_error!(VERTEX_TYPE_MAP_UDF, "eot_received_from_map");
96110
}
97-
update_udf_process_time_metric(is_mono_vertex());
98-
continue;
99111
}
100-
101-
Self::process_response(&sender_map, resp).await
102112
}
103113
}
104114

105115
/// Processes the response from the server and sends it to the appropriate oneshot sender
106116
/// based on the message id entry in the map.
107-
async fn process_response(sender_map: &ResponseSenderMap, resp: MapResponse) {
117+
async fn process_response(sender_map: &Arc<Mutex<BatchSenderMapState>>, resp: MapResponse) {
108118
let msg_id = resp.id;
109119

110120
let sender_entry = sender_map
111121
.lock()
112122
.expect("failed to acquire poisoned lock")
123+
.map
113124
.remove(&msg_id);
114125

115126
if let Some((msg_info, sender)) = sender_entry {
@@ -147,9 +158,18 @@ impl UserDefinedBatchMap {
147158
return;
148159
}
149160

150-
self.senders
161+
let mut senders_guard = self
162+
.senders
151163
.lock()
152-
.expect("failed to acquire poisoned lock")
164+
.expect("failed to acquire poisoned lock");
165+
166+
if senders_guard.closed {
167+
let _ = respond_to.send(Err(Error::Mapper("mapper closed".to_string())));
168+
return;
169+
}
170+
171+
senders_guard
172+
.map
153173
.insert(key.clone(), (msg_info, respond_to));
154174
}
155175

rust/numaflow-core/src/mapper/map/stream.rs

Lines changed: 86 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use crate::error::{Error, Result};
77
use crate::message::Message;
88
use numaflow_pb::clients::map::{self, MapRequest, MapResponse, map_client::MapClient};
99
use tokio::sync::mpsc;
10+
use tokio_stream::StreamExt;
1011
use tokio_util::task::AbortOnDropHandle;
1112
use tonic::Streaming;
1213
use tonic::transport::Channel;
@@ -17,14 +18,19 @@ use super::{
1718
update_udf_process_time_metric, update_udf_read_metric, update_udf_write_only_metric,
1819
};
1920

20-
type StreamResponseSenderMap =
21-
Arc<Mutex<HashMap<String, (ParentMessageInfo, mpsc::Sender<Result<Message>>)>>>;
21+
type StreamResponseSenderMap = HashMap<String, (ParentMessageInfo, mpsc::Sender<Result<Message>>)>;
22+
23+
#[derive(Default)]
24+
pub(in crate::mapper) struct StreamSenderMapState {
25+
map: StreamResponseSenderMap,
26+
closed: bool,
27+
}
2228

2329
/// UserDefinedStreamMap is a grpc client that sends stream requests to the map server
2430
#[derive(Clone)]
2531
pub(in crate::mapper) struct UserDefinedStreamMap {
2632
read_tx: mpsc::Sender<MapRequest>,
27-
senders: StreamResponseSenderMap,
33+
senders: Arc<Mutex<StreamSenderMapState>>,
2834
_handle: Arc<AbortOnDropHandle<()>>,
2935
}
3036

@@ -38,7 +44,7 @@ impl UserDefinedStreamMap {
3844
let resp_stream = create_response_stream(read_tx.clone(), read_rx, &mut client).await?;
3945

4046
// map to track the oneshot response sender for each request along with the message info
41-
let sender_map = Arc::new(Mutex::new(HashMap::new()));
47+
let sender_map = Arc::new(Mutex::new(StreamSenderMapState::default()));
4248

4349
// background task to receive responses from the server and send them to the appropriate
4450
// mpsc sender based on the id
@@ -56,9 +62,15 @@ impl UserDefinedStreamMap {
5662
}
5763

5864
/// Broadcasts a gRPC error to all pending senders and records error metrics.
59-
async fn broadcast_error(sender_map: &StreamResponseSenderMap, error: tonic::Status) {
60-
let senders =
61-
std::mem::take(&mut *sender_map.lock().expect("failed to acquire poisoned lock"));
65+
async fn broadcast_error(sender_map: &Arc<Mutex<StreamSenderMapState>>, error: tonic::Status) {
66+
// Force dropping the sender_guard by moving it out of the scope
67+
// Using `drop(sender_guard)` here doesn't satisfy the borrow checker since it assumes it is
68+
// still in use across await calls for some reason.
69+
let senders = {
70+
let mut sender_guard = sender_map.lock().expect("failed to acquire poisoned lock");
71+
sender_guard.closed = true;
72+
std::mem::take(&mut sender_guard.map)
73+
};
6274

6375
for (_, (_, sender)) in senders {
6476
let _ = sender.send(Err(Error::Grpc(Box::new(error.clone())))).await;
@@ -69,41 +81,40 @@ impl UserDefinedStreamMap {
6981
/// receive responses from the server and gets the corresponding oneshot sender from the map
7082
/// and sends the response.
7183
async fn receive_stream_responses(
72-
sender_map: StreamResponseSenderMap,
84+
sender_map: Arc<Mutex<StreamSenderMapState>>,
7385
mut resp_stream: Streaming<MapResponse>,
7486
) {
75-
loop {
76-
let resp = match resp_stream.message().await {
77-
Ok(Some(message)) => message,
78-
Ok(None) => break,
87+
while let Some(resp) = resp_stream.next().await {
88+
match resp {
89+
Ok(message) => {
90+
let (message_info, response_sender) = sender_map
91+
.lock()
92+
.expect("failed to acquire poisoned lock")
93+
.map
94+
.remove(&message.id)
95+
.expect("map entry should always be present");
96+
97+
// once we get eot, we can drop the sender to let the callee
98+
// know that we are done sending responses
99+
if let Some(map::TransmissionStatus { eot: true }) = message.status {
100+
update_udf_process_time_metric(is_mono_vertex());
101+
continue;
102+
}
103+
104+
Self::process_stream_response(
105+
&sender_map,
106+
message.id,
107+
message_info,
108+
response_sender,
109+
message.results,
110+
)
111+
.await;
112+
}
79113
Err(e) => {
80114
error!(?e, "Error reading message from stream map gRPC stream");
81115
Self::broadcast_error(&sender_map, e).await;
82-
break;
83116
}
84-
};
85-
86-
let (message_info, response_sender) = sender_map
87-
.lock()
88-
.expect("failed to acquire poisoned lock")
89-
.remove(&resp.id)
90-
.expect("map entry should always be present");
91-
92-
// once we get eot, we can drop the sender to let the callee
93-
// know that we are done sending responses
94-
if let Some(map::TransmissionStatus { eot: true }) = resp.status {
95-
update_udf_process_time_metric(is_mono_vertex());
96-
continue;
97117
}
98-
99-
Self::process_stream_response(
100-
&sender_map,
101-
resp.id,
102-
message_info,
103-
response_sender,
104-
resp.results,
105-
)
106-
.await;
107118
}
108119
}
109120

@@ -129,15 +140,32 @@ impl UserDefinedStreamMap {
129140
return;
130141
}
131142

132-
self.senders
133-
.lock()
134-
.expect("failed to acquire poisoned lock")
135-
.insert(key.clone(), (msg_info, respond_to));
143+
// move the senders_guard out of the scope to drop the guard before sending the response
144+
let mapper_closed = {
145+
let mut senders_guard = self
146+
.senders
147+
.lock()
148+
.expect("failed to acquire poisoned lock");
149+
if !senders_guard.closed {
150+
// Write the sender back to the map, because we need to send
151+
// more responses for the same request
152+
senders_guard
153+
.map
154+
.insert(key.clone(), (msg_info, respond_to.clone()));
155+
}
156+
senders_guard.closed
157+
};
158+
159+
if mapper_closed {
160+
let _ = respond_to
161+
.send(Err(Error::Mapper("mapper closed".to_string())))
162+
.await;
163+
}
136164
}
137165

138166
/// Processes stream responses and sends them to the appropriate mpsc sender
139167
async fn process_stream_response(
140-
sender_map: &StreamResponseSenderMap,
168+
sender_map: &Arc<Mutex<StreamSenderMapState>>,
141169
msg_id: String,
142170
mut message_info: ParentMessageInfo,
143171
response_sender: mpsc::Sender<Result<Message>>,
@@ -158,12 +186,24 @@ impl UserDefinedStreamMap {
158186
update_udf_write_only_metric(is_mono_vertex());
159187
}
160188

161-
// Write the sender back to the map, because we need to send
162-
// more responses for the same request
163-
sender_map
164-
.lock()
165-
.expect("failed to acquire poisoned lock")
166-
.insert(msg_id, (message_info, response_sender));
189+
// move the senders_guard out of the scope to drop the guard before sending the response
190+
let mapper_closed = {
191+
let mut senders_guard = sender_map.lock().expect("failed to acquire poisoned lock");
192+
if !senders_guard.closed {
193+
// Write the sender back to the map, because we need to send
194+
// more responses for the same request
195+
senders_guard
196+
.map
197+
.insert(msg_id, (message_info, response_sender.clone()));
198+
}
199+
senders_guard.closed
200+
};
201+
202+
if mapper_closed {
203+
let _ = response_sender
204+
.send(Err(Error::Mapper("mapper closed".to_string())))
205+
.await;
206+
}
167207
}
168208
}
169209

0 commit comments

Comments
 (0)