Skip to content

Commit 63cb0e4

Browse files
committed
Address PR review: poison connections on timeout, error on truncated assembly
- Add poisoned flag (AtomicBool) to SmbClient; set on read timeout - read_exact_timeout checks poisoned state and fails fast with BrokenPipe - Pool skips poisoned connections in round-robin selection - Fix shutdown comment: stream.shutdown() closes write half, poisoned flag prevents further use of the connection - assemble_parts returns UnexpectedEof on empty chunks instead of silent truncation - Fix cargo fmt (collapsed function signatures)
1 parent de51556 commit 63cb0e4

3 files changed

Lines changed: 63 additions & 28 deletions

File tree

src/smb/client.rs

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
use bytes::Buf;
44
use std::io;
55
use std::sync::Arc;
6-
use std::sync::atomic::{AtomicU64, Ordering};
6+
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
77
use std::time::Duration;
88
use tokio::io::{AsyncReadExt, AsyncWriteExt};
99
use tokio::net::TcpStream;
@@ -63,6 +63,8 @@ pub struct SmbClient {
6363
client_guid: [u8; 16],
6464
/// SMB 3.1.1 signing key (derived after auth)
6565
signing_key: Option<[u8; 16]>,
66+
/// Set on read timeout — connection framing is desynchronized.
67+
poisoned: AtomicBool,
6668
}
6769

6870
impl SmbClient {
@@ -129,32 +131,43 @@ impl SmbClient {
129131
compound_max_write_size: 65536,
130132
client_guid,
131133
signing_key: None,
134+
poisoned: AtomicBool::new(false),
132135
};
133136

134137
client.negotiate_and_auth().await?;
135138
Ok(Arc::new(client))
136139
}
137140

141+
/// Whether this connection has been poisoned by a timeout.
142+
pub fn is_poisoned(&self) -> bool {
143+
self.poisoned.load(Ordering::Relaxed)
144+
}
145+
138146
fn next_message_id(&self) -> u64 {
139147
self.message_id.fetch_add(1, Ordering::Relaxed)
140148
}
141149

