Skip to content

Commit fcb8db1

Browse files
committed
feat: implement full noq features
1 parent 04e2261 commit fcb8db1

4 files changed

Lines changed: 465 additions & 17 deletions

File tree

compio-quic/src/connection.rs

Lines changed: 214 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::{
22
collections::VecDeque,
33
fmt::Debug,
4-
net::{IpAddr, SocketAddr},
4+
net::{IpAddr, SocketAddr, SocketAddrV6},
55
pin::{Pin, pin},
66
task::{Context, Poll, Waker},
77
time::{Duration, Instant},
@@ -10,7 +10,7 @@ use std::{
1010
use compio_buf::bytes::Bytes;
1111
use compio_log::Instrument;
1212
use compio_runtime::JoinHandle;
13-
use flume::{Receiver, Sender};
13+
use flume::{Receiver, Sender, unbounded};
1414
use futures_util::{
1515
FutureExt, StreamExt,
1616
future::{self, Fuse, FusedFuture, LocalBoxFuture},
@@ -19,14 +19,15 @@ use futures_util::{
1919
#[cfg(rustls)]
2020
use noq_proto::crypto::rustls::HandshakeData;
2121
use noq_proto::{
22-
ConnectionHandle, ConnectionStats, Dir, EndpointEvent, PathId, Side, StreamEvent, StreamId,
23-
VarInt, congestion::Controller,
22+
ConnectionHandle, ConnectionStats, Dir, EndpointEvent, FourTuple, PathError, PathEvent, PathId,
23+
PathStats, PathStatus, Side, StreamEvent, StreamId, VarInt, congestion::Controller,
24+
n0_nat_traversal,
2425
};
2526
use rustc_hash::FxHashMap as HashMap;
2627
use thiserror::Error;
2728

2829
use crate::{
29-
RecvStream, SendStream, Socket,
30+
OpenPath, Path, RecvStream, SendStream, Socket,
3031
sync::{
3132
mutex_blocking::{Mutex, MutexGuard},
3233
shared::Shared,
@@ -44,14 +45,21 @@ pub(crate) struct ConnectionState {
4445
pub(crate) conn: noq_proto::Connection,
4546
pub(crate) error: Option<ConnectionError>,
4647
connected: bool,
48+
handshake_confirmed: bool,
4749
worker: Option<JoinHandle<()>>,
4850
poller: Option<Waker>,
4951
on_connected: Option<Waker>,
5052
on_handshake_data: Option<Waker>,
53+
on_handshake_confirmed: VecDeque<Waker>,
5154
datagram_received: VecDeque<Waker>,
5255
datagrams_unblocked: VecDeque<Waker>,
5356
stream_opened: [VecDeque<Waker>; 2],
5457
stream_available: [VecDeque<Waker>; 2],
58+
open_path: HashMap<PathId, Sender<Result<(), PathError>>>,
59+
path_events: Vec<Sender<PathEvent>>,
60+
observed_external_addr: Option<SocketAddr>,
61+
nat_traversal_updates: Vec<Sender<n0_nat_traversal::Event>>,
62+
final_path_stats: HashMap<PathId, PathStats>,
5563
pub(crate) writable: HashMap<StreamId, Waker>,
5664
pub(crate) readable: HashMap<StreamId, Waker>,
5765
pub(crate) stopped: HashMap<StreamId, Waker>,
@@ -68,6 +76,7 @@ impl ConnectionState {
6876
if let Some(waker) = self.on_connected.take() {
6977
waker.wake()
7078
}
79+
self.on_handshake_confirmed.drain(..).for_each(Waker::wake);
7180
self.datagram_received.drain(..).for_each(Waker::wake);
7281
self.datagrams_unblocked.drain(..).for_each(Waker::wake);
7382
for e in &mut self.stream_opened {
@@ -76,6 +85,9 @@ impl ConnectionState {
7685
for e in &mut self.stream_available {
7786
e.drain(..).for_each(Waker::wake);
7887
}
88+
for tx in self.open_path.drain().map(|(_, tx)| tx) {
89+
let _ = tx.send(Err(PathError::ValidationFailed));
90+
}
7991
wake_all_streams(&mut self.writable);
8092
wake_all_streams(&mut self.readable);
8193
wake_all_streams(&mut self.stopped);
@@ -104,6 +116,12 @@ impl ConnectionState {
104116
pub(crate) fn check_0rtt(&self) -> bool {
105117
self.conn.side().is_server() || self.conn.is_handshaking() || self.conn.accepted_0rtt()
106118
}
119+
120+
pub(crate) fn path_stats(&mut self, path_id: PathId) -> Option<PathStats> {
121+
self.conn
122+
.path_stats(path_id)
123+
.or_else(|| self.final_path_stats.get(&path_id).copied())
124+
}
107125
}
108126

109127
fn wake_stream(stream: StreamId, wakers: &mut HashMap<StreamId, Waker>) {
@@ -116,6 +134,41 @@ fn wake_all_streams(wakers: &mut HashMap<StreamId, Waker>) {
116134
wakers.drain().for_each(|(_, waker)| waker.wake())
117135
}
118136

137+
fn wake_waiters(wakers: &mut VecDeque<Waker>) {
138+
wakers.drain(..).for_each(Waker::wake)
139+
}
140+
141+
fn broadcast<T: Clone>(listeners: &mut Vec<Sender<T>>, event: T) {
142+
listeners.retain(|tx| tx.send(event.clone()).is_ok());
143+
}
144+
145+
fn normalize_remote_address(
146+
state: &ConnectionState,
147+
addr: SocketAddr,
148+
) -> Result<SocketAddr, PathError> {
149+
let ipv6 = state
150+
.conn
151+
.paths()
152+
.iter()
153+
.filter_map(|id| state.conn.network_path(*id).ok())
154+
.map(|path| path.remote.is_ipv6())
155+
.next()
156+
.unwrap_or_default();
157+
if addr.is_ipv6() && !ipv6 {
158+
return Err(PathError::InvalidRemoteAddress(addr));
159+
}
160+
Ok(if ipv6 {
161+
SocketAddr::V6(match addr {
162+
SocketAddr::V4(addr) => {
163+
SocketAddrV6::new(addr.ip().to_ipv6_mapped(), addr.port(), 0, 0)
164+
}
165+
SocketAddr::V6(addr) => addr,
166+
})
167+
} else {
168+
addr
169+
})
170+
}
171+
119172
#[derive(Debug)]
120173
pub(crate) struct ConnectionInner {
121174
state: Mutex<ConnectionState>,
@@ -143,15 +196,22 @@ impl ConnectionInner {
143196
state: Mutex::new(ConnectionState {
144197
conn,
145198
connected: false,
199+
handshake_confirmed: false,
146200
error: None,
147201
worker: None,
148202
poller: None,
149203
on_connected: None,
150204
on_handshake_data: None,
205+
on_handshake_confirmed: VecDeque::new(),
151206
datagram_received: VecDeque::new(),
152207
datagrams_unblocked: VecDeque::new(),
153208
stream_opened: [VecDeque::new(), VecDeque::new()],
154209
stream_available: [VecDeque::new(), VecDeque::new()],
210+
open_path: HashMap::default(),
211+
path_events: Vec::new(),
212+
observed_external_addr: None,
213+
nat_traversal_updates: Vec::new(),
214+
final_path_stats: HashMap::default(),
155215
writable: HashMap::default(),
156216
readable: HashMap::default(),
157217
stopped: HashMap::default(),
@@ -284,13 +344,33 @@ impl ConnectionInner {
284344
DatagramsUnblocked => state.datagrams_unblocked.drain(..).for_each(Waker::wake),
285345

286346
HandshakeConfirmed => {
287-
todo!()
347+
state.handshake_confirmed = true;
348+
wake_waiters(&mut state.on_handshake_confirmed);
288349
}
289-
Path(_) => {
290-
todo!()
350+
Path(event) => {
351+
match &event {
352+
PathEvent::ObservedAddr { addr, .. } => {
353+
state.observed_external_addr = Some(*addr);
354+
}
355+
PathEvent::Opened { id } => {
356+
if let Some(tx) = state.open_path.remove(id) {
357+
let _ = tx.send(Ok(()));
358+
}
359+
}
360+
PathEvent::Abandoned { id, .. } => {
361+
if let Some(tx) = state.open_path.remove(id) {
362+
let _ = tx.send(Err(PathError::ValidationFailed));
363+
}
364+
}
365+
PathEvent::Discarded { id, path_stats } => {
366+
state.final_path_stats.insert(*id, *path_stats);
367+
}
368+
PathEvent::RemoteStatus { .. } => {}
369+
}
370+
broadcast(&mut state.path_events, event);
291371
}
292-
NatTraversal(_) => {
293-
todo!()
372+
NatTraversal(event) => {
373+
broadcast(&mut state.nat_traversal_updates, event);
294374
}
295375
}
296376
}
@@ -674,6 +754,25 @@ impl Connection {
674754
.close(error_code, Bytes::copy_from_slice(reason));
675755
}
676756

757+
/// Wait for the TLS handshake to be confirmed.
758+
pub async fn handshake_confirmed(&self) -> Result<(), ConnectionError> {
759+
future::poll_fn(|cx| {
760+
let mut state = self.0.try_state()?;
761+
if state.handshake_confirmed {
762+
return Poll::Ready(Ok(()));
763+
}
764+
if !state
765+
.on_handshake_confirmed
766+
.iter()
767+
.any(|waker| waker.will_wake(cx.waker()))
768+
{
769+
state.on_handshake_confirmed.push_back(cx.waker().clone());
770+
}
771+
Poll::Pending
772+
})
773+
.await
774+
}
775+
677776
/// Wait for the connection to be closed for any reason.
678777
pub async fn closed(&self) -> ConnectionError {
679778
let worker = self.0.state().worker.take();
@@ -691,6 +790,111 @@ impl Connection {
691790
self.0.try_state().err()
692791
}
693792

793+
/// Opens an additional path if multipath is negotiated.
794+
pub fn open_path(&self, addr: SocketAddr, initial_status: PathStatus) -> OpenPath {
795+
let mut state = self.0.state();
796+
let addr = match normalize_remote_address(&state, addr) {
797+
Ok(addr) => addr,
798+
Err(err) => return OpenPath::rejected(err),
799+
};
800+
let (tx, rx) = flume::bounded(1);
801+
let result = state.conn.open_path(
802+
FourTuple {
803+
remote: addr,
804+
local_ip: None,
805+
},
806+
initial_status,
807+
Instant::now(),
808+
);
809+
match result {
810+
Ok(path_id) => {
811+
state.open_path.insert(path_id, tx);
812+
state.wake();
813+
OpenPath::new(path_id, rx, self.0.clone())
814+
}
815+
Err(err) => OpenPath::rejected(err),
816+
}
817+
}
818+
819+
/// Returns the path handle for an open path.
820+
pub fn path(&self, id: PathId) -> Option<Path> {
821+
Path::new(&self.0, id)
822+
}
823+
824+
/// Subscribe to path events for this connection.
825+
pub fn path_events(&self) -> Receiver<PathEvent> {
826+
let (tx, rx) = unbounded();
827+
self.0.state().path_events.push(tx);
828+
rx
829+
}
830+
831+
/// Subscribe to NAT traversal updates for this connection.
832+
pub fn nat_traversal_updates(&self) -> Receiver<n0_nat_traversal::Event> {
833+
let (tx, rx) = unbounded();
834+
self.0.state().nat_traversal_updates.push(tx);
835+
rx
836+
}
837+
838+
/// The latest external address observed by the peer.
839+
pub fn observed_external_addr(&self) -> Option<SocketAddr> {
840+
self.0.state().observed_external_addr
841+
}
842+
843+
/// Statistics for a specific path.
844+
pub fn path_stats(&self, path_id: PathId) -> Option<PathStats> {
845+
self.0.state().path_stats(path_id)
846+
}
847+
848+
/// Whether the multipath extension was negotiated for this connection.
849+
pub fn is_multipath_enabled(&self) -> bool {
850+
self.0.state().conn.is_multipath_negotiated()
851+
}
852+
853+
/// Registers a local address for the NAT traversal extension.
854+
pub fn add_nat_traversal_address(
855+
&self,
856+
address: SocketAddr,
857+
) -> Result<(), n0_nat_traversal::Error> {
858+
let mut state = self.0.state();
859+
state.conn.add_nat_traversal_address(address)?;
860+
state.wake();
861+
Ok(())
862+
}
863+
864+
/// Removes a local address from the NAT traversal extension set.
865+
pub fn remove_nat_traversal_address(
866+
&self,
867+
address: SocketAddr,
868+
) -> Result<(), n0_nat_traversal::Error> {
869+
let mut state = self.0.state();
870+
state.conn.remove_nat_traversal_address(address)?;
871+
state.wake();
872+
Ok(())
873+
}
874+
875+
/// Returns the local NAT traversal addresses known to this connection.
876+
pub fn get_local_nat_traversal_addresses(
877+
&self,
878+
) -> Result<Vec<SocketAddr>, n0_nat_traversal::Error> {
879+
self.0.state().conn.get_local_nat_traversal_addresses()
880+
}
881+
882+
/// Returns the remote NAT traversal addresses known to this connection.
883+
pub fn get_remote_nat_traversal_addresses(
884+
&self,
885+
) -> Result<Vec<SocketAddr>, n0_nat_traversal::Error> {
886+
self.0.state().conn.get_remote_nat_traversal_addresses()
887+
}
888+
889+
/// Initiates a NAT traversal round and returns the candidate addresses
890+
/// being probed.
891+
pub fn initiate_nat_traversal_round(&self) -> Result<Vec<SocketAddr>, n0_nat_traversal::Error> {
892+
let mut state = self.0.state();
893+
let addresses = state.conn.initiate_nat_traversal_round(Instant::now())?;
894+
state.wake();
895+
Ok(addresses)
896+
}
897+
694898
fn poll_recv_datagram(&self, cx: &mut Context) -> Poll<Result<Bytes, ConnectionError>> {
695899
let mut state = self.0.try_state()?;
696900
if let Some(bytes) = state.conn.datagrams().recv() {

compio-quic/src/lib.rs

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414
)]
1515

1616
pub use noq_proto::{
17-
AckFrequencyConfig, ApplicationClose, Chunk, ClientConfig, ClosedStream, ConfigError,
18-
ConnectError, ConnectionClose, ConnectionId, ConnectionIdGenerator, ConnectionStats, Dir,
19-
EcnCodepoint, EndpointConfig, FrameStats, FrameType, IdleTimeout, MtuDiscoveryConfig,
20-
NoneTokenLog, NoneTokenStore, PathStats, ServerConfig, Side, StdSystemTime, StreamId,
21-
TimeSource, TokenLog, TokenMemoryCache, TokenReuseError, TokenStore, Transmit, TransportConfig,
22-
TransportErrorCode, UdpStats, ValidationTokenConfig, VarInt, VarIntBoundsExceeded, Written,
23-
congestion, crypto,
17+
AckFrequencyConfig, ApplicationClose, Chunk, ClientConfig, ClosePathError, ClosedPath,
18+
ClosedStream, ConfigError, ConnectError, ConnectionClose, ConnectionId, ConnectionIdGenerator,
19+
ConnectionStats, Dir, EcnCodepoint, EndpointConfig, FrameStats, FrameType, IdleTimeout,
20+
MtuDiscoveryConfig, MultipathNotNegotiated, NoneTokenLog, NoneTokenStore, PathAbandonReason,
21+
PathError, PathEvent, PathId, PathStats, PathStatus, ServerConfig, SetPathStatusError, Side,
22+
StdSystemTime, StreamId, TimeSource, TokenLog, TokenMemoryCache, TokenReuseError, TokenStore,
23+
Transmit, TransportConfig, TransportErrorCode, UdpStats, ValidationTokenConfig, VarInt,
24+
VarIntBoundsExceeded, Written, congestion, crypto, n0_nat_traversal,
2425
};
2526
#[cfg(feature = "qlog")]
2627
pub use noq_proto::{QlogConfig, QlogStream};
@@ -30,6 +31,7 @@ mod builder;
3031
mod connection;
3132
mod endpoint;
3233
mod incoming;
34+
mod path;
3335
mod recv_stream;
3436
mod send_stream;
3537
mod socket;
@@ -39,6 +41,7 @@ pub use builder::{ClientBuilder, ServerBuilder};
3941
pub use connection::{Connecting, Connection, ConnectionError, OpenStreamError, SendDatagramError};
4042
pub use endpoint::{Endpoint, EndpointStats};
4143
pub use incoming::{Incoming, IncomingFuture};
44+
pub use path::{OpenPath, Path};
4245
pub use recv_stream::{ReadError, ReadExactError, RecvStream, ResetError};
4346
pub use send_stream::{SendStream, StoppedError, WriteError};
4447
#[cfg(feature = "sync")]

0 commit comments

Comments
 (0)