@@ -123,6 +123,11 @@ impl<S: AsyncRead + AsyncWrite + Unpin> MptpStream<S> {
123123 }
124124 }
125125 }
126+
127+ // Periodically cleanup closed subs if the list gets too long
128+ if self . subs . len ( ) > 16 {
129+ self . subs . retain ( |s| !s. closed ) ;
130+ }
126131 }
127132}
128133
@@ -281,7 +286,12 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for MptpStream<S> {
281286 // Future packet, buffer it
282287 // log::trace!("Buffering future PN: {} (expected {})", pn, this.expected_read_pn);
283288 if !this. reorder_buffer . contains_key ( & pn) {
284- this. reorder_buffer . insert ( pn, payload) ;
289+ // Limit reorder buffer size to 1024 packets or ~4MB
290+ if this. reorder_buffer . len ( ) < 1024 {
291+ this. reorder_buffer . insert ( pn, payload) ;
292+ } else {
293+ warn ! ( "Reorder buffer full, dropping PN {}" , pn) ;
294+ }
285295 }
286296 }
287297 }
@@ -345,17 +355,15 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MptpStream<S> {
345355 return Poll :: Pending ;
346356 }
347357
348- // Check backpressure (if any buffer is too full)
349- let mut needs_flush = false ;
358+ // Check backpressure (if ALL active buffers are too full)
359+ let mut all_full = true ;
350360 let mut active_subs = 0 ;
351361
352362 for sub in & this. subs {
353363 if !sub. closed {
354364 active_subs += 1 ;
355- if sub. write_buf . len ( ) > 64 * 1024 {
356- needs_flush = true ;
357- // Don't break, check all? Or break is fine.
358- break ;
365+ if sub. write_buf . len ( ) <= 64 * 1024 {
366+ all_full = false ;
359367 }
360368 }
361369 }
@@ -368,17 +376,10 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MptpStream<S> {
368376 ) ) ) ;
369377 }
370378
371- if needs_flush {
372- // Try to flush
379+ if all_full {
380+ // Try to flush - maybe it helps?
373381 let _ = Pin :: new ( & mut * this) . poll_flush ( cx) ;
374- }
375-
376- // If still full, return pending
377- // But only check active ones
378- for sub in & this. subs {
379- if !sub. closed && sub. write_buf . len ( ) > 64 * 1024 {
380- return Poll :: Pending ;
381- }
382+ return Poll :: Pending ;
382383 }
383384
384385 let pn = this. next_pn ;
@@ -393,13 +394,30 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MptpStream<S> {
393394 frame. encode ( & mut encoded) ;
394395 let encoded_bytes = encoded. freeze ( ) ;
395396
396- // Broadcast to all subs
397- for sub in & mut this. subs {
397+ // Broadcast to all non-full subs
398+ let mut sent_count = 0 ;
399+ for ( i, sub) in this. subs . iter_mut ( ) . enumerate ( ) {
398400 if !sub. closed {
399- sub. write_buf . extend_from_slice ( & encoded_bytes) ;
401+ if sub. write_buf . len ( ) <= 64 * 1024 {
402+ sub. write_buf . extend_from_slice ( & encoded_bytes) ;
403+ sent_count += 1 ;
404+ } else {
405+ warn ! (
406+ "Sub {} (CID={}) is full ({} bytes), skipping PN {}" ,
407+ i,
408+ sub. cid,
409+ sub. write_buf. len( ) ,
410+ pn
411+ ) ;
412+ }
400413 }
401414 }
402415
416+ if sent_count == 0 {
417+ // Should not happen due to all_full check above, but for safety:
418+ return Poll :: Pending ;
419+ }
420+
403421 // Try flush immediately
404422 let _ = Pin :: new ( & mut * this) . poll_flush ( cx) ;
405423
@@ -413,11 +431,14 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MptpStream<S> {
413431 this. poll_new_subs ( cx) ;
414432
415433 let mut all_flushed = true ;
434+ let mut any_flushed = false ;
435+ let mut active_subs = 0 ;
416436
417437 for ( i, sub) in this. subs . iter_mut ( ) . enumerate ( ) {
418438 if sub. closed {
419439 continue ;
420440 }
441+ active_subs += 1 ;
421442 while !sub. write_buf . is_empty ( ) {
422443 match Pin :: new ( & mut sub. stream ) . poll_write ( cx, & sub. write_buf ) {
423444 Poll :: Ready ( Ok ( n) ) => {
@@ -438,9 +459,16 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MptpStream<S> {
438459 }
439460 }
440461 }
462+ if !sub. closed && sub. write_buf . is_empty ( ) {
463+ any_flushed = true ;
464+ }
441465 }
442466
443- if all_flushed {
467+ if all_flushed || ( any_flushed && active_subs > 0 ) {
468+ // If at least one path is flushed, we consider the overall stream "flushed" enough
469+ // to continue, but we'll keep trying to flush others in future calls.
470+ Poll :: Ready ( Ok ( ( ) ) )
471+ } else if active_subs == 0 {
444472 Poll :: Ready ( Ok ( ( ) ) )
445473 } else {
446474 Poll :: Pending
@@ -464,21 +492,43 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MptpStream<S> {
464492 // We just append to write buf. poll_flush will send it.
465493 // But shutdown expects to close *now*.
466494 // However, standard AsyncWrite::poll_shutdown implies "flush pending writes and close".
467- sub. write_buf . extend_from_slice ( & encoded_bytes) ;
495+ if !sub. closed {
496+ sub. write_buf . extend_from_slice ( & encoded_bytes) ;
497+ }
468498 }
469499
470500 // Flush all buffers
471501 let _ = Pin :: new ( & mut * this) . poll_flush ( cx) ;
472502
473503 // Now shutdown underlying streams
474504 let mut all_done = true ;
505+ let mut any_done = false ;
506+ let mut active_subs = 0 ;
507+
475508 for sub in & mut this. subs {
476- if Pin :: new ( & mut sub. stream ) . poll_shutdown ( cx) . is_pending ( ) {
477- all_done = false ;
509+ if sub. closed {
510+ continue ;
511+ }
512+ active_subs += 1 ;
513+ match Pin :: new ( & mut sub. stream ) . poll_shutdown ( cx) {
514+ Poll :: Ready ( Ok ( ( ) ) ) => {
515+ any_done = true ;
516+ }
517+ Poll :: Ready ( Err ( _) ) => {
518+ // Just ignore error and mark closed?
519+ sub. closed = true ;
520+ any_done = true ;
521+ }
522+ Poll :: Pending => {
523+ all_done = false ;
524+ }
478525 }
479526 }
480527
481- if all_done {
528+ if all_done || ( any_done && active_subs > 0 ) {
529+ // If at least one path is shut down, we consider the overall stream "shut down" enough
530+ Poll :: Ready ( Ok ( ( ) ) )
531+ } else if active_subs == 0 {
482532 Poll :: Ready ( Ok ( ( ) ) )
483533 } else {
484534 Poll :: Pending
@@ -568,4 +618,28 @@ mod tests {
568618 let n = s1. read ( & mut buf) . await . unwrap ( ) ;
569619 assert ! ( n > 0 ) ;
570620 }
621+
622+ #[ tokio:: test]
623+ async fn test_resilience_to_stuck_sub ( ) {
624+ let ( c1, mut s1) = tokio:: io:: duplex ( 1024 ) ;
625+ let ( c2, _s2) = tokio:: io:: duplex ( 1024 ) ; // s2 is never read, so c2 will become full
626+ let mut mptp = MptpStream :: new ( vec ! [ c1, c2] , uuid:: Uuid :: new_v4 ( ) ) ;
627+
628+ // Read from s1 in a background task to keep c1 empty
629+ tokio:: spawn ( async move {
630+ let mut buf = [ 0u8 ; 1024 ] ;
631+ while let Ok ( n) = s1. read ( & mut buf) . await {
632+ if n == 0 { break ; }
633+ }
634+ } ) ;
635+
636+ // Write enough that c2 is definitely full but mptp shouldn't hang
637+ for i in 0 ..100 {
638+ let msg = format ! ( "hello {}" , i) ;
639+ mptp. write_all ( msg. as_bytes ( ) ) . await . unwrap ( ) ;
640+ }
641+
642+ // Flush should succeed because c1 is flushed
643+ mptp. flush ( ) . await . unwrap ( ) ;
644+ }
571645}
0 commit comments