diff --git a/crates/client/src/errors.rs b/crates/client/src/errors.rs index 02b297bfa..b7cfd9418 100644 --- a/crates/client/src/errors.rs +++ b/crates/client/src/errors.rs @@ -1,9 +1,10 @@ //! Contains errors that can be returned by clients. use http::uri::InvalidUri; +use std::{error::Error, fmt}; use temporalio_common::{ - data_converters::PayloadConversionError, - protos::temporal::api::{common::v1::Payload, failure::v1::Failure, query::v1::QueryRejected}, + data_converters::{PayloadConversionError, TemporalError}, + protos::temporal::api::{common::v1::Payload, query::v1::QueryRejected}, }; use tonic::Code; @@ -123,8 +124,8 @@ pub enum WorkflowUpdateError { NotFound(#[source] tonic::Status), /// The update failed with an application-level failure. - #[error("Update failed: {0:?}")] - Failed(Box), + #[error("Update failed: {}", TemporalErrorChain(.0))] + Failed(#[source] TemporalError), /// Error serializing input or deserializing output. #[error("Payload conversion error: {0}")] @@ -154,8 +155,8 @@ impl WorkflowUpdateError { #[non_exhaustive] pub enum WorkflowGetResultError { /// The workflow finished in failure. - #[error("Workflow failed: {0:?}")] - Failed(Box), + #[error("Workflow failed: {}", TemporalErrorChain(.0))] + Failed(#[source] TemporalError), /// The workflow was cancelled. #[error("Workflow cancelled")] @@ -173,7 +174,7 @@ pub enum WorkflowGetResultError { /// The workflow timed out. #[error("Workflow timed out")] - TimedOut, + Timeout, /// The workflow continued as new. #[error("Workflow continued as new")] @@ -217,12 +218,26 @@ impl WorkflowGetResultError { Self::Failed(_) | Self::Cancelled { .. } | Self::Terminated { .. } - | Self::TimedOut + | Self::Timeout | Self::ContinuedAsNew ) } } +struct TemporalErrorChain<'a>(&'a TemporalError); + +impl fmt::Display for TemporalErrorChain<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0)?; + let mut source = self.0.source(); + while let Some(err) = source { + write!(f, ": {err}")?; + source = err.source(); + } + Ok(()) + } +} + /// Errors returned by client methods that don't need more specific error types. #[derive(thiserror::Error, Debug)] #[non_exhaustive] @@ -292,3 +307,66 @@ impl AsyncActivityError { #[derive(Debug, thiserror::Error)] #[non_exhaustive] pub enum ClientNewError {} + +#[cfg(test)] +mod tests { + use super::{WorkflowGetResultError, WorkflowUpdateError}; + use temporalio_common::{ + data_converters::TemporalError, protos::temporal::api::enums::v1::RetryState, + }; + + #[test] + fn workflow_get_result_error_includes_nested_activity_cause_message() { + let error = WorkflowGetResultError::Failed(TemporalError::Activity { + message: "Activity task failed".into(), + stack_trace: String::new(), + scheduled_event_id: 1, + started_event_id: 2, + identity: "worker".into(), + activity_type: "test-activity".into(), + activity_id: "activity-id".into(), + retry_state: RetryState::NonRetryableFailure, + cause: Some(Box::new(TemporalError::Application { + message: "boom".into(), + stack_trace: String::new(), + r#type: "TestError".into(), + non_retryable: false, + details: None, + next_retry_delay: None, + cause: None, + })), + }); + + let rendered = error.to_string(); + assert!(rendered.contains("Workflow failed: Activity task failed")); + assert!(rendered.contains("boom")); + } + + #[test] + fn workflow_update_error_includes_nested_child_workflow_cause_message() { + let error = WorkflowUpdateError::Failed(TemporalError::ChildWorkflow { + message: "Child workflow task failed".into(), + stack_trace: String::new(), + namespace: "default".into(), + workflow_id: "child-id".into(), + run_id: "child-run".into(), + workflow_type: "child-type".into(), + initiated_event_id: 3, + started_event_id: 4, + retry_state: RetryState::InProgress, + cause: Some(Box::new(TemporalError::Application { + message: "child boom".into(), + stack_trace: String::new(), + r#type: "ChildError".into(), + non_retryable: false, + details: None, + next_retry_delay: None, + cause: None, + })), + }); + + let rendered = error.to_string(); + assert!(rendered.contains("Update failed: Child workflow task failed")); + assert!(rendered.contains("child boom")); + } +} diff --git a/crates/client/src/lib.rs b/crates/client/src/lib.rs index 26bc766ef..92b905448 100644 --- a/crates/client/src/lib.rs +++ b/crates/client/src/lib.rs @@ -35,6 +35,7 @@ pub use async_activity_handle::{ ActivityHeartbeatResponse, ActivityIdentifier, AsyncActivityHandle, }; +pub use errors::*; pub use metrics::{LONG_REQUEST_LATENCY_HISTOGRAM_NAME, REQUEST_LATENCY_HISTOGRAM_NAME}; pub use options_structs::*; pub use replaceable::SharedReplaceableClient; @@ -42,8 +43,8 @@ pub use retry::RetryOptions; pub use tonic; pub use workflow_handle::{ UntypedQuery, UntypedSignal, UntypedUpdate, UntypedWorkflow, UntypedWorkflowHandle, - WorkflowExecutionDescription, WorkflowExecutionInfo, WorkflowExecutionResult, WorkflowHandle, - WorkflowHistory, WorkflowUpdateHandle, + WorkflowExecutionDescription, WorkflowExecutionInfo, WorkflowHandle, WorkflowHistory, + WorkflowUpdateHandle, }; use crate::{ @@ -55,7 +56,6 @@ use crate::{ request_extensions::RequestExt, worker::ClientWorkerSet, }; -use errors::*; use futures_util::{stream, stream::Stream}; use http::Uri; use parking_lot::RwLock; diff --git a/crates/client/src/workflow_handle.rs b/crates/client/src/workflow_handle.rs index 661269c87..a1981d12a 100644 --- a/crates/client/src/workflow_handle.rs +++ b/crates/client/src/workflow_handle.rs @@ -41,7 +41,7 @@ use uuid::Uuid; /// Enumerates terminal states for a particular workflow execution #[derive(Debug)] #[allow(clippy::large_enum_variant)] -pub enum WorkflowExecutionResult { +pub(crate) enum WorkflowExecutionResult { /// The workflow finished successfully Succeeded(T), /// The workflow finished in failure @@ -267,14 +267,21 @@ where let raw = self.get_result_raw(opts).await?; match raw { WorkflowExecutionResult::Succeeded(v) => Ok(v), - WorkflowExecutionResult::Failed(f) => Err(WorkflowGetResultError::Failed(Box::new(f))), + WorkflowExecutionResult::Failed(f) => { + let err = self + .client + .data_converter() + .decode_failure(f, &SerializationContextData::Workflow) + .await; + Err(WorkflowGetResultError::Failed(err)) + } WorkflowExecutionResult::Cancelled { details } => { Err(WorkflowGetResultError::Cancelled { details }) } WorkflowExecutionResult::Terminated { details } => { Err(WorkflowGetResultError::Terminated { details }) } - WorkflowExecutionResult::TimedOut => Err(WorkflowGetResultError::TimedOut), + WorkflowExecutionResult::TimedOut => Err(WorkflowGetResultError::Timeout), WorkflowExecutionResult::ContinuedAsNew => Err(WorkflowGetResultError::ContinuedAsNew), } } @@ -805,7 +812,12 @@ where .await .map_err(WorkflowUpdateError::from), Some(update::v1::outcome::Value::Failure(failure)) => { - Err(WorkflowUpdateError::Failed(Box::new(failure))) + let err = self + .client + .data_converter() + .decode_failure(failure, &SerializationContextData::Workflow) + .await; + Err(WorkflowUpdateError::Failed(err)) } None => Err(WorkflowUpdateError::Other( "Update returned no outcome value".into(), diff --git a/crates/common/Cargo.toml b/crates/common/Cargo.toml index f88c5c8ab..a1d39b204 100644 --- a/crates/common/Cargo.toml +++ b/crates/common/Cargo.toml @@ -104,6 +104,8 @@ pbjson-build = { workspace = true } workspace = true [dev-dependencies] +assert_matches = "1.5" futures-util = { version = "0.3", default-features = false } +rstest = "0.26" tempfile = "3.21" tokio = { version = "1.47", features = ["macros", "rt"] } diff --git a/crates/common/src/data_converters.rs b/crates/common/src/data_converters.rs index e58f37dad..88d1a8c96 100644 --- a/crates/common/src/data_converters.rs +++ b/crates/common/src/data_converters.rs @@ -1,16 +1,25 @@ //! Contains traits for and default implementations of data converters, codecs, and other //! serialization related functionality. -use crate::protos::temporal::api::{common::v1::Payload, failure::v1::Failure}; +use crate::protos::temporal::api::{ + common::v1::{Payload, Payloads}, + enums::v1::{NexusHandlerErrorRetryBehavior, RetryState, TimeoutType}, + failure::v1::{ + ActivityFailureInfo, ApplicationFailureInfo, CanceledFailureInfo, + ChildWorkflowExecutionFailureInfo, Failure, NexusHandlerFailureInfo, + NexusOperationFailureInfo, ServerFailureInfo, TerminatedFailureInfo, TimeoutFailureInfo, + failure::FailureInfo, + }, +}; use futures::{FutureExt, future::BoxFuture}; use std::{collections::HashMap, sync::Arc}; +use tracing::warn; /// Combines a [`PayloadConverter`], [`FailureConverter`], and [`PayloadCodec`] to handle all /// serialization needs for communicating with the Temporal server. #[derive(Clone)] pub struct DataConverter { payload_converter: PayloadConverter, - #[allow(dead_code)] // Will be used for failure conversion failure_converter: Arc, codec: Arc, } @@ -105,10 +114,139 @@ impl DataConverter { &self.payload_converter } + /// Returns the failure converter component of this data converter. + pub fn failure_converter(&self) -> &(dyn FailureConverter + Send + Sync) { + self.failure_converter.as_ref() + } + /// Returns the codec component of this data converter. pub fn codec(&self) -> &(dyn PayloadCodec + Send + Sync) { self.codec.as_ref() } + + /// Convert a [`Failure`] proto into an error using only the + /// failure converter (no codec). Use [`decode_failure`](Self::decode_failure) + /// for the full pipeline including codec. + pub fn to_error(&self, failure: Failure, context: &SerializationContextData) -> TemporalError { + self.failure_converter + .to_error(failure, &self.payload_converter, context) + } + + /// Convert an error into a [`Failure`] proto using only the failure + /// converter (no codec). The codec is applied separately by the + /// `PayloadVisitable` visitor on outgoing completions. + pub fn to_failure( + &self, + error: Box, + context: &SerializationContextData, + ) -> Failure { + self.failure_converter + .to_failure(error, &self.payload_converter, context) + } + + /// Decode a [`Failure`] proto into an error, applying the codec + /// to embedded payloads before running the failure converter. + pub async fn decode_failure( + &self, + failure: Failure, + context: &SerializationContextData, + ) -> TemporalError { + let decoded = Self::apply_codec_to_failure(failure, context, |ctx, payloads| { + self.codec.decode(ctx, payloads) + }) + .await; + self.failure_converter + .to_error(decoded, &self.payload_converter, context) + } + + /// Recursively apply a codec operation (encode or decode) to all payloads + /// embedded in a [`Failure`]: `encoded_attributes`, detail payloads in + /// failure info variants, and causes. + async fn apply_codec_to_failure( + failure: Failure, + context: &SerializationContextData, + codec_fn: F, + ) -> Failure + where + F: Fn(&SerializationContextData, Vec) -> Fut + Copy, + Fut: std::future::Future>, + { + let Failure { + message, + source, + stack_trace, + encoded_attributes, + cause, + failure_info, + .. + } = failure; + let cause_context = Self::nested_failure_context(*context, failure_info.as_ref()); + + let encoded_attributes = match encoded_attributes { + Some(ea) => codec_fn(context, vec![ea]).await.into_iter().next(), + None => None, + }; + + let failure_info = match failure_info { + Some(FailureInfo::ApplicationFailureInfo(mut app)) => { + let mut d = app.details.take().unwrap_or_default(); + d.payloads = codec_fn(context, d.payloads).await; + app.details = Some(d); + Some(FailureInfo::ApplicationFailureInfo(app)) + } + Some(FailureInfo::TimeoutFailureInfo(mut t)) => { + let mut d = t.last_heartbeat_details.take().unwrap_or_default(); + d.payloads = codec_fn(context, d.payloads).await; + t.last_heartbeat_details = Some(d); + Some(FailureInfo::TimeoutFailureInfo(t)) + } + Some(FailureInfo::CanceledFailureInfo(mut c)) => { + let mut d = c.details.take().unwrap_or_default(); + d.payloads = codec_fn(context, d.payloads).await; + c.details = Some(d); + Some(FailureInfo::CanceledFailureInfo(c)) + } + Some(FailureInfo::ResetWorkflowFailureInfo(mut r)) => { + let mut d = r.last_heartbeat_details.take().unwrap_or_default(); + d.payloads = codec_fn(context, d.payloads).await; + r.last_heartbeat_details = Some(d); + Some(FailureInfo::ResetWorkflowFailureInfo(r)) + } + other => other, + }; + + let cause = match cause { + Some(c) => Some(Box::new( + Box::pin(Self::apply_codec_to_failure(*c, &cause_context, codec_fn)).await, + )), + None => None, + }; + + Failure { + message, + source, + stack_trace, + encoded_attributes, + cause, + failure_info, + } + } + + fn nested_failure_context( + context: SerializationContextData, + failure_info: Option<&FailureInfo>, + ) -> SerializationContextData { + match failure_info { + Some(FailureInfo::ActivityFailureInfo(_)) => SerializationContextData::Activity, + Some(FailureInfo::ChildWorkflowExecutionFailureInfo(_)) => { + SerializationContextData::Workflow + } + Some(FailureInfo::NexusOperationExecutionFailureInfo(_)) => { + SerializationContextData::Nexus + } + _ => context, + } + } } /// Data about the serialization context, indicating where the serialization is occurring. @@ -197,14 +335,19 @@ impl std::error::Error for PayloadConversionError { } /// Converts between Rust errors and Temporal [`Failure`] protobufs. +/// +/// Implementations must be infallible — conversion should always succeed, +/// falling back to a reasonable default (e.g. wrapping as +/// `ApplicationFailureInfo`) rather than returning an error. This matches +/// every other Temporal SDK. pub trait FailureConverter { /// Convert an error into a Temporal failure protobuf. fn to_failure( &self, - error: Box, + error: Box, payload_converter: &PayloadConverter, context: &SerializationContextData, - ) -> Result; + ) -> Failure; /// Convert a Temporal failure protobuf back into a Rust error. fn to_error( @@ -212,10 +355,575 @@ pub trait FailureConverter { failure: Failure, payload_converter: &PayloadConverter, context: &SerializationContextData, - ) -> Result, PayloadConversionError>; + ) -> TemporalError; } -/// Default (currently unimplemented) failure converter. +/// Default failure converter that maps between Temporal [`Failure`] protobufs +/// and Rust error types. pub struct DefaultFailureConverter; + +/// An error produced by the failure converter, representing a Temporal +/// [`Failure`] proto as an error. +#[derive(Debug, thiserror::Error)] +pub enum TemporalError { + /// Application-level failure — the primary error type users throw from + /// workflows and activities. + #[error("{message}")] + Application { + /// Human-readable error message. + message: String, + /// Stack trace from the originating SDK, if available. + stack_trace: String, + /// Application error type string. + r#type: String, + /// Whether this error is non-retryable. + non_retryable: bool, + /// Serialized detail payloads. + details: Option, + /// Override for the next retry delay. + next_retry_delay: Option, + /// Recursive cause. + #[source] + cause: Option>, + }, + /// A timeout occurred (activity start-to-close, schedule-to-close, etc.). + #[error("{message}")] + Timeout { + /// Human-readable error message. + message: String, + /// Stack trace from the originating SDK, if available. + stack_trace: String, + /// Which kind of timeout. + timeout_type: TimeoutType, + /// Last heartbeat details before the timeout. + last_heartbeat_details: Option, + /// Recursive cause. + #[source] + cause: Option>, + }, + /// The operation was cancelled. + #[error("{message}")] + Cancelled { + /// Human-readable error message. + message: String, + /// Stack trace from the originating SDK, if available. + stack_trace: String, + /// Cancellation detail payloads. + details: Option, + /// Recursive cause. + #[source] + cause: Option>, + }, + /// The workflow or activity was terminated. + #[error("{message}")] + Terminated { + /// Human-readable error message. + message: String, + /// Stack trace from the originating SDK, if available. + stack_trace: String, + /// Recursive cause. + #[source] + cause: Option>, + }, + /// An error originated at the Temporal server. + #[error("{message}")] + Server { + /// Human-readable error message. + message: String, + /// Stack trace from the originating SDK, if available. + stack_trace: String, + /// Whether this error is non-retryable. + non_retryable: bool, + /// Recursive cause. + #[source] + cause: Option>, + }, + /// An activity execution failed. The original error is available as the + /// cause. + #[error("{message}")] + Activity { + /// Human-readable error message. + message: String, + /// Stack trace from the originating SDK, if available. + stack_trace: String, + /// Scheduled event ID. + scheduled_event_id: i64, + /// Started event ID. + started_event_id: i64, + /// Worker identity. + identity: String, + /// Activity type name. + activity_type: String, + /// Activity ID. + activity_id: String, + /// Retry state at the time of failure. + retry_state: RetryState, + /// Recursive cause (typically the underlying application error). + #[source] + cause: Option>, + }, + /// A child workflow execution failed. + #[error("{message}")] + ChildWorkflow { + /// Human-readable error message. + message: String, + /// Stack trace from the originating SDK, if available. + stack_trace: String, + /// Child workflow namespace. + namespace: String, + /// Child workflow ID. + workflow_id: String, + /// Child workflow run ID. + run_id: String, + /// Child workflow type name. + workflow_type: String, + /// Initiated event ID. + initiated_event_id: i64, + /// Started event ID. + started_event_id: i64, + /// Retry state at the time of failure. + retry_state: RetryState, + /// Recursive cause. + #[source] + cause: Option>, + }, + /// A Nexus operation failed. + #[error("{message}")] + NexusOperation { + /// Human-readable error message. + message: String, + /// Stack trace from the originating SDK, if available. + stack_trace: String, + /// Scheduled event ID. + scheduled_event_id: i64, + /// Nexus endpoint name. + endpoint: String, + /// Nexus service name. + service: String, + /// Nexus operation name. + operation: String, + /// Operation token (may be empty for sync completions). + operation_token: String, + /// Recursive cause. + #[source] + cause: Option>, + }, + /// A Nexus handler produced an error. + #[error("{message}")] + NexusHandler { + /// Human-readable error message. + message: String, + /// Stack trace from the originating SDK, if available. + stack_trace: String, + /// Nexus error type. + r#type: String, + /// Retry behavior. + retry_behavior: NexusHandlerErrorRetryBehavior, + /// Recursive cause. + #[source] + cause: Option>, + }, + /// A failure with no specific `failure_info`, or an unmodeled variant. + #[error("{message}")] + Generic { + /// Human-readable error message. + message: String, + /// Stack trace from the originating SDK, if available. + stack_trace: String, + /// Recursive cause. + #[source] + cause: Option>, + }, + /// An opaque error from a custom failure converter that doesn't map to a + /// known Temporal failure type. + #[error(transparent)] + Other(Box), +} + +impl TemporalError { + /// The human-readable error message, if this is a known Temporal failure + /// variant. Returns `None` for [`Other`](Self::Other) — use `Display` to + /// get a message from any variant. + pub fn message(&self) -> Option<&str> { + match self { + Self::Application { message, .. } + | Self::Timeout { message, .. } + | Self::Cancelled { message, .. } + | Self::Terminated { message, .. } + | Self::Server { message, .. } + | Self::Activity { message, .. } + | Self::ChildWorkflow { message, .. } + | Self::NexusOperation { message, .. } + | Self::NexusHandler { message, .. } + | Self::Generic { message, .. } => Some(message), + Self::Other(_) => None, + } + } + + /// The stack trace, if this is a known Temporal failure variant. Returns + /// `None` for [`Other`](Self::Other). + pub fn stack_trace(&self) -> Option<&str> { + match self { + Self::Application { stack_trace, .. } + | Self::Timeout { stack_trace, .. } + | Self::Cancelled { stack_trace, .. } + | Self::Terminated { stack_trace, .. } + | Self::Server { stack_trace, .. } + | Self::Activity { stack_trace, .. } + | Self::ChildWorkflow { stack_trace, .. } + | Self::NexusOperation { stack_trace, .. } + | Self::NexusHandler { stack_trace, .. } + | Self::Generic { stack_trace, .. } => Some(stack_trace), + Self::Other(_) => None, + } + } + + /// Returns the cause of this error, if any. + pub fn cause(&self) -> Option<&TemporalError> { + match self { + Self::Application { cause, .. } + | Self::Timeout { cause, .. } + | Self::Cancelled { cause, .. } + | Self::Terminated { cause, .. } + | Self::Server { cause, .. } + | Self::Activity { cause, .. } + | Self::ChildWorkflow { cause, .. } + | Self::NexusOperation { cause, .. } + | Self::NexusHandler { cause, .. } + | Self::Generic { cause, .. } => cause.as_deref(), + Self::Other(_) => None, + } + } + + fn from_failure(mut failure: Failure, payload_converter: &PayloadConverter) -> Self { + // If encoded_attributes is present, decode message (and stack_trace) + // from the payload — the top-level fields were cleared for encryption. + if let Some(payload) = failure.encoded_attributes.take() { + let ctx = SerializationContext { + data: &SerializationContextData::None, + converter: payload_converter, + }; + match payload_converter.from_payload::(&ctx, payload) { + Ok(attrs) => { + if let Some(msg) = attrs.get("message").and_then(|v| v.as_str()) { + failure.message = msg.to_owned(); + } + if let Some(st) = attrs.get("stack_trace").and_then(|v| v.as_str()) { + failure.stack_trace = st.to_owned(); + } + } + Err(e) => { + warn!( + error = %e, + "Failed to decode encoded_attributes on Failure proto, \ + falling back to top-level message" + ); + } + } + } + + let cause = failure + .cause + .map(|c| Box::new(Self::from_failure(*c, payload_converter))); + + let stack_trace = failure.stack_trace; + + match failure.failure_info { + Some(FailureInfo::ApplicationFailureInfo(info)) => Self::Application { + message: failure.message, + stack_trace, + r#type: info.r#type, + non_retryable: info.non_retryable, + details: info.details, + next_retry_delay: info.next_retry_delay, + cause, + }, + Some(FailureInfo::TimeoutFailureInfo(info)) => Self::Timeout { + message: failure.message, + stack_trace, + timeout_type: info.timeout_type(), + last_heartbeat_details: info.last_heartbeat_details, + cause, + }, + Some(FailureInfo::CanceledFailureInfo(info)) => Self::Cancelled { + message: failure.message, + stack_trace, + details: info.details, + cause, + }, + Some(FailureInfo::TerminatedFailureInfo(_)) => Self::Terminated { + message: failure.message, + stack_trace, + cause, + }, + Some(FailureInfo::ServerFailureInfo(info)) => Self::Server { + message: failure.message, + stack_trace, + non_retryable: info.non_retryable, + cause, + }, + Some(FailureInfo::ActivityFailureInfo(info)) => { + let retry_state = info.retry_state(); + Self::Activity { + message: failure.message, + stack_trace, + scheduled_event_id: info.scheduled_event_id, + started_event_id: info.started_event_id, + identity: info.identity, + activity_type: info.activity_type.map(|t| t.name).unwrap_or_default(), + activity_id: info.activity_id, + retry_state, + cause, + } + } + Some(FailureInfo::ChildWorkflowExecutionFailureInfo(info)) => { + let retry_state = info.retry_state(); + let (workflow_id, run_id) = info + .workflow_execution + .map(|e| (e.workflow_id, e.run_id)) + .unwrap_or_default(); + Self::ChildWorkflow { + message: failure.message, + stack_trace, + namespace: info.namespace, + workflow_id, + run_id, + workflow_type: info.workflow_type.map(|t| t.name).unwrap_or_default(), + initiated_event_id: info.initiated_event_id, + started_event_id: info.started_event_id, + retry_state, + cause, + } + } + Some(FailureInfo::NexusOperationExecutionFailureInfo(info)) => Self::NexusOperation { + message: failure.message, + stack_trace, + scheduled_event_id: info.scheduled_event_id, + endpoint: info.endpoint, + service: info.service, + operation: info.operation, + operation_token: info.operation_token, + cause, + }, + Some(FailureInfo::NexusHandlerFailureInfo(info)) => { + let retry_behavior = info.retry_behavior(); + Self::NexusHandler { + message: failure.message, + stack_trace, + r#type: info.r#type, + retry_behavior, + cause, + } + } + Some(FailureInfo::ResetWorkflowFailureInfo(_)) | None => Self::Generic { + message: failure.message, + stack_trace, + cause, + }, + } + } + + fn into_failure(self) -> Failure { + let (message, stack_trace, failure_info, cause) = match self { + Self::Application { + message, + stack_trace, + r#type, + non_retryable, + details, + next_retry_delay, + cause, + } => ( + message, + stack_trace, + Some(FailureInfo::ApplicationFailureInfo( + ApplicationFailureInfo { + r#type, + non_retryable, + details, + next_retry_delay, + ..Default::default() + }, + )), + cause, + ), + Self::Timeout { + message, + stack_trace, + timeout_type, + last_heartbeat_details, + cause, + } => ( + message, + stack_trace, + Some(FailureInfo::TimeoutFailureInfo(TimeoutFailureInfo { + timeout_type: timeout_type.into(), + last_heartbeat_details, + })), + cause, + ), + Self::Cancelled { + message, + stack_trace, + details, + cause, + } => ( + message, + stack_trace, + Some(FailureInfo::CanceledFailureInfo(CanceledFailureInfo { + details, + })), + cause, + ), + Self::Terminated { + message, + stack_trace, + cause, + } => ( + message, + stack_trace, + Some(FailureInfo::TerminatedFailureInfo(TerminatedFailureInfo {})), + cause, + ), + Self::Server { + message, + stack_trace, + non_retryable, + cause, + } => ( + message, + stack_trace, + Some(FailureInfo::ServerFailureInfo(ServerFailureInfo { + non_retryable, + })), + cause, + ), + Self::Activity { + message, + stack_trace, + scheduled_event_id, + started_event_id, + identity, + activity_type, + activity_id, + retry_state, + cause, + } => ( + message, + stack_trace, + Some(FailureInfo::ActivityFailureInfo(ActivityFailureInfo { + scheduled_event_id, + started_event_id, + identity, + activity_type: Some(crate::protos::temporal::api::common::v1::ActivityType { + name: activity_type, + }), + activity_id, + retry_state: retry_state.into(), + })), + cause, + ), + Self::ChildWorkflow { + message, + stack_trace, + namespace, + workflow_id, + run_id, + workflow_type, + initiated_event_id, + started_event_id, + retry_state, + cause, + } => ( + message, + stack_trace, + Some(FailureInfo::ChildWorkflowExecutionFailureInfo( + ChildWorkflowExecutionFailureInfo { + namespace, + workflow_execution: Some( + crate::protos::temporal::api::common::v1::WorkflowExecution { + workflow_id, + run_id, + }, + ), + workflow_type: Some( + crate::protos::temporal::api::common::v1::WorkflowType { + name: workflow_type, + }, + ), + initiated_event_id, + started_event_id, + retry_state: retry_state.into(), + }, + )), + cause, + ), + Self::NexusOperation { + message, + stack_trace, + scheduled_event_id, + endpoint, + service, + operation, + operation_token, + cause, + } => ( + message, + stack_trace, + Some(FailureInfo::NexusOperationExecutionFailureInfo( + NexusOperationFailureInfo { + scheduled_event_id, + endpoint, + service, + operation, + operation_token, + ..Default::default() + }, + )), + cause, + ), + Self::NexusHandler { + message, + stack_trace, + r#type, + retry_behavior, + cause, + } => ( + message, + stack_trace, + Some(FailureInfo::NexusHandlerFailureInfo( + NexusHandlerFailureInfo { + r#type, + retry_behavior: retry_behavior.into(), + }, + )), + cause, + ), + Self::Generic { + message, + stack_trace, + cause, + } => (message, stack_trace, None, cause), + Self::Other(e) => ( + e.to_string(), + String::new(), + Some(FailureInfo::ApplicationFailureInfo( + ApplicationFailureInfo::default(), + )), + None, + ), + }; + + Failure { + message, + stack_trace, + source: "RustSDK".into(), + failure_info, + cause: cause.map(|c| Box::new(c.into_failure())), + ..Default::default() + } + } +} + /// Encodes and decodes payloads, enabling encryption or compression. pub trait PayloadCodec { /// Encode payloads before they are sent to the server. @@ -672,24 +1380,47 @@ impl Default for DataConverter { ) } } + impl FailureConverter for DefaultFailureConverter { fn to_failure( &self, - _: Box, - _: &PayloadConverter, - _: &SerializationContextData, - ) -> Result { - todo!() + error: Box, + _payload_converter: &PayloadConverter, + _context: &SerializationContextData, + ) -> Failure { + match error.downcast::() { + Ok(tf) => tf.into_failure(), + Err(error) => { + let mut failure = generic_error_to_application_failure(error.as_ref()); + failure.source = "RustSDK".into(); + failure + } + } } + fn to_error( &self, - _: Failure, - _: &PayloadConverter, - _: &SerializationContextData, - ) -> Result, PayloadConversionError> { - todo!() + failure: Failure, + payload_converter: &PayloadConverter, + _context: &SerializationContextData, + ) -> TemporalError { + TemporalError::from_failure(failure, payload_converter) + } +} + +fn generic_error_to_application_failure(error: &dyn std::error::Error) -> Failure { + Failure { + message: error.to_string(), + failure_info: Some(FailureInfo::ApplicationFailureInfo( + ApplicationFailureInfo::default(), + )), + cause: error + .source() + .map(|cause| Box::new(generic_error_to_application_failure(cause))), + ..Default::default() } } + impl PayloadCodec for DefaultPayloadCodec { fn encode( &self, @@ -769,6 +1500,13 @@ impl_multi_args!(MultiArgs6; 6; 0: A, 1: B, 2: C, 3: D, 4: E, 5: F); #[cfg(test)] mod tests { use super::*; + use crate::protos::temporal::api::failure::v1::{ + ActivityFailureInfo, ApplicationFailureInfo, CanceledFailureInfo, ServerFailureInfo, + TerminatedFailureInfo, TimeoutFailureInfo, failure::FailureInfo, + }; + use assert_matches::assert_matches; + use rstest::rstest; + use std::error::Error as _; #[test] fn test_empty_payloads_as_unit_type() { @@ -866,4 +1604,619 @@ mod tests { let args: MultiArgs2 = ("hello".to_string(), 42i32).into(); assert_eq!(args, MultiArgs2("hello".to_string(), 42)); } + + #[test] + fn plain_error_becomes_application_failure() { + let err: Box = + "something went wrong".to_string().into(); + let failure = DefaultFailureConverter.to_failure( + err, + &PayloadConverter::default(), + &SerializationContextData::Workflow, + ); + + assert_eq!(failure.message, "something went wrong"); + assert_matches!( + failure.failure_info, + Some(FailureInfo::ApplicationFailureInfo(_)) + ); + } + + #[test] + fn source_field_is_set() { + let err: Box = "test".to_string().into(); + let failure = DefaultFailureConverter.to_failure( + err, + &PayloadConverter::default(), + &SerializationContextData::Workflow, + ); + + assert_eq!(failure.source, "RustSDK"); + } + + #[rstest] + #[case::application( + "app error", + FailureInfo::ApplicationFailureInfo(ApplicationFailureInfo { + r#type: "MyError".into(), + non_retryable: true, + ..Default::default() + }) + )] + #[case::timeout( + "timed out", + FailureInfo::TimeoutFailureInfo(TimeoutFailureInfo { + timeout_type: 1, + ..Default::default() + }) + )] + #[case::canceled( + "canceled", + FailureInfo::CanceledFailureInfo(CanceledFailureInfo::default()) + )] + #[case::terminated( + "terminated", + FailureInfo::TerminatedFailureInfo(TerminatedFailureInfo {}) + )] + #[case::server( + "server error", + FailureInfo::ServerFailureInfo(ServerFailureInfo { non_retryable: true }) + )] + fn failure_type_round_trips(#[case] message: &str, #[case] info: FailureInfo) { + let converter = DefaultFailureConverter; + let pc = PayloadConverter::default(); + let ctx = SerializationContextData::Workflow; + + let original = Failure { + message: message.into(), + failure_info: Some(info.clone()), + ..Default::default() + }; + + let error = converter.to_error(original.clone(), &pc, &ctx); + assert_eq!(error.to_string(), message); + + let round_tripped = converter.to_failure(Box::new(error), &pc, &ctx); + assert_eq!(round_tripped.failure_info, Some(info)); + assert_eq!(round_tripped.message, original.message); + } + + // -- cause chain -- + + #[test] + fn cause_chain_is_preserved() { + let inner = Failure { + message: "root cause".into(), + failure_info: Some(FailureInfo::ApplicationFailureInfo( + ApplicationFailureInfo::default(), + )), + ..Default::default() + }; + let outer = Failure { + message: "outer error".into(), + cause: Some(Box::new(inner)), + failure_info: Some(FailureInfo::ApplicationFailureInfo( + ApplicationFailureInfo::default(), + )), + ..Default::default() + }; + + let error = DefaultFailureConverter.to_error( + outer, + &PayloadConverter::default(), + &SerializationContextData::Workflow, + ); + + let source = error.source().expect("error should have a source"); + assert_eq!(source.to_string(), "root cause"); + } + + #[derive(Debug, thiserror::Error)] + #[error("inner cause")] + struct InnerCause; + + #[derive(Debug, thiserror::Error)] + #[error("outer wrapper")] + struct OuterCause { + #[source] + source: InnerCause, + } + + #[test] + fn plain_error_source_chain_is_preserved_in_failure() { + let pc = PayloadConverter::default(); + let ctx = SerializationContextData::Workflow; + let error = OuterCause { source: InnerCause }; + + let failure = DefaultFailureConverter.to_failure(Box::new(error), &pc, &ctx); + + assert_eq!(failure.message, "outer wrapper"); + assert_matches!( + failure.failure_info, + Some(FailureInfo::ApplicationFailureInfo(_)) + ); + let cause = failure + .cause + .as_ref() + .expect("failure should preserve causes"); + assert_eq!(cause.message, "inner cause"); + assert_matches!( + cause.failure_info, + Some(FailureInfo::ApplicationFailureInfo(_)) + ); + + let round_tripped = DefaultFailureConverter.to_error(failure, &pc, &ctx); + let source = round_tripped.source().expect("error should have a source"); + assert_eq!(source.to_string(), "inner cause"); + } + + #[test] + fn deeply_nested_cause_chain() { + let pc = PayloadConverter::default(); + let ctx = SerializationContextData::Workflow; + + let level0 = Failure { + message: "level0".into(), + failure_info: Some(FailureInfo::ApplicationFailureInfo( + ApplicationFailureInfo::default(), + )), + ..Default::default() + }; + let level1 = Failure { + message: "level1".into(), + cause: Some(Box::new(level0)), + failure_info: Some(FailureInfo::ApplicationFailureInfo( + ApplicationFailureInfo::default(), + )), + ..Default::default() + }; + let level2 = Failure { + message: "level2".into(), + cause: Some(Box::new(level1)), + failure_info: Some(FailureInfo::ApplicationFailureInfo( + ApplicationFailureInfo::default(), + )), + ..Default::default() + }; + + let error = DefaultFailureConverter.to_error(level2, &pc, &ctx); + + let e1 = error.source().expect("should have level1"); + assert_eq!(e1.to_string(), "level1"); + let e0 = e1.source().expect("should have level0"); + assert_eq!(e0.to_string(), "level0"); + } + + // -- cross-SDK -- + + #[test] + fn cross_sdk_failure_deserializes() { + let foreign = Failure { + message: "something failed in TypeScript".into(), + source: "TypeScriptSDK".into(), + stack_trace: "at someFunction (file.ts:10:5)".into(), + failure_info: Some(FailureInfo::ApplicationFailureInfo( + ApplicationFailureInfo { + r#type: "Error".into(), + ..Default::default() + }, + )), + ..Default::default() + }; + + let error = DefaultFailureConverter.to_error( + foreign, + &PayloadConverter::default(), + &SerializationContextData::Workflow, + ); + assert_eq!(error.to_string(), "something failed in TypeScript"); + + assert_eq!(error.stack_trace(), Some("at someFunction (file.ts:10:5)")); + } + + #[test] + fn failure_with_no_info_deserializes() { + let bare = Failure { + message: "bare failure".into(), + ..Default::default() + }; + + let error = DefaultFailureConverter.to_error( + bare, + &PayloadConverter::default(), + &SerializationContextData::Workflow, + ); + assert_eq!(error.to_string(), "bare failure"); + } + + struct UpperCaseFailureConverter; + impl FailureConverter for UpperCaseFailureConverter { + fn to_failure( + &self, + error: Box, + _: &PayloadConverter, + _: &SerializationContextData, + ) -> Failure { + Failure { + message: error.to_string().to_uppercase(), + failure_info: Some(FailureInfo::ApplicationFailureInfo( + ApplicationFailureInfo::default(), + )), + ..Default::default() + } + } + + fn to_error( + &self, + failure: Failure, + _: &PayloadConverter, + _: &SerializationContextData, + ) -> TemporalError { + TemporalError::Other(failure.message.to_lowercase().into()) + } + } + + #[test] + fn custom_failure_converter_is_used() { + let custom = UpperCaseFailureConverter; + let pc = PayloadConverter::default(); + let ctx = SerializationContextData::Workflow; + + let err: Box = "hello world".to_string().into(); + let failure = custom.to_failure(err, &pc, &ctx); + assert_eq!(failure.message, "HELLO WORLD"); + + let error = custom.to_error(failure, &pc, &ctx); + assert_eq!(error.to_string(), "hello world"); + } + + #[test] + fn encoded_attributes_hides_message_and_stack_trace() { + let encoded_msg = serde_json::to_vec(&serde_json::json!({ + "message": "secret error", + "stack_trace": "at secret_fn (secret.rs:42)" + })) + .unwrap(); + + let failure = Failure { + message: "Encoded failure".into(), + stack_trace: String::new(), + encoded_attributes: Some(Payload { + metadata: { + let mut hm = HashMap::new(); + hm.insert("encoding".to_string(), b"json/plain".to_vec()); + hm + }, + data: encoded_msg, + external_payloads: vec![], + }), + failure_info: Some(FailureInfo::ApplicationFailureInfo( + ApplicationFailureInfo::default(), + )), + ..Default::default() + }; + + let error = DefaultFailureConverter.to_error( + failure, + &PayloadConverter::default(), + &SerializationContextData::Workflow, + ); + + assert_eq!(error.to_string(), "secret error"); + } + + #[test] + fn non_json_encoded_attributes_falls_back_to_message() { + let failure = Failure { + message: "fallback message".into(), + encoded_attributes: Some(Payload { + metadata: { + let mut hm = HashMap::new(); + hm.insert("encoding".to_string(), b"binary/protobuf".to_vec()); + hm + }, + data: vec![0xFF, 0xFE, 0x00], + external_payloads: vec![], + }), + failure_info: Some(FailureInfo::ApplicationFailureInfo( + ApplicationFailureInfo::default(), + )), + ..Default::default() + }; + + let error = DefaultFailureConverter.to_error( + failure, + &PayloadConverter::default(), + &SerializationContextData::Workflow, + ); + + assert_eq!(error.to_string(), "fallback message"); + } + + #[test] + fn stack_trace_round_trips() { + let converter = DefaultFailureConverter; + let pc = PayloadConverter::default(); + let ctx = SerializationContextData::Workflow; + + let original = Failure { + message: "oops".into(), + stack_trace: "at my_fn (lib.rs:42)".into(), + failure_info: Some(FailureInfo::ApplicationFailureInfo( + ApplicationFailureInfo::default(), + )), + ..Default::default() + }; + + let error = converter.to_error(original, &pc, &ctx); + assert_eq!(error.stack_trace(), Some("at my_fn (lib.rs:42)")); + + let round_tripped = converter.to_failure(Box::new(error), &pc, &ctx); + assert_eq!(round_tripped.stack_trace, "at my_fn (lib.rs:42)"); + } + + /// A codec that XOR-encodes payload data, used to verify that + /// `DataConverter::to_error`/`to_failure` apply the codec to failure payloads. + struct XorFailureCodec(u8); + impl PayloadCodec for XorFailureCodec { + fn encode( + &self, + _: &SerializationContextData, + payloads: Vec, + ) -> BoxFuture<'static, Vec> { + let key = self.0; + async move { + payloads + .into_iter() + .map(|mut p| { + p.data.iter_mut().for_each(|b| *b ^= key); + p + }) + .collect() + } + .boxed() + } + fn decode( + &self, + _: &SerializationContextData, + payloads: Vec, + ) -> BoxFuture<'static, Vec> { + // XOR is its own inverse + let key = self.0; + async move { + payloads + .into_iter() + .map(|mut p| { + p.data.iter_mut().for_each(|b| *b ^= key); + p + }) + .collect() + } + .boxed() + } + } + + struct ContextAwareFailureCodec; + impl PayloadCodec for ContextAwareFailureCodec { + fn encode( + &self, + _: &SerializationContextData, + payloads: Vec, + ) -> BoxFuture<'static, Vec> { + async move { payloads }.boxed() + } + + fn decode( + &self, + context: &SerializationContextData, + payloads: Vec, + ) -> BoxFuture<'static, Vec> { + let prefix = match context { + SerializationContextData::Workflow => b"wf:".as_slice(), + SerializationContextData::Activity => b"act:".as_slice(), + SerializationContextData::Nexus => b"nex:".as_slice(), + SerializationContextData::None => b"none:".as_slice(), + } + .to_vec(); + + async move { + payloads + .into_iter() + .map(|mut payload| { + if payload.data.starts_with(&prefix) { + payload.data.drain(..prefix.len()); + } + payload + }) + .collect() + } + .boxed() + } + } + + #[tokio::test] + async fn decode_failure_applies_codec_to_detail_payloads() { + let dc = DataConverter::new( + PayloadConverter::default(), + DefaultFailureConverter, + XorFailureCodec(0xAB), + ); + let ctx = SerializationContextData::Workflow; + + // Build a Failure proto with XOR-encoded detail payloads (simulating + // what the server would send after the send-path PayloadVisitable + // visitor encoded them). + let plaintext = b"\"some detail\"".to_vec(); + let encoded_data: Vec = plaintext.iter().map(|b| b ^ 0xAB).collect(); + let encoded_payload = Payload { + metadata: { + let mut hm = HashMap::new(); + hm.insert("encoding".to_string(), b"json/plain".to_vec()); + hm + }, + data: encoded_data, + external_payloads: vec![], + }; + let failure = Failure { + message: "test".into(), + failure_info: Some(FailureInfo::ApplicationFailureInfo( + ApplicationFailureInfo { + r#type: "TestError".into(), + details: Some(Payloads { + payloads: vec![encoded_payload], + }), + ..Default::default() + }, + )), + ..Default::default() + }; + + // decode_failure should decode the XOR'd payloads back to plaintext + let error = dc.decode_failure(failure, &ctx).await; + match &error { + TemporalError::Application { details, .. } => { + let payloads = details.as_ref().unwrap(); + assert_eq!( + payloads.payloads[0].data, plaintext, + "decoded detail payload should match original plaintext" + ); + } + other => panic!("expected Application, got {other:?}"), + } + } + + #[tokio::test] + async fn decode_failure_uses_nested_activity_context_for_activity_causes() { + let dc = DataConverter::new( + PayloadConverter::default(), + DefaultFailureConverter, + ContextAwareFailureCodec, + ); + let plaintext = b"\"activity detail\"".to_vec(); + let mut encoded_data = b"act:".to_vec(); + encoded_data.extend_from_slice(&plaintext); + let failure = Failure { + message: "Activity task failed".into(), + cause: Some(Box::new(Failure { + message: "application failure".into(), + failure_info: Some(FailureInfo::ApplicationFailureInfo( + ApplicationFailureInfo { + r#type: "TestError".into(), + details: Some(Payloads { + payloads: vec![Payload { + metadata: { + let mut hm = HashMap::new(); + hm.insert("encoding".to_string(), b"json/plain".to_vec()); + hm + }, + data: encoded_data, + external_payloads: vec![], + }], + }), + ..Default::default() + }, + )), + ..Default::default() + })), + failure_info: Some(FailureInfo::ActivityFailureInfo(ActivityFailureInfo { + activity_type: Some(crate::protos::temporal::api::common::v1::ActivityType { + name: "test-activity".into(), + }), + activity_id: "activity-id".into(), + ..Default::default() + })), + ..Default::default() + }; + + let error = dc + .decode_failure(failure, &SerializationContextData::Workflow) + .await; + + match error { + TemporalError::Activity { + cause: Some(cause), .. + } => match cause.as_ref() { + TemporalError::Application { + details: Some(details), + .. + } => { + assert_eq!(details.payloads[0].data, plaintext); + } + other => panic!("expected Application cause, got {other:?}"), + }, + other => panic!("expected Activity, got {other:?}"), + } + } + + /// Mimics the proto structure produced by the local activity state machine + /// when a local activity is cancelled via TryCancel. The outer failure has + /// `ActivityFailureInfo` and the cause has `CanceledFailureInfo`. + #[test] + fn la_cancel_produces_activity_with_cancelled_cause() { + let converter = DefaultFailureConverter; + let pc = PayloadConverter::default(); + let ctx = SerializationContextData::Workflow; + + // Inner failure: what Cancellation::from_details(None) produces + let cancel_cause = Failure { + message: "Activity cancelled".into(), + failure_info: Some(FailureInfo::CanceledFailureInfo(CanceledFailureInfo { + details: None, + })), + ..Default::default() + }; + + // Outer failure: what wrap_fail! produces around the cancel cause + let wrapped = Failure { + message: "Local Activity cancelled".into(), + cause: Some(Box::new(cancel_cause)), + failure_info: Some(FailureInfo::ActivityFailureInfo(ActivityFailureInfo { + activity_type: Some(crate::protos::temporal::api::common::v1::ActivityType { + name: "echo".into(), + }), + activity_id: "1".into(), + retry_state: RetryState::CancelRequested.into(), + ..Default::default() + })), + ..Default::default() + }; + + let te = converter.to_error(wrapped, &pc, &ctx); + + // Should be Activity { cause: Some(Cancelled { .. }) } + assert_matches!(&te, TemporalError::Activity { + message, + activity_type, + cause: Some(cause), + .. + } => { + assert_eq!(message, "Local Activity cancelled"); + assert_eq!(activity_type, "echo"); + assert_matches!(cause.as_ref(), TemporalError::Cancelled { message, .. } => { + assert_eq!(message, "Activity cancelled"); + }); + }); + } + + /// `Cancellation::from_details(None)` produces a Failure with + /// CanceledFailureInfo, which correctly converts to TemporalError::Cancelled. + #[test] + fn cancellation_from_details_produces_cancelled() { + let converter = DefaultFailureConverter; + let pc = PayloadConverter::default(); + let ctx = SerializationContextData::Workflow; + + let cancel = Failure { + message: "Activity cancelled".into(), + failure_info: Some(FailureInfo::CanceledFailureInfo(CanceledFailureInfo { + details: None, + })), + ..Default::default() + }; + let te = converter.to_error(cancel, &pc, &ctx); + + assert_matches!(&te, TemporalError::Cancelled { message, .. } => { + assert_eq!(message, "Activity cancelled"); + }); + } } diff --git a/crates/sdk-core/tests/integ_tests/async_activity_client_tests.rs b/crates/sdk-core/tests/integ_tests/async_activity_client_tests.rs index 4548d48f3..543b735fc 100644 --- a/crates/sdk-core/tests/integ_tests/async_activity_client_tests.rs +++ b/crates/sdk-core/tests/integ_tests/async_activity_client_tests.rs @@ -135,10 +135,8 @@ async fn async_activity_completions( } Outcome::Failure => { let err = activity_result.expect_err("expected failure"); - if let ActivityExecutionError::Failed(failure) = err { - // The failure we sent is wrapped as the cause - let cause = failure.cause.expect("cause should be present"); - assert_eq!(cause.message, "async failure reason"); + if let ActivityExecutionError::Failed { message, .. } = &err { + assert_eq!(message, "async failure reason"); } else { panic!("expected Failed, got {err:?}"); } @@ -146,7 +144,7 @@ async fn async_activity_completions( Outcome::Cancellation => { let err = activity_result.expect_err("expected cancellation"); assert!( - matches!(err, ActivityExecutionError::Cancelled(_)), + matches!(err, ActivityExecutionError::Cancelled { .. }), "expected Cancelled, got {err:?}" ); } diff --git a/crates/sdk-core/tests/integ_tests/data_converter_tests.rs b/crates/sdk-core/tests/integ_tests/data_converter_tests.rs index 5b1a1442c..1c09ede21 100644 --- a/crates/sdk-core/tests/integ_tests/data_converter_tests.rs +++ b/crates/sdk-core/tests/integ_tests/data_converter_tests.rs @@ -1,4 +1,4 @@ -use crate::common::{CoreWfStarter, get_integ_connection, integ_namespace}; +use crate::common::{CoreWfStarter, TestWorker, get_integ_connection, integ_namespace}; use futures::{FutureExt, future::BoxFuture}; use std::{ sync::{ @@ -7,21 +7,33 @@ use std::{ }, time::Duration, }; -use temporalio_client::{Client, ClientOptions, UntypedWorkflow, WorkflowStartOptions}; +use temporalio_client::{ + Client, ClientOptions, UntypedWorkflow, WorkflowGetResultError, WorkflowStartOptions, +}; use temporalio_common::{ data_converters::{ - DataConverter, DefaultFailureConverter, MultiArgs2, PayloadCodec, PayloadConversionError, - PayloadConverter, SerializationContext, SerializationContextData, TemporalDeserializable, - TemporalSerializable, + DataConverter, DefaultFailureConverter, FailureConverter, MultiArgs2, PayloadCodec, + PayloadConversionError, PayloadConverter, RawValue, SerializationContext, + SerializationContextData, TemporalDeserializable, TemporalError, TemporalSerializable, + }, + protos::{ + DEFAULT_WORKFLOW_TYPE, canned_histories, + temporal::api::{ + common::v1::Payload, enums::v1::WorkflowTaskFailedCause, + failure::v1::failure::FailureInfo, history::v1::history_event::Attributes, + }, }, - protos::temporal::api::{common::v1::Payload, history::v1::history_event::Attributes}, worker::WorkerTaskTypes, }; use temporalio_macros::{activities, workflow, workflow_methods}; use temporalio_sdk::{ - ActivityOptions, WorkflowContext, WorkflowResult, + ActivityOptions, CancellableFuture, ChildWorkflowExecutionError, ChildWorkflowOptions, + WorkflowContext, WorkflowResult, activities::{ActivityContext, ActivityError}, }; +use temporalio_sdk_core::test_help::{ + MockPollCfg, build_mock_pollers, mock_worker, mock_worker_client, +}; #[derive(Clone, Debug)] struct TrackedWrapper(TrackedValue); @@ -249,6 +261,122 @@ async fn multi_args_serializes_as_multiple_payloads() { assert_eq!(second_payload_data, 42); } +#[workflow] +#[derive(Default)] +struct FailingWorkflow; +#[workflow_methods] +impl FailingWorkflow { + #[run] + async fn run(_ctx: &mut WorkflowContext, message: String) -> WorkflowResult { + Err(temporalio_sdk::WorkflowTermination::Failed( + anyhow::anyhow!("{}", message), + )) + } +} + +#[tokio::test] +async fn failing_workflow_produces_failure_with_message() { + let wf_name = FailingWorkflow::name(); + let mut starter = CoreWfStarter::new(wf_name); + starter.sdk_config.register_workflow::(); + starter.sdk_config.task_types = WorkerTaskTypes::workflow_only(); + let mut worker = starter.worker().await; + + let task_queue = starter.get_task_queue().to_owned(); + let handle = worker + .submit_workflow( + FailingWorkflow::run, + "intentional failure".to_string(), + WorkflowStartOptions::new(task_queue, wf_name.to_owned()).build(), + ) + .await + .unwrap(); + worker.run_until_done().await.unwrap(); + + let res = handle.get_result(Default::default()).await.unwrap_err(); + if let WorkflowGetResultError::Failed(ref te) = res { + let message = te.message().unwrap(); + assert!( + message.contains("intentional failure"), + "failure message should contain the workflow error, got: {message}", + ); + } else { + panic!("expected Failed, got: {res:?}"); + } +} + +#[workflow] +#[derive(Default)] +struct FailingActivityWorkflow; +#[workflow_methods] +impl FailingActivityWorkflow { + #[run] + async fn run(ctx: &mut WorkflowContext, _input: String) -> WorkflowResult { + ctx.start_activity( + FailingActivities::always_fail, + "activity input".to_string(), + ActivityOptions { + start_to_close_timeout: Some(Duration::from_secs(5)), + retry_policy: Some( + temporalio_common::protos::temporal::api::common::v1::RetryPolicy { + maximum_attempts: 1, + ..Default::default() + }, + ), + ..Default::default() + }, + ) + .await + .map_err(|e| { + temporalio_sdk::WorkflowTermination::Failed(anyhow::anyhow!("activity failed: {}", e)) + }) + } +} + +struct FailingActivities; +#[activities] +impl FailingActivities { + #[activity] + async fn always_fail(_ctx: ActivityContext, _input: String) -> Result { + Err(ActivityError::NonRetryable( + "activity went boom".to_string().into(), + )) + } +} + +#[tokio::test] +async fn activity_failure_propagates_through_workflow() { + let wf_name = FailingActivityWorkflow::name(); + let mut starter = CoreWfStarter::new(wf_name); + starter + .sdk_config + .register_workflow::(); + starter.sdk_config.register_activities(FailingActivities); + let mut worker = starter.worker().await; + + let task_queue = starter.get_task_queue().to_owned(); + let handle = worker + .submit_workflow( + FailingActivityWorkflow::run, + "trigger".to_string(), + WorkflowStartOptions::new(task_queue, wf_name.to_owned()).build(), + ) + .await + .unwrap(); + worker.run_until_done().await.unwrap(); + + let res = handle.get_result(Default::default()).await.unwrap_err(); + if let WorkflowGetResultError::Failed(ref te) = res { + let message = te.message().unwrap(); + assert!( + message.contains("activity failed"), + "workflow failure should mention the activity error, got: {message}", + ); + } else { + panic!("expected Failed, got: {res:?}"); + } +} + /// A codec that XORs payload data with a key and tracks encode/decode operations. struct XorCodec { key: u8, @@ -379,3 +507,1080 @@ async fn codec_encodes_and_decodes_payloads() { "Codec should have decoded payloads, but decode_count was 0" ); } + +/// Apply `f` to every failure in the cause chain. +fn walk_failure_chain( + failure: &mut temporalio_common::protos::temporal::api::failure::v1::Failure, + mut f: impl FnMut(&mut temporalio_common::protos::temporal::api::failure::v1::Failure), +) { + let mut curr = Some(failure); + while let Some(node) = curr { + f(node); + curr = node.cause.as_deref_mut(); + } +} + +/// Custom failure converter that wraps the default, uppercasing outgoing +/// messages and lowercasing incoming ones. +struct UpperCaseFailureConverter; + +impl temporalio_common::data_converters::FailureConverter for UpperCaseFailureConverter { + fn to_failure( + &self, + error: Box, + payload_converter: &PayloadConverter, + context: &SerializationContextData, + ) -> temporalio_common::protos::temporal::api::failure::v1::Failure { + let mut failure = DefaultFailureConverter.to_failure(error, payload_converter, context); + walk_failure_chain(&mut failure, |f| f.message = f.message.to_uppercase()); + failure + } + + fn to_error( + &self, + mut failure: temporalio_common::protos::temporal::api::failure::v1::Failure, + payload_converter: &PayloadConverter, + context: &SerializationContextData, + ) -> temporalio_common::data_converters::TemporalError { + walk_failure_chain(&mut failure, |f| f.message = f.message.to_lowercase()); + DefaultFailureConverter.to_error(failure, payload_converter, context) + } +} + +struct PrefixFailureConverter { + outgoing_prefix: &'static str, + incoming_prefix: &'static str, +} + +impl FailureConverter for PrefixFailureConverter { + fn to_failure( + &self, + error: Box, + payload_converter: &PayloadConverter, + context: &SerializationContextData, + ) -> temporalio_common::protos::temporal::api::failure::v1::Failure { + let mut failure = DefaultFailureConverter.to_failure(error, payload_converter, context); + if !self.outgoing_prefix.is_empty() { + walk_failure_chain(&mut failure, |f| { + f.message = format!("{}{}", self.outgoing_prefix, f.message); + }); + } + failure + } + + fn to_error( + &self, + mut failure: temporalio_common::protos::temporal::api::failure::v1::Failure, + payload_converter: &PayloadConverter, + context: &SerializationContextData, + ) -> TemporalError { + if !self.incoming_prefix.is_empty() { + walk_failure_chain(&mut failure, |f| { + f.message = format!("{}{}", self.incoming_prefix, f.message); + }); + } + DefaultFailureConverter.to_error(failure, payload_converter, context) + } +} + +struct GenericFailureConverter; + +impl FailureConverter for GenericFailureConverter { + fn to_failure( + &self, + error: Box, + payload_converter: &PayloadConverter, + context: &SerializationContextData, + ) -> temporalio_common::protos::temporal::api::failure::v1::Failure { + let mut failure = DefaultFailureConverter.to_failure(error, payload_converter, context); + failure.message = format!("generic:{}", failure.message); + failure.failure_info = None; + failure + } + + fn to_error( + &self, + failure: temporalio_common::protos::temporal::api::failure::v1::Failure, + payload_converter: &PayloadConverter, + context: &SerializationContextData, + ) -> TemporalError { + DefaultFailureConverter.to_error(failure, payload_converter, context) + } +} + +#[workflow] +#[derive(Default)] +struct CustomConverterFailWorkflow; +#[workflow_methods] +impl CustomConverterFailWorkflow { + #[run] + async fn run(_ctx: &mut WorkflowContext, message: String) -> WorkflowResult { + Err(temporalio_sdk::WorkflowTermination::Failed( + anyhow::anyhow!("{}", message), + )) + } +} + +#[tokio::test] +async fn custom_failure_converter_applied_to_workflow_failure() { + let wf_name = CustomConverterFailWorkflow::name(); + + let connection = get_integ_connection(None).await; + let data_converter = DataConverter::new( + PayloadConverter::default(), + UpperCaseFailureConverter, + temporalio_common::data_converters::DefaultPayloadCodec, + ); + let client_opts = ClientOptions::new(integ_namespace()) + .data_converter(data_converter) + .build(); + let client = Client::new(connection, client_opts).unwrap(); + + let mut starter = CoreWfStarter::new_with_overrides(wf_name, None, Some(client)); + starter + .sdk_config + .register_workflow::(); + starter.sdk_config.task_types = WorkerTaskTypes::workflow_only(); + let wf_id = starter.get_task_queue().to_owned(); + let mut worker = starter.worker().await; + + let task_queue = starter.get_task_queue().to_owned(); + worker + .submit_workflow( + CustomConverterFailWorkflow::run, + "should be uppercased".to_string(), + WorkflowStartOptions::new(task_queue, wf_id.clone()).build(), + ) + .await + .unwrap(); + worker.run_until_done().await.unwrap(); + + // Fetch the raw failure from history to see what the worker actually sent. + let client = starter.get_client().await; + let events = client + .get_workflow_handle::(&wf_id) + .fetch_history(Default::default()) + .await + .unwrap() + .into_events(); + + let failed_attrs = events + .iter() + .find_map(|e| { + if let Attributes::WorkflowExecutionFailedEventAttributes(attrs) = + e.attributes.as_ref().unwrap() + { + Some(attrs) + } else { + None + } + }) + .expect("Should find WorkflowExecutionFailed event"); + + let failure = failed_attrs.failure.as_ref().expect("should have failure"); + + // The custom converter uppercases, so the server-side failure message must + // be all-uppercase. Without the converter wired up, this would contain the + // raw lowercase message. + assert_eq!( + failure.message, "SHOULD BE UPPERCASED", + "Failure message on server should reflect custom converter (uppercase), got: {}", + failure.message, + ); +} + +#[workflow] +struct PanickingOnceWorkflow { + did_panic: Arc, +} + +#[workflow_methods(factory_only)] +impl PanickingOnceWorkflow { + #[run(name = DEFAULT_WORKFLOW_TYPE)] + async fn run(ctx: &mut WorkflowContext) -> WorkflowResult<()> { + ctx.timer(Duration::from_secs(1)).await; + if ctx.state(|wf| wf.did_panic.fetch_add(1, Ordering::SeqCst)) == 0 { + panic!("workflow panic marker"); + } + Ok(()) + } +} + +#[tokio::test] +async fn custom_failure_converter_applied_to_workflow_panic_failures() { + let wf_id = "workflow-panic-failure-converter"; + let history = canned_histories::workflow_fails_with_failure_after_timer("1"); + let mock_client = mock_worker_client(); + let mut mock_cfg = MockPollCfg::from_resp_batches(wf_id, history, [1, 2, 2], mock_client); + mock_cfg.using_rust_sdk = true; + mock_cfg.num_expected_fails = 1; + mock_cfg.expect_fail_wft_matcher = Box::new(|_, cause, failure| { + *cause == WorkflowTaskFailedCause::WorkflowWorkerUnhandledFailure + && failure.as_ref().is_some_and(|failure| { + failure.message + == "panic-converted:Workflow function panicked: workflow panic marker" + && matches!( + failure.failure_info, + Some(FailureInfo::ApplicationFailureInfo(_)) + ) + }) + }); + + let mut mock = build_mock_pollers(mock_cfg); + mock.worker_cfg(|cfg| { + cfg.max_cached_workflows = 1; + cfg.ignore_evicts_on_shutdown = false; + }); + let core = mock_worker(mock); + let data_converter = DataConverter::new( + PayloadConverter::default(), + PrefixFailureConverter { + outgoing_prefix: "panic-converted:", + incoming_prefix: "", + }, + temporalio_common::data_converters::DefaultPayloadCodec, + ); + let mut worker = TestWorker::new(temporalio_sdk::Worker::new_from_core( + Arc::new(core), + data_converter, + )); + + let did_panic = Arc::new(AtomicUsize::new(0)); + let did_panic_clone = did_panic.clone(); + worker.register_workflow_with_factory(move || PanickingOnceWorkflow { + did_panic: did_panic_clone.clone(), + }); + worker + .submit_wf( + DEFAULT_WORKFLOW_TYPE, + vec![], + WorkflowStartOptions::new("fake_tq".to_owned(), wf_id.to_owned()).build(), + ) + .await + .unwrap(); + worker.run_until_done().await.unwrap(); +} + +#[workflow] +#[derive(Default)] +struct FailWithDetailsWorkflow; +#[workflow_methods] +impl FailWithDetailsWorkflow { + #[run] + async fn run(_ctx: &mut WorkflowContext, _input: String) -> WorkflowResult { + use temporalio_common::{ + data_converters::TemporalError, protos::temporal::api::common::v1::Payloads, + }; + + // Build detail payloads with a known plaintext marker. + let detail_payload = Payload { + metadata: { + let mut hm = std::collections::HashMap::new(); + hm.insert("encoding".to_string(), b"json/plain".to_vec()); + hm + }, + data: b"\"detail-marker-plaintext\"".to_vec(), + external_payloads: vec![], + }; + + let tf = TemporalError::Application { + message: "fail with details".to_string(), + stack_trace: String::new(), + r#type: String::new(), + non_retryable: false, + details: Some(Payloads { + payloads: vec![detail_payload], + }), + next_retry_delay: None, + cause: None, + }; + + Err(temporalio_sdk::WorkflowTermination::Failed(tf.into())) + } +} + +#[tokio::test] +async fn codec_applied_to_outgoing_workflow_failure_payloads() { + let wf_name = FailWithDetailsWorkflow::name(); + let codec = Arc::new(XorCodec::new(0x42)); + + let connection = get_integ_connection(None).await; + let data_converter = DataConverter::new( + PayloadConverter::default(), + DefaultFailureConverter, + codec.clone(), + ); + let client_opts = ClientOptions::new(integ_namespace()) + .data_converter(data_converter) + .build(); + let client = Client::new(connection, client_opts).unwrap(); + + let mut starter = CoreWfStarter::new_with_overrides(wf_name, None, Some(client)); + starter + .sdk_config + .register_workflow::(); + starter.sdk_config.task_types = WorkerTaskTypes::workflow_only(); + let wf_id = starter.get_task_queue().to_owned(); + let mut worker = starter.worker().await; + + let task_queue = starter.get_task_queue().to_owned(); + worker + .submit_workflow( + FailWithDetailsWorkflow::run, + "trigger".to_string(), + WorkflowStartOptions::new(task_queue, wf_id.clone()).build(), + ) + .await + .unwrap(); + worker.run_until_done().await.unwrap(); + + // Fetch raw history — the server stores whatever the worker sent, so if the + // codec was applied the detail payloads will NOT contain our plaintext marker. + let client = starter.get_client().await; + let events = client + .get_workflow_handle::(&wf_id) + .fetch_history(Default::default()) + .await + .unwrap() + .into_events(); + + let failed_attrs = events + .iter() + .find_map(|e| { + if let Attributes::WorkflowExecutionFailedEventAttributes(attrs) = + e.attributes.as_ref().unwrap() + { + Some(attrs) + } else { + None + } + }) + .expect("Should find WorkflowExecutionFailed event"); + + let failure = failed_attrs.failure.as_ref().expect("should have failure"); + + // Dig into ApplicationFailureInfo to find the detail payloads. + use temporalio_common::protos::temporal::api::failure::v1::failure::FailureInfo; + let details = match failure.failure_info.as_ref() { + Some(FailureInfo::ApplicationFailureInfo(info)) => info.details.as_ref(), + other => panic!("Expected ApplicationFailureInfo, got: {:?}", other), + }; + + let detail_payloads = details.expect("should have detail payloads"); + assert!( + !detail_payloads.payloads.is_empty(), + "should have at least one detail payload" + ); + + // The plaintext marker should NOT appear in the raw data if the codec was applied. + let raw_data = &detail_payloads.payloads[0].data; + let raw_str = String::from_utf8_lossy(raw_data); + assert!( + !raw_str.contains("detail-marker-plaintext"), + "Detail payload on server should be encoded by the codec, but found plaintext: {}", + raw_str, + ); +} + +// --------------------------------------------------------------------------- +// Send path: codec is applied to outgoing activity failure payloads +// --------------------------------------------------------------------------- + +struct FailWithDetailsActivities; +#[activities] +impl FailWithDetailsActivities { + #[activity] + async fn fail_with_details( + _ctx: ActivityContext, + _input: String, + ) -> Result { + use temporalio_common::{ + data_converters::TemporalError, protos::temporal::api::common::v1::Payloads, + }; + + let detail_payload = Payload { + metadata: { + let mut hm = std::collections::HashMap::new(); + hm.insert("encoding".to_string(), b"json/plain".to_vec()); + hm + }, + data: b"\"activity-detail-plaintext\"".to_vec(), + external_payloads: vec![], + }; + + let tf = TemporalError::Application { + message: "activity fail with details".to_string(), + stack_trace: String::new(), + r#type: String::new(), + non_retryable: true, + details: Some(Payloads { + payloads: vec![detail_payload], + }), + next_retry_delay: None, + cause: None, + }; + + Err(ActivityError::NonRetryable(tf.into())) + } +} + +struct PanickingActivities; +#[activities] +impl PanickingActivities { + #[activity] + async fn always_panic(_ctx: ActivityContext, _input: String) -> Result { + panic!("activity panic marker"); + } +} + +#[workflow] +#[derive(Default)] +struct ActivityFailDetailsWorkflow; +#[workflow_methods] +impl ActivityFailDetailsWorkflow { + #[run] + async fn run(ctx: &mut WorkflowContext, _input: String) -> WorkflowResult { + ctx.start_activity( + FailWithDetailsActivities::fail_with_details, + "trigger".to_string(), + ActivityOptions { + start_to_close_timeout: Some(Duration::from_secs(5)), + retry_policy: Some( + temporalio_common::protos::temporal::api::common::v1::RetryPolicy { + maximum_attempts: 1, + ..Default::default() + }, + ), + ..Default::default() + }, + ) + .await + .map_err(|e| { + temporalio_sdk::WorkflowTermination::Failed(anyhow::anyhow!("activity failed: {}", e)) + }) + } +} + +#[workflow] +#[derive(Default)] +struct ActivityPanicWorkflow; +#[workflow_methods] +impl ActivityPanicWorkflow { + #[run] + async fn run(ctx: &mut WorkflowContext, _input: String) -> WorkflowResult { + let err = ctx + .start_activity( + PanickingActivities::always_panic, + "trigger".to_string(), + ActivityOptions { + start_to_close_timeout: Some(Duration::from_secs(5)), + retry_policy: Some( + temporalio_common::protos::temporal::api::common::v1::RetryPolicy { + maximum_attempts: 1, + ..Default::default() + }, + ), + ..Default::default() + }, + ) + .await + .unwrap_err(); + + Ok(err.to_string()) + } +} + +#[tokio::test] +async fn codec_applied_to_outgoing_activity_failure_payloads() { + let wf_name = ActivityFailDetailsWorkflow::name(); + let codec = Arc::new(XorCodec::new(0x42)); + + let connection = get_integ_connection(None).await; + let data_converter = DataConverter::new( + PayloadConverter::default(), + DefaultFailureConverter, + codec.clone(), + ); + let client_opts = ClientOptions::new(integ_namespace()) + .data_converter(data_converter) + .build(); + let client = Client::new(connection, client_opts).unwrap(); + + let mut starter = CoreWfStarter::new_with_overrides(wf_name, None, Some(client)); + starter + .sdk_config + .register_workflow::(); + starter + .sdk_config + .register_activities(FailWithDetailsActivities); + let wf_id = starter.get_task_queue().to_owned(); + let mut worker = starter.worker().await; + + let task_queue = starter.get_task_queue().to_owned(); + worker + .submit_workflow( + ActivityFailDetailsWorkflow::run, + "trigger".to_string(), + WorkflowStartOptions::new(task_queue, wf_id.clone()).build(), + ) + .await + .unwrap(); + worker.run_until_done().await.unwrap(); + + // Fetch raw history to inspect the ActivityTaskFailed event. + let client = starter.get_client().await; + let events = client + .get_workflow_handle::(&wf_id) + .fetch_history(Default::default()) + .await + .unwrap() + .into_events(); + + let activity_failed_attrs = events + .iter() + .find_map(|e| { + if let Attributes::ActivityTaskFailedEventAttributes(attrs) = + e.attributes.as_ref().unwrap() + { + Some(attrs) + } else { + None + } + }) + .expect("Should find ActivityTaskFailed event"); + + let failure = activity_failed_attrs + .failure + .as_ref() + .expect("should have failure"); + + // Walk through cause chain — the activity failure wraps an application failure. + use temporalio_common::protos::temporal::api::failure::v1::failure::FailureInfo; + let app_failure = failure + .cause + .as_ref() + .map(|c| c.as_ref()) + .unwrap_or(failure); + + let details = match app_failure.failure_info.as_ref() { + Some(FailureInfo::ApplicationFailureInfo(info)) => info.details.as_ref(), + other => panic!("Expected ApplicationFailureInfo, got: {:?}", other), + }; + + let detail_payloads = details.expect("should have detail payloads"); + assert!( + !detail_payloads.payloads.is_empty(), + "should have at least one detail payload" + ); + + let raw_data = &detail_payloads.payloads[0].data; + let raw_str = String::from_utf8_lossy(raw_data); + assert!( + !raw_str.contains("activity-detail-plaintext"), + "Activity failure detail payload on server should be encoded by codec, \ + but found plaintext: {}", + raw_str, + ); +} + +#[tokio::test] +async fn custom_failure_converter_applied_to_activity_panic_failures() { + let wf_name = ActivityPanicWorkflow::name(); + + let connection = get_integ_connection(None).await; + let data_converter = DataConverter::new( + PayloadConverter::default(), + PrefixFailureConverter { + outgoing_prefix: "activity-converted:", + incoming_prefix: "", + }, + temporalio_common::data_converters::DefaultPayloadCodec, + ); + let client_opts = ClientOptions::new(integ_namespace()) + .data_converter(data_converter) + .build(); + let client = Client::new(connection, client_opts).unwrap(); + + let mut starter = CoreWfStarter::new_with_overrides(wf_name, None, Some(client)); + starter + .sdk_config + .register_workflow::(); + starter.sdk_config.register_activities(PanickingActivities); + let wf_id = starter.get_task_queue().to_owned(); + let mut worker = starter.worker().await; + + let task_queue = starter.get_task_queue().to_owned(); + let handle = worker + .submit_workflow( + ActivityPanicWorkflow::run, + "trigger".to_string(), + WorkflowStartOptions::new(task_queue, wf_id.clone()).build(), + ) + .await + .unwrap(); + worker.run_until_done().await.unwrap(); + handle.get_result(Default::default()).await.unwrap(); + + let client = starter.get_client().await; + let events = client + .get_workflow_handle::(&wf_id) + .fetch_history(Default::default()) + .await + .unwrap() + .into_events(); + + let activity_failed_attrs = events + .iter() + .find_map(|e| { + if let Attributes::ActivityTaskFailedEventAttributes(attrs) = + e.attributes.as_ref().unwrap() + { + Some(attrs) + } else { + None + } + }) + .expect("Should find ActivityTaskFailed event"); + + let failure = activity_failed_attrs + .failure + .as_ref() + .expect("should have failure"); + let app_failure = failure.cause.as_deref().unwrap_or(failure); + assert_eq!( + app_failure.message, + "activity-converted:Activity function panicked: activity panic marker", + ); +} + +// --------------------------------------------------------------------------- +// Receive path: activity failures are converted through the failure converter +// before reaching workflow code +// --------------------------------------------------------------------------- + +/// Workflow that captures the activity error message into its result so the +/// test can inspect it from outside. +#[workflow] +#[derive(Default)] +struct ActivityFailConverterWorkflow; +#[workflow_methods] +impl ActivityFailConverterWorkflow { + #[run] + async fn run(ctx: &mut WorkflowContext, _input: String) -> WorkflowResult { + let err = ctx + .start_activity( + FailingActivities::always_fail, + "input".to_string(), + ActivityOptions { + start_to_close_timeout: Some(Duration::from_secs(5)), + retry_policy: Some( + temporalio_common::protos::temporal::api::common::v1::RetryPolicy { + maximum_attempts: 1, + ..Default::default() + }, + ), + ..Default::default() + }, + ) + .await + .unwrap_err(); + + Ok(err.to_string()) + } +} + +#[workflow] +#[derive(Default)] +struct NonRetryableActivityWithCustomConverterWorkflow; +#[workflow_methods] +impl NonRetryableActivityWithCustomConverterWorkflow { + #[run] + async fn run(ctx: &mut WorkflowContext, _input: String) -> WorkflowResult { + let err = ctx + .start_activity( + FailingActivities::always_fail, + "input".to_string(), + ActivityOptions { + start_to_close_timeout: Some(Duration::from_secs(5)), + retry_policy: Some( + temporalio_common::protos::temporal::api::common::v1::RetryPolicy { + initial_interval: Some(temporalio_common::prost_dur!(from_millis(10))), + backoff_coefficient: 1.0, + maximum_attempts: 5, + ..Default::default() + }, + ), + ..Default::default() + }, + ) + .await + .unwrap_err(); + + Ok(err.to_string()) + } +} + +#[workflow] +#[derive(Default)] +struct ChildStartCancelledFromHistoryWorkflow; + +#[workflow_methods] +impl ChildStartCancelledFromHistoryWorkflow { + #[run(name = DEFAULT_WORKFLOW_TYPE)] + async fn run(ctx: &mut WorkflowContext) -> WorkflowResult<()> { + let start = ctx.child_workflow( + temporalio_common::UntypedWorkflow::new("child"), + RawValue::new(vec![]), + ChildWorkflowOptions { + workflow_id: "child-id-1".to_owned(), + cancel_type: + temporalio_common::protos::coresdk::child_workflow::ChildWorkflowCancellationType::WaitCancellationCompleted, + ..Default::default() + }, + ); + + start.cancel(); + let err = start + .await + .expect_err("child start should resolve as cancelled"); + + match err { + ChildWorkflowExecutionError::Cancelled { source, .. } => match source.as_ref() { + TemporalError::ChildWorkflow { + workflow_id, + cause: Some(cause), + .. + } => { + assert_eq!(workflow_id, "child-id-1"); + assert!( + cause.message().unwrap_or_default().starts_with("decoded:"), + "expected decoded cancellation cause, got {cause:?}" + ); + } + other => panic!("expected child workflow metadata, got {other:?}"), + }, + other => panic!("expected cancelled child start, got {other:?}"), + } + + Ok(()) + } +} + +#[tokio::test] +async fn activity_failure_converted_through_failure_converter() { + let wf_name = ActivityFailConverterWorkflow::name(); + + let connection = get_integ_connection(None).await; + let data_converter = DataConverter::new( + PayloadConverter::default(), + UpperCaseFailureConverter, + temporalio_common::data_converters::DefaultPayloadCodec, + ); + let client_opts = ClientOptions::new(integ_namespace()) + .data_converter(data_converter) + .build(); + let client = Client::new(connection, client_opts).unwrap(); + + let mut starter = CoreWfStarter::new_with_overrides(wf_name, None, Some(client)); + starter + .sdk_config + .register_workflow::(); + starter.sdk_config.register_activities(FailingActivities); + let wf_id = starter.get_task_queue().to_owned(); + let mut worker = starter.worker().await; + + let task_queue = starter.get_task_queue().to_owned(); + let handle = worker + .submit_workflow( + ActivityFailConverterWorkflow::run, + "trigger".to_string(), + WorkflowStartOptions::new(task_queue, wf_id).build(), + ) + .await + .unwrap(); + worker.run_until_done().await.unwrap(); + + let result = handle.get_result(Default::default()).await.unwrap(); + assert_eq!(result, "Activity failed: activity went boom"); +} + +#[tokio::test] +async fn custom_failure_converter_preserves_non_retryable_activity_errors() { + let wf_name = NonRetryableActivityWithCustomConverterWorkflow::name(); + + let connection = get_integ_connection(None).await; + let data_converter = DataConverter::new( + PayloadConverter::default(), + GenericFailureConverter, + temporalio_common::data_converters::DefaultPayloadCodec, + ); + let client_opts = ClientOptions::new(integ_namespace()) + .data_converter(data_converter) + .build(); + let client = Client::new(connection, client_opts).unwrap(); + + let mut starter = CoreWfStarter::new_with_overrides(wf_name, None, Some(client)); + starter + .sdk_config + .register_workflow::(); + starter.sdk_config.register_activities(FailingActivities); + let wf_id = starter.get_task_queue().to_owned(); + let mut worker = starter.worker().await; + + let task_queue = starter.get_task_queue().to_owned(); + let handle = worker + .submit_workflow( + NonRetryableActivityWithCustomConverterWorkflow::run, + "trigger".to_string(), + WorkflowStartOptions::new(task_queue, wf_id.clone()).build(), + ) + .await + .unwrap(); + worker.run_until_done().await.unwrap(); + + let result = handle.get_result(Default::default()).await.unwrap(); + assert!( + result.contains("activity went boom"), + "workflow should still observe the activity failure, got: {result}", + ); + + let client = starter.get_client().await; + let events = client + .get_workflow_handle::(&wf_id) + .fetch_history(Default::default()) + .await + .unwrap() + .into_events(); + + let activity_started_count = events + .iter() + .filter(|e| { + matches!( + e.attributes.as_ref(), + Some(Attributes::ActivityTaskStartedEventAttributes(_)) + ) + }) + .count(); + assert_eq!( + activity_started_count, 1, + "ActivityError::NonRetryable must suppress retries even when a custom failure converter \ + returns a non-application failure", + ); + + let activity_failed_attrs = events + .iter() + .find_map(|e| { + if let Attributes::ActivityTaskFailedEventAttributes(attrs) = + e.attributes.as_ref().unwrap() + { + Some(attrs) + } else { + None + } + }) + .expect("Should find ActivityTaskFailed event"); + let failure = activity_failed_attrs + .failure + .as_ref() + .expect("should have failure"); + let app_info = std::iter::successors(Some(failure), |f| f.cause.as_deref()) + .find_map(|f| match f.failure_info.as_ref() { + Some(FailureInfo::ApplicationFailureInfo(info)) => Some(info), + _ => None, + }) + .expect("Should find ApplicationFailureInfo in failure chain"); + assert!( + app_info.non_retryable, + "ActivityError::NonRetryable must leave a non-retryable marker on the failure chain", + ); +} + +#[tokio::test] +async fn child_start_cancellation_converted_through_failure_converter() { + let mut history = temporalio_common::protos::TestHistoryBuilder::default(); + history.add_by_type( + temporalio_common::protos::temporal::api::enums::v1::EventType::WorkflowExecutionStarted, + ); + history.add_full_wf_task(); + history.add_workflow_execution_completed(); + + let mut mock_cfg = MockPollCfg::from_hist_builder(history); + mock_cfg.completion_asserts_from_expectations(|mut asserts| { + asserts.then(|wft| { + assert_eq!(wft.commands.len(), 1); + assert_matches!( + wft.commands[0].command_type(), + temporalio_common::protos::temporal::api::enums::v1::CommandType::CompleteWorkflowExecution + ); + }); + }); + mock_cfg.using_rust_sdk = true; + + let mut mock = build_mock_pollers(mock_cfg); + mock.worker_cfg(|cfg| { + cfg.max_cached_workflows = 1; + cfg.ignore_evicts_on_shutdown = false; + }); + let core = mock_worker(mock); + let data_converter = DataConverter::new( + PayloadConverter::default(), + PrefixFailureConverter { + outgoing_prefix: "", + incoming_prefix: "decoded:", + }, + temporalio_common::data_converters::DefaultPayloadCodec, + ); + let mut worker = TestWorker::new(temporalio_sdk::Worker::new_from_core( + Arc::new(core), + data_converter, + )); + worker.register_workflow::(); + worker.run_until_done().await.unwrap(); +} + +// --------------------------------------------------------------------------- +// Receive path: codec decodes activity failure payloads before the workflow +// sees them +// --------------------------------------------------------------------------- + +/// Activity that fails with detail payloads (used by the codec receive test). +struct CodecFailActivities; +#[activities] +impl CodecFailActivities { + #[activity] + async fn fail_with_encoded_details( + _ctx: ActivityContext, + _input: String, + ) -> Result { + use temporalio_common::{ + data_converters::TemporalError, protos::temporal::api::common::v1::Payloads, + }; + + let detail_payload = Payload { + metadata: { + let mut hm = std::collections::HashMap::new(); + hm.insert("encoding".to_string(), b"json/plain".to_vec()); + hm + }, + data: b"\"readable-detail\"".to_vec(), + external_payloads: vec![], + }; + + let tf = TemporalError::Application { + message: "activity with details".to_string(), + stack_trace: String::new(), + r#type: String::new(), + non_retryable: true, + details: Some(Payloads { + payloads: vec![detail_payload], + }), + next_retry_delay: None, + cause: None, + }; + + Err(ActivityError::NonRetryable(tf.into())) + } +} + +/// Workflow that catches the activity error and extracts the detail payload +/// to return as its result, so the test can verify the payload was decoded. +#[workflow] +#[derive(Default)] +struct ActivityCodecDecodeWorkflow; +#[workflow_methods] +impl ActivityCodecDecodeWorkflow { + #[run] + async fn run(ctx: &mut WorkflowContext, _input: String) -> WorkflowResult { + use temporalio_common::data_converters::TemporalError; + + let err = ctx + .start_activity( + CodecFailActivities::fail_with_encoded_details, + "input".to_string(), + ActivityOptions { + start_to_close_timeout: Some(Duration::from_secs(5)), + retry_policy: Some( + temporalio_common::protos::temporal::api::common::v1::RetryPolicy { + maximum_attempts: 1, + ..Default::default() + }, + ), + ..Default::default() + }, + ) + .await + .unwrap_err(); + + // Try to extract the detail payload from the error. If the failure + // converter + codec are wired up on the receive path, the error will be + // a TemporalError with decoded detail payloads. + let err_str = format!("{}", err); + if let temporalio_sdk::ActivityExecutionError::Failed { source: e, .. } = err + && let TemporalError::Activity { + cause: Some(tf), .. + } = e.as_ref() + && let TemporalError::Application { + details: Some(payloads), + .. + } = tf.as_ref() + && let Some(p) = payloads.payloads.first() + { + let detail_str = String::from_utf8_lossy(&p.data); + return Ok(format!("detail:{}", detail_str)); + } + + Ok(format!("raw_error:{}", err_str)) + } +} + +#[tokio::test] +async fn codec_decodes_activity_failure_payloads_on_receive() { + let wf_name = ActivityCodecDecodeWorkflow::name(); + let codec = Arc::new(XorCodec::new(0x42)); + + let connection = get_integ_connection(None).await; + let data_converter = DataConverter::new( + PayloadConverter::default(), + DefaultFailureConverter, + codec.clone(), + ); + let client_opts = ClientOptions::new(integ_namespace()) + .data_converter(data_converter) + .build(); + let client = Client::new(connection, client_opts).unwrap(); + + let mut starter = CoreWfStarter::new_with_overrides(wf_name, None, Some(client)); + starter + .sdk_config + .register_workflow::(); + starter.sdk_config.register_activities(CodecFailActivities); + let wf_id = starter.get_task_queue().to_owned(); + let mut worker = starter.worker().await; + + let task_queue = starter.get_task_queue().to_owned(); + let handle = worker + .submit_workflow( + ActivityCodecDecodeWorkflow::run, + "trigger".to_string(), + WorkflowStartOptions::new(task_queue, wf_id).build(), + ) + .await + .unwrap(); + worker.run_until_done().await.unwrap(); + + let result = handle.get_result(Default::default()).await.unwrap(); + + // If the codec + failure converter are wired up on the receive path, the + // workflow should see a TemporalError with the decoded detail payload + // containing the original "readable-detail" string. + assert!( + result.starts_with("detail:"), + "Workflow should receive a TemporalError with decoded detail payloads, got: {}", + result, + ); + assert!( + result.contains("readable-detail"), + "Detail payload should be decoded (readable), got: {}", + result, + ); +} diff --git a/crates/sdk-core/tests/integ_tests/heartbeat_tests.rs b/crates/sdk-core/tests/integ_tests/heartbeat_tests.rs index 80551b3ec..cc23aca5b 100644 --- a/crates/sdk-core/tests/integ_tests/heartbeat_tests.rs +++ b/crates/sdk-core/tests/integ_tests/heartbeat_tests.rs @@ -54,10 +54,10 @@ impl ActivityDoesntHeartbeatHitsTimeoutThenCompletesWf { ) .await; let err = res.unwrap_err(); - if let ActivityExecutionError::Failed(f) = &err { - assert_eq!(f.is_timeout(), Some(TimeoutType::Heartbeat)); + if let ActivityExecutionError::Timeout { timeout_type, .. } = &err { + assert_eq!(*timeout_type, TimeoutType::Heartbeat); } else { - panic!("expected Failed, got {err:?}"); + panic!("expected TimedOut, got {err:?}"); } Ok(()) } diff --git a/crates/sdk-core/tests/integ_tests/workflow_tests/activities.rs b/crates/sdk-core/tests/integ_tests/workflow_tests/activities.rs index 41c2b92a3..3bd33a321 100644 --- a/crates/sdk-core/tests/integ_tests/workflow_tests/activities.rs +++ b/crates/sdk-core/tests/integ_tests/workflow_tests/activities.rs @@ -45,7 +45,8 @@ use temporalio_common::{ }; use temporalio_macros::{activities, workflow, workflow_methods}; use temporalio_sdk::{ - ActivityOptions, CancellableFuture, WorkflowContext, WorkflowResult, WorkflowTermination, + ActivityExecutionError, ActivityOptions, CancellableFuture, ContinueAsNewOptions, + WorkflowContext, WorkflowResult, WorkflowTermination, activities::{ActivityContext, ActivityError}, }; use temporalio_sdk_core::{ @@ -218,71 +219,111 @@ async fn activity_workflow() { core.complete_execution(&task.run_id).await; } +struct NonRetryableActivities; +#[activities] +impl NonRetryableActivities { + #[activity] + async fn non_retryable_fail( + _ctx: ActivityContext, + _input: String, + ) -> Result { + Err(ActivityError::NonRetryable( + "activity went boom".to_string().into(), + )) + } +} + +#[workflow] +#[derive(Default)] +struct NonRetryableActivityWorkflow; +#[workflow_methods] +impl NonRetryableActivityWorkflow { + #[run] + async fn run(ctx: &mut WorkflowContext, _input: String) -> WorkflowResult { + ctx.start_activity( + NonRetryableActivities::non_retryable_fail, + "trigger".to_string(), + ActivityOptions { + start_to_close_timeout: Some(Duration::from_secs(5)), + retry_policy: Some(RetryPolicy { + maximum_attempts: 1, + ..Default::default() + }), + ..Default::default() + }, + ) + .await + .map_err(|e| WorkflowTermination::Failed(anyhow::anyhow!("activity failed: {}", e))) + } +} + #[tokio::test] async fn activity_non_retryable_failure() { - let mut starter = init_core_and_create_wf("activity_non_retryable_failure").await; - let core = starter.get_worker().await; - let task_q = starter.get_task_queue(); - let activity_id = "act-1"; - let task = core.poll_workflow_activation().await.unwrap(); - // Complete workflow task and schedule activity - core.complete_workflow_activation( - schedule_activity_cmd( - 0, - task_q, - activity_id, - ActivityCancellationType::TryCancel, - Duration::from_secs(60), - Duration::from_secs(60), + let wf_name = NonRetryableActivityWorkflow::name(); + let mut starter = CoreWfStarter::new(wf_name); + starter + .sdk_config + .register_workflow::(); + starter + .sdk_config + .register_activities(NonRetryableActivities); + let mut worker = starter.worker().await; + + let task_queue = starter.get_task_queue().to_owned(); + let wf_id = format!("{wf_name}-{}", uuid::Uuid::new_v4()); + let handle = worker + .submit_workflow( + NonRetryableActivityWorkflow::run, + "trigger".to_string(), + WorkflowStartOptions::new(task_queue, wf_id.clone()).build(), ) - .into_completion(task.run_id), - ) - .await - .unwrap(); - // Poll activity and verify that it's been scheduled - let task = core.poll_activity_task().await.unwrap(); - assert_matches!(task.variant, Some(act_task::Variant::Start(_))); - // Fail activity with non-retryable error - let failure = Failure::application_failure("activity failed".to_string(), true); - core.complete_activity_task(ActivityTaskCompletion { - task_token: task.task_token, - result: Some(ActivityExecutionResult::fail(failure.clone())), - }) - .await - .unwrap(); - // Poll workflow task and verify that activity has failed. - let task = core.poll_workflow_activation().await.unwrap(); - assert_matches!( - task.jobs.as_slice(), - [ - WorkflowActivationJob { - variant: Some(workflow_activation_job::Variant::ResolveActivity( - ResolveActivity {seq, result: Some(ActivityResolution{ - status: Some(act_res::Status::Failed(activity_result::Failure{ - failure: Some(f), - }))}),..} - )), - }, - ] => { - assert_eq!(*seq, 0); - assert_eq!(f, &Failure{ - message: "Activity task failed".to_owned(), - cause: Some(Box::new(failure)), - failure_info: Some(FailureInfo::ActivityFailureInfo(ActivityFailureInfo{ - activity_id: "act-1".to_owned(), - activity_type: Some(ActivityType { - name: DEFAULT_ACTIVITY_TYPE.to_owned(), - }), - scheduled_event_id: 5, - started_event_id: 6, - identity: INTEG_CLIENT_IDENTITY.to_owned(), - retry_state: RetryState::NonRetryableFailure as i32, - })), - ..Default::default() - }); - } + .await + .unwrap(); + worker.run_until_done().await.unwrap(); + let _result = handle.get_result(Default::default()).await; + + // Fetch history and find the ActivityTaskFailed event to inspect the failure + // the worker actually sent. + let client = starter.get_client().await; + let events = client + .get_workflow_handle::(&wf_id) + .fetch_history(Default::default()) + .await + .unwrap() + .into_events(); + + use temporalio_common::protos::temporal::api::history::v1::history_event::Attributes; + let activity_failed = events + .iter() + .find_map(|e| { + if let Attributes::ActivityTaskFailedEventAttributes(attrs) = + e.attributes.as_ref().unwrap() + { + Some(attrs) + } else { + None + } + }) + .expect("Should find ActivityTaskFailed event"); + + let failure = activity_failed + .failure + .as_ref() + .expect("should have failure"); + + // The server wraps the activity error in an ApplicationFailureInfo directly + // on the failure proto. Walk through the failure and its cause to find it. + let app_info = std::iter::successors(Some(failure), |f| f.cause.as_deref()) + .find_map(|f| match f.failure_info.as_ref() { + Some(FailureInfo::ApplicationFailureInfo(info)) => Some(info), + _ => None, + }) + .expect("Should find ApplicationFailureInfo in failure chain"); + + assert!( + app_info.non_retryable, + "ActivityError::NonRetryable must set non_retryable on the failure proto" ); - core.complete_execution(&task.run_id).await; } #[tokio::test] @@ -1197,7 +1238,11 @@ async fn activity_can_be_cancelled_by_local_timeout() { }, ) .await; - assert!(res.is_err_and(|e| e.is_timeout())); + let err = res.unwrap_err(); + assert!( + matches!(err, ActivityExecutionError::Timeout { .. }), + "Expected timeout got {err:?}" + ); Ok(()) } } @@ -1267,7 +1312,7 @@ async fn long_activity_timeout_repro() { ctx.timer(Duration::from_secs(60 * 3)).await; iter += 1; if iter > 5000 { - return Err(WorkflowTermination::continue_as_new(Default::default())); + ctx.continue_as_new(&(), ContinueAsNewOptions::default())?; } } } diff --git a/crates/sdk-core/tests/integ_tests/workflow_tests/child_workflows.rs b/crates/sdk-core/tests/integ_tests/workflow_tests/child_workflows.rs index b167d8e3e..ee34d0a3d 100644 --- a/crates/sdk-core/tests/integ_tests/workflow_tests/child_workflows.rs +++ b/crates/sdk-core/tests/integ_tests/workflow_tests/child_workflows.rs @@ -478,7 +478,7 @@ impl ParentCancelsChildWf { .result() .await .expect_err("child should be cancelled"); - assert_matches!(err, ChildWorkflowExecutionError::Cancelled(_)); + assert_matches!(err, ChildWorkflowExecutionError::Cancelled { .. }); Ok(()) } } @@ -739,7 +739,7 @@ impl ParentWf { let started = start_res.map_err(|e| anyhow!(e))?; match (expectation, started.result().await) { (Expectation::Success, Ok(_)) => Ok(()), - (Expectation::Failure, Err(ChildWorkflowExecutionError::Failed(_))) => Ok(()), + (Expectation::Failure, Err(ChildWorkflowExecutionError::Failed { .. })) => Ok(()), _ => Err(anyhow!("Unexpected child WF status").into()), } } @@ -842,8 +842,8 @@ impl CancelBeforeSendWf { ); start.cancel(); match start.await { - Err(ChildWorkflowExecutionError::Cancelled(_)) => Ok(()), - _ => Err(anyhow!("Unexpected start status").into()), + Err(ChildWorkflowExecutionError::Cancelled { .. }) => Ok(()), + other => Err(anyhow!("Unexpected start status: {other:?}").into()), } } } @@ -1406,3 +1406,135 @@ async fn cancel_child_after_cancel_external_uses_correct_seq() { worker.run_until_done().await.unwrap(); } + +#[workflow] +#[derive(Default)] +struct FailingChildWf; +#[workflow_methods] +impl FailingChildWf { + #[run] + async fn run(_ctx: &mut WorkflowContext) -> WorkflowResult { + Err(WorkflowTermination::Failed(anyhow::anyhow!( + "child went boom" + ))) + } +} + +/// Parent that starts a failing child and asserts `result()` yields +/// `ChildWorkflowExecutionError::Failed`. +#[workflow] +#[derive(Default)] +struct ParentOfFailingChild; +#[workflow_methods] +impl ParentOfFailingChild { + #[run] + async fn run(ctx: &mut WorkflowContext) -> WorkflowResult<()> { + let started = ctx + .child_workflow( + FailingChildWf::run, + (), + ChildWorkflowOptions { + workflow_id: format!("{}-child", ctx.task_queue()), + ..Default::default() + }, + ) + .await + .map_err(|e| anyhow!(e))?; + let err = started.result().await.unwrap_err(); + assert_matches!( + err, + ChildWorkflowExecutionError::Failed { ref message, .. } + if message.contains("child went boom"), + "expected Failed with child's message, got: {err:?}" + ); + Ok(()) + } +} + +#[tokio::test] +async fn child_workflow_failure_produces_failed_error() { + let wf_name = "child-wf-failure-error"; + let mut starter = CoreWfStarter::new(wf_name); + starter.sdk_config.task_types = WorkerTaskTypes::workflow_only(); + let mut worker = starter.worker().await; + + worker.register_workflow::(); + worker.register_workflow::(); + + let task_queue = starter.get_task_queue().to_owned(); + worker + .submit_workflow( + ParentOfFailingChild::run, + (), + WorkflowStartOptions::new(task_queue, wf_name.to_owned()).build(), + ) + .await + .unwrap(); + worker.run_until_done().await.unwrap(); +} + +/// Child workflow that blocks until cancelled. +#[workflow] +#[derive(Default)] +struct CancellableChildWf; +#[workflow_methods] +impl CancellableChildWf { + #[run] + async fn run(ctx: &mut WorkflowContext) -> WorkflowResult<()> { + ctx.wait_condition(|_| false).await; + Ok(()) + } +} + +/// Parent that starts a child, cancels it, and asserts `result()` yields +/// `ChildWorkflowExecutionError::Cancelled`. +#[workflow] +#[derive(Default)] +struct ParentCancelsChild; +#[workflow_methods] +impl ParentCancelsChild { + #[run] + async fn run(ctx: &mut WorkflowContext) -> WorkflowResult<()> { + let started = ctx + .child_workflow( + CancellableChildWf::run, + (), + ChildWorkflowOptions { + workflow_id: format!("{}-child", ctx.task_queue()), + ..Default::default() + }, + ) + .await + .map_err(|e| anyhow!(e))?; + started.cancel("test cancel".to_string()); + let err = started.result().await.unwrap_err(); + assert_matches!( + err, + ChildWorkflowExecutionError::Cancelled { .. }, + "expected Cancelled, got: {err:?}" + ); + Ok(()) + } +} + +#[tokio::test] +async fn child_workflow_cancel_produces_cancelled_error() { + let wf_name = "child-wf-cancel-error"; + let mut starter = CoreWfStarter::new(wf_name); + starter.sdk_config.task_types = WorkerTaskTypes::workflow_only(); + let mut worker = starter.worker().await; + + worker.register_workflow::(); + worker.register_workflow::(); + + let task_queue = starter.get_task_queue().to_owned(); + worker + .submit_workflow( + ParentCancelsChild::run, + (), + WorkflowStartOptions::new(task_queue, wf_name.to_owned()).build(), + ) + .await + .unwrap(); + worker.run_until_done().await.unwrap(); +} diff --git a/crates/sdk-core/tests/integ_tests/workflow_tests/continue_as_new.rs b/crates/sdk-core/tests/integ_tests/workflow_tests/continue_as_new.rs index 9b2566b86..aa0a5dd27 100644 --- a/crates/sdk-core/tests/integ_tests/workflow_tests/continue_as_new.rs +++ b/crates/sdk-core/tests/integ_tests/workflow_tests/continue_as_new.rs @@ -4,7 +4,7 @@ use temporalio_client::WorkflowStartOptions; use temporalio_common::{ protos::{ DEFAULT_WORKFLOW_TYPE, canned_histories, - coresdk::{AsJsonPayloadExt, workflow_commands::ContinueAsNewWorkflowExecution}, + coresdk::workflow_commands::ContinueAsNewWorkflowExecution, temporal::api::{ command::v1::command::Attributes, enums::v1::{CommandType, ContinueAsNewVersioningBehavior}, @@ -14,7 +14,7 @@ use temporalio_common::{ worker::WorkerTaskTypes, }; use temporalio_macros::{workflow, workflow_methods}; -use temporalio_sdk::{WorkflowContext, WorkflowResult, WorkflowTermination}; +use temporalio_sdk::{ContinueAsNewOptions, WorkflowContext, WorkflowResult, WorkflowTermination}; use temporalio_sdk_core::{TunerHolder, test_help::MockPollCfg}; #[workflow] @@ -27,15 +27,9 @@ impl ContinueAsNewWf { async fn run(ctx: &mut WorkflowContext, run_ct: u8) -> WorkflowResult<()> { ctx.timer(Duration::from_millis(500)).await; if run_ct < 5 { - Err(WorkflowTermination::continue_as_new( - ContinueAsNewWorkflowExecution { - arguments: vec![(run_ct + 1).as_json_payload().unwrap()], - ..Default::default() - }, - )) - } else { - Ok(()) + ctx.continue_as_new(&(run_ct + 1), ContinueAsNewOptions::default())?; } + Ok(()) } } @@ -145,12 +139,8 @@ impl ContinueAsNewSuggestedWf { ctx.timer(Duration::from_millis(500)).await; // Second WFT: flag should be true (set on WFT started event 8) assert!(ctx.continue_as_new_suggested()); - Err(WorkflowTermination::continue_as_new( - ContinueAsNewWorkflowExecution { - arguments: vec![[1].into()], - ..Default::default() - }, - )) + ctx.continue_as_new(&(), ContinueAsNewOptions::default())?; + Ok(()) } } diff --git a/crates/sdk-core/tests/integ_tests/workflow_tests/local_activities.rs b/crates/sdk-core/tests/integ_tests/workflow_tests/local_activities.rs index 2dff951a0..601340255 100644 --- a/crates/sdk-core/tests/integ_tests/workflow_tests/local_activities.rs +++ b/crates/sdk-core/tests/integ_tests/workflow_tests/local_activities.rs @@ -39,7 +39,7 @@ use temporalio_common::{ command::v1::{RecordMarkerCommandAttributes, command}, common::v1::RetryPolicy, enums::v1::{CommandType, EventType, TimeoutType, WorkflowTaskFailedCause}, - failure::v1::{Failure, failure::FailureInfo}, + failure::v1::Failure, history::v1::history_event::Attributes::MarkerRecordedEventAttributes, query::v1::WorkflowQuery, }, @@ -296,7 +296,7 @@ impl LocalActRetryTimerBackoff { }, ) .await; - assert!(matches!(res, Err(ActivityExecutionError::Failed(_)))); + assert!(matches!(res, Err(ActivityExecutionError::Failed { .. }))); Ok(()) } } @@ -391,10 +391,11 @@ async fn cancel_immediate(#[case] cancel_type: ActivityCancellationType) { ); la.cancel(); let resolution = la.await; - assert!(matches!( - resolution, - Err(ActivityExecutionError::Cancelled(_)) - )); + assert!( + matches!(resolution, Err(ActivityExecutionError::Cancelled { .. })), + "got: {:?}", + resolution + ); Ok(()) } } @@ -546,10 +547,11 @@ async fn cancel_after_act_starts( // resolving the LA with cancel on replay ctx.timer(Duration::from_secs(1)).await; let resolution = la.await; - assert!(matches!( - resolution, - Err(ActivityExecutionError::Cancelled(_)) - )); + assert!( + matches!(resolution, Err(ActivityExecutionError::Cancelled { .. })), + "got: {:?}", + resolution + ); Ok(()) } } @@ -639,13 +641,11 @@ async fn x_to_close_timeout(#[case] is_schedule: bool) { ) .await; let err = res.unwrap_err(); - if let ActivityExecutionError::Failed(f) = &err { - assert_eq!( - f.is_timeout(), - Some(TimeoutType::try_from(timeout_type).unwrap()) - ); + let expected_tt = TimeoutType::try_from(timeout_type).unwrap(); + if let ActivityExecutionError::Timeout { timeout_type, .. } = &err { + assert_eq!(*timeout_type, expected_tt); } else { - return Err(anyhow!("expected Failed, got {err:?}").into()); + return Err(anyhow!("expected timeout, got {err:?}").into()); } Ok(()) } @@ -716,12 +716,16 @@ async fn schedule_to_close_timeout_across_timer_backoff(#[case] cached: bool) { }, ) .await; - let err = res.unwrap_err(); - if let ActivityExecutionError::Failed(f) = &err { - assert_eq!(f.is_timeout(), Some(TimeoutType::ScheduleToClose)); - } else { - panic!("expected Failed, got {err:?}"); - } + assert!( + matches!( + res, + Err(ActivityExecutionError::Timeout { + timeout_type: TimeoutType::ScheduleToClose, + .. + }) + ), + "expected timeout error, got {res:?}" + ); Ok(()) } } @@ -832,8 +836,8 @@ async fn timer_backoff_concurrent_with_non_timer_backoff() { }, ); let (r1, r2) = temporalio_sdk::workflows::join!(r1, r2); - assert!(matches!(r1, Err(ActivityExecutionError::Failed(_)))); - assert!(matches!(r2, Err(ActivityExecutionError::Failed(_)))); + assert!(matches!(r1, Err(ActivityExecutionError::Failed { .. }))); + assert!(matches!(r2, Err(ActivityExecutionError::Failed { .. }))); Ok(()) } } @@ -1525,7 +1529,7 @@ async fn local_act_fail_and_retry(#[case] eventually_pass: bool) { if eventually_pass { assert!(la_res.is_ok()) } else { - assert!(matches!(la_res, Err(ActivityExecutionError::Failed(_)))) + assert!(matches!(la_res, Err(ActivityExecutionError::Failed { .. }))) } Ok(()) } @@ -1621,7 +1625,7 @@ async fn local_act_retry_long_backoff_uses_timer() { }, ) .await; - assert!(matches!(la_res, Err(ActivityExecutionError::Failed(_)))); + assert!(matches!(la_res, Err(ActivityExecutionError::Failed { .. }))); ctx.timer(Duration::from_secs(1)).await; Ok(()) } @@ -1988,14 +1992,11 @@ async fn test_schedule_to_start_timeout() { }, ) .await; - assert!(la_res.is_err()); - if let Err(ActivityExecutionError::Failed(ref fail)) = la_res { - assert_eq!(fail.is_timeout(), Some(TimeoutType::ScheduleToStart)); - assert_matches!(fail.failure_info, Some(FailureInfo::ActivityFailureInfo(_))); - assert_matches!( - fail.cause.as_ref().unwrap().failure_info, - Some(FailureInfo::TimeoutFailureInfo(_)) - ); + let err = la_res.unwrap_err(); + if let ActivityExecutionError::Timeout { timeout_type, .. } = &err { + assert_eq!(*timeout_type, TimeoutType::ScheduleToStart); + } else { + panic!("expected TimedOut, got {err:?}"); } Ok(()) } @@ -2087,8 +2088,10 @@ async fn test_schedule_to_start_timeout_not_based_on_original_time( .await; if is_sched_to_start { assert!(la_res.is_ok()); - } else if let Err(ActivityExecutionError::Failed(ref fail)) = la_res { - assert_eq!(fail.is_timeout(), Some(TimeoutType::ScheduleToClose)); + } else if let Err(ActivityExecutionError::Timeout { timeout_type, .. }) = &la_res { + assert_eq!(*timeout_type, TimeoutType::ScheduleToClose); + } else { + panic!("expected TimedOut, got {la_res:?}"); } Ok(()) } @@ -2157,8 +2160,10 @@ async fn start_to_close_timeout_allows_retries(#[values(true, false)] la_complet .await; if la_completes { assert!(la_res.is_ok(), "Result should be ok was {la_res:?}"); - } else if let Err(ActivityExecutionError::Failed(ref fail)) = la_res { - assert_eq!(fail.is_timeout(), Some(TimeoutType::StartToClose)); + } else if let Err(ActivityExecutionError::Timeout { timeout_type, .. }) = &la_res { + assert_eq!(*timeout_type, TimeoutType::StartToClose); + } else { + panic!("expected TimedOut, got {la_res:?}"); } Ok(()) } @@ -3275,20 +3280,11 @@ async fn cancel_after_act_starts_canned( la.cancel(); ctx.timer(Duration::from_secs(1)).await; let resolution = la.await; - assert!(matches!( - resolution, - Err(ActivityExecutionError::Cancelled(_)) - )); - if let Err(ActivityExecutionError::Cancelled(rfail)) = resolution { - assert_matches!( - rfail.failure_info, - Some(FailureInfo::ActivityFailureInfo(_)) - ); - assert_matches!( - rfail.cause.unwrap().failure_info, - Some(FailureInfo::CanceledFailureInfo(_)) - ); - } + let err = resolution.unwrap_err(); + assert!( + matches!(err, ActivityExecutionError::Cancelled { .. }), + "expected Cancelled, got: {err}" + ); Ok(()) } } diff --git a/crates/sdk-core/tests/integ_tests/workflow_tests/nexus.rs b/crates/sdk-core/tests/integ_tests/workflow_tests/nexus.rs index 8a2319e51..8ad0a11e6 100644 --- a/crates/sdk-core/tests/integ_tests/workflow_tests/nexus.rs +++ b/crates/sdk-core/tests/integ_tests/workflow_tests/nexus.rs @@ -498,7 +498,7 @@ async fn nexus_async( Some(nexus_operation_result::Status::Failed(f)) => f ); assert_eq!(f.message, "nexus operation completed unsuccessfully"); - assert_eq!(f.cause.unwrap().message, "Workflow execution error: broken"); + assert_eq!(f.cause.unwrap().message, "broken"); } Outcome::Cancel | Outcome::CancelAfterRecordedBeforeStarted => { let f = assert_matches!( diff --git a/crates/sdk-core/tests/integ_tests/workflow_tests/upsert_search_attrs.rs b/crates/sdk-core/tests/integ_tests/workflow_tests/upsert_search_attrs.rs index 2eda35ce3..076738f90 100644 --- a/crates/sdk-core/tests/integ_tests/workflow_tests/upsert_search_attrs.rs +++ b/crates/sdk-core/tests/integ_tests/workflow_tests/upsert_search_attrs.rs @@ -17,7 +17,7 @@ use temporalio_common::{ worker::WorkerTaskTypes, }; use temporalio_macros::{workflow, workflow_methods}; -use temporalio_sdk::{WorkflowContext, WorkflowResult, WorkflowTermination}; +use temporalio_sdk::{ContinueAsNewOptions, WorkflowContext, WorkflowResult}; use temporalio_sdk_core::test_help::MockPollCfg; use uuid::Uuid; @@ -42,10 +42,9 @@ impl SearchAttrUpdater { (SEARCH_ATTR_INT.to_string(), int_val), ]); if orig_val == 49 { - Err(WorkflowTermination::continue_as_new(Default::default())) - } else { - Ok(()) + ctx.continue_as_new(&(), ContinueAsNewOptions::default())?; } + Ok(()) } } diff --git a/crates/sdk/src/activities.rs b/crates/sdk/src/activities.rs index 6a69c53fe..e74f06998 100644 --- a/crates/sdk/src/activities.rs +++ b/crates/sdk/src/activities.rs @@ -251,7 +251,7 @@ pub enum ActivityError { } impl ActivityError { - /// Construct a cancelled error without details + /// Construct a cancelled error without details. pub fn cancelled() -> Self { Self::Cancelled { details: None } } diff --git a/crates/sdk/src/lib.rs b/crates/sdk/src/lib.rs index d8561a53b..e7af798b4 100644 --- a/crates/sdk/src/lib.rs +++ b/crates/sdk/src/lib.rs @@ -97,12 +97,14 @@ macro_rules! __temporal_join { use workflow_future::WorkflowFunction; pub use temporalio_client::Namespace; +pub use temporalio_common::protos::temporal::api::enums::v1::{RetryState, TimeoutType}; pub use workflow_context::{ ActivityExecutionError, ActivityOptions, BaseWorkflowContext, CancellableFuture, ChildWorkflowExecutionError, ChildWorkflowOptions, ChildWorkflowSignalError, - ExternalWorkflowHandle, LocalActivityOptions, NexusOperationOptions, ParentWorkflowInfo, - RootWorkflowInfo, Signal, SignalData, StartChildWorkflowExecutionFailedCause, - StartedChildWorkflow, SyncWorkflowContext, TimerOptions, WorkflowContext, WorkflowContextView, + ContinueAsNewOptions, ExternalWorkflowHandle, LocalActivityOptions, NexusOperationOptions, + ParentWorkflowInfo, RootWorkflowInfo, Signal, SignalData, + StartChildWorkflowExecutionFailedCause, StartedChildWorkflow, SyncWorkflowContext, + TimerOptions, WorkflowContext, WorkflowContextView, }; use crate::{ @@ -133,7 +135,7 @@ use std::{ use temporalio_client::{Client, NamespacedClient}; use temporalio_common::{ ActivityDefinition, WorkflowDefinition, - data_converters::{DataConverter, SerializationContextData}, + data_converters::{DataConverter, SerializationContextData, TemporalError}, payload_visitor::{decode_payloads, encode_payloads}, protos::{ TaskToken, @@ -156,7 +158,7 @@ use temporalio_common::{ temporal::api::{ common::v1::Payload, enums::v1::WorkflowTaskFailedCause, - failure::v1::{Failure, failure}, + failure::v1::{ApplicationFailureInfo, Failure, failure}, }, }, worker::{WorkerDeploymentOptions, WorkerTaskTypes, build_id_from_current_exe}, @@ -803,7 +805,6 @@ impl WorkflowHalf { _ => None, }) { let workflow_type = sw.workflow_type.clone(); - let payload_converter = common.data_converter.payload_converter().clone(); let (wff, activations) = { if let Some(factory) = self.workflow_definitions.get_workflow(&workflow_type) { match WorkflowFunction::from_invocation(factory).start_workflow( @@ -812,7 +813,7 @@ impl WorkflowHalf { run_id.clone(), std::mem::take(sw), completions_tx.clone(), - payload_converter, + common.data_converter.clone(), self.detect_nondeterministic_futures, ) { Ok(result) => result, @@ -953,9 +954,20 @@ impl ActivityHalf { .instrument(span); let output = AssertUnwindSafe(act_fut).catch_unwind().await; let result = match output { - Err(e) => ActivityExecutionResult::fail(Failure::application_failure( - format!("Activity function panicked: {}", panic_formatter(e)), - true, + Err(e) => ActivityExecutionResult::fail(codec_data_converter.to_failure( + Box::new(TemporalError::Application { + message: format!( + "Activity function panicked: {}", + panic_formatter(e) + ), + stack_trace: String::new(), + r#type: String::new(), + non_retryable: true, + details: None, + next_retry_delay: None, + cause: None, + }), + &SerializationContextData::Activity, )), Ok(Ok(p)) => ActivityExecutionResult::ok(p), Ok(Err(err)) => match err { @@ -963,10 +975,8 @@ impl ActivityHalf { source, explicit_delay, } => ActivityExecutionResult::fail({ - let mut f = Failure::application_failure_from_error( - anyhow::Error::from_boxed(source), - false, - ); + let mut f = codec_data_converter + .to_failure(source, &SerializationContextData::Activity); if let Some(d) = explicit_delay && let Some(failure::FailureInfo::ApplicationFailureInfo(fi)) = f.failure_info.as_mut() @@ -978,12 +988,12 @@ impl ActivityHalf { ActivityError::Cancelled { details } => { ActivityExecutionResult::cancel_from_details(details) } - ActivityError::NonRetryable(nre) => ActivityExecutionResult::fail( - Failure::application_failure_from_error( - anyhow::Error::from_boxed(nre), - true, - ), - ), + ActivityError::NonRetryable(nre) => { + ActivityExecutionResult::fail(force_failure_non_retryable( + codec_data_converter + .to_failure(nre, &SerializationContextData::Activity), + )) + } ActivityError::WillCompleteAsync => { ActivityExecutionResult::will_complete_async() } @@ -1340,6 +1350,32 @@ impl PrintablePanicType for EndPrintingAttempts { type NextType = EndPrintingAttempts; } +fn force_failure_non_retryable(mut failure: Failure) -> Failure { + match failure.failure_info.as_mut() { + Some(failure::FailureInfo::ApplicationFailureInfo(fi)) => { + fi.non_retryable = true; + failure + } + Some(failure::FailureInfo::ServerFailureInfo(fi)) => { + fi.non_retryable = true; + failure + } + _ => Failure { + // Activities marked NonRetryable must keep suppressing retries even + // when a custom converter chooses a failure kind without that flag. + message: failure.message.clone(), + failure_info: Some(failure::FailureInfo::ApplicationFailureInfo( + ApplicationFailureInfo { + non_retryable: true, + ..Default::default() + }, + )), + cause: Some(Box::new(failure)), + ..Default::default() + }, + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/sdk/src/workflow_context.rs b/crates/sdk/src/workflow_context.rs index 7359b35c5..78b641d25 100644 --- a/crates/sdk/src/workflow_context.rs +++ b/crates/sdk/src/workflow_context.rs @@ -1,15 +1,15 @@ mod options; pub use options::{ - ActivityOptions, ChildWorkflowOptions, LocalActivityOptions, NexusOperationOptions, Signal, - SignalData, TimerOptions, + ActivityOptions, ChildWorkflowOptions, ContinueAsNewOptions, LocalActivityOptions, + NexusOperationOptions, Signal, SignalData, TimerOptions, }; pub use temporalio_common::protos::coresdk::child_workflow::StartChildWorkflowExecutionFailedCause; use crate::{ CancelExternalWfResult, CancellableID, CancellableIDWithReason, CommandCreateRequest, CommandSubscribeChildWorkflowCompletion, NexusStartResult, RustWfCmd, SignalExternalWfResult, - SupportsCancelReason, TimerResult, UnblockEvent, Unblockable, + SupportsCancelReason, TimerResult, UnblockEvent, Unblockable, WorkflowTermination, workflow_context::options::IntoWorkflowCommand, workflow_executor::SdkWakeGuard, }; use futures_util::{ @@ -35,12 +35,12 @@ use std::{ use temporalio_common::{ ActivityDefinition, SignalDefinition, WorkflowDefinition, data_converters::{ - GenericPayloadConverter, PayloadConversionError, PayloadConverter, SerializationContext, - SerializationContextData, TemporalDeserializable, + DataConverter, GenericPayloadConverter, PayloadConversionError, PayloadConverter, + SerializationContext, SerializationContextData, TemporalDeserializable, TemporalError, }, protos::{ coresdk::{ - activity_result::{ActivityResolution, activity_resolution}, + activity_result::{ActivityResolution, Cancellation, activity_resolution}, child_workflow::ChildWorkflowResult, common::NamespacedWorkflowExecution, nexus::NexusOperationResult, @@ -56,7 +56,8 @@ use temporalio_common::{ }, }, temporal::api::{ - common::v1::{Memo, Payload, SearchAttributes}, + common::v1::{Memo, Payload, Payloads, SearchAttributes}, + enums::v1::{RetryState, TimeoutType}, failure::v1::Failure, sdk::v1::UserMetadata, }, @@ -97,7 +98,7 @@ struct WorkflowContextInner { am_cancelled: watch::Receiver>, shared: RefCell, seq_nums: RefCell, - payload_converter: PayloadConverter, + data_converter: DataConverter, state_mutated: Cell, } @@ -281,37 +282,200 @@ impl WorkflowContextView { /// Error type for activity execution outcomes. #[derive(Debug, thiserror::Error)] pub enum ActivityExecutionError { - /// The activity failed with the given failure details. - #[error("Activity failed: {}", .0.message)] - Failed(Box), + /// The activity failed with an application error. + #[error("Activity failed: {message}")] + Failed { + /// Human-readable error message from the application failure. + message: String, + /// Application error type string. + error_type: String, + /// Whether this error is non-retryable. + non_retryable: bool, + /// Activity type name. + activity_type: String, + /// Activity ID. + activity_id: String, + /// Retry state at the time of failure. + retry_state: RetryState, + /// The full [`TemporalError`] cause chain. + #[source] + source: Box, + }, + /// The activity timed out. + #[error("Activity timed out ({timeout_type:?}): {activity_type} ({activity_id})")] + Timeout { + /// Which kind of timeout. + timeout_type: TimeoutType, + /// Activity type name. + activity_type: String, + /// Activity ID. + activity_id: String, + /// Retry state at the time of timeout. + retry_state: RetryState, + /// The full [`TemporalError`] cause chain. + #[source] + source: Box, + }, /// The activity was cancelled. - #[error("Activity cancelled: {}", .0.message)] - Cancelled(Box), - // TODO: Timed out variant + #[error("Activity cancelled")] + Cancelled { + /// Human-readable cancellation message. + message: String, + /// Cancellation detail payloads. + details: Option, + /// The full [`TemporalError`] cause chain. + #[source] + source: Box, + }, + /// The activity was terminated. + #[error("Activity terminated")] + Terminated { + /// The full [`TemporalError`] cause chain. + #[source] + source: Box, + }, /// Failed to serialize input or deserialize result payload. #[error("Payload conversion failed: {0}")] Serialization(#[from] PayloadConversionError), } impl ActivityExecutionError { - /// Returns true if this error represents a timeout. - pub fn is_timeout(&self) -> bool { - match self { - ActivityExecutionError::Failed(f) => f.is_timeout().is_some(), - _ => false, + /// Construct from a [`TemporalError`], classifying into the appropriate variant + /// based on the error's shape. + fn from_temporal_error(te: TemporalError) -> Self { + match &te { + TemporalError::Activity { + activity_type, + activity_id, + retry_state, + cause: Some(cause), + .. + } => { + let activity_type = activity_type.clone(); + let activity_id = activity_id.clone(); + let retry_state = *retry_state; + match cause.as_ref() { + TemporalError::Application { + message, + r#type, + non_retryable, + .. + } => Self::Failed { + message: message.clone(), + error_type: r#type.clone(), + non_retryable: *non_retryable, + activity_type, + activity_id, + retry_state, + source: Box::new(te), + }, + TemporalError::Timeout { timeout_type, .. } => Self::Timeout { + timeout_type: *timeout_type, + activity_type, + activity_id, + retry_state, + source: Box::new(te), + }, + TemporalError::Cancelled { + message, details, .. + } => Self::Cancelled { + message: message.clone(), + details: details.clone(), + source: Box::new(te), + }, + TemporalError::Terminated { .. } => Self::Terminated { + source: Box::new(te), + }, + _ => Self::failed_fallback(te), + } + } + TemporalError::Cancelled { + message, details, .. + } => Self::Cancelled { + message: message.clone(), + details: details.clone(), + source: Box::new(te), + }, + _ => Self::failed_fallback(te), + } + } + + fn failed_fallback(te: TemporalError) -> Self { + let message = te + .message() + .map(str::to_owned) + .unwrap_or_else(|| te.to_string()); + Self::Failed { + message, + error_type: String::new(), + non_retryable: false, + activity_type: String::new(), + activity_id: String::new(), + retry_state: RetryState::Unspecified, + source: Box::new(te), } } } -/// Error returned when a child workflow execution fails. +/// Error returned when a workflow awaits a child workflow result. #[derive(Debug, thiserror::Error)] pub enum ChildWorkflowExecutionError { /// The child workflow failed. - #[error("Child workflow failed: {}", .0.message)] - Failed(Box), + #[error("Child workflow failed: {message}")] + Failed { + /// Human-readable error message from the child workflow failure. + message: String, + /// Child workflow type name. + workflow_type: String, + /// Child workflow ID. + workflow_id: String, + /// Child workflow run ID. + run_id: String, + /// Retry state at the time of failure. + retry_state: RetryState, + /// The full [`TemporalError`] cause chain. + #[source] + source: Box, + }, + /// The child workflow timed out. + #[error("Child workflow timed out ({timeout_type:?}): {workflow_type} ({workflow_id})")] + Timeout { + /// Which kind of timeout. + timeout_type: TimeoutType, + /// Child workflow type name. + workflow_type: String, + /// Child workflow ID. + workflow_id: String, + /// Child workflow run ID. + run_id: String, + /// The full [`TemporalError`] cause chain. + #[source] + source: Box, + }, /// The child workflow was cancelled. - #[error("Child workflow cancelled: {}", .0.message)] - Cancelled(Box), + #[error("Child workflow cancelled")] + Cancelled { + /// Human-readable cancellation message. + message: String, + /// Cancellation detail payloads. + details: Option, + /// The full [`TemporalError`] cause chain. + #[source] + source: Box, + }, + /// The child workflow was terminated. + #[error("Child workflow terminated")] + Terminated { + /// Child workflow type name. + workflow_type: String, + /// Child workflow ID. + workflow_id: String, + /// Child workflow run ID. + run_id: String, + /// The full [`TemporalError`] cause chain. + #[source] + source: Box, + }, /// The child workflow failed to start (e.g., workflow ID already exists). #[error( "Child workflow start failed: workflow_id={workflow_id}, workflow_type={workflow_type}, cause={cause:?}" @@ -324,17 +488,93 @@ pub enum ChildWorkflowExecutionError { /// The cause of the start failure. cause: StartChildWorkflowExecutionFailedCause, }, - /// Failed to serialize input or deserialize the child workflow result payload. + /// Failed to deserialize the child workflow result payload. #[error("Payload conversion failed: {0}")] Serialization(#[from] PayloadConversionError), } +impl ChildWorkflowExecutionError { + /// Construct from a [`TemporalError`], classifying into the appropriate variant + /// based on the error's shape. + fn from_temporal_error(te: TemporalError) -> Self { + match &te { + TemporalError::ChildWorkflow { + workflow_type, + workflow_id, + run_id, + retry_state, + cause: Some(cause), + .. + } => { + let workflow_type = workflow_type.clone(); + let workflow_id = workflow_id.clone(); + let run_id = run_id.clone(); + let retry_state = *retry_state; + match cause.as_ref() { + TemporalError::Application { message, .. } => Self::Failed { + message: message.clone(), + workflow_type, + workflow_id, + run_id, + retry_state, + source: Box::new(te), + }, + TemporalError::Timeout { timeout_type, .. } => Self::Timeout { + timeout_type: *timeout_type, + workflow_type, + workflow_id, + run_id, + source: Box::new(te), + }, + TemporalError::Cancelled { + message, details, .. + } => Self::Cancelled { + message: message.clone(), + details: details.clone(), + source: Box::new(te), + }, + TemporalError::Terminated { .. } => Self::Terminated { + workflow_type, + workflow_id, + run_id, + source: Box::new(te), + }, + _ => Self::failed_fallback(te), + } + } + TemporalError::Cancelled { + message, details, .. + } => Self::Cancelled { + message: message.clone(), + details: details.clone(), + source: Box::new(te), + }, + _ => Self::failed_fallback(te), + } + } + + fn failed_fallback(te: TemporalError) -> Self { + let message = te + .message() + .map(str::to_owned) + .unwrap_or_else(|| te.to_string()); + Self::Failed { + message, + workflow_type: String::new(), + workflow_id: String::new(), + run_id: String::new(), + retry_state: RetryState::Unspecified, + source: Box::new(te), + } + } +} + /// Error returned when signaling a child workflow fails. #[derive(Debug, thiserror::Error)] pub enum ChildWorkflowSignalError { /// The signal delivery failed. - #[error("Child workflow signal failed: {}", .0.message)] - Failed(Box), + #[error("Child workflow signal failed: {0}")] + Failed(#[source] Box), /// Failed to serialize the signal input payload. #[error("Signal payload conversion failed: {0}")] Serialization(#[from] PayloadConversionError), @@ -349,7 +589,7 @@ impl BaseWorkflowContext { run_id: String, init_workflow_job: InitializeWorkflow, am_cancelled: watch::Receiver>, - payload_converter: PayloadConverter, + data_converter: DataConverter, ) -> (Self, Receiver) { // The receiving side is non-async let (chan, rx) = std::sync::mpsc::channel(); @@ -378,7 +618,7 @@ impl BaseWorkflowContext { next_signal_external_wf_sequence_number: 1, next_nexus_op_sequence_number: 1, }), - payload_converter, + data_converter, state_mutated: Cell::new(false), }), }, @@ -455,9 +695,14 @@ impl BaseWorkflowContext { let input = input.into(); let ctx = SerializationContext { data: &SerializationContextData::Workflow, - converter: &self.inner.payload_converter, + converter: self.inner.data_converter.payload_converter(), }; - let payloads = match self.inner.payload_converter.to_payloads(&ctx, &input) { + let payloads = match self + .inner + .data_converter + .payload_converter() + .to_payloads(&ctx, &input) + { Ok(p) => p, Err(e) => { return ActivityFut::eager(e.into()); @@ -476,7 +721,7 @@ impl BaseWorkflowContext { } .into(), ); - ActivityFut::running(cmd, self.inner.payload_converter.clone()) + ActivityFut::running(cmd, self.inner.data_converter.clone()) } /// Request to run a local activity @@ -492,9 +737,14 @@ impl BaseWorkflowContext { let input = input.into(); let ctx = SerializationContext { data: &SerializationContextData::Workflow, - converter: &self.inner.payload_converter, + converter: self.inner.data_converter.payload_converter(), }; - let payloads = match self.inner.payload_converter.to_payloads(&ctx, &input) { + let payloads = match self + .inner + .data_converter + .payload_converter() + .to_payloads(&ctx, &input) + { Ok(p) => p, Err(e) => { return ActivityFut::eager(e.into()); @@ -502,7 +752,7 @@ impl BaseWorkflowContext { }; ActivityFut::running( LATimerBackoffFut::new(AD::name().to_string(), payloads, opts, self.clone()), - self.inner.payload_converter.clone(), + self.inner.data_converter.clone(), ) } @@ -517,11 +767,12 @@ impl BaseWorkflowContext { WD::Output: TemporalDeserializable, { let input = input.into(); + let pc = self.inner.data_converter.payload_converter(); let ctx = SerializationContext { data: &SerializationContextData::Workflow, - converter: &self.inner.payload_converter, + converter: pc, }; - let payloads = match self.inner.payload_converter.to_payloads(&ctx, &input) { + let payloads = match pc.to_payloads(&ctx, &input) { Ok(p) => p, Err(e) => { return ChildWorkflowStartFut::eager(e.into()); @@ -550,7 +801,7 @@ impl BaseWorkflowContext { child_seq, result_future: result_cmd, base_ctx: self.clone(), - payload_converter: self.inner.payload_converter.clone(), + data_converter: self.inner.data_converter.clone(), }; let (cmd, unblocker) = CancellableWFCommandFut::new_with_dat( @@ -700,7 +951,7 @@ impl SyncWorkflowContext { /// Returns the [PayloadConverter] currently used by the worker running this workflow. pub fn payload_converter(&self) -> &PayloadConverter { - &self.base.inner.payload_converter + self.base.inner.data_converter.payload_converter() } /// Return various information that the workflow was initialized with. Will eventually become @@ -1115,6 +1366,35 @@ impl WorkflowContext { result } + /// Signal that this workflow should continue as a new execution with the given input and + /// options. + /// + /// This always returns an `Err` which should be propigated + /// + /// ```ignore + /// ctx.continue_as_new(&new_input, ContinueAsNewOptions::default())?; + /// ``` + pub fn continue_as_new( + &self, + input: &::Input, + opts: ContinueAsNewOptions, + ) -> Result + where + W: crate::workflows::WorkflowImplementation, + { + let pc = self.sync.base.inner.data_converter.payload_converter(); + let ctx = SerializationContext { + data: &SerializationContextData::Workflow, + converter: pc, + }; + let arguments = pc + .to_payloads(&ctx, input) + .map_err(WorkflowTermination::failed)?; + let workflow_type = self.sync.workflow_initial_info().workflow_type.clone(); + let proto = opts.into_proto(workflow_type, arguments); + Err(WorkflowTermination::continue_as_new(proto)) + } + /// Wait for some condition on workflow state to become true, yielding the workflow if not. /// /// The condition closure receives an immutable reference to the workflow state, @@ -1387,9 +1667,9 @@ impl Future for LATimerBackoffFut { } else { self.terminated = true; Poll::Ready(ActivityResolution { - status: Some( - activity_resolution::Status::Cancelled(Default::default()), - ), + status: Some(activity_resolution::Status::Cancelled( + Cancellation::from_details(None), + )), }) } } @@ -1406,7 +1686,9 @@ impl Future for LATimerBackoffFut { if self.did_cancel.load(Ordering::Acquire) { self.terminated = true; return Poll::Ready(ActivityResolution { - status: Some(activity_resolution::Status::Cancelled(Default::default())), + status: Some(activity_resolution::Status::Cancelled( + Cancellation::from_details(None), + )), }); } @@ -1452,7 +1734,7 @@ enum ActivityFut { /// Running activity that will deserialize output on completion. Running { inner: F, - payload_converter: PayloadConverter, + data_converter: DataConverter, _phantom: PhantomData, }, Terminated, @@ -1466,10 +1748,10 @@ impl ActivityFut { } } - fn running(inner: F, payload_converter: PayloadConverter) -> Self { + fn running(inner: F, data_converter: DataConverter) -> Self { Self::Running { inner, - payload_converter, + data_converter, _phantom: PhantomData, } } @@ -1492,36 +1774,44 @@ where } ActivityFut::Running { inner, - payload_converter, + data_converter, .. } => match Pin::new(inner).poll(cx) { Poll::Pending => Poll::Pending, Poll::Ready(resolution) => Poll::Ready({ let status = resolution.status.ok_or_else(|| { - ActivityExecutionError::Failed(Box::new(Failure { + ActivityExecutionError::from_temporal_error(TemporalError::Application { message: "Activity completed without a status".to_string(), - ..Default::default() - })) + stack_trace: String::new(), + r#type: String::new(), + non_retryable: false, + details: None, + next_retry_delay: None, + cause: None, + }) })?; + let ctx = &SerializationContextData::Workflow; + let pc = data_converter.payload_converter(); match status { activity_resolution::Status::Completed(success) => { let payload = success.result.unwrap_or_default(); - let ctx = SerializationContext { - data: &SerializationContextData::Workflow, - converter: payload_converter, + let ser_ctx = SerializationContext { + data: ctx, + converter: pc, }; - payload_converter - .from_payload::(&ctx, payload) + pc.from_payload::(&ser_ctx, payload) .map_err(ActivityExecutionError::Serialization) } - activity_resolution::Status::Failed(f) => Err( - ActivityExecutionError::Failed(Box::new(f.failure.unwrap_or_default())), - ), + activity_resolution::Status::Failed(f) => { + let failure = f.failure.unwrap_or_default(); + let te = data_converter.to_error(failure, ctx); + Err(ActivityExecutionError::from_temporal_error(te)) + } activity_resolution::Status::Cancelled(c) => { - Err(ActivityExecutionError::Cancelled(Box::new( - c.failure.unwrap_or_default(), - ))) + let failure = c.failure.unwrap_or_default(); + let te = data_converter.to_error(failure, ctx); + Err(ActivityExecutionError::from_temporal_error(te)) } activity_resolution::Status::Backoff(_) => { panic!("DoBackoff should be handled by LATimerBackoffFut") @@ -1565,7 +1855,7 @@ pub(crate) struct ChildWfCommon { child_seq: u32, result_future: CancellableWFCommandFut, base_ctx: BaseWorkflowContext, - payload_converter: PayloadConverter, + data_converter: DataConverter, } /// Child workflow in pending state. Internal type used during the start handshake; @@ -1594,7 +1884,7 @@ pub struct StartedChildWorkflow { enum ChildWorkflowFut { Running { inner: F, - payload_converter: PayloadConverter, + data_converter: DataConverter, _phantom: PhantomData, }, Terminated, @@ -1614,38 +1904,46 @@ where let poll = match this { ChildWorkflowFut::Running { inner, - payload_converter, + data_converter, .. } => match Pin::new(inner).poll(cx) { Poll::Pending => Poll::Pending, Poll::Ready(result) => Poll::Ready({ use temporalio_common::protos::coresdk::child_workflow::child_workflow_result; let status = result.status.ok_or_else(|| { - ChildWorkflowExecutionError::Failed(Box::new(Failure { - message: "Child workflow completed without a status".to_string(), - ..Default::default() - })) + ChildWorkflowExecutionError::from_temporal_error( + TemporalError::Application { + message: "Child workflow completed without a status".to_string(), + stack_trace: String::new(), + r#type: String::new(), + non_retryable: false, + details: None, + next_retry_delay: None, + cause: None, + }, + ) })?; + let ctx = &SerializationContextData::Workflow; + let pc = data_converter.payload_converter(); match status { child_workflow_result::Status::Completed(success) => { - let payloads = success.result.into_iter().collect(); - let ctx = SerializationContext { - data: &SerializationContextData::Workflow, - converter: payload_converter, + let payloads: Vec<_> = success.result.into_iter().collect(); + let ser_ctx = SerializationContext { + data: ctx, + converter: pc, }; - payload_converter - .from_payloads::(&ctx, payloads) + pc.from_payloads::(&ser_ctx, payloads) .map_err(ChildWorkflowExecutionError::Serialization) } child_workflow_result::Status::Failed(f) => { - Err(ChildWorkflowExecutionError::Failed(Box::new( - f.failure.unwrap_or_default(), - ))) + let failure = f.failure.unwrap_or_default(); + let te = data_converter.to_error(failure, ctx); + Err(ChildWorkflowExecutionError::from_temporal_error(te)) } child_workflow_result::Status::Cancelled(c) => { - Err(ChildWorkflowExecutionError::Cancelled(Box::new( - c.failure.unwrap_or_default(), - ))) + let failure = c.failure.unwrap_or_default(); + let te = data_converter.to_error(failure, ctx); + Err(ChildWorkflowExecutionError::from_temporal_error(te)) } } }), @@ -1733,26 +2031,31 @@ where } ChildWorkflowStartFut::Running(inner) => match Pin::new(inner).poll(cx) { Poll::Pending => Poll::Pending, - Poll::Ready(pending) => Poll::Ready(match pending.status { - ChildWorkflowStartStatus::Succeeded(s) => Ok(StartedChildWorkflow { - run_id: s.run_id, - common: pending.common, - _phantom: PhantomData, - }), - ChildWorkflowStartStatus::Failed(f) => { - Err(ChildWorkflowExecutionError::StartFailed { - workflow_id: f.workflow_id, - workflow_type: f.workflow_type, - cause: StartChildWorkflowExecutionFailedCause::try_from(f.cause) - .unwrap_or(StartChildWorkflowExecutionFailedCause::Unspecified), - }) - } - ChildWorkflowStartStatus::Cancelled(c) => { - Err(ChildWorkflowExecutionError::Cancelled(Box::new( - c.failure.unwrap_or_default(), - ))) - } - }), + Poll::Ready(pending) => { + let PendingChildWorkflow { status, common, .. } = pending; + Poll::Ready(match status { + ChildWorkflowStartStatus::Succeeded(s) => Ok(StartedChildWorkflow { + run_id: s.run_id, + common, + _phantom: PhantomData, + }), + ChildWorkflowStartStatus::Failed(f) => { + Err(ChildWorkflowExecutionError::StartFailed { + workflow_id: f.workflow_id, + workflow_type: f.workflow_type, + cause: StartChildWorkflowExecutionFailedCause::try_from(f.cause) + .unwrap_or(StartChildWorkflowExecutionFailedCause::Unspecified), + }) + } + ChildWorkflowStartStatus::Cancelled(c) => { + let failure = c.failure.unwrap_or_default(); + let te = common + .data_converter + .to_error(failure, &SerializationContextData::Workflow); + Err(ChildWorkflowExecutionError::from_temporal_error(te)) + } + }) + } }, ChildWorkflowStartFut::Terminated => panic!("polled after termination"), }; @@ -1807,7 +2110,7 @@ enum SignalChildFut { Errored { error: Option, }, - Running(F), + Running(F, DataConverter), Terminated, } @@ -1815,8 +2118,13 @@ impl SignalChildFut { fn eager(err: ChildWorkflowSignalError) -> Self { Self::Errored { error: Some(err) } } + + fn running(inner: F, data_converter: DataConverter) -> Self { + Self::Running(inner, data_converter) + } } +// SAFETY: DataConverter contains only Arc fields, so it is Unpin. impl Unpin for SignalChildFut where F: Unpin {} impl Future for SignalChildFut @@ -1831,11 +2139,12 @@ where SignalChildFut::Errored { error } => { Poll::Ready(Err(error.take().expect("polled after completion"))) } - SignalChildFut::Running(inner) => match Pin::new(inner).poll(cx) { + SignalChildFut::Running(inner, dc) => match Pin::new(inner).poll(cx) { Poll::Pending => Poll::Pending, Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), Poll::Ready(Err(failure)) => { - Poll::Ready(Err(ChildWorkflowSignalError::Failed(Box::new(failure)))) + let te = dc.to_error(failure, &SerializationContextData::Workflow); + Poll::Ready(Err(ChildWorkflowSignalError::Failed(Box::new(te)))) } }, SignalChildFut::Terminated => panic!("polled after termination"), @@ -1861,7 +2170,7 @@ where F: CancellableFuture + Unpin, { fn cancel(&self) { - if let SignalChildFut::Running(inner) = self { + if let SignalChildFut::Running(inner, _) = self { inner.cancel() } } @@ -1878,7 +2187,7 @@ where ) -> impl CancellableFutureWithReason> { ChildWorkflowFut::Running { inner: self.common.result_future, - payload_converter: self.common.payload_converter, + data_converter: self.common.data_converter.clone(), _phantom: PhantomData, } } @@ -1900,11 +2209,12 @@ where signal: S, input: S::Input, ) -> impl CancellableFuture> + 'static { + let pc = self.common.data_converter.payload_converter(); let ctx = SerializationContext { data: &SerializationContextData::Workflow, - converter: &self.common.payload_converter, + converter: pc, }; - let payloads = match self.common.payload_converter.to_payloads(&ctx, &input) { + let payloads = match pc.to_payloads(&ctx, &input) { Ok(p) => p, Err(e) => { return SignalChildFut::eager(e.into()); @@ -1912,7 +2222,10 @@ where }; let signal = Signal::new(S::name(&signal), payloads); let target = sig_we::Target::ChildWorkflowId(self.common.workflow_id.clone()); - SignalChildFut::Running(self.common.base_ctx.clone().send_signal_wf(target, signal)) + SignalChildFut::running( + self.common.base_ctx.clone().send_signal_wf(target, signal), + self.common.data_converter.clone(), + ) } } @@ -1946,16 +2259,12 @@ impl ExternalWorkflowHandle { signal: S, input: S::Input, ) -> impl CancellableFuture + 'static { + let pc = self.base_ctx.inner.data_converter.payload_converter(); let ctx = SerializationContext { data: &SerializationContextData::Workflow, - converter: &self.base_ctx.inner.payload_converter, + converter: pc, }; - let payloads = match self - .base_ctx - .inner - .payload_converter - .to_payloads(&ctx, &input) - { + let payloads = match pc.to_payloads(&ctx, &input) { Ok(p) => p, Err(e) => { return SignalExternalFut::SerializationError(Some(e)); diff --git a/crates/sdk/src/workflow_context/options.rs b/crates/sdk/src/workflow_context/options.rs index 74b797b11..5d6e917b9 100644 --- a/crates/sdk/src/workflow_context/options.rs +++ b/crates/sdk/src/workflow_context/options.rs @@ -5,10 +5,12 @@ use temporalio_common::protos::{ coresdk::{ AsJsonPayloadExt, child_workflow::ChildWorkflowCancellationType, + common::VersioningIntent, nexus::NexusOperationCancellationType, workflow_commands::{ - ActivityCancellationType, ScheduleActivity, ScheduleLocalActivity, - ScheduleNexusOperation, StartChildWorkflowExecution, WorkflowCommand, + ActivityCancellationType, ContinueAsNewWorkflowExecution, ScheduleActivity, + ScheduleLocalActivity, ScheduleNexusOperation, StartChildWorkflowExecution, + WorkflowCommand, }, }, temporal::api::{ @@ -430,6 +432,58 @@ impl IntoWorkflowCommand for NexusOperationOptions { } } +/// Options for continuing a workflow as a new execution. +/// +/// All fields are optional. Unset fields inherit the current workflow's values where applicable. +#[derive(Default, Debug, bon::Builder)] +#[non_exhaustive] +pub struct ContinueAsNewOptions { + /// Override the workflow type for the new execution. If `None`, reuses the current type. + pub workflow_type: Option, + /// Task queue for the new execution. If `None`, reuses the current task queue. + pub task_queue: Option, + /// Timeout for a single run of the new workflow. + pub run_timeout: Option, + /// Timeout of a single workflow task. + pub task_timeout: Option, + /// If set, the new workflow will have this memo. If `None`, reuses the current memo. + pub memo: Option>, + /// If set, the new workflow will have these headers. + pub headers: Option>, + /// If set, the new workflow will have these search attributes. If `None`, reuses the current + /// search attributes. + pub search_attributes: Option, + /// If set, the new workflow will have this retry policy. If `None`, reuses the current policy. + pub retry_policy: Option, + /// Whether the new workflow should run on a worker with a compatible build id. + pub versioning_intent: Option, +} + +impl ContinueAsNewOptions { + pub(crate) fn into_proto( + self, + workflow_type: String, + arguments: Vec, + ) -> ContinueAsNewWorkflowExecution { + ContinueAsNewWorkflowExecution { + workflow_type: self.workflow_type.unwrap_or(workflow_type), + task_queue: self.task_queue.unwrap_or_default(), + arguments, + workflow_run_timeout: self.run_timeout.and_then(|t| t.try_into().ok()), + workflow_task_timeout: self.task_timeout.and_then(|t| t.try_into().ok()), + memo: self.memo.unwrap_or_default(), + headers: self.headers.unwrap_or_default(), + search_attributes: self.search_attributes, + retry_policy: self.retry_policy, + versioning_intent: self + .versioning_intent + .unwrap_or(VersioningIntent::Unspecified) + .into(), + ..Default::default() + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/sdk/src/workflow_future.rs b/crates/sdk/src/workflow_future.rs index 5cb3bcae0..2effeb7c0 100644 --- a/crates/sdk/src/workflow_future.rs +++ b/crates/sdk/src/workflow_future.rs @@ -16,7 +16,7 @@ use std::{ task::{Context, Poll}, }; use temporalio_common::{ - data_converters::PayloadConverter, + data_converters::{DataConverter, SerializationContextData, TemporalError}, protos::{ coresdk::{ workflow_activation::{ @@ -38,7 +38,6 @@ use temporalio_common::{ temporal::api::{ common::v1::{Payload, Payloads}, enums::v1::{VersioningBehavior, WorkflowTaskFailedCause}, - failure::v1::Failure, }, utilities::TryIntoOrNone, }, @@ -67,7 +66,7 @@ impl WorkflowFunction { run_id: String, init_workflow_job: InitializeWorkflow, outgoing_completions: UnboundedSender, - payload_converter: PayloadConverter, + data_converter: DataConverter, detect_nondeterministic: bool, ) -> Result< ( @@ -90,12 +89,16 @@ impl WorkflowFunction { run_id, init_workflow_job, cancel_rx, - payload_converter.clone(), + data_converter.clone(), ); // Create the workflow execution using the factory - let execution = (self.factory)(input, payload_converter.clone(), base_ctx.clone()) - .context("Failed to create workflow execution")?; + let execution = (self.factory)( + input, + data_converter.payload_converter().clone(), + base_ctx.clone(), + ) + .context("Failed to create workflow execution")?; let wake_tracking = if detect_nondeterministic { Some(WakeTracker::new()) @@ -114,7 +117,7 @@ impl WorkflowFunction { incoming_activations, command_status: Default::default(), cancel_sender: cancel_tx, - payload_converter, + data_converter, update_futures: Default::default(), signal_futures: Default::default(), wake_tracking, @@ -145,8 +148,8 @@ pub(crate) struct WorkflowFuture { cancel_sender: watch::Sender>, /// Base workflow context for sending commands base_ctx: BaseWorkflowContext, - /// Payload converter for serialization/deserialization - payload_converter: PayloadConverter, + /// Data converter for serialization/deserialization and failure conversion + data_converter: DataConverter, /// Stores in-progress update futures update_futures: Vec<( String, @@ -260,7 +263,7 @@ impl WorkflowFuture { payloads: q.arguments, }, headers: q.headers, - converter: &self.payload_converter, + converter: self.data_converter.payload_converter(), }; let dispatch_result = match panic::catch_unwind(AssertUnwindSafe(|| { @@ -279,14 +282,14 @@ impl WorkflowFuture { response: Some(payload), }), // TODO [rust-sdk-branch]: Return list of known queries in error - None => query_result::Variant::Failed(Failure { - message: format!("No query handler for '{}'", query_type), - ..Default::default() - }), - Some(Err(e)) => query_result::Variant::Failed(Failure { - message: e.to_string(), - ..Default::default() - }), + None => query_result::Variant::Failed(self.data_converter.to_failure( + format!("No query handler for '{}'", query_type).into(), + &SerializationContextData::Workflow, + )), + Some(Err(e)) => query_result::Variant::Failed( + self.data_converter + .to_failure(Box::new(e), &SerializationContextData::Workflow), + ), }; outgoing_cmds.push( @@ -312,7 +315,7 @@ impl WorkflowFuture { payloads: sig.input, }, headers: sig.headers, - converter: &self.payload_converter, + converter: self.data_converter.payload_converter(), }; let dispatch_result = match panic::catch_unwind(AssertUnwindSafe(|| { @@ -344,7 +347,7 @@ impl WorkflowFuture { let data = DispatchData { payloads: Payloads { payloads: u.input }, headers: u.headers, - converter: &self.payload_converter, + converter: self.data_converter.payload_converter(), }; let trait_val_result = if u.run_validator { @@ -378,10 +381,13 @@ impl WorkflowFuture { } } Some(Err(e)) => { + let failure = self + .data_converter + .to_failure(Box::new(e), &SerializationContextData::Workflow); outgoing_cmds.push( update_response( protocol_instance_id.clone(), - update_response::Response::Rejected(anyhow!(e).into()), + update_response::Response::Rejected(failure), ) .into(), ); @@ -391,16 +397,14 @@ impl WorkflowFuture { } } if not_found { + let failure = self.data_converter.to_failure( + format!("No update handler registered for update name {}", name).into(), + &SerializationContextData::Workflow, + ); outgoing_cmds.push( update_response( protocol_instance_id, - update_response::Response::Rejected( - format!( - "No update handler registered for update name {}", - name - ) - .into(), - ), + update_response::Response::Rejected(failure), ) .into(), ); @@ -581,7 +585,12 @@ impl WorkflowFuture { instance_id, match v { Ok(v) => update_response::Response::Completed(v), - Err(e) => update_response::Response::Rejected(e.into()), + Err(e) => update_response::Response::Rejected( + self.data_converter.to_failure( + Box::new(e), + &SerializationContextData::Workflow, + ), + ), }, ) .into(), @@ -621,16 +630,21 @@ impl WorkflowFuture { Ok(r) => r, Err(e) => { let errmsg = format!("Workflow function panicked: {}", panic_formatter(e)); + let failure = self.data_converter.to_failure( + Box::new(TemporalError::Application { + message: errmsg.clone(), + stack_trace: String::new(), + r#type: String::new(), + non_retryable: true, + details: None, + next_retry_delay: None, + cause: None, + }), + &SerializationContextData::Workflow, + ); warn!("{}", errmsg); self.outgoing_completions - .send(WorkflowActivationCompletion::fail( - run_id, - Failure { - message: errmsg, - ..Default::default() - }, - None, - )) + .send(WorkflowActivationCompletion::fail(run_id, failure, None)) .expect("Completion channel intact"); // Loop back up because we're about to get evicted return Ok(true); @@ -762,11 +776,11 @@ impl WorkflowFuture { panic!("Don't explicitly return WorkflowTermination::Evicted") } WorkflowTermination::Failed(e) => { + let failure = self + .data_converter + .to_failure(anyhow_to_box(e), &SerializationContextData::Workflow); workflow_command::Variant::FailWorkflowExecution(FailWorkflowExecution { - failure: Some(Failure { - message: format!("Workflow execution error: {e}"), - ..Default::default() - }), + failure: Some(failure), }) } }, @@ -789,6 +803,18 @@ enum CommandID { NexusOpComplete(u32), } +/// Convert an `anyhow::Error` into `Box`, preserving the inner +/// type for downcasting by the failure converter. `anyhow::Error` doesn't +/// implement `std::error::Error`, so a plain `.into()` would lose the +/// wrapped type. Instead we try to extract a `TemporalError` first. +fn anyhow_to_box(e: anyhow::Error) -> Box { + use temporalio_common::data_converters::TemporalError; + match e.downcast::() { + Ok(te) => Box::new(te), + Err(e) => e.into_boxed_dyn_error(), + } +} + fn update_response( instance_id: String, resp: update_response::Response, diff --git a/crates/sdk/src/workflows.rs b/crates/sdk/src/workflows.rs index ce3fdd9f6..ee8c003d3 100644 --- a/crates/sdk/src/workflows.rs +++ b/crates/sdk/src/workflows.rs @@ -109,10 +109,7 @@ use temporalio_common::{ GenericPayloadConverter, PayloadConversionError, PayloadConverter, SerializationContext, SerializationContextData, TemporalDeserializable, TemporalSerializable, }, - protos::temporal::api::{ - common::v1::{Payload, Payloads}, - failure::v1::Failure, - }, + protos::temporal::api::common::v1::{Payload, Payloads}, }; /// Error type for workflow operations @@ -127,15 +124,6 @@ pub enum WorkflowError { Execution(#[from] anyhow::Error), } -impl From for Failure { - fn from(err: WorkflowError) -> Self { - Failure { - message: err.to_string(), - ..Default::default() - } - } -} - /// Trait implemented by workflow structs to enable execution by the worker. /// /// This trait is typically generated by the `#[workflow_methods]` macro and should not