Skip to content

Commit 201120b

Browse files
authored
Use JWKS for fetching public keys (#584)
* Use JWKS for fetching public keys * Address comments * Additional tes coverage * Strict test * Remove non-ctx version
1 parent 5191fd8 commit 201120b

4 files changed

Lines changed: 349 additions & 34 deletions

File tree

go.mod

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ require (
77
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1
88
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.6.0
99
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.0.0
10+
github.com/MicahParks/keyfunc/v3 v3.8.0
1011
github.com/auth0/go-auth0 v0.17.2
1112
github.com/aws/aws-sdk-go v1.49.6
1213
github.com/aws/aws-sdk-go-v2 v1.36.3
@@ -17,11 +18,11 @@ require (
1718
github.com/deckarep/golang-set/v2 v2.6.0
1819
github.com/dimchansky/utfbom v1.1.1
1920
github.com/flopp/go-staticmaps v0.0.0-20220221183018-c226716bec53
20-
github.com/form3tech-oss/jwt-go v3.2.5+incompatible
2121
github.com/getkin/kin-openapi v0.133.0
2222
github.com/go-chi/chi/v5 v5.2.2
2323
github.com/go-chi/cors v1.2.1
2424
github.com/go-redis/redis/v8 v8.11.5
25+
github.com/golang-jwt/jwt/v5 v5.3.1
2526
github.com/golang-migrate/migrate/v4 v4.18.3
2627
github.com/golang/geo v0.0.0-20210211234256-740aa86cb551
2728
github.com/google/uuid v1.6.0
@@ -62,6 +63,7 @@ require (
6263
cel.dev/expr v0.25.1 // indirect
6364
github.com/Azure/azure-sdk-for-go/sdk/internal v1.8.0 // indirect
6465
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 // indirect
66+
github.com/MicahParks/jwkset v0.11.0 // indirect
6567
github.com/PuerkitoBio/rehttp v1.3.0 // indirect
6668
github.com/Yiling-J/theine-go v0.6.2 // indirect
6769
github.com/agnivade/levenshtein v1.2.1 // indirect
@@ -96,7 +98,6 @@ require (
9698
github.com/go-openapi/jsonpointer v0.21.0 // indirect
9799
github.com/go-openapi/swag v0.23.0 // indirect
98100
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
99-
github.com/golang-jwt/jwt/v5 v5.3.1 // indirect
100101
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
101102
github.com/google/cel-go v0.27.0 // indirect
102103
github.com/gorilla/websocket v1.5.0 // indirect
@@ -173,6 +174,7 @@ require (
173174
golang.org/x/sys v0.40.0 // indirect
174175
golang.org/x/term v0.39.0 // indirect
175176
golang.org/x/text v0.34.0 // indirect
177+
golang.org/x/time v0.9.0 // indirect
176178
golang.org/x/tools v0.41.0 // indirect
177179
gonum.org/v1/gonum v0.17.0 // indirect
178180
google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57 // indirect

go.sum

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ github.com/IBM/pgxpoolprometheus v1.1.2 h1:sHJwxoL5Lw4R79Zt+H4Uj1zZ4iqXJLdk7XDE7
2424
github.com/IBM/pgxpoolprometheus v1.1.2/go.mod h1:+vWzISN6S9ssgurhUNmm6AlXL9XLah3TdWJktquKTR8=
2525
github.com/Masterminds/squirrel v1.5.4 h1:uUcX/aBc8O7Fg9kaISIUsHXdKuqehiXAMQTYX8afzqM=
2626
github.com/Masterminds/squirrel v1.5.4/go.mod h1:NNaOrjSoIDfDA40n7sr2tPNZRfjzjA400rg+riTZj10=
27+
github.com/MicahParks/jwkset v0.11.0 h1:yc0zG+jCvZpWgFDFmvs8/8jqqVBG9oyIbmBtmjOhoyQ=
28+
github.com/MicahParks/jwkset v0.11.0/go.mod h1:U2oRhRaLgDCLjtpGL2GseNKGmZtLs/3O7p+OZaL5vo0=
29+
github.com/MicahParks/keyfunc/v3 v3.8.0 h1:Hx2dgIjAXGk9slakM6rV9BOeaWDPEXXZ4Us8guNBfds=
30+
github.com/MicahParks/keyfunc/v3 v3.8.0/go.mod h1:z66bkCviwqfg2YUp+Jcc/xRE9IXLcMq6DrgV/+Htru0=
2731
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
2832
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
2933
github.com/PuerkitoBio/rehttp v1.3.0 h1:w54Pb72MQn2eJrSdPsvGqXlAfiK1+NMTGDrOJJ4YvSU=
@@ -149,8 +153,6 @@ github.com/flopp/go-staticmaps v0.0.0-20220221183018-c226716bec53 h1:bpgLIxOpmht
149153
github.com/flopp/go-staticmaps v0.0.0-20220221183018-c226716bec53/go.mod h1:vGgI6wKa1TTiN9iumpzYZgNc/C7KxqsZbw9OH8O10iQ=
150154
github.com/fogleman/gg v1.3.0 h1:/7zJX8F6AaYQc57WQCyN9cAIz+4bCJGO9B+dyW29am8=
151155
github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
152-
github.com/form3tech-oss/jwt-go v3.2.5+incompatible h1:/l4kBbb4/vGSsdtB5nUe8L7B9mImVMaBPw9L/0TBHU8=
153-
github.com/form3tech-oss/jwt-go v3.2.5+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k=
154156
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
155157
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
156158
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
@@ -552,6 +554,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
552554
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
553555
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
554556
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
557+
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
558+
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
555559
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
556560
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
557561
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=

server/auth/mw/jwtcheck/jwtcheck.go

Lines changed: 85 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,99 @@
11
package jwtcheck
22

33
import (
4-
"crypto/rsa"
4+
"context"
55
"encoding/json"
66
"errors"
7-
"io/ioutil"
7+
"fmt"
8+
"io"
89
"net/http"
10+
"os"
911
"strings"
1012

11-
"github.com/form3tech-oss/jwt-go"
13+
keyfunc "github.com/MicahParks/keyfunc/v3"
14+
"github.com/golang-jwt/jwt/v5"
1215
"github.com/interline-io/log"
1316
"github.com/interline-io/transitland-lib/server/auth/authn"
1417
)
1518

1619
// JWTMiddleware checks and pulls user information from JWT in Authorization header.
20+
// The JWT is validated against a static RSA public key loaded from pubKeyPath.
1721
func JWTMiddleware(jwtAudience string, jwtIssuer string, pubKeyPath string, useEmailAsId bool) (func(http.Handler) http.Handler, error) {
18-
var verifyKey *rsa.PublicKey
19-
verifyBytes, err := ioutil.ReadFile(pubKeyPath)
22+
verifyBytes, err := os.ReadFile(pubKeyPath)
2023
if err != nil {
2124
return nil, err
2225
}
23-
verifyKey, err = jwt.ParseRSAPublicKeyFromPEM(verifyBytes)
26+
verifyKey, err := jwt.ParseRSAPublicKeyFromPEM(verifyBytes)
2427
if err != nil {
2528
return nil, err
2629
}
30+
keyFunc := func(token *jwt.Token) (any, error) {
31+
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
32+
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
33+
}
34+
return verifyKey, nil
35+
}
36+
return newJWTHandler(keyFunc, jwtAudience, jwtIssuer, useEmailAsId), nil
37+
}
38+
39+
// JWTMiddlewareOIDC checks and pulls user information from JWT in Authorization header.
40+
// The JWKS keys are discovered from the issuer's OpenID Connect discovery endpoint.
41+
// The context controls the lifetime of the background JWKS refresh goroutine.
42+
func JWTMiddlewareOIDC(ctx context.Context, jwtAudience string, jwtIssuer string, useEmailAsId bool) (func(http.Handler) http.Handler, error) {
43+
jwksURL, err := discoverJWKSURL(ctx, jwtIssuer)
44+
if err != nil {
45+
return nil, fmt.Errorf("OIDC discovery failed: %w", err)
46+
}
47+
kf, err := keyfunc.NewDefaultCtx(ctx, []string{jwksURL})
48+
if err != nil {
49+
return nil, fmt.Errorf("failed to create JWKS keyfunc: %w", err)
50+
}
51+
return newJWTHandler(kf.Keyfunc, jwtAudience, jwtIssuer, useEmailAsId), nil
52+
}
53+
54+
// discoverJWKSURL fetches the OIDC discovery document from the issuer and returns the jwks_uri.
55+
func discoverJWKSURL(ctx context.Context, issuer string) (string, error) {
56+
discoveryURL := strings.TrimRight(issuer, "/") + "/.well-known/openid-configuration"
57+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, discoveryURL, nil)
58+
if err != nil {
59+
return "", fmt.Errorf("failed to create OIDC discovery request: %w", err)
60+
}
61+
resp, err := http.DefaultClient.Do(req)
62+
if err != nil {
63+
return "", fmt.Errorf("failed to fetch OIDC discovery document from %s: %w", discoveryURL, err)
64+
}
65+
defer resp.Body.Close()
66+
if resp.StatusCode != http.StatusOK {
67+
return "", fmt.Errorf("OIDC discovery endpoint %s returned status %d", discoveryURL, resp.StatusCode)
68+
}
69+
body, err := io.ReadAll(resp.Body)
70+
if err != nil {
71+
return "", fmt.Errorf("failed to read OIDC discovery response: %w", err)
72+
}
73+
var doc struct {
74+
JWKSURI string `json:"jwks_uri"`
75+
}
76+
if err := json.Unmarshal(body, &doc); err != nil {
77+
return "", fmt.Errorf("failed to parse OIDC discovery document: %w", err)
78+
}
79+
if doc.JWKSURI == "" {
80+
return "", errors.New("OIDC discovery document missing jwks_uri")
81+
}
82+
return doc.JWKSURI, nil
83+
}
84+
85+
func newJWTHandler(keyFunc jwt.Keyfunc, jwtAudience string, jwtIssuer string, useEmailAsId bool) func(http.Handler) http.Handler {
2786
return func(next http.Handler) http.Handler {
2887
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
29-
if tokenString := strings.Split(r.Header.Get("Authorization"), "Bearer "); len(tokenString) == 2 {
30-
claims, err := validateJwt(verifyKey, jwtAudience, jwtIssuer, tokenString[1])
88+
if tokenString, ok := strings.CutPrefix(r.Header.Get("Authorization"), "Bearer "); ok && tokenString != "" {
89+
claims, err := validateJwt(keyFunc, jwtAudience, jwtIssuer, tokenString)
3190
if err != nil {
3291
log.Error().Err(err).Msgf("invalid jwt token")
3392
writeJsonError(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
3493
return
3594
}
3695
if claims == nil {
37-
log.Error().Err(err).Msgf("no claims")
96+
log.Error().Msgf("no claims")
3897
writeJsonError(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
3998
return
4099
}
@@ -47,32 +106,31 @@ func JWTMiddleware(jwtAudience string, jwtIssuer string, pubKeyPath string, useE
47106
}
48107
next.ServeHTTP(w, r)
49108
})
50-
}, nil
109+
}
51110
}
52111

53-
type CustomClaimsExample struct {
54-
Email string
55-
jwt.StandardClaims
56-
}
112+
// CustomClaimsExample contains the JWT claims used by the middleware.
113+
type CustomClaimsExample = customClaims
57114

58-
func (c *CustomClaimsExample) Valid() error {
59-
return nil
115+
type customClaims struct {
116+
Email string `json:"email"`
117+
jwt.RegisteredClaims
60118
}
61119

62-
func validateJwt(rsaPublicKey *rsa.PublicKey, jwtAudience string, jwtIssuer string, tokenString string) (*CustomClaimsExample, error) {
63-
// Parse the token
64-
token, err := jwt.ParseWithClaims(tokenString, &CustomClaimsExample{}, func(token *jwt.Token) (interface{}, error) {
65-
return rsaPublicKey, nil
66-
})
120+
func validateJwt(keyFunc jwt.Keyfunc, jwtAudience string, jwtIssuer string, tokenString string) (*customClaims, error) {
121+
token, err := jwt.ParseWithClaims(
122+
tokenString,
123+
&customClaims{},
124+
keyFunc,
125+
jwt.WithAudience(jwtAudience),
126+
jwt.WithIssuer(jwtIssuer),
127+
)
67128
if err != nil {
68129
return nil, err
69130
}
70-
claims := token.Claims.(*CustomClaimsExample)
71-
if !claims.VerifyAudience(jwtAudience, true) {
72-
return nil, errors.New("invalid audience")
73-
}
74-
if !claims.VerifyIssuer(jwtIssuer, true) {
75-
return nil, errors.New("invalid issuer")
131+
claims, ok := token.Claims.(*customClaims)
132+
if !ok {
133+
return nil, errors.New("invalid claims type")
76134
}
77135
return claims, nil
78136
}

0 commit comments

Comments
 (0)