Skip to content

Commit 5289d65

Browse files
authored
feat(driver,net)!: reuse socket on Windows (#887)
* feat(driver,net)!: accept with provided socket * feat(driver,net,iocp): disconnect socket and reuse * docs(net): disconnect * test(net): connect reuse
1 parent 6501181 commit 5289d65

6 files changed

Lines changed: 159 additions & 23 deletions

File tree

compio-driver/src/sys/op/socket/iocp.rs

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
use std::os::windows::io::AsRawSocket;
2-
31
use rustix::net::RecvFlags;
42
use windows_sys::Win32::{
53
Networking::WinSock::{
6-
LPFN_ACCEPTEX, LPFN_CONNECTEX, LPFN_GETACCEPTEXSOCKADDRS, LPFN_WSARECVMSG,
7-
SO_UPDATE_ACCEPT_CONTEXT, SO_UPDATE_CONNECT_CONTEXT, SOCKADDR, SOCKADDR_STORAGE,
8-
SOL_SOCKET, WSAID_ACCEPTEX, WSAID_CONNECTEX, WSAID_GETACCEPTEXSOCKADDRS, WSAID_WSARECVMSG,
9-
WSAMSG, WSARecv, WSARecvFrom, WSASend, WSASendMsg, WSASendTo, closesocket, setsockopt,
10-
socklen_t,
4+
LPFN_ACCEPTEX, LPFN_CONNECTEX, LPFN_DISCONNECTEX, LPFN_GETACCEPTEXSOCKADDRS,
5+
LPFN_WSARECVMSG, SO_UPDATE_ACCEPT_CONTEXT, SO_UPDATE_CONNECT_CONTEXT, SOCKADDR,
6+
SOCKADDR_STORAGE, SOL_SOCKET, TF_REUSE_SOCKET, WSAID_ACCEPTEX, WSAID_CONNECTEX,
7+
WSAID_DISCONNECTEX, WSAID_GETACCEPTEXSOCKADDRS, WSAID_WSARECVMSG, WSAMSG, WSARecv,
8+
WSARecvFrom, WSASend, WSASendMsg, WSASendTo, closesocket, setsockopt, socklen_t,
119
},
1210
System::IO::OVERLAPPED,
1311
};
@@ -35,15 +33,15 @@ unsafe impl OpCode for CloseSocket {
3533
}
3634

3735
/// Accept a connection.
38-
pub struct Accept<S> {
36+
pub struct Accept<S, SA> {
3937
pub(crate) fd: S,
40-
pub(crate) accept_fd: socket2::Socket,
38+
pub(crate) accept_fd: SA,
4139
pub(crate) buffer: [u8; ACCEPT_BUFFER_SIZE],
4240
}
4341

44-
impl<S> Accept<S> {
42+
impl<S, SA> Accept<S, SA> {
4543
/// Create [`Accept`]. `accept_fd` should not be bound.
46-
pub fn new(fd: S, accept_fd: socket2::Socket) -> Self {
44+
pub fn new(fd: S, accept_fd: SA) -> Self {
4745
Self {
4846
fd,
4947
accept_fd,
@@ -52,14 +50,14 @@ impl<S> Accept<S> {
5250
}
5351
}
5452

55-
impl<S: AsFd> Accept<S> {
53+
impl<S: AsFd, SA: AsFd> Accept<S, SA> {
5654
/// Update accept context.
5755
pub fn update_context(&self) -> io::Result<()> {
5856
let fd = self.fd.as_fd().as_raw_fd();
5957
syscall!(
6058
SOCKET,
6159
setsockopt(
62-
self.accept_fd.as_raw_socket() as _,
60+
self.accept_fd.as_fd().as_raw_fd() as _,
6361
SOL_SOCKET,
6462
SO_UPDATE_ACCEPT_CONTEXT,
6563
&fd as *const _ as _,
@@ -70,7 +68,7 @@ impl<S: AsFd> Accept<S> {
7068
}
7169

7270
/// Get the remote address from the inner buffer.
73-
pub fn into_addr(self) -> io::Result<(socket2::Socket, SockAddr)> {
71+
pub fn into_addr(self) -> io::Result<(SA, SockAddr)> {
7472
let get_addrs_fn = GET_ADDRS
7573
.get_or_try_init(|| {
7674
get_wsa_fn(self.fd.as_fd().as_raw_fd(), WSAID_GETACCEPTEXSOCKADDRS)
@@ -109,7 +107,7 @@ impl<S: AsFd> Accept<S> {
109107
}
110108
}
111109

112-
unsafe impl<S: AsFd> OpCode for Accept<S> {
110+
unsafe impl<S: AsFd, SA: AsFd> OpCode for Accept<S, SA> {
113111
type Control = ();
114112

115113
unsafe fn operate(&mut self, _: &mut (), optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
@@ -122,7 +120,7 @@ unsafe impl<S: AsFd> OpCode for Accept<S> {
122120
let res = unsafe {
123121
accept_fn(
124122
self.fd.as_fd().as_raw_fd() as _,
125-
self.accept_fd.as_raw_socket() as _,
123+
self.accept_fd.as_fd().as_raw_fd() as _,
126124
self.buffer.sys_slice_mut().ptr() as _,
127125
0,
128126
ACCEPT_ADDR_BUFFER_SIZE as _,
@@ -187,7 +185,38 @@ unsafe impl<S: AsFd> OpCode for Connect<S> {
187185
}
188186
}
189187

190-
/// Receive data from remote.
188+
/// Disconnect a connected socket and reuse it for another connection.
189+
pub struct Disconnect<S> {
190+
pub(crate) fd: S,
191+
}
192+
193+
impl<S> Disconnect<S> {
194+
/// Create [`Disconnect`].
195+
pub fn new(fd: S) -> Self {
196+
Self { fd }
197+
}
198+
}
199+
200+
static DISCONNECT_EX: OnceLock<LPFN_DISCONNECTEX> = OnceLock::new();
201+
202+
unsafe impl<S: AsFd> OpCode for Disconnect<S> {
203+
type Control = ();
204+
205+
unsafe fn operate(&mut self, _: &mut (), optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
206+
let disconnect_fn = DISCONNECT_EX
207+
.get_or_try_init(|| get_wsa_fn(self.fd.as_fd().as_raw_fd(), WSAID_DISCONNECTEX))?
208+
.ok_or_else(|| {
209+
io::Error::new(io::ErrorKind::Unsupported, "cannot retrieve DisconnectEx")
210+
})?;
211+
let res =
212+
unsafe { disconnect_fn(self.fd.as_fd().as_raw_fd() as _, optr, TF_REUSE_SOCKET, 0) };
213+
win32_result(res, 0)
214+
}
215+
216+
fn cancel(&mut self, _: &mut (), optr: *mut OVERLAPPED) -> io::Result<()> {
217+
cancel(self.fd.as_fd().as_raw_fd(), optr)
218+
}
219+
}
191220

192221
#[derive(Default)]
193222
#[doc(hidden)]

compio-net/src/incoming/windows.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use crate::Socket;
1414

1515
pub struct Incoming<'a> {
1616
listener: &'a Socket,
17-
state: Option<Submit<Accept<SharedFd<Socket2>>>>,
17+
state: Option<Submit<Accept<SharedFd<Socket2>, Socket2>>>,
1818
}
1919

2020
impl<'a> Incoming<'a> {

compio-net/src/socket/mod.rs

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,13 @@ use std::{
99
use compio_buf::{
1010
BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut, SetLen, buf_try,
1111
};
12+
#[cfg(windows)]
13+
use compio_driver::op::Disconnect;
1214
#[cfg(unix)]
1315
use compio_driver::op::{Bind, CreateSocket, Listen, ShutdownSocket};
1416
use compio_driver::{
15-
AsRawFd, BufferRef, OpCode, RawFd, ResultTakeBuffer, SharedFd, TakeBuffer, ToSharedFd,
17+
AsFd, AsRawFd, BorrowedFd, BufferRef, OpCode, RawFd, ResultTakeBuffer, SharedFd, TakeBuffer,
18+
ToSharedFd,
1619
op::{
1720
Accept, BufResultExt, CloseSocket, Connect, Recv, RecvFlags, RecvFrom, RecvFromManaged,
1821
RecvFromMulti, RecvFromMultiResult, RecvFromVectored, RecvManaged, RecvMsg, RecvMsgManaged,
@@ -153,6 +156,12 @@ impl Socket {
153156
Ok(())
154157
}
155158

159+
#[cfg(windows)]
160+
pub async fn disconnect(&self) -> io::Result<()> {
161+
let op = Disconnect::new(self.to_shared_fd());
162+
compio_runtime::submit(op).await.0.map(|_| ())
163+
}
164+
156165
#[cfg(unix)]
157166
pub async fn accept(&self) -> io::Result<(Self, SockAddr)> {
158167
let op = Accept::new(self.to_shared_fd());
@@ -168,11 +177,16 @@ impl Socket {
168177
let ty = self.socket.r#type()?;
169178
let protocol = self.socket.protocol()?;
170179
let accept_sock = Socket2::new(domain, ty, protocol)?;
180+
self.accept_with(Self::from_socket2(accept_sock)?).await
181+
}
182+
183+
#[cfg(windows)]
184+
pub async fn accept_with(&self, accept_sock: Self) -> io::Result<(Self, SockAddr)> {
171185
let op = Accept::new(self.to_shared_fd(), accept_sock);
172186
let (_, op) = buf_try!(@try compio_runtime::submit(op).await);
173187
op.update_context()?;
174188
let (accept_sock, addr) = op.into_addr()?;
175-
Ok((Self::from_socket2(accept_sock)?, addr))
189+
Ok((accept_sock, addr))
176190
}
177191

178192
pub fn incoming(&self) -> Incoming<'_> {
@@ -642,9 +656,8 @@ impl AsRawFd for Socket {
642656
}
643657
}
644658

645-
#[cfg(unix)]
646-
impl std::os::fd::AsFd for Socket {
647-
fn as_fd(&self) -> std::os::fd::BorrowedFd<'_> {
659+
impl AsFd for Socket {
660+
fn as_fd(&self) -> BorrowedFd<'_> {
648661
self.socket.as_fd()
649662
}
650663
}

compio-net/src/tcp.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,15 @@ impl TcpListener {
119119
Ok((stream, addr.as_socket().expect("should be SocketAddr")))
120120
}
121121

122+
/// Accepts a new incoming connection from this listener using the provided
123+
/// socket.
124+
#[cfg(windows)]
125+
pub async fn accept_with(&self, sock: TcpSocket) -> io::Result<(TcpStream, SocketAddr)> {
126+
let (socket, addr) = self.inner.accept_with(sock.inner).await?;
127+
let stream = TcpStream { inner: socket };
128+
Ok((stream, addr.as_socket().expect("should be SocketAddr")))
129+
}
130+
122131
/// Returns a stream of incoming connections to this listener.
123132
///
124133
/// ## Platform specific
@@ -334,6 +343,16 @@ impl TcpStream {
334343
self.inner.into_poll_fd()
335344
}
336345

346+
/// Close the connection of the socket, and reuse it to create a new
347+
/// connection. This method is useful when the socket is created by
348+
/// [`TcpListener::accept`], and will be reused in
349+
/// [`TcpListener::accept_with`] to accept a new connection.
350+
#[cfg(windows)]
351+
pub async fn disconnect(self) -> io::Result<TcpSocket> {
352+
self.inner.disconnect().await?;
353+
Ok(TcpSocket { inner: self.inner })
354+
}
355+
337356
/// Gets the value of the `TCP_NODELAY` option on this socket.
338357
///
339358
/// For more information about this option, see
@@ -1065,6 +1084,8 @@ impl TcpSocket {
10651084
/// The [`TcpSocket`] is consumed. Once the connection is established, a
10661085
/// connected [`TcpStream`] is returned. If the connection fails, the
10671086
/// encountered error is returned.
1087+
///
1088+
/// On Windows, the socket should be bound to an address before connecting.
10681089
pub async fn connect(self, addr: SocketAddr) -> io::Result<TcpStream> {
10691090
self.inner.connect_async(&addr.into()).await?;
10701091
Ok(TcpStream { inner: self.inner })

compio-net/src/unix.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,15 @@ impl UnixListener {
116116
Ok((stream, addr))
117117
}
118118

119+
/// Accepts a new incoming connection from this listener using the provided
120+
/// socket.
121+
#[cfg(windows)]
122+
pub async fn accept_with(&self, sock: UnixSocket) -> io::Result<(UnixStream, SockAddr)> {
123+
let (socket, addr) = self.inner.accept_with(sock.inner).await?;
124+
let stream = UnixStream { inner: socket };
125+
Ok((stream, addr))
126+
}
127+
119128
/// Returns a stream of incoming connections to this listener.
120129
///
121130
/// ## Platform specific
@@ -288,6 +297,16 @@ impl UnixStream {
288297
self.inner.into_poll_fd()
289298
}
290299

300+
/// Close the connection of the socket, and reuse it to create a new
301+
/// connection. This method is useful when the socket is created by
302+
/// [`UnixListener::accept`], and will be reused in
303+
/// [`UnixListener::accept_with`] to accept a new connection.
304+
#[cfg(windows)]
305+
pub async fn disconnect(self) -> io::Result<UnixSocket> {
306+
self.inner.disconnect().await?;
307+
Ok(UnixSocket { inner: self.inner })
308+
}
309+
291310
/// Signifies whether the underlying socket was non-empty after the last
292311
/// receive operation.
293312
///

compio-net/tests/tcp_disconnect.rs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#![cfg(windows)]
2+
3+
use compio_io::AsyncWrite;
4+
use compio_net::{TcpListener, TcpStream};
5+
use compio_runtime::ResumeUnwind;
6+
7+
#[test]
8+
fn disconnect() {
9+
compio_runtime::Runtime::new().unwrap().block_on(async {
10+
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
11+
let addr = listener.local_addr().unwrap();
12+
let task = compio_runtime::spawn(async move {
13+
let (socket, _) = listener.accept().await.unwrap();
14+
let socket = socket.disconnect().await.unwrap();
15+
let (mut socket, _) = listener.accept_with(socket).await.unwrap();
16+
socket.shutdown().await.unwrap();
17+
socket.close().await.unwrap();
18+
});
19+
20+
for _i in 0..2 {
21+
let mut client = TcpStream::connect(addr).await.unwrap();
22+
client.shutdown().await.unwrap();
23+
client.close().await.unwrap();
24+
}
25+
26+
task.await.resume_unwind().expect("shouldn't be cancelled");
27+
})
28+
}
29+
30+
#[test]
31+
fn reuse() {
32+
compio_runtime::Runtime::new().unwrap().block_on(async {
33+
let listener1 = TcpListener::bind("127.0.0.1:0").await.unwrap();
34+
let addr1 = listener1.local_addr().unwrap();
35+
let listener2 = TcpListener::bind("127.0.0.1:0").await.unwrap();
36+
let addr2 = listener2.local_addr().unwrap();
37+
38+
let task = compio_runtime::spawn(async move {
39+
let (socket, _) = listener1.accept().await.unwrap();
40+
let socket = socket.disconnect().await.unwrap();
41+
let (mut socket, _) = listener2.accept_with(socket).await.unwrap();
42+
socket.shutdown().await.unwrap();
43+
socket.close().await.unwrap();
44+
});
45+
46+
let client = TcpStream::connect(addr1).await.unwrap();
47+
let client = client.disconnect().await.unwrap();
48+
let mut client = client.connect(addr2).await.unwrap();
49+
client.shutdown().await.unwrap();
50+
client.close().await.unwrap();
51+
52+
task.await.resume_unwind().expect("shouldn't be cancelled");
53+
})
54+
}

0 commit comments

Comments
 (0)