diff --git a/core/src/core_tests/activity_tasks.rs b/core/src/core_tests/activity_tasks.rs index 045ada976..516f9750f 100644 --- a/core/src/core_tests/activity_tasks.rs +++ b/core/src/core_tests/activity_tasks.rs @@ -1105,8 +1105,9 @@ async fn graceful_shutdown(#[values(true, false)] at_max_outstanding: bool) { assert_matches!( cancel.variant, Some(activity_task::Variant::Cancel(Cancel { - reason: r - })) if r == ActivityCancelReason::WorkerShutdown as i32 + reason, + details + })) if reason == ActivityCancelReason::WorkerShutdown as i32 && details.as_ref().is_some_and(|d| d.is_worker_shutdown) ); seen_tts.insert(cancel.task_token); } @@ -1241,3 +1242,163 @@ async fn pass_activity_summary_to_metadata() { .unwrap(); worker.run_until_done().await.unwrap(); } + +#[tokio::test] +async fn heartbeat_response_can_be_paused() { + let mut mock_client = mock_workflow_client(); + // First heartbeat returns pause only + mock_client + .expect_record_activity_heartbeat() + .times(1) + .returning(|_, _| { + Ok(RecordActivityTaskHeartbeatResponse { + cancel_requested: false, + activity_paused: true, + }) + }); + // Second heartbeat returns cancel only + mock_client + .expect_record_activity_heartbeat() + .times(1) + .returning(|_, _| { + Ok(RecordActivityTaskHeartbeatResponse { + cancel_requested: true, + activity_paused: false, + }) + }); + // Third heartbeat returns both + mock_client + .expect_record_activity_heartbeat() + .times(1) + .returning(|_, _| { + Ok(RecordActivityTaskHeartbeatResponse { + cancel_requested: true, + activity_paused: true, + }) + }); + mock_client + .expect_cancel_activity_task() + .times(3) + .returning(|_, _| Ok(RespondActivityTaskCanceledResponse::default())); + + let core = mock_worker(MocksHolder::from_client_with_activities( + mock_client, + [ + PollActivityTaskQueueResponse { + task_token: vec![1], + activity_id: "act1".to_string(), + heartbeat_timeout: Some(prost_dur!(from_millis(1))), + ..Default::default() + } + .into(), + PollActivityTaskQueueResponse { + task_token: vec![2], + activity_id: "act2".to_string(), + heartbeat_timeout: Some(prost_dur!(from_millis(1))), + ..Default::default() + } + .into(), + PollActivityTaskQueueResponse { + task_token: vec![3], + activity_id: "act3".to_string(), + heartbeat_timeout: Some(prost_dur!(from_millis(1))), + ..Default::default() + } + .into(), + ], + )); + + // The general testing pattern for each of these cases is: + // 1. Poll for activity task + // 2. Record activity heartbeat, get mocked heartbeat response + // 3. Sleep for 10ms (waiting for heartbeat request to be flushed) + // (i.e. sleep enough for the heartbeat flush interval to have elapsed) + // 4. Poll for activity task. + // We expect a cancellation activity task as they are prioritized (i.e. ordered before) + // regular activity tasks. + // 5. Assert that the received activity task is indeed a cancellation, with the reason + // and details we expect. + // 6. Complete the activity with a cancellation result. + // + // Repeat for subsequent test case(s). + + // Test pause only + let act = core.poll_activity_task().await.unwrap(); + core.record_activity_heartbeat(ActivityHeartbeat { + task_token: act.task_token.clone(), + details: vec![vec![1_u8, 2, 3].into()], + }); + sleep(Duration::from_millis(10)).await; + let act = core.poll_activity_task().await.unwrap(); + assert_matches!( + &act, + ActivityTask { + task_token, + variant: Some(activity_task::Variant::Cancel(Cancel { reason, details })), + } if + task_token == &vec![1] && + *reason == ActivityCancelReason::Paused as i32 && + details.as_ref().is_some_and(|d| d.is_paused) && + details.as_ref().is_some_and(|d| !d.is_cancelled) + ); + core.complete_activity_task(ActivityTaskCompletion { + task_token: act.task_token, + result: Some(ActivityExecutionResult::cancel_from_details(None)), + }) + .await + .unwrap(); + + // Test cancel only + let act = core.poll_activity_task().await.unwrap(); + core.record_activity_heartbeat(ActivityHeartbeat { + task_token: act.task_token.clone(), + details: vec![vec![1_u8, 2, 3].into()], + }); + sleep(Duration::from_millis(10)).await; + let act = core.poll_activity_task().await.unwrap(); + assert_matches!( + &act, + ActivityTask { + task_token, + variant: Some(activity_task::Variant::Cancel(Cancel { reason, details })), + } if + task_token == &vec![2] && + *reason == ActivityCancelReason::Cancelled as i32 && + details.as_ref().is_some_and(|d| !d.is_paused) && + details.as_ref().is_some_and(|d| d.is_cancelled) + ); + core.complete_activity_task(ActivityTaskCompletion { + task_token: act.task_token, + result: Some(ActivityExecutionResult::cancel_from_details(None)), + }) + .await + .unwrap(); + + // Test both pause and cancel (should prioritize cancel) + let act = core.poll_activity_task().await.unwrap(); + core.record_activity_heartbeat(ActivityHeartbeat { + task_token: act.task_token.clone(), + details: vec![vec![1_u8, 2, 3].into()], + }); + sleep(Duration::from_millis(10)).await; + let act = core.poll_activity_task().await.unwrap(); + assert_matches!( + &act, + ActivityTask { + task_token, + variant: Some(activity_task::Variant::Cancel(Cancel { reason, details })), + } if + task_token == &vec![3] && + *reason == ActivityCancelReason::Cancelled as i32 && + details.as_ref().is_some_and(|d| d.is_paused) && + details.as_ref().is_some_and(|d| d.is_cancelled) + ); + core.complete_activity_task(ActivityTaskCompletion { + task_token: act.task_token, + result: Some(ActivityExecutionResult::cancel_from_details(None)), + }) + .await + .unwrap(); + + core.drain_activity_poller_and_shutdown().await; +} diff --git a/core/src/worker/activities.rs b/core/src/worker/activities.rs index cb955315e..b506008a2 100644 --- a/core/src/worker/activities.rs +++ b/core/src/worker/activities.rs @@ -40,7 +40,7 @@ use temporal_sdk_core_protos::{ coresdk::{ ActivityHeartbeat, ActivitySlotInfo, activity_result::{self as ar, activity_execution_result as aer}, - activity_task::{ActivityCancelReason, ActivityTask}, + activity_task::{ActivityCancelReason, ActivityCancellationDetails, ActivityTask}, }, temporal::api::{ failure::v1::{ApplicationFailureInfo, CanceledFailureInfo, Failure, failure::FailureInfo}, @@ -65,16 +65,19 @@ type OutstandingActMap = Arc>; struct PendingActivityCancel { task_token: TaskToken, reason: ActivityCancelReason, - /// Set true if we should assume the server has already forgotten about this activity - consider_not_found: bool, + details: ActivityCancellationDetails, } impl PendingActivityCancel { - fn new(task_token: TaskToken, reason: ActivityCancelReason) -> Self { + fn new( + task_token: TaskToken, + reason: ActivityCancelReason, + details: ActivityCancellationDetails, + ) -> Self { Self { task_token, reason, - consider_not_found: false, + details, } } } @@ -508,13 +511,14 @@ where } else { details.issued_cancel_to_lang = Some(next_pc.reason); if next_pc.reason == ActivityCancelReason::NotFound - || next_pc.consider_not_found + || next_pc.details.is_not_found { details.known_not_found = true; } Some(Ok(ActivityTask::cancel_from_ids( next_pc.task_token.0, next_pc.reason, + next_pc.details, ))) } } else { @@ -566,6 +570,9 @@ where let _ = cancels_tx.send(PendingActivityCancel::new( tt, ActivityCancelReason::WorkerShutdown, + ActivityTask::primary_reason_to_cancellation_details( + ActivityCancelReason::WorkerShutdown, + ), )); } else { // Fire off task to keep track of local timeouts. We do this so that @@ -611,11 +618,15 @@ where "Timing out activity due to elapsed local \ {timeout_type} timer" ); - let _ = cancel_tx.send(PendingActivityCancel { - task_token: tt, - reason: ActivityCancelReason::TimedOut, - consider_not_found: true, - }); + let _ = cancel_tx.send(PendingActivityCancel::new( + tt, + ActivityCancelReason::TimedOut, + ActivityCancellationDetails { + is_not_found: true, + is_timed_out: true, + ..Default::default() + }, + )); })); outstanding_info.timeout_resetter = resetter; } @@ -639,6 +650,9 @@ where let _ = self.cancels_tx.send(PendingActivityCancel::new( mapref.key().clone(), ActivityCancelReason::WorkerShutdown, + ActivityTask::primary_reason_to_cancellation_details( + ActivityCancelReason::WorkerShutdown, + ), )); } } diff --git a/core/src/worker/activities/activity_heartbeat_manager.rs b/core/src/worker/activities/activity_heartbeat_manager.rs index 0bd928a12..c461f4353 100644 --- a/core/src/worker/activities/activity_heartbeat_manager.rs +++ b/core/src/worker/activities/activity_heartbeat_manager.rs @@ -10,7 +10,10 @@ use std::{ time::{Duration, Instant}, }; use temporal_sdk_core_protos::{ - coresdk::{ActivityHeartbeat, IntoPayloadsExt, activity_task::ActivityCancelReason}, + coresdk::{ + ActivityHeartbeat, IntoPayloadsExt, + activity_task::{ActivityCancelReason, ActivityCancellationDetails, ActivityTask}, + }, temporal::api::{ common::v1::Payload, workflowservice::v1::RecordActivityTaskHeartbeatResponse, }, @@ -142,12 +145,23 @@ impl ActivityHeartbeatManager { .record_activity_heartbeat(tt.clone(), details.into_payloads()) .await { - Ok(RecordActivityTaskHeartbeatResponse { cancel_requested, activity_paused: _ }) => { - if cancel_requested { + Ok(RecordActivityTaskHeartbeatResponse { cancel_requested, activity_paused }) => { + if cancel_requested || activity_paused { + // Prioritize Cancel over Pause + let reason = if cancel_requested { + ActivityCancelReason::Cancelled + } else { + ActivityCancelReason::Paused + }; cancels_tx .send(PendingActivityCancel::new( tt.clone(), - ActivityCancelReason::Cancelled, + reason, + ActivityCancellationDetails { + is_cancelled: cancel_requested, + is_paused: activity_paused, + ..Default::default() + } )) .expect( "Receive half of heartbeat cancels not blocked", @@ -164,6 +178,7 @@ impl ActivityHeartbeatManager { .send(PendingActivityCancel::new( tt.clone(), ActivityCancelReason::NotFound, + ActivityTask::primary_reason_to_cancellation_details(ActivityCancelReason::NotFound) )) .expect("Receive half of heartbeat cancels not blocked"); } diff --git a/core/src/worker/activities/local_activities.rs b/core/src/worker/activities/local_activities.rs index 7f34b5679..a63bfe4d6 100644 --- a/core/src/worker/activities/local_activities.rs +++ b/core/src/worker/activities/local_activities.rs @@ -22,7 +22,7 @@ use temporal_sdk_core_protos::{ coresdk::{ LocalActivitySlotInfo, activity_result::{Cancellation, Failure as ActFail, Success}, - activity_task::{ActivityCancelReason, ActivityTask, Cancel, Start, activity_task}, + activity_task::{ActivityCancelReason, ActivityTask, Start, activity_task}, }, temporal::api::{ common::v1::WorkflowExecution, @@ -629,12 +629,13 @@ impl LocalActivityManager { }; // We want to generate a cancel task if the reason for failure was a timeout. let task = if is_timeout { - Some(ActivityTask { - task_token: task_token.clone().0, - variant: Some(activity_task::Variant::Cancel(Cancel { - reason: ActivityCancelReason::TimedOut as i32, - })), - }) + Some(ActivityTask::cancel_from_ids( + task_token.clone().0, + ActivityCancelReason::TimedOut, + ActivityTask::primary_reason_to_cancellation_details( + ActivityCancelReason::TimedOut, + ), + )) } else { None }; @@ -786,12 +787,13 @@ impl LocalActivityManager { } self.cancels_req_tx - .send(CancelOrTimeout::Cancel(ActivityTask { - task_token: lai.task_token.0.clone(), - variant: Some(activity_task::Variant::Cancel(Cancel { - reason: ActivityCancelReason::Cancelled as i32, - })), - })) + .send(CancelOrTimeout::Cancel(ActivityTask::cancel_from_ids( + lai.task_token.0.clone(), + ActivityCancelReason::Cancelled, + ActivityTask::primary_reason_to_cancellation_details( + ActivityCancelReason::Cancelled, + ), + ))) .expect("Receive half of LA cancel channel cannot be dropped"); None } diff --git a/sdk-core-protos/protos/local/temporal/sdk/core/activity_task/activity_task.proto b/sdk-core-protos/protos/local/temporal/sdk/core/activity_task/activity_task.proto index 88b955ae2..0a3a53a8f 100644 --- a/sdk-core-protos/protos/local/temporal/sdk/core/activity_task/activity_task.proto +++ b/sdk-core-protos/protos/local/temporal/sdk/core/activity_task/activity_task.proto @@ -67,7 +67,18 @@ message Start { // Attempt to cancel a running activity message Cancel { + // Primary cancellation reason ActivityCancelReason reason = 1; + // Activity cancellation details, surfaces all cancellation reasons. + ActivityCancellationDetails details = 2; +} + +message ActivityCancellationDetails { + bool is_not_found = 1; + bool is_cancelled = 2; + bool is_paused = 3; + bool is_timed_out = 4; + bool is_worker_shutdown = 5; } enum ActivityCancelReason { @@ -79,6 +90,8 @@ enum ActivityCancelReason { TIMED_OUT = 2; // Core is shutting down and the graceful timeout has elapsed WORKER_SHUTDOWN = 3; + // Activity was paused + PAUSED = 4; } diff --git a/sdk-core-protos/src/lib.rs b/sdk-core-protos/src/lib.rs index 0bb448bd0..beb5d5e66 100644 --- a/sdk-core-protos/src/lib.rs +++ b/sdk-core-protos/src/lib.rs @@ -69,23 +69,42 @@ pub mod coresdk { tonic::include_proto!("coresdk.activity_task"); impl ActivityTask { - pub fn cancel_from_ids(task_token: Vec, reason: ActivityCancelReason) -> Self { + pub fn cancel_from_ids( + task_token: Vec, + reason: ActivityCancelReason, + details: ActivityCancellationDetails, + ) -> Self { Self { task_token, variant: Some(activity_task::Variant::Cancel(Cancel { reason: reason as i32, + details: Some(details), })), } } + // Checks if both the primary reason or details have a timeout cancellation. pub fn is_timeout(&self) -> bool { match &self.variant { - Some(activity_task::Variant::Cancel(Cancel { reason })) => { + Some(activity_task::Variant::Cancel(Cancel { reason, details })) => { *reason == ActivityCancelReason::TimedOut as i32 + || details.as_ref().is_some_and(|d| d.is_timed_out) } _ => false, } } + + pub fn primary_reason_to_cancellation_details( + reason: ActivityCancelReason, + ) -> ActivityCancellationDetails { + ActivityCancellationDetails { + is_not_found: reason == ActivityCancelReason::NotFound, + is_cancelled: reason == ActivityCancelReason::Cancelled, + is_paused: reason == ActivityCancelReason::Paused, + is_timed_out: reason == ActivityCancelReason::TimedOut, + is_worker_shutdown: reason == ActivityCancelReason::WorkerShutdown, + } + } } impl Display for ActivityTaskCompletion {