@@ -12,7 +12,7 @@ use std::{
12
12
use bytes:: { Bytes , BytesMut } ;
13
13
use serde:: { Deserialize , Serialize } ;
14
14
15
- use futures:: { ready, stream, Sink , Stream } ;
15
+ use futures:: { ready, stream, Sink , Stream , SinkExt , sink :: Buffer } ;
16
16
17
17
use async_recursion:: async_recursion;
18
18
use async_trait:: async_trait;
@@ -395,9 +395,12 @@ impl ConnectedSink for ConnectedBidi {
395
395
}
396
396
}
397
397
398
- pub struct ConnectedDemux < T : ConnectedSink > {
398
+ pub type BufferedDrain < S , I > = DemuxDrain < I , Buffer < S , I > > ;
399
+
400
+ pub struct ConnectedDemux < T : ConnectedSink >
401
+ where <T as ConnectedSink >:: Input : Sync {
399
402
pub keys : Vec < u32 > ,
400
- sink : Option < DemuxDrain < T :: Input , T :: Sink > > ,
403
+ sink : Option < BufferedDrain < T :: Sink , T :: Input > > ,
401
404
}
402
405
403
406
#[ pin_project]
@@ -460,7 +463,7 @@ where
460
463
for ( id, pipe) in demux {
461
464
connected_demux. insert (
462
465
id,
463
- Box :: pin ( T :: from_defn ( ServerOrBound :: Server ( pipe) ) . await . into_sink ( ) ) ,
466
+ Box :: pin ( T :: from_defn ( ServerOrBound :: Server ( pipe) ) . await . into_sink ( ) . buffer ( 1024 ) ) ,
464
467
) ;
465
468
}
466
469
@@ -481,7 +484,7 @@ where
481
484
for ( id, bound) in demux {
482
485
connected_demux. insert (
483
486
id,
484
- Box :: pin ( T :: from_defn ( ServerOrBound :: Bound ( bound) ) . await . into_sink ( ) ) ,
487
+ Box :: pin ( T :: from_defn ( ServerOrBound :: Bound ( bound) ) . await . into_sink ( ) . buffer ( 1024 ) ) ,
485
488
) ;
486
489
}
487
490
@@ -505,7 +508,7 @@ where
505
508
<T as ConnectedSink >:: Input : ' static + Sync ,
506
509
{
507
510
type Input = ( u32 , T :: Input ) ;
508
- type Sink = DemuxDrain < T :: Input , T :: Sink > ;
511
+ type Sink = BufferedDrain < T :: Sink , T :: Input > ;
509
512
510
513
fn into_sink ( mut self ) -> Self :: Sink {
511
514
self . sink . take ( ) . unwrap ( )
0 commit comments