Skip to content

Commit d813282

Browse files
authored
Implement nexus-delivered activity cancels (temporalio#2394)
1 parent 8497711 commit d813282

11 files changed

Lines changed: 592 additions & 45 deletions

activity/doc.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ The first parameter to an activity function can be an optional context.Context.
5656
* The activity function returns.
5757
* The context deadline is exceeded. The deadline is calculated based on the minimum of the ScheduleToClose timeout plus
5858
the activity task scheduled time and the StartToClose timeout plus the activity task start time.
59-
* The activity calls RecordHeartbeat after being cancelled by the Temporal server.
59+
* The Temporal server requests cancellation. On supported servers this can be delivered directly to the worker; on older
60+
servers the activity observes cancellation when it calls RecordHeartbeat after the request.
6061
6162
# Failing the Activity
6263
@@ -97,7 +98,8 @@ payload containing progress information.
9798
9899
When an Activity is canceled (or its Workflow execution is completed or failed) the context passed into its function
99100
is canceled which sets its Done channel’s closed state. So an Activity can use that to perform any necessary cleanup
100-
and abort its execution. Currently cancellation is delivered only to Activities that call RecordHeartbeat.
101+
and abort its execution. On supported servers, cancellation can be delivered directly to the worker. Heartbeats are still
102+
recommended for long-running Activities and may be the cancellation delivery path on older servers.
101103
102104
# Async/Manual Activity Completion
103105

internal/activity.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -289,10 +289,9 @@ func GetWorkerStopChannel(ctx context.Context) <-chan struct{} {
289289
// If the activity is either canceled or workflow/activity doesn't exist, then we would cancel
290290
// the context with error context.Canceled.
291291
//
292-
// TODO: Implement automatic heartbeating with cancellation through ctx.
293-
//
294292
// details - The details that you provided here can be seen in the workflow when it receives TimeoutError. You
295-
// can check error TimeoutType()/Details().
293+
// can check error TimeoutType()/Details(). Heartbeat responses may also deliver server requests such as activity
294+
// cancellation, pause, and reset to the activity context.
296295
//
297296
// Exposed as: [go.temporal.io/sdk/activity.RecordHeartbeat]
298297
func RecordActivityHeartbeat(ctx context.Context, details ...interface{}) {

internal/internal_task_handlers.go

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,12 @@ type (
145145
cache *WorkerCache
146146
deadlockDetectionTimeout time.Duration
147147
capabilities *workflowservice.GetSystemInfoResponse_Capabilities
148+
workerControlTaskQueue string
149+
}
150+
151+
activityCancellationCallbacks struct {
152+
sync.Mutex
153+
cancels map[string]context.CancelCauseFunc
148154
}
149155

150156
activityProvider func(name string) activity
@@ -172,6 +178,7 @@ type (
172178
inboundPayloadVisitor PayloadVisitor
173179
outboundPayloadVisitor PayloadVisitor
174180
payloadVisitorConcurrency int
181+
activityCancellationCallbacks *activityCancellationCallbacks
175182
}
176183

177184
// history wrapper method to help information about events.
@@ -587,6 +594,7 @@ func newWorkflowTaskHandler(params workerExecutionParameters, ppMgr pressurePoin
587594
cache: params.cache,
588595
deadlockDetectionTimeout: params.DeadlockDetectionTimeout,
589596
capabilities: params.capabilities,
597+
workerControlTaskQueue: params.workerControlTaskQueue,
590598
}
591599
}
592600

@@ -2007,6 +2015,7 @@ func (wth *workflowTaskHandlerImpl) completeWorkflow(
20072015
wth.useBuildIDForVersioning,
20082016
wth.workerDeploymentVersion,
20092017
),
2018+
WorkerControlTaskQueue: wth.workerControlTaskQueue,
20102019
}
20112020
if wth.capabilities != nil && wth.capabilities.BuildIdBasedVersioning {
20122021
//lint:ignore SA1019 ignore deprecated versioning APIs
@@ -2100,12 +2109,40 @@ func newActivityTaskHandlerWithCustomProvider(
21002109
params.UseBuildIDForVersioning,
21012110
params.DeploymentOptions.Version,
21022111
),
2103-
inboundPayloadVisitor: params.inboundPayloadVisitor,
2104-
outboundPayloadVisitor: params.outboundPayloadVisitor,
2105-
payloadVisitorConcurrency: params.payloadVisitorConcurrency,
2112+
inboundPayloadVisitor: params.inboundPayloadVisitor,
2113+
outboundPayloadVisitor: params.outboundPayloadVisitor,
2114+
payloadVisitorConcurrency: params.payloadVisitorConcurrency,
2115+
activityCancellationCallbacks: params.activityCancellationCallbacks,
2116+
}
2117+
}
2118+
2119+
func newActivityCancellationCallbacks() *activityCancellationCallbacks {
2120+
return &activityCancellationCallbacks{cancels: make(map[string]context.CancelCauseFunc)}
2121+
}
2122+
2123+
func (r *activityCancellationCallbacks) register(taskToken []byte, cancel context.CancelCauseFunc) func() {
2124+
key := string(taskToken)
2125+
r.Lock()
2126+
r.cancels[key] = cancel
2127+
r.Unlock()
2128+
return func() {
2129+
r.Lock()
2130+
delete(r.cancels, key)
2131+
r.Unlock()
21062132
}
21072133
}
21082134

