Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit c690994

Browse files
authoredJan 6, 2022
Merge pull request #134 from auth0/patch/improve-concurrency
Improve concurrency
2 parents 985899d + 9df4394 commit c690994

File tree

7 files changed

+39
-18
lines changed

7 files changed

+39
-18
lines changed
 

‎Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
.PHONY: test
44
test: ## Run tests.
5-
go test -cover -covermode=atomic -coverprofile=coverage.out ./...
5+
go test -race -cover -covermode=atomic -coverprofile=coverage.out ./...
66

77
.PHONY: lint
88
lint: ## Run golangci-lint.

‎examples/http-example/main.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ func main() {
4848

4949
// We want this struct to be filled in with
5050
// our custom claims from the token.
51-
customClaims := &CustomClaimsExample{}
51+
customClaims := func() validator.CustomClaims {
52+
return &CustomClaimsExample{}
53+
}
5254

5355
// Set up the validator.
5456
jwtValidator, err := validator.New(

‎extractor_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ func Test_AuthHeaderTokenExtractor(t *testing.T) {
4444

4545
for _, testCase := range testCases {
4646
t.Run(testCase.name, func(t *testing.T) {
47+
t.Parallel()
48+
4749
gotToken, err := AuthHeaderTokenExtractor(testCase.request)
4850
if testCase.wantError != "" {
4951
assert.EqualError(t, err, testCase.wantError)
@@ -96,6 +98,8 @@ func Test_CookieTokenExtractor(t *testing.T) {
9698

9799
for _, testCase := range testCases {
98100
t.Run(testCase.name, func(t *testing.T) {
101+
t.Parallel()
102+
99103
request, err := http.NewRequest(http.MethodGet, "https://example.com", nil)
100104
require.NoError(t, err)
101105

‎middleware_test.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ func Test_CheckJWT(t *testing.T) {
5151
name: "it can successfully validate a token",
5252
validateToken: jwtValidator.ValidateToken,
5353
token: validToken,
54+
method: http.MethodGet,
5455
wantToken: tokenClaims,
5556
wantStatusCode: http.StatusOK,
5657
wantBody: `{"message":"Authenticated."}`,
@@ -67,19 +68,22 @@ func Test_CheckJWT(t *testing.T) {
6768
{
6869
name: "it fails to validate a token with a bad format",
6970
token: "bad",
71+
method: http.MethodGet,
7072
wantStatusCode: http.StatusInternalServerError,
7173
wantBody: `{"message":"Something went wrong while checking the JWT."}`,
7274
},
7375
{
7476
name: "it fails to validate if token is missing and credentials are not optional",
7577
token: "",
78+
method: http.MethodGet,
7679
wantStatusCode: http.StatusBadRequest,
7780
wantBody: `{"message":"JWT is missing."}`,
7881
},
7982
{
8083
name: "it fails to validate an invalid token",
8184
validateToken: jwtValidator.ValidateToken,
8285
token: invalidToken,
86+
method: http.MethodGet,
8387
wantStatusCode: http.StatusUnauthorized,
8488
wantBody: `{"message":"JWT is invalid."}`,
8589
},
@@ -100,6 +104,7 @@ func Test_CheckJWT(t *testing.T) {
100104
return "", errors.New("token extractor error")
101105
}),
102106
},
107+
method: http.MethodGet,
103108
wantStatusCode: http.StatusInternalServerError,
104109
wantBody: `{"message":"Something went wrong while checking the JWT."}`,
105110
},
@@ -111,6 +116,7 @@ func Test_CheckJWT(t *testing.T) {
111116
return "", nil
112117
}),
113118
},
119+
method: http.MethodGet,
114120
wantStatusCode: http.StatusOK,
115121
wantBody: `{"message":"Authenticated."}`,
116122
},
@@ -123,16 +129,15 @@ func Test_CheckJWT(t *testing.T) {
123129
return "", nil
124130
}),
125131
},
132+
method: http.MethodGet,
126133
wantStatusCode: http.StatusBadRequest,
127134
wantBody: `{"message":"JWT is missing."}`,
128135
},
129136
}
130137

131138
for _, testCase := range testCases {
132139
t.Run(testCase.name, func(t *testing.T) {
133-
if testCase.method == "" {
134-
testCase.method = http.MethodGet
135-
}
140+
t.Parallel()
136141

137142
middleware := New(testCase.validateToken, testCase.options...)
138143

‎validator/option.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ func WithAllowedClockSkew(skew time.Duration) Option {
2121
// CustomClaims that will be unmarshalled into and on which
2222
// Validate is called on for custom validation. If this option
2323
// is not used the Validator will do nothing for custom claims.
24-
func WithCustomClaims(c CustomClaims) Option {
24+
func WithCustomClaims(f func() CustomClaims) Option {
2525
return func(v *Validator) {
26-
v.customClaims = c
26+
v.customClaims = f
2727
}
2828
}

‎validator/validator.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ type Validator struct {
3131
keyFunc func(context.Context) (interface{}, error) // Required.
3232
signatureAlgorithm SignatureAlgorithm // Required.
3333
expectedClaims jwt.Expected // Internal.
34-
customClaims CustomClaims // Optional.
34+
customClaims func() CustomClaims // Optional.
3535
allowedClockSkew time.Duration // Optional.
3636
}
3737

@@ -114,16 +114,17 @@ func (v *Validator) ValidateToken(ctx context.Context, tokenString string) (inte
114114

115115
claimDest := []interface{}{&jwt.Claims{}}
116116
if v.customClaims != nil {
117-
claimDest = append(claimDest, v.customClaims)
117+
claimDest = append(claimDest, v.customClaims())
118118
}
119119

120120
if err = token.Claims(key, claimDest...); err != nil {
121121
return nil, fmt.Errorf("could not get token claims: %w", err)
122122
}
123123

124124
registeredClaims := *claimDest[0].(*jwt.Claims)
125-
v.expectedClaims.Time = time.Now()
126-
if err = registeredClaims.ValidateWithLeeway(v.expectedClaims, v.allowedClockSkew); err != nil {
125+
expectedClaims := v.expectedClaims
126+
expectedClaims.Time = time.Now()
127+
if err = registeredClaims.ValidateWithLeeway(expectedClaims, v.allowedClockSkew); err != nil {
127128
return nil, fmt.Errorf("expected claims not validated: %w", err)
128129
}
129130

‎validator/validator_test.go

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ func TestValidator_ValidateToken(t *testing.T) {
3030
token string
3131
keyFunc func(context.Context) (interface{}, error)
3232
algorithm SignatureAlgorithm
33-
customClaims CustomClaims
33+
customClaims func() CustomClaims
3434
expectedError error
3535
expectedClaims *ValidatedClaims
3636
}{
@@ -40,6 +40,7 @@ func TestValidator_ValidateToken(t *testing.T) {
4040
keyFunc: func(context.Context) (interface{}, error) {
4141
return []byte("secret"), nil
4242
},
43+
algorithm: HS256,
4344
expectedClaims: &ValidatedClaims{
4445
RegisteredClaims: RegisteredClaims{
4546
Issuer: issuer,
@@ -54,7 +55,10 @@ func TestValidator_ValidateToken(t *testing.T) {
5455
keyFunc: func(context.Context) (interface{}, error) {
5556
return []byte("secret"), nil
5657
},
57-
customClaims: &testClaims{},
58+
algorithm: HS256,
59+
customClaims: func() CustomClaims {
60+
return &testClaims{}
61+
},
5862
expectedClaims: &ValidatedClaims{
5963
RegisteredClaims: RegisteredClaims{
6064
Issuer: issuer,
@@ -81,6 +85,7 @@ func TestValidator_ValidateToken(t *testing.T) {
8185
keyFunc: func(context.Context) (interface{}, error) {
8286
return []byte("secret"), nil
8387
},
88+
algorithm: HS256,
8489
expectedError: errors.New("could not parse the token: square/go-jose: compact JWS format must have three parts"),
8590
},
8691
{
@@ -89,6 +94,7 @@ func TestValidator_ValidateToken(t *testing.T) {
8994
keyFunc: func(context.Context) (interface{}, error) {
9095
return nil, errors.New("key func error message")
9196
},
97+
algorithm: HS256,
9298
expectedError: errors.New("error getting the keys from the key func: key func error message"),
9399
},
94100
{
@@ -97,6 +103,7 @@ func TestValidator_ValidateToken(t *testing.T) {
97103
keyFunc: func(context.Context) (interface{}, error) {
98104
return []byte("secret"), nil
99105
},
106+
algorithm: HS256,
100107
expectedError: errors.New("could not get token claims: square/go-jose: error in cryptographic primitive"),
101108
},
102109
{
@@ -105,6 +112,7 @@ func TestValidator_ValidateToken(t *testing.T) {
105112
keyFunc: func(context.Context) (interface{}, error) {
106113
return []byte("secret"), nil
107114
},
115+
algorithm: HS256,
108116
expectedError: errors.New("expected claims not validated: square/go-jose/jwt: validation failed, invalid audience claim (aud)"),
109117
},
110118
{
@@ -113,18 +121,19 @@ func TestValidator_ValidateToken(t *testing.T) {
113121
keyFunc: func(context.Context) (interface{}, error) {
114122
return []byte("secret"), nil
115123
},
116-
customClaims: &testClaims{
117-
ReturnError: errors.New("custom claims error message"),
124+
algorithm: HS256,
125+
customClaims: func() CustomClaims {
126+
return &testClaims{
127+
ReturnError: errors.New("custom claims error message"),
128+
}
118129
},
119130
expectedError: errors.New("custom claims not validated: custom claims error message"),
120131
},
121132
}
122133

123134
for _, testCase := range testCases {
124135
t.Run(testCase.name, func(t *testing.T) {
125-
if testCase.algorithm == "" {
126-
testCase.algorithm = HS256
127-
}
136+
t.Parallel()
128137

129138
validator, err := New(
130139
testCase.keyFunc,

0 commit comments

Comments
 (0)
Please sign in to comment.