Skip to content

Commit e7dfbf1

Browse files
committed
fix(runs): address review — gate header trust, guard enrichment, bound cache
- Add Config.TrustForwardedIdentityHeaders (default true). executed_by is derived from the proxy-forwarded, unverified JWTs only when set; turn off where the service may be reached without a trusted proxy. - Enricher: reject userinfo responses whose `sub` doesn't match the caller (token confusion), and evict expired cache entries (lazily + swept on store) so the map can't grow unbounded. - Fix the misleading "never blocks" doc — userinfo is a synchronous call on cache miss. - created_by column: VARCHAR(255) -> TEXT (OIDC sub length is IdP-dependent). - Clarify created_by is intentionally not in the API filter allowlist. - Tests: actionMetadataFromModel (executed_by populate / created_by fallback / corrupt bytes / none) and enricher subject-mismatch rejection. Signed-off-by: Kevin Su <pingsutw@apache.org>
1 parent 125d7a1 commit e7dfbf1

9 files changed

Lines changed: 123 additions & 12 deletions

File tree

runs/config/config.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ var defaultConfig = &Config{
3232
ExecutionQPS: 10.0,
3333
ExecutionBurst: 20,
3434
},
35+
// Defaults on: the runs service is designed to sit behind an auth proxy/LB. Set
36+
// false for deployments where that guarantee does not hold.
37+
TrustForwardedIdentityHeaders: true,
3538
}
3639

3740
var configSection = config.MustRegisterSection(configSectionKey, defaultConfig)
@@ -68,6 +71,17 @@ type Config struct {
6871
// AuthMetadata configures the OAuth2 authorization-server metadata endpoint
6972
// (the GetOAuth2Metadata RPC and /.well-known/oauth-authorization-server).
7073
AuthMetadata AuthMetadataConfig `json:"authMetadata"`
74+
75+
// TrustForwardedIdentityHeaders controls whether run attribution (executed_by)
76+
// is derived from the auth headers the proxy forwards (X-Amzn-Oidc-*, Authorization).
77+
// Those JWTs are decoded but NOT signature-verified here — that is safe only when
78+
// the service sits behind a trusted proxy/LB that validates tokens and strips any
79+
// client-supplied copies of these headers. Set to false if the service can be
80+
// reached directly or through an untrusted proxy, in which case executed_by is left
81+
// unset rather than risk a spoofed identity. (Note: the runs service performs no
82+
// authorization itself, so a direct caller can already act unauthenticated — this
83+
// flag only governs whether to trust the forwarded identity for attribution.)
84+
TrustForwardedIdentityHeaders bool `json:"trustForwardedIdentityHeaders" pflag:",Derive executed_by from proxy-forwarded auth headers (requires a trusted proxy)"`
7185
}
7286

7387
// AuthMetadataConfig controls how the runs service serves OAuth2 authorization
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
-- Add created_by to actions: the OIDC subject of the identity that created the run.
22
-- Captured from the auth headers the load balancer forwards (it enforces auth),
33
-- and used to populate ActionMetadata.executed_by on read.
4-
ALTER TABLE actions ADD COLUMN IF NOT EXISTS created_by VARCHAR(255);
4+
-- TEXT, not VARCHAR(n): the OIDC `sub` length is IdP-dependent and can exceed 255.
5+
ALTER TABLE actions ADD COLUMN IF NOT EXISTS created_by TEXT;