2135+
func (r *activityCancellationCallbacks) cancel(taskToken []byte) bool {
2136+
r.Lock()
2137+
cancel, ok := r.cancels[string(taskToken)]
2138+
r.Unlock()
2139+
if !ok {
2140+
return false
2141+
}
2142+
cancel(NewCanceledError())
2143+
return true
2144+
}
2145+
21092146
// heartbeatVisitorError wraps an outbound payload visitor error from a heartbeat.
21102147
// It is used as the context cancellation cause so Execute() can detect that
21112148
// RespondActivityTaskFailed was already sent proactively and skip sending a second response.
@@ -2120,7 +2157,8 @@ type temporalInvoker struct {
21202157
service workflowservice.WorkflowServiceClient
21212158
metricsHandler metrics.Handler
21222159
taskToken []byte
2123-
// cancelHandler is called when the activity is canceled by a heartbeat request.
2160+
// cancelHandler is called when the activity is canceled by a heartbeat response
2161+
// or worker command.
21242162
cancelHandler context.CancelCauseFunc
21252163
// Amount of time to wait between each pending heartbeat send
21262164
heartbeatThrottleInterval time.Duration
@@ -2351,6 +2389,10 @@ func (ath *activityTaskHandlerImpl) Execute(taskQueue string, t *workflowservice
23512389
}
23522390
canCtx, cancel := context.WithCancelCause(rootCtx)
23532391
defer cancel(nil)
2392+
if ath.activityCancellationCallbacks != nil {
2393+
unregister := ath.activityCancellationCallbacks.register(t.TaskToken, cancel)
2394+
defer unregister()
2395+
}
23542396

23552397
if err := visitProtoPayloads(canCtx, ath.inboundPayloadVisitor, t, ath.payloadVisitorConcurrency); err != nil {
23562398
return ath.visitorErrorToActivityFailure("Activity task preprocess error: ", t, err), nil

internal/internal_task_handlers_test.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,19 @@ func TestTaskHandlersTestSuite(t *testing.T) {
135135
})
136136
}
137137

138+
func TestActivityCancellationCallbacksCancel(t *testing.T) {
139+
registry := newActivityCancellationCallbacks()
140+
taskToken := []byte{1, 2, 3}
141+
ctx, cancel := context.WithCancelCause(context.Background())
142+
143+
unregister := registry.register(taskToken, cancel)
144+
require.True(t, registry.cancel([]byte{1, 2, 3}))
145+
require.True(t, IsCanceledError(context.Cause(ctx)))
146+
147+
unregister()
148+
require.False(t, registry.cancel(taskToken))
149+
}
150+
138151
func createTestEventWorkflowExecutionCompleted(eventID int64, attr *historypb.WorkflowExecutionCompletedEventAttributes) *historypb.HistoryEvent {
139152
return &historypb.HistoryEvent{
140153
EventId: eventID,

internal/internal_task_pollers.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ type (
8989
pollTimeTracker *pollTimeTracker
9090
// Unique identifier for worker
9191
workerInstanceKey string
92+
// Per-client queue used by the server to send worker commands.
93+
workerControlTaskQueue string
9294
// Server cancels polls on shutdown
9395
workerPollCompleteOnShutdown *atomic.Bool
9496
}
@@ -369,6 +371,7 @@ func newWorkflowTaskProcessor(
369371
capabilities: params.capabilities,
370372
pollTimeTracker: params.pollTimeTracker,
371373
workerInstanceKey: params.workerInstanceKey,
374+
workerControlTaskQueue: params.workerControlTaskQueue,
372375
workerPollCompleteOnShutdown: params.workerPollCompleteOnShutdown,
373376
},
374377
service: service,
@@ -1151,7 +1154,8 @@ func (wtp *workflowTaskPoller) getNextPollRequest() (request *workflowservice.Po
11511154
wtp.useBuildIDVersioning,
11521155
wtp.workerDeploymentVersion,
11531156
),
1154-
WorkerInstanceKey: wtp.workerInstanceKey,
1157+
WorkerInstanceKey: wtp.workerInstanceKey,
1158+
WorkerControlTaskQueue: wtp.workerControlTaskQueue,
11551159
}
11561160
if wtp.getCapabilities().BuildIdBasedVersioning {
11571161
//lint:ignore SA1019 ignore deprecated versioning APIs
@@ -1368,6 +1372,7 @@ func newActivityTaskPoller(taskHandler ActivityTaskHandler, service workflowserv
13681372
capabilities: params.capabilities,
13691373
pollTimeTracker: params.pollTimeTracker,
13701374
workerInstanceKey: params.workerInstanceKey,
1375+
workerControlTaskQueue: params.workerControlTaskQueue,
13711376
workerPollCompleteOnShutdown: params.workerPollCompleteOnShutdown,
13721377
},
13731378
taskHandler: taskHandler,
@@ -1408,7 +1413,8 @@ func (atp *activityTaskPoller) poll(ctx context.Context) (taskForWorker, error)
14081413
atp.useBuildIDVersioning,
14091414
atp.workerDeploymentVersion,
14101415
),
1411-
WorkerInstanceKey: atp.workerInstanceKey,
1416+
WorkerInstanceKey: atp.workerInstanceKey,
1417+
WorkerControlTaskQueue: atp.workerControlTaskQueue,
14121418
}
14131419

