@@ -10,11 +10,10 @@ pub use self::{read_packet::ReadPacket, write_packet::WritePacket};
1010
1111use bytes:: BytesMut ;
1212use futures_core:: { ready, stream} ;
13- use futures_util:: stream:: { FuturesUnordered , StreamExt } ;
14- use mio:: net:: { TcpKeepalive , TcpSocket } ;
1513use mysql_common:: proto:: codec:: PacketCodec as PacketCodecInner ;
1614use native_tls:: { Certificate , Identity , TlsConnector } ;
1715use pin_project:: pin_project;
16+ use socket2:: { Socket as Socket2Socket , TcpKeepalive } ;
1817#[ cfg( unix) ]
1918use tokio:: io:: AsyncWriteExt ;
2019use tokio:: {
@@ -35,14 +34,17 @@ use std::{
3534 Read ,
3635 } ,
3736 mem:: replace,
38- net:: { SocketAddr , ToSocketAddrs } ,
3937 ops:: { Deref , DerefMut } ,
4038 pin:: Pin ,
4139 task:: { Context , Poll } ,
4240 time:: Duration ,
4341} ;
4442
45- use crate :: { buffer_pool:: PooledBuf , error:: IoError , opts:: SslOpts } ;
43+ use crate :: {
44+ buffer_pool:: PooledBuf ,
45+ error:: IoError ,
46+ opts:: { HostPortOrUrl , SslOpts , DEFAULT_PORT } ,
47+ } ;
4648
4749#[ cfg( unix) ]
4850use crate :: io:: socket:: Socket ;
@@ -208,6 +210,7 @@ impl Endpoint {
208210 . map ( |x| vec ! [ x] )
209211 . or_else ( |_| {
210212 pem:: parse_many ( & * root_cert_data)
213+ . unwrap_or_default ( )
211214 . iter ( )
212215 . map ( pem:: encode)
213216 . map ( |s| Certificate :: from_pem ( s. as_bytes ( ) ) )
@@ -354,108 +357,41 @@ impl Stream {
354357 }
355358 }
356359
357- pub ( crate ) async fn connect_tcp < S > ( addr : S , keepalive : Option < Duration > ) -> io:: Result < Stream >
358- where
359- S : ToSocketAddrs ,
360- {
361- // TODO: Use tokio to setup keepalive (see tokio-rs/tokio#3082)
362- async fn connect_stream (
363- addr : SocketAddr ,
364- keepalive_opts : Option < TcpKeepalive > ,
365- ) -> io:: Result < TcpStream > {
366- let socket = if addr. is_ipv6 ( ) {
367- TcpSocket :: new_v6 ( ) ?
368- } else {
369- TcpSocket :: new_v4 ( ) ?
370- } ;
371-
372- if let Some ( keepalive_opts) = keepalive_opts {
373- socket. set_keepalive_params ( keepalive_opts) ?;
360+ pub ( crate ) async fn connect_tcp (
361+ addr : & HostPortOrUrl ,
362+ keepalive : Option < Duration > ,
363+ ) -> io:: Result < Stream > {
364+ let tcp_stream = match addr {
365+ HostPortOrUrl :: HostPort ( host, port) => {
366+ TcpStream :: connect ( ( host. as_str ( ) , * port) ) . await ?
374367 }
368+ HostPortOrUrl :: Url ( url) => {
369+ let addrs = url. socket_addrs ( || Some ( DEFAULT_PORT ) ) ?;
370+ TcpStream :: connect ( & * addrs) . await ?
371+ }
372+ } ;
375373
376- let stream = tokio:: task:: spawn_blocking ( move || {
377- let mut stream = socket. connect ( addr) ?;
378- let mut poll = mio:: Poll :: new ( ) ?;
379- let mut events = mio:: Events :: with_capacity ( 1024 ) ;
380-
381- poll. registry ( )
382- . register ( & mut stream, mio:: Token ( 0 ) , mio:: Interest :: WRITABLE ) ?;
383-
384- loop {
385- poll. poll ( & mut events, None ) ?;
386-
387- for event in & events {
388- if event. token ( ) == mio:: Token ( 0 ) && event. is_error ( ) {
389- return Err ( io:: Error :: new (
390- io:: ErrorKind :: ConnectionRefused ,
391- "Connection refused" ,
392- ) ) ;
393- }
394-
395- if event. token ( ) == mio:: Token ( 0 ) && event. is_writable ( ) {
396- // The socket connected (probably, it could still be a spurious
397- // wakeup)
398- return Ok :: < _ , io:: Error > ( stream) ;
399- }
400- }
401- }
402- } )
403- . await ??;
404-
374+ if let Some ( duration) = keepalive {
405375 #[ cfg( unix) ]
406- let std_stream = unsafe {
376+ let socket = unsafe {
407377 use std:: os:: unix:: prelude:: * ;
408- let fd = stream . into_raw_fd ( ) ;
409- std :: net :: TcpStream :: from_raw_fd ( fd)
378+ let fd = tcp_stream . as_raw_fd ( ) ;
379+ Socket2Socket :: from_raw_fd ( fd)
410380 } ;
411-
412381 #[ cfg( windows) ]
413- let std_stream = unsafe {
382+ let socket = unsafe {
414383 use std:: os:: windows:: prelude:: * ;
415- let fd = stream . into_raw_socket ( ) ;
416- std :: net :: TcpStream :: from_raw_socket ( fd )
384+ let sock = tcp_stream . as_raw_socket ( ) ;
385+ Socket2Socket :: from_raw_socket ( sock )
417386 } ;
418-
419- Ok ( TcpStream :: from_std ( std_stream ) ? )
387+ socket . set_tcp_keepalive ( & TcpKeepalive :: new ( ) . with_time ( duration ) ) ? ;
388+ std :: mem :: forget ( socket ) ;
420389 }
421390
422- let keepalive_opts = keepalive. map ( |time| TcpKeepalive :: new ( ) . with_time ( time) ) ;
423-
424- match addr. to_socket_addrs ( ) {
425- Ok ( addresses) => {
426- let mut streams = FuturesUnordered :: new ( ) ;
427-
428- for address in addresses {
429- streams. push ( connect_stream ( address, keepalive_opts. clone ( ) ) ) ;
430- }
431-
432- let mut err = None ;
433- while let Some ( stream) = streams. next ( ) . await {
434- match stream {
435- Err ( e) => {
436- err = Some ( e) ;
437- }
438- Ok ( stream) => {
439- return Ok ( Stream {
440- closed : false ,
441- codec : Box :: new ( Framed :: new ( stream. into ( ) , PacketCodec :: default ( ) ) )
442- . into ( ) ,
443- } ) ;
444- }
445- }
446- }
447-
448- if let Some ( e) = err {
449- Err ( e)
450- } else {
451- Err ( io:: Error :: new (
452- io:: ErrorKind :: InvalidInput ,
453- "could not resolve to any address" ,
454- ) )
455- }
456- }
457- Err ( err) => Err ( err) ,
458- }
391+ Ok ( Stream {
392+ closed : false ,
393+ codec : Box :: new ( Framed :: new ( tcp_stream. into ( ) , PacketCodec :: default ( ) ) ) . into ( ) ,
394+ } )
459395 }
460396
461397 #[ cfg( unix) ]
0 commit comments