Skip to content

Commit 6351cd9

Browse files
committed
feat(dns): Add flag to specify if we should prefer IPv4 over IPv6
1 parent 90d378e commit 6351cd9

File tree

4 files changed

+123
-64
lines changed

4 files changed

+123
-64
lines changed

src/dns.rs

+71-26
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@ use std::time::Duration;
1515
use tokio::net::{TcpStream, UdpSocket};
1616
use url::{Host, Url};
1717

18-
// Interweave v4 and v6 addresses as per RFC8305.
18+
// Interleave v4 and v6 addresses as per RFC8305.
1919
// The first address is v6 if we have any v6 addresses.
20-
pub fn sort_socket_addrs(socket_addrs: &[SocketAddr]) -> impl Iterator<Item = &'_ SocketAddr> {
21-
let mut pick_v6 = false;
20+
#[inline]
21+
fn sort_socket_addrs(socket_addrs: &[SocketAddr], prefer_ipv6: bool) -> impl Iterator<Item = &'_ SocketAddr> {
22+
let mut pick_v6 = !prefer_ipv6;
2223
let mut v6 = socket_addrs.iter().filter(|s| matches!(s, SocketAddr::V6(_)));
2324
let mut v4 = socket_addrs.iter().filter(|s| matches!(s, SocketAddr::V4(_)));
2425
std::iter::from_fn(move || {
@@ -34,48 +35,63 @@ pub fn sort_socket_addrs(socket_addrs: &[SocketAddr]) -> impl Iterator<Item = &'
3435
#[derive(Clone)]
3536
pub enum DnsResolver {
3637
System,
37-
TrustDns(AsyncResolver<GenericConnector<TokioRuntimeProviderWithSoMark>>),
38+
TrustDns {
39+
resolver: AsyncResolver<GenericConnector<TokioRuntimeProviderWithSoMark>>,
40+
prefer_ipv6: bool,
41+
},
3842
}
3943

4044
impl DnsResolver {
4145
pub async fn lookup_host(&self, domain: &str, port: u16) -> anyhow::Result<Vec<SocketAddr>> {
4246
let addrs: Vec<SocketAddr> = match self {
4347
Self::System => tokio::net::lookup_host(format!("{}:{}", domain, port)).await?.collect(),
44-
Self::TrustDns(dns_resolver) => dns_resolver
45-
.lookup_ip(domain)
46-
.await?
47-
.into_iter()
48-
.map(|ip| match ip {
49-
IpAddr::V4(ip) => SocketAddr::V4(SocketAddrV4::new(ip, port)),
50-
IpAddr::V6(ip) => SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)),
51-
})
52-
.collect(),
48+
Self::TrustDns { resolver, prefer_ipv6 } => {
49+
let addrs: Vec<_> = resolver
50+
.lookup_ip(domain)
51+
.await?
52+
.into_iter()
53+
.map(|ip| match ip {
54+
IpAddr::V4(ip) => SocketAddr::V4(SocketAddrV4::new(ip, port)),
55+
IpAddr::V6(ip) => SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)),
56+
})
57+
.collect();
58+
sort_socket_addrs(&addrs, *prefer_ipv6).copied().collect()
59+
}
5360
};
5461

5562
Ok(addrs)
5663
}
5764

