Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions internal/activity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func (s *activityTestSuite) TearDownTest() {
func (s *activityTestSuite) TestActivityHeartbeat() {
ctx, cancel := context.WithCancelCause(context.Background())
invoker := newServiceInvoker([]byte("task-token"), "identity", s.service, metrics.NopHandler, cancel,
1*time.Second, make(chan struct{}), s.namespace, &atomic.Bool{})
1*time.Second, make(chan struct{}), s.namespace, &atomic.Bool{}, nil, nil)
ctx, _ = newActivityContext(ctx, nil, &activityEnvironment{serviceInvoker: invoker})

s.service.EXPECT().RecordActivityTaskHeartbeat(gomock.Any(), gomock.Any(), gomock.Any()).
Expand All @@ -55,7 +55,7 @@ func (s *activityTestSuite) TestActivityHeartbeat() {
func (s *activityTestSuite) TestActivityHeartbeat_InternalError() {
ctx, cancel := context.WithCancelCause(context.Background())
invoker := newServiceInvoker([]byte("task-token"), "identity", s.service, metrics.NopHandler, cancel,
1*time.Second, make(chan struct{}), s.namespace, &atomic.Bool{})
1*time.Second, make(chan struct{}), s.namespace, &atomic.Bool{}, nil, nil)
ctx, _ = newActivityContext(ctx, nil, &activityEnvironment{
serviceInvoker: invoker,
logger: getLogger()})
Expand All @@ -72,7 +72,7 @@ func (s *activityTestSuite) TestActivityHeartbeat_InternalError() {
func (s *activityTestSuite) TestActivityHeartbeat_CancelRequested() {
ctx, cancel := context.WithCancelCause(context.Background())
invoker := newServiceInvoker([]byte("task-token"), "identity", s.service, metrics.NopHandler, cancel,
1*time.Second, make(chan struct{}), s.namespace, &atomic.Bool{})
1*time.Second, make(chan struct{}), s.namespace, &atomic.Bool{}, nil, nil)
ctx, _ = newActivityContext(ctx, nil, &activityEnvironment{
serviceInvoker: invoker,
logger: getLogger()})
Expand All @@ -88,7 +88,7 @@ func (s *activityTestSuite) TestActivityHeartbeat_CancelRequested() {
func (s *activityTestSuite) TestActivityHeartbeat_PauseRequested() {
ctx, cancel := context.WithCancelCause(context.Background())
invoker := newServiceInvoker([]byte("task-token"), "identity", s.service, metrics.NopHandler, cancel,
1*time.Second, make(chan struct{}), s.namespace, &atomic.Bool{})
1*time.Second, make(chan struct{}), s.namespace, &atomic.Bool{}, nil, nil)
ctx, _ = newActivityContext(ctx, nil, &activityEnvironment{
serviceInvoker: invoker,
logger: getLogger()})
Expand All @@ -105,7 +105,7 @@ func (s *activityTestSuite) TestActivityHeartbeat_PauseRequested() {
func (s *activityTestSuite) TestActivityHeartbeat_ResetRequested() {
ctx, cancel := context.WithCancelCause(context.Background())
invoker := newServiceInvoker([]byte("task-token"), "identity", s.service, metrics.NopHandler, cancel,
1*time.Second, make(chan struct{}), s.namespace, &atomic.Bool{})
1*time.Second, make(chan struct{}), s.namespace, &atomic.Bool{}, nil, nil)
ctx, _ = newActivityContext(ctx, nil, &activityEnvironment{
serviceInvoker: invoker,
logger: getLogger()})
Expand All @@ -122,7 +122,7 @@ func (s *activityTestSuite) TestActivityHeartbeat_ResetRequested() {
func (s *activityTestSuite) TestActivityHeartbeat_EntityNotExist() {
ctx, cancel := context.WithCancelCause(context.Background())
invoker := newServiceInvoker([]byte("task-token"), "identity", s.service, metrics.NopHandler, cancel,
1*time.Second, make(chan struct{}), s.namespace, &atomic.Bool{})
1*time.Second, make(chan struct{}), s.namespace, &atomic.Bool{}, nil, nil)
ctx, _ = newActivityContext(ctx, nil, &activityEnvironment{
serviceInvoker: invoker,
logger: getLogger()})
Expand All @@ -138,7 +138,7 @@ func (s *activityTestSuite) TestActivityHeartbeat_EntityNotExist() {
func (s *activityTestSuite) TestActivityHeartbeat_SuppressContinousInvokes() {
ctx, cancel := context.WithCancelCause(context.Background())
invoker := newServiceInvoker([]byte("task-token"), "identity", s.service, metrics.NopHandler, cancel,
2*time.Second, make(chan struct{}), s.namespace, &atomic.Bool{})
2*time.Second, make(chan struct{}), s.namespace, &atomic.Bool{}, nil, nil)
ctx, _ = newActivityContext(ctx, nil, &activityEnvironment{
serviceInvoker: invoker,
logger: getLogger()})
Expand All @@ -154,7 +154,7 @@ func (s *activityTestSuite) TestActivityHeartbeat_SuppressContinousInvokes() {
// High HB timeout configured.
service2 := workflowservicemock.NewMockWorkflowServiceClient(s.mockCtrl)
invoker2 := newServiceInvoker([]byte("task-token"), "identity", service2, metrics.NopHandler, cancel,
20*time.Second, make(chan struct{}), s.namespace, &atomic.Bool{})
20*time.Second, make(chan struct{}), s.namespace, &atomic.Bool{}, nil, nil)
ctx, _ = newActivityContext(ctx, nil, &activityEnvironment{
serviceInvoker: invoker2,
logger: getLogger()})
Expand All @@ -168,7 +168,7 @@ func (s *activityTestSuite) TestActivityHeartbeat_SuppressContinousInvokes() {
waitCh := make(chan struct{})
service3 := workflowservicemock.NewMockWorkflowServiceClient(s.mockCtrl)
invoker3 := newServiceInvoker([]byte("task-token"), "identity", service3, metrics.NopHandler, cancel,
2*time.Second, make(chan struct{}), s.namespace, &atomic.Bool{})
2*time.Second, make(chan struct{}), s.namespace, &atomic.Bool{}, nil, nil)
ctx, _ = newActivityContext(ctx, nil, &activityEnvironment{
serviceInvoker: invoker3,
logger: getLogger()})
Expand Down Expand Up @@ -199,7 +199,7 @@ func (s *activityTestSuite) TestActivityHeartbeat_SuppressContinousInvokes() {
waitCh2 := make(chan struct{})
service4 := workflowservicemock.NewMockWorkflowServiceClient(s.mockCtrl)
invoker4 := newServiceInvoker([]byte("task-token"), "identity", service4, metrics.NopHandler, cancel,
2*time.Second, make(chan struct{}), s.namespace, &atomic.Bool{})
2*time.Second, make(chan struct{}), s.namespace, &atomic.Bool{}, nil, nil)
ctx, _ = newActivityContext(ctx, nil, &activityEnvironment{
serviceInvoker: invoker4,
logger: getLogger()})
Expand All @@ -224,7 +224,7 @@ func (s *activityTestSuite) TestActivityHeartbeat_WorkerStop() {
ctx, cancel := context.WithCancelCause(context.Background())
workerStopChannel := make(chan struct{})
invoker := newServiceInvoker([]byte("task-token"), "identity", s.service, metrics.NopHandler, cancel,
5*time.Second, workerStopChannel, s.namespace, &atomic.Bool{})
5*time.Second, workerStopChannel, s.namespace, &atomic.Bool{}, nil, nil)
ctx, _ = newActivityContext(ctx, nil, &activityEnvironment{serviceInvoker: invoker})

heartBeatDetail := "testDetails"
Expand Down
70 changes: 49 additions & 21 deletions internal/internal_task_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -2097,6 +2097,14 @@ func newActivityTaskHandlerWithCustomProvider(
}
}

// heartbeatVisitorError wraps an outbound payload visitor error from a heartbeat.
// It is used as the context cancellation cause so Execute() can detect that
// RespondActivityTaskFailed was already sent proactively and skip sending a second response.
type heartbeatVisitorError struct{ err error }

func (e heartbeatVisitorError) Error() string { return e.err.Error() }
func (e heartbeatVisitorError) Unwrap() error { return e.err }

type temporalInvoker struct {
sync.Mutex
identity string
Expand All @@ -2113,6 +2121,8 @@ type temporalInvoker struct {
workerStopChannel <-chan struct{}
namespace string
excludeInternalFromRetry *atomic.Bool // borrowed from client in order to tell if internal errors are retriable
outboundPayloadVisitor PayloadVisitor
failureConverter converter.FailureConverter
}

func (i *temporalInvoker) Heartbeat(ctx context.Context, details *commonpb.Payloads, skipBatching bool) error {
Expand Down Expand Up @@ -2183,7 +2193,31 @@ func (i *temporalInvoker) internalHeartBeat(ctx context.Context, details *common
ctx, cancel := context.WithTimeout(ctx, recordTimeout)
defer cancel()

err := recordActivityHeartbeat(ctx, i.service, i.metricsHandler, i.identity, i.taskToken, details)
request := &workflowservice.RecordActivityTaskHeartbeatRequest{
TaskToken: i.taskToken,
Details: details,
Identity: i.identity,
Namespace: i.namespace,
}
var err error
if visitErr := visitProtoPayloads(ctx, i.outboundPayloadVisitor, request, 0); visitErr != nil {
// Proactively fail the task so the server can retry immediately rather than
// waiting for the heartbeat timeout. Errors are ignored — if the RPC fails the
// activity will still be timed out by the server.
Comment thread
jmaeagle99 marked this conversation as resolved.
failReq := &workflowservice.RespondActivityTaskFailedRequest{
TaskToken: i.taskToken,
Failure: i.failureConverter.ErrorToFailure(visitErr),
Identity: i.identity,
Namespace: i.namespace,
}
failCtx, failCancel := context.WithTimeout(context.Background(), recordTimeout)
defer failCancel()
_, _ = i.service.RespondActivityTaskFailed(failCtx, failReq)
err = heartbeatVisitorError{visitErr}
i.cancelHandler(err)
} else {
err = recordActivityHeartbeat(ctx, i.service, i.metricsHandler, request)
}

switch err.(type) {
case *CanceledError:
Expand Down Expand Up @@ -2248,6 +2282,8 @@ func newServiceInvoker(
workerStopChannel <-chan struct{},
namespace string,
excludeInternalFromRetry *atomic.Bool,
outboundPayloadVisitor PayloadVisitor,
failureConverter converter.FailureConverter,
) ServiceInvoker {
return &temporalInvoker{
taskToken: taskToken,
Expand All @@ -2260,6 +2296,8 @@ func newServiceInvoker(
workerStopChannel: workerStopChannel,
namespace: namespace,
excludeInternalFromRetry: excludeInternalFromRetry,
outboundPayloadVisitor: outboundPayloadVisitor,
failureConverter: failureConverter,
}
}

Expand Down Expand Up @@ -2308,7 +2346,7 @@ func (ath *activityTaskHandlerImpl) Execute(taskQueue string, t *workflowservice
heartbeatThrottleInterval := ath.getHeartbeatThrottleInterval(t.GetHeartbeatTimeout().AsDuration())
invoker := newServiceInvoker(
t.TaskToken, ath.identity, ath.client.workflowService, ath.metricsHandler, cancel, heartbeatThrottleInterval,
ath.workerStopCh, ath.namespace, ath.client.excludeInternalFromRetry)
ath.workerStopCh, ath.namespace, ath.client.excludeInternalFromRetry, ath.outboundPayloadVisitor, failureConverter)

workflowType := t.WorkflowType.GetName()
activityType := t.ActivityType.GetName()
Expand Down Expand Up @@ -2368,6 +2406,13 @@ func (ath *activityTaskHandlerImpl) Execute(taskQueue string, t *workflowservice
output, err := activityImplementation.Execute(ctx, t.Input)
// Check if context canceled at a higher level before we cancel it ourselves

// The heartbeat visitor failure path proactively sent RespondActivityTaskFailed,
// skip sending another response regardless of what the activity returned.
var hbVisitorErr heartbeatVisitorError
if errors.As(context.Cause(canCtx), &hbVisitorErr) {
return nil, nil
}

// Cancels that don't originate from the server will have separate cancel reasons, like
// ErrWorkerShutdown or ErrActivityPaused
isActivityCanceled := ctx.Err() == context.Canceled && IsCanceledError(context.Cause(ctx))
Expand Down Expand Up @@ -2519,16 +2564,8 @@ func createNewCommandWithMetadata(commandType enumspb.CommandType, metadata *sdk
}

func recordActivityHeartbeat(ctx context.Context, service workflowservice.WorkflowServiceClient, metricsHandler metrics.Handler,
identity string, taskToken []byte, details *commonpb.Payloads,
request *workflowservice.RecordActivityTaskHeartbeatRequest,
) error {
namespace := getNamespaceFromActivityCtx(ctx)
request := &workflowservice.RecordActivityTaskHeartbeatRequest{
TaskToken: taskToken,
Details: details,
Identity: identity,
Namespace: namespace,
}

var heartbeatResponse *workflowservice.RecordActivityTaskHeartbeatResponse
grpcCtx, cancel := newGRPCContext(ctx,
grpcMetricsHandler(metricsHandler),
Expand All @@ -2549,17 +2586,8 @@ func recordActivityHeartbeat(ctx context.Context, service workflowservice.Workfl
}

func recordActivityHeartbeatByID(ctx context.Context, service workflowservice.WorkflowServiceClient, metricsHandler metrics.Handler,
identity, namespace, workflowID, runID, activityID string, details *commonpb.Payloads,
request *workflowservice.RecordActivityTaskHeartbeatByIdRequest,
) error {
request := &workflowservice.RecordActivityTaskHeartbeatByIdRequest{
Namespace: namespace,
WorkflowId: workflowID,
RunId: runID,
ActivityId: activityID,
Details: details,
Identity: identity,
}

var heartbeatResponse *workflowservice.RecordActivityTaskHeartbeatByIdResponse
grpcCtx, cancel := newGRPCContext(ctx,
grpcMetricsHandler(metricsHandler),
Expand Down
4 changes: 2 additions & 2 deletions internal/internal_task_handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1994,7 +1994,7 @@ func (t *TaskHandlersTestSuite) TestHeartBeat_NilResponseWithError() {

temporalInvoker := newServiceInvoker(
nil, "Test_Temporal_Invoker", mockService, metrics.NopHandler, func(err error) {}, 0,
make(chan struct{}), t.namespace, &atomic.Bool{})
make(chan struct{}), t.namespace, &atomic.Bool{}, nil, nil)

ctx, err := newActivityContext(context.Background(), nil, &activityEnvironment{serviceInvoker: temporalInvoker, logger: t.logger})
t.NoError(err)
Expand All @@ -2015,7 +2015,7 @@ func (t *TaskHandlersTestSuite) TestHeartBeat_NilResponseWithNamespaceNotActiveE

temporalInvoker := newServiceInvoker(
nil, "Test_Temporal_Invoker", mockService, metrics.NopHandler, cancelHandler,
0, make(chan struct{}), t.namespace, &atomic.Bool{})
0, make(chan struct{}), t.namespace, &atomic.Bool{}, nil, nil)

ctx, err := newActivityContext(context.Background(), nil, &activityEnvironment{serviceInvoker: temporalInvoker, logger: t.logger})
t.NoError(err)
Expand Down
8 changes: 0 additions & 8 deletions internal/internal_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -1147,14 +1147,6 @@ func getDataConverterFromActivityCtx(ctx context.Context) converter.DataConverte
return WithContext(ctx, dataConverter)
}

Comment thread
jmaeagle99 marked this conversation as resolved.
func getNamespaceFromActivityCtx(ctx context.Context) string {
env := getActivityEnvironmentFromCtx(ctx)
if env == nil {
return ""
}
return env.namespace
}

func getActivityEnvironmentFromCtx(ctx context.Context) *activityEnvironment {
if ctx == nil || ctx.Value(activityEnvContextKey) == nil {
return nil
Expand Down
3 changes: 2 additions & 1 deletion internal/internal_worker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"go.temporal.io/api/workflowservice/v1"
"go.temporal.io/api/workflowservicemock/v1"
"google.golang.org/grpc"
"google.golang.org/protobuf/proto"

"go.temporal.io/sdk/converter"
iconverter "go.temporal.io/sdk/internal/converter"
Expand Down Expand Up @@ -2344,7 +2345,7 @@ func (s *internalWorkerTestSuite) TestRecordActivityHeartbeatWithDataConverter()
s.service.EXPECT().RecordActivityTaskHeartbeat(gomock.Any(), gomock.Any(), gomock.Any()).Return(&heartbeatResponse, nil).
Do(func(ctx context.Context, request *workflowservice.RecordActivityTaskHeartbeatRequest, opts ...grpc.CallOption) {
heartbeatRequest = request
require.Equal(t, encodedDetail, request.Details)
require.True(t, proto.Equal(encodedDetail, request.Details), "details proto mismatch")
}).Times(1)

_ = wfClient.RecordActivityHeartbeat(context.Background(), nil, detail1, detail2, detail3)
Expand Down
24 changes: 22 additions & 2 deletions internal/internal_workflow_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,16 @@ func (wc *WorkflowClient) RecordActivityHeartbeatWithOptions(ctx context.Context
if err != nil {
return err
}
return recordActivityHeartbeat(ctx, wc.workflowService, wc.metricsHandler, wc.identity, opts.TaskToken, data)
request := &workflowservice.RecordActivityTaskHeartbeatRequest{
TaskToken: opts.TaskToken,
Details: data,
Identity: wc.identity,
Namespace: cmp.Or(opts.Namespace, wc.namespace),
}
if err := visitProtoPayloads(ctx, wc.newOutboundPayloadVisitor(), request, 0); err != nil {
return err
}
return recordActivityHeartbeat(ctx, wc.workflowService, wc.metricsHandler, request)
}

// RecordActivityHeartbeatByID records heartbeat for an activity.
Expand Down Expand Up @@ -715,7 +724,18 @@ func (wc *WorkflowClient) RecordActivityHeartbeatByIDWithOptions(ctx context.Con
if err != nil {
return err
}
return recordActivityHeartbeatByID(ctx, wc.workflowService, wc.metricsHandler, wc.identity, opts.Namespace, opts.WorkflowID, opts.RunID, opts.ActivityID, data)
byIDRequest := &workflowservice.RecordActivityTaskHeartbeatByIdRequest{
Namespace: cmp.Or(opts.Namespace, wc.namespace),
WorkflowId: opts.WorkflowID,
RunId: opts.RunID,
ActivityId: opts.ActivityID,
Details: data,
Identity: wc.identity,
}
if err := visitProtoPayloads(ctx, wc.newOutboundPayloadVisitor(), byIDRequest, 0); err != nil {
return err
}
return recordActivityHeartbeatByID(ctx, wc.workflowService, wc.metricsHandler, byIDRequest)
}

// ListClosedWorkflow gets closed workflow executions based on request filters
Expand Down
62 changes: 62 additions & 0 deletions test/payload_limits_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,68 @@ func (ts *PayloadLimitsTestSuite) TestPayloadSizeErrorDisabledWorkflowResult() {
ts.Contains(attributes.Reason, "BadScheduleActivityAttributes: CompleteWorkflowExecutionCommandAttributes.Result exceeds size limit.")
}

// assertActivityTaskFailed checks that an ActivityTaskFailed event (not a timeout) is in history.
func (ts *PayloadLimitsTestSuite) assertActivityTaskFailed(ctx context.Context, run client.WorkflowRun) {
err := run.Get(ctx, nil)
var workflowExecutionErr *temporal.WorkflowExecutionError
ts.ErrorAs(err, &workflowExecutionErr)
var activityErr *temporal.ActivityError
ts.ErrorAs(workflowExecutionErr.Unwrap(), &activityErr)

eventIterator := ts.client.GetWorkflowHistory(ctx, run.GetID(), run.GetRunID(), false, enumspb.HISTORY_EVENT_FILTER_TYPE_ALL_EVENT)
var actTaskFailedEvent *historypb.HistoryEvent
for eventIterator.HasNext() {
event, err := eventIterator.Next()
ts.NoError(err)
if event.EventType == enumspb.EVENT_TYPE_ACTIVITY_TASK_FAILED {
actTaskFailedEvent = event
}
}
ts.NotNil(actTaskFailedEvent, "expected ActivityTaskFailed event in history, not a timeout")
}

// TestPayloadSizeErrorActivityHeartbeat verifies that oversized heartbeat details proactively
// produce an ActivityTaskFailed event regardless of whether the activity observes the cancellation.
func (ts *PayloadLimitsTestSuite) TestPayloadSizeErrorActivityHeartbeat() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

logger := ilog.NewMemoryLogger()
ts.ResetClientAndWorker(func(opts *client.Options) {
opts.Logger = logger
}, nil)

wfName := "payload-size-error-activity-heartbeat"
actName := "heartbeat-large-payload-activity"
ts.worker.RegisterWorkflowWithOptions(
func(ctx workflow.Context) error {
return workflow.ExecuteActivity(
workflow.WithActivityOptions(ctx, workflow.ActivityOptions{
ScheduleToCloseTimeout: 20 * time.Second,
HeartbeatTimeout: 15 * time.Second,
RetryPolicy: &temporal.RetryPolicy{MaximumAttempts: 1},
}),
actName,
).Get(ctx, nil)
},
workflow.RegisterOptions{Name: wfName},
)
ts.worker.RegisterActivityWithOptions(
func(ctx context.Context) error {
activity.RecordHeartbeat(ctx, strings.Repeat("h", payloadSizeErrorLimit+1000))
return nil // ignores ctx.Done()
},
activity.RegisterOptions{Name: actName},
)

run, err := ts.client.ExecuteWorkflow(ctx, ts.startWorkflowOptions(ts.T().Name()), wfName)
ts.NoError(err)

ts.assertActivityTaskFailed(ctx, run)
ts.assertLogContains(logger, payloadErrorMessage)
}


func (ts *PayloadLimitsTestSuite) TestPayloadSizeWarningClientCustom() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down
Loading