Skip to content

Commit 5238a5d

Browse files
committed
Address comments
1 parent a00a97f commit 5238a5d

1 file changed

Lines changed: 18 additions & 4 deletions

File tree

server/auth/mw/jwtcheck/jwtcheck.go

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"net/http"
1010
"os"
1111
"strings"
12+
"time"
1213

1314
keyfunc "github.com/MicahParks/keyfunc/v3"
1415
"github.com/golang-jwt/jwt/v5"
@@ -27,7 +28,7 @@ func JWTMiddleware(jwtAudience string, jwtIssuer string, pubKeyPath string, useE
2728
if err != nil {
2829
return nil, err
2930
}
30-
keyFunc := func(token *jwt.Token) (interface{}, error) {
31+
keyFunc := func(token *jwt.Token) (any, error) {
3132
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
3233
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
3334
}
@@ -66,8 +67,18 @@ func JWTMiddlewareOIDCCtx(ctx context.Context, jwtAudience string, jwtIssuer str
6667

6768
// discoverJWKSURL fetches the OIDC discovery document from the issuer and returns the jwks_uri.
6869
func discoverJWKSURL(issuer string) (string, error) {
70+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
71+
defer cancel()
72+
return discoverJWKSURLWithContext(ctx, issuer)
73+
}
74+
75+
func discoverJWKSURLWithContext(ctx context.Context, issuer string) (string, error) {
6976
discoveryURL := strings.TrimRight(issuer, "/") + "/.well-known/openid-configuration"
70-
resp, err := http.Get(discoveryURL)
77+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, discoveryURL, nil)
78+
if err != nil {
79+
return "", fmt.Errorf("failed to create OIDC discovery request: %w", err)
80+
}
81+
resp, err := http.DefaultClient.Do(req)
7182
if err != nil {
7283
return "", fmt.Errorf("failed to fetch OIDC discovery document from %s: %w", discoveryURL, err)
7384
}
@@ -94,8 +105,8 @@ func discoverJWKSURL(issuer string) (string, error) {
94105
func newJWTHandler(keyFunc jwt.Keyfunc, jwtAudience string, jwtIssuer string, useEmailAsId bool) func(http.Handler) http.Handler {
95106
return func(next http.Handler) http.Handler {
96107
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
97-
if tokenString := strings.Split(r.Header.Get("Authorization"), "Bearer "); len(tokenString) == 2 {
98-
claims, err := validateJwt(keyFunc, jwtAudience, jwtIssuer, tokenString[1])
108+
if tokenString, ok := strings.CutPrefix(r.Header.Get("Authorization"), "Bearer "); ok && tokenString != "" {
109+
claims, err := validateJwt(keyFunc, jwtAudience, jwtIssuer, tokenString)
99110
if err != nil {
100111
log.Error().Err(err).Msgf("invalid jwt token")
101112
writeJsonError(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
@@ -118,6 +129,9 @@ func newJWTHandler(keyFunc jwt.Keyfunc, jwtAudience string, jwtIssuer string, us
118129
}
119130
}
120131

132+
// CustomClaimsExample contains the JWT claims used by the middleware.
133+
type CustomClaimsExample = customClaims
134+
121135
type customClaims struct {
122136
Email string `json:"email"`
123137
jwt.RegisteredClaims

0 commit comments

Comments
 (0)