runs/repository/models/action.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,9 @@ type Action struct {
5353
RunSource string `db:"run_source" json:"run_source,omitempty"`
5454

5555
// CreatedBy is the OIDC subject of the identity that created this run, captured
56-
// from the auth headers the load balancer forwards. Kept for querying/filtering.
57-
// NULL for runs created without an authenticated identity.
56+
// from the auth headers the load balancer forwards. NULL for runs created without
57+
// an authenticated identity. Not exposed for API-level filtering/sorting — it is
58+
// intentionally absent from ActionColumnsSet; add it there only if that's desired.
5859
CreatedBy sql.NullString `db:"created_by" json:"created_by,omitempty"`
5960

6061
// ExecutedBy is the serialized common.EnrichedIdentity of the run's creator

runs/service/executed_by_test.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package service
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
"github.com/stretchr/testify/require"
8+
"google.golang.org/protobuf/proto"
9+
10+
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common"
11+
"github.com/flyteorg/flyte/v2/runs/repository/models"
12+
)
13+
14+
func mustMarshalIdentity(t *testing.T, id *common.EnrichedIdentity) []byte {
15+
t.Helper()
16+
b, err := proto.Marshal(id)
17+
require.NoError(t, err)
18+
return b
19+
}
20+
21+
func fullIdentity(sub, first, last, email string) *common.EnrichedIdentity {
22+
return &common.EnrichedIdentity{Principal: &common.EnrichedIdentity_User{User: &common.User{
23+
Id: &common.UserIdentifier{Subject: sub},
24+
Spec: &common.UserSpec{FirstName: first, LastName: last, Email: email},
25+
}}}
26+
}
27+
28+
func TestActionMetadataFromModel_ExecutedBy(t *testing.T) {
29+
t.Run("full identity from executed_by", func(t *testing.T) {
30+
m := &models.Action{ExecutedBy: mustMarshalIdentity(t, fullIdentity("00u1", "Kevin", "Su", "kevin@union.ai"))}
31+
eb := actionMetadataFromModel(m).GetExecutedBy().GetUser()
32+
assert.Equal(t, "00u1", eb.GetId().GetSubject())
33+
assert.Equal(t, "Kevin", eb.GetSpec().GetFirstName())
34+
assert.Equal(t, "Su", eb.GetSpec().GetLastName())
35+
assert.Equal(t, "kevin@union.ai", eb.GetSpec().GetEmail())
36+
})
37+
38+
t.Run("falls back to subject-only from created_by", func(t *testing.T) {
39+
m := &models.Action{}
40+
m.CreatedBy.Valid, m.CreatedBy.String = true, "00u2"
41+
eb := actionMetadataFromModel(m).GetExecutedBy().GetUser()
42+
assert.Equal(t, "00u2", eb.GetId().GetSubject())
43+
assert.Nil(t, eb.GetSpec())
44+
})
45+
46+
t.Run("corrupt executed_by falls back to created_by", func(t *testing.T) {
47+
m := &models.Action{ExecutedBy: []byte("not a valid proto\xff\xfe")}
48+
m.CreatedBy.Valid, m.CreatedBy.String = true, "00u3"
49+
assert.Equal(t, "00u3", actionMetadataFromModel(m).GetExecutedBy().GetUser().GetId().GetSubject())
50+
})
51+
52+
t.Run("no identity yields nil executed_by", func(t *testing.T) {
53+
assert.Nil(t, actionMetadataFromModel(&models.Action{}).GetExecutedBy())
54+
})
55+
}

