Skip to content

Commit 30c3148

Browse files
Clean up retrying token source tests
Simplify test assertions: compare token and error values directly instead of using boolean flags. Rename callCount to gotNumCalls for clarity. Co-authored-by: Isaac Signed-off-by: Ubuntu <renaud.hartert@databricks.com>
1 parent dd57752 commit 30c3148

File tree

1 file changed

+37
-45
lines changed

1 file changed

+37
-45
lines changed

config/experimental/auth/retrying_token_source_test.go

Lines changed: 37 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"net/http"
88
"net/url"
99
"testing"
10-
"time"
1110

1211
"golang.org/x/oauth2"
1312
)
@@ -23,107 +22,100 @@ func (e *testHTTPError) Error() string { return fmt.Sprintf("http %d: %s",
2322
func (e *testHTTPError) HTTPStatusCode() int { return e.code }
2423

2524
func TestRetryingTokenSource(t *testing.T) {
26-
validToken := &oauth2.Token{
27-
AccessToken: "test-token",
28-
Expiry: time.Now().Add(time.Hour),
29-
}
25+
err401 := &testHTTPError{code: 401, message: "unauthorized"}
26+
err400 := &testHTTPError{code: 400, message: "bad request"}
27+
token := &oauth2.Token{AccessToken: "test-token"}
3028

3129
testCases := []struct {
32-
name string
33-
callErrors []error
34-
wantToken bool
35-
wantErr bool
36-
wantCalls int
30+
name string
31+
callErrors []error
32+
wantToken *oauth2.Token
33+
wantErr error
34+
wantNumCalls int
3735
}{
3836
{
39-
name: "success on first call",
40-
callErrors: []error{nil},
41-
wantToken: true,
42-
wantCalls: 1,
37+
name: "success on first call",
38+
callErrors: []error{nil},
39+
wantToken: token,
40+
wantNumCalls: 1,
4341
},
4442
{
4543
name: "retry on http 429 then succeed",
4644
callErrors: []error{
4745
&testHTTPError{code: 429, message: "rate limited"},
4846
nil,
4947
},
50-
wantToken: true,
51-
wantCalls: 2,
48+
wantToken: token,
49+
wantNumCalls: 2,
5250
},
5351
{
5452
name: "retry on http 503 then succeed",
5553
callErrors: []error{
5654
&testHTTPError{code: 503, message: "service unavailable"},
5755
nil,
5856
},
59-
wantToken: true,
60-
wantCalls: 2,
57+
wantToken: token,
58+
wantNumCalls: 2,
6159
},
6260
{
6361
name: "retry on oauth2 retrieve error 429",
6462
callErrors: []error{
6563
&oauth2.RetrieveError{Response: &http.Response{StatusCode: 429}},
6664
nil,
6765
},
68-
wantToken: true,
69-
wantCalls: 2,
66+
wantToken: token,
67+
wantNumCalls: 2,
7068
},
7169
{
7270
name: "retry on transient network error",
7371
callErrors: []error{
7472
&url.Error{Op: "Post", URL: "https://host/token", Err: fmt.Errorf("connection reset by peer")},
7573
nil,
7674
},
77-
wantToken: true,
78-
wantCalls: 2,
75+
wantToken: token,
76+
wantNumCalls: 2,
7977
},
8078
{
8179
name: "no retry on http 401",
8280
callErrors: []error{
83-
&testHTTPError{code: 401, message: "unauthorized"},
81+
err401,
8482
},
85-
wantErr: true,
86-
wantCalls: 1,
83+
wantErr: err401,
84+
wantNumCalls: 1,
8785
},
8886
{
8987
name: "no retry on http 400",
9088
callErrors: []error{
91-
&testHTTPError{code: 400, message: "bad request"},
89+
err400,
9290
},
93-
wantErr: true,
94-
wantCalls: 1,
91+
wantErr: err400,
92+
wantNumCalls: 1,
9593
},
9694
}
9795

9896
for _, tc := range testCases {
9997
t.Run(tc.name, func(t *testing.T) {
100-
callCount := 0
98+
gotNumCalls := 0
10199
inner := TokenSourceFn(func(ctx context.Context) (*oauth2.Token, error) {
102-
err := tc.callErrors[callCount]
103-
callCount++
100+
err := tc.callErrors[gotNumCalls]
101+
gotNumCalls++
104102
if err != nil {
105103
return nil, err
106104
}
107-
return validToken, nil
105+
return token, nil
108106
})
109107

110108
ts := NewRetryingTokenSource(inner)
111-
tok, err := ts.Token(context.Background())
109+
gotToken, gotErr := ts.Token(context.Background())
112110

113-
if callCount != tc.wantCalls {
114-
t.Errorf("got %d calls, want %d", callCount, tc.wantCalls)
115-
}
116-
if tc.wantErr && err == nil {
117-
t.Error("got nil error, want error")
118-
}
119-
if !tc.wantErr && err != nil {
120-
t.Errorf("got error %v, want nil", err)
111+
if gotNumCalls != tc.wantNumCalls {
112+
t.Errorf("got %d calls, want %d", gotNumCalls, tc.wantNumCalls)
121113
}
122-
if tc.wantToken && tok == nil {
123-
t.Error("got nil token, want token")
114+
if !errors.Is(gotErr, tc.wantErr) {
115+
t.Errorf("got error %v, want %v", gotErr, tc.wantErr)
124116
}
125-
if !tc.wantToken && tok != nil {
126-
t.Errorf("got token %v, want nil", tok)
117+
if gotToken != tc.wantToken {
118+
t.Errorf("got token %v, want %v", gotToken, tc.wantToken)
127119
}
128120
})
129121
}

0 commit comments

Comments
 (0)