142150
/// Read exactly `buf.len()` bytes from the stream with a timeout.
143-
/// Returns `TimedOut` if the SMB server doesn't respond within the deadline.
144151
///
145-
/// A timeout may leave the stream mid-frame, so we shut it down to prevent
146-
/// desynchronized reuse.
147-
async fn read_exact_timeout(
148-
stream: &mut TcpStream,
149-
buf: &mut [u8],
150-
) -> io::Result<()> {
152+
/// On timeout the stream framing is desynchronized, so we poison the
153+
/// connection (all future operations fail fast) and drop the underlying
154+
/// socket to fully close both halves.
155+
async fn read_exact_timeout(&self, stream: &mut TcpStream, buf: &mut [u8]) -> io::Result<()> {
156+
if self.poisoned.load(Ordering::Relaxed) {
157+
return Err(io::Error::new(
158+
io::ErrorKind::BrokenPipe,
159+
"SMB connection poisoned by previous timeout",
160+
));
161+
}
151162
match tokio::time::timeout(SMB_READ_TIMEOUT, stream.read_exact(buf)).await {
152163
Ok(result) => result.map(|_| ()),
153164
Err(_) => {
165+
self.poisoned.store(true, Ordering::Relaxed);
166+
// Drop the socket to fully close both halves.
154167
let _ = stream.shutdown().await;
155168
Err(io::Error::new(
156169
io::ErrorKind::TimedOut,
157-
"SMB server read timed out; connection closed",
170+
"SMB server read timed out; connection poisoned",
158171
))
159172
}
160173
}
@@ -188,7 +201,7 @@ impl SmbClient {
188201
// Read responses, looping past STATUS_PENDING interim responses
189202
loop {
190203
let mut len_buf = [0u8; 4];
191-
Self::read_exact_timeout(&mut stream, &mut len_buf).await?;
204+
self.read_exact_timeout(&mut stream, &mut len_buf).await?;
192205
let msg_len = u32::from_be_bytes(len_buf) as usize;
193206

194207
if !(SMB2_HEADER_SIZE..=16 * 1024 * 1024).contains(&msg_len) {
@@ -200,7 +213,7 @@ impl SmbClient {
200213
}
201214

202215
let mut msg = vec![0u8; msg_len];
203-
Self::read_exact_timeout(&mut stream, &mut msg).await?;
216+
self.read_exact_timeout(&mut stream, &mut msg).await?;
204217

205218
let header = Header::decode(&msg).ok_or_else(|| {
206219
crate::serr!("[spiceio] smb invalid header");
@@ -532,7 +545,7 @@ impl SmbClient {
532545

533546
while received < count {
534547
let mut len_buf = [0u8; 4];
535-
Self::read_exact_timeout(&mut stream, &mut len_buf).await?;
548+
self.read_exact_timeout(&mut stream, &mut len_buf).await?;
536549
let msg_len = u32::from_be_bytes(len_buf) as usize;
537550

538551
if !(SMB2_HEADER_SIZE..=16 * 1024 * 1024).contains(&msg_len) {
@@ -543,7 +556,7 @@ impl SmbClient {
543556
}
544557

545558
let mut msg = vec![0u8; msg_len];
546-
Self::read_exact_timeout(&mut stream, &mut msg).await?;
559+
self.read_exact_timeout(&mut stream, &mut msg).await?;
547560

548561
let header = Header::decode(&msg)
549562
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "invalid SMB2 header"))?;
@@ -688,7 +701,7 @@ impl SmbClient {
688701
let mut received = 0usize;
689702
while received < n {
690703
let mut len_buf = [0u8; 4];
691-
Self::read_exact_timeout(&mut stream, &mut len_buf).await?;
704+
self.read_exact_timeout(&mut stream, &mut len_buf).await?;
692705
let msg_len = u32::from_be_bytes(len_buf) as usize;
693706

694707
if !(SMB2_HEADER_SIZE..=16 * 1024 * 1024).contains(&msg_len) {
@@ -699,7 +712,7 @@ impl SmbClient {
699712
}
700713

701714
let mut msg = vec![0u8; msg_len];
702-
Self::read_exact_timeout(&mut stream, &mut msg).await?;
715+
self.read_exact_timeout(&mut stream, &mut msg).await?;
703716

704717
let header = Header::decode(&msg)
705718
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "invalid SMB2 header"))?;
@@ -878,7 +891,7 @@ impl SmbClient {
878891
// Read response frames, skipping STATUS_PENDING interim responses
879892
loop {
880893
let mut len_buf = [0u8; 4];
881-
Self::read_exact_timeout(&mut stream, &mut len_buf).await?;
894+
self.read_exact_timeout(&mut stream, &mut len_buf).await?;
882895
let msg_len = u32::from_be_bytes(len_buf) as usize;
883896

884897
if !(SMB2_HEADER_SIZE..=16 * 1024 * 1024).contains(&msg_len) {
@@ -890,7 +903,7 @@ impl SmbClient {
890903
}
891904

892905
let mut msg = vec![0u8; msg_len];
893-
Self::read_exact_timeout(&mut stream, &mut msg).await?;
906+
self.read_exact_timeout(&mut stream, &mut msg).await?;
894907

895908
// Single STATUS_PENDING interim — skip
896909
if let Some(h) = Header::decode(&msg)

src/smb/ops.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -466,11 +466,7 @@ impl ShareSession {
466466
/// Reads each temp part using pipelined reads and writes through a WalWriter
467467
/// (pipelined writes + atomic rename). Never holds more than one pipeline
468468
/// buffer in memory — supports arbitrarily large files.
469-
pub async fn assemble_parts(
470-
&self,
471-
key: &str,
472-
temp_paths: &[&str],
473-
) -> io::Result<ObjectMeta> {
469+
pub async fn assemble_parts(&self, key: &str, temp_paths: &[&str]) -> io::Result<ObjectMeta> {
474470
let mut wal = self.open_wal_write(key).await?;
475471
let max_read = self.pool.max_read_size;
476472

@@ -502,7 +498,14 @@ impl ShareSession {
502498
.pipelined_read(tree_id, &file_id, offset, max_read, batch)
503499
.await?;
504500
if chunks.is_empty() {
505-
break;
501+
let _ = client.close(tree_id, &file_id).await;
502+
return Err(io::Error::new(
503+
io::ErrorKind::UnexpectedEof,
504+
format!(
505+
"unexpected EOF assembling part '{}': read {} of {} bytes",
506+
temp_path, offset, file_size
507+
),
508+
));
506509
}
507510
for chunk in &chunks {
508511
wal.write(chunk).await?;

src/smb/pool.rs

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,34 @@ impl SmbPool {
6666
}))
6767
}
6868

69-
/// Pick the next connection via round-robin.
69+
/// Pick the next healthy connection via round-robin, skipping poisoned ones.
70+
/// Falls back to a poisoned connection if all are poisoned (error will
71+
/// surface on the first I/O attempt).
7072
pub fn get(&self) -> &Arc<SmbClient> {
71-
let idx = self.next.fetch_add(1, Ordering::Relaxed) % self.clients.len();
72-
&self.clients[idx]
73+
let n = self.clients.len();
74+
let start = self.next.fetch_add(1, Ordering::Relaxed);
75+
for i in 0..n {
76+
let idx = (start + i) % n;
77+
if !self.clients[idx].is_poisoned() {
78+
return &self.clients[idx];
79+
}
80+
}
81+
// All poisoned — return round-robin pick; caller gets BrokenPipe on I/O
82+
&self.clients[start % n]
7383
}
7484

75-
/// Get the next round-robin index (and advance the counter).
85+
/// Get the next round-robin index (and advance the counter), preferring
86+
/// healthy connections.
7687
pub fn next_index(&self) -> usize {
77-
self.next.fetch_add(1, Ordering::Relaxed)
88+
let n = self.clients.len();
89+
let start = self.next.fetch_add(1, Ordering::Relaxed);
90+
for i in 0..n {
91+
let idx = (start + i) % n;
92+
if !self.clients[idx].is_poisoned() {
93+
return idx;
94+
}
95+
}
96+
start % n
7897
}
7998

8099
/// Access a specific connection by index.

0 commit comments

Comments
 (0)