11use std:: {
22 collections:: VecDeque ,
33 fmt:: Debug ,
4- net:: { IpAddr , SocketAddr } ,
4+ net:: { IpAddr , SocketAddr , SocketAddrV6 } ,
55 pin:: { Pin , pin} ,
66 task:: { Context , Poll , Waker } ,
77 time:: { Duration , Instant } ,
@@ -10,7 +10,7 @@ use std::{
1010use compio_buf:: bytes:: Bytes ;
1111use compio_log:: Instrument ;
1212use compio_runtime:: JoinHandle ;
13- use flume:: { Receiver , Sender } ;
13+ use flume:: { Receiver , Sender , unbounded } ;
1414use futures_util:: {
1515 FutureExt , StreamExt ,
1616 future:: { self , Fuse , FusedFuture , LocalBoxFuture } ,
@@ -19,14 +19,15 @@ use futures_util::{
1919#[ cfg( rustls) ]
2020use noq_proto:: crypto:: rustls:: HandshakeData ;
2121use noq_proto:: {
22- ConnectionHandle , ConnectionStats , Dir , EndpointEvent , PathId , Side , StreamEvent , StreamId ,
23- VarInt , congestion:: Controller ,
22+ ConnectionHandle , ConnectionStats , Dir , EndpointEvent , FourTuple , PathError , PathEvent , PathId ,
23+ PathStats , PathStatus , Side , StreamEvent , StreamId , VarInt , congestion:: Controller ,
24+ n0_nat_traversal,
2425} ;
2526use rustc_hash:: FxHashMap as HashMap ;
2627use thiserror:: Error ;
2728
2829use crate :: {
29- RecvStream , SendStream , Socket ,
30+ OpenPath , Path , RecvStream , SendStream , Socket ,
3031 sync:: {
3132 mutex_blocking:: { Mutex , MutexGuard } ,
3233 shared:: Shared ,
@@ -44,14 +45,21 @@ pub(crate) struct ConnectionState {
4445 pub ( crate ) conn : noq_proto:: Connection ,
4546 pub ( crate ) error : Option < ConnectionError > ,
4647 connected : bool ,
48+ handshake_confirmed : bool ,
4749 worker : Option < JoinHandle < ( ) > > ,
4850 poller : Option < Waker > ,
4951 on_connected : Option < Waker > ,
5052 on_handshake_data : Option < Waker > ,
53+ on_handshake_confirmed : VecDeque < Waker > ,
5154 datagram_received : VecDeque < Waker > ,
5255 datagrams_unblocked : VecDeque < Waker > ,
5356 stream_opened : [ VecDeque < Waker > ; 2 ] ,
5457 stream_available : [ VecDeque < Waker > ; 2 ] ,
58+ open_path : HashMap < PathId , Sender < Result < ( ) , PathError > > > ,
59+ path_events : Vec < Sender < PathEvent > > ,
60+ observed_external_addr : Option < SocketAddr > ,
61+ nat_traversal_updates : Vec < Sender < n0_nat_traversal:: Event > > ,
62+ final_path_stats : HashMap < PathId , PathStats > ,
5563 pub ( crate ) writable : HashMap < StreamId , Waker > ,
5664 pub ( crate ) readable : HashMap < StreamId , Waker > ,
5765 pub ( crate ) stopped : HashMap < StreamId , Waker > ,
@@ -68,6 +76,7 @@ impl ConnectionState {
6876 if let Some ( waker) = self . on_connected . take ( ) {
6977 waker. wake ( )
7078 }
79+ self . on_handshake_confirmed . drain ( ..) . for_each ( Waker :: wake) ;
7180 self . datagram_received . drain ( ..) . for_each ( Waker :: wake) ;
7281 self . datagrams_unblocked . drain ( ..) . for_each ( Waker :: wake) ;
7382 for e in & mut self . stream_opened {
@@ -76,6 +85,9 @@ impl ConnectionState {
7685 for e in & mut self . stream_available {
7786 e. drain ( ..) . for_each ( Waker :: wake) ;
7887 }
88+ for tx in self . open_path . drain ( ) . map ( |( _, tx) | tx) {
89+ let _ = tx. send ( Err ( PathError :: ValidationFailed ) ) ;
90+ }
7991 wake_all_streams ( & mut self . writable ) ;
8092 wake_all_streams ( & mut self . readable ) ;
8193 wake_all_streams ( & mut self . stopped ) ;
@@ -104,6 +116,12 @@ impl ConnectionState {
104116 pub ( crate ) fn check_0rtt ( & self ) -> bool {
105117 self . conn . side ( ) . is_server ( ) || self . conn . is_handshaking ( ) || self . conn . accepted_0rtt ( )
106118 }
119+
120+ pub ( crate ) fn path_stats ( & mut self , path_id : PathId ) -> Option < PathStats > {
121+ self . conn
122+ . path_stats ( path_id)
123+ . or_else ( || self . final_path_stats . get ( & path_id) . copied ( ) )
124+ }
107125}
108126
109127fn wake_stream ( stream : StreamId , wakers : & mut HashMap < StreamId , Waker > ) {
@@ -116,6 +134,41 @@ fn wake_all_streams(wakers: &mut HashMap<StreamId, Waker>) {
116134 wakers. drain ( ) . for_each ( |( _, waker) | waker. wake ( ) )
117135}
118136
137+ fn wake_waiters ( wakers : & mut VecDeque < Waker > ) {
138+ wakers. drain ( ..) . for_each ( Waker :: wake)
139+ }
140+
141+ fn broadcast < T : Clone > ( listeners : & mut Vec < Sender < T > > , event : T ) {
142+ listeners. retain ( |tx| tx. send ( event. clone ( ) ) . is_ok ( ) ) ;
143+ }
144+
145+ fn normalize_remote_address (
146+ state : & ConnectionState ,
147+ addr : SocketAddr ,
148+ ) -> Result < SocketAddr , PathError > {
149+ let ipv6 = state
150+ . conn
151+ . paths ( )
152+ . iter ( )
153+ . filter_map ( |id| state. conn . network_path ( * id) . ok ( ) )
154+ . map ( |path| path. remote . is_ipv6 ( ) )
155+ . next ( )
156+ . unwrap_or_default ( ) ;
157+ if addr. is_ipv6 ( ) && !ipv6 {
158+ return Err ( PathError :: InvalidRemoteAddress ( addr) ) ;
159+ }
160+ Ok ( if ipv6 {
161+ SocketAddr :: V6 ( match addr {
162+ SocketAddr :: V4 ( addr) => {
163+ SocketAddrV6 :: new ( addr. ip ( ) . to_ipv6_mapped ( ) , addr. port ( ) , 0 , 0 )
164+ }
165+ SocketAddr :: V6 ( addr) => addr,
166+ } )
167+ } else {
168+ addr
169+ } )
170+ }
171+
119172#[ derive( Debug ) ]
120173pub ( crate ) struct ConnectionInner {
121174 state : Mutex < ConnectionState > ,
@@ -143,15 +196,22 @@ impl ConnectionInner {
143196 state : Mutex :: new ( ConnectionState {
144197 conn,
145198 connected : false ,
199+ handshake_confirmed : false ,
146200 error : None ,
147201 worker : None ,
148202 poller : None ,
149203 on_connected : None ,
150204 on_handshake_data : None ,
205+ on_handshake_confirmed : VecDeque :: new ( ) ,
151206 datagram_received : VecDeque :: new ( ) ,
152207 datagrams_unblocked : VecDeque :: new ( ) ,
153208 stream_opened : [ VecDeque :: new ( ) , VecDeque :: new ( ) ] ,
154209 stream_available : [ VecDeque :: new ( ) , VecDeque :: new ( ) ] ,
210+ open_path : HashMap :: default ( ) ,
211+ path_events : Vec :: new ( ) ,
212+ observed_external_addr : None ,
213+ nat_traversal_updates : Vec :: new ( ) ,
214+ final_path_stats : HashMap :: default ( ) ,
155215 writable : HashMap :: default ( ) ,
156216 readable : HashMap :: default ( ) ,
157217 stopped : HashMap :: default ( ) ,
@@ -284,13 +344,33 @@ impl ConnectionInner {
284344 DatagramsUnblocked => state. datagrams_unblocked . drain ( ..) . for_each ( Waker :: wake) ,
285345
286346 HandshakeConfirmed => {
287- todo ! ( )
347+ state. handshake_confirmed = true ;
348+ wake_waiters ( & mut state. on_handshake_confirmed ) ;
288349 }
289- Path ( _) => {
290- todo ! ( )
350+ Path ( event) => {
351+ match & event {
352+ PathEvent :: ObservedAddr { addr, .. } => {
353+ state. observed_external_addr = Some ( * addr) ;
354+ }
355+ PathEvent :: Opened { id } => {
356+ if let Some ( tx) = state. open_path . remove ( id) {
357+ let _ = tx. send ( Ok ( ( ) ) ) ;
358+ }
359+ }
360+ PathEvent :: Abandoned { id, .. } => {
361+ if let Some ( tx) = state. open_path . remove ( id) {
362+ let _ = tx. send ( Err ( PathError :: ValidationFailed ) ) ;
363+ }
364+ }
365+ PathEvent :: Discarded { id, path_stats } => {
366+ state. final_path_stats . insert ( * id, * path_stats) ;
367+ }
368+ PathEvent :: RemoteStatus { .. } => { }
369+ }
370+ broadcast ( & mut state. path_events , event) ;
291371 }
292- NatTraversal ( _ ) => {
293- todo ! ( )
372+ NatTraversal ( event ) => {
373+ broadcast ( & mut state . nat_traversal_updates , event ) ;
294374 }
295375 }
296376 }
@@ -674,6 +754,25 @@ impl Connection {
674754 . close ( error_code, Bytes :: copy_from_slice ( reason) ) ;
675755 }
676756
757+ /// Wait for the TLS handshake to be confirmed.
758+ pub async fn handshake_confirmed ( & self ) -> Result < ( ) , ConnectionError > {
759+ future:: poll_fn ( |cx| {
760+ let mut state = self . 0 . try_state ( ) ?;
761+ if state. handshake_confirmed {
762+ return Poll :: Ready ( Ok ( ( ) ) ) ;
763+ }
764+ if !state
765+ . on_handshake_confirmed
766+ . iter ( )
767+ . any ( |waker| waker. will_wake ( cx. waker ( ) ) )
768+ {
769+ state. on_handshake_confirmed . push_back ( cx. waker ( ) . clone ( ) ) ;
770+ }
771+ Poll :: Pending
772+ } )
773+ . await
774+ }
775+
677776 /// Wait for the connection to be closed for any reason.
678777 pub async fn closed ( & self ) -> ConnectionError {
679778 let worker = self . 0 . state ( ) . worker . take ( ) ;
@@ -691,6 +790,111 @@ impl Connection {
691790 self . 0 . try_state ( ) . err ( )
692791 }
693792
793+ /// Opens an additional path if multipath is negotiated.
794+ pub fn open_path ( & self , addr : SocketAddr , initial_status : PathStatus ) -> OpenPath {
795+ let mut state = self . 0 . state ( ) ;
796+ let addr = match normalize_remote_address ( & state, addr) {
797+ Ok ( addr) => addr,
798+ Err ( err) => return OpenPath :: rejected ( err) ,
799+ } ;
800+ let ( tx, rx) = flume:: bounded ( 1 ) ;
801+ let result = state. conn . open_path (
802+ FourTuple {
803+ remote : addr,
804+ local_ip : None ,
805+ } ,
806+ initial_status,
807+ Instant :: now ( ) ,
808+ ) ;
809+ match result {
810+ Ok ( path_id) => {
811+ state. open_path . insert ( path_id, tx) ;
812+ state. wake ( ) ;
813+ OpenPath :: new ( path_id, rx, self . 0 . clone ( ) )
814+ }
815+ Err ( err) => OpenPath :: rejected ( err) ,
816+ }
817+ }
818+
819+ /// Returns the path handle for an open path.
820+ pub fn path ( & self , id : PathId ) -> Option < Path > {
821+ Path :: new ( & self . 0 , id)
822+ }
823+
824+ /// Subscribe to path events for this connection.
825+ pub fn path_events ( & self ) -> Receiver < PathEvent > {
826+ let ( tx, rx) = unbounded ( ) ;
827+ self . 0 . state ( ) . path_events . push ( tx) ;
828+ rx
829+ }
830+
831+ /// Subscribe to NAT traversal updates for this connection.
832+ pub fn nat_traversal_updates ( & self ) -> Receiver < n0_nat_traversal:: Event > {
833+ let ( tx, rx) = unbounded ( ) ;
834+ self . 0 . state ( ) . nat_traversal_updates . push ( tx) ;
835+ rx
836+ }
837+
838+ /// The latest external address observed by the peer.
839+ pub fn observed_external_addr ( & self ) -> Option < SocketAddr > {
840+ self . 0 . state ( ) . observed_external_addr
841+ }
842+
843+ /// Statistics for a specific path.
844+ pub fn path_stats ( & self , path_id : PathId ) -> Option < PathStats > {
845+ self . 0 . state ( ) . path_stats ( path_id)
846+ }
847+
848+ /// Whether the multipath extension was negotiated for this connection.
849+ pub fn is_multipath_enabled ( & self ) -> bool {
850+ self . 0 . state ( ) . conn . is_multipath_negotiated ( )
851+ }
852+
853+ /// Registers a local address for the NAT traversal extension.
854+ pub fn add_nat_traversal_address (
855+ & self ,
856+ address : SocketAddr ,
857+ ) -> Result < ( ) , n0_nat_traversal:: Error > {
858+ let mut state = self . 0 . state ( ) ;
859+ state. conn . add_nat_traversal_address ( address) ?;
860+ state. wake ( ) ;
861+ Ok ( ( ) )
862+ }
863+
864+ /// Removes a local address from the NAT traversal extension set.
865+ pub fn remove_nat_traversal_address (
866+ & self ,
867+ address : SocketAddr ,
868+ ) -> Result < ( ) , n0_nat_traversal:: Error > {
869+ let mut state = self . 0 . state ( ) ;
870+ state. conn . remove_nat_traversal_address ( address) ?;
871+ state. wake ( ) ;
872+ Ok ( ( ) )
873+ }
874+
875+ /// Returns the local NAT traversal addresses known to this connection.
876+ pub fn get_local_nat_traversal_addresses (
877+ & self ,
878+ ) -> Result < Vec < SocketAddr > , n0_nat_traversal:: Error > {
879+ self . 0 . state ( ) . conn . get_local_nat_traversal_addresses ( )
880+ }
881+
882+ /// Returns the remote NAT traversal addresses known to this connection.
883+ pub fn get_remote_nat_traversal_addresses (
884+ & self ,
885+ ) -> Result < Vec < SocketAddr > , n0_nat_traversal:: Error > {
886+ self . 0 . state ( ) . conn . get_remote_nat_traversal_addresses ( )
887+ }
888+
889+ /// Initiates a NAT traversal round and returns the candidate addresses
890+ /// being probed.
891+ pub fn initiate_nat_traversal_round ( & self ) -> Result < Vec < SocketAddr > , n0_nat_traversal:: Error > {
892+ let mut state = self . 0 . state ( ) ;
893+ let addresses = state. conn . initiate_nat_traversal_round ( Instant :: now ( ) ) ?;
894+ state. wake ( ) ;
895+ Ok ( addresses)
896+ }
897+
694898 fn poll_recv_datagram ( & self , cx : & mut Context ) -> Poll < Result < Bytes , ConnectionError > > {
695899 let mut state = self . 0 . try_state ( ) ?;
696900 if let Some ( bytes) = state. conn . datagrams ( ) . recv ( ) {
0 commit comments