Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 190 additions & 29 deletions src/pipeline/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,22 @@ use std::task::{Context, Poll};
use std::{error, fmt};
use tower_service::Service;

#[derive(Debug)]
struct Inner<T, S> {
transport: Option<T>,
service: Option<S>,
}

impl <T, S> Inner<T, S> {
fn take(&mut self) -> (Option<T>, Option<S>) {
(self.transport.take(), self.service.take())
}

fn is_populated(&self) -> bool {
self.transport.is_some() & self.service.is_some()
}
}

/// This type provides an implementation of a Tower
/// [`Service`](https://docs.rs/tokio-service/0.1/tokio_service/trait.Service.html) on top of a
/// request-at-a-time protocol transport. In particular, it wraps a transport that implements
Expand All @@ -17,14 +33,12 @@ use tower_service::Service;
#[derive(Debug)]
pub struct Server<T, S>
where
T: Sink<S::Response> + TryStream,
T: Sink<S::Response> + TryStream + Unpin,
S: Service<<T as TryStream>::Ok>,
{
#[pin]
pending: FuturesOrdered<S::Future>,
#[pin]
transport: T,
service: S,
inner: Inner<T, S>,

in_flight: usize,
finish: bool,
Expand All @@ -33,7 +47,7 @@ where
/// An error that occurred while servicing a request.
pub enum Error<T, S>
where
T: Sink<S::Response> + TryStream,
T: Sink<S::Response> + TryStream + Unpin,
S: Service<<T as TryStream>::Ok>,
{
/// The underlying transport failed to produce a request.
Expand All @@ -44,11 +58,14 @@ where

/// The underlying service failed to process a request.
Service(S::Error),

/// The future has completed or errored and should now be discarded.
CompletedOrErrored,
}

impl<T, S> fmt::Display for Error<T, S>
where
T: Sink<S::Response> + TryStream,
T: Sink<S::Response> + TryStream + Unpin,
S: Service<<T as TryStream>::Ok>,
<T as Sink<S::Response>>::Error: fmt::Display,
<T as TryStream>::Error: fmt::Display,
Expand All @@ -59,13 +76,14 @@ where
Error::BrokenTransportRecv(ref se) => fmt::Display::fmt(se, f),
Error::BrokenTransportSend(ref se) => fmt::Display::fmt(se, f),
Error::Service(ref se) => fmt::Display::fmt(se, f),
Error::CompletedOrErrored => write!(f, "Completed or errored future"),
}
}
}

impl<T, S> fmt::Debug for Error<T, S>
where
T: Sink<S::Response> + TryStream,
T: Sink<S::Response> + TryStream + Unpin,
S: Service<<T as TryStream>::Ok>,
<T as Sink<S::Response>>::Error: fmt::Debug,
<T as TryStream>::Error: fmt::Debug,
Expand All @@ -76,13 +94,14 @@ where
Error::BrokenTransportRecv(ref se) => write!(f, "BrokenTransportRecv({:?})", se),
Error::BrokenTransportSend(ref se) => write!(f, "BrokenTransportSend({:?})", se),
Error::Service(ref se) => write!(f, "Service({:?})", se),
Error::CompletedOrErrored => write!(f, "Completed or errored future"),
}
}
}

impl<T, S> error::Error for Error<T, S>
where
T: Sink<S::Response> + TryStream,
T: Sink<S::Response> + TryStream + Unpin,
S: Service<<T as TryStream>::Ok>,
<T as Sink<S::Response>>::Error: error::Error,
<T as TryStream>::Error: error::Error,
Expand All @@ -93,6 +112,7 @@ where
Error::BrokenTransportSend(ref se) => Some(se),
Error::BrokenTransportRecv(ref se) => Some(se),
Error::Service(ref se) => Some(se),
Error::CompletedOrErrored => None,
}
}

Expand All @@ -102,13 +122,14 @@ where
Error::BrokenTransportSend(ref se) => se.description(),
Error::BrokenTransportRecv(ref se) => se.description(),
Error::Service(ref se) => se.description(),
Error::CompletedOrErrored => "Completed or errored future",
}
}
}

