Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 46 additions & 17 deletions compio-driver/src/sys/op/socket/iocp.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
use std::os::windows::io::AsRawSocket;

use rustix::net::RecvFlags;
use windows_sys::Win32::{
Networking::WinSock::{
LPFN_ACCEPTEX, LPFN_CONNECTEX, LPFN_GETACCEPTEXSOCKADDRS, LPFN_WSARECVMSG,
SO_UPDATE_ACCEPT_CONTEXT, SO_UPDATE_CONNECT_CONTEXT, SOCKADDR, SOCKADDR_STORAGE,
SOL_SOCKET, WSAID_ACCEPTEX, WSAID_CONNECTEX, WSAID_GETACCEPTEXSOCKADDRS, WSAID_WSARECVMSG,
WSAMSG, WSARecv, WSARecvFrom, WSASend, WSASendMsg, WSASendTo, closesocket, setsockopt,
socklen_t,
LPFN_ACCEPTEX, LPFN_CONNECTEX, LPFN_DISCONNECTEX, LPFN_GETACCEPTEXSOCKADDRS,
LPFN_WSARECVMSG, SO_UPDATE_ACCEPT_CONTEXT, SO_UPDATE_CONNECT_CONTEXT, SOCKADDR,
SOCKADDR_STORAGE, SOL_SOCKET, TF_REUSE_SOCKET, WSAID_ACCEPTEX, WSAID_CONNECTEX,
WSAID_DISCONNECTEX, WSAID_GETACCEPTEXSOCKADDRS, WSAID_WSARECVMSG, WSAMSG, WSARecv,
WSARecvFrom, WSASend, WSASendMsg, WSASendTo, closesocket, setsockopt, socklen_t,
},
System::IO::OVERLAPPED,
};
Expand Down Expand Up @@ -35,15 +33,15 @@ unsafe impl OpCode for CloseSocket {
}

