Skip to content

Commit 6d2752e

Browse files
authored
Merge branch 'master' into remove-send-guard
2 parents 783a03a + d7ebff8 commit 6d2752e

32 files changed

Lines changed: 1488 additions & 100 deletions

File tree

.github/workflows/per-pr.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ jobs:
4646
test:
4747
name: Unit Tests
4848
# Give extra time to ensure pushes to main have time to populate the cache
49-
timeout-minutes: ${{ github.ref == 'refs/heads/master' && 20 || 15 }}
49+
timeout-minutes: ${{ matrix.timeoutMinutes || (github.ref == 'refs/heads/master' && 20 || 15) }}
5050
strategy:
5151
fail-fast: false
5252
matrix:
@@ -60,6 +60,7 @@ jobs:
6060
runsOn: macos-14
6161
- os: macos-intel
6262
runsOn: macos-15-intel
63+
timeoutMinutes: 20
6364
runs-on: ${{ matrix.runsOn || matrix.os }}
6465
steps:
6566
- uses: actions/checkout@v4

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ prost = "0.14"
3232
prost-types = { version = "0.7", package = "prost-wkt-types" }
3333
pbjson = "0.9"
3434
pbjson-build = "0.9"
35+
serde_json = "1.0"
3536

3637
[workspace.lints.rust]
3738
unreachable_pub = "warn"

crates/client/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ hyper-util = "0.1.16"
3535
opentelemetry = { workspace = true, features = ["metrics"], optional = true }
3636
parking_lot = "0.12"
3737
thiserror = { workspace = true }
38-
tokio = "1.47"
38+
tokio = { version = "1.47", features = ["net", "time"] }
3939
tonic = { workspace = true, features = ["tls-ring", "tls-native-roots"] }
4040
tower = { version = "0.5", features = ["util"] }
4141
tracing = "0.1"