impl<T, S> Error<T, S>
where
T: Sink<S::Response> + TryStream,
T: Sink<S::Response> + TryStream + Unpin,
S: Service<<T as TryStream>::Ok>,
{
fn from_sink_error(e: <T as Sink<S::Response>>::Error) -> Self {
Expand All @@ -124,9 +145,112 @@ where
}
}

impl <T, S> From<ErrorWithInner<T, S, Error<T, S>>> for Error<T, S>
where
T: Sink<S::Response> + TryStream + Unpin,
S: Service<<T as TryStream>::Ok>,
<T as Sink<S::Response>>::Error: error::Error,
<T as TryStream>::Error: error::Error,
S::Error: error::Error,
{
fn from(source: ErrorWithInner<T, S, Error<T, S>>) -> Self {
let ErrorWithInner { error, .. } = source;
error
}
}

trait MapErrWithInner<I, E>
where E: fmt::Display + fmt::Debug + error::Error,
{
fn map_err_with_inner<T, S>(self, inner: &mut Inner<T, S>)
-> Result<I, ErrorWithInner<T, S, E>>;
}

impl <I, E> MapErrWithInner<I, E> for Result<I, E>
where E: fmt::Display + fmt::Debug + error::Error,
{
fn map_err_with_inner<T, S>(self, inner: &mut Inner<T, S>)
-> Result<I, ErrorWithInner<T, S, E>>
{
match self {
Ok(t) => Ok(t),
Err(e) => {
let (transport, service) = inner.take();
Err(ErrorWithInner { error: e, transport, service })
}
}
}
}

trait MapPollErrWithInner<I, E>
where E: fmt::Display + fmt::Debug + error::Error,
{
fn map_err_with_inner<T, S>(self, inner: &mut Inner<T, S>)
-> Poll<Result<I, ErrorWithInner<T, S, E>>>;
}

impl <I, E> MapPollErrWithInner<I, E> for Poll<Result<I, E>>
where E: fmt::Display + fmt::Debug + error::Error,
{
fn map_err_with_inner<T, S>(self, inner: &mut Inner<T, S>)
-> Poll<Result<I, ErrorWithInner<T, S, E>>>
{
match self {
Poll::Ready(res) => Poll::Ready(res.map_err_with_inner(inner)),
Poll::Pending => Poll::Pending,
}
}
}

/// Error type encapsulates the inner transport and service as well as the
/// error.
pub struct ErrorWithInner<T, S, E>
{
/// Wrapped error
pub error: E,
/// Inner transport
pub transport: Option<T>,
/// Inner service
pub service: Option<S>,
}

impl<T, S, E> fmt::Display for ErrorWithInner<T, S, E>
where
E: fmt::Display + error::Error,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
<E as std::fmt::Display>::fmt(&self.error, f)
}
}

impl<T, S, E> fmt::Debug for ErrorWithInner<T, S, E>
where
E: fmt::Debug + error::Error,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
<E as std::fmt::Debug>::fmt(&self.error, f)
}
}

impl<T, S, E> error::Error for ErrorWithInner<T, S, E>
where
E: error::Error,
{
#[allow(deprecated)]
fn cause(&self) -> Option<&dyn error::Error> {
<E as error::Error>::cause(&self.error)
}

#[allow(deprecated)]
fn description(&self) -> &str {
<E as error::Error>::description(&self.error)
}
}