14141420
response, err := atp.pollActivityTaskQueue(ctx, request)

internal/internal_task_pollers_test.go

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@ import (
44
"context"
55
"encoding/binary"
66
"errors"
7-
"github.com/google/uuid"
87
"sync/atomic"
98
"testing"
109
"time"
1110

1211
"github.com/golang/mock/gomock"
12+
"github.com/google/uuid"
1313
"github.com/stretchr/testify/require"
1414
commonpb "go.temporal.io/api/common/v1"
1515
enumspb "go.temporal.io/api/enums/v1"
@@ -40,6 +40,63 @@ func (wth *countingTaskHandler) ProcessWorkflowTask(
4040
return wth.WorkflowTaskHandler.ProcessWorkflowTask(task, wfctx, hb)
4141
}
4242

43+
func TestPollRequestsIncludeWorkerControlTaskQueue(t *testing.T) {
44+
t.Parallel()
45+
46+
ctrl := gomock.NewController(t)
47+
service := workflowservicemock.NewMockWorkflowServiceClient(ctrl)
48+
const (
49+
namespace = "test-ns"
50+
taskQueue = "test-task-queue"
51+
identity = "test-worker"
52+
controlQueue = "temporal-sys/worker-commands/test-ns/grouping-key"
53+
)
54+
55+
service.EXPECT().PollWorkflowTaskQueue(gomock.Any(), gomock.Any(), gomock.Any()).
56+
DoAndReturn(func(_ context.Context, req *workflowservice.PollWorkflowTaskQueueRequest, _ ...grpc.CallOption) (*workflowservice.PollWorkflowTaskQueueResponse, error) {
57+
require.Equal(t, controlQueue, req.WorkerControlTaskQueue)
58+
require.Equal(t, taskQueue, req.TaskQueue.GetName())
59+
return &workflowservice.PollWorkflowTaskQueueResponse{}, nil
60+
})
61+
62+
base := basePoller{
63+
metricsHandler: metrics.NopHandler,
64+
workerBuildID: "test-build-id",
65+
workerControlTaskQueue: controlQueue,
66+
}
67+
wtp := &workflowTaskPoller{
68+
basePoller: base,
69+
mode: NonSticky,
70+
namespace: namespace,
71+
taskQueueName: taskQueue,
72+
identity: identity,
73+
service: service,
74+
logger: ilog.NewDefaultLogger(),
75+
numNormalPollerMetric: newNumPollerMetric(metrics.NopHandler, metrics.PollerTypeWorkflowTask),
76+
}
77+
_, err := wtp.poll(context.Background())
78+
require.NoError(t, err)
79+
80+
service.EXPECT().PollActivityTaskQueue(gomock.Any(), gomock.Any(), gomock.Any()).
81+
DoAndReturn(func(_ context.Context, req *workflowservice.PollActivityTaskQueueRequest, _ ...grpc.CallOption) (*workflowservice.PollActivityTaskQueueResponse, error) {
82+
require.Equal(t, controlQueue, req.WorkerControlTaskQueue)
83+
require.Equal(t, taskQueue, req.TaskQueue.GetName())
84+
return &workflowservice.PollActivityTaskQueueResponse{}, nil
85+
})
86+
87+
atp := &activityTaskPoller{
88+
basePoller: base,
89+
namespace: namespace,
90+
taskQueueName: taskQueue,
91+
identity: identity,
92+
service: service,
93+
logger: ilog.NewDefaultLogger(),
94+
numPollerMetric: newNumPollerMetric(metrics.NopHandler, metrics.PollerTypeActivityTask),
95+
}
96+
_, err = atp.poll(context.Background())
97+
require.NoError(t, err)
98+
}
99+
43100
func TestWFTRacePrevention(t *testing.T) {
44101
params := workerExecutionParameters{cache: NewWorkerCache()}
45102
ensureRequiredParams(&params)

internal/internal_worker.go

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,10 @@ type (
228228

229229
workerInstanceKey string
230230

231+
workerControlTaskQueue string
232+
233+
activityCancellationCallbacks *activityCancellationCallbacks
234+
231235
workerPollCompleteOnShutdown *atomic.Bool
232236

233237
// Set to true during start() when the namespace has the poller_autoscaling capability.
@@ -314,6 +318,10 @@ func (params *workerExecutionParameters) isInternalWorker() bool {
314318
return params.Namespace == "temporal-system" || params.TaskQueue == "temporal-sys-per-ns-tq"
315319
}
316320

321+
func workerControlTaskQueue(namespace, groupingKey string) string {
322+
return fmt.Sprintf("temporal-sys/worker-commands/%s/%s", namespace, groupingKey)
323+
}
324+
317325
func newWorkflowWorkerInternal(client *WorkflowClient, params workerExecutionParameters, ppMgr pressurePointMgr, overrides *workerOverrides, registry *registry) *workflowWorker {
318326
workerStopChannel := make(chan struct{})
319327
params.WorkerStopChannel = getReadOnlyChannel(workerStopChannel)
@@ -1224,7 +1232,7 @@ type AggregatedWorker struct {
12241232
// stopC is created in NewAggregatedWorker and closed by AggregatedWorker.Stop()
12251233
// to mark the aggregated worker stopped, unblock Run(), and prevent restart.
12261234
// Child worker stop channels are closed later by their own Stop methods.
1227-
stopC chan struct{}
1235+
stopC chan struct{}
12281236
fatalErr error
12291237
fatalErrLock sync.Mutex
12301238
capabilities *workflowservice.GetSystemInfoResponse_Capabilities
@@ -2269,6 +2277,12 @@ func NewAggregatedWorker(client *WorkflowClient, taskQueue string, options Worke
22692277
// All worker systems that depend on the capabilities to process workflow/activity tasks
22702278
// should take a pointer to this struct and wait for it to be populated when the worker is run.
22712279
var capabilities workflowservice.GetSystemInfoResponse_Capabilities
2280+
var activityCancellationCallbacks *activityCancellationCallbacks
2281+
if client.heartbeatManager != nil {
2282+
activityCancellationCallbacks = client.heartbeatManager.
2283+
sharedNamespaceWorkerFor(client.namespace).
2284+
activityCancellationCallbacks
2285+
}
22722286

22732287
baseMetricsHandler := client.metricsHandler.WithTags(metrics.TaskQueueTags(taskQueue))
22742288
var metricsHandler metrics.Handler
@@ -2338,12 +2352,14 @@ func NewAggregatedWorker(client *WorkflowClient, taskQueue string, options Worke
23382352
taskQueue: taskQueue,
23392353
maxConcurrent: options.MaxConcurrentEagerActivityExecutionSize,
23402354
}),
2341-
capabilities: &capabilities,
2342-
pollTimeTracker: &pollTimeTracker{},
2343-
workerInstanceKey: workerInstanceKey,
2344-
workerPollCompleteOnShutdown: workerPollCompleteOnShutdown,
2345-
serverSupportsAutoscaling: &atomic.Bool{},
2346-
inboundPayloadVisitor: extstore.NewExternalRetrievalVisitor(client.storageParams),
2355+
capabilities: &capabilities,
2356+
pollTimeTracker: &pollTimeTracker{},
2357+
workerInstanceKey: workerInstanceKey,
2358+
workerControlTaskQueue: workerControlTaskQueue(client.namespace, client.workerGroupingKey),
2359+
activityCancellationCallbacks: activityCancellationCallbacks,
2360+
workerPollCompleteOnShutdown: workerPollCompleteOnShutdown,
2361+
serverSupportsAutoscaling: &atomic.Bool{},
2362+
inboundPayloadVisitor: extstore.NewExternalRetrievalVisitor(client.storageParams),
23472363
outboundPayloadVisitor: newCompositePayloadVisitor(
23482364
extstore.NewExternalStorageVisitor(client.storageParams),
23492365
payloadLimitVisitor,

0 commit comments

Comments
 (0)