@@ -33,6 +33,22 @@ func (rt RoundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
3333 return rt (r )
3434}
3535
36+ // trackingProvider wraps a credentials.Provider and tracks how many times Retrieve() is called.
37+ // This helps us determine that after Expire() a refresh (via Get) is triggered.
38+ type trackingProvider struct {
39+ provider credentials.Provider
40+ count int
41+ }
42+
43+ func (tp * trackingProvider ) Retrieve () (credentials.Value , error ) {
44+ tp .count ++
45+ return tp .provider .Retrieve ()
46+ }
47+
48+ func (tp * trackingProvider ) IsExpired () bool {
49+ return tp .provider .IsExpired ()
50+ }
51+
3652func TestSigV4_Inferred_Region (t * testing.T ) {
3753 os .Setenv ("AWS_ACCESS_KEY_ID" , "secret" )
3854 os .Setenv ("AWS_SECRET_ACCESS_KEY" , "token" )
@@ -114,4 +130,73 @@ func TestSigV4RoundTripper(t *testing.T) {
114130 _ , err = cli .Do (req )
115131 require .NoError (t , err )
116132 })
133+
134+ t .Run ("Expired Token" , func (t * testing.T ) {
135+ t .Run ("Successful retry" , func (t * testing.T ) {
136+ // Setup tracking credentials to verify that after Expire()
137+ // a refresh is triggered by observing an extra call to Retrieve().
138+ // Note: The first signing call will trigger a Retrieve() so we expect count >= 2.
139+ // For example, if count == 1 then no refresh happened after Expire().
140+ tp := & trackingProvider {provider : & credentials.StaticProvider {
141+ Value : credentials.Value {
142+ AccessKeyID : "test-id" ,
143+ SecretAccessKey : "secret" ,
144+ SessionToken : "token" ,
145+ },
146+ }}
147+ trackingCreds := credentials .NewCredentials (tp )
148+ rt .signer .Credentials = trackingCreds
149+
150+ callCount := 0
151+ originalNext := rt .next
152+ defer func () { rt .next = originalNext }()
153+
154+ rt .next = RoundTripperFunc (func (req * http.Request ) (* http.Response , error ) {
155+ callCount ++
156+ if callCount == 1 {
157+ return & http.Response {
158+ StatusCode : http .StatusForbidden ,
159+ Header : http.Header {
160+ "X-Amzn-Errortype" : []string {"ExpiredTokenException" },
161+ },
162+ }, nil
163+ }
164+ return & http.Response {StatusCode : http .StatusOK }, nil
165+ })
166+
167+ req , err := http .NewRequest (http .MethodPost , "https://example.com" , strings .NewReader ("Hello, world!" ))
168+ require .NoError (t , err )
169+
170+ resp , err := cli .Do (req )
171+ require .NoError (t , err )
172+
173+ require .Equal (t , http .StatusOK , resp .StatusCode )
174+ require .Greater (t , tp .count , 1 , "Expected credentials retrieval to be triggered at least twice (initial retrieval and after Expire())" )
175+ })
176+
177+ t .Run ("Different 403 error does not retry" , func (t * testing.T ) {
178+ callCount := 0
179+ originalNext := rt .next
180+ defer func () { rt .next = originalNext }()
181+
182+ rt .next = RoundTripperFunc (func (req * http.Request ) (* http.Response , error ) {
183+ callCount ++
184+ return & http.Response {
185+ StatusCode : http .StatusForbidden ,
186+ Header : http.Header {
187+ "X-Amzn-Errortype" : []string {"SomeOtherError" },
188+ },
189+ }, nil
190+ })
191+
192+ req , err := http .NewRequest (http .MethodPost , "https://example.com" , strings .NewReader ("Hello, world!" ))
193+ require .NoError (t , err )
194+
195+ resp , err := cli .Do (req )
196+ require .NoError (t , err )
197+ require .Equal (t , http .StatusForbidden , resp .StatusCode )
198+
199+ require .Equal (t , 1 , callCount , "Expected only one HTTP call" )
200+ })
201+ })
117202}
0 commit comments