impl<T, S> Server<T, S>
where
T: Sink<S::Response> + TryStream,
T: Sink<S::Response> + TryStream + Unpin,
S: Service<<T as TryStream>::Ok>,
{
/// Construct a new [`Server`] over the given `transport` that services requests using the
Expand All @@ -139,8 +263,7 @@ where
pub fn new(transport: T, service: S) -> Self {
Server {
pending: FuturesOrdered::new(),
transport,
service,
inner: Inner { transport: Some(transport), service: Some(service) },
in_flight: 0,
finish: false,
}
Expand Down Expand Up @@ -178,10 +301,13 @@ where

impl<T, S> Future for Server<T, S>
where
T: Sink<S::Response> + TryStream,
T: Sink<S::Response> + TryStream + Unpin,
S: Service<<T as TryStream>::Ok>,
<T as Sink<S::Response>>::Error: error::Error,
<T as TryStream>::Error: error::Error,
<S as Service<<T as TryStream>::Ok>>::Error: error::Error,
{
type Output = Result<(), Error<T, S>>;
type Output = Result<(Option<T>, Option<S>), ErrorWithInner<T, S, Error<T, S>>>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let span = tracing::trace_span!("poll");
Expand All @@ -191,19 +317,31 @@ where
// go through the deref so we can do partial borrows
let this = self.project();

// we never move transport or pending, nor do we ever hand out &mut to it
let mut transport: Pin<_> = this.transport;
// we never move pending, nor do we ever hand out &mut to it
let mut pending: Pin<_> = this.pending;

let inner: &mut Inner<T, S> = this.inner;

if !inner.is_populated() {
return Poll::Ready(
Err(Error::CompletedOrErrored)
.map_err_with_inner(inner)
)
}

// track how many times we have iterated
let mut i = 0;

loop {
// first, poll pending futures to see if any have produced responses
// note that we only poll for completed service futures if we can send the response
while let Poll::Ready(r) = transport.as_mut().poll_ready(cx) {
while let Poll::Ready(r) = Pin::new(inner.transport.as_mut().unwrap()).poll_ready(cx) {
if let Err(e) = r {
return Poll::Ready(Err(Error::from_sink_error(e)));
return Poll::Ready(
Err(Error::from_sink_error(e))
.map_err_with_inner(inner)
);
//return Poll::Ready(Err(Error::from_sink_error(e)));
}

tracing::trace!(
Expand All @@ -213,15 +351,18 @@ where
);
match pending.as_mut().try_poll_next(cx) {
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Err(Error::from_service_error(e)));
return Poll::Ready(
Err(Error::from_service_error(e))
.map_err_with_inner(inner)
);
}
Poll::Ready(Some(Ok(rsp))) => {
tracing::trace!("transport.start_send");
// try to send the response!
transport
.as_mut()
Pin::new(inner.transport.as_mut().unwrap())
.start_send(rsp)
.map_err(Error::from_sink_error)?;
.map_err(Error::from_sink_error)
.map_err_with_inner(inner)?;
*this.in_flight -= 1;
}
_ => {
Expand All @@ -233,15 +374,16 @@ where

// also try to make progress on sending
tracing::trace!(finish = *this.finish, "transport.poll_flush");
if let Poll::Ready(()) = transport
if let Poll::Ready(()) = Pin::new(inner.transport.as_mut().unwrap())
.as_mut()
.poll_flush(cx)
.map_err(Error::from_sink_error)?
.map_err(Error::from_sink_error)
.map_err_with_inner(inner)?
{
if *this.finish && pending.as_mut().is_empty() {
// there are no more requests
// and we've finished all the work!
return Poll::Ready(Ok(()));
return Poll::Ready(Ok(inner.take()));
}
}

Expand All @@ -262,16 +404,35 @@ where

// is the service ready?
tracing::trace!("service.poll_ready");
ready!(this.service.poll_ready(cx)).map_err(Error::from_service_error)?;
ready!(inner
.service
.as_mut()
.unwrap()
.poll_ready(cx)
).map_err(Error::from_service_error).map_err_with_inner(inner)?;

tracing::trace!("transport.poll_next");
let rq = ready!(transport.as_mut().try_poll_next(cx))
let rq = ready!(
Pin::new(
inner
.transport
.as_mut()
.unwrap()
).try_poll_next(cx))
.transpose()
.map_err(Error::from_stream_error)?;
.map_err(Error::from_stream_error)
.map_err_with_inner(inner)?;

if let Some(rq) = rq {
// the service is ready, and we have another request!
// you know what that means:
pending.push(this.service.call(rq));
pending.push(
inner
.service
.as_mut()
.unwrap()
.call(rq)
);
*this.in_flight += 1;
} else {
// there are no more requests coming
Expand Down