/// Accept a connection.
pub struct Accept<S> {
pub struct Accept<S, SA> {
pub(crate) fd: S,
pub(crate) accept_fd: socket2::Socket,
pub(crate) accept_fd: SA,
pub(crate) buffer: [u8; ACCEPT_BUFFER_SIZE],
Comment thread
Berrysoft marked this conversation as resolved.
}

impl<S> Accept<S> {
impl<S, SA> Accept<S, SA> {
/// Create [`Accept`]. `accept_fd` should not be bound.
pub fn new(fd: S, accept_fd: socket2::Socket) -> Self {
pub fn new(fd: S, accept_fd: SA) -> Self {
Self {
fd,
accept_fd,
Expand All @@ -52,14 +50,14 @@ impl<S> Accept<S> {
}
}

impl<S: AsFd> Accept<S> {
impl<S: AsFd, SA: AsFd> Accept<S, SA> {
/// Update accept context.
pub fn update_context(&self) -> io::Result<()> {
let fd = self.fd.as_fd().as_raw_fd();
syscall!(
SOCKET,
setsockopt(
self.accept_fd.as_raw_socket() as _,
self.accept_fd.as_fd().as_raw_fd() as _,
SOL_SOCKET,
SO_UPDATE_ACCEPT_CONTEXT,
&fd as *const _ as _,
Expand All @@ -70,7 +68,7 @@ impl<S: AsFd> Accept<S> {
}

/// Get the remote address from the inner buffer.
pub fn into_addr(self) -> io::Result<(socket2::Socket, SockAddr)> {
pub fn into_addr(self) -> io::Result<(SA, SockAddr)> {
let get_addrs_fn = GET_ADDRS
.get_or_try_init(|| {
get_wsa_fn(self.fd.as_fd().as_raw_fd(), WSAID_GETACCEPTEXSOCKADDRS)
Expand Down Expand Up @@ -109,7 +107,7 @@ impl<S: AsFd> Accept<S> {
}
}

unsafe impl<S: AsFd> OpCode for Accept<S> {
unsafe impl<S: AsFd, SA: AsFd> OpCode for Accept<S, SA> {
type Control = ();

unsafe fn operate(&mut self, _: &mut (), optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
Expand All @@ -122,7 +120,7 @@ unsafe impl<S: AsFd> OpCode for Accept<S> {
let res = unsafe {
accept_fn(
self.fd.as_fd().as_raw_fd() as _,
self.accept_fd.as_raw_socket() as _,
self.accept_fd.as_fd().as_raw_fd() as _,
self.buffer.sys_slice_mut().ptr() as _,
0,
ACCEPT_ADDR_BUFFER_SIZE as _,
Expand Down Expand Up @@ -187,7 +185,38 @@ unsafe impl<S: AsFd> OpCode for Connect<S> {
}
}

/// Receive data from remote.
/// Disconnect a connected socket and reuse it for another connection.
pub struct Disconnect<S> {
pub(crate) fd: S,
}

impl<S> Disconnect<S> {
/// Create [`Disconnect`].
pub fn new(fd: S) -> Self {
Self { fd }
}
}

static DISCONNECT_EX: OnceLock<LPFN_DISCONNECTEX> = OnceLock::new();

unsafe impl<S: AsFd> OpCode for Disconnect<S> {
type Control = ();

unsafe fn operate(&mut self, _: &mut (), optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
let disconnect_fn = DISCONNECT_EX
.get_or_try_init(|| get_wsa_fn(self.fd.as_fd().as_raw_fd(), WSAID_DISCONNECTEX))?
.ok_or_else(|| {
io::Error::new(io::ErrorKind::Unsupported, "cannot retrieve DisconnectEx")
})?;
let res =
unsafe { disconnect_fn(self.fd.as_fd().as_raw_fd() as _, optr, TF_REUSE_SOCKET, 0) };
win32_result(res, 0)
}

fn cancel(&mut self, _: &mut (), optr: *mut OVERLAPPED) -> io::Result<()> {
cancel(self.fd.as_fd().as_raw_fd(), optr)
}
}

#[derive(Default)]
#[doc(hidden)]
Expand Down
2 changes: 1 addition & 1 deletion compio-net/src/incoming/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::Socket;

pub struct Incoming<'a> {
listener: &'a Socket,
state: Option<Submit<Accept<SharedFd<Socket2>>>>,
state: Option<Submit<Accept<SharedFd<Socket2>, Socket2>>>,
}

impl<'a> Incoming<'a> {
Expand Down
23 changes: 18 additions & 5 deletions compio-net/src/socket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@ use std::{
use compio_buf::{
BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut, SetLen, buf_try,
};
#[cfg(windows)]
use compio_driver::op::Disconnect;
#[cfg(unix)]
use compio_driver::op::{Bind, CreateSocket, Listen, ShutdownSocket};
use compio_driver::{
AsRawFd, BufferRef, OpCode, RawFd, ResultTakeBuffer, SharedFd, TakeBuffer, ToSharedFd,
AsFd, AsRawFd, BorrowedFd, BufferRef, OpCode, RawFd, ResultTakeBuffer, SharedFd, TakeBuffer,
ToSharedFd,
op::{
Accept, BufResultExt, CloseSocket, Connect, Recv, RecvFlags, RecvFrom, RecvFromManaged,
RecvFromMulti, RecvFromMultiResult, RecvFromVectored, RecvManaged, RecvMsg, RecvMsgManaged,
Expand Down Expand Up @@ -153,6 +156,12 @@ impl Socket {
Ok(())
}

#[cfg(windows)]
pub async fn disconnect(&self) -> io::Result<()> {
let op = Disconnect::new(self.to_shared_fd());
compio_runtime::submit(op).await.0.map(|_| ())
}

#[cfg(unix)]
pub async fn accept(&self) -> io::Result<(Self, SockAddr)> {
let op = Accept::new(self.to_shared_fd());
Expand All @@ -168,11 +177,16 @@ impl Socket {
let ty = self.socket.r#type()?;
let protocol = self.socket.protocol()?;
let accept_sock = Socket2::new(domain, ty, protocol)?;
self.accept_with(Self::from_socket2(accept_sock)?).await
}

#[cfg(windows)]
pub async fn accept_with(&self, accept_sock: Self) -> io::Result<(Self, SockAddr)> {
let op = Accept::new(self.to_shared_fd(), accept_sock);
let (_, op) = buf_try!(@try compio_runtime::submit(op).await);
op.update_context()?;
let (accept_sock, addr) = op.into_addr()?;
Ok((Self::from_socket2(accept_sock)?, addr))
Ok((accept_sock, addr))
}

pub fn incoming(&self) -> Incoming<'_> {
Expand Down Expand Up @@ -642,9 +656,8 @@ impl AsRawFd for Socket {
}
}

#[cfg(unix)]
impl std::os::fd::AsFd for Socket {
fn as_fd(&self) -> std::os::fd::BorrowedFd<'_> {
impl AsFd for Socket {
fn as_fd(&self) -> BorrowedFd<'_> {
self.socket.as_fd()
}
}
Expand Down
21 changes: 21 additions & 0 deletions compio-net/src/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,15 @@ impl TcpListener {
Ok((stream, addr.as_socket().expect("should be SocketAddr")))
}

/// Accepts a new incoming connection from this listener using the provided
/// socket.
#[cfg(windows)]
pub async fn accept_with(&self, sock: TcpSocket) -> io::Result<(TcpStream, SocketAddr)> {
let (socket, addr) = self.inner.accept_with(sock.inner).await?;
let stream = TcpStream { inner: socket };
Ok((stream, addr.as_socket().expect("should be SocketAddr")))
}

/// Returns a stream of incoming connections to this listener.
///
/// ## Platform specific
Expand Down Expand Up @@ -334,6 +343,16 @@ impl TcpStream {
self.inner.into_poll_fd()
}

/// Close the connection of the socket, and reuse it to create a new
/// connection. This method is useful when the socket is created by
/// [`TcpListener::accept`], and will be reused in
/// [`TcpListener::accept_with`] to accept a new connection.
#[cfg(windows)]
pub async fn disconnect(self) -> io::Result<TcpSocket> {
self.inner.disconnect().await?;
Ok(TcpSocket { inner: self.inner })
}

/// Gets the value of the `TCP_NODELAY` option on this socket.
///
/// For more information about this option, see
Expand Down Expand Up @@ -1065,6 +1084,8 @@ impl TcpSocket {
/// The [`TcpSocket`] is consumed. Once the connection is established, a
/// connected [`TcpStream`] is returned. If the connection fails, the
/// encountered error is returned.
///
/// On Windows, the socket should be bound to an address before connecting.
pub async fn connect(self, addr: SocketAddr) -> io::Result<TcpStream> {
self.inner.connect_async(&addr.into()).await?;
Ok(TcpStream { inner: self.inner })
Expand Down
19 changes: 19 additions & 0 deletions compio-net/src/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,15 @@ impl UnixListener {
Ok((stream, addr))
}

/// Accepts a new incoming connection from this listener using the provided
/// socket.
#[cfg(windows)]
pub async fn accept_with(&self, sock: UnixSocket) -> io::Result<(UnixStream, SockAddr)> {
let (socket, addr) = self.inner.accept_with(sock.inner).await?;
let stream = UnixStream { inner: socket };
Ok((stream, addr))
}

/// Returns a stream of incoming connections to this listener.
///
/// ## Platform specific
Expand Down Expand Up @@ -288,6 +297,16 @@ impl UnixStream {
self.inner.into_poll_fd()
}

/// Close the connection of the socket, and reuse it to create a new
/// connection. This method is useful when the socket is created by
/// [`UnixListener::accept`], and will be reused in
/// [`UnixListener::accept_with`] to accept a new connection.
#[cfg(windows)]
pub async fn disconnect(self) -> io::Result<UnixSocket> {
self.inner.disconnect().await?;
Ok(UnixSocket { inner: self.inner })
}
Comment thread
Berrysoft marked this conversation as resolved.

/// Signifies whether the underlying socket was non-empty after the last
/// receive operation.
///
Expand Down
54 changes: 54 additions & 0 deletions compio-net/tests/tcp_disconnect.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#![cfg(windows)]

use compio_io::AsyncWrite;
use compio_net::{TcpListener, TcpStream};
use compio_runtime::ResumeUnwind;

#[test]
fn disconnect() {
compio_runtime::Runtime::new().unwrap().block_on(async {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let task = compio_runtime::spawn(async move {
let (socket, _) = listener.accept().await.unwrap();
let socket = socket.disconnect().await.unwrap();
let (mut socket, _) = listener.accept_with(socket).await.unwrap();
socket.shutdown().await.unwrap();
socket.close().await.unwrap();
});

for _i in 0..2 {
let mut client = TcpStream::connect(addr).await.unwrap();
client.shutdown().await.unwrap();
client.close().await.unwrap();
}

task.await.resume_unwind().expect("shouldn't be cancelled");
})
}

#[test]
fn reuse() {
compio_runtime::Runtime::new().unwrap().block_on(async {
let listener1 = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr1 = listener1.local_addr().unwrap();
let listener2 = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr2 = listener2.local_addr().unwrap();

let task = compio_runtime::spawn(async move {
let (socket, _) = listener1.accept().await.unwrap();
let socket = socket.disconnect().await.unwrap();
let (mut socket, _) = listener2.accept_with(socket).await.unwrap();
socket.shutdown().await.unwrap();
socket.close().await.unwrap();
});

let client = TcpStream::connect(addr1).await.unwrap();
let client = client.disconnect().await.unwrap();
let mut client = client.connect(addr2).await.unwrap();
client.shutdown().await.unwrap();
Comment thread
Berrysoft marked this conversation as resolved.
client.close().await.unwrap();

task.await.resume_unwind().expect("shouldn't be cancelled");
})
}