crates/client/src/dns.rs

Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
use crate::{
2+
add_tls_to_channel,
3+
errors::ClientConnectError,
4+
options_structs::{
5+
ClientKeepAliveOptions, ConnectionOptions, DnsLoadBalancingOptions, TlsOptions,
6+
},
7+
};
8+
use http::Uri;
9+
use std::{collections::HashSet, net::SocketAddr, sync::Arc, time::Duration};
10+
use tokio::sync::mpsc;
11+
use tonic::transport::{Channel, Endpoint, channel::Change};
12+
use url::Url;
13+
14+
/// Validates DNS load balancing configuration and returns the options if DNS LB should be used.
15+
///
16+
/// Returns `Err` if `dns_load_balancing` is set alongside `service_override` or
17+
/// `http_connect_proxy`. Returns `Ok(None)` if DNS LB is disabled or the target is an IP literal.
18+
pub(crate) fn validate_and_get_dns_lb(
19+
options: &ConnectionOptions,
20+
) -> Result<Option<&DnsLoadBalancingOptions>, ClientConnectError> {
21+
let Some(dns_opts) = options.dns_load_balancing.as_ref() else {
22+
return Ok(None);
23+
};
24+
25+
if options.service_override.is_some() {
26+
return Err(ClientConnectError::InvalidConfig(
27+
"dns_load_balancing cannot be used with service_override".to_owned(),
28+
));
29+
}
30+
if options.http_connect_proxy.is_some() {
31+
return Err(ClientConnectError::InvalidConfig(
32+
"dns_load_balancing cannot be used with http_connect_proxy".to_owned(),
33+
));
34+
}
35+
36+
let host = options
37+
.target
38+
.host()
39+
.ok_or_else(|| ClientConnectError::InvalidConfig("target URL has no host".to_owned()))?;
40+
41+
match host {
42+
url::Host::Domain("localhost") => Ok(None),
43+
url::Host::Domain(_) => Ok(Some(dns_opts)),
44+
url::Host::Ipv4(_) | url::Host::Ipv6(_) => Ok(None),
45+
}
46+
}
47+
48+
async fn resolve_host(host: &str, port: u16) -> Result<Vec<SocketAddr>, std::io::Error> {
49+
tokio::net::lookup_host(format!("{host}:{port}"))
50+
.await
51+
.map(|addrs| addrs.collect())
52+
}
53+
54+
fn endpoint_uri(addr: SocketAddr, scheme: &str) -> String {
55+
match addr {
56+
SocketAddr::V4(v4) => format!("{scheme}://{v4}"),
57+
SocketAddr::V6(v6) => format!("{scheme}://[{}]:{}", v6.ip(), v6.port()),
58+
}
59+
}
60+
61+
async fn build_endpoint(
62+
addr: SocketAddr,
63+
original_host: &str,
64+
scheme: &str,
65+
tls_options: Option<&TlsOptions>,
66+
keep_alive: Option<&ClientKeepAliveOptions>,
67+
override_origin: Option<&Uri>,
68+
) -> Result<Endpoint, ClientConnectError> {
69+
let uri = endpoint_uri(addr, scheme);
70+
let channel = Channel::from_shared(uri)?;
71+
72+
// When connecting to an IP with TLS, SNI must use the original hostname.
73+
let tls_for_ip = tls_options.map(|tls| {
74+
if tls.domain.is_some() {
75+
tls.clone()
76+
} else {
77+
let mut patched = tls.clone();
78+
patched.domain = Some(original_host.to_owned());
79+
patched
80+
}
81+
});
82+
let channel = add_tls_to_channel(tls_for_ip.as_ref().or(tls_options), channel).await?;
83+
84+
let channel = if let Some(keep_alive) = keep_alive {
85+
channel
86+
.keep_alive_while_idle(true)
87+
.http2_keep_alive_interval(keep_alive.interval)
88+
.keep_alive_timeout(keep_alive.timeout)
89+
} else {
90+
channel
91+
};
92+
93+
let channel = if let Some(origin) = override_origin.cloned() {
94+
channel.origin(origin)
95+
} else {
96+
channel
97+
};
98+
99+
Ok(channel)
100+
}
101+
102+
/// Creates a balanced channel backed by all DNS-resolved addresses for the target.
103+
pub(crate) async fn create_balanced_channel(
104+
options: &ConnectionOptions,
105+
) -> Result<(Channel, mpsc::Sender<Change<SocketAddr, Endpoint>>), ClientConnectError> {
106+
let host = options
107+
.target
108+
.host_str()
109+
.ok_or_else(|| ClientConnectError::InvalidConfig("target URL has no host".to_owned()))?;
110+
let port = options.target.port_or_known_default().unwrap_or(7233);
111+
let scheme = options.target.scheme();
112+
113+
let addrs = resolve_host(host, port).await.map_err(|source| {
114+
ClientConnectError::DnsResolutionError {
115+
host: host.to_owned(),
116+
source,
117+
}
118+
})?;
119+
if addrs.is_empty() {
120+
return Err(ClientConnectError::DnsResolutionError {
121+
host: host.to_owned(),
122+
source: std::io::Error::new(
123+
std::io::ErrorKind::NotFound,
124+
"DNS resolution returned no addresses",
125+
),
126+
});
127+
}
128+
129+
let (channel, sender) = Channel::balance_channel(addrs.len());
130+
131+
for addr in addrs {
132+
let endpoint = build_endpoint(
133+
addr,
134+
host,
135+
scheme,
136+
options.tls_options.as_ref(),
137+
options.keep_alive.as_ref(),
138+
options.override_origin.as_ref(),
139+
)
140+
.await?;
141+
// Unbounded-ish send into the freshly-created channel; can't realistically fail.
142+
let _ = sender.send(Change::Insert(addr, endpoint)).await;
143+
}
144+
145+
Ok((channel, sender))
146+
}
147+
148+
/// Handle that aborts the DNS re-resolution task when dropped.
149+
pub(crate) struct DnsReresolutionHandle {
150+
abort_handle: tokio::task::AbortHandle,
151+
}
152+
153+
impl Drop for DnsReresolutionHandle {
154+
fn drop(&mut self) {
155+
self.abort_handle.abort();
156+
}
157+
}
158+
159+
/// Spawns a background task that periodically re-resolves DNS and updates the balanced channel.
160+
pub(crate) fn spawn_dns_reresolution(
161+
sender: mpsc::Sender<Change<SocketAddr, Endpoint>>,
162+
target: Url,
163+
tls_options: Option<TlsOptions>,
164+
keep_alive: Option<ClientKeepAliveOptions>,
165+
override_origin: Option<Uri>,
166+
resolution_interval: Duration,
167+
) -> Arc<DnsReresolutionHandle> {
168+
let host = target.host_str().unwrap_or("").to_owned();
169+
let port = target.port_or_known_default().unwrap_or(7233);
170+
let scheme = target.scheme().to_owned();
171+
172+
let handle = tokio::spawn(async move {
173+
let mut current_addrs: HashSet<SocketAddr> = HashSet::new();
174+
// Populate initial set from the channel we already seeded
175+
if let Ok(initial) = resolve_host(&host, port).await {
176+
current_addrs.extend(initial);
177+
}
178+
179+
loop {
180+
tokio::time::sleep(resolution_interval).await;
181+
182+
let new_addrs = match resolve_host(&host, port).await {
183+
Ok(addrs) => addrs.into_iter().collect::<HashSet<_>>(),
184+
Err(e) => {
185+
warn!(
186+
host = %host,
187+
error = %e,
188+
"DNS re-resolution failed, keeping existing endpoints"
189+
);
190+
continue;
191+
}
192+
};
193+
194+
if new_addrs.is_empty() {
195+
warn!(
196+
host = %host,
197+
"DNS re-resolution returned no addresses, keeping existing endpoints"
198+
);
199+
continue;
200+
}
201+
202+
// Remove stale endpoints
203+
for addr in current_addrs.difference(&new_addrs) {
204+
if sender.send(Change::Remove(*addr)).await.is_err() {
205+
return;
206+
}
207+
}
208+
209+
// Add new endpoints
210+
for addr in new_addrs.difference(&current_addrs) {
211+
match build_endpoint(
212+
*addr,
213+
&host,
214+
&scheme,
215+
tls_options.as_ref(),
216+
keep_alive.as_ref(),
217+
override_origin.as_ref(),
218+
)
219+
.await
220+
{
221+
Ok(endpoint) => {
222+
if sender.send(Change::Insert(*addr, endpoint)).await.is_err() {
223+
return;
224+
}
225+
}
226+
Err(e) => {
227+
warn!(
228+
addr = %addr,
229+
error = %e,
230+
"Failed to build endpoint for resolved address"
231+
);
232+
}
233+
}
234+
}
235+
236+
current_addrs = new_addrs;
237+
}
238+
});
239+
240+
Arc::new(DnsReresolutionHandle {
241+
abort_handle: handle.abort_handle(),
242+
})
243+
}
244+
245+
#[cfg(test)]
246+
mod tests {
247+
use super::*;
248+
249+
#[test]
250+
fn ip_v4_target_returns_none() {
251+
let opts = ConnectionOptions::new(Url::parse("http://1.2.3.4:7233").unwrap()).build();
252+
assert!(validate_and_get_dns_lb(&opts).unwrap().is_none());
253+
}
254+
255+
#[test]
256+
fn ip_v6_target_returns_none() {
257+
let opts = ConnectionOptions::new(Url::parse("http://[::1]:7233").unwrap()).build();
258+
assert!(validate_and_get_dns_lb(&opts).unwrap().is_none());
259+
}
260+
261+
#[test]
262+
fn domain_target_returns_some() {
263+
let opts =
264+
ConnectionOptions::new(Url::parse("http://temporal.example.com:7233").unwrap()).build();
265+
assert!(validate_and_get_dns_lb(&opts).unwrap().is_some());
266+
}
267+
268+
#[test]
269+
fn disabled_returns_none() {
270+
let opts = ConnectionOptions::new(Url::parse("http://temporal.example.com:7233").unwrap())
271+
.dns_load_balancing(None)
272+
.build();
273+
assert!(validate_and_get_dns_lb(&opts).unwrap().is_none());
274+
}
275+
276+
#[test]
277+
fn service_override_with_dns_lb_is_error() {
278+
use crate::callback_based::CallbackBasedGrpcService;
279+
280+
let svc = CallbackBasedGrpcService {
281+
callback: Arc::new(|_| Box::pin(async { unreachable!() })),
282+
};
283+
let opts = ConnectionOptions::new(Url::parse("http://temporal.example.com:7233").unwrap())
284+
.service_override(svc)
285+
.build();
286+
assert!(validate_and_get_dns_lb(&opts).is_err());
287+
}
288+
289+
#[test]
290+
fn localhost_returns_none() {
291+
let opts = ConnectionOptions::new(Url::parse("http://localhost:7233").unwrap()).build();
292+
assert!(validate_and_get_dns_lb(&opts).unwrap().is_none());
293+
}
294+
295+
#[test]
296+
fn endpoint_uri_v4() {
297+
let addr: SocketAddr = "1.2.3.4:7233".parse().unwrap();
298+
assert_eq!(endpoint_uri(addr, "https"), "https://1.2.3.4:7233");
299+
}
300+
301+
#[test]
302+
fn endpoint_uri_v6() {
303+
let addr: SocketAddr = "[::1]:7233".parse().unwrap();
304+
assert_eq!(endpoint_uri(addr, "https"), "https://[::1]:7233");
305+
}
306+
}

crates/client/src/errors.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,18 @@ pub enum ClientConnectError {
2424
/// server capabilities / verify server is responding.
2525
#[error("`get_system_info` call error after connection: {0:?}")]
2626
SystemInfoCallError(tonic::Status),
27+
/// DNS resolution failed when attempting load-balanced connection.
28+
#[error("DNS resolution error for '{host}': {source}")]
29+
DnsResolutionError {
30+
/// The host that failed to resolve.
31+
host: String,
32+
/// The underlying IO error.
33+
#[source]
34+
source: std::io::Error,
35+
},
36+
/// Invalid client configuration.
37+
#[error("Invalid client configuration: {0}")]
38+
InvalidConfig(String),
2739
}
2840

2941
/// Errors thrown when a gRPC metadata header is invalid.

crates/client/src/grpc.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2094,6 +2094,7 @@ mod tests {
20942094
let opts = ConnectionOptions::new(url::Url::parse("http://localhost:7233").unwrap())
20952095
.skip_get_system_info(true)
20962096
.service_override(service_override)
2097+
.dns_load_balancing(None)
20972098
.build();
20982099
let mut connection = crate::Connection::connect(opts).await.unwrap();
20992100

0 commit comments

Comments
 (0)