Skip to content

Commit 39cf672

Browse files
committed
Add support for handling expired AWS credentials in SigV4 RoundTripper
- Implement automatic credential refresh for ExpiredTokenException - Refactor signRequest method to improve code reusability - Add test cases for expired token and different 403 error scenarios
1 parent d4b9a8d commit 39cf672

File tree

2 files changed

+119
-9
lines changed

2 files changed

+119
-9
lines changed

sigv4.go

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,25 @@ func (rt *sigV4RoundTripper) newBuf() interface{} {
102102
return bytes.NewBuffer(make([]byte, 0, 1024))
103103
}
104104

105+
func (rt *sigV4RoundTripper) signRequest(req *http.Request, signReq *http.Request, seeker io.ReadSeeker) error {
106+
_, _ = seeker.Seek(0, io.SeekStart)
107+
108+
headers, err := rt.signer.Sign(signReq, seeker, "aps", rt.region, time.Now().UTC())
109+
if err != nil {
110+
return fmt.Errorf("failed to sign request: %w", err)
111+
}
112+
113+
// Copy over signed headers. Authorization header is not returned by
114+
// rt.signer.Sign and needs to be copied separately.
115+
for k, v := range headers {
116+
req.Header[textproto.CanonicalMIMEHeaderKey(k)] = v
117+
}
118+
req.Header.Set("Authorization", signReq.Header.Get("Authorization"))
119+
120+
_, _ = seeker.Seek(0, io.SeekStart)
121+
return nil
122+
}
123+
105124
func (rt *sigV4RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
106125
// rt.signer.Sign needs a seekable body, so we replace the body with a
107126
// buffered reader filled with the contents of original body.
@@ -136,17 +155,23 @@ func (rt *sigV4RoundTripper) RoundTrip(req *http.Request) (*http.Response, error
136155
signReq.Header.Del(header)
137156
}
138157

139-
headers, err := rt.signer.Sign(signReq, seeker, "aps", rt.region, time.Now().UTC())
140-
if err != nil {
141-
return nil, fmt.Errorf("failed to sign request: %w", err)
158+
if err := rt.signRequest(req, signReq, seeker); err != nil {
159+
return nil, err
142160
}
143161

144-
// Copy over signed headers. Authorization header is not returned by
145-
// rt.signer.Sign and needs to be copied separately.
146-
for k, v := range headers {
147-
req.Header[textproto.CanonicalMIMEHeaderKey(k)] = v
162+
resp, err := rt.next.RoundTrip(req)
163+
if err != nil {
164+
return nil, err
165+
}
166+
// Credentials might expire during the request.
167+
// To gracefully handle that we force a refresh of the credentials and retry the request.
168+
if resp.StatusCode == http.StatusForbidden && resp.Header.Get("X-Amzn-Errortype") == "ExpiredTokenException" {
169+
rt.signer.Credentials.Expire()
170+
if err := rt.signRequest(req, signReq, seeker); err != nil {
171+
return nil, err
172+
}
173+
return rt.next.RoundTrip(req)
148174
}
149-
req.Header.Set("Authorization", signReq.Header.Get("Authorization"))
150175

151-
return rt.next.RoundTrip(req)
176+
return resp, err
152177
}

sigv4_test.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,22 @@ func (rt RoundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
3333
return rt(r)
3434
}
3535

36+
// trackingProvider wraps a credentials.Provider and tracks how many times Retrieve() is called.
37+
// This helps us determine that after Expire() a refresh (via Get) is triggered.
38+
type trackingProvider struct {
39+
provider credentials.Provider
40+
count int
41+
}
42+
43+
func (tp *trackingProvider) Retrieve() (credentials.Value, error) {
44+
tp.count++
45+
return tp.provider.Retrieve()
46+
}
47+
48+
func (tp *trackingProvider) IsExpired() bool {
49+
return tp.provider.IsExpired()
50+
}
51+
3652
func TestSigV4_Inferred_Region(t *testing.T) {
3753
os.Setenv("AWS_ACCESS_KEY_ID", "secret")
3854
os.Setenv("AWS_SECRET_ACCESS_KEY", "token")
@@ -114,4 +130,73 @@ func TestSigV4RoundTripper(t *testing.T) {
114130
_, err = cli.Do(req)
115131
require.NoError(t, err)
116132
})
133+
134+
t.Run("Expired Token", func(t *testing.T) {
135+
t.Run("Successful retry", func(t *testing.T) {
136+
// Setup tracking credentials to verify that after Expire()
137+
// a refresh is triggered by observing an extra call to Retrieve().
138+
// Note: The first signing call will trigger a Retrieve() so we expect count >= 2.
139+
// For example, if count == 1 then no refresh happened after Expire().
140+
tp := &trackingProvider{provider: &credentials.StaticProvider{
141+
Value: credentials.Value{
142+
AccessKeyID: "test-id",
143+
SecretAccessKey: "secret",
144+
SessionToken: "token",
145+
},
146+
}}
147+
trackingCreds := credentials.NewCredentials(tp)
148+
rt.signer.Credentials = trackingCreds
149+
150+
callCount := 0
151+
originalNext := rt.next
152+
defer func() { rt.next = originalNext }()
153+
154+
rt.next = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
155+
callCount++
156+
if callCount == 1 {
157+
return &http.Response{
158+
StatusCode: http.StatusForbidden,
159+
Header: http.Header{
160+
"X-Amzn-Errortype": []string{"ExpiredTokenException"},
161+
},
162+
}, nil
163+
}
164+
return &http.Response{StatusCode: http.StatusOK}, nil
165+
})
166+
167+
req, err := http.NewRequest(http.MethodPost, "https://example.com", strings.NewReader("Hello, world!"))
168+
require.NoError(t, err)
169+
170+
resp, err := cli.Do(req)
171+
require.NoError(t, err)
172+
173+
require.Equal(t, http.StatusOK, resp.StatusCode)
174+
require.Greater(t, tp.count, 1, "Expected credentials retrieval to be triggered at least twice (initial retrieval and after Expire())")
175+
})
176+
177+
t.Run("Different 403 error does not retry", func(t *testing.T) {
178+
callCount := 0
179+
originalNext := rt.next
180+
defer func() { rt.next = originalNext }()
181+
182+
rt.next = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
183+
callCount++
184+
return &http.Response{
185+
StatusCode: http.StatusForbidden,
186+
Header: http.Header{
187+
"X-Amzn-Errortype": []string{"SomeOtherError"},
188+
},
189+
}, nil
190+
})
191+
192+
req, err := http.NewRequest(http.MethodPost, "https://example.com", strings.NewReader("Hello, world!"))
193+
require.NoError(t, err)
194+
195+
resp, err := cli.Do(req)
196+
require.NoError(t, err)
197+
require.Equal(t, http.StatusForbidden, resp.StatusCode)
198+
199+
require.Equal(t, 1, callCount, "Expected only one HTTP call")
200+
})
201+
})
117202
}

0 commit comments

Comments
 (0)