77 "net/http"
88 "net/url"
99 "testing"
10- "time"
1110
1211 "golang.org/x/oauth2"
1312)
@@ -23,107 +22,100 @@ func (e *testHTTPError) Error() string { return fmt.Sprintf("http %d: %s",
2322func (e * testHTTPError ) HTTPStatusCode () int { return e .code }
2423
2524func TestRetryingTokenSource (t * testing.T ) {
26- validToken := & oauth2.Token {
27- AccessToken : "test-token" ,
28- Expiry : time .Now ().Add (time .Hour ),
29- }
25+ err401 := & testHTTPError {code : 401 , message : "unauthorized" }
26+ err400 := & testHTTPError {code : 400 , message : "bad request" }
27+ token := & oauth2.Token {AccessToken : "test-token" }
3028
3129 testCases := []struct {
32- name string
33- callErrors []error
34- wantToken bool
35- wantErr bool
36- wantCalls int
30+ name string
31+ callErrors []error
32+ wantToken * oauth2. Token
33+ wantErr error
34+ wantNumCalls int
3735 }{
3836 {
39- name : "success on first call" ,
40- callErrors : []error {nil },
41- wantToken : true ,
42- wantCalls : 1 ,
37+ name : "success on first call" ,
38+ callErrors : []error {nil },
39+ wantToken : token ,
40+ wantNumCalls : 1 ,
4341 },
4442 {
4543 name : "retry on http 429 then succeed" ,
4644 callErrors : []error {
4745 & testHTTPError {code : 429 , message : "rate limited" },
4846 nil ,
4947 },
50- wantToken : true ,
51- wantCalls : 2 ,
48+ wantToken : token ,
49+ wantNumCalls : 2 ,
5250 },
5351 {
5452 name : "retry on http 503 then succeed" ,
5553 callErrors : []error {
5654 & testHTTPError {code : 503 , message : "service unavailable" },
5755 nil ,
5856 },
59- wantToken : true ,
60- wantCalls : 2 ,
57+ wantToken : token ,
58+ wantNumCalls : 2 ,
6159 },
6260 {
6361 name : "retry on oauth2 retrieve error 429" ,
6462 callErrors : []error {
6563 & oauth2.RetrieveError {Response : & http.Response {StatusCode : 429 }},
6664 nil ,
6765 },
68- wantToken : true ,
69- wantCalls : 2 ,
66+ wantToken : token ,
67+ wantNumCalls : 2 ,
7068 },
7169 {
7270 name : "retry on transient network error" ,
7371 callErrors : []error {
7472 & url.Error {Op : "Post" , URL : "https://host/token" , Err : fmt .Errorf ("connection reset by peer" )},
7573 nil ,
7674 },
77- wantToken : true ,
78- wantCalls : 2 ,
75+ wantToken : token ,
76+ wantNumCalls : 2 ,
7977 },
8078 {
8179 name : "no retry on http 401" ,
8280 callErrors : []error {
83- & testHTTPError { code : 401 , message : "unauthorized" } ,
81+ err401 ,
8482 },
85- wantErr : true ,
86- wantCalls : 1 ,
83+ wantErr : err401 ,
84+ wantNumCalls : 1 ,
8785 },
8886 {
8987 name : "no retry on http 400" ,
9088 callErrors : []error {
91- & testHTTPError { code : 400 , message : "bad request" } ,
89+ err400 ,
9290 },
93- wantErr : true ,
94- wantCalls : 1 ,
91+ wantErr : err400 ,
92+ wantNumCalls : 1 ,
9593 },
9694 }
9795
9896 for _ , tc := range testCases {
9997 t .Run (tc .name , func (t * testing.T ) {
100- callCount := 0
98+ gotNumCalls := 0
10199 inner := TokenSourceFn (func (ctx context.Context ) (* oauth2.Token , error ) {
102- err := tc .callErrors [callCount ]
103- callCount ++
100+ err := tc .callErrors [gotNumCalls ]
101+ gotNumCalls ++
104102 if err != nil {
105103 return nil , err
106104 }
107- return validToken , nil
105+ return token , nil
108106 })
109107
110108 ts := NewRetryingTokenSource (inner )
111- tok , err := ts .Token (context .Background ())
109+ gotToken , gotErr := ts .Token (context .Background ())
112110
113- if callCount != tc .wantCalls {
114- t .Errorf ("got %d calls, want %d" , callCount , tc .wantCalls )
115- }
116- if tc .wantErr && err == nil {
117- t .Error ("got nil error, want error" )
118- }
119- if ! tc .wantErr && err != nil {
120- t .Errorf ("got error %v, want nil" , err )
111+ if gotNumCalls != tc .wantNumCalls {
112+ t .Errorf ("got %d calls, want %d" , gotNumCalls , tc .wantNumCalls )
121113 }
122- if tc . wantToken && tok == nil {
123- t .Error ("got nil token , want token" )
114+ if ! errors . Is ( gotErr , tc . wantErr ) {
115+ t .Errorf ("got error %v , want %v" , gotErr , tc . wantErr )
124116 }
125- if ! tc . wantToken && tok != nil {
126- t .Errorf ("got token %v, want nil " , tok )
117+ if gotToken != tc . wantToken {
118+ t .Errorf ("got token %v, want %v " , gotToken , tc . wantToken )
127119 }
128120 })
129121 }
0 commit comments