Skip to content

Commit 99800ab

Browse files
committed
transport: Make accept async to close the gap on service races
Signed-off-by: Alexandru Vasile <[email protected]>
1 parent 7e048bf commit 99800ab

File tree

11 files changed

+208
-83
lines changed

11 files changed

+208
-83
lines changed

src/transport/dummy.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use crate::{
2525
types::ConnectionId,
2626
};
2727

28-
use futures::Stream;
28+
use futures::{future::BoxFuture, Stream};
2929
use multiaddr::Multiaddr;
3030

3131
use std::{
@@ -71,8 +71,8 @@ impl Transport for DummyTransport {
7171
Ok(())
7272
}
7373

74-
fn accept(&mut self, _: ConnectionId) -> crate::Result<()> {
75-
Ok(())
74+
fn accept(&mut self, _: ConnectionId) -> crate::Result<BoxFuture<'static, crate::Result<()>>> {
75+
Ok(Box::pin(async { Ok(()) }))
7676
}
7777

7878
fn accept_pending(&mut self, _connection_id: ConnectionId) -> crate::Result<()> {

src/transport/manager/mod.rs

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ use crate::{
3939
};
4040

4141
use address::{scores, AddressStore};
42-
use futures::{Stream, StreamExt};
42+
use futures::{future::BoxFuture, Stream, StreamExt};
4343
use indexmap::IndexMap;
4444
use multiaddr::{Multiaddr, Protocol};
4545
use multihash::Multihash;
@@ -252,6 +252,11 @@ pub struct TransportManager {
252252

253253
/// Opening connections errors.
254254
opening_errors: HashMap<ConnectionId, Vec<(Multiaddr, DialError)>>,
255+
256+
/// Pending accept future with associated connection information.
257+
/// When a connection is accepted, we must wait for the accept future to complete
258+
/// (which notifies all protocols) before emitting the ConnectionEstablished event.
259+
pending_accept: Option<(PeerId, Endpoint, BoxFuture<'static, crate::Result<()>>)>,
255260
}
256261

257262
/// Builder for [`crate::transport::manager::TransportManager`].
@@ -365,6 +370,7 @@ impl TransportManagerBuilder {
365370
pending_connections: HashMap::new(),
366371
connection_limits: limits::ConnectionLimits::new(self.connection_limits_config),
367372
opening_errors: HashMap::new(),
373+
pending_accept: None,
368374
}
369375
}
370376
}
@@ -1090,6 +1096,35 @@ impl TransportManager {
10901096
/// Poll next event from [`crate::transport::manager::TransportManager`].
10911097
pub async fn next(&mut self) -> Option<TransportEvent> {
10921098
loop {
1099+
// First, check if we have a pending accept future to poll
1100+
if let Some((peer, endpoint, mut future)) = self.pending_accept.take() {
1101+
match future.as_mut().await {
1102+
Ok(()) => {
1103+
tracing::trace!(
1104+
target: LOG_TARGET,
1105+
?peer,
1106+
?endpoint,
1107+
"connection accepted and protocols notified",
1108+
);
1109+
1110+
return Some(TransportEvent::ConnectionEstablished {
1111+
peer,
1112+
endpoint,
1113+
});
1114+
}
1115+
Err(error) => {
1116+
tracing::debug!(
1117+
target: LOG_TARGET,
1118+
?peer,
1119+
?endpoint,
1120+
?error,
1121+
"failed to notify protocols about connection",
1122+
);
1123+
// If notification failed, we don't emit the ConnectionEstablished event
1124+
}
1125+
}
1126+
}
1127+
10931128
tokio::select! {
10941129
event = self.event_rx.recv() => {
10951130
let Some(event) = event else {
@@ -1270,16 +1305,27 @@ impl TransportManager {
12701305
"accept connection",
12711306
);
12721307

1273-
let _ = self
1308+
match self
12741309
.transports
12751310
.get_mut(&transport)
12761311
.expect("transport to exist")
1277-
.accept(endpoint.connection_id());
1278-
1279-
return Some(TransportEvent::ConnectionEstablished {
1280-
peer,
1281-
endpoint,
1282-
});
1312+
.accept(endpoint.connection_id())
1313+
{
1314+
Ok(future) => {
1315+
// Store the accept future to be polled in the next iteration
1316+
// This ensures protocols are notified before we emit ConnectionEstablished
1317+
self.pending_accept = Some((peer, endpoint, future));
1318+
}
1319+
Err(error) => {
1320+
tracing::debug!(
1321+
target: LOG_TARGET,
1322+
?peer,
1323+
?endpoint,
1324+
?error,
1325+
"failed to accept connection",
1326+
);
1327+
}
1328+
}
12831329
}
12841330
Ok(ConnectionEstablishedResult::Reject) => {
12851331
tracing::trace!(

src/transport/mod.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
2323
use crate::{error::DialError, transport::manager::TransportHandle, types::ConnectionId, PeerId};
2424

25-
use futures::Stream;
25+
use futures::{future::BoxFuture, Stream};
2626
use hickory_resolver::TokioResolver;
2727
use multiaddr::Multiaddr;
2828

@@ -194,7 +194,12 @@ pub(crate) trait Transport: Stream + Unpin + Send {
194194
fn dial(&mut self, connection_id: ConnectionId, address: Multiaddr) -> crate::Result<()>;
195195

196196
/// Accept negotiated connection.
197-
fn accept(&mut self, connection_id: ConnectionId) -> crate::Result<()>;
197+
///
198+
/// Returns a future that completes when the connection has been fully established
199+
/// and all installed protocols have been notified via their event channels.
200+
/// This ensures that by the time the caller receives a ConnectionEstablished event,
201+
/// protocols are ready to handle substream operations.
202+
fn accept(&mut self, connection_id: ConnectionId) -> crate::Result<BoxFuture<'static, crate::Result<()>>>;
198203

199204
/// Accept pending connection.
200205
fn accept_pending(&mut self, connection_id: ConnectionId) -> crate::Result<()>;

src/transport/quic/connection.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,12 @@ impl QuicConnection {
237237
.report_connection_established(self.peer, self.endpoint.clone())
238238
.await?;
239239

240+
self.start_event_loop().await
241+
}
242+
243+
/// Start the connection event loop without notifying protocols.
244+
/// This is used when protocols have already been notified during accept().
245+
pub(crate) async fn start_event_loop(mut self) -> crate::Result<()> {
240246
loop {
241247
tokio::select! {
242248
event = self.connection.accept_bi() => match event {

src/transport/quic/mod.rs

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -306,35 +306,47 @@ impl Transport for QuicTransport {
306306
Ok(())
307307
}
308308

309-
fn accept(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
309+
fn accept(&mut self, connection_id: ConnectionId) -> crate::Result<BoxFuture<'static, crate::Result<()>>> {
310310
let (connection, endpoint) = self
311311
.pending_open
312312
.remove(&connection_id)
313313
.ok_or(Error::ConnectionDoesntExist(connection_id))?;
314314
let bandwidth_sink = self.context.bandwidth_sink.clone();
315-
let protocol_set = self.context.protocol_set(connection_id);
315+
let mut protocol_set = self.context.protocol_set(connection_id);
316316
let substream_open_timeout = self.config.substream_open_timeout;
317+
let executor = self.context.executor.clone();
317318

318319
tracing::trace!(
319320
target: LOG_TARGET,
320321
?connection_id,
321322
"start connection",
322323
);
323324

324-
self.context.executor.run(Box::pin(async move {
325-
let _ = QuicConnection::new(
326-
connection.peer,
327-
endpoint,
328-
connection.connection,
329-
protocol_set,
330-
bandwidth_sink,
331-
substream_open_timeout,
332-
)
333-
.start()
334-
.await;
335-
}));
325+
let peer = connection.peer;
326+
let endpoint_clone = endpoint.clone();
327+
328+
Ok(Box::pin(async move {
329+
// First, notify all protocols about the connection establishment
330+
protocol_set
331+
.report_connection_established(peer, endpoint_clone)
332+
.await?;
333+
334+
// After protocols are notified, spawn the connection event loop
335+
executor.run(Box::pin(async move {
336+
let _ = QuicConnection::new(
337+
peer,
338+
endpoint,
339+
connection.connection,
340+
protocol_set,
341+
bandwidth_sink,
342+
substream_open_timeout,
343+
)
344+
.start_event_loop()
345+
.await;
346+
}));
336347

337-
Ok(())
348+
Ok(())
349+
}))
338350
}
339351

340352
fn reject(&mut self, connection_id: ConnectionId) -> crate::Result<()> {

src/transport/tcp/connection.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,12 @@ impl TcpConnection {
736736
.report_connection_established(self.peer, self.endpoint.clone())
737737
.await?;
738738

739+
self.start_event_loop().await
740+
}
741+
742+
/// Start the connection event loop without notifying protocols.
743+
/// This is used when protocols have already been notified during accept().
744+
pub(crate) async fn start_event_loop(mut self) -> crate::Result<()> {
739745
loop {
740746
tokio::select! {
741747
substream = self.connection.next() => {

src/transport/tcp/mod.rs

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -389,37 +389,50 @@ impl Transport for TcpTransport {
389389
)
390390
}
391391

392-
fn accept(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
392+
fn accept(&mut self, connection_id: ConnectionId) -> crate::Result<BoxFuture<'static, crate::Result<()>>> {
393393
let context = self
394394
.pending_open
395395
.remove(&connection_id)
396396
.ok_or(Error::ConnectionDoesntExist(connection_id))?;
397-
let protocol_set = self.context.protocol_set(connection_id);
397+
let mut protocol_set = self.context.protocol_set(connection_id);
398398
let bandwidth_sink = self.context.bandwidth_sink.clone();
399399
let next_substream_id = self.context.next_substream_id.clone();
400+
let executor = self.context.executor.clone();
400401

401402
tracing::trace!(
402403
target: LOG_TARGET,
403404
?connection_id,
404405
"start connection",
405406
);
406407

407-
self.context.executor.run(Box::pin(async move {
408-
if let Err(error) =
409-
TcpConnection::new(context, protocol_set, bandwidth_sink, next_substream_id)
410-
.start()
411-
.await
412-
{
413-
tracing::debug!(
414-
target: LOG_TARGET,
415-
?connection_id,
416-
?error,
417-
"connection exited with error",
418-
);
419-
}
420-
}));
408+
let peer = context.peer();
409+
let endpoint = context.endpoint().clone();
410+
411+
Ok(Box::pin(async move {
412+
// First, notify all protocols about the connection establishment
413+
// This ensures that when the accept() future completes, protocols are ready
414+
protocol_set
415+
.report_connection_established(peer, endpoint)
416+
.await?;
417+
418+
// After protocols are notified, spawn the connection event loop
419+
executor.run(Box::pin(async move {
420+
if let Err(error) =
421+
TcpConnection::new(context, protocol_set, bandwidth_sink, next_substream_id)
422+
.start_event_loop()
423+
.await
424+
{
425+
tracing::debug!(
426+
target: LOG_TARGET,
427+
?connection_id,
428+
?error,
429+
"connection exited with error",
430+
);
431+
}
432+
}));
421433

422-
Ok(())
434+
Ok(())
435+
}))
423436
}
424437

425438
fn reject(&mut self, connection_id: ConnectionId) -> crate::Result<()> {

src/transport/webrtc/connection.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,12 @@ impl WebRtcConnection {
680680
.report_connection_established(self.peer, self.endpoint.clone())
681681
.await;
682682

683+
self.run_event_loop().await;
684+
}
685+
686+
/// Start the connection event loop without notifying protocols.
687+
/// This is used when protocols have already been notified during accept().
688+
pub async fn run_event_loop(mut self) {
683689
loop {
684690
// poll output until we get a timeout
685691
let timeout = match self.rtc.poll_output().unwrap() {

src/transport/webrtc/mod.rs

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ impl Transport for WebRtcTransport {
532532
))
533533
}
534534

535-
fn accept(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
535+
fn accept(&mut self, connection_id: ConnectionId) -> crate::Result<BoxFuture<'static, crate::Result<()>>> {
536536
tracing::trace!(
537537
target: LOG_TARGET,
538538
?connection_id,
@@ -562,19 +562,13 @@ impl Transport for WebRtcTransport {
562562

563563
let rtc = connection.on_accept()?;
564564
let (tx, rx) = channel(self.datagram_buffer_size);
565-
let protocol_set = self.context.protocol_set(connection_id);
565+
let mut protocol_set = self.context.protocol_set(connection_id);
566566
let connection_id = endpoint.connection_id();
567+
let endpoint_clone = endpoint.clone();
568+
let executor = self.context.executor.clone();
569+
let socket = Arc::clone(&self.socket);
570+
let listen_address = self.listen_address;
567571

568-
let connection = WebRtcConnection::new(
569-
rtc,
570-
peer,
571-
source,
572-
self.listen_address,
573-
Arc::clone(&self.socket),
574-
protocol_set,
575-
endpoint,
576-
rx,
577-
);
578572
self.open.insert(
579573
source,
580574
ConnectionContext {
@@ -584,11 +578,30 @@ impl Transport for WebRtcTransport {
584578
},
585579
);
586580

587-
self.context.executor.run(Box::pin(async move {
588-
connection.run().await;
589-
}));
581+
Ok(Box::pin(async move {
582+
// First, notify all protocols about the connection establishment
583+
protocol_set
584+
.report_connection_established(peer, endpoint_clone)
585+
.await?;
590586

591-
Ok(())
587+
// After protocols are notified, create connection and spawn event loop
588+
let connection = WebRtcConnection::new(
589+
rtc,
590+
peer,
591+
source,
592+
listen_address,
593+
socket,
594+
protocol_set,
595+
endpoint,
596+
rx,
597+
);
598+
599+
executor.run(Box::pin(async move {
600+
connection.run_event_loop().await;
601+
}));
602+
603+
Ok(())
604+
}))
592605
}
593606

594607
fn reject(&mut self, connection_id: ConnectionId) -> crate::Result<()> {

src/transport/websocket/connection.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,9 +456,15 @@ impl WebSocketConnection {
456456
/// Start connection event loop.
457457
pub(crate) async fn start(mut self) -> crate::Result<()> {
458458
self.protocol_set
459-
.report_connection_established(self.peer, self.endpoint)
459+
.report_connection_established(self.peer, self.endpoint.clone())
460460
.await?;
461461

462+
self.start_event_loop().await
463+
}
464+
465+
/// Start the connection event loop without notifying protocols.
466+
/// This is used when protocols have already been notified during accept().
467+
pub(crate) async fn start_event_loop(mut self) -> crate::Result<()> {
462468
loop {
463469
tokio::select! {
464470
substream = self.connection.next() => match substream {

0 commit comments

Comments
 (0)