11package jwtcheck
22
33import (
4- "crypto/rsa "
4+ "context "
55 "encoding/json"
66 "errors"
7- "io/ioutil"
7+ "fmt"
8+ "io"
89 "net/http"
10+ "os"
911 "strings"
1012
11- "github.com/form3tech-oss/jwt-go"
13+ keyfunc "github.com/MicahParks/keyfunc/v3"
14+ "github.com/golang-jwt/jwt/v5"
1215 "github.com/interline-io/log"
1316 "github.com/interline-io/transitland-lib/server/auth/authn"
1417)
1518
1619// JWTMiddleware checks and pulls user information from JWT in Authorization header.
20+ // The JWT is validated against a static RSA public key loaded from pubKeyPath.
1721func JWTMiddleware (jwtAudience string , jwtIssuer string , pubKeyPath string , useEmailAsId bool ) (func (http.Handler ) http.Handler , error ) {
18- var verifyKey * rsa.PublicKey
19- verifyBytes , err := ioutil .ReadFile (pubKeyPath )
22+ verifyBytes , err := os .ReadFile (pubKeyPath )
2023 if err != nil {
2124 return nil , err
2225 }
23- verifyKey , err = jwt .ParseRSAPublicKeyFromPEM (verifyBytes )
26+ verifyKey , err : = jwt .ParseRSAPublicKeyFromPEM (verifyBytes )
2427 if err != nil {
2528 return nil , err
2629 }
30+ keyFunc := func (token * jwt.Token ) (any , error ) {
31+ if _ , ok := token .Method .(* jwt.SigningMethodRSA ); ! ok {
32+ return nil , fmt .Errorf ("unexpected signing method: %v" , token .Header ["alg" ])
33+ }
34+ return verifyKey , nil
35+ }
36+ return newJWTHandler (keyFunc , jwtAudience , jwtIssuer , useEmailAsId ), nil
37+ }
38+
39+ // JWTMiddlewareOIDC checks and pulls user information from JWT in Authorization header.
40+ // The JWKS keys are discovered from the issuer's OpenID Connect discovery endpoint.
41+ // The context controls the lifetime of the background JWKS refresh goroutine.
42+ func JWTMiddlewareOIDC (ctx context.Context , jwtAudience string , jwtIssuer string , useEmailAsId bool ) (func (http.Handler ) http.Handler , error ) {
43+ jwksURL , err := discoverJWKSURL (ctx , jwtIssuer )
44+ if err != nil {
45+ return nil , fmt .Errorf ("OIDC discovery failed: %w" , err )
46+ }
47+ kf , err := keyfunc .NewDefaultCtx (ctx , []string {jwksURL })
48+ if err != nil {
49+ return nil , fmt .Errorf ("failed to create JWKS keyfunc: %w" , err )
50+ }
51+ return newJWTHandler (kf .Keyfunc , jwtAudience , jwtIssuer , useEmailAsId ), nil
52+ }
53+
54+ // discoverJWKSURL fetches the OIDC discovery document from the issuer and returns the jwks_uri.
55+ func discoverJWKSURL (ctx context.Context , issuer string ) (string , error ) {
56+ discoveryURL := strings .TrimRight (issuer , "/" ) + "/.well-known/openid-configuration"
57+ req , err := http .NewRequestWithContext (ctx , http .MethodGet , discoveryURL , nil )
58+ if err != nil {
59+ return "" , fmt .Errorf ("failed to create OIDC discovery request: %w" , err )
60+ }
61+ resp , err := http .DefaultClient .Do (req )
62+ if err != nil {
63+ return "" , fmt .Errorf ("failed to fetch OIDC discovery document from %s: %w" , discoveryURL , err )
64+ }
65+ defer resp .Body .Close ()
66+ if resp .StatusCode != http .StatusOK {
67+ return "" , fmt .Errorf ("OIDC discovery endpoint %s returned status %d" , discoveryURL , resp .StatusCode )
68+ }
69+ body , err := io .ReadAll (resp .Body )
70+ if err != nil {
71+ return "" , fmt .Errorf ("failed to read OIDC discovery response: %w" , err )
72+ }
73+ var doc struct {
74+ JWKSURI string `json:"jwks_uri"`
75+ }
76+ if err := json .Unmarshal (body , & doc ); err != nil {
77+ return "" , fmt .Errorf ("failed to parse OIDC discovery document: %w" , err )
78+ }
79+ if doc .JWKSURI == "" {
80+ return "" , errors .New ("OIDC discovery document missing jwks_uri" )
81+ }
82+ return doc .JWKSURI , nil
83+ }
84+
85+ func newJWTHandler (keyFunc jwt.Keyfunc , jwtAudience string , jwtIssuer string , useEmailAsId bool ) func (http.Handler ) http.Handler {
2786 return func (next http.Handler ) http.Handler {
2887 return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
29- if tokenString := strings .Split (r .Header .Get ("Authorization" ), "Bearer " ); len ( tokenString ) == 2 {
30- claims , err := validateJwt (verifyKey , jwtAudience , jwtIssuer , tokenString [ 1 ] )
88+ if tokenString , ok := strings .CutPrefix (r .Header .Get ("Authorization" ), "Bearer " ); ok && tokenString != "" {
89+ claims , err := validateJwt (keyFunc , jwtAudience , jwtIssuer , tokenString )
3190 if err != nil {
3291 log .Error ().Err (err ).Msgf ("invalid jwt token" )
3392 writeJsonError (w , http .StatusText (http .StatusUnauthorized ), http .StatusUnauthorized )
3493 return
3594 }
3695 if claims == nil {
37- log .Error ().Err ( err ). Msgf ("no claims" )
96+ log .Error ().Msgf ("no claims" )
3897 writeJsonError (w , http .StatusText (http .StatusUnauthorized ), http .StatusUnauthorized )
3998 return
4099 }
@@ -47,32 +106,31 @@ func JWTMiddleware(jwtAudience string, jwtIssuer string, pubKeyPath string, useE
47106 }
48107 next .ServeHTTP (w , r )
49108 })
50- }, nil
109+ }
51110}
52111
53- type CustomClaimsExample struct {
54- Email string
55- jwt.StandardClaims
56- }
112+ // CustomClaimsExample contains the JWT claims used by the middleware.
113+ type CustomClaimsExample = customClaims
57114
58- func (c * CustomClaimsExample ) Valid () error {
59- return nil
115+ type customClaims struct {
116+ Email string `json:"email"`
117+ jwt.RegisteredClaims
60118}
61119
62- func validateJwt (rsaPublicKey * rsa.PublicKey , jwtAudience string , jwtIssuer string , tokenString string ) (* CustomClaimsExample , error ) {
63- // Parse the token
64- token , err := jwt .ParseWithClaims (tokenString , & CustomClaimsExample {}, func (token * jwt.Token ) (interface {}, error ) {
65- return rsaPublicKey , nil
66- })
120+ func validateJwt (keyFunc jwt.Keyfunc , jwtAudience string , jwtIssuer string , tokenString string ) (* customClaims , error ) {
121+ token , err := jwt .ParseWithClaims (
122+ tokenString ,
123+ & customClaims {},
124+ keyFunc ,
125+ jwt .WithAudience (jwtAudience ),
126+ jwt .WithIssuer (jwtIssuer ),
127+ )
67128 if err != nil {
68129 return nil , err
69130 }
70- claims := token .Claims .(* CustomClaimsExample )
71- if ! claims .VerifyAudience (jwtAudience , true ) {
72- return nil , errors .New ("invalid audience" )
73- }
74- if ! claims .VerifyIssuer (jwtIssuer , true ) {
75- return nil , errors .New ("invalid issuer" )
131+ claims , ok := token .Claims .(* customClaims )
132+ if ! ok {
133+ return nil , errors .New ("invalid claims type" )
76134 }
77135 return claims , nil
78136}
0 commit comments