Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 45 additions & 21 deletions crates/invoker-impl/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,25 @@ use std::ops::RangeInclusive;
use std::time::Duration;
use tokio::task::JoinError;

#[derive(Debug)]
pub struct InvokerError {
pub kind: InvokerErrorKind,
// Deployment ID associated with the error, if any.
#[allow(dead_code)]
pub deployment_id: Option<DeploymentId>,
}

impl From<InvokerErrorKind> for InvokerError {
fn from(source: InvokerErrorKind) -> Self {
Self {
kind: source,
deployment_id: None,
}
}
}

#[derive(Debug, thiserror::Error, codederror::CodedError)]
pub(crate) enum InvokerError {
pub(crate) enum InvokerErrorKind {
#[error("no deployment was found to process the invocation")]
#[code(restate_errors::RT0011)]
NoDeploymentForService,
Expand Down Expand Up @@ -169,14 +186,21 @@ pub(crate) enum InvokerError {
ServiceUnavailable(http::StatusCode),
}

impl InvokerError {
impl InvokerErrorKind {
pub(crate) fn into_invoker_error(self, deployment_id: DeploymentId) -> InvokerError {
InvokerError {
kind: self,
deployment_id: Some(deployment_id),
}
}

pub(crate) fn error_stacktrace(&self) -> Option<&str> {
match self {
InvokerError::Sdk(s) => s
InvokerErrorKind::Sdk(s) => s
.error
.stacktrace()
.and_then(|s| if s.is_empty() { None } else { Some(s) }),
InvokerError::SdkV2(s) => s
InvokerErrorKind::SdkV2(s) => s
.error
.stacktrace()
.and_then(|s| if s.is_empty() { None } else { Some(s) }),
Expand All @@ -185,30 +209,30 @@ impl InvokerError {
}

pub(crate) fn is_transient(&self) -> bool {
!matches!(self, InvokerError::NotInvoked)
!matches!(self, InvokerErrorKind::NotInvoked)
}

pub(crate) fn should_bump_start_message_retry_count_since_last_stored_entry(&self) -> bool {
!matches!(
self,
InvokerError::NotInvoked
| InvokerError::JournalReader(_)
| InvokerError::StateReader(_)
| InvokerError::NoDeploymentForService
| InvokerError::BadNegotiatedServiceProtocolVersion(_)
| InvokerError::UnknownDeployment(_)
| InvokerError::ResumeWithWrongServiceProtocolVersion(_)
| InvokerError::IncompatibleServiceEndpoint(_, _)
InvokerErrorKind::NotInvoked
| InvokerErrorKind::JournalReader(_)
| InvokerErrorKind::StateReader(_)
| InvokerErrorKind::NoDeploymentForService
| InvokerErrorKind::BadNegotiatedServiceProtocolVersion(_)
| InvokerErrorKind::UnknownDeployment(_)
| InvokerErrorKind::ResumeWithWrongServiceProtocolVersion(_)
| InvokerErrorKind::IncompatibleServiceEndpoint(_, _)
)
}

pub(crate) fn next_retry_interval_override(&self) -> Option<Duration> {
match self {
InvokerError::Sdk(SdkInvocationError {
InvokerErrorKind::Sdk(SdkInvocationError {
next_retry_interval_override,
..
}) => *next_retry_interval_override,
InvokerError::SdkV2(SdkInvocationErrorV2 {
InvokerErrorKind::SdkV2(SdkInvocationErrorV2 {
next_retry_interval_override,
..
}) => *next_retry_interval_override,
Expand All @@ -218,9 +242,9 @@ impl InvokerError {

pub(crate) fn into_invocation_error(self) -> InvocationError {
match self {
InvokerError::Sdk(sdk_error) => sdk_error.error,
InvokerError::SdkV2(sdk_error) => sdk_error.error,
InvokerError::EntryEnrichment(entry_index, entry_type, e) => {
InvokerErrorKind::Sdk(sdk_error) => sdk_error.error,
InvokerErrorKind::SdkV2(sdk_error) => sdk_error.error,
InvokerErrorKind::EntryEnrichment(entry_index, entry_type, e) => {
let msg = format!(
"Error when processing entry {} of type {}: {}",
entry_index,
Expand All @@ -233,7 +257,7 @@ impl InvokerError {
}
err
}
e @ InvokerError::BadNegotiatedServiceProtocolVersion(_) => {
e @ InvokerErrorKind::BadNegotiatedServiceProtocolVersion(_) => {
InvocationError::new(codes::UNSUPPORTED_MEDIA_TYPE, e)
}
e => InvocationError::internal(e),
Expand All @@ -243,11 +267,11 @@ impl InvokerError {
pub(crate) fn into_invocation_error_report(mut self) -> InvocationErrorReport {
let doc_error_code = codederror::CodedError::code(&self);
let maybe_related_entry = match self {
InvokerError::Sdk(SdkInvocationError {
InvokerErrorKind::Sdk(SdkInvocationError {
ref mut related_entry,
..
}) => related_entry.take(),
InvokerError::SdkV2(SdkInvocationErrorV2 {
InvokerErrorKind::SdkV2(SdkInvocationErrorV2 {
related_command: ref mut related_entry,
..
}) => related_entry
Expand Down
80 changes: 54 additions & 26 deletions crates/invoker-impl/src/invocation_task/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ mod service_protocol_runner_v4;

use super::Notification;

use crate::error::InvokerError;
use crate::error::{InvokerError, InvokerErrorKind};
use crate::invocation_task::service_protocol_runner::ServiceProtocolRunner;
use crate::metric_definitions::INVOKER_TASK_DURATION;
use bytes::Bytes;
Expand Down Expand Up @@ -154,19 +154,34 @@ pub(super) struct InvocationTask<IR, EE, DMR> {
}

/// This is needed to split the run_internal in multiple loop functions and have shortcircuiting.
enum TerminalLoopState<T> {
enum TerminalLoopState<T, E = InvokerErrorKind> {
Continue(T),
Closed,
Suspended(HashSet<EntryIndex>),
SuspendedV2(HashSet<NotificationId>),
Failed(InvokerError),
Failed(E),
}

impl<T, E> TerminalLoopState<T, E> {
pub fn map_err<F>(self, f: impl FnOnce(E) -> F) -> TerminalLoopState<T, F> {
match self {
TerminalLoopState::Failed(e) => TerminalLoopState::Failed(f(e)),
TerminalLoopState::Closed => TerminalLoopState::Closed,
TerminalLoopState::Continue(v) => TerminalLoopState::Continue(v),
TerminalLoopState::Suspended(v) => TerminalLoopState::Suspended(v),
TerminalLoopState::SuspendedV2(v) => TerminalLoopState::SuspendedV2(v),
}
}
}

impl<T, E: Into<InvokerError>> From<Result<T, E>> for TerminalLoopState<T> {
impl<T, E, F> From<Result<T, E>> for TerminalLoopState<T, F>
where
F: From<E>,
{
fn from(value: Result<T, E>) -> Self {
match value {
Ok(v) => TerminalLoopState::Continue(v),
Err(e) => TerminalLoopState::Failed(e.into()),
Err(e) => TerminalLoopState::Failed(F::from(e)),
}
}
}
Expand Down Expand Up @@ -256,7 +271,7 @@ where
TerminalLoopState::Closed => InvocationTaskOutputInner::Closed,
TerminalLoopState::Suspended(v) => InvocationTaskOutputInner::Suspended(v),
TerminalLoopState::SuspendedV2(v) => InvocationTaskOutputInner::SuspendedV2(v),
TerminalLoopState::Failed(e) => InvocationTaskOutputInner::Failed(e),
TerminalLoopState::Failed(err) => InvocationTaskOutputInner::Failed(err),
};

self.send_invoker_tx(inner);
Expand All @@ -266,16 +281,16 @@ where
async fn select_protocol_version_and_run(
&mut self,
input_journal: InvokeInputJournal,
) -> TerminalLoopState<()> {
) -> TerminalLoopState<(), InvokerError> {
let mut txn = self.invocation_reader.transaction();
// Resolve journal and its metadata
let (journal_metadata, journal_stream) = match input_journal {
InvokeInputJournal::NoCachedJournal => {
let (journal_meta, journal_stream) = shortcircuit!(
txn.read_journal(&self.invocation_id)
.await
.map_err(|e| InvokerError::JournalReader(e.into()))
.and_then(|opt| opt.ok_or_else(|| InvokerError::NotInvoked))
.map_err(|e| InvokerErrorKind::JournalReader(e.into()))
.and_then(|opt| opt.ok_or_else(|| InvokerErrorKind::NotInvoked))
);
(journal_meta, future::Either::Left(journal_stream))
}
Expand All @@ -286,7 +301,7 @@ where
};

if self.invocation_epoch != journal_metadata.invocation_epoch {
shortcircuit!(Err(InvokerError::StaleJournalRead {
shortcircuit!(Err(InvokerErrorKind::StaleJournalRead {
actual: journal_metadata.invocation_epoch,
expected: self.invocation_epoch
}));
Expand All @@ -298,20 +313,28 @@ where
if let Some(pinned_deployment) = &journal_metadata.pinned_deployment {
// We have a pinned deployment that we can't change even if newer
// deployments have been registered for the same service.

// TODO(azmy): remove the deployment IDs from the error kind
// since now we have some variants in the InvokerErrorKind that already hold the deployment ID
// and we don't want to duplicate the deployment ID in the error kind
let deployment_metadata = shortcircuit!(
schemas
.get_deployment(&pinned_deployment.deployment_id)
.ok_or_else(|| InvokerError::UnknownDeployment(
.ok_or_else(|| InvokerErrorKind::UnknownDeployment(
pinned_deployment.deployment_id
))
)
.into_invoker_error(pinned_deployment.deployment_id))
);

// todo: We should support resuming an invocation with a newer protocol version if
// the endpoint supports it
if !pinned_deployment.service_protocol_version.is_supported() {
shortcircuit!(Err(InvokerError::ResumeWithWrongServiceProtocolVersion(
pinned_deployment.service_protocol_version
)));
shortcircuit!(Err(
InvokerErrorKind::ResumeWithWrongServiceProtocolVersion(
pinned_deployment.service_protocol_version
)
.into_invoker_error(pinned_deployment.deployment_id)
));
}

(
Expand All @@ -327,15 +350,15 @@ where
.resolve_latest_deployment_for_service(
self.invocation_target.service_name()
)
.ok_or(InvokerError::NoDeploymentForService)
.ok_or(InvokerErrorKind::NoDeploymentForService)
);

let chosen_service_protocol_version = shortcircuit!(
ServiceProtocolVersion::choose_max_supported_version(
&deployment.metadata.supported_protocol_versions,
)
.ok_or_else(|| {
InvokerError::IncompatibleServiceEndpoint(
InvokerErrorKind::IncompatibleServiceEndpoint(
deployment.id,
deployment.metadata.supported_protocol_versions.clone(),
)
Expand Down Expand Up @@ -376,16 +399,17 @@ where
shortcircuit!(
txn.read_state(&keyed_service_id.unwrap())
.await
.map_err(|e| InvokerError::StateReader(e.into()))
.map_err(|e| InvokerErrorKind::StateReader(e.into()))
.map(|r| r.map(itertools::Either::Left))
)
};

// No need to read from Rocksdb anymore
drop(txn);

let deployment_id = deployment.id;
self.send_invoker_tx(InvocationTaskOutputInner::PinnedDeployment(
PinnedDeployment::new(deployment.id, chosen_service_protocol_version),
PinnedDeployment::new(deployment_id, chosen_service_protocol_version),
deployment_changed,
));

Expand All @@ -396,6 +420,7 @@ where
service_protocol_runner
.run(journal_metadata, deployment, journal_stream, state_iter)
.await
.map_err(|f| f.into_invoker_error(deployment_id))
} else {
// Protocol runner for service protocol v4+
let service_protocol_runner = service_protocol_runner_v4::ServiceProtocolRunner::new(
Expand All @@ -405,6 +430,7 @@ where
service_protocol_runner
.run(journal_metadata, deployment, journal_stream, state_iter)
.await
.map_err(|f| f.into_invoker_error(deployment_id))
}
}
}
Expand Down Expand Up @@ -468,16 +494,16 @@ impl ResponseStreamState {
fn poll_only_headers(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<ResponseParts, InvokerError>> {
) -> Poll<Result<ResponseParts, InvokerErrorKind>> {
match self {
ResponseStreamState::WaitingHeaders(join_handle) => {
let http_response = match ready!(join_handle.poll_unpin(cx)) {
Ok(Ok(res)) => res,
Ok(Err(hyper_err)) => {
return Poll::Ready(Err(InvokerError::Client(Box::new(hyper_err))));
return Poll::Ready(Err(InvokerErrorKind::Client(Box::new(hyper_err))));
}
Err(join_err) => {
return Poll::Ready(Err(InvokerError::UnexpectedJoinError(join_err)));
return Poll::Ready(Err(InvokerErrorKind::UnexpectedJoinError(join_err)));
}
};

Expand All @@ -499,18 +525,20 @@ impl ResponseStreamState {
fn poll_next_chunk(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<ResponseChunk, InvokerError>> {
) -> Poll<Result<ResponseChunk, InvokerErrorKind>> {
// Could be replaced by a Stream implementation
loop {
match self {
ResponseStreamState::WaitingHeaders(join_handle) => {
let http_response = match ready!(join_handle.poll_unpin(cx)) {
Ok(Ok(res)) => res,
Ok(Err(hyper_err)) => {
return Poll::Ready(Err(InvokerError::Client(Box::new(hyper_err))));
return Poll::Ready(Err(InvokerErrorKind::Client(Box::new(hyper_err))));
}
Err(join_err) => {
return Poll::Ready(Err(InvokerError::UnexpectedJoinError(join_err)));
return Poll::Ready(Err(InvokerErrorKind::UnexpectedJoinError(
join_err,
)));
}
};

Expand All @@ -533,7 +561,7 @@ impl ResponseStreamState {
Ok(ResponseChunk::Data(frame.into_data().unwrap()))
}
Ok(_) => Ok(ResponseChunk::End),
Err(err) => Err(InvokerError::ClientBody(err)),
Err(err) => Err(InvokerErrorKind::ClientBody(err)),
});
}
}
Expand Down
Loading
Loading