@@ -145,6 +145,12 @@ type (
145145 cache * WorkerCache
146146 deadlockDetectionTimeout time.Duration
147147 capabilities * workflowservice.GetSystemInfoResponse_Capabilities
148+ workerControlTaskQueue string
149+ }
150+
151+ activityCancellationCallbacks struct {
152+ sync.Mutex
153+ cancels map [string ]context.CancelCauseFunc
148154 }
149155
150156 activityProvider func (name string ) activity
@@ -172,6 +178,7 @@ type (
172178 inboundPayloadVisitor PayloadVisitor
173179 outboundPayloadVisitor PayloadVisitor
174180 payloadVisitorConcurrency int
181+ activityCancellationCallbacks * activityCancellationCallbacks
175182 }
176183
177184 // history wrapper method to help information about events.
@@ -587,6 +594,7 @@ func newWorkflowTaskHandler(params workerExecutionParameters, ppMgr pressurePoin
587594 cache : params .cache ,
588595 deadlockDetectionTimeout : params .DeadlockDetectionTimeout ,
589596 capabilities : params .capabilities ,
597+ workerControlTaskQueue : params .workerControlTaskQueue ,
590598 }
591599}
592600
@@ -2007,6 +2015,7 @@ func (wth *workflowTaskHandlerImpl) completeWorkflow(
20072015 wth .useBuildIDForVersioning ,
20082016 wth .workerDeploymentVersion ,
20092017 ),
2018+ WorkerControlTaskQueue : wth .workerControlTaskQueue ,
20102019 }
20112020 if wth .capabilities != nil && wth .capabilities .BuildIdBasedVersioning {
20122021 //lint:ignore SA1019 ignore deprecated versioning APIs
@@ -2100,12 +2109,40 @@ func newActivityTaskHandlerWithCustomProvider(
21002109 params .UseBuildIDForVersioning ,
21012110 params .DeploymentOptions .Version ,
21022111 ),
2103- inboundPayloadVisitor : params .inboundPayloadVisitor ,
2104- outboundPayloadVisitor : params .outboundPayloadVisitor ,
2105- payloadVisitorConcurrency : params .payloadVisitorConcurrency ,
2112+ inboundPayloadVisitor : params .inboundPayloadVisitor ,
2113+ outboundPayloadVisitor : params .outboundPayloadVisitor ,
2114+ payloadVisitorConcurrency : params .payloadVisitorConcurrency ,
2115+ activityCancellationCallbacks : params .activityCancellationCallbacks ,
2116+ }
2117+ }
2118+
2119+ func newActivityCancellationCallbacks () * activityCancellationCallbacks {
2120+ return & activityCancellationCallbacks {cancels : make (map [string ]context.CancelCauseFunc )}
2121+ }
2122+
2123+ func (r * activityCancellationCallbacks ) register (taskToken []byte , cancel context.CancelCauseFunc ) func () {
2124+ key := string (taskToken )
2125+ r .Lock ()
2126+ r .cancels [key ] = cancel
2127+ r .Unlock ()
2128+ return func () {
2129+ r .Lock ()
2130+ delete (r .cancels , key )
2131+ r .Unlock ()
21062132 }
21072133}
21082134
2135+ func (r * activityCancellationCallbacks ) cancel (taskToken []byte ) bool {
2136+ r .Lock ()
2137+ cancel , ok := r .cancels [string (taskToken )]
2138+ r .Unlock ()
2139+ if ! ok {
2140+ return false
2141+ }
2142+ cancel (NewCanceledError ())
2143+ return true
2144+ }
2145+
21092146// heartbeatVisitorError wraps an outbound payload visitor error from a heartbeat.
21102147// It is used as the context cancellation cause so Execute() can detect that
21112148// RespondActivityTaskFailed was already sent proactively and skip sending a second response.
@@ -2120,7 +2157,8 @@ type temporalInvoker struct {
21202157 service workflowservice.WorkflowServiceClient
21212158 metricsHandler metrics.Handler
21222159 taskToken []byte
2123- // cancelHandler is called when the activity is canceled by a heartbeat request.
2160+ // cancelHandler is called when the activity is canceled by a heartbeat response
2161+ // or worker command.
21242162 cancelHandler context.CancelCauseFunc
21252163 // Amount of time to wait between each pending heartbeat send
21262164 heartbeatThrottleInterval time.Duration
@@ -2351,6 +2389,10 @@ func (ath *activityTaskHandlerImpl) Execute(taskQueue string, t *workflowservice
23512389 }
23522390 canCtx , cancel := context .WithCancelCause (rootCtx )
23532391 defer cancel (nil )
2392+ if ath .activityCancellationCallbacks != nil {
2393+ unregister := ath .activityCancellationCallbacks .register (t .TaskToken , cancel )
2394+ defer unregister ()
2395+ }
23542396
23552397 if err := visitProtoPayloads (canCtx , ath .inboundPayloadVisitor , t , ath .payloadVisitorConcurrency ); err != nil {
23562398 return ath .visitorErrorToActivityFailure ("Activity task preprocess error: " , t , err ), nil
0 commit comments