diff --git a/src/crypto/noise/mod.rs b/src/crypto/noise/mod.rs index 709ab082..f5775684 100644 --- a/src/crypto/noise/mod.rs +++ b/src/crypto/noise/mod.rs @@ -331,15 +331,14 @@ enum ReadState { } enum WriteState { - Ready { + /// No pending encrypted data, ready to accept new writes + Idle, + /// Writing encrypted data to socket + Writing { + /// Offset into encrypt_buffer that's been written to socket offset: usize, - size: usize, - encrypted_size: usize, - }, - WriteFrame { - offset: usize, - size: usize, - encrypted_size: usize, + /// Total length of encrypted data in encrypt_buffer + encrypted_len: usize, }, } @@ -378,11 +377,7 @@ impl NoiseSocket { nread: 0usize, offset: 0usize, current_frame_size: None, - write_state: WriteState::Ready { - offset: 0usize, - size: 0usize, - encrypted_size: 0usize, - }, + write_state: WriteState::Idle, encrypt_buffer: vec![0u8; max_write_buffer_size * (MAX_NOISE_MSG_LEN + 2)], decrypt_buffer: Some(vec![0u8; MAX_FRAME_LEN]), read_state: ReadState::ReadData { @@ -651,90 +646,150 @@ impl AsyncWrite for NoiseSocket { buf: &[u8], ) -> Poll> { let this = Pin::into_inner(self); - let mut chunks = buf.chunks(MAX_FRAME_LEN).peekable(); - loop { - match this.write_state { - WriteState::Ready { - offset, - size, - encrypted_size, - } => { - let Some(chunk) = chunks.next() else { + // Step 1. Attempt to drain any pending data. + if let WriteState::Writing { + offset, + encrypted_len, + } = &mut this.write_state + { + loop { + match Pin::new(&mut this.io) + .poll_write(cx, &this.encrypt_buffer[*offset..*encrypted_len]) + { + Poll::Ready(Ok(0)) => { + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); + } + Poll::Ready(Ok(n)) => { + *offset += n; + if offset == encrypted_len { + // Buffer fully drained! + this.write_state = WriteState::Idle; + break; + } + } + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => { + // Socket is busy, move on to encryption. break; - }; + } + } + } + } - match this.noise.write_message(chunk, &mut this.encrypt_buffer[offset + 2..]) { - Err(error) => { - tracing::error!( - target: LOG_TARGET, - ?error, - ty = ?this.ty, - peer = ?this.peer, - "failed to encrypt message" - ); + // Step 2. Encrypt and buffer the new data. + let mut buffer_offset = match this.write_state { + WriteState::Idle => 0, + WriteState::Writing { encrypted_len, .. } => encrypted_len, + }; + // Nothing to do if there is no data to write. + if buf.is_empty() { + return Poll::Ready(Ok(0)); + } - return Poll::Ready(Err(io::ErrorKind::InvalidData.into())); - } - Ok(nwritten) => { - this.encrypt_buffer[offset] = (nwritten >> 8) as u8; - this.encrypt_buffer[offset + 1] = (nwritten & 0xff) as u8; - - if let Some(next_chunk) = chunks.peek() { - if next_chunk.len() + NOISE_EXTRA_ENCRYPT_SPACE + 2 - <= this.encrypt_buffer[offset + nwritten + 2..].len() - { - this.write_state = WriteState::Ready { - offset: offset + nwritten + 2, - size: size + chunk.len(), - encrypted_size: encrypted_size + nwritten + 2, - }; - continue; - } - } + let mut total_plaintext = 0usize; + // Encrypt as many chunks as fit in the remaining space + for chunk in buf.chunks(MAX_FRAME_LEN) { + // Check space for this specific chunk + overhead + // Note: overhead is 2 bytes length + 16 bytes auth tag + let overhead = 2 + NOISE_EXTRA_ENCRYPT_SPACE; + if buffer_offset + chunk.len() + overhead > this.encrypt_buffer.len() { + // Buffer is full, stop packing + break; + } - this.write_state = WriteState::WriteFrame { - offset: 0usize, - size: size + chunk.len(), - encrypted_size: encrypted_size + nwritten + 2, - }; - } - } + match this.noise.write_message(chunk, &mut this.encrypt_buffer[buffer_offset + 2..]) { + Ok(nwritten) => { + // Write frame length prefix + this.encrypt_buffer[buffer_offset] = (nwritten >> 8) as u8; + this.encrypt_buffer[buffer_offset + 1] = (nwritten & 0xff) as u8; + + buffer_offset += nwritten + 2; + total_plaintext += chunk.len(); } - WriteState::WriteFrame { - ref mut offset, - size, - encrypted_size, - } => loop { - match futures::ready!(Pin::new(&mut this.io) - .poll_write(cx, &this.encrypt_buffer[*offset..encrypted_size])) - { - Ok(nwritten) => { - *offset += nwritten; - - if offset == &encrypted_size { - this.write_state = WriteState::Ready { - offset: 0usize, - size: 0usize, - encrypted_size: 0usize, - }; - return Poll::Ready(Ok(size)); - } - } - Err(error) => return Poll::Ready(Err(error)), - } - }, + Err(error) => { + tracing::error!(target: LOG_TARGET, ?error, "failed to encrypt"); + return Poll::Ready(Err(io::ErrorKind::InvalidData.into())); + } + } + } + if total_plaintext == 0 { + // No data could be buffered because the buffer is full. + // + // This can only happen when we're in WriteState::Writing (buffer not empty). + // In step 1, the inner poll_write must have returned Pending (otherwise the + // buffer would have drained and we'd have space). That Pending registered + // the waker, so we'll be woken when the socket becomes writable again. + // + // This condition will always be satisfied, since the encrypted buffer + // is large enough (MAX_NOISE_MSG_LEN) to hold at least one chunk (MAX_FRAME_LEN) with + // overhead. + return Poll::Pending; + } + + // Step 3. Adjust state to writing and return number of bytes accepted. + // Without this step, we can cause higher-level panics in rust-yamux + // leading to unnecessary connection closures: + // - poll_write is called with buffer 512 bytes (we previously returned Pending but accepted + // and encrypted the buffer) + // - a future poll_write is called with a PONG frame (or smaller buffer) of 12 bytes + // - at this point we would have returned 512 from the previous call causing indexing out of + // bounds + + match this.write_state { + WriteState::Idle => { + this.write_state = WriteState::Writing { + offset: 0, + encrypted_len: buffer_offset, + }; + } + WriteState::Writing { + ref mut encrypted_len, + .. + } => { + *encrypted_len = buffer_offset; } } - Poll::Ready(Ok(0)) + // We have successfully buffered the data: + // - poll_flush or next poll_write will drain it. + Poll::Ready(Ok(total_plaintext)) } - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.io).poll_flush(cx) + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = Pin::into_inner(self); + + // Flush internal buffer of encrypted messages + if let WriteState::Writing { + offset, + encrypted_len, + } = &mut this.write_state + { + loop { + match futures::ready!(Pin::new(&mut this.io) + .poll_write(cx, &this.encrypt_buffer[*offset..*encrypted_len])) + { + Ok(0) => return Poll::Ready(Err(io::ErrorKind::WriteZero.into())), + Ok(n) => { + *offset += n; + if offset == encrypted_len { + this.write_state = WriteState::Idle; + break; + } + } + Err(e) => return Poll::Ready(Err(e)), + } + } + } + + // Flush underlying socket + Pin::new(&mut this.io).poll_flush(cx) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // Ensure buffer is flushed before closing + futures::ready!(self.as_mut().poll_flush(cx))?; + Pin::new(&mut self.io).poll_close(cx) } } @@ -918,7 +973,17 @@ mod tests { // verify the connection works by reading a string let mut buf = vec![0u8; 512]; + + // Calling AsyncWrite::write, followed by AsyncRead::read_exact can + // cause deadlocks because the "AsyncWrite::write" does not guarantee + // flushing. Therefore, this is a misuse of the API. let sent = res1.0.write(b"hello, world").await.unwrap(); + // Write ensures data reaches the buffers, flush ensures data is sent. + res1.0.flush().await.unwrap(); + + // At this point it is safe to read_exact. The test previously relied + // on the fact that `Noise::poll_write` would flush the data internally, + // causing head-of-line blocking and panics on different buffer sizes. res2.0.read_exact(&mut buf[..sent]).await.unwrap(); assert_eq!(std::str::from_utf8(&buf[..sent]), Ok("hello, world")); @@ -936,4 +1001,154 @@ mod tests { _ => panic!("invalid error"), } } + + /// Mock IO that returns Pending on first write, then Ready on subsequent writes + struct MockPendingIO { + write_count: usize, + buffer: Vec, + } + + impl MockPendingIO { + fn new() -> Self { + Self { + write_count: 0, + buffer: Vec::new(), + } + } + } + + impl AsyncRead for MockPendingIO { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut [u8], + ) -> Poll> { + Poll::Ready(Ok(0)) + } + } + + impl AsyncWrite for MockPendingIO { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.write_count += 1; + + // Return Pending on first write, Ready on subsequent writes + if self.write_count == 1 { + Poll::Pending + } else { + // Accept the write + self.buffer.extend_from_slice(buf); + Poll::Ready(Ok(buf.len())) + } + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + } + + #[tokio::test] + async fn test_poll_write_wrong_size_panic() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let keypair1 = Keypair::generate(); + let keypair2 = Keypair::generate(); + + let peer1_id = PeerId::from_public_key(&keypair1.public().into()); + let peer2_id = PeerId::from_public_key(&keypair2.public().into()); + + let listener = TcpListener::bind("[::1]:0".parse::().unwrap()).await.unwrap(); + + let (stream1, stream2) = tokio::join!( + TcpStream::connect(listener.local_addr().unwrap()), + listener.accept() + ); + let (io1, io2) = { + let io1 = TokioAsyncReadCompatExt::compat(stream1.unwrap()).into_inner(); + let io1 = Box::new(TokioAsyncWriteCompatExt::compat_write(io1)); + let io2 = TokioAsyncReadCompatExt::compat(stream2.unwrap().0).into_inner(); + let io2 = Box::new(TokioAsyncWriteCompatExt::compat_write(io2)); + + (io1, io2) + }; + + // Perform handshake + let (res1, res2) = tokio::join!( + handshake( + io1, + &keypair1, + Role::Dialer, + MAX_READ_AHEAD_FACTOR, + MAX_WRITE_BUFFER_SIZE, + std::time::Duration::from_secs(10), + HandshakeTransport::Tcp, + ), + handshake( + io2, + &keypair2, + Role::Listener, + MAX_READ_AHEAD_FACTOR, + MAX_WRITE_BUFFER_SIZE, + std::time::Duration::from_secs(10), + HandshakeTransport::Tcp, + ) + ); + let (socket1, peer1) = res1.unwrap(); + let (_socket2, peer2) = res2.unwrap(); + + assert_eq!(peer1, peer2_id); + assert_eq!(peer2, peer1_id); + + // Wrap socket with MockPendingIO + let mock_io = MockPendingIO::new(); + let mut noise_socket = NoiseSocket::new( + mock_io, + socket1.noise, + MAX_READ_AHEAD_FACTOR, + MAX_WRITE_BUFFER_SIZE, + peer1, + HandshakeTransport::Tcp, + ); + + // First write with 512 bytes - this will encrypt data, buffer it and return Ok(512) + // However, the data is not yet flushed to the underlying IO. + let large_buffer = vec![0xAA; 512]; + let waker = futures::task::noop_waker(); + let mut cx = Context::from_waker(&waker); + + match Pin::new(&mut noise_socket).poll_write(&mut cx, &large_buffer) { + Poll::Ready(Ok(n)) if n == 512 => {} + state => panic!("Expected Ok(512), got {:?}", state), + } + + // Second write with 12 bytes (PONG frame). + // This previously flushes the first write and returned 512 instead of 12, causing a panic + // to rust-yamux when indexing the buffer. + // With the new implementation this will: flush any pending data (from first write), and + // then encrypt the small buffer. + let small_buffer = vec![0xBB; 12]; + match Pin::new(&mut noise_socket).poll_write(&mut cx, &small_buffer) { + Poll::Ready(Ok(n)) => { + println!( + "poll_write returned {} bytes, but buffer is only {} bytes", + n, + small_buffer.len() + ); + + // Safe to reference since the exact length is returned. + let _ = &small_buffer[n..]; + } + Poll::Pending => panic!("Expected Ready, got Pending"), + Poll::Ready(Err(e)) => panic!("Expected Ready, got error: {}", e), + } + } }