Skip to content

Commit bbf11a4

Browse files
committed
Add context.Context to callback.Validator interface
Thread context through to callback validation so that deployments can check caller identity when validating callbacks. The base validator does not use the context.
1 parent b8c0ec1 commit bbf11a4

6 files changed

Lines changed: 39 additions & 28 deletions

File tree

chasm/lib/activity/frontend.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ var ErrStandaloneActivityDisabled = serviceerror.NewUnimplemented("Standalone ac
3838

3939
type frontendHandler struct {
4040
FrontendHandler
41-
callbackValidator *callback.Validator
41+
callbackValidator callback.Validator
4242
client activitypb.ActivityServiceClient
4343
config *Config
4444
logger log.Logger
@@ -50,7 +50,7 @@ type frontendHandler struct {
5050

5151
// NewFrontendHandler creates a new FrontendHandler instance for processing activity frontend requests.
5252
func NewFrontendHandler(
53-
callbackValidator *callback.Validator,
53+
callbackValidator callback.Validator,
5454
client activitypb.ActivityServiceClient,
5555
config *Config,
5656
logger log.Logger,
@@ -95,7 +95,7 @@ func (h *frontendHandler) StartActivityExecution(ctx context.Context, req *workf
9595
return nil, err
9696
}
9797

98-
modifiedReq, err := h.validateAndPopulateStartRequest(req, namespaceID)
98+
modifiedReq, err := h.validateAndPopulateStartRequest(ctx, req, namespaceID)
9999
if err != nil {
100100
return nil, err
101101
}
@@ -351,6 +351,7 @@ func (h *frontendHandler) RequestCancelActivityExecution(
351351
}
352352

353353
func (h *frontendHandler) validateAndPopulateStartRequest(
354+
ctx context.Context,
354355
req *workflowservice.StartActivityExecutionRequest,
355356
namespaceID namespace.ID,
356357
) (*workflowservice.StartActivityExecutionRequest, error) {
@@ -397,7 +398,7 @@ func (h *frontendHandler) validateAndPopulateStartRequest(
397398
}
398399

399400
if cbs := req.GetCompletionCallbacks(); len(cbs) > 0 {
400-
if err := h.callbackValidator.Validate(req.GetNamespace(), cbs); err != nil {
401+
if err := h.callbackValidator.Validate(ctx, req.GetNamespace(), cbs); err != nil {
401402
return nil, err
402403
}
403404
}

chasm/lib/activity/frontend_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package activity
22

33
import (
4+
"context"
45
"testing"
56
"time"
67

@@ -50,11 +51,11 @@ func TestRequestIdStableAcrossRetries(t *testing.T) {
5051
// validateAndPopulateStartRequest with the same request pointer.
5152
validateTwoAttempts := func(t *testing.T, req *workflowservice.StartActivityExecutionRequest) {
5253
t.Helper()
53-
clone1, err := h.validateAndPopulateStartRequest(req, nsID)
54+
clone1, err := h.validateAndPopulateStartRequest(context.Background(), req, nsID)
5455
require.NoError(t, err)
5556
require.NotEmpty(t, clone1.RequestId)
5657

57-
clone2, err := h.validateAndPopulateStartRequest(req, nsID)
58+
clone2, err := h.validateAndPopulateStartRequest(context.Background(), req, nsID)
5859
require.NoError(t, err)
5960
require.Equal(t, clone1.RequestId, clone2.RequestId)
6061
}

chasm/lib/callback/validator.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package callback
22

33
import (
4+
"context"
45
"fmt"
56
"strings"
67

@@ -11,7 +12,11 @@ import (
1112
)
1213

1314
// Validator validates completion callbacks attached to executions (workflows and standalone activities).
14-
type Validator struct {
15+
type Validator interface {
16+
Validate(ctx context.Context, namespaceName string, cbs []*commonpb.Callback) error
17+
}
18+
19+
type validator struct {
1520
maxCallbacksPerExecution dynamicconfig.IntPropertyFnWithNamespaceFilter
1621
urlMaxLength dynamicconfig.IntPropertyFnWithNamespaceFilter
1722
headerMaxSize dynamicconfig.IntPropertyFnWithNamespaceFilter
@@ -23,8 +28,8 @@ func NewValidator(
2328
urlMaxLength dynamicconfig.IntPropertyFnWithNamespaceFilter,
2429
headerMaxSize dynamicconfig.IntPropertyFnWithNamespaceFilter,
2530
endpointRules dynamicconfig.TypedPropertyFnWithNamespaceFilter[AddressMatchRules],
26-
) *Validator {
27-
return &Validator{
31+
) Validator {
32+
return &validator{
2833
maxCallbacksPerExecution: maxCallbacksPerExecution,
2934
urlMaxLength: urlMaxLength,
3035
headerMaxSize: headerMaxSize,
@@ -34,7 +39,7 @@ func NewValidator(
3439

3540
// Validate validates completion callbacks: count, URL length, endpoint allowlist, header size, and normalizes header
3641
// keys to lowercase.
37-
func (v *Validator) Validate(namespaceName string, cbs []*commonpb.Callback) error {
42+
func (v *validator) Validate(_ context.Context, namespaceName string, cbs []*commonpb.Callback) error {
3843
if len(cbs) > v.maxCallbacksPerExecution(namespaceName) {
3944
return serviceerror.NewInvalidArgumentf(
4045
"cannot attach more than %d callbacks to an execution", v.maxCallbacksPerExecution(namespaceName),

chasm/lib/callback/validator_test.go

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package callback
22

33
import (
4+
"context"
45
"regexp"
56
"testing"
67

@@ -31,7 +32,7 @@ func TestValidateCallbacks(t *testing.T) {
3132
},
3233
}},
3334
}
34-
err := v.Validate("ns", cbs)
35+
err := v.Validate(context.Background(), "ns", cbs)
3536
require.NoError(t, err)
3637
})
3738

@@ -46,7 +47,7 @@ func TestValidateCallbacks(t *testing.T) {
4647
{Variant: &commonpb.Callback_Nexus_{Nexus: &commonpb.Callback_Nexus{Url: "http://localhost/cb1"}}},
4748
{Variant: &commonpb.Callback_Nexus_{Nexus: &commonpb.Callback_Nexus{Url: "http://localhost/cb2"}}},
4849
}
49-
err := v.Validate("ns", cbs)
50+
err := v.Validate(context.Background(), "ns", cbs)
5051
var invalidArgErr *serviceerror.InvalidArgument
5152
require.ErrorAs(t, err, &invalidArgErr)
5253
require.Contains(t, err.Error(), "cannot attach more than 1 callbacks")
@@ -66,7 +67,7 @@ func TestValidateCallbacks(t *testing.T) {
6667
},
6768
}},
6869
}
69-
err := v.Validate("ns", cbs)
70+
err := v.Validate(context.Background(), "ns", cbs)
7071
var invalidArgErr *serviceerror.InvalidArgument
7172
require.ErrorAs(t, err, &invalidArgErr)
7273
require.Contains(t, err.Error(), "url length longer than max length allowed")
@@ -81,7 +82,7 @@ func TestValidateCallbacks(t *testing.T) {
8182
},
8283
}},
8384
}
84-
err := v.Validate("ns", cbs)
85+
err := v.Validate(context.Background(), "ns", cbs)
8586
var invalidArgErr *serviceerror.InvalidArgument
8687
require.ErrorAs(t, err, &invalidArgErr)
8788
require.Contains(t, err.Error(), "header size longer than max allowed size")
@@ -96,7 +97,7 @@ func TestValidateCallbacks(t *testing.T) {
9697
},
9798
}},
9899
}
99-
err := v.Validate("ns", cbs)
100+
err := v.Validate(context.Background(), "ns", cbs)
100101
require.NoError(t, err)
101102
nexus := cbs[0].GetNexus()
102103
require.Equal(t, "application/json", nexus.Header["content-type"])
@@ -119,7 +120,7 @@ func TestValidateCallbacks(t *testing.T) {
119120
},
120121
}},
121122
}
122-
err := v.Validate("ns", cbs)
123+
err := v.Validate(context.Background(), "ns", cbs)
123124
var invalidArgErr *serviceerror.InvalidArgument
124125
require.ErrorAs(t, err, &invalidArgErr)
125126
require.Contains(t, err.Error(), "does not match any configured callback address")
@@ -129,14 +130,14 @@ func TestValidateCallbacks(t *testing.T) {
129130
cbs := []*commonpb.Callback{
130131
{Variant: nil},
131132
}
132-
err := v.Validate("ns", cbs)
133+
err := v.Validate(context.Background(), "ns", cbs)
133134
var unimplementedErr *serviceerror.Unimplemented
134135
require.ErrorAs(t, err, &unimplementedErr)
135136
require.Contains(t, err.Error(), "unknown callback variant")
136137
})
137138

138139
t.Run("EmptyCallbacksNoError", func(t *testing.T) {
139-
err := v.Validate("ns", nil)
140+
err := v.Validate(context.Background(), "ns", nil)
140141
require.NoError(t, err)
141142
})
142143

@@ -146,7 +147,7 @@ func TestValidateCallbacks(t *testing.T) {
146147
Internal: &commonpb.Callback_Internal{},
147148
}},
148149
}
149-
err := v.Validate("ns", cbs)
150+
err := v.Validate(context.Background(), "ns", cbs)
150151
require.NoError(t, err)
151152
})
152153
}

