Skip to content

Commit f2be7f7

Browse files
authored
Merge branch 'master' into shutdown-rpc-initiate-shutdown
2 parents 5dcbf1c + 33c6c5e commit f2be7f7

23 files changed

Lines changed: 733 additions & 162 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@ src/protos/*.rs
2323
/.cloud_certs/
2424
cloud_envs.fish
2525
/.claude/settings.local.json
26+
.codex

Cargo.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ license = "MIT"
1414
license-file = "LICENSE.txt"
1515

1616
[workspace.dependencies]
17-
bon = { version = "3", features = ["implied-bounds"] }
18-
derive_more = { version = "2.0", features = [
17+
bon = { version = "3", default-features = false, features = ["alloc", "implied-bounds"] }
18+
derive_more = { version = "2.0", default-features = false, features = [
1919
"constructor",
2020
"display",
2121
"from",
@@ -24,10 +24,10 @@ derive_more = { version = "2.0", features = [
2424
"try_into",
2525
] }
2626
thiserror = "2"
27-
tonic = "0.14"
27+
tonic = { version = "0.14", default-features = false }
2828
tonic-prost = "0.14"
2929
tonic-prost-build = "0.14"
30-
opentelemetry = { version = "0.31", features = ["metrics"] }
30+
opentelemetry = { version = "0.31", default-features = false, features = ["metrics"] }
3131
prost = "0.14"
3232
prost-types = { version = "0.7", package = "prost-wkt-types" }
3333
pbjson = "0.9"

crates/client/Cargo.toml

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ anyhow = "1.0"
2222
async-trait = "0.1"
2323
backoff = "0.4"
2424
base64 = "0.22"
25-
bon = "3"
25+
bon = { version = "3", default-features = false, features = ["alloc"] }
2626
derive_more = { workspace = true }
2727
dyn-clone = "1.0"
2828
bytes = "1.10"
@@ -35,12 +35,18 @@ hyper-util = "0.1.16"
3535
opentelemetry = { workspace = true, features = ["metrics"], optional = true }
3636
parking_lot = "0.12"
3737
thiserror = { workspace = true }
38-
tokio = "1.47"
39-
tonic = { workspace = true, features = ["tls-ring", "tls-native-roots"] }
38+
tokio = { version = "1.47", default-features = false, features = [
39+
"io-util",
40+
"net",
41+
"rt",
42+
"sync",
43+
"time",
44+
] }
45+
tonic = { workspace = true, default-features = false, features = ["tls-ring", "tls-native-roots", "channel"] }
4046
tower = { version = "0.5", features = ["util"] }
4147
tracing = "0.1"
4248
url = "2.5"
43-
uuid = { version = "1.18", features = ["v4"] }
49+
uuid = { version = "1.18", default-features = false, features = ["v4"] }
4450
rand = "0.10"
4551

4652
[dependencies.temporalio-common]
@@ -54,6 +60,14 @@ prost = "0.14"
5460
prost-types = { workspace = true }
5561
rstest = "0.26"
5662
tempfile = "3"
63+
tokio = { version = "1.47", default-features = false, features = [
64+
"io-util",
65+
"macros",
66+
"net",
67+
"rt",
68+
"sync",
69+
"time",
70+
] }
5771

5872
[lints]
5973
workspace = true

crates/client/src/dns.rs

Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
use crate::{
2+
add_tls_to_channel,
3+
errors::ClientConnectError,
4+
options_structs::{
5+
ClientKeepAliveOptions, ConnectionOptions, DnsLoadBalancingOptions, TlsOptions,
6+
},
7+
};
8+
use http::Uri;
9+
use std::{collections::HashSet, net::SocketAddr, sync::Arc, time::Duration};
10+
use tokio::sync::mpsc;
11+
use tonic::transport::{Channel, Endpoint, channel::Change};
12+
use url::Url;
13+
14+
/// Validates DNS load balancing configuration and returns the options if DNS LB should be used.
15+
///
16+
/// Returns `Err` if `dns_load_balancing` is set alongside `service_override` or
17+
/// `http_connect_proxy`. Returns `Ok(None)` if DNS LB is disabled or the target is an IP literal.
18+
pub(crate) fn validate_and_get_dns_lb(
19+
options: &ConnectionOptions,
20+
) -> Result<Option<&DnsLoadBalancingOptions>, ClientConnectError> {
21+
let Some(dns_opts) = options.dns_load_balancing.as_ref() else {
22+
return Ok(None);
23+
};
24+
25+
if options.service_override.is_some() {
26+
return Err(ClientConnectError::InvalidConfig(
27+
"dns_load_balancing cannot be used with service_override".to_owned(),
28+
));
29+
}
30+
if options.http_connect_proxy.is_some() {
31+
return Err(ClientConnectError::InvalidConfig(
32+
"dns_load_balancing cannot be used with http_connect_proxy".to_owned(),
33+
));
34+
}
35+
36+
let host = options
37+
.target
38+
.host()
39+
.ok_or_else(|| ClientConnectError::InvalidConfig("target URL has no host".to_owned()))?;
40+
41+
match host {
42+
url::Host::Domain("localhost") => Ok(None),
43+
url::Host::Domain(_) => Ok(Some(dns_opts)),
44+
url::Host::Ipv4(_) | url::Host::Ipv6(_) => Ok(None),
45+
}
46+
}
47+
48+
async fn resolve_host(host: &str, port: u16) -> Result<Vec<SocketAddr>, std::io::Error> {
49+
tokio::net::lookup_host(format!("{host}:{port}"))
50+
.await
51+
.map(|addrs| addrs.collect())
52+
}
53+
54+
fn endpoint_uri(addr: SocketAddr, scheme: &str) -> String {
55+
match addr {
56+
SocketAddr::V4(v4) => format!("{scheme}://{v4}"),
57+
SocketAddr::V6(v6) => format!("{scheme}://[{}]:{}", v6.ip(), v6.port()),
58+
}
59+
}
60+
61+
async fn build_endpoint(
62+
addr: SocketAddr,
63+
original_host: &str,
64+
scheme: &str,
65+
tls_options: Option<&TlsOptions>,
66+
keep_alive: Option<&ClientKeepAliveOptions>,
67+
override_origin: Option<&Uri>,
68+
) -> Result<Endpoint, ClientConnectError> {
69+
let uri = endpoint_uri(addr, scheme);
70+
let channel = Channel::from_shared(uri)?;
71+
72+
// When connecting to an IP with TLS, SNI must use the original hostname.
73+
let tls_for_ip = tls_options.map(|tls| {
74+
if tls.domain.is_some() {
75+
tls.clone()
76+
} else {
77+
let mut patched = tls.clone();
78+
patched.domain = Some(original_host.to_owned());
79+
patched
80+
}
81+
});
82+
let channel = add_tls_to_channel(tls_for_ip.as_ref().or(tls_options), channel).await?;
83+
84+
let channel = if let Some(keep_alive) = keep_alive {
85+
channel
86+
.keep_alive_while_idle(true)
87+
.http2_keep_alive_interval(keep_alive.interval)
88+
.keep_alive_timeout(keep_alive.timeout)
89+
} else {
90+
channel
91+
};
92+
93+
let channel = if let Some(origin) = override_origin.cloned() {
94+
channel.origin(origin)
95+
} else {
96+
channel
97+
};
98+
99+
Ok(channel)
100+
}
101+
102+
/// Creates a balanced channel backed by all DNS-resolved addresses for the target.
103+
pub(crate) async fn create_balanced_channel(
104+
options: &ConnectionOptions,
105+
) -> Result<(Channel, mpsc::Sender<Change<SocketAddr, Endpoint>>), ClientConnectError> {
106+
let host = options
107+
.target
108+
.host_str()
109+
.ok_or_else(|| ClientConnectError::InvalidConfig("target URL has no host".to_owned()))?;
110+
let port = options.target.port_or_known_default().unwrap_or(7233);
111+
let scheme = options.target.scheme();
112+
113+
let addrs = resolve_host(host, port).await.map_err(|source| {
114+
ClientConnectError::DnsResolutionError {
115+
host: host.to_owned(),
116+
source,
117+
}
118+
})?;
119+
if addrs.is_empty() {
120+
return Err(ClientConnectError::DnsResolutionError {
121+
host: host.to_owned(),
122+
source: std::io::Error::new(
123+
std::io::ErrorKind::NotFound,
124+
"DNS resolution returned no addresses",
125+
),
126+
});
127+
}
128+
129+
let (channel, sender) = Channel::balance_channel(addrs.len());
130+
131+
for addr in addrs {
132+
let endpoint = build_endpoint(
133+
addr,
134+
host,
135+
scheme,
136+
options.tls_options.as_ref(),
137+
options.keep_alive.as_ref(),
138+
options.override_origin.as_ref(),
139+
)
140+
.await?;
141+
// Unbounded-ish send into the freshly-created channel; can't realistically fail.
142+
let _ = sender.send(Change::Insert(addr, endpoint)).await;
143+
}
144+
145+
Ok((channel, sender))
146+
}
147+
148+
/// Handle that aborts the DNS re-resolution task when dropped.
149+
pub(crate) struct DnsReresolutionHandle {
150+
abort_handle: tokio::task::AbortHandle,
151+
}
152+
153+
impl Drop for DnsReresolutionHandle {
154+
fn drop(&mut self) {
155+
self.abort_handle.abort();
156+
}
157+
}
158+
159+
/// Spawns a background task that periodically re-resolves DNS and updates the balanced channel.
160+
pub(crate) fn spawn_dns_reresolution(
161+
sender: mpsc::Sender<Change<SocketAddr, Endpoint>>,
162+
target: Url,
163+
tls_options: Option<TlsOptions>,
164+
keep_alive: Option<ClientKeepAliveOptions>,
165+
override_origin: Option<Uri>,
166+
resolution_interval: Duration,
167+
) -> Arc<DnsReresolutionHandle> {
168+
let host = target.host_str().unwrap_or("").to_owned();
169+
let port = target.port_or_known_default().unwrap_or(7233);
170+
let scheme = target.scheme().to_owned();
171+
172+
let handle = tokio::spawn(async move {
173+
let mut current_addrs: HashSet<SocketAddr> = HashSet::new();
174+
// Populate initial set from the channel we already seeded
175+
if let Ok(initial) = resolve_host(&host, port).await {
176+
current_addrs.extend(initial);
177+
}
178+
179+
loop {
180+
tokio::time::sleep(resolution_interval).await;
181+
182+
let new_addrs = match resolve_host(&host, port).await {
183+
Ok(addrs) => addrs.into_iter().collect::<HashSet<_>>(),
184+
Err(e) => {
185+
warn!(
186+
host = %host,
187+
error = %e,
188+
"DNS re-resolution failed, keeping existing endpoints"
189+
);
190+
continue;
191+
}
192+
};
193+
194+
if new_addrs.is_empty() {
195+
warn!(
196+
host = %host,
197+
"DNS re-resolution returned no addresses, keeping existing endpoints"
198+
);
199+
continue;
200+
}
201+
202+
// Remove stale endpoints
203+
for addr in current_addrs.difference(&new_addrs) {
204+
if sender.send(Change::Remove(*addr)).await.is_err() {
205+
return;
206+
}
207+
}
208+
209+
// Add new endpoints
210+
for addr in new_addrs.difference(&current_addrs) {
211+
match build_endpoint(
212+
*addr,
213+
&host,
214+
&scheme,
215+
tls_options.as_ref(),
216+
keep_alive.as_ref(),
217+
override_origin.as_ref(),
218+
)
219+
.await
220+
{
221+
Ok(endpoint) => {
222+
if sender.send(Change::Insert(*addr, endpoint)).await.is_err() {
223+
return;
224+
}
225+
}
226+
Err(e) => {
227+
warn!(
228+
addr = %addr,
229+
error = %e,
230+
"Failed to build endpoint for resolved address"
231+
);
232+
}
233+
}
234+
}
235+
236+
current_addrs = new_addrs;
237+
}
238+
});
239+
240+
Arc::new(DnsReresolutionHandle {
241+
abort_handle: handle.abort_handle(),
242+
})
243+
}
244+
245+
#[cfg(test)]
246+
mod tests {
247+
use super::*;
248+
249+
#[test]
250+
fn ip_v4_target_returns_none() {
251+
let opts = ConnectionOptions::new(Url::parse("http://1.2.3.4:7233").unwrap()).build();
252+
assert!(validate_and_get_dns_lb(&opts).unwrap().is_none());
253+
}
254+
255+
#[test]
256+
fn ip_v6_target_returns_none() {
257+
let opts = ConnectionOptions::new(Url::parse("http://[::1]:7233").unwrap()).build();
258+
assert!(validate_and_get_dns_lb(&opts).unwrap().is_none());
259+
}
260+
261+
#[test]
262+
fn domain_target_returns_some() {
263+
let opts =
264+
ConnectionOptions::new(Url::parse("http://temporal.example.com:7233").unwrap()).build();
265+
assert!(validate_and_get_dns_lb(&opts).unwrap().is_some());
266+
}
267+
268+
#[test]
269+
fn disabled_returns_none() {
270+
let opts = ConnectionOptions::new(Url::parse("http://temporal.example.com:7233").unwrap())
271+
.dns_load_balancing(None)
272+
.build();
273+
assert!(validate_and_get_dns_lb(&opts).unwrap().is_none());
274+
}
275+
276+
#[test]
277+
fn service_override_with_dns_lb_is_error() {
278+
use crate::callback_based::CallbackBasedGrpcService;
279+
280+
let svc = CallbackBasedGrpcService {
281+
callback: Arc::new(|_| Box::pin(async { unreachable!() })),
282+
};
283+
let opts = ConnectionOptions::new(Url::parse("http://temporal.example.com:7233").unwrap())
284+
.service_override(svc)
285+
.build();
286+
assert!(validate_and_get_dns_lb(&opts).is_err());
287+
}
288+
289+
#[test]
290+
fn localhost_returns_none() {
291+
let opts = ConnectionOptions::new(Url::parse("http://localhost:7233").unwrap()).build();
292+
assert!(validate_and_get_dns_lb(&opts).unwrap().is_none());
293+
}
294+
295+
#[test]
296+
fn endpoint_uri_v4() {
297+
let addr: SocketAddr = "1.2.3.4:7233".parse().unwrap();
298+
assert_eq!(endpoint_uri(addr, "https"), "https://1.2.3.4:7233");
299+
}
300+
301+
#[test]
302+
fn endpoint_uri_v6() {
303+
let addr: SocketAddr = "[::1]:7233".parse().unwrap();
304+
assert_eq!(endpoint_uri(addr, "https"), "https://[::1]:7233");
305+
}
306+
}

0 commit comments

Comments
 (0)