From 183390a19b2829597dc90e09453f664a1733e658 Mon Sep 17 00:00:00 2001 From: Luke Kim <80174+lukekim@users.noreply.github.com> Date: Thu, 14 May 2026 17:25:29 +0900 Subject: [PATCH] Coalesced pipelined SMB I/O for higher 10G throughput MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reworks the SMB pipelined-read and pipelined-write paths to build all batch packets into one contiguous BytesMut, sign each in-place, and emit a single write_all per batch — eliminating 64 per-packet to_vec allocations and collapsing 64 write_all syscalls per batch into 1. Adds a zero-copy read response decoder that slices an owned Vec into Bytes without the prior body.to_vec() — saves ~4 MiB of memcpy per 64-deep batch at 64 KiB chunks. Sizes the GetObject streaming channel to READ_PIPELINE_DEPTH so a full pipeline batch can dump into the channel without blocking, letting back-to-back SMB batches overlap HTTP draining. Extends bench-live.sh with concurrent multi-stream PUT/GET (BENCH_CONCURRENCY) and an optional raw mount_smbfs baseline (BENCH_MOUNT_BASELINE) to quantify the spiceio translation overhead against the link ceiling. Adds matching protocol micro-benches. Microbench (pipelined_write_encode, d64 x 64 KiB): 154 us -> 49 us, ~3.1x faster on the CPU side, on top of the 64 -> 1 syscall reduction. --- Cargo.lock | 2 +- Cargo.toml | 2 +- benches/protocol_bench.rs | 79 ++++++++++++++++++++ scripts/bench-live.sh | 148 ++++++++++++++++++++++++++++++++++++-- src/s3/router.rs | 11 ++- src/smb/client.rs | 98 ++++++++++++++++--------- src/smb/ops.rs | 7 +- src/smb/protocol.rs | 55 ++++++++++++++ 8 files changed, 357 insertions(+), 45 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b4e421b..203f56b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -656,7 +656,7 @@ dependencies = [ [[package]] name = "spiceio" -version = "0.5.2" +version = "0.5.3" dependencies = [ "bytes", "criterion", diff --git a/Cargo.toml b/Cargo.toml index ab98577..ebaa17d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "spiceio" -version = "0.5.2" +version = "0.5.3" edition = "2024" description = "S3-compatible API proxy to SMB file shares" license = "Apache-2.0" diff --git a/benches/protocol_bench.rs b/benches/protocol_bench.rs index f0fe7bf..520e261 100644 --- a/benches/protocol_bench.rs +++ b/benches/protocol_bench.rs @@ -242,6 +242,42 @@ fn bench_pipelined_read_decode(c: &mut Criterion) { group.finish(); } +/// Bench the zero-copy `decode_read_response_from_msg` path used after the +/// pipelined-read optimization. Compared to `bench_pipelined_read_decode` this +/// avoids the per-response body `to_vec()` — for a 64-deep 64 KiB batch that's +/// ~4 MiB of memcpy per batch eliminated. +fn bench_pipelined_read_decode_zerocopy(c: &mut Criterion) { + let mut group = c.benchmark_group("pipelined_read_decode_zerocopy"); + let cases = [(8usize, 65536usize), (64, 65536), (64, 8192)]; + for (depth, chunk_size) in cases { + let base_msg_id = 1_000u64; + let messages: Vec> = (0..depth) + .map(|i| build_read_response_msg(base_msg_id + i as u64, chunk_size)) + .collect(); + group.throughput(criterion::Throughput::Bytes((depth * chunk_size) as u64)); + group.bench_with_input( + criterion::BenchmarkId::from_parameter(format!("d{depth}_c{chunk_size}")), + &messages, + |b, messages| { + b.iter(|| { + let n = messages.len(); + let mut slots: Vec> = (0..n).map(|_| None).collect(); + for msg in messages.iter() { + let header = Header::decode(black_box(msg)).unwrap(); + let slot = header.message_id.wrapping_sub(base_msg_id) as usize; + // Clone to simulate ownership transfer from the read + // path — the production code reads directly into a + // fresh Vec each response. + slots[slot] = decode_read_response_from_msg(msg.clone()); + } + slots + }); + }, + ); + } + group.finish(); +} + /// Bench the CPU-bound per-batch work of `pipelined_write`: header construction /// (with credit charge), `encode_write_request`, and `build_request` framing. /// This is the inner loop of WAL pipelined writes before any I/O happens. @@ -279,6 +315,47 @@ fn bench_pipelined_write_encode(c: &mut Criterion) { group.finish(); } +/// Bench the coalesced equivalent: build all packets directly into a single +/// `BytesMut`, the way `pipelined_write` does post-optimization. Comparable +/// to `bench_pipelined_write_encode` — captures the win from eliminating +/// per-packet allocations and from a single contiguous buffer. +fn bench_pipelined_write_encode_coalesced(c: &mut Criterion) { + use bytes::BufMut; + let mut group = c.benchmark_group("pipelined_write_encode_coalesced"); + let file_id = [1u8; 16]; + let cases = [(8usize, 65536usize), (64, 65536), (64, 1024 * 1024)]; + const WRITE_REQUEST_FIXED: usize = 48; + for (depth, chunk_size) in cases { + let chunk = vec![0u8; chunk_size]; + group.throughput(criterion::Throughput::Bytes((depth * chunk_size) as u64)); + group.bench_with_input( + criterion::BenchmarkId::from_parameter(format!("d{depth}_c{chunk_size}")), + &chunk, + |b, chunk| { + b.iter(|| { + let total_bytes = + depth * (4 + SMB2_HEADER_SIZE + WRITE_REQUEST_FIXED + chunk.len()); + let mut buf = BytesMut::with_capacity(total_bytes); + let mut offset = 0u64; + for i in 0..depth { + let mut hdr = Header::new(Command::Write, i as u64) + .with_credit_charge(chunk.len() as u32); + hdr.tree_id = 42; + hdr.session_id = 0xdead_beef; + let packet_total = SMB2_HEADER_SIZE + WRITE_REQUEST_FIXED + chunk.len(); + buf.put_u32((packet_total as u32) & 0x00FF_FFFF); + hdr.encode(&mut buf); + encode_write_request(&mut buf, &file_id, offset, black_box(chunk)); + offset += chunk.len() as u64; + } + buf + }); + }, + ); + } + group.finish(); +} + fn bench_parse_directory_entries(c: &mut Criterion) { // Build 50 entries let mut data = Vec::new(); @@ -321,7 +398,9 @@ criterion_group!( bench_build_request, bench_parse_compound_response, bench_pipelined_read_decode, + bench_pipelined_read_decode_zerocopy, bench_pipelined_write_encode, + bench_pipelined_write_encode_coalesced, bench_parse_directory_entries, ); criterion_main!(benches); diff --git a/scripts/bench-live.sh b/scripts/bench-live.sh index 3fb6677..c83640d 100755 --- a/scripts/bench-live.sh +++ b/scripts/bench-live.sh @@ -5,7 +5,15 @@ set -euo pipefail # # Usage: SPICEIO_SMB_USER=user SPICEIO_SMB_PASS=pass ./scripts/bench-live.sh # -# Runs write and read throughput tests at various file sizes. +# Runs write and read throughput tests at various file sizes, plus +# concurrent multi-stream tests intended to saturate a 10G link. +# +# Environment knobs: +# BENCH_CONCURRENCY parallel streams in the concurrent tests (default 8) +# BENCH_MOUNT_BASELINE 1 to also benchmark a raw mount_smbfs mount of the +# same share — gives a hard ceiling on what the link +# can do, so we can see spiceio's translation overhead +# # Requires: aws cli, dd, curl, bc, perl (Time::HiRes). SMB_SERVER="${SPICEIO_SMB_SERVER:-192.168.3.148}" @@ -15,6 +23,8 @@ SMB_DOMAIN="${SPICEIO_SMB_DOMAIN:-}" REGION="${SPICEIO_REGION:-us-east-1}" BUCKET="${SPICEIO_BUCKET:-bench}" BIND="${SPICEIO_BIND:-127.0.0.1:18334}" +CONCURRENCY="${BENCH_CONCURRENCY:-8}" +MOUNT_BASELINE="${BENCH_MOUNT_BASELINE:-0}" : "${SPICEIO_SMB_USER:?SPICEIO_SMB_USER is required}" : "${SPICEIO_SMB_PASS:?SPICEIO_SMB_PASS is required}" @@ -32,6 +42,7 @@ fi # ── Cleanup ───────────────────────────────────────────────────────────── SPICEIO_PID="" +MOUNT_POINT="" cleanup() { echo "" echo "[bench] cleaning up..." @@ -40,6 +51,10 @@ cleanup() { kill "$SPICEIO_PID" 2>/dev/null || true wait "$SPICEIO_PID" 2>/dev/null || true fi + if [[ -n "$MOUNT_POINT" && -d "$MOUNT_POINT" ]]; then + umount "$MOUNT_POINT" 2>/dev/null || true + rmdir "$MOUNT_POINT" 2>/dev/null || true + fi rm -f /tmp/spiceio-bench-* } trap cleanup EXIT @@ -122,6 +137,114 @@ bench_multi_write() { rm -f "$file" } +# Concurrent single-file PUT: N parallel uploads of `size_bytes`-each. +# Aggregate throughput is what hits the link — this is the test that +# meaningfully exercises a 10G NAS pipe. +bench_concurrent_write() { + local concurrency=$1 size_bytes=$2 label=$3 + local total=$((concurrency * size_bytes)) + local file="/tmp/spiceio-bench-cwrite-${label}" + gen_file "$file" "$size_bytes" + + local start end elapsed mbps + start=$(perl -MTime::HiRes=time -e 'printf "%.6f\n", time') + local pids=() + for i in $(seq 1 "$concurrency"); do + $AWS s3 cp "$file" "s3://${BUCKET}/${PREFIX}/cw-${label}-${i}" --quiet 2>/dev/null & + pids+=($!) + done + for pid in "${pids[@]}"; do + wait "$pid" + done + end=$(perl -MTime::HiRes=time -e 'printf "%.6f\n", time') + elapsed=$(echo "$end - $start" | bc -l) + mbps=$(echo "$total / $elapsed / 1048576" | bc -l) + printf " PUT x%-3d %-5s %6.2fs %7.1f MiB/s (%.2f Gbit/s)\n" \ + "$concurrency" "$label" "$elapsed" "$mbps" \ + "$(echo "$mbps * 8 / 1024" | bc -l)" + rm -f "$file" +} + +bench_concurrent_read() { + local concurrency=$1 size_bytes=$2 label=$3 + local total=$((concurrency * size_bytes)) + + local start end elapsed mbps + start=$(perl -MTime::HiRes=time -e 'printf "%.6f\n", time') + local pids=() + for i in $(seq 1 "$concurrency"); do + $AWS s3 cp "s3://${BUCKET}/${PREFIX}/cw-${label}-${i}" "/tmp/spiceio-bench-cread-${label}-${i}" \ + --quiet 2>/dev/null & + pids+=($!) + done + for pid in "${pids[@]}"; do + wait "$pid" + done + end=$(perl -MTime::HiRes=time -e 'printf "%.6f\n", time') + elapsed=$(echo "$end - $start" | bc -l) + mbps=$(echo "$total / $elapsed / 1048576" | bc -l) + printf " GET x%-3d %-5s %6.2fs %7.1f MiB/s (%.2f Gbit/s)\n" \ + "$concurrency" "$label" "$elapsed" "$mbps" \ + "$(echo "$mbps * 8 / 1024" | bc -l)" + rm -f /tmp/spiceio-bench-cread-${label}-* +} + +# Optional raw-SMB baseline via mount_smbfs. Mounts the same share locally +# and runs the same dd-based write/read tests. Establishes the hard +# ceiling for what the link can do, so we can attribute spiceio's +# translation overhead. +bench_mount_baseline() { + local user="$SPICEIO_SMB_USER" + local pass="$SPICEIO_SMB_PASS" + local server="$SMB_SERVER" + local share="$SMB_SHARE" + + MOUNT_POINT="/tmp/spiceio-bench-mount-$$" + mkdir -p "$MOUNT_POINT" + local escaped_pass + escaped_pass=$(printf '%s' "$pass" | perl -MURI::Escape -ne 'print uri_escape($_)') + if ! mount_smbfs -N "//${user}:${escaped_pass}@${server}/${share}" "$MOUNT_POINT" 2>/dev/null; then + echo " (mount_smbfs failed — skipping baseline)" + rmdir "$MOUNT_POINT" 2>/dev/null + MOUNT_POINT="" + return + fi + + local target="${MOUNT_POINT}/${PREFIX}-mount-baseline" + mkdir -p "$target" + + local label sizes labels + sizes=(104857600 524288000) + labels=("100M" "500M") + for idx in "${!sizes[@]}"; do + local size_bytes=${sizes[$idx]} + label=${labels[$idx]} + local file="/tmp/spiceio-bench-mountin-${label}" + gen_file "$file" "$size_bytes" + + local start end elapsed mbps + start=$(perl -MTime::HiRes=time -e 'printf "%.6f\n", time') + cp "$file" "${target}/${label}" + end=$(perl -MTime::HiRes=time -e 'printf "%.6f\n", time') + elapsed=$(echo "$end - $start" | bc -l) + mbps=$(echo "$size_bytes / $elapsed / 1048576" | bc -l) + printf " PUT mount %-5s %6.2fs %7.1f MiB/s\n" "$label" "$elapsed" "$mbps" + + start=$(perl -MTime::HiRes=time -e 'printf "%.6f\n", time') + cp "${target}/${label}" "${file}.out" + end=$(perl -MTime::HiRes=time -e 'printf "%.6f\n", time') + elapsed=$(echo "$end - $start" | bc -l) + mbps=$(echo "$size_bytes / $elapsed / 1048576" | bc -l) + printf " GET mount %-5s %6.2fs %7.1f MiB/s\n" "$label" "$elapsed" "$mbps" + rm -f "$file" "${file}.out" + done + + rm -rf "$target" 2>/dev/null + umount "$MOUNT_POINT" 2>/dev/null + rmdir "$MOUNT_POINT" 2>/dev/null + MOUNT_POINT="" +} + # ── Run benchmarks ────────────────────────────────────────────────────── echo "" echo "═══════════════════════════════════════════════════════════════" @@ -156,11 +279,24 @@ bench_multi_write 100 1048576 "1M" bench_multi_write 20 10485760 "10M" bench_multi_write 10 52428800 "50M" -# Total: 1685 (write) + 1685 (read) + 800 (multi-write) = 4170 MiB transferred +# Concurrent single-stream tests. Single-stream uploads top out at one TCP +# connection's worth of pipe; aggregate concurrent uploads is the test +# that actually saturates a 10G link. +echo "" +echo "── Concurrent write throughput (x${CONCURRENCY} parallel) ──" +bench_concurrent_write "$CONCURRENCY" 104857600 "100M" +bench_concurrent_write "$CONCURRENCY" 524288000 "500M" + echo "" -echo "── Aggregate ──" -echo " Total written: 2485 MiB (single-file + multi-file)" -echo " Total read: 1685 MiB" -echo " Total I/O: 4170 MiB" +echo "── Concurrent read throughput (x${CONCURRENCY} parallel) ──" +bench_concurrent_read "$CONCURRENCY" 104857600 "100M" +bench_concurrent_read "$CONCURRENCY" 524288000 "500M" + +if [[ "$MOUNT_BASELINE" == "1" ]]; then + echo "" + echo "── Raw mount_smbfs baseline (link ceiling) ──" + bench_mount_baseline +fi + echo "" echo "═══════════════════════════════════════════════════════════════" diff --git a/src/s3/router.rs b/src/s3/router.rs index 3c50cec..bd5f97c 100644 --- a/src/s3/router.rs +++ b/src/s3/router.rs @@ -591,8 +591,15 @@ async fn handle_get_object( let content_length = end - start + 1; - // Build response with streaming body - let (body, tx) = SpiceioBody::channel(4); + // Build response with streaming body. + // + // Channel capacity is sized to match the SMB pipeline depth so a full + // batch of reads can dump into the channel without blocking the producer. + // That lets the SMB-reading task immediately issue the next pipelined + // batch (incurring its round-trip) while the HTTP-sending task drains + // the previous batch into the wire — back-to-back batches overlap, which + // is the difference between filling and starving the 10G link. + let (body, tx) = SpiceioBody::channel(crate::smb::ops::READ_PIPELINE_DEPTH); let chunk_size = handle.max_chunk; // Spawn background task to stream pipelined SMB reads into the channel. diff --git a/src/smb/client.rs b/src/smb/client.rs index 62bfae1..70423b5 100644 --- a/src/smb/client.rs +++ b/src/smb/client.rs @@ -202,9 +202,12 @@ impl SmbClient { async fn send_recv_inner(&self, packet: &[u8]) -> io::Result<(Header, Vec, Vec)> { let mut stream = self.stream.lock().await; - // Sign the packet if we have a signing key + // Sign the packet if we have a signing key. We need a writable buffer + // to sign in-place; `BytesMut::from(&[u8])` is one alloc + one copy + // (same cost as the previous `to_vec`, but expressed as a typed buffer + // that mirrors what the pipelined paths do). if let Some(ref key) = self.signing_key { - let mut signed = packet.to_vec(); + let mut signed = BytesMut::from(packet); sign_packet(&mut signed, key); stream.write_all(&signed).await?; } else { @@ -507,6 +510,10 @@ impl SmbClient { /// Holds the stream lock for the entire batch, eliminating per-request /// round-trip latency. Returns chunks in offset order. Stops early on EOF. /// + /// Coalesces all request packets into a single contiguous buffer and signs + /// each in-place — one allocation, one `write_all` syscall for the whole + /// batch of request headers (only the responses carry bulk data). + /// /// Responses may arrive out of order (SMB2 does not guarantee response /// ordering). Each response is matched to its request slot via message_id. pub async fn pipelined_read( @@ -525,31 +532,38 @@ impl SmbClient { // response.message_id → slot index via simple subtraction. let base_msg_id = self.message_id.fetch_add(count as u64, Ordering::Relaxed); - let mut packets = Vec::with_capacity(count); + // Each request: 4 (NetBIOS length) + SMB2_HEADER_SIZE (64) + 49 + // (read request fixed part incl. 1-byte buffer pad). + const READ_REQUEST_FIXED: usize = 49; + let per_packet = 4 + SMB2_HEADER_SIZE + READ_REQUEST_FIXED; + let mut buf = BytesMut::with_capacity(per_packet * count); + let mut packet_starts: Vec = Vec::with_capacity(count + 1); + for i in 0..count { + packet_starts.push(buf.len()); let offset = start_offset + (i as u64) * (chunk_size as u64); let msg_id = base_msg_id + i as u64; let mut hdr = Header::new(Command::Read, msg_id).with_credit_charge(chunk_size); hdr.session_id = self.session_id; hdr.tree_id = tree_id; - let packet = build_request(&hdr, |buf| { - encode_read_request(buf, file_id, offset, chunk_size); - }); - packets.push(packet); - } - let mut stream = self.stream.lock().await; + let packet_smb_total = SMB2_HEADER_SIZE + READ_REQUEST_FIXED; + buf.put_u32((packet_smb_total as u32) & 0x00FF_FFFF); + hdr.encode(&mut buf); + encode_read_request(&mut buf, file_id, offset, chunk_size); + } + packet_starts.push(buf.len()); - // Send all requests - for packet in &packets { - if let Some(ref key) = self.signing_key { - let mut signed = packet.to_vec(); - sign_packet(&mut signed, key); - stream.write_all(&signed).await?; - } else { - stream.write_all(packet).await?; + if let Some(ref key) = self.signing_key { + for i in 0..count { + let start = packet_starts[i]; + let end = packet_starts[i + 1]; + sign_packet(&mut buf[start..end], key); } } + + let mut stream = self.stream.lock().await; + stream.write_all(&buf).await?; stream.flush().await?; // Receive responses into ordered slots (handles out-of-order delivery). @@ -605,8 +619,10 @@ impl SmbClient { ))); } - let body = msg[SMB2_HEADER_SIZE..].to_vec(); - let data = decode_read_response_owned(body).ok_or_else(|| { + // Zero-copy: hand the full `msg` Vec to the decoder, which slices + // into it as `Bytes` without an extra body copy. For 64KB chunks + // pipelined 64 deep this saves ~4 MiB of memcpy per batch. + let data = decode_read_response_from_msg(msg).ok_or_else(|| { io::Error::new(io::ErrorKind::InvalidData, "invalid read response") })?; slots[slot] = Some(data); @@ -666,6 +682,8 @@ impl SmbClient { /// all responses. Holds the stream lock for the entire batch, eliminating /// per-request round-trip latency. Returns total bytes written. /// + /// Coalesces all packets into a single contiguous buffer and signs each + /// in-place — one allocation, one `write_all` syscall for the whole batch. /// Responses may arrive out of order; each is matched by message_id. pub async fn pipelined_write( &self, @@ -681,33 +699,45 @@ impl SmbClient { let n = chunks.len(); let base_msg_id = self.message_id.fetch_add(n as u64, Ordering::Relaxed); - let mut packets = Vec::with_capacity(n); + // Each packet: 4 (NetBIOS length) + SMB2_HEADER_SIZE (64) + 48 + // (write request fixed part) + chunk data. + const WRITE_REQUEST_FIXED: usize = 48; + let total_bytes: usize = chunks + .iter() + .map(|c| 4 + SMB2_HEADER_SIZE + WRITE_REQUEST_FIXED + c.len()) + .sum(); + let mut buf = BytesMut::with_capacity(total_bytes); + let mut packet_starts: Vec = Vec::with_capacity(n + 1); + let mut offset = start_offset; for (i, chunk) in chunks.iter().enumerate() { + packet_starts.push(buf.len()); let msg_id = base_msg_id + i as u64; let mut hdr = Header::new(Command::Write, msg_id).with_credit_charge(chunk.len() as u32); hdr.session_id = self.session_id; hdr.tree_id = tree_id; - let packet = build_request(&hdr, |buf| { - encode_write_request(buf, file_id, offset, chunk); - }); - packets.push(packet); + + let packet_smb_total = SMB2_HEADER_SIZE + WRITE_REQUEST_FIXED + chunk.len(); + buf.put_u32((packet_smb_total as u32) & 0x00FF_FFFF); + hdr.encode(&mut buf); + encode_write_request(&mut buf, file_id, offset, chunk); offset += chunk.len() as u64; } + packet_starts.push(buf.len()); - let mut stream = self.stream.lock().await; - - // Send all requests - for packet in &packets { - if let Some(ref key) = self.signing_key { - let mut signed = packet.to_vec(); - sign_packet(&mut signed, key); - stream.write_all(&signed).await?; - } else { - stream.write_all(packet).await?; + // Sign each packet in-place. We pre-allocated exact capacity, so the + // earlier slices are still valid (no realloc could have moved them). + if let Some(ref key) = self.signing_key { + for i in 0..n { + let start = packet_starts[i]; + let end = packet_starts[i + 1]; + sign_packet(&mut buf[start..end], key); } } + + let mut stream = self.stream.lock().await; + stream.write_all(&buf).await?; stream.flush().await?; // Receive all responses (handles out-of-order delivery) diff --git a/src/smb/ops.rs b/src/smb/ops.rs index 487658e..a3b869b 100644 --- a/src/smb/ops.rs +++ b/src/smb/ops.rs @@ -712,7 +712,12 @@ impl ShareSession { } /// Number of read requests to pipeline in a single batch. -const PIPELINE_DEPTH: usize = 64; +/// +/// Public so the HTTP streaming path can size its response channel to the +/// same depth — back-to-back read batches overlap with HTTP draining when +/// the channel can hold a full batch (see `s3::router::handle_get_object`). +pub const READ_PIPELINE_DEPTH: usize = 64; +const PIPELINE_DEPTH: usize = READ_PIPELINE_DEPTH; impl FileHandle { /// Read a chunk at the given offset. Returns empty bytes at EOF. diff --git a/src/smb/protocol.rs b/src/smb/protocol.rs index b770317..d6bbe87 100644 --- a/src/smb/protocol.rs +++ b/src/smb/protocol.rs @@ -487,6 +487,28 @@ pub fn decode_read_response_owned(body: Vec) -> Option { Some(bytes) } +/// Zero-copy decode that takes ownership of the full SMB2 message (header + +/// body) and returns a `Bytes` slice referencing the response payload. Avoids +/// the extra body-copy that `decode_read_response_owned` would require if the +/// caller had to split body off first. +pub fn decode_read_response_from_msg(msg: Vec) -> Option { + if msg.len() < SMB2_HEADER_SIZE + 17 { + return None; + } + let body = &msg[SMB2_HEADER_SIZE..]; + let data_offset = u16::from_le_bytes(body[2..4].try_into().unwrap()) as usize; + let data_length = u32::from_le_bytes(body[4..8].try_into().unwrap()) as usize; + + // `data_offset` is from the start of the SMB2 message, not the body. + let start = data_offset; + let end = start.checked_add(data_length)?; + if end > msg.len() { + return None; + } + let bytes = Bytes::from(msg); + Some(bytes.slice(start..end)) +} + // ── Write ─────────────────────────────────────────────────────────────────── pub fn encode_write_request(buf: &mut BytesMut, file_id: &[u8; 16], offset: u64, data: &[u8]) { @@ -871,6 +893,39 @@ mod tests { assert!(decode_read_response(&[0u8; 5]).is_none()); } + #[test] + fn decode_read_response_from_msg_valid() { + // Build a complete SMB2 message: 64-byte header + body. data_offset + // and data_length are measured from the start of the SMB2 message, + // matching the wire format. + let mut msg = vec![0u8; SMB2_HEADER_SIZE + 32]; + let body = &mut msg[SMB2_HEADER_SIZE..]; + let data_offset = (SMB2_HEADER_SIZE + 16) as u16; + body[2..4].copy_from_slice(&data_offset.to_le_bytes()); + body[4..8].copy_from_slice(&5u32.to_le_bytes()); + body[16..21].copy_from_slice(b"hello"); + + let data = decode_read_response_from_msg(msg).unwrap(); + assert_eq!(&data[..], b"hello"); + } + + #[test] + fn decode_read_response_from_msg_too_short() { + assert!(decode_read_response_from_msg(vec![0u8; SMB2_HEADER_SIZE + 5]).is_none()); + assert!(decode_read_response_from_msg(vec![0u8; 10]).is_none()); + } + + #[test] + fn decode_read_response_from_msg_rejects_overflow_length() { + // data_length that would extend past the buffer is rejected. + let mut msg = vec![0u8; SMB2_HEADER_SIZE + 32]; + let body = &mut msg[SMB2_HEADER_SIZE..]; + let data_offset = (SMB2_HEADER_SIZE + 16) as u16; + body[2..4].copy_from_slice(&data_offset.to_le_bytes()); + body[4..8].copy_from_slice(&1_000_000u32.to_le_bytes()); + assert!(decode_read_response_from_msg(msg).is_none()); + } + #[test] fn decode_write_response_valid() { let mut body = vec![0u8; 16];