Skip to content

Commit 90d378e

Browse files
authored
Do DNS queries for both A and AAAA simultaneously (#302)
* Do DNS queries for both A and AAAA simultaneously We implement a basic version of RFC8305 (happy eyeballs) to establish the connection afterwards. * Try to connect to UDP sockets simultaneously
1 parent 4f570dc commit 90d378e

File tree

3 files changed

+122
-21
lines changed

3 files changed

+122
-21
lines changed

src/dns.rs

+18-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::tcp;
22
use anyhow::{anyhow, Context};
33
use futures_util::{FutureExt, TryFutureExt};
4-
use hickory_resolver::config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts};
4+
use hickory_resolver::config::{LookupIpStrategy, NameServerConfig, Protocol, ResolverConfig, ResolverOpts};
55
use hickory_resolver::name_server::{GenericConnector, RuntimeProvider, TokioRuntimeProvider};
66
use hickory_resolver::proto::iocompat::AsyncIoTokioAsStd;
77
use hickory_resolver::proto::TokioTime;
@@ -15,6 +15,22 @@ use std::time::Duration;
1515
use tokio::net::{TcpStream, UdpSocket};
1616
use url::{Host, Url};
1717

18+
// Interweave v4 and v6 addresses as per RFC8305.
19+
// The first address is v6 if we have any v6 addresses.
20+
pub fn sort_socket_addrs(socket_addrs: &[SocketAddr]) -> impl Iterator<Item = &'_ SocketAddr> {
21+
let mut pick_v6 = false;
22+
let mut v6 = socket_addrs.iter().filter(|s| matches!(s, SocketAddr::V6(_)));
23+
let mut v4 = socket_addrs.iter().filter(|s| matches!(s, SocketAddr::V4(_)));
24+
std::iter::from_fn(move || {
25+
pick_v6 = !pick_v6;
26+
if pick_v6 {
27+
v6.next().or_else(|| v4.next())
28+
} else {
29+
v4.next().or_else(|| v6.next())
30+
}
31+
})
32+
}
33+
1834
#[derive(Clone)]
1935
pub enum DnsResolver {
2036
System,
@@ -112,6 +128,7 @@ impl DnsResolver {
112128

113129
let mut opts = ResolverOpts::default();
114130
opts.timeout = std::time::Duration::from_secs(1);
131+
opts.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
115132
Ok(Self::TrustDns(AsyncResolver::new(
116133
cfg,
117134
opts,

src/tcp.rs

+61-9
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
use anyhow::{anyhow, Context};
22
use std::{io, vec};
3+
use tokio::task::JoinSet;
34

4-
use crate::dns::DnsResolver;
5+
use crate::dns::{self, DnsResolver};
56
use base64::Engine;
67
use bytes::BytesMut;
78
use log::warn;
@@ -11,7 +12,7 @@ use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
1112
use std::time::Duration;
1213
use tokio::io::{AsyncReadExt, AsyncWriteExt};
1314
use tokio::net::{TcpListener, TcpSocket, TcpStream};
14-
use tokio::time::timeout;
15+
use tokio::time::{sleep, timeout};
1516
use tokio_stream::wrappers::TcpListenerStream;
1617
use tracing::log::info;
1718
use tracing::{debug, instrument};
@@ -70,7 +71,9 @@ pub async fn connect(
7071

7172
let mut cnx = None;
7273
let mut last_err = None;
73-
for addr in socket_addrs {
74+
let mut join_set = JoinSet::new();
75+
76+
for (ix, addr) in dns::sort_socket_addrs(&socket_addrs).copied().enumerate() {
7477
debug!("Connecting to {}", addr);
7578

7679
let socket = match &addr {
@@ -79,16 +82,45 @@ pub async fn connect(
7982
};
8083

8184
configure_socket(socket2::SockRef::from(&socket), &so_mark)?;
82-
match timeout(connect_timeout, socket.connect(addr)).await {
85+
86+
// Spawn the connection attempt in the join set.
87+
// We include a delay of ix * 250 milliseconds, as per RFC8305.
88+
// See https://datatracker.ietf.org/doc/html/rfc8305#section-5
89+
let fut = async move {
90+
if ix > 0 {
91+
sleep(Duration::from_millis(250 * ix as u64)).await;
92+
}
93+
match timeout(connect_timeout, socket.connect(addr)).await {
94+
Ok(Ok(s)) => Ok(Ok(s)),
95+
Ok(Err(e)) => Ok(Err((addr, e))),
96+
Err(e) => Err((addr, e)),
97+
}
98+
};
99+
join_set.spawn(fut);
100+
}
101+
102+
// Wait for the next future that finishes in the join set, until we got one
103+
// that resulted in a successful connection.
104+
// If cnx is no longer None, we exit the loop, since this means that we got
105+
// a successful connection.
106+
while let (None, Some(res)) = (&cnx, join_set.join_next().await) {
107+
match res? {
83108
Ok(Ok(stream)) => {
109+
// We've got a successful connection, so we can abort all other
110+
// on-going attempts.
111+
join_set.abort_all();
112+
113+
debug!(
114+
"Connected to tcp endpoint {}, aborted all other connection attempts",
115+
stream.peer_addr()?
116+
);
84117
cnx = Some(stream);
85-
break;
86118
}
87-
Ok(Err(err)) => {
88-
warn!("Cannot connect to tcp endpoint {addr} reason {err}");
119+
Ok(Err((addr, err))) => {
120+
debug!("Cannot connect to tcp endpoint {addr} reason {err}");
89121
last_err = Some(err);
90122
}
91-
Err(_) => {
123+
Err((addr, _)) => {
92124
warn!(
93125
"Cannot connect to tcp endpoint {addr} due to timeout of {}s elapsed",
94126
connect_timeout.as_secs()
@@ -195,7 +227,7 @@ pub async fn run_server(bind: SocketAddr, ip_transparent: bool) -> Result<TcpLis
195227
mod tests {
196228
use super::*;
197229
use futures_util::pin_mut;
198-
use std::net::SocketAddr;
230+
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
199231
use testcontainers::core::WaitFor;
200232
use testcontainers::runners::AsyncRunner;
201233
use testcontainers::{ContainerAsync, Image, ImageArgs, RunnableImage};
@@ -227,6 +259,26 @@ mod tests {
227259
}
228260
}
229261

262+
#[test]
263+
fn test_sort_socket_addrs() {
264+
let addrs = [
265+
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1)),
266+
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 2), 1)),
267+
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 1), 1, 0, 0)),
268+
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 3), 1)),
269+
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 2), 1, 0, 0)),
270+
];
271+
let expected = [
272+
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 1), 1, 0, 0)),
273+
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1)),
274+
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 2), 1, 0, 0)),
275+
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 2), 1)),
276+
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 3), 1)),
277+
];
278+
let actual: Vec<_> = dns::sort_socket_addrs(&addrs).copied().collect();
279+
assert_eq!(expected, *actual);
280+
}
281+
230282
#[tokio::test]
231283
async fn test_proxy_connection() {
232284
let server_addr: SocketAddr = "[::1]:1236".parse().unwrap();

src/udp.rs

+43-11
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use std::future::Future;
88
use std::io;
99
use std::io::{Error, ErrorKind};
1010
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
11+
use tokio::task::JoinSet;
1112

1213
use log::warn;
1314
use std::pin::{pin, Pin};
@@ -18,9 +19,9 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
1819
use tokio::net::UdpSocket;
1920
use tokio::sync::futures::Notified;
2021

21-
use crate::dns::DnsResolver;
22+
use crate::dns::{self, DnsResolver};
2223
use tokio::sync::Notify;
23-
use tokio::time::{timeout, Interval};
24+
use tokio::time::{sleep, timeout, Interval};
2425
use tracing::{debug, error, info};
2526
use url::Host;
2627

@@ -337,7 +338,9 @@ pub async fn connect(
337338

338339
let mut cnx = None;
339340
let mut last_err = None;
340-
for addr in socket_addrs {
341+
let mut join_set = JoinSet::new();
342+
343+
for (ix, addr) in dns::sort_socket_addrs(&socket_addrs).copied().enumerate() {
341344
debug!("connecting to {}", addr);
342345

343346
let socket = match &addr {
@@ -353,18 +356,47 @@ pub async fn connect(
353356
}
354357
};
355358

356-
match timeout(connect_timeout, socket.connect(addr)).await {
357-
Ok(Ok(_)) => {
359+
// Spawn the connection attempt in the join set.
360+
// We include a delay of ix * 250 milliseconds, as per RFC8305.
361+
// See https://datatracker.ietf.org/doc/html/rfc8305#section-5
362+
let fut = async move {
363+
if ix > 0 {
364+
sleep(Duration::from_millis(250 * ix as u64)).await;
365+
}
366+
367+
match timeout(connect_timeout, socket.connect(addr)).await {
368+
Ok(Ok(())) => Ok(Ok(socket)),
369+
Ok(Err(e)) => Ok(Err((addr, e))),
370+
Err(e) => Err((addr, e)),
371+
}
372+
};
373+
join_set.spawn(fut);
374+
}
375+
376+
// Wait for the next future that finishes in the join set, until we got one
377+
// that resulted in a successful connection.
378+
// If cnx is no longer None, we exit the loop, since this means that we got
379+
// a successful connection.
380+
while let (None, Some(res)) = (&cnx, join_set.join_next().await) {
381+
match res? {
382+
Ok(Ok(socket)) => {
383+
// We've got a successful connection, so we can abort all other
384+
// on-going attempts.
385+
join_set.abort_all();
386+
387+
debug!(
388+
"Connected to udp endpoint {}, aborted all other connection attempts",
389+
socket.peer_addr()?
390+
);
358391
cnx = Some(socket);
359-
break;
360392
}
361-
Ok(Err(err)) => {
362-
debug!("Cannot connect udp socket to specified peer {addr} reason {err}");
393+
Ok(Err((addr, err))) => {
394+
debug!("Cannot connect to udp endpoint {addr} reason {err}");
363395
last_err = Some(err);
364396
}
365-
Err(_) => {
366-
debug!(
367-
"Cannot connect udp socket to specified peer {addr} due to timeout of {}s elapsed",
397+
Err((addr, _)) => {
398+
warn!(
399+
"Cannot connect to udp endpoint {addr} due to timeout of {}s elapsed",
368400
connect_timeout.as_secs()
369401
);
370402
}

0 commit comments

Comments
 (0)