Skip to content

Commit 2cea5c3

Browse files
authored
Lowercaseify Nexus request headers (#1006)
1 parent 30a2741 commit 2cea5c3

3 files changed

Lines changed: 36 additions & 3 deletions

File tree

core/src/worker/nexus.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,11 @@ use temporal_sdk_core_protos::{
3232
CancelNexusTask, NexusTask, NexusTaskCancelReason, nexus_task, nexus_task_completion,
3333
},
3434
},
35-
temporal::api::nexus::v1::{request::Variant, response, start_operation_response},
35+
temporal::api::nexus::{
36+
self,
37+
v1::{request::Variant, response, start_operation_response},
38+
},
39+
utilities::normalize_http_headers,
3640
};
3741
use tokio::{
3842
join,
@@ -42,7 +46,7 @@ use tokio::{
4246
use tokio_stream::wrappers::UnboundedReceiverStream;
4347
use tokio_util::sync::CancellationToken;
4448

45-
static REQUEST_TIMEOUT_HEADER: &str = "Request-Timeout";
49+
static REQUEST_TIMEOUT_HEADER: &str = "request-timeout";
4650

4751
/// Centralizes all state related to received nexus tasks
4852
pub(super) struct NexusManager {
@@ -245,11 +249,18 @@ where
245249
.filter_map(move |t| {
246250
let res = match t {
247251
TaskStreamInput::Poll(t) => match *t {
248-
Ok(t) => {
252+
Ok(mut t) => {
249253
if let Some(dur) = t.resp.sched_to_start() {
250254
self.metrics.nexus_task_sched_to_start_latency(dur);
251255
};
252256

257+
if let Some(ref mut req) = t.resp.request {
258+
req.header = normalize_http_headers(std::mem::take(&mut req.header));
259+
if let Some(nexus::v1::request::Variant::StartOperation(ref mut sor)) = req.variant {
260+
sor.callback_header = normalize_http_headers(std::mem::take(&mut sor.callback_header));
261+
}
262+
}
263+
253264
let tt = TaskToken(t.resp.task_token.clone());
254265
let mut timeout_task = None;
255266
if let Some(timeout_str) = t

sdk-core-protos/src/utilities.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::collections::HashMap;
2+
13
use prost::{EncodeError, Message};
24

35
pub trait TryIntoOrNone<F, T> {
@@ -26,3 +28,13 @@ pub fn pack_any<T: Message>(
2628
Message::encode(msg, &mut value)?;
2729
Ok(prost_wkt_types::Any { type_url, value })
2830
}
31+
32+
/// Given a header map, lowercase all the keys and return it as a new map.
33+
/// Any keys that are duplicated after lowercasing will clobber each other in undefined ordering.
34+
pub fn normalize_http_headers(headers: HashMap<String, String>) -> HashMap<String, String> {
35+
let mut new_headers = HashMap::new();
36+
for (header_key, val) in headers.into_iter() {
37+
new_headers.insert(header_key.to_lowercase(), val);
38+
}
39+
new_headers
40+
}

tests/integ_tests/workflow_tests/nexus.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,16 @@ async fn nexus_async(
258258
let client = starter.get_client().await.get_client().clone();
259259
let nexus_task_handle = async {
260260
let mut nt = core_worker.poll_nexus_task().await.unwrap().unwrap_task();
261+
// Verify request header key for timeout exists and is lowercase
262+
if outcome == Outcome::Timeout {
263+
assert!(
264+
nt.request
265+
.as_ref()
266+
.unwrap()
267+
.header
268+
.contains_key("request-timeout")
269+
);
270+
}
261271
let start_req = assert_matches!(
262272
nt.request.unwrap().variant.unwrap(),
263273
request::Variant::StartOperation(sr) => sr

0 commit comments

Comments
 (0)