diff --git a/crates/invoker-impl/src/error.rs b/crates/invoker-impl/src/error.rs index 4fc6ac9a6c..571f994015 100644 --- a/crates/invoker-impl/src/error.rs +++ b/crates/invoker-impl/src/error.rs @@ -29,8 +29,25 @@ use std::ops::RangeInclusive; use std::time::Duration; use tokio::task::JoinError; +#[derive(Debug)] +pub struct InvokerError { + pub kind: InvokerErrorKind, + // Deployment ID associated with the error, if any. + #[allow(dead_code)] + pub deployment_id: Option, +} + +impl From for InvokerError { + fn from(source: InvokerErrorKind) -> Self { + Self { + kind: source, + deployment_id: None, + } + } +} + #[derive(Debug, thiserror::Error, codederror::CodedError)] -pub(crate) enum InvokerError { +pub(crate) enum InvokerErrorKind { #[error("no deployment was found to process the invocation")] #[code(restate_errors::RT0011)] NoDeploymentForService, @@ -169,14 +186,21 @@ pub(crate) enum InvokerError { ServiceUnavailable(http::StatusCode), } -impl InvokerError { +impl InvokerErrorKind { + pub(crate) fn into_invoker_error(self, deployment_id: DeploymentId) -> InvokerError { + InvokerError { + kind: self, + deployment_id: Some(deployment_id), + } + } + pub(crate) fn error_stacktrace(&self) -> Option<&str> { match self { - InvokerError::Sdk(s) => s + InvokerErrorKind::Sdk(s) => s .error .stacktrace() .and_then(|s| if s.is_empty() { None } else { Some(s) }), - InvokerError::SdkV2(s) => s + InvokerErrorKind::SdkV2(s) => s .error .stacktrace() .and_then(|s| if s.is_empty() { None } else { Some(s) }), @@ -185,30 +209,30 @@ impl InvokerError { } pub(crate) fn is_transient(&self) -> bool { - !matches!(self, InvokerError::NotInvoked) + !matches!(self, InvokerErrorKind::NotInvoked) } pub(crate) fn should_bump_start_message_retry_count_since_last_stored_entry(&self) -> bool { !matches!( self, - InvokerError::NotInvoked - | InvokerError::JournalReader(_) - | InvokerError::StateReader(_) - | InvokerError::NoDeploymentForService - | InvokerError::BadNegotiatedServiceProtocolVersion(_) - | InvokerError::UnknownDeployment(_) - | InvokerError::ResumeWithWrongServiceProtocolVersion(_) - | InvokerError::IncompatibleServiceEndpoint(_, _) + InvokerErrorKind::NotInvoked + | InvokerErrorKind::JournalReader(_) + | InvokerErrorKind::StateReader(_) + | InvokerErrorKind::NoDeploymentForService + | InvokerErrorKind::BadNegotiatedServiceProtocolVersion(_) + | InvokerErrorKind::UnknownDeployment(_) + | InvokerErrorKind::ResumeWithWrongServiceProtocolVersion(_) + | InvokerErrorKind::IncompatibleServiceEndpoint(_, _) ) } pub(crate) fn next_retry_interval_override(&self) -> Option { match self { - InvokerError::Sdk(SdkInvocationError { + InvokerErrorKind::Sdk(SdkInvocationError { next_retry_interval_override, .. }) => *next_retry_interval_override, - InvokerError::SdkV2(SdkInvocationErrorV2 { + InvokerErrorKind::SdkV2(SdkInvocationErrorV2 { next_retry_interval_override, .. }) => *next_retry_interval_override, @@ -218,9 +242,9 @@ impl InvokerError { pub(crate) fn into_invocation_error(self) -> InvocationError { match self { - InvokerError::Sdk(sdk_error) => sdk_error.error, - InvokerError::SdkV2(sdk_error) => sdk_error.error, - InvokerError::EntryEnrichment(entry_index, entry_type, e) => { + InvokerErrorKind::Sdk(sdk_error) => sdk_error.error, + InvokerErrorKind::SdkV2(sdk_error) => sdk_error.error, + InvokerErrorKind::EntryEnrichment(entry_index, entry_type, e) => { let msg = format!( "Error when processing entry {} of type {}: {}", entry_index, @@ -233,7 +257,7 @@ impl InvokerError { } err } - e @ InvokerError::BadNegotiatedServiceProtocolVersion(_) => { + e @ InvokerErrorKind::BadNegotiatedServiceProtocolVersion(_) => { InvocationError::new(codes::UNSUPPORTED_MEDIA_TYPE, e) } e => InvocationError::internal(e), @@ -243,11 +267,11 @@ impl InvokerError { pub(crate) fn into_invocation_error_report(mut self) -> InvocationErrorReport { let doc_error_code = codederror::CodedError::code(&self); let maybe_related_entry = match self { - InvokerError::Sdk(SdkInvocationError { + InvokerErrorKind::Sdk(SdkInvocationError { ref mut related_entry, .. }) => related_entry.take(), - InvokerError::SdkV2(SdkInvocationErrorV2 { + InvokerErrorKind::SdkV2(SdkInvocationErrorV2 { related_command: ref mut related_entry, .. }) => related_entry diff --git a/crates/invoker-impl/src/invocation_task/mod.rs b/crates/invoker-impl/src/invocation_task/mod.rs index 62caffa681..cb457ce501 100644 --- a/crates/invoker-impl/src/invocation_task/mod.rs +++ b/crates/invoker-impl/src/invocation_task/mod.rs @@ -13,7 +13,7 @@ mod service_protocol_runner_v4; use super::Notification; -use crate::error::InvokerError; +use crate::error::{InvokerError, InvokerErrorKind}; use crate::invocation_task::service_protocol_runner::ServiceProtocolRunner; use crate::metric_definitions::INVOKER_TASK_DURATION; use bytes::Bytes; @@ -154,19 +154,34 @@ pub(super) struct InvocationTask { } /// This is needed to split the run_internal in multiple loop functions and have shortcircuiting. -enum TerminalLoopState { +enum TerminalLoopState { Continue(T), Closed, Suspended(HashSet), SuspendedV2(HashSet), - Failed(InvokerError), + Failed(E), +} + +impl TerminalLoopState { + pub fn map_err(self, f: impl FnOnce(E) -> F) -> TerminalLoopState { + match self { + TerminalLoopState::Failed(e) => TerminalLoopState::Failed(f(e)), + TerminalLoopState::Closed => TerminalLoopState::Closed, + TerminalLoopState::Continue(v) => TerminalLoopState::Continue(v), + TerminalLoopState::Suspended(v) => TerminalLoopState::Suspended(v), + TerminalLoopState::SuspendedV2(v) => TerminalLoopState::SuspendedV2(v), + } + } } -impl> From> for TerminalLoopState { +impl From> for TerminalLoopState +where + F: From, +{ fn from(value: Result) -> Self { match value { Ok(v) => TerminalLoopState::Continue(v), - Err(e) => TerminalLoopState::Failed(e.into()), + Err(e) => TerminalLoopState::Failed(F::from(e)), } } } @@ -256,7 +271,7 @@ where TerminalLoopState::Closed => InvocationTaskOutputInner::Closed, TerminalLoopState::Suspended(v) => InvocationTaskOutputInner::Suspended(v), TerminalLoopState::SuspendedV2(v) => InvocationTaskOutputInner::SuspendedV2(v), - TerminalLoopState::Failed(e) => InvocationTaskOutputInner::Failed(e), + TerminalLoopState::Failed(err) => InvocationTaskOutputInner::Failed(err), }; self.send_invoker_tx(inner); @@ -266,7 +281,7 @@ where async fn select_protocol_version_and_run( &mut self, input_journal: InvokeInputJournal, - ) -> TerminalLoopState<()> { + ) -> TerminalLoopState<(), InvokerError> { let mut txn = self.invocation_reader.transaction(); // Resolve journal and its metadata let (journal_metadata, journal_stream) = match input_journal { @@ -274,8 +289,8 @@ where let (journal_meta, journal_stream) = shortcircuit!( txn.read_journal(&self.invocation_id) .await - .map_err(|e| InvokerError::JournalReader(e.into())) - .and_then(|opt| opt.ok_or_else(|| InvokerError::NotInvoked)) + .map_err(|e| InvokerErrorKind::JournalReader(e.into())) + .and_then(|opt| opt.ok_or_else(|| InvokerErrorKind::NotInvoked)) ); (journal_meta, future::Either::Left(journal_stream)) } @@ -286,7 +301,7 @@ where }; if self.invocation_epoch != journal_metadata.invocation_epoch { - shortcircuit!(Err(InvokerError::StaleJournalRead { + shortcircuit!(Err(InvokerErrorKind::StaleJournalRead { actual: journal_metadata.invocation_epoch, expected: self.invocation_epoch })); @@ -298,20 +313,28 @@ where if let Some(pinned_deployment) = &journal_metadata.pinned_deployment { // We have a pinned deployment that we can't change even if newer // deployments have been registered for the same service. + + // TODO(azmy): remove the deployment IDs from the error kind + // since now we have some variants in the InvokerErrorKind that already hold the deployment ID + // and we don't want to duplicate the deployment ID in the error kind let deployment_metadata = shortcircuit!( schemas .get_deployment(&pinned_deployment.deployment_id) - .ok_or_else(|| InvokerError::UnknownDeployment( + .ok_or_else(|| InvokerErrorKind::UnknownDeployment( pinned_deployment.deployment_id - )) + ) + .into_invoker_error(pinned_deployment.deployment_id)) ); // todo: We should support resuming an invocation with a newer protocol version if // the endpoint supports it if !pinned_deployment.service_protocol_version.is_supported() { - shortcircuit!(Err(InvokerError::ResumeWithWrongServiceProtocolVersion( - pinned_deployment.service_protocol_version - ))); + shortcircuit!(Err( + InvokerErrorKind::ResumeWithWrongServiceProtocolVersion( + pinned_deployment.service_protocol_version + ) + .into_invoker_error(pinned_deployment.deployment_id) + )); } ( @@ -327,7 +350,7 @@ where .resolve_latest_deployment_for_service( self.invocation_target.service_name() ) - .ok_or(InvokerError::NoDeploymentForService) + .ok_or(InvokerErrorKind::NoDeploymentForService) ); let chosen_service_protocol_version = shortcircuit!( @@ -335,7 +358,7 @@ where &deployment.metadata.supported_protocol_versions, ) .ok_or_else(|| { - InvokerError::IncompatibleServiceEndpoint( + InvokerErrorKind::IncompatibleServiceEndpoint( deployment.id, deployment.metadata.supported_protocol_versions.clone(), ) @@ -376,7 +399,7 @@ where shortcircuit!( txn.read_state(&keyed_service_id.unwrap()) .await - .map_err(|e| InvokerError::StateReader(e.into())) + .map_err(|e| InvokerErrorKind::StateReader(e.into())) .map(|r| r.map(itertools::Either::Left)) ) }; @@ -384,8 +407,9 @@ where // No need to read from Rocksdb anymore drop(txn); + let deployment_id = deployment.id; self.send_invoker_tx(InvocationTaskOutputInner::PinnedDeployment( - PinnedDeployment::new(deployment.id, chosen_service_protocol_version), + PinnedDeployment::new(deployment_id, chosen_service_protocol_version), deployment_changed, )); @@ -396,6 +420,7 @@ where service_protocol_runner .run(journal_metadata, deployment, journal_stream, state_iter) .await + .map_err(|f| f.into_invoker_error(deployment_id)) } else { // Protocol runner for service protocol v4+ let service_protocol_runner = service_protocol_runner_v4::ServiceProtocolRunner::new( @@ -405,6 +430,7 @@ where service_protocol_runner .run(journal_metadata, deployment, journal_stream, state_iter) .await + .map_err(|f| f.into_invoker_error(deployment_id)) } } } @@ -468,16 +494,16 @@ impl ResponseStreamState { fn poll_only_headers( &mut self, cx: &mut Context<'_>, - ) -> Poll> { + ) -> Poll> { match self { ResponseStreamState::WaitingHeaders(join_handle) => { let http_response = match ready!(join_handle.poll_unpin(cx)) { Ok(Ok(res)) => res, Ok(Err(hyper_err)) => { - return Poll::Ready(Err(InvokerError::Client(Box::new(hyper_err)))); + return Poll::Ready(Err(InvokerErrorKind::Client(Box::new(hyper_err)))); } Err(join_err) => { - return Poll::Ready(Err(InvokerError::UnexpectedJoinError(join_err))); + return Poll::Ready(Err(InvokerErrorKind::UnexpectedJoinError(join_err))); } }; @@ -499,7 +525,7 @@ impl ResponseStreamState { fn poll_next_chunk( &mut self, cx: &mut Context<'_>, - ) -> Poll> { + ) -> Poll> { // Could be replaced by a Stream implementation loop { match self { @@ -507,10 +533,12 @@ impl ResponseStreamState { let http_response = match ready!(join_handle.poll_unpin(cx)) { Ok(Ok(res)) => res, Ok(Err(hyper_err)) => { - return Poll::Ready(Err(InvokerError::Client(Box::new(hyper_err)))); + return Poll::Ready(Err(InvokerErrorKind::Client(Box::new(hyper_err)))); } Err(join_err) => { - return Poll::Ready(Err(InvokerError::UnexpectedJoinError(join_err))); + return Poll::Ready(Err(InvokerErrorKind::UnexpectedJoinError( + join_err, + ))); } }; @@ -533,7 +561,7 @@ impl ResponseStreamState { Ok(ResponseChunk::Data(frame.into_data().unwrap())) } Ok(_) => Ok(ResponseChunk::End), - Err(err) => Err(InvokerError::ClientBody(err)), + Err(err) => Err(InvokerErrorKind::ClientBody(err)), }); } } diff --git a/crates/invoker-impl/src/invocation_task/service_protocol_runner.rs b/crates/invoker-impl/src/invocation_task/service_protocol_runner.rs index 747ef15ff5..efdef770cd 100644 --- a/crates/invoker-impl/src/invocation_task/service_protocol_runner.rs +++ b/crates/invoker-impl/src/invocation_task/service_protocol_runner.rs @@ -9,7 +9,7 @@ // by the Apache License, Version 2.0. use crate::Notification; -use crate::error::{InvocationErrorRelatedEntry, InvokerError, SdkInvocationError}; +use crate::error::{InvocationErrorRelatedEntry, InvokerErrorKind, SdkInvocationError}; use crate::invocation_task::{ InvocationTask, InvocationTaskOutputInner, InvokerBodyStream, InvokerRequestStreamSender, ResponseChunk, ResponseStreamState, TerminalLoopState, X_RESTATE_SERVER, @@ -202,7 +202,7 @@ where // Sanity check of the stream decoder if self.decoder.has_remaining() { warn_it!( - InvokerError::WriteAfterEndOfStream, + InvokerErrorKind::WriteAfterEndOfStream, "The read buffer is non empty after the stream has been closed." ); } @@ -371,7 +371,7 @@ where ResponseChunk::Data(buf) => crate::shortcircuit!(self.handle_read(parent_span_context, buf)), ResponseChunk::End => { // Response stream was closed without SuspensionMessage, EndMessage or ErrorMessage - return TerminalLoopState::Failed(InvokerError::Sdk(SdkInvocationError::unknown())) + return TerminalLoopState::Failed(InvokerErrorKind::Sdk(SdkInvocationError::unknown())) } } }, @@ -398,13 +398,13 @@ where ResponseChunk::Data(buf) => crate::shortcircuit!(self.handle_read(parent_span_context, buf)), ResponseChunk::End => { // Response stream was closed without SuspensionMessage, EndMessage or ErrorMessage - return TerminalLoopState::Failed(InvokerError::Sdk(SdkInvocationError::unknown()) ) + return TerminalLoopState::Failed(InvokerErrorKind::Sdk(SdkInvocationError::unknown())) } } }, _ = tokio::time::sleep(self.invocation_task.abort_timeout) => { warn!("Inactivity detected, going to close invocation"); - return TerminalLoopState::Failed(InvokerError::AbortTimeoutFired(self.invocation_task.abort_timeout.into())) + return TerminalLoopState::Failed(InvokerErrorKind::AbortTimeoutFired(self.invocation_task.abort_timeout.into())) }, } } @@ -419,7 +419,7 @@ where state_entries: EagerState, retry_count_since_last_stored_entry: u32, duration_since_last_stored_entry: Duration, - ) -> Result<(), InvokerError> { + ) -> Result<(), InvokerErrorKind> { let is_partial = state_entries.is_partial(); // Send the invoke frame @@ -446,12 +446,12 @@ where &mut self, http_stream_tx: &mut InvokerRequestStreamSender, msg: ProtocolMessage, - ) -> Result<(), InvokerError> { + ) -> Result<(), InvokerErrorKind> { trace!(restate.protocol.message = ?msg, "Sending message"); let buf = self.encoder.encode(msg); if http_stream_tx.send(Ok(Frame::data(buf))).await.is_err() { - return Err(InvokerError::UnexpectedClosedRequestStream); + return Err(InvokerErrorKind::UnexpectedClosedRequestStream); }; Ok(()) } @@ -459,24 +459,24 @@ where fn handle_response_headers( &mut self, mut parts: http::response::Parts, - ) -> Result<(), InvokerError> { + ) -> Result<(), InvokerErrorKind> { // if service is running behind a gateway, the service can be down // but we still get a response code from the gateway itself. In that // case we still need to return the proper error if GATEWAY_ERRORS_CODES.contains(&parts.status) { - return Err(InvokerError::ServiceUnavailable(parts.status)); + return Err(InvokerErrorKind::ServiceUnavailable(parts.status)); } // otherwise we return generic UnexpectedResponse if !parts.status.is_success() { // Decorate the error in case of UNSUPPORTED_MEDIA_TYPE, as it probably is the incompatible protocol version if parts.status == StatusCode::UNSUPPORTED_MEDIA_TYPE { - return Err(InvokerError::BadNegotiatedServiceProtocolVersion( + return Err(InvokerErrorKind::BadNegotiatedServiceProtocolVersion( self.service_protocol_version, )); } - return Err(InvokerError::UnexpectedResponse(parts.status)); + return Err(InvokerErrorKind::UnexpectedResponse(parts.status)); } let content_type = parts.headers.remove(http::header::CONTENT_TYPE); @@ -487,14 +487,14 @@ where { #[allow(clippy::borrow_interior_mutable_const)] if ct != expected_content_type { - return Err(InvokerError::UnexpectedContentType( + return Err(InvokerErrorKind::UnexpectedContentType( Some(ct), expected_content_type, )); } } None => { - return Err(InvokerError::UnexpectedContentType( + return Err(InvokerErrorKind::UnexpectedContentType( None, expected_content_type, )); @@ -505,7 +505,7 @@ where self.invocation_task .send_invoker_tx(InvocationTaskOutputInner::ServerHeaderReceived( hv.to_str() - .map_err(|e| InvokerError::BadHeader(X_RESTATE_SERVER, e))? + .map_err(|e| InvokerErrorKind::BadHeader(X_RESTATE_SERVER, e))? .to_owned(), )) } @@ -536,23 +536,23 @@ where trace!(restate.protocol.message_header = ?mh, restate.protocol.message = ?message, "Received message"); match message { ProtocolMessage::Start { .. } => { - TerminalLoopState::Failed(InvokerError::UnexpectedMessage(MessageType::Start)) - } - ProtocolMessage::Completion(_) => { - TerminalLoopState::Failed(InvokerError::UnexpectedMessage(MessageType::Completion)) - } - ProtocolMessage::EntryAck(_) => { - TerminalLoopState::Failed(InvokerError::UnexpectedMessage(MessageType::EntryAck)) + TerminalLoopState::Failed(InvokerErrorKind::UnexpectedMessage(MessageType::Start)) } + ProtocolMessage::Completion(_) => TerminalLoopState::Failed( + InvokerErrorKind::UnexpectedMessage(MessageType::Completion), + ), + ProtocolMessage::EntryAck(_) => TerminalLoopState::Failed( + InvokerErrorKind::UnexpectedMessage(MessageType::EntryAck), + ), ProtocolMessage::Suspension(suspension) => { let suspension_indexes = HashSet::from_iter(suspension.entry_indexes); // We currently don't support empty suspension_indexes set if suspension_indexes.is_empty() { - return TerminalLoopState::Failed(InvokerError::EmptySuspensionMessage); + return TerminalLoopState::Failed(InvokerErrorKind::EmptySuspensionMessage); } // Sanity check on the suspension indexes if *suspension_indexes.iter().max().unwrap() >= self.next_journal_index { - return TerminalLoopState::Failed(InvokerError::BadSuspensionMessage( + return TerminalLoopState::Failed(InvokerErrorKind::BadSuspensionMessage( suspension_indexes, self.next_journal_index, )); @@ -560,7 +560,7 @@ where TerminalLoopState::Suspended(suspension_indexes) } ProtocolMessage::Error(e) => { - TerminalLoopState::Failed(InvokerError::Sdk(SdkInvocationError { + TerminalLoopState::Failed(InvokerErrorKind::Sdk(SdkInvocationError { related_entry: Some(InvocationErrorRelatedEntry { related_entry_index: e.related_entry_index, related_entry_name: e.related_entry_name.clone(), @@ -588,7 +588,7 @@ where &self.invocation_task.invocation_target, parent_span_context ) - .map_err(|e| InvokerError::EntryEnrichment( + .map_err(|e| InvokerErrorKind::EntryEnrichment( self.next_journal_index, entry_type, e diff --git a/crates/invoker-impl/src/invocation_task/service_protocol_runner_v4.rs b/crates/invoker-impl/src/invocation_task/service_protocol_runner_v4.rs index 10088273bd..4884bfc625 100644 --- a/crates/invoker-impl/src/invocation_task/service_protocol_runner_v4.rs +++ b/crates/invoker-impl/src/invocation_task/service_protocol_runner_v4.rs @@ -56,7 +56,8 @@ use restate_types::service_protocol::ServiceProtocolVersion; use crate::Notification; use crate::error::{ - CommandPreconditionError, InvocationErrorRelatedCommandV2, InvokerError, SdkInvocationErrorV2, + CommandPreconditionError, InvocationErrorRelatedCommandV2, InvokerErrorKind, + SdkInvocationErrorV2, }; use crate::invocation_task::{ InvocationTask, InvocationTaskOutputInner, InvokerBodyStream, InvokerRequestStreamSender, @@ -214,7 +215,7 @@ where // Sanity check of the stream decoder if self.decoder.has_remaining() { warn_it!( - InvokerError::WriteAfterEndOfStream, + InvokerErrorKind::WriteAfterEndOfStream, "The read buffer is non empty after the stream has been closed." ); } @@ -383,7 +384,7 @@ where ResponseChunk::Data(buf) => crate::shortcircuit!(self.handle_read(parent_span_context, buf)), ResponseChunk::End => { // Response stream was closed without SuspensionMessage, EndMessage or ErrorMessage - return TerminalLoopState::Failed(InvokerError::SdkV2(SdkInvocationErrorV2::unknown())) + return TerminalLoopState::Failed(InvokerErrorKind::SdkV2(SdkInvocationErrorV2::unknown())) } } }, @@ -410,13 +411,13 @@ where ResponseChunk::Data(buf) => crate::shortcircuit!(self.handle_read(parent_span_context, buf)), ResponseChunk::End => { // Response stream was closed without SuspensionMessage, EndMessage or ErrorMessage - return TerminalLoopState::Failed(InvokerError::SdkV2(SdkInvocationErrorV2::unknown())) + return TerminalLoopState::Failed(InvokerErrorKind::SdkV2(SdkInvocationErrorV2::unknown())) } } }, _ = tokio::time::sleep(self.invocation_task.abort_timeout) => { warn!("Inactivity detected, going to close invocation"); - return TerminalLoopState::Failed(InvokerError::AbortTimeoutFired(self.invocation_task.abort_timeout.into())) + return TerminalLoopState::Failed(InvokerErrorKind::AbortTimeoutFired(self.invocation_task.abort_timeout.into())) }, } } @@ -431,7 +432,7 @@ where state_entries: EagerState, retry_count_since_last_stored_entry: u32, duration_since_last_stored_entry: Duration, - ) -> Result<(), InvokerError> { + ) -> Result<(), InvokerErrorKind> { let is_partial = state_entries.is_partial(); // Send the invoke frame @@ -458,7 +459,7 @@ where &mut self, http_stream_tx: &mut InvokerRequestStreamSender, entry: RawEntry, - ) -> Result<(), InvokerError> { + ) -> Result<(), InvokerErrorKind> { // TODO(slinkydeveloper) could this code be improved a tad bit more introducing something to our magic macro in message_codec? match entry.inner { RawEntryInner::Command(cmd) => { @@ -489,12 +490,12 @@ where &mut self, http_stream_tx: &mut InvokerRequestStreamSender, msg: Message, - ) -> Result<(), InvokerError> { + ) -> Result<(), InvokerErrorKind> { trace!(restate.protocol.message = ?msg, "Sending message"); let buf = self.encoder.encode(msg); if http_stream_tx.send(Ok(Frame::data(buf))).await.is_err() { - return Err(InvokerError::UnexpectedClosedRequestStream); + return Err(InvokerErrorKind::UnexpectedClosedRequestStream); }; Ok(()) } @@ -504,12 +505,12 @@ where http_stream_tx: &mut InvokerRequestStreamSender, ty: MessageType, buf: Bytes, - ) -> Result<(), InvokerError> { + ) -> Result<(), InvokerErrorKind> { trace!(restate.protocol.message = ?ty, "Sending message"); let buf = self.encoder.encode_raw(ty, buf); if http_stream_tx.send(Ok(Frame::data(buf))).await.is_err() { - return Err(InvokerError::UnexpectedClosedRequestStream); + return Err(InvokerErrorKind::UnexpectedClosedRequestStream); }; Ok(()) } @@ -517,24 +518,24 @@ where fn handle_response_headers( &mut self, mut parts: http::response::Parts, - ) -> Result<(), InvokerError> { + ) -> Result<(), InvokerErrorKind> { // if service is running behind a gateway, the service can be down // but we still get a response code from the gateway itself. In that // case we still need to return the proper error if GATEWAY_ERRORS_CODES.contains(&parts.status) { - return Err(InvokerError::ServiceUnavailable(parts.status)); + return Err(InvokerErrorKind::ServiceUnavailable(parts.status)); } // otherwise we return generic UnexpectedResponse if !parts.status.is_success() { // Decorate the error in case of UNSUPPORTED_MEDIA_TYPE, as it probably is the incompatible protocol version if parts.status == StatusCode::UNSUPPORTED_MEDIA_TYPE { - return Err(InvokerError::BadNegotiatedServiceProtocolVersion( + return Err(InvokerErrorKind::BadNegotiatedServiceProtocolVersion( self.service_protocol_version, )); } - return Err(InvokerError::UnexpectedResponse(parts.status)); + return Err(InvokerErrorKind::UnexpectedResponse(parts.status)); } let content_type = parts.headers.remove(http::header::CONTENT_TYPE); @@ -545,14 +546,14 @@ where { #[allow(clippy::borrow_interior_mutable_const)] if ct != expected_content_type { - return Err(InvokerError::UnexpectedContentType( + return Err(InvokerErrorKind::UnexpectedContentType( Some(ct), expected_content_type, )); } } None => { - return Err(InvokerError::UnexpectedContentType( + return Err(InvokerErrorKind::UnexpectedContentType( None, expected_content_type, )); @@ -563,7 +564,7 @@ where self.invocation_task .send_invoker_tx(InvocationTaskOutputInner::ServerHeaderReceived( hv.to_str() - .map_err(|e| InvokerError::BadHeader(X_RESTATE_SERVER, e))? + .map_err(|e| InvokerErrorKind::BadHeader(X_RESTATE_SERVER, e))? .to_owned(), )) } @@ -611,11 +612,11 @@ where match message { // Control messages Message::Start { .. } => { - TerminalLoopState::Failed(InvokerError::UnexpectedMessageV4(MessageType::Start)) + TerminalLoopState::Failed(InvokerErrorKind::UnexpectedMessageV4(MessageType::Start)) } - Message::CommandAck(_) => TerminalLoopState::Failed(InvokerError::UnexpectedMessageV4( - MessageType::CommandAck, - )), + Message::CommandAck(_) => TerminalLoopState::Failed( + InvokerErrorKind::UnexpectedMessageV4(MessageType::CommandAck), + ), Message::Suspension(suspension) => self.handle_suspension_message(suspension), Message::Error(e) => self.handle_error_message(e), Message::End(_) => TerminalLoopState::Closed, @@ -627,7 +628,7 @@ where result: match crate::shortcircuit!( run_completion .result - .ok_or(InvokerError::MalformedProposeRunCompletion) + .ok_or(InvokerErrorKind::MalformedProposeRunCompletion) ) { proto::propose_run_completion_message::Result::Value(b) => { RunResult::Success(b) @@ -709,11 +710,13 @@ where span_relation: parent_span_context.as_linked() } ) - .map_err(|e| InvokerError::CommandPrecondition( - self.command_index, - EntryType::Command(CommandType::OneWayCall), - e - )) + .map_err( + |e| InvokerErrorKind::CommandPrecondition( + self.command_index, + EntryType::Command(CommandType::OneWayCall), + e + ) + ) ), invoke_time: cmd.invoke_time.into(), invocation_id_completion_id: cmd.invocation_id_notification_idx, @@ -745,11 +748,13 @@ where span_relation: parent_span_context.as_parent() } ) - .map_err(|e| InvokerError::CommandPrecondition( - self.command_index, - EntryType::Command(CommandType::Call), - e - )) + .map_err( + |e| InvokerErrorKind::CommandPrecondition( + self.command_index, + EntryType::Command(CommandType::Call), + e + ) + ) ), invocation_id_completion_id: cmd.invocation_id_notification_idx, result_completion_id: cmd.result_completion_id, @@ -891,51 +896,57 @@ where TerminalLoopState::Continue(()) } Message::SignalNotification(_) => TerminalLoopState::Failed( - InvokerError::UnexpectedMessageV4(MessageType::SignalNotification), + InvokerErrorKind::UnexpectedMessageV4(MessageType::SignalNotification), ), Message::GetInvocationOutputCompletionNotification(_) => { - TerminalLoopState::Failed(InvokerError::UnexpectedMessageV4( + TerminalLoopState::Failed(InvokerErrorKind::UnexpectedMessageV4( MessageType::GetInvocationOutputCompletionNotification, )) } Message::AttachInvocationCompletionNotification(_) => { - TerminalLoopState::Failed(InvokerError::UnexpectedMessageV4( + TerminalLoopState::Failed(InvokerErrorKind::UnexpectedMessageV4( MessageType::AttachInvocationCompletionNotification, )) } Message::RunCompletionNotification(_) => TerminalLoopState::Failed( - InvokerError::UnexpectedMessageV4(MessageType::RunCompletionNotification), + InvokerErrorKind::UnexpectedMessageV4(MessageType::RunCompletionNotification), ), Message::CallCompletionNotification(_) => TerminalLoopState::Failed( - InvokerError::UnexpectedMessageV4(MessageType::CallCompletionNotification), + InvokerErrorKind::UnexpectedMessageV4(MessageType::CallCompletionNotification), ), Message::CallInvocationIdCompletionNotification(_) => { - TerminalLoopState::Failed(InvokerError::UnexpectedMessageV4( + TerminalLoopState::Failed(InvokerErrorKind::UnexpectedMessageV4( MessageType::CallInvocationIdCompletionNotification, )) } Message::SleepCompletionNotification(_) => TerminalLoopState::Failed( - InvokerError::UnexpectedMessageV4(MessageType::SleepCompletionNotification), + InvokerErrorKind::UnexpectedMessageV4(MessageType::SleepCompletionNotification), ), Message::CompletePromiseCompletionNotification(_) => { - TerminalLoopState::Failed(InvokerError::UnexpectedMessageV4( + TerminalLoopState::Failed(InvokerErrorKind::UnexpectedMessageV4( MessageType::CompletePromiseCompletionNotification, )) } - Message::PeekPromiseCompletionNotification(_) => TerminalLoopState::Failed( - InvokerError::UnexpectedMessageV4(MessageType::PeekPromiseCompletionNotification), - ), - Message::GetPromiseCompletionNotification(_) => TerminalLoopState::Failed( - InvokerError::UnexpectedMessageV4(MessageType::GetPromiseCompletionNotification), - ), + Message::PeekPromiseCompletionNotification(_) => { + TerminalLoopState::Failed(InvokerErrorKind::UnexpectedMessageV4( + MessageType::PeekPromiseCompletionNotification, + )) + } + Message::GetPromiseCompletionNotification(_) => { + TerminalLoopState::Failed(InvokerErrorKind::UnexpectedMessageV4( + MessageType::GetPromiseCompletionNotification, + )) + } Message::GetLazyStateKeysCompletionNotification(_) => { - TerminalLoopState::Failed(InvokerError::UnexpectedMessageV4( + TerminalLoopState::Failed(InvokerErrorKind::UnexpectedMessageV4( MessageType::GetLazyStateKeysCompletionNotification, )) } - Message::GetLazyStateCompletionNotification(_) => TerminalLoopState::Failed( - InvokerError::UnexpectedMessageV4(MessageType::GetLazyStateCompletionNotification), - ), + Message::GetLazyStateCompletionNotification(_) => { + TerminalLoopState::Failed(InvokerErrorKind::UnexpectedMessageV4( + MessageType::GetLazyStateCompletionNotification, + )) + } Message::Custom(_, _) => { unimplemented!() } @@ -967,13 +978,13 @@ where .collect(); // We currently don't support empty suspension_indexes set if suspension_indexes.is_empty() { - return TerminalLoopState::Failed(InvokerError::EmptySuspensionMessage); + return TerminalLoopState::Failed(InvokerErrorKind::EmptySuspensionMessage); } TerminalLoopState::SuspendedV2(suspension_indexes) } fn handle_error_message(&mut self, error: proto::ErrorMessage) -> TerminalLoopState<()> { - TerminalLoopState::Failed(InvokerError::SdkV2(SdkInvocationErrorV2 { + TerminalLoopState::Failed(InvokerErrorKind::SdkV2(SdkInvocationErrorV2 { related_command: Some(InvocationErrorRelatedCommandV2 { related_command_index: error.related_command_index, related_command_name: error.related_command_name.clone(), @@ -1075,9 +1086,9 @@ fn check_workflow_type( command_index: CommandIndex, entry_type: &EntryType, service_type: &ServiceType, -) -> Result<(), InvokerError> { +) -> Result<(), InvokerErrorKind> { if *service_type != ServiceType::Workflow { - return Err(InvokerError::CommandPrecondition( + return Err(InvokerErrorKind::CommandPrecondition( command_index, *entry_type, CommandPreconditionError::NoWorkflowOperations, @@ -1091,9 +1102,9 @@ fn can_read_state( command_index: CommandIndex, entry_type: &EntryType, invocation_target_type: &InvocationTargetType, -) -> Result<(), InvokerError> { +) -> Result<(), InvokerErrorKind> { if !invocation_target_type.can_read_state() { - return Err(InvokerError::CommandPrecondition( + return Err(InvokerErrorKind::CommandPrecondition( command_index, *entry_type, CommandPreconditionError::NoStateOperations, @@ -1107,10 +1118,10 @@ fn can_write_state( command_index: CommandIndex, entry_type: &EntryType, invocation_target_type: &InvocationTargetType, -) -> Result<(), InvokerError> { +) -> Result<(), InvokerErrorKind> { can_read_state(command_index, entry_type, invocation_target_type)?; if !invocation_target_type.can_write_state() { - return Err(InvokerError::CommandPrecondition( + return Err(InvokerErrorKind::CommandPrecondition( command_index, *entry_type, CommandPreconditionError::NoWriteStateOperations, diff --git a/crates/invoker-impl/src/lib.rs b/crates/invoker-impl/src/lib.rs index f1b2696353..cfedd030a1 100644 --- a/crates/invoker-impl/src/lib.rs +++ b/crates/invoker-impl/src/lib.rs @@ -52,12 +52,12 @@ use tokio::task::{AbortHandle, JoinSet}; use tracing::{debug, trace}; use tracing::{error, instrument}; -use crate::error::SdkInvocationErrorV2; +use crate::error::{InvokerError, SdkInvocationErrorV2}; use crate::metric_definitions::{ INVOKER_ENQUEUE, INVOKER_INVOCATION_TASKS, TASK_OP_COMPLETED, TASK_OP_FAILED, TASK_OP_STARTED, TASK_OP_SUSPENDED, }; -use error::InvokerError; +use error::InvokerErrorKind; pub use input_command::ChannelStatusReader; pub use input_command::InvokerHandle; use restate_invoker_api::invocation_reader::InvocationReader; @@ -1090,7 +1090,7 @@ where .remove_invocation_with_epoch(partition, &invocation_id, invocation_epoch) { debug_assert_eq!(invocation_epoch, ism.invocation_epoch); - self.handle_error_event(options, partition, invocation_id, error, ism) + self.handle_error_event(options, partition, invocation_id, error.kind, ism) .await; } else { // If no state machine, this might be a result for an aborted invocation. @@ -1176,7 +1176,7 @@ where options: &InvokerOptions, partition: PartitionLeaderEpoch, invocation_id: InvocationId, - error: InvokerError, + error: InvokerErrorKind, mut ism: InvocationStateMachine, ) { match ism.handle_task_error( @@ -1211,7 +1211,7 @@ where let next_retry_at = SystemTime::now() + next_retry_timer_duration; let journal_v2_related_command_type = - if let InvokerError::SdkV2(SdkInvocationErrorV2 { + if let InvokerErrorKind::SdkV2(SdkInvocationErrorV2 { related_command: Some(ref related_entry), .. }) = error @@ -1436,7 +1436,7 @@ mod tests { use restate_types::schema::invocation_target::InvocationTargetMetadata; use restate_types::schema::service::{InvocationAttemptOptions, ServiceMetadata}; - use crate::error::{InvokerError, SdkInvocationErrorV2}; + use crate::error::{InvokerErrorKind, SdkInvocationErrorV2}; use crate::quota::InvokerConcurrencyQuota; // -- Mocks @@ -1893,7 +1893,7 @@ mod tests { MOCK_PARTITION, invocation_id, 0, - InvokerError::EmptySuspensionMessage, /* any error is fine */ + InvokerErrorKind::EmptySuspensionMessage.into(), /* any error is fine */ ) .await; @@ -1947,7 +1947,7 @@ mod tests { MOCK_PARTITION, invocation_id, 0, - InvokerError::SdkV2(SdkInvocationErrorV2::unknown()), + InvokerErrorKind::SdkV2(SdkInvocationErrorV2::unknown()).into(), ) .await; assert_eq!(