Skip to content

Commit a18c406

Browse files
fix(auth-callout): harden oauth jwt validation (#49)
Signed-off-by: Frank Spitulski <fspitulski@nvidia.com>
1 parent f318101 commit a18c406

14 files changed

Lines changed: 166 additions & 19 deletions

File tree

auth-callout/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ nats:
3535
jwks:
3636
url: "https://keycloak/realms/master/protocol/openid-connect/certs"
3737
issuer: "https://keycloak/realms/master"
38+
audience: "dsx-exchange"
3839

3940
mtls:
4041
ca-path: "/etc/ssl/certs/ca.crt"

auth-callout/api/env.example

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ NATS_URL=nats://nats:4222
88
# JWKS URL for OAuth2 token validation
99
JWKS_URL=http://keycloak.127-0-0-1.nip.io:8080/realms/auth-callout/protocol/openid-connect/certs
1010
JWKS_ISSUER=http://keycloak.127-0-0-1.nip.io:8080/realms/auth-callout
11+
JWKS_AUDIENCE=dsx-exchange

auth-callout/deploy/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ serviceConfig:
8787
jwks:
8888
url: "https://keycloak/realms/master/protocol/openid-connect/certs"
8989
issuer: "https://keycloak/realms/master"
90+
audience: "dsx-exchange"
9091
```
9192

9293
### mTLS Configuration

auth-callout/deploy/values.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ serviceConfig:
201201
jwks:
202202
url: "" # JWKS endpoint URL (e.g., "https://auth.example.com/.well-known/jwks.json")
203203
issuer: "" # Expected JWT issuer
204+
audience: "" # Expected JWT audience
204205

205206
# mTLS configuration
206207
mtls:

auth-callout/devspace.yaml

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,20 @@ deployments:
225225
"attributes": {
226226
"include.in.token.scope": "true",
227227
"display.on.consent.screen": "true"
228-
}
228+
},
229+
"protocolMappers": [
230+
{
231+
"name": "dsx-exchange-audience",
232+
"protocol": "openid-connect",
233+
"protocolMapper": "oidc-audience-mapper",
234+
"consentRequired": false,
235+
"config": {
236+
"included.custom.audience": "dsx-exchange",
237+
"id.token.claim": "false",
238+
"access.token.claim": "true"
239+
}
240+
}
241+
]
229242
}
230243
],
231244
"users": []
@@ -371,6 +384,7 @@ deployments:
371384
jwks:
372385
url: "http://keycloak-service.keycloak:8080/realms/auth-callout/protocol/openid-connect/certs"
373386
issuer: "http://keycloak-service.keycloak:8080/realms/auth-callout"
387+
audience: "dsx-exchange"
374388
permissions:
375389
oauth2:
376390
csc-admin:

auth-callout/env.example

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ NATS_URL=nats://nats:4222
88
# JWKS URL for OAuth2 token validation
99
JWKS_URL=http://keycloak.127-0-0-1.nip.io:8080/realms/auth-callout/protocol/openid-connect/certs
1010
JWKS_ISSUER=http://keycloak.127-0-0-1.nip.io:8080/realms/auth-callout
11+
JWKS_AUDIENCE=dsx-exchange
1112

1213
# NATS keys are secrets and must come from Vault or generated local files.
1314
# Generate development values with scripts/devspace-get-key.sh or nsc.

auth-callout/src/cmd/auth_callout/config/defaults.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ nats:
1919
jwks:
2020
url: "" # JWKS endpoint URL
2121
issuer: "" # Expected JWT issuer
22+
audience: "" # Expected JWT audience
2223

2324
# mTLS configuration
2425
mtls:

auth-callout/src/internal/appconfig/manager.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ func applyAliases(k *koanf.Koanf) {
107107
setString(k, "nats.xkey-seed", "NATS_XKEY_SEED", envPrefix+"NATS_XKEY_SEED")
108108
setString(k, "jwks.url", "JWKS_URL", envPrefix+"JWKS_URL")
109109
setString(k, "jwks.issuer", "JWKS_ISSUER", envPrefix+"JWKS_ISSUER")
110+
setString(k, "jwks.audience", "JWKS_AUDIENCE", envPrefix+"JWKS_AUDIENCE")
110111
setString(k, "mtls.ca-path", "MTLS_CA_PATH", envPrefix+"MTLS_CA_PATH")
111112
setString(k, "permissions.file", "PERMISSIONS_FILE", envPrefix+"PERMISSIONS_FILE")
112113
setString(k, "observability.telemetry.service-name", envPrefix+"SERVICE_NAME")

auth-callout/src/internal/auth/auth_test.go

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ func TestOAuth2Authentication(t *testing.T) {
8383
oauth2Auth, err := NewOAuth2Authenticator(
8484
jwksServer.URL,
8585
"https://auth.example.com/",
86+
"test-audience",
8687
pm,
8788
testLogger(),
8889
testServiceName,
@@ -179,6 +180,44 @@ func TestOAuth2Authentication(t *testing.T) {
179180
expectError: true,
180181
expectedErrMsg: "missing required scope: mqtt",
181182
},
183+
{
184+
name: "missing expiration fails",
185+
claims: gojwt.MapClaims{
186+
"iss": "https://auth.example.com/",
187+
"sub": "user@example.com",
188+
"aud": "test-audience",
189+
"iat": now.Unix(),
190+
"scope": "mqtt",
191+
},
192+
expectError: true,
193+
expectedErrMsg: "token is missing required claim: exp claim is required",
194+
},
195+
{
196+
name: "wrong issuer fails",
197+
claims: gojwt.MapClaims{
198+
"iss": "https://other.example.com/",
199+
"sub": "user@example.com",
200+
"aud": "test-audience",
201+
"exp": now.Add(1 * time.Hour).Unix(),
202+
"iat": now.Unix(),
203+
"scope": "mqtt",
204+
},
205+
expectError: true,
206+
expectedErrMsg: "token has invalid issuer",
207+
},
208+
{
209+
name: "wrong audience fails",
210+
claims: gojwt.MapClaims{
211+
"iss": "https://auth.example.com/",
212+
"sub": "user@example.com",
213+
"aud": "wrong-audience",
214+
"exp": now.Add(1 * time.Hour).Unix(),
215+
"iat": now.Unix(),
216+
"scope": "mqtt",
217+
},
218+
expectError: true,
219+
expectedErrMsg: "token has invalid audience",
220+
},
182221
}
183222

184223
for _, tt := range tests {
@@ -207,6 +246,66 @@ func TestOAuth2Authentication(t *testing.T) {
207246
}
208247
}
209248

249+
func TestOAuth2RejectsUnexpectedSigningMethod(t *testing.T) {
250+
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
251+
require.NoError(t, err)
252+
253+
jwkSet := jwkset.NewMemoryStorage()
254+
jwk, err := jwkset.NewJWKFromKey(privateKey, jwkset.JWKOptions{
255+
Metadata: jwkset.JWKMetadataOptions{
256+
KID: "test-key-1",
257+
},
258+
})
259+
require.NoError(t, err)
260+
require.NoError(t, jwkSet.KeyWrite(context.Background(), jwk))
261+
262+
jwksServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
263+
jwks, err := jwkSet.JSONPublic(context.Background())
264+
if err != nil {
265+
http.Error(w, "Failed to get JWKS", http.StatusInternalServerError)
266+
return
267+
}
268+
w.Header().Set("Content-Type", "application/json")
269+
if _, err := w.Write(jwks); err != nil {
270+
http.Error(w, "Failed to write JWKS", http.StatusInternalServerError)
271+
}
272+
}))
273+
defer jwksServer.Close()
274+
275+
permFile := createTestPermissionsFile(t)
276+
defer os.Remove(permFile)
277+
278+
pm, err := config.NewPermissionsManager(permFile, testLogger())
279+
require.NoError(t, err)
280+
defer pm.Close()
281+
282+
oauth2Auth, err := NewOAuth2Authenticator(
283+
jwksServer.URL,
284+
"https://auth.example.com/",
285+
"test-audience",
286+
pm,
287+
testLogger(),
288+
testServiceName,
289+
)
290+
require.NoError(t, err)
291+
defer oauth2Auth.Close()
292+
293+
token := gojwt.NewWithClaims(gojwt.SigningMethodHS256, gojwt.MapClaims{
294+
"iss": "https://auth.example.com/",
295+
"sub": "user@example.com",
296+
"aud": "test-audience",
297+
"exp": time.Now().Add(1 * time.Hour).Unix(),
298+
"scope": "mqtt",
299+
})
300+
token.Header["kid"] = "test-key-1"
301+
tokenString, err := token.SignedString([]byte("secret"))
302+
require.NoError(t, err)
303+
304+
_, err = oauth2Auth.Authenticate(context.Background(), tokenString)
305+
require.Error(t, err)
306+
assert.Contains(t, err.Error(), "signing method HS256 is invalid")
307+
}
308+
210309
// TestOAuth2RequiredScope tests per-client required scope validation
211310
func TestOAuth2RequiredScope(t *testing.T) {
212311
// Generate RSA key pair for JWT signing
@@ -278,6 +377,7 @@ func TestOAuth2RequiredScope(t *testing.T) {
278377
oauth2Auth, err := NewOAuth2Authenticator(
279378
jwksServer.URL,
280379
"https://auth.example.com/",
380+
"test-audience",
281381
pm,
282382
testLogger(),
283383
testServiceName,
@@ -299,6 +399,7 @@ func TestOAuth2RequiredScope(t *testing.T) {
299399
claims: gojwt.MapClaims{
300400
"iss": "https://auth.example.com/",
301401
"sub": "default@example.com",
402+
"aud": "test-audience",
302403
"exp": now.Add(1 * time.Hour).Unix(),
303404
"iat": now.Unix(),
304405
"scope": "mqtt openid",
@@ -311,6 +412,7 @@ func TestOAuth2RequiredScope(t *testing.T) {
311412
claims: gojwt.MapClaims{
312413
"iss": "https://auth.example.com/",
313414
"sub": "default@example.com",
415+
"aud": "test-audience",
314416
"exp": now.Add(1 * time.Hour).Unix(),
315417
"iat": now.Unix(),
316418
"scope": "openid profile",
@@ -324,6 +426,7 @@ func TestOAuth2RequiredScope(t *testing.T) {
324426
"iss": "https://auth.example.com/",
325427
"sub": "some-service",
326428
"azp": "custom-client-id",
429+
"aud": "test-audience",
327430
"exp": now.Add(1 * time.Hour).Unix(),
328431
"iat": now.Unix(),
329432
"scope": "nats:events openid",
@@ -337,6 +440,7 @@ func TestOAuth2RequiredScope(t *testing.T) {
337440
"iss": "https://auth.example.com/",
338441
"sub": "some-service",
339442
"azp": "custom-client-id",
443+
"aud": "test-audience",
340444
"exp": now.Add(1 * time.Hour).Unix(),
341445
"iat": now.Unix(),
342446
"scope": "mqtt openid",

auth-callout/src/internal/auth/oauth2.go

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,22 @@ type OAuth2Authenticator struct {
2626
jwks keyfunc.Keyfunc
2727
pm *config.PermissionsManager
2828
issuer string
29+
audience string
2930
jwksURL string
3031
logger *otelzap.Logger
3132
serviceName string
3233
cancel context.CancelFunc
3334
}
3435

3536
// NewOAuth2Authenticator creates a new OAuth2 authenticator
36-
func NewOAuth2Authenticator(jwksURL string, issuer string, pm *config.PermissionsManager, logger *otelzap.Logger, serviceName string) (*OAuth2Authenticator, error) {
37+
func NewOAuth2Authenticator(jwksURL string, issuer string, audience string, pm *config.PermissionsManager, logger *otelzap.Logger, serviceName string) (*OAuth2Authenticator, error) {
38+
if issuer == "" {
39+
return nil, fmt.Errorf("OAuth2 issuer is required")
40+
}
41+
if audience == "" {
42+
return nil, fmt.Errorf("OAuth2 audience is required")
43+
}
44+
3745
// Create JWKS client with automatic refresh - context controls lifecycle
3846
ctx, cancel := context.WithCancel(context.Background())
3947
k, err := keyfunc.NewDefaultCtx(ctx, []string{jwksURL})
@@ -48,6 +56,7 @@ func NewOAuth2Authenticator(jwksURL string, issuer string, pm *config.Permission
4856
jwks: k,
4957
pm: pm,
5058
issuer: issuer,
59+
audience: audience,
5160
jwksURL: jwksURL,
5261
logger: logger,
5362
serviceName: serviceName,
@@ -78,7 +87,15 @@ func (o *OAuth2Authenticator) Authenticate(ctx context.Context, token string) (*
7887
}
7988

8089
// Parse and validate the JWT token
81-
parsed, err := jwt.ParseWithClaims(token, &Claims{}, o.jwks.Keyfunc)
90+
parsed, err := jwt.ParseWithClaims(
91+
token,
92+
&Claims{},
93+
o.jwks.Keyfunc,
94+
jwt.WithValidMethods([]string{jwt.SigningMethodRS256.Alg()}),
95+
jwt.WithExpirationRequired(),
96+
jwt.WithIssuer(o.issuer),
97+
jwt.WithAudience(o.audience),
98+
)
8299
if err != nil {
83100
if counter, err := meter.Int64Counter("auth_oauth2_failures_total",
84101
metric.WithDescription("Total OAuth2 authentication failures")); err == nil {
@@ -113,18 +130,6 @@ func (o *OAuth2Authenticator) Authenticate(ctx context.Context, token string) (*
113130
return nil, fmt.Errorf("invalid claims type")
114131
}
115132

116-
// Validate issuer if configured
117-
if o.issuer != "" && claims.Issuer != o.issuer {
118-
if counter, err := meter.Int64Counter("auth_oauth2_failures_total",
119-
metric.WithDescription("Total OAuth2 authentication failures")); err == nil {
120-
counter.Add(ctx, 1, metric.WithAttributes(
121-
attribute.String("method", "oauth2"),
122-
attribute.String("reason", "invalid_issuer"),
123-
))
124-
}
125-
return nil, fmt.Errorf("invalid issuer: expected %s, got %s", o.issuer, claims.Issuer)
126-
}
127-
128133
// Look up user profile in permissions config using both subject and azp
129134
profile, requiredScope, ok := o.pm.GetOAuth2Profile(claims.Subject, claims.Azp)
130135
if !ok {

0 commit comments

Comments
 (0)