runs/service/identity_enricher.go

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,10 @@ func newIdentityEnricher(authServerBaseURL string) *identityEnricher {
5555
// enrich fills any profile fields (email, first/last name) missing from base with
5656
// userinfo claims fetched using the access token. Fields already present on base
5757
// (e.g. from x-amzn-oidc-data) are authoritative and kept. userinfo is queried only
58-
// when the profile is incomplete and not cached. base is returned unchanged on any
59-
// miss or error — enrichment never blocks or fails run creation.
58+
// when the profile is incomplete and not cached. On a cache miss it makes a
59+
// synchronous userinfo call (bounded by userinfoHTTPTimeout), which adds latency to
60+
// run creation; on any error or timeout it returns base unchanged — enrichment is
61+
// best-effort and never fails run creation.
6062
func (e *identityEnricher) enrich(ctx context.Context, accessToken string, base *common.EnrichedIdentity) *common.EnrichedIdentity {
6163
if e == nil || base.GetUser() == nil {
6264
return base
@@ -77,22 +79,41 @@ func (e *identityEnricher) enrich(ctx context.Context, accessToken string, base
7779
logger.Warnf(ctx, "identity enrichment: userinfo fetch failed for subject %q: %v", subject, err)
7880
return base
7981
}
82+
// Guard against token confusion / IdP misconfiguration: never associate a profile
83+
// fetched for a different subject with this run.
84+
if claims.Sub != "" && claims.Sub != subject {
85+
logger.Warnf(ctx, "identity enrichment: userinfo subject %q does not match caller %q; ignoring", claims.Sub, subject)
86+
return base
87+
}
8088
e.store(subject, claims)
8189
return mergeClaims(base, claims)
8290
}
8391

8492
func (e *identityEnricher) cachedFor(subject string) *oidcClaims {
8593
e.mu.Lock()
8694
defer e.mu.Unlock()
87-
if c, ok := e.cache[subject]; ok && time.Now().Before(c.expires) {
88-
return c.claims
95+
c, ok := e.cache[subject]
96+
if !ok {
97+
return nil
8998
}
90-
return nil
99+
if time.Now().After(c.expires) {
100+
// Drop the stale entry so the map does not accumulate dead keys.
101+
delete(e.cache, subject)
102+
return nil
103+
}
104+
return c.claims
91105
}
92106

93107
func (e *identityEnricher) store(subject string, claims *oidcClaims) {
94108
e.mu.Lock()
95109
defer e.mu.Unlock()
110+
// Opportunistically evict any other expired entries to bound the map size.
111+
now := time.Now()
112+
for k, c := range e.cache {
113+
if now.After(c.expires) {
114+
delete(e.cache, k)
115+
}
116+
}
96117
e.cache[subject] = cachedClaims{claims: claims, expires: time.Now().Add(identityCacheTTL)}
97118
}
98119

runs/service/identity_enricher_test.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,16 @@ func TestEnrich_FillsOnlyMissingFields(t *testing.T) {
8383
assert.Equal(t, "kevin@union.ai", spec.GetEmail()) // header email preserved
8484
}
8585

86+
func TestEnrich_RejectsSubjectMismatch(t *testing.T) {
87+
// userinfo returns a different subject than the caller — must not be trusted.
88+
srv, _ := newTestIdP(t, `{"sub":"someone-else","email":"evil@x.com","given_name":"Mallory"}`, http.StatusOK)
89+
e := newIdentityEnricher(srv.URL)
90+
91+
got := e.enrich(context.Background(), "access-tok", subjectOnlyIdentity("00u2"))
92+
assert.Nil(t, got.GetUser().GetSpec())
93+
assert.Equal(t, "00u2", got.GetUser().GetId().GetSubject())
94+
}
95+
8696
func TestEnrich_UserinfoErrorFallsBackToBase(t *testing.T) {
8797
srv, _ := newTestIdP(t, `nope`, http.StatusUnauthorized)
8898
e := newIdentityEnricher(srv.URL)

runs/service/run_service.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ type RunService struct {
4747
dataStore *storage.DataStore
4848
abortReconciler *AbortReconciler
4949
enricher *identityEnricher
50+
// trustHeaders gates deriving executed_by from proxy-forwarded auth headers.
51+
trustHeaders bool
5052
}
5153

5254
type actionDataClient interface {
@@ -154,6 +156,7 @@ func NewRunService(
154156
dataStore *storage.DataStore,
155157
reconciler *AbortReconciler,
156158
authServerBaseURL string,
159+
trustForwardedIdentityHeaders bool,
157160
) *RunService {
158161
return &RunService{
159162
repo: repo,
@@ -164,6 +167,7 @@ func NewRunService(
164167
dataStore: dataStore,
165168
abortReconciler: reconciler,
166169
enricher: newIdentityEnricher(authServerBaseURL),
170+
trustHeaders: trustForwardedIdentityHeaders,
167171
}
168172
}
169173

@@ -331,8 +335,13 @@ func (s *RunService) CreateRun(
331335
// Capture who created the run from the auth headers the load balancer forwards
332336
// (it enforces auth upstream). nil when there is no authenticated identity. On the
333337
// Bearer path the token carries only the subject, so enrich name/email via userinfo.
334-
executedBy := identityFromHeaders(req.Header())
335-
executedBy = s.enricher.enrich(ctx, accessTokenFromHeaders(req.Header()), executedBy)
338+
// Only trust the proxy-forwarded identity headers when configured to (the JWTs are
339+
// decoded, not signature-verified — see Config.TrustForwardedIdentityHeaders).
340+
var executedBy *common.EnrichedIdentity
341+
if s.trustHeaders {
342+
executedBy = identityFromHeaders(req.Header())
343+
executedBy = s.enricher.enrich(ctx, accessTokenFromHeaders(req.Header()), executedBy)
344+
}
336345

337346
// Persist task spec and create a run model
338347
run, err := s.persistRunModel(ctx, runId, taskID, taskSpec, inputPrefix, runOutputBase, runSpec, request.GetSource(), triggerName, triggerTaskName, triggerRevision, triggerType, executedBy)

runs/setup.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ func Setup(ctx context.Context, sc *app.SetupContext) error {
9393
return abortReconciler.Run(ctx)
9494
})
9595

96-
runsSvc := service.NewRunService(repo, actionsClient, dataProxyClient, projectClient, cfg.StoragePrefix, sc.DataStore, abortReconciler, cfg.AuthMetadata.ExternalAuthServerBaseURL)
96+
runsSvc := service.NewRunService(repo, actionsClient, dataProxyClient, projectClient, cfg.StoragePrefix, sc.DataStore, abortReconciler, cfg.AuthMetadata.ExternalAuthServerBaseURL, cfg.TrustForwardedIdentityHeaders)
9797
taskSvc := service.NewTaskService(repo, projectClient)
9898

9999
runsPath, runsHandler := workflowconnect.NewRunServiceHandler(runsSvc, connect.WithInterceptors(otelInterceptor))

runs/test/api/setup_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ func TestMain(m *testing.M) {
113113
// Create RunService with a no-op actions client (points at test server; not used by watch tests)
114114
endpointURL := fmt.Sprintf("http://localhost:%d", testPort)
115115
actionsClient := actionsconnect.NewActionsServiceClient(http.DefaultClient, endpointURL)
116-
runSvc := service.NewRunService(repo, actionsClient, nil, nil, "", nil, nil, "")
116+
runSvc := service.NewRunService(repo, actionsClient, nil, nil, "", nil, nil, "", true)
117117

118118
// Setup HTTP server
119119
mux := http.NewServeMux()

0 commit comments

Comments
 (0)