service/frontend/fx.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -827,7 +827,7 @@ func OperatorHandlerProvider(
827827
// so that existing operator configurations (component.callbacks.allowedAddresses) are honored.
828828
// TODO: Once HSM callbacks (components/callbacks) are removed, move this provider into
829829
// chasm/lib/callback/fx.go and read directly from callback.AllowedAddresses.
830-
func callbackValidatorProvider(dc *dynamicconfig.Collection) *callback.Validator {
830+
func callbackValidatorProvider(dc *dynamicconfig.Collection) callback.Validator {
831831
return callback.NewValidator(
832832
callback.MaxPerExecution.Get(dc),
833833
dynamicconfig.FrontendCallbackURLMaxLength.Get(dc),
@@ -876,7 +876,7 @@ func HandlerProvider(
876876
healthInterceptor *interceptor.HealthInterceptor,
877877
scheduleSpecBuilder *scheduler.SpecBuilder,
878878
activityHandler activity.FrontendHandler,
879-
callbackValidator *callback.Validator,
879+
callbackValidator callback.Validator,
880880
nexusOperationHandler chasmnexus.FrontendHandler,
881881
registry *chasm.Registry,
882882
frontendServiceResolver membership.ServiceResolver,

service/frontend/workflow_handler.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ type (
123123

124124
status int32
125125

126-
callbackValidator *callback.Validator
126+
callbackValidator callback.Validator
127127
tokenSerializer *tasktoken.Serializer
128128
config *Config
129129
versionChecker headers.VersionChecker
@@ -302,7 +302,7 @@ func (wh *WorkflowHandler) ValidateWorkerDeploymentVersionComputeConfig(
302302

303303
// NewWorkflowHandler creates a gRPC handler for workflowservice
304304
func NewWorkflowHandler(
305-
callbackValidator *callback.Validator,
305+
callbackValidator callback.Validator,
306306
config *Config,
307307
namespaceReplicationQueue persistence.NamespaceReplicationQueue,
308308
visibilityMgr manager.VisibilityManager,
@@ -545,7 +545,7 @@ func (wh *WorkflowHandler) StartWorkflowExecution(
545545
defer log.CapturePanic(wh.logger, &retError)
546546

547547
var err error
548-
if request, err = wh.prepareStartWorkflowRequest(request); err != nil {
548+
if request, err = wh.prepareStartWorkflowRequest(ctx, request); err != nil {
549549
return nil, err
550550
}
551551

@@ -603,6 +603,7 @@ func (wh *WorkflowHandler) convertToStartWorkflowExecutionResponse(
603603

604604
// Validates the request and sets default values where they are missing.
605605
func (wh *WorkflowHandler) prepareStartWorkflowRequest(
606+
ctx context.Context,
606607
request *workflowservice.StartWorkflowExecutionRequest,
607608
) (*workflowservice.StartWorkflowExecutionRequest, error) {
608609
if request == nil {
@@ -680,7 +681,7 @@ func (wh *WorkflowHandler) prepareStartWorkflowRequest(
680681
}
681682

682683
if cbs := request.GetCompletionCallbacks(); len(cbs) > 0 {
683-
if err := wh.callbackValidator.Validate(namespaceName.String(), cbs); err != nil {
684+
if err := wh.callbackValidator.Validate(ctx, namespaceName.String(), cbs); err != nil {
684685
return nil, err
685686
}
686687
}
@@ -791,7 +792,7 @@ func (wh *WorkflowHandler) ExecuteMultiOperation(
791792
return nil, errMultiOpNotStartAndUpdate
792793
}
793794

794-
historyReq, err := wh.convertToHistoryMultiOperationRequest(namespaceID, request)
795+
historyReq, err := wh.convertToHistoryMultiOperationRequest(ctx, namespaceID, request)
795796
if err != nil {
796797
return nil, err
797798
}
@@ -815,6 +816,7 @@ func (wh *WorkflowHandler) ExecuteMultiOperation(
815816
}
816817

817818
func (wh *WorkflowHandler) convertToHistoryMultiOperationRequest(
819+
ctx context.Context,
818820
namespaceID namespace.ID,
819821
request *workflowservice.ExecuteMultiOperationRequest,
820822
) (*historyservice.ExecuteMultiOperationRequest, error) {
@@ -825,7 +827,7 @@ func (wh *WorkflowHandler) convertToHistoryMultiOperationRequest(
825827
errs := make([]error, len(request.Operations))
826828

827829
for i, op := range request.Operations {
828-
convertedOp, opWorkflowID, err := wh.convertToHistoryMultiOperationItem(namespaceID, namespace.Name(request.Namespace), op)
830+
convertedOp, opWorkflowID, err := wh.convertToHistoryMultiOperationItem(ctx, namespaceID, namespace.Name(request.Namespace), op)
829831
if err != nil {
830832
hasError = true
831833
} else {
@@ -856,6 +858,7 @@ func (wh *WorkflowHandler) convertToHistoryMultiOperationRequest(
856858
}
857859

858860
func (wh *WorkflowHandler) convertToHistoryMultiOperationItem(
861+
ctx context.Context,
859862
namespaceID namespace.ID,
860863
namespaceName namespace.Name,
861864
op *workflowservice.ExecuteMultiOperationRequest_Operation,
@@ -868,7 +871,7 @@ func (wh *WorkflowHandler) convertToHistoryMultiOperationItem(
868871
return nil, "", errMultiOpNamespaceMismatch
869872
}
870873
var err error
871-
if startReq, err = wh.prepareStartWorkflowRequest(startReq); err != nil {
874+
if startReq, err = wh.prepareStartWorkflowRequest(ctx, startReq); err != nil {
872875
return nil, "", err
873876
}
874877
if len(startReq.CronSchedule) > 0 {

0 commit comments

Comments
 (0)