1
1
use anyhow:: { anyhow, Context } ;
2
2
use std:: { io, vec} ;
3
+ use tokio:: task:: JoinSet ;
3
4
4
- use crate :: dns:: DnsResolver ;
5
+ use crate :: dns:: { self , DnsResolver } ;
5
6
use base64:: Engine ;
6
7
use bytes:: BytesMut ;
7
8
use log:: warn;
@@ -11,7 +12,7 @@ use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
11
12
use std:: time:: Duration ;
12
13
use tokio:: io:: { AsyncReadExt , AsyncWriteExt } ;
13
14
use tokio:: net:: { TcpListener , TcpSocket , TcpStream } ;
14
- use tokio:: time:: timeout;
15
+ use tokio:: time:: { sleep , timeout} ;
15
16
use tokio_stream:: wrappers:: TcpListenerStream ;
16
17
use tracing:: log:: info;
17
18
use tracing:: { debug, instrument} ;
@@ -70,7 +71,9 @@ pub async fn connect(
70
71
71
72
let mut cnx = None ;
72
73
let mut last_err = None ;
73
- for addr in socket_addrs {
74
+ let mut join_set = JoinSet :: new ( ) ;
75
+
76
+ for ( ix, addr) in dns:: sort_socket_addrs ( & socket_addrs) . copied ( ) . enumerate ( ) {
74
77
debug ! ( "Connecting to {}" , addr) ;
75
78
76
79
let socket = match & addr {
@@ -79,16 +82,45 @@ pub async fn connect(
79
82
} ;
80
83
81
84
configure_socket ( socket2:: SockRef :: from ( & socket) , & so_mark) ?;
82
- match timeout ( connect_timeout, socket. connect ( addr) ) . await {
85
+
86
+ // Spawn the connection attempt in the join set.
87
+ // We include a delay of ix * 250 milliseconds, as per RFC8305.
88
+ // See https://datatracker.ietf.org/doc/html/rfc8305#section-5
89
+ let fut = async move {
90
+ if ix > 0 {
91
+ sleep ( Duration :: from_millis ( 250 * ix as u64 ) ) . await ;
92
+ }
93
+ match timeout ( connect_timeout, socket. connect ( addr) ) . await {
94
+ Ok ( Ok ( s) ) => Ok ( Ok ( s) ) ,
95
+ Ok ( Err ( e) ) => Ok ( Err ( ( addr, e) ) ) ,
96
+ Err ( e) => Err ( ( addr, e) ) ,
97
+ }
98
+ } ;
99
+ join_set. spawn ( fut) ;
100
+ }
101
+
102
+ // Wait for the next future that finishes in the join set, until we got one
103
+ // that resulted in a successful connection.
104
+ // If cnx is no longer None, we exit the loop, since this means that we got
105
+ // a successful connection.
106
+ while let ( None , Some ( res) ) = ( & cnx, join_set. join_next ( ) . await ) {
107
+ match res? {
83
108
Ok ( Ok ( stream) ) => {
109
+ // We've got a successful connection, so we can abort all other
110
+ // on-going attempts.
111
+ join_set. abort_all ( ) ;
112
+
113
+ debug ! (
114
+ "Connected to tcp endpoint {}, aborted all other connection attempts" ,
115
+ stream. peer_addr( ) ?
116
+ ) ;
84
117
cnx = Some ( stream) ;
85
- break ;
86
118
}
87
- Ok ( Err ( err) ) => {
88
- warn ! ( "Cannot connect to tcp endpoint {addr} reason {err}" ) ;
119
+ Ok ( Err ( ( addr , err) ) ) => {
120
+ debug ! ( "Cannot connect to tcp endpoint {addr} reason {err}" ) ;
89
121
last_err = Some ( err) ;
90
122
}
91
- Err ( _ ) => {
123
+ Err ( ( addr , _ ) ) => {
92
124
warn ! (
93
125
"Cannot connect to tcp endpoint {addr} due to timeout of {}s elapsed" ,
94
126
connect_timeout. as_secs( )
@@ -195,7 +227,7 @@ pub async fn run_server(bind: SocketAddr, ip_transparent: bool) -> Result<TcpLis
195
227
mod tests {
196
228
use super :: * ;
197
229
use futures_util:: pin_mut;
198
- use std:: net:: SocketAddr ;
230
+ use std:: net:: { Ipv4Addr , Ipv6Addr , SocketAddr } ;
199
231
use testcontainers:: core:: WaitFor ;
200
232
use testcontainers:: runners:: AsyncRunner ;
201
233
use testcontainers:: { ContainerAsync , Image , ImageArgs , RunnableImage } ;
@@ -227,6 +259,26 @@ mod tests {
227
259
}
228
260
}
229
261
262
+ #[ test]
263
+ fn test_sort_socket_addrs ( ) {
264
+ let addrs = [
265
+ SocketAddr :: V4 ( SocketAddrV4 :: new ( Ipv4Addr :: new ( 127 , 0 , 0 , 1 ) , 1 ) ) ,
266
+ SocketAddr :: V4 ( SocketAddrV4 :: new ( Ipv4Addr :: new ( 127 , 0 , 0 , 2 ) , 1 ) ) ,
267
+ SocketAddr :: V6 ( SocketAddrV6 :: new ( Ipv6Addr :: new ( 0 , 0 , 0 , 0 , 127 , 0 , 0 , 1 ) , 1 , 0 , 0 ) ) ,
268
+ SocketAddr :: V4 ( SocketAddrV4 :: new ( Ipv4Addr :: new ( 127 , 0 , 0 , 3 ) , 1 ) ) ,
269
+ SocketAddr :: V6 ( SocketAddrV6 :: new ( Ipv6Addr :: new ( 0 , 0 , 0 , 0 , 127 , 0 , 0 , 2 ) , 1 , 0 , 0 ) ) ,
270
+ ] ;
271
+ let expected = [
272
+ SocketAddr :: V6 ( SocketAddrV6 :: new ( Ipv6Addr :: new ( 0 , 0 , 0 , 0 , 127 , 0 , 0 , 1 ) , 1 , 0 , 0 ) ) ,
273
+ SocketAddr :: V4 ( SocketAddrV4 :: new ( Ipv4Addr :: new ( 127 , 0 , 0 , 1 ) , 1 ) ) ,
274
+ SocketAddr :: V6 ( SocketAddrV6 :: new ( Ipv6Addr :: new ( 0 , 0 , 0 , 0 , 127 , 0 , 0 , 2 ) , 1 , 0 , 0 ) ) ,
275
+ SocketAddr :: V4 ( SocketAddrV4 :: new ( Ipv4Addr :: new ( 127 , 0 , 0 , 2 ) , 1 ) ) ,
276
+ SocketAddr :: V4 ( SocketAddrV4 :: new ( Ipv4Addr :: new ( 127 , 0 , 0 , 3 ) , 1 ) ) ,
277
+ ] ;
278
+ let actual: Vec < _ > = dns:: sort_socket_addrs ( & addrs) . copied ( ) . collect ( ) ;
279
+ assert_eq ! ( expected, * actual) ;
280
+ }
281
+
230
282
#[ tokio:: test]
231
283
async fn test_proxy_connection ( ) {
232
284
let server_addr: SocketAddr = "[::1]:1236" . parse ( ) . unwrap ( ) ;
0 commit comments