Skip to content

avoid waker clones #3748

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bindings/rust/integration/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ s2n-tls = { path = "../s2n-tls", features = ["testing"] }
s2n-tls-sys = { path = "../s2n-tls-sys" }
criterion = { version = "0.3", features = ["html_reports"] }
anyhow = "1"
futures-test = "0.3"

[[bench]]
name = "handshake"
Expand Down
3 changes: 2 additions & 1 deletion bindings/rust/integration/benches/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0

use criterion::{criterion_group, criterion_main, Criterion};
use futures_test::task::noop_context;
use s2n_tls::{
security,
testing::{build_config, s2n_tls_pair},
Expand All @@ -15,7 +16,7 @@ pub fn handshake(c: &mut Criterion) {
group.bench_function(format!("handshake_{:?}", policy), move |b| {
// This does include connection initalization overhead.
// TODO: create a separate benchamrk that excludes this step.
b.iter(|| s2n_tls_pair(config.clone()));
b.iter(|| s2n_tls_pair(config.clone(), &mut noop_context()));
});
}

Expand Down
101 changes: 49 additions & 52 deletions bindings/rust/s2n-tls-tokio/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,18 +105,17 @@ where
{
type Output = Result<(), Error>;

fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// Retrieve a result, either from the stored error
// or by polling Connection::poll_negotiate().
// Connection::poll_negotiate() only completes once,
// regardless of how often this method is polled.
let result = match self.error.take() {
Some(err) => Err(err),
None => {
ready!(self.tls.with_io(ctx, |context| {
let conn = context.get_mut().as_mut();
conn.poll_negotiate().map(|r| r.map(|_| ()))
}))
ready!(self
.tls
.with_io(|conn| { conn.poll_negotiate(cx).map(|r| r.map(|_| ())) }))
}
};
// If the result isn't a fatal error, return it immediately.
Expand All @@ -128,7 +127,7 @@ where
match result {
Ok(r) => Ok(r).into(),
Err(e) if e.is_retryable() => Err(e).into(),
Err(e) => match Pin::new(&mut self.tls).poll_shutdown(ctx) {
Err(e) => match Pin::new(&mut self.tls).poll_shutdown(cx) {
Pending => {
self.error = Some(e);
Pending
Expand Down Expand Up @@ -181,30 +180,30 @@ where
Ok(tls)
}

fn with_io<F, R>(&mut self, ctx: &mut Context, action: F) -> Poll<Result<R, Error>>
fn with_io<F, R>(&mut self, action: F) -> Poll<Result<R, Error>>
where
F: FnOnce(Pin<&mut Self>) -> Poll<Result<R, Error>>,
F: FnOnce(&mut Connection) -> Poll<Result<R, Error>>,
{
// Setting contexts on a connection is considered unsafe
// because the raw pointers provide no lifetime or memory guarantees.
// We protect against this by pinning the stream during the action
// and clearing the context afterwards.
unsafe {
let context = self as *mut Self as *mut c_void;
let conn = self.as_mut();

self.as_mut().set_receive_callback(Some(Self::recv_io_cb))?;
self.as_mut().set_send_callback(Some(Self::send_io_cb))?;
self.as_mut().set_receive_context(context)?;
self.as_mut().set_send_context(context)?;
self.as_mut().set_waker(Some(ctx.waker()))?;
conn.set_receive_callback(Some(Self::recv_io_cb))?;
conn.set_send_callback(Some(Self::send_io_cb))?;
conn.set_receive_context(context)?;
conn.set_send_context(context)?;

let result = action(Pin::new(self));
let result = action(conn);

conn.set_receive_callback(None)?;
conn.set_send_callback(None)?;
conn.set_receive_context(std::ptr::null_mut())?;
conn.set_send_context(std::ptr::null_mut())?;

self.as_mut().set_receive_callback(None)?;
self.as_mut().set_send_callback(None)?;
self.as_mut().set_receive_context(std::ptr::null_mut())?;
self.as_mut().set_send_context(std::ptr::null_mut())?;
self.as_mut().set_waker(None)?;
result
}
}
Expand All @@ -215,18 +214,20 @@ where
{
debug_assert_ne!(ctx, std::ptr::null_mut());
let tls = unsafe { &mut *(ctx as *mut Self) };

let mut async_context = Context::from_waker(tls.conn.as_ref().waker().unwrap());
let stream = Pin::new(&mut tls.stream);

match action(stream, &mut async_context) {
Poll::Ready(Ok(len)) => len as c_int,
Poll::Pending => {
set_errno(Errno(libc::EWOULDBLOCK));
CallbackResult::Failure.into()
tls.conn.as_mut().with_async_context(|async_context| {
let async_context = async_context.unwrap();

match action(stream, async_context) {
Poll::Ready(Ok(len)) => len as c_int,
Poll::Pending => {
set_errno(Errno(libc::EWOULDBLOCK));
CallbackResult::Failure.into()
}
_ => CallbackResult::Failure.into(),
}
_ => CallbackResult::Failure.into(),
}
})
}

unsafe extern "C" fn recv_io_cb(ctx: *mut c_void, buf: *mut u8, len: u32) -> c_int {
Expand Down Expand Up @@ -288,14 +289,14 @@ where
/// internally) returns ready.
pub fn poll_blinding(
mut self: Pin<&mut Self>,
ctx: &mut Context<'_>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Error>> {
self.as_mut().set_blinding_timer(Ok(()))?;

let tls = self.get_mut();

if let Some(blinding) = &mut tls.blinding {
ready!(blinding.as_mut().project().timer.as_mut().poll(ctx));
ready!(blinding.as_mut().project().timer.as_mut().poll(cx));

// Set blinding to None to ensure the next go can have blinding
let mut blinding = tls.blinding.take().unwrap();
Expand Down Expand Up @@ -340,17 +341,15 @@ where
{
fn poll_read(
self: Pin<&mut Self>,
ctx: &mut Context<'_>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let tls = self.get_mut();
tls.with_io(ctx, |mut context| {
tls.with_io(|context| {
context
.conn
.as_mut()
// Safe since poll_recv_uninitialized does not
// deinitialize any bytes.
.poll_recv_uninitialized(unsafe { buf.unfilled_mut() })
.poll_recv_uninitialized(cx, unsafe { buf.unfilled_mut() })
.map_ok(|size| {
unsafe {
// Safe since poll_recv_uninitialized guaranteed
Expand All @@ -372,41 +371,39 @@ where
{
fn poll_write(
self: Pin<&mut Self>,
ctx: &mut Context<'_>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let tls = self.get_mut();
tls.with_io(ctx, |mut context| context.conn.as_mut().poll_send(buf))
tls.with_io(|context| context.poll_send(cx, buf))
.map_err(io::Error::from)
}

fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let tls = self.get_mut();

ready!(tls.with_io(ctx, |mut context| {
context.conn.as_mut().poll_flush().map(|r| r.map(|_| ()))
}))
.map_err(io::Error::from)?;
ready!(tls.with_io(|context| { context.poll_flush(cx).map(|r| r.map(|_| ())) }))
.map_err(io::Error::from)?;

Pin::new(&mut tls.stream).poll_flush(ctx)
Pin::new(&mut tls.stream).poll_flush(cx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
ready!(self.as_mut().poll_blinding(ctx))?;
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
ready!(self.as_mut().poll_blinding(cx))?;

let status = ready!(self.as_mut().with_io(ctx, |mut context| {
context.conn.as_mut().poll_shutdown().map(|r| r.map(|_| ()))
}));
let status = ready!(self
.as_mut()
.with_io(|context| { context.poll_shutdown(cx).map(|r| r.map(|_| ())) }));

if let Err(e) = status {
// In case of an error shutting down, make sure you wait for
// the blinding timeout.
self.as_mut().set_blinding_timer(Err(e))?;
ready!(self.as_mut().poll_blinding(ctx))?;
ready!(self.as_mut().poll_blinding(cx))?;
unreachable!("should have returned the error we just put in!");
}

Pin::new(&mut self.as_mut().stream).poll_shutdown(ctx)
Pin::new(&mut self.as_mut().stream).poll_shutdown(cx)
}
}

Expand Down Expand Up @@ -437,7 +434,7 @@ where
{
type Output = Result<(), Error>;

fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut *self.as_mut().stream).poll_blinding(ctx)
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut *self.as_mut().stream).poll_blinding(cx)
}
}
4 changes: 2 additions & 2 deletions bindings/rust/s2n-tls/src/callbacks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
//! [`VerifyHostNameCallback`] as an example.
//! * "async" callbacks return a [Poll](`core::task::Poll`) and should not block the task performing the handshake.
//! They will be polled until they return [Poll::Ready](`core::task::Poll::Ready`).
//! [Connection::waker()](`crate::connection::Connection::waker()`)
//! can be used to register the task for wakeup. See [`ClientHelloCallback`] as an example.
//! The `cx` argument to `poll` can be used to register the task for wakeup.
//! See [`ClientHelloCallback`] as an example.

use crate::{config::Context, connection::Connection};
use core::{mem::ManuallyDrop, ptr::NonNull, time::Duration};
Expand Down
28 changes: 14 additions & 14 deletions bindings/rust/s2n-tls/src/callbacks/pkey.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ mod tests {
config, connection, error, security, testing,
testing::{s2n_tls::*, *},
};
use core::task::{Poll, Waker};
use core::task::Poll;
use futures_test::task::new_count_waker;
use openssl::{ec::EcKey, ecdsa::EcdsaSig};

Expand All @@ -140,10 +140,7 @@ mod tests {
"/../../../tests/pems/ecdsa_p384_pkcs1_cert.pem"
));

fn new_pair<T>(
callback: T,
waker: Waker,
) -> Result<Pair<s2n_tls::Harness, s2n_tls::Harness>, Error>
fn new_pair<T>(callback: T) -> Result<Pair<s2n_tls::Harness, s2n_tls::Harness>, Error>
where
T: 'static + PrivateKeyCallback,
{
Expand All @@ -161,7 +158,6 @@ mod tests {
let server = {
let mut server = connection::Connection::new_server();
server.set_config(config.clone())?;
server.set_waker(Some(&waker))?;
Harness::new(server)
};

Expand Down Expand Up @@ -213,13 +209,14 @@ mod tests {
}

let (waker, wake_count) = new_count_waker();
let mut cx = std::task::Context::from_waker(&waker);
let counter = testing::Counter::default();
let callback = TestPkeyCallback(counter.clone());
let pair = new_pair(callback, waker)?;
let pair = new_pair(callback)?;

assert_eq!(counter.count(), 0);
assert_eq!(wake_count, 0);
poll_tls_pair(pair);
poll_tls_pair(pair, &mut cx);
assert_eq!(counter.count(), 1);
assert_eq!(wake_count, 0);

Expand Down Expand Up @@ -271,13 +268,14 @@ mod tests {
}

let (waker, wake_count) = new_count_waker();
let mut cx = std::task::Context::from_waker(&waker);
let counter = testing::Counter::default();
let callback = TestPkeyCallback(counter.clone());
let pair = new_pair(callback, waker)?;
let pair = new_pair(callback)?;

assert_eq!(counter.count(), 0);
assert_eq!(wake_count, 0);
poll_tls_pair(pair);
poll_tls_pair(pair, &mut cx);
assert_eq!(counter.count(), 1);
assert_eq!(wake_count, POLL_COUNT);

Expand All @@ -301,13 +299,14 @@ mod tests {
}

let (waker, wake_count) = new_count_waker();
let mut cx = std::task::Context::from_waker(&waker);
let counter = testing::Counter::default();
let callback = TestPkeyCallback(counter.clone());
let pair = new_pair(callback, waker)?;
let pair = new_pair(callback)?;

assert_eq!(counter.count(), 0);
assert_eq!(wake_count, 0);
let result = poll_tls_pair_result(pair);
let result = poll_tls_pair_result(pair, &mut cx);
assert_eq!(counter.count(), 1);
assert_eq!(wake_count, 0);

Expand Down Expand Up @@ -357,13 +356,14 @@ mod tests {
}

let (waker, wake_count) = new_count_waker();
let mut cx = std::task::Context::from_waker(&waker);
let counter = testing::Counter::default();
let callback = TestPkeyCallback(counter.clone());
let pair = new_pair(callback, waker)?;
let pair = new_pair(callback)?;

assert_eq!(counter.count(), 0);
assert_eq!(wake_count, 0);
let result = poll_tls_pair_result(pair);
let result = poll_tls_pair_result(pair, &mut cx);
assert_eq!(counter.count(), 1);
assert_eq!(wake_count, POLL_COUNT);

Expand Down
Loading