Skip to content

Commit a2e324a

Browse files
authored
[Nexus] Dynamic config to filter Nexus request headers (temporalio#6973)
## What changed? <!-- Describe what has changed in this PR --> Nexus request headers sanitization: it removes blacklisted headers before it's sent to the client's operation handler. ## Why? <!-- Tell your future self why have you made these changes --> ## How did you test it? <!-- How have you verified this change? Tested locally? Added a unit test? Checked in staging env? --> Unit test. ## Potential risks <!-- Assuming the worst case, what can be broken when deploying this change to production? --> ## Documentation <!-- Have you made sure this change doesn't falsify anything currently stated in `docs/`? If significant new behavior is added, have you described that in `docs/`? --> ## Is hotfix candidate? <!-- Is this PR a hotfix candidate or does it require a notification to be sent to the broader community? (Yes/No) -->
1 parent 4f9d034 commit a2e324a

File tree

5 files changed

+135
-10
lines changed

5 files changed

+135
-10
lines changed

common/dynamicconfig/constants.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -898,6 +898,12 @@ used when the first cache layer has a miss. Requires server restart for change t
898898
30*time.Second,
899899
`The TTL of the Nexus endpoint registry's readthrough LRU cache - the cache is a secondary cache and is only
900900
used when the first cache layer has a miss. Requires server restart for change to be applied.`,
901+
)
902+
FrontendNexusRequestHeadersBlacklist = NewGlobalTypedSetting(
903+
"frontend.nexusRequestHeadersBlacklist",
904+
[]string(nil),
905+
`Nexus request headers to be removed before being sent to a user handler.
906+
Wildcards (*) are expanded to allow any substring. By default blacklist is empty.`,
901907
)
902908
FrontendCallbackURLMaxLength = NewNamespaceIntSetting(
903909
"frontend.callbackURLMaxLength",

service/frontend/nexus_handler.go

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
"fmt"
2929
"net/http"
3030
"net/url"
31+
"regexp"
3132
"runtime/debug"
3233
"strconv"
3334
"strings"
@@ -92,6 +93,7 @@ type operationContext struct {
9293
telemetryInterceptor *interceptor.TelemetryInterceptor
9394
redirectionInterceptor *interceptor.Redirection
9495
forwardingEnabledForNamespace dynamicconfig.BoolPropertyFnWithNamespaceFilter
96+
headersBlacklist *dynamicconfig.GlobalCachedTypedValue[*regexp.Regexp]
9597
cleanupFunctions []func(map[string]string, error)
9698
}
9799

@@ -152,7 +154,11 @@ func (c *operationContext) augmentContext(ctx context.Context, header nexus.Head
152154
return ctx
153155
}
154156

155-
func (c *operationContext) interceptRequest(ctx context.Context, request *matchingservice.DispatchNexusTaskRequest, header nexus.Header) error {
157+
func (c *operationContext) interceptRequest(
158+
ctx context.Context,
159+
request *matchingservice.DispatchNexusTaskRequest,
160+
header nexus.Header,
161+
) error {
156162
err := c.auth.Authorize(ctx, c.claims, &authorization.CallTarget{
157163
APIName: c.apiName,
158164
Namespace: c.namespaceName,
@@ -171,13 +177,18 @@ func (c *operationContext) interceptRequest(ctx context.Context, request *matchi
171177

172178
if !c.namespace.ActiveInCluster(c.clusterMetadata.GetCurrentClusterName()) {
173179
if c.shouldForwardRequest(ctx, header) {
174-
// Handler methods should have special logic to forward requests if this method returns a serviceerror.NamespaceNotActive error.
180+
// Handler methods should have special logic to forward requests if this method returns
181+
// a serviceerror.NamespaceNotActive error.
175182
c.metricsHandler = c.metricsHandler.WithTags(metrics.OutcomeTag("request_forwarded"))
176183
handler, forwardStartTime := c.redirectionInterceptor.BeforeCall(c.apiName)
177184
c.cleanupFunctions = append(c.cleanupFunctions, func(_ map[string]string, retErr error) {
178185
c.redirectionInterceptor.AfterCall(handler, forwardStartTime, c.namespace.ActiveClusterName(), retErr)
179186
})
180-
return serviceerror.NewNamespaceNotActive(c.namespaceName, c.clusterMetadata.GetCurrentClusterName(), c.namespace.ActiveClusterName())
187+
return serviceerror.NewNamespaceNotActive(
188+
c.namespaceName,
189+
c.clusterMetadata.GetCurrentClusterName(),
190+
c.namespace.ActiveClusterName(),
191+
)
181192
}
182193
c.metricsHandler = c.metricsHandler.WithTags(metrics.OutcomeTag("namespace_inactive_forwarding_disabled"))
183194
return nexus.HandlerErrorf(nexus.HandlerErrorTypeUnavailable, "cluster inactive")
@@ -198,7 +209,12 @@ func (c *operationContext) interceptRequest(ctx context.Context, request *matchi
198209
}
199210
})
200211

201-
cleanup, err := c.namespaceConcurrencyLimitInterceptor.Allow(c.namespace.Name(), c.apiName, c.metricsHandlerForInterceptors, request)
212+
cleanup, err := c.namespaceConcurrencyLimitInterceptor.Allow(
213+
c.namespace.Name(),
214+
c.apiName,
215+
c.metricsHandlerForInterceptors,
216+
request,
217+
)
202218
c.cleanupFunctions = append(c.cleanupFunctions, func(map[string]string, error) { cleanup() })
203219
if err != nil {
204220
c.metricsHandler = c.metricsHandler.WithTags(metrics.OutcomeTag("namespace_concurrency_limited"))
@@ -221,6 +237,22 @@ func (c *operationContext) interceptRequest(ctx context.Context, request *matchi
221237
return converted
222238
}
223239

240+
// THIS MUST BE THE LAST STEP IN interceptRequest.
241+
// Sanitize headers.
242+
if request.GetRequest().GetHeader() != nil {
243+
// Making a copy to ensure the original map is not modified as it might be used somewhere else.
244+
sanitizedHeaders := make(map[string]string, len(request.Request.Header))
245+
headersBlacklist := c.headersBlacklist.Get()
246+
for name, value := range request.Request.Header {
247+
if !headersBlacklist.MatchString(name) {
248+
sanitizedHeaders[name] = value
249+
}
250+
}
251+
request.Request.Header = sanitizedHeaders
252+
}
253+
254+
// DO NOT ADD ANY STEPS HERE. ALL STEPS MUST BE BEFORE HEADERS SANITIZATION.
255+
224256
return nil
225257
}
226258

@@ -259,6 +291,7 @@ type nexusHandler struct {
259291
forwardingEnabledForNamespace dynamicconfig.BoolPropertyFnWithNamespaceFilter
260292
forwardingClients *cluster.FrontendHTTPClientCache
261293
payloadSizeLimit dynamicconfig.IntPropertyFnWithNamespaceFilter
294+
headersBlacklist *dynamicconfig.GlobalCachedTypedValue[*regexp.Regexp]
262295
}
263296

264297
// Extracts a nexusContext from the given ctx and returns an operationContext with tagged metrics and logging.
@@ -277,6 +310,7 @@ func (h *nexusHandler) getOperationContext(ctx context.Context, method string) (
277310
telemetryInterceptor: h.telemetryInterceptor,
278311
redirectionInterceptor: h.redirectionInterceptor,
279312
forwardingEnabledForNamespace: h.forwardingEnabledForNamespace,
313+
headersBlacklist: h.headersBlacklist,
280314
cleanupFunctions: make([]func(map[string]string, error), 0),
281315
}
282316
oc.metricsHandlerForInterceptors = h.metricsHandler.WithTags(

service/frontend/nexus_handler_test.go

Lines changed: 73 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,15 @@ package frontend
2525
import (
2626
"context"
2727
"errors"
28+
"regexp"
2829
"testing"
2930
"time"
3031

3132
"github.com/google/uuid"
3233
"github.com/nexus-rpc/sdk-go/nexus"
3334
"github.com/stretchr/testify/require"
3435
enumspb "go.temporal.io/api/enums/v1"
36+
nexuspb "go.temporal.io/api/nexus/v1"
3537
"go.temporal.io/api/serviceerror"
3638
"go.temporal.io/server/api/matchingservice/v1"
3739
persistencespb "go.temporal.io/server/api/persistence/v1"
@@ -48,6 +50,7 @@ import (
4850
"go.temporal.io/server/common/primitives/timestamp"
4951
"go.temporal.io/server/common/quotas"
5052
"go.temporal.io/server/common/rpc/interceptor"
53+
"go.temporal.io/server/common/util"
5154
)
5255

5356
type mockAuthorizer struct{}
@@ -96,6 +99,7 @@ type contextOptions struct {
9699
namespaceRateLimitAllow bool
97100
rateLimitAllow bool
98101
redirectAllow bool
102+
headersBlacklist []string
99103
}
100104

101105
func newOperationContext(options contextOptions) *operationContext {
@@ -147,12 +151,47 @@ func newOperationContext(options contextOptions) *operationContext {
147151
oc.apiName: 1,
148152
},
149153
)
150-
oc.namespaceRateLimitInterceptor = interceptor.NewNamespaceRateLimitInterceptor(nil, mockRateLimiter{options.namespaceRateLimitAllow}, make(map[string]int))
151-
oc.rateLimitInterceptor = interceptor.NewRateLimitInterceptor(mockRateLimiter{options.rateLimitAllow}, make(map[string]int))
154+
oc.namespaceRateLimitInterceptor = interceptor.NewNamespaceRateLimitInterceptor(
155+
nil,
156+
mockRateLimiter{options.namespaceRateLimitAllow},
157+
make(map[string]int),
158+
)
159+
oc.rateLimitInterceptor = interceptor.NewRateLimitInterceptor(
160+
mockRateLimiter{options.rateLimitAllow},
161+
make(map[string]int),
162+
)
152163

153-
oc.clusterMetadata = clustertest.NewMetadataForTest(cluster.NewTestClusterMetadataConfig(true, !options.namespacePassive))
154-
oc.forwardingEnabledForNamespace = dynamicconfig.GetBoolPropertyFnFilteredByNamespace(options.redirectAllow)
155-
oc.redirectionInterceptor = interceptor.NewRedirection(nil, nil, config.DCRedirectionPolicy{Policy: interceptor.DCRedirectionPolicyAllAPIsForwarding}, oc.logger, nil, oc.metricsHandlerForInterceptors, clock.NewRealTimeSource(), oc.clusterMetadata)
164+
oc.clusterMetadata = clustertest.NewMetadataForTest(
165+
cluster.NewTestClusterMetadataConfig(true, !options.namespacePassive),
166+
)
167+
oc.forwardingEnabledForNamespace = dynamicconfig.GetBoolPropertyFnFilteredByNamespace(
168+
options.redirectAllow,
169+
)
170+
oc.headersBlacklist = dynamicconfig.NewGlobalCachedTypedValue(
171+
dynamicconfig.NewCollection(
172+
&dynamicconfig.StaticClient{
173+
dynamicconfig.FrontendNexusRequestHeadersBlacklist.Key(): options.headersBlacklist,
174+
},
175+
nil,
176+
),
177+
dynamicconfig.FrontendNexusRequestHeadersBlacklist,
178+
func(patterns []string) (*regexp.Regexp, error) {
179+
if len(patterns) == 0 {
180+
return matchNothing, nil
181+
}
182+
return util.WildCardStringsToRegexp(patterns)
183+
},
184+
)
185+
oc.redirectionInterceptor = interceptor.NewRedirection(
186+
nil,
187+
nil,
188+
config.DCRedirectionPolicy{Policy: interceptor.DCRedirectionPolicyAllAPIsForwarding},
189+
oc.logger,
190+
nil,
191+
oc.metricsHandlerForInterceptors,
192+
clock.NewRealTimeSource(),
193+
oc.clusterMetadata,
194+
)
156195

157196
return oc
158197
}
@@ -328,3 +367,32 @@ func TestNexusInterceptRequest_InvalidSDKVersion_ResultsInBadRequest(t *testing.
328367
require.Equal(t, 1, len(snap["test"]))
329368
require.Equal(t, map[string]string{"outcome": "unsupported_client"}, snap["test"][0].Tags)
330369
}
370+
371+
func TestNexusInterceptRequest_HeadersSanitization(t *testing.T) {
372+
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
373+
defer cancel()
374+
var err error
375+
oc := newOperationContext(contextOptions{
376+
namespaceState: enumspb.NAMESPACE_STATE_REGISTERED,
377+
namespacePassive: false,
378+
quota: 1,
379+
namespaceRateLimitAllow: true,
380+
rateLimitAllow: true,
381+
headersBlacklist: []string{"delete-*", "remove-*"},
382+
})
383+
initialHeader := nexus.Header{
384+
"ok-header": "ok",
385+
"delete-foo": "foo",
386+
"delete-bar": "bar",
387+
"remove-zzz": "zzz",
388+
}
389+
header := util.CloneMapNonNil(initialHeader)
390+
ctx = oc.augmentContext(ctx, header)
391+
request := &matchingservice.DispatchNexusTaskRequest{
392+
Request: &nexuspb.Request{Header: header},
393+
}
394+
err = oc.interceptRequest(ctx, request, header)
395+
require.NoError(t, err)
396+
require.Equal(t, initialHeader, header)
397+
require.Equal(t, map[string]string{"ok-header": "ok"}, request.Request.Header)
398+
}

service/frontend/nexus_http_handler.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ func NewNexusHTTPHandler(
109109
forwardingEnabledForNamespace: serviceConfig.EnableNamespaceNotActiveAutoForwarding,
110110
forwardingClients: clientCache,
111111
payloadSizeLimit: serviceConfig.BlobSizeLimitError,
112+
headersBlacklist: serviceConfig.NexusRequestHeadersBlacklist,
112113
},
113114
GetResultTimeout: serviceConfig.KeepAliveMaxConnectionIdle(),
114115
Logger: log.NewSlogLogger(logger),

service/frontend/service.go

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,10 @@ import (
5050
"google.golang.org/grpc/reflection"
5151
)
5252

53-
var matchAny = regexp.MustCompile(".*")
53+
var (
54+
matchAny = regexp.MustCompile(".*")
55+
matchNothing = regexp.MustCompile(".^")
56+
)
5457

5558
// Config represents configuration for frontend service
5659
type Config struct {
@@ -207,6 +210,8 @@ type Config struct {
207210
MaxCallbacksPerWorkflow dynamicconfig.IntPropertyFnWithNamespaceFilter
208211
CallbackEndpointConfigs dynamicconfig.TypedPropertyFnWithNamespaceFilter[[]callbacks.AddressMatchRule]
209212

213+
NexusRequestHeadersBlacklist *dynamicconfig.GlobalCachedTypedValue[*regexp.Regexp]
214+
210215
LinkMaxSize dynamicconfig.IntPropertyFnWithNamespaceFilter
211216
MaxLinksPerRequest dynamicconfig.IntPropertyFnWithNamespaceFilter
212217

@@ -335,6 +340,17 @@ func NewConfig(
335340
CallbackHeaderMaxSize: dynamicconfig.FrontendCallbackHeaderMaxSize.Get(dc),
336341
MaxCallbacksPerWorkflow: dynamicconfig.MaxCallbacksPerWorkflow.Get(dc),
337342

343+
NexusRequestHeadersBlacklist: dynamicconfig.NewGlobalCachedTypedValue(
344+
dc,
345+
dynamicconfig.FrontendNexusRequestHeadersBlacklist,
346+
func(patterns []string) (*regexp.Regexp, error) {
347+
if len(patterns) == 0 {
348+
return matchNothing, nil
349+
}
350+
return util.WildCardStringsToRegexp(patterns)
351+
},
352+
),
353+
338354
LinkMaxSize: dynamicconfig.FrontendLinkMaxSize.Get(dc),
339355
MaxLinksPerRequest: dynamicconfig.FrontendMaxLinksPerRequest.Get(dc),
340356

0 commit comments

Comments
 (0)