58-
pub fn new_from_urls(resolvers: &[Url], proxy: Option<Url>, so_mark: Option<u32>) -> anyhow::Result<Self> {
65+
pub fn new_from_urls(
66+
resolvers: &[Url],
67+
proxy: Option<Url>,
68+
so_mark: Option<u32>,
69+
prefer_ipv6: bool,
70+
) -> anyhow::Result<Self> {
5971
if resolvers.is_empty() {
6072
// no dns resolver specified, fall-back to default one
6173
let Ok((cfg, mut opts)) = hickory_resolver::system_conf::read_system_conf() else {
6274
warn!("Fall-backing to system dns resolver. You should consider specifying a dns resolver. To avoid performance issue");
6375
return Ok(Self::System);
6476
};
6577

66-
opts.timeout = std::time::Duration::from_secs(1);
78+
opts.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
79+
opts.timeout = Duration::from_secs(1);
6780
// Windows end-up with too many dns resolvers, which causes a performance issue
6881
// https://github.com/hickory-dns/hickory-dns/issues/1968
6982
#[cfg(target_os = "windows")]
7083
{
7184
opts.cache_size = 1024;
7285
opts.num_concurrent_reqs = cfg.name_servers().len();
7386
}
74-
return Ok(Self::TrustDns(AsyncResolver::new(
75-
cfg,
76-
opts,
77-
GenericConnector::new(TokioRuntimeProviderWithSoMark::new(proxy, so_mark)),
78-
)));
87+
return Ok(Self::TrustDns {
88+
resolver: AsyncResolver::new(
89+
cfg,
90+
opts,
91+
GenericConnector::new(TokioRuntimeProviderWithSoMark::new(proxy, so_mark)),
92+
),
93+
prefer_ipv6,
94+
});
7995
};
8096

8197
// if one is specified as system, use the default one from libc
@@ -127,13 +143,16 @@ impl DnsResolver {
127143
}
128144

129145
let mut opts = ResolverOpts::default();
130-
opts.timeout = std::time::Duration::from_secs(1);
146+
opts.timeout = Duration::from_secs(1);
131147
opts.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
132-
Ok(Self::TrustDns(AsyncResolver::new(
133-
cfg,
134-
opts,
135-
GenericConnector::new(TokioRuntimeProviderWithSoMark::new(proxy, so_mark)),
136-
)))
148+
Ok(Self::TrustDns {
149+
resolver: AsyncResolver::new(
150+
cfg,
151+
opts,
152+
GenericConnector::new(TokioRuntimeProviderWithSoMark::new(proxy, so_mark)),
153+
),
154+
prefer_ipv6,
155+
})
137156
}
138157
}
139158

@@ -235,3 +254,29 @@ impl RuntimeProvider for TokioRuntimeProviderWithSoMark {
235254
Box::pin(socket)
236255
}
237256
}
257+
258+
#[cfg(test)]
259+
mod tests {
260+
use crate::dns::sort_socket_addrs;
261+
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
262+
263+
#[test]
264+
fn test_sort_socket_addrs() {
265+
let addrs = [
266+
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1)),
267+
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 2), 1)),
268+
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 1), 1, 0, 0)),
269+
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 3), 1)),
270+
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 2), 1, 0, 0)),
271+
];
272+
let expected = [
273+
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 1), 1, 0, 0)),
274+
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1)),
275+
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 2), 1, 0, 0)),
276+
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 2), 1)),
277+
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 3), 1)),
278+
];
279+
let actual: Vec<_> = sort_socket_addrs(&addrs, true).copied().collect();
280+
assert_eq!(expected, *actual);
281+
}
282+
}

src/main.rs

+43-6
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,17 @@ struct Client {
258258
/// **WARN** On windows you may want to specify explicitly the DNS resolver to avoid excessive DNS queries
259259
#[arg(long, verbatim_doc_comment)]
260260
dns_resolver: Vec<Url>,
261+
262+
/// Enable if you prefer the dns resolver to prioritize IPv4 over IPv6
263+
/// This is useful if you have a broken IPv6 connection, and want to avoid the delay of trying to connect to IPv6
264+
/// If you don't have any IPv6 this does not change anything.
265+
#[arg(
266+
long,
267+
default_value = "false",
268+
env = "WSTUNNEL_DNS_PREFER_IPV4",
269+
verbatim_doc_comment
270+
)]
271+
dns_resolver_prefer_ipv4: bool,
261272
}
262273

