33use bytes:: Buf ;
44use std:: io;
55use std:: sync:: Arc ;
6- use std:: sync:: atomic:: { AtomicU64 , Ordering } ;
6+ use std:: sync:: atomic:: { AtomicBool , AtomicU64 , Ordering } ;
77use std:: time:: Duration ;
88use tokio:: io:: { AsyncReadExt , AsyncWriteExt } ;
99use tokio:: net:: TcpStream ;
@@ -63,6 +63,8 @@ pub struct SmbClient {
6363 client_guid : [ u8 ; 16 ] ,
6464 /// SMB 3.1.1 signing key (derived after auth)
6565 signing_key : Option < [ u8 ; 16 ] > ,
66+ /// Set on read timeout — connection framing is desynchronized.
67+ poisoned : AtomicBool ,
6668}
6769
6870impl SmbClient {
@@ -129,32 +131,43 @@ impl SmbClient {
129131 compound_max_write_size : 65536 ,
130132 client_guid,
131133 signing_key : None ,
134+ poisoned : AtomicBool :: new ( false ) ,
132135 } ;
133136
134137 client. negotiate_and_auth ( ) . await ?;
135138 Ok ( Arc :: new ( client) )
136139 }
137140
141+ /// Whether this connection has been poisoned by a timeout.
142+ pub fn is_poisoned ( & self ) -> bool {
143+ self . poisoned . load ( Ordering :: Relaxed )
144+ }
145+
138146 fn next_message_id ( & self ) -> u64 {
139147 self . message_id . fetch_add ( 1 , Ordering :: Relaxed )
140148 }
141149
142150 /// Read exactly `buf.len()` bytes from the stream with a timeout.
143- /// Returns `TimedOut` if the SMB server doesn't respond within the deadline.
144151 ///
145- /// A timeout may leave the stream mid-frame, so we shut it down to prevent
146- /// desynchronized reuse.
147- async fn read_exact_timeout (
148- stream : & mut TcpStream ,
149- buf : & mut [ u8 ] ,
150- ) -> io:: Result < ( ) > {
152+ /// On timeout the stream framing is desynchronized, so we poison the
153+ /// connection (all future operations fail fast) and drop the underlying
154+ /// socket to fully close both halves.
155+ async fn read_exact_timeout ( & self , stream : & mut TcpStream , buf : & mut [ u8 ] ) -> io:: Result < ( ) > {
156+ if self . poisoned . load ( Ordering :: Relaxed ) {
157+ return Err ( io:: Error :: new (
158+ io:: ErrorKind :: BrokenPipe ,
159+ "SMB connection poisoned by previous timeout" ,
160+ ) ) ;
161+ }
151162 match tokio:: time:: timeout ( SMB_READ_TIMEOUT , stream. read_exact ( buf) ) . await {
152163 Ok ( result) => result. map ( |_| ( ) ) ,
153164 Err ( _) => {
165+ self . poisoned . store ( true , Ordering :: Relaxed ) ;
166+ // Drop the socket to fully close both halves.
154167 let _ = stream. shutdown ( ) . await ;
155168 Err ( io:: Error :: new (
156169 io:: ErrorKind :: TimedOut ,
157- "SMB server read timed out; connection closed " ,
170+ "SMB server read timed out; connection poisoned " ,
158171 ) )
159172 }
160173 }
@@ -188,7 +201,7 @@ impl SmbClient {
188201 // Read responses, looping past STATUS_PENDING interim responses
189202 loop {
190203 let mut len_buf = [ 0u8 ; 4 ] ;
191- Self :: read_exact_timeout ( & mut stream, & mut len_buf) . await ?;
204+ self . read_exact_timeout ( & mut stream, & mut len_buf) . await ?;
192205 let msg_len = u32:: from_be_bytes ( len_buf) as usize ;
193206
194207 if !( SMB2_HEADER_SIZE ..=16 * 1024 * 1024 ) . contains ( & msg_len) {
@@ -200,7 +213,7 @@ impl SmbClient {
200213 }
201214
202215 let mut msg = vec ! [ 0u8 ; msg_len] ;
203- Self :: read_exact_timeout ( & mut stream, & mut msg) . await ?;
216+ self . read_exact_timeout ( & mut stream, & mut msg) . await ?;
204217
205218 let header = Header :: decode ( & msg) . ok_or_else ( || {
206219 crate :: serr!( "[spiceio] smb invalid header" ) ;
@@ -532,7 +545,7 @@ impl SmbClient {
532545
533546 while received < count {
534547 let mut len_buf = [ 0u8 ; 4 ] ;
535- Self :: read_exact_timeout ( & mut stream, & mut len_buf) . await ?;
548+ self . read_exact_timeout ( & mut stream, & mut len_buf) . await ?;
536549 let msg_len = u32:: from_be_bytes ( len_buf) as usize ;
537550
538551 if !( SMB2_HEADER_SIZE ..=16 * 1024 * 1024 ) . contains ( & msg_len) {
@@ -543,7 +556,7 @@ impl SmbClient {
543556 }
544557
545558 let mut msg = vec ! [ 0u8 ; msg_len] ;
546- Self :: read_exact_timeout ( & mut stream, & mut msg) . await ?;
559+ self . read_exact_timeout ( & mut stream, & mut msg) . await ?;
547560
548561 let header = Header :: decode ( & msg)
549562 . ok_or_else ( || io:: Error :: new ( io:: ErrorKind :: InvalidData , "invalid SMB2 header" ) ) ?;
@@ -688,7 +701,7 @@ impl SmbClient {
688701 let mut received = 0usize ;
689702 while received < n {
690703 let mut len_buf = [ 0u8 ; 4 ] ;
691- Self :: read_exact_timeout ( & mut stream, & mut len_buf) . await ?;
704+ self . read_exact_timeout ( & mut stream, & mut len_buf) . await ?;
692705 let msg_len = u32:: from_be_bytes ( len_buf) as usize ;
693706
694707 if !( SMB2_HEADER_SIZE ..=16 * 1024 * 1024 ) . contains ( & msg_len) {
@@ -699,7 +712,7 @@ impl SmbClient {
699712 }
700713
701714 let mut msg = vec ! [ 0u8 ; msg_len] ;
702- Self :: read_exact_timeout ( & mut stream, & mut msg) . await ?;
715+ self . read_exact_timeout ( & mut stream, & mut msg) . await ?;
703716
704717 let header = Header :: decode ( & msg)
705718 . ok_or_else ( || io:: Error :: new ( io:: ErrorKind :: InvalidData , "invalid SMB2 header" ) ) ?;
@@ -878,7 +891,7 @@ impl SmbClient {
878891 // Read response frames, skipping STATUS_PENDING interim responses
879892 loop {
880893 let mut len_buf = [ 0u8 ; 4 ] ;
881- Self :: read_exact_timeout ( & mut stream, & mut len_buf) . await ?;
894+ self . read_exact_timeout ( & mut stream, & mut len_buf) . await ?;
882895 let msg_len = u32:: from_be_bytes ( len_buf) as usize ;
883896
884897 if !( SMB2_HEADER_SIZE ..=16 * 1024 * 1024 ) . contains ( & msg_len) {
@@ -890,7 +903,7 @@ impl SmbClient {
890903 }
891904
892905 let mut msg = vec ! [ 0u8 ; msg_len] ;
893- Self :: read_exact_timeout ( & mut stream, & mut msg) . await ?;
906+ self . read_exact_timeout ( & mut stream, & mut msg) . await ?;
894907
895908 // Single STATUS_PENDING interim — skip
896909 if let Some ( h) = Header :: decode ( & msg)
0 commit comments