263274
#[derive(clap::Args, Debug)]
@@ -295,6 +306,17 @@ struct Server {
295306
#[arg(long, verbatim_doc_comment)]
296307
dns_resolver: Vec<Url>,
297308

309+
/// Enable if you prefer the dns resolver to prioritize IPv4 over IPv6
310+
/// This is useful if you have a broken IPv6 connection, and want to avoid the delay of trying to connect to IPv6
311+
/// If you don't have any IPv6 this does not change anything.
312+
#[arg(
313+
long,
314+
default_value = "false",
315+
env = "WSTUNNEL_DNS_PREFER_IPV4",
316+
verbatim_doc_comment
317+
)]
318+
dns_resolver_prefer_ipv4: bool,
319+
298320
/// Server will only accept connection from the specified tunnel information.
299321
/// Can be specified multiple time
300322
/// Example: --restrict-to "google.com:443" --restrict-to "localhost:22"
@@ -755,8 +777,13 @@ impl WsClientConfig {
755777
#[tokio::main]
756778
async fn main() {
757779
let args = Wstunnel::parse();
758-
let socket = UdpSocket::bind(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0)).await.unwrap();
759-
socket.connect("[2001:4810:0:3::78]:443".parse::<SocketAddr>().unwrap()).await.unwrap();
780+
let socket = UdpSocket::bind(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0))
781+
.await
782+
.unwrap();
783+
socket
784+
.connect("[2001:4810:0:3::78]:443".parse::<SocketAddr>().unwrap())
785+
.await
786+
.unwrap();
760787

761788
// Setup logging
762789
let mut env_filter = EnvFilter::builder().parse(&args.log_lvl).expect("Invalid log level");
@@ -902,8 +929,13 @@ async fn main() {
902929
websocket_mask_frame: args.websocket_mask_frame,
903930
cnx_pool: None,
904931
tls_reloader: None,
905-
dns_resolver: DnsResolver::new_from_urls(&args.dns_resolver, http_proxy.clone(), args.socket_so_mark)
906-
.expect("cannot create dns resolver"),
932+
dns_resolver: DnsResolver::new_from_urls(
933+
&args.dns_resolver,
934+
http_proxy.clone(),
935+
args.socket_so_mark,
936+
!args.dns_resolver_prefer_ipv4,
937+
)
938+
.expect("cannot create dns resolver"),
907939
http_proxy,
908940
};
909941

@@ -1324,8 +1356,13 @@ async fn main() {
13241356
timeout_connect: Duration::from_secs(10),
13251357
websocket_mask_frame: args.websocket_mask_frame,
13261358
tls: tls_config,
1327-
dns_resolver: DnsResolver::new_from_urls(&args.dns_resolver, None, args.socket_so_mark)
1328-
.expect("Cannot create DNS resolver"),
1359+
dns_resolver: DnsResolver::new_from_urls(
1360+
&args.dns_resolver,
1361+
None,
1362+
args.socket_so_mark,
1363+
!args.dns_resolver_prefer_ipv4,
1364+
)
1365+
.expect("Cannot create DNS resolver"),
13291366
restriction_config: args.restrict_config,
13301367
};
13311368

src/tcp.rs

+5-27
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use anyhow::{anyhow, Context};
22
use std::{io, vec};
33
use tokio::task::JoinSet;
44

5-
use crate::dns::{self, DnsResolver};
5+
use crate::dns::DnsResolver;
66
use base64::Engine;
77
use bytes::BytesMut;
88
use log::warn;
@@ -73,14 +73,11 @@ pub async fn connect(
7373
let mut last_err = None;
7474
let mut join_set = JoinSet::new();
7575

76-
for (ix, addr) in dns::sort_socket_addrs(&socket_addrs).copied().enumerate() {
77-
debug!("Connecting to {}", addr);
78-
76+
for (ix, addr) in socket_addrs.into_iter().enumerate() {
7977
let socket = match &addr {
8078
SocketAddr::V4(_) => TcpSocket::new_v4()?,
8179
SocketAddr::V6(_) => TcpSocket::new_v6()?,
8280
};
83-
8481
configure_socket(socket2::SockRef::from(&socket), &so_mark)?;
8582

8683
// Spawn the connection attempt in the join set.
@@ -90,6 +87,7 @@ pub async fn connect(
9087
if ix > 0 {
9188
sleep(Duration::from_millis(250 * ix as u64)).await;
9289
}
90+
debug!("Connecting to {}", addr);
9391
match timeout(connect_timeout, socket.connect(addr)).await {
9492
Ok(Ok(s)) => Ok(Ok(s)),
9593
Ok(Err(e)) => Ok(Err((addr, e))),
@@ -107,7 +105,7 @@ pub async fn connect(
107105
match res? {
108106
Ok(Ok(stream)) => {
109107
// We've got a successful connection, so we can abort all other
110-
// on-going attempts.
108+
// ongoing attempts.
111109
join_set.abort_all();
112110

113111
debug!(
@@ -227,7 +225,7 @@ pub async fn run_server(bind: SocketAddr, ip_transparent: bool) -> Result<TcpLis
227225
mod tests {
228226
use super::*;
229227
use futures_util::pin_mut;
230-
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
228+
use std::net::SocketAddr;
231229
use testcontainers::core::WaitFor;
232230
use testcontainers::runners::AsyncRunner;
233231
use testcontainers::{ContainerAsync, Image, ImageArgs, RunnableImage};
@@ -259,26 +257,6 @@ mod tests {
259257
}
260258
}
261259

262-
#[test]
263-
fn test_sort_socket_addrs() {
264-
let addrs = [
265-
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1)),
266-
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 2), 1)),
267-
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 1), 1, 0, 0)),
268-
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 3), 1)),
269-
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 2), 1, 0, 0)),
270-
];
271-
let expected = [
272-
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 1), 1, 0, 0)),
273-
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1)),
274-
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 127, 0, 0, 2), 1, 0, 0)),
275-
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 2), 1)),
276-
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 3), 1)),
277-
];
278-
let actual: Vec<_> = dns::sort_socket_addrs(&addrs).copied().collect();
279-
assert_eq!(expected, *actual);
280-
}
281-
282260
#[tokio::test]
283261
async fn test_proxy_connection() {
284262
let server_addr: SocketAddr = "[::1]:1236".parse().unwrap();

src/udp.rs

+4-5
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
1919
use tokio::net::UdpSocket;
2020
use tokio::sync::futures::Notified;
2121

22-
use crate::dns::{self, DnsResolver};
22+
use crate::dns::DnsResolver;
2323
use tokio::sync::Notify;
2424
use tokio::time::{sleep, timeout, Interval};
2525
use tracing::{debug, error, info};
@@ -340,9 +340,7 @@ pub async fn connect(
340340
let mut last_err = None;
341341
let mut join_set = JoinSet::new();
342342

343-
for (ix, addr) in dns::sort_socket_addrs(&socket_addrs).copied().enumerate() {
344-
debug!("connecting to {}", addr);
345-
343+
for (ix, addr) in socket_addrs.into_iter().enumerate() {
346344
let socket = match &addr {
347345
SocketAddr::V4(_) => UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)).await,
348346
SocketAddr::V6(_) => UdpSocket::bind(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0)).await,
@@ -364,6 +362,7 @@ pub async fn connect(
364362
sleep(Duration::from_millis(250 * ix as u64)).await;
365363
}
366364

365+
debug!("connecting to {}", addr);
367366
match timeout(connect_timeout, socket.connect(addr)).await {
368367
Ok(Ok(())) => Ok(Ok(socket)),
369368
Ok(Err(e)) => Ok(Err((addr, e))),
@@ -381,7 +380,7 @@ pub async fn connect(
381380
match res? {
382381
Ok(Ok(socket)) => {
383382
// We've got a successful connection, so we can abort all other
384-
// on-going attempts.
383+
// ongoing attempts.
385384
join_set.abort_all();
386385

387386
debug!(

0 commit comments

Comments
 (0)