Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 34 additions & 9 deletions sigv4.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,25 @@ func (rt *sigV4RoundTripper) newBuf() interface{} {
return bytes.NewBuffer(make([]byte, 0, 1024))
}

func (rt *sigV4RoundTripper) signRequest(req *http.Request, signReq *http.Request, seeker io.ReadSeeker) error {
_, _ = seeker.Seek(0, io.SeekStart)

headers, err := rt.signer.Sign(signReq, seeker, "aps", rt.region, time.Now().UTC())
if err != nil {
return fmt.Errorf("failed to sign request: %w", err)
}

// Copy over signed headers. Authorization header is not returned by
// rt.signer.Sign and needs to be copied separately.
for k, v := range headers {
req.Header[textproto.CanonicalMIMEHeaderKey(k)] = v
}
req.Header.Set("Authorization", signReq.Header.Get("Authorization"))

_, _ = seeker.Seek(0, io.SeekStart)
return nil
}

func (rt *sigV4RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
// rt.signer.Sign needs a seekable body, so we replace the body with a
// buffered reader filled with the contents of original body.
Expand Down Expand Up @@ -136,17 +155,23 @@ func (rt *sigV4RoundTripper) RoundTrip(req *http.Request) (*http.Response, error
signReq.Header.Del(header)
}

headers, err := rt.signer.Sign(signReq, seeker, "aps", rt.region, time.Now().UTC())
if err != nil {
return nil, fmt.Errorf("failed to sign request: %w", err)
if err := rt.signRequest(req, signReq, seeker); err != nil {
return nil, err
}

// Copy over signed headers. Authorization header is not returned by
// rt.signer.Sign and needs to be copied separately.
for k, v := range headers {
req.Header[textproto.CanonicalMIMEHeaderKey(k)] = v
resp, err := rt.next.RoundTrip(req)
if err != nil {
return nil, err
}
// Credentials might expire during the request.
// To gracefully handle that we force a refresh of the credentials and retry the request.
if resp.StatusCode == http.StatusForbidden && resp.Header.Get("X-Amzn-Errortype") == "ExpiredTokenException" {
rt.signer.Credentials.Expire()
if err := rt.signRequest(req, signReq, seeker); err != nil {
return nil, err
}
return rt.next.RoundTrip(req)
}
req.Header.Set("Authorization", signReq.Header.Get("Authorization"))

return rt.next.RoundTrip(req)
return resp, err
}
85 changes: 85 additions & 0 deletions sigv4_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,22 @@ func (rt RoundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return rt(r)
}

// trackingProvider wraps a credentials.Provider and tracks how many times Retrieve() is called.
// This helps us determine that after Expire() a refresh (via Get) is triggered.
type trackingProvider struct {
provider credentials.Provider
count int
}

func (tp *trackingProvider) Retrieve() (credentials.Value, error) {
tp.count++
return tp.provider.Retrieve()
}

func (tp *trackingProvider) IsExpired() bool {
return tp.provider.IsExpired()
}

func TestSigV4_Inferred_Region(t *testing.T) {
os.Setenv("AWS_ACCESS_KEY_ID", "secret")
os.Setenv("AWS_SECRET_ACCESS_KEY", "token")
Expand Down Expand Up @@ -114,4 +130,73 @@ func TestSigV4RoundTripper(t *testing.T) {
_, err = cli.Do(req)
require.NoError(t, err)
})

t.Run("Expired Token", func(t *testing.T) {
t.Run("Successful retry", func(t *testing.T) {
// Setup tracking credentials to verify that after Expire()
// a refresh is triggered by observing an extra call to Retrieve().
// Note: The first signing call will trigger a Retrieve() so we expect count >= 2.
// For example, if count == 1 then no refresh happened after Expire().
tp := &trackingProvider{provider: &credentials.StaticProvider{
Value: credentials.Value{
AccessKeyID: "test-id",
SecretAccessKey: "secret",
SessionToken: "token",
},
}}
trackingCreds := credentials.NewCredentials(tp)
rt.signer.Credentials = trackingCreds

callCount := 0
originalNext := rt.next
defer func() { rt.next = originalNext }()

rt.next = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
callCount++
if callCount == 1 {
return &http.Response{
StatusCode: http.StatusForbidden,
Header: http.Header{
"X-Amzn-Errortype": []string{"ExpiredTokenException"},
},
}, nil
}
return &http.Response{StatusCode: http.StatusOK}, nil
})

req, err := http.NewRequest(http.MethodPost, "https://example.com", strings.NewReader("Hello, world!"))
require.NoError(t, err)

resp, err := cli.Do(req)
require.NoError(t, err)

require.Equal(t, http.StatusOK, resp.StatusCode)
require.Greater(t, tp.count, 1, "Expected credentials retrieval to be triggered at least twice (initial retrieval and after Expire())")
})

t.Run("Different 403 error does not retry", func(t *testing.T) {
callCount := 0
originalNext := rt.next
defer func() { rt.next = originalNext }()

rt.next = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
callCount++
return &http.Response{
StatusCode: http.StatusForbidden,
Header: http.Header{
"X-Amzn-Errortype": []string{"SomeOtherError"},
},
}, nil
})

req, err := http.NewRequest(http.MethodPost, "https://example.com", strings.NewReader("Hello, world!"))
require.NoError(t, err)

resp, err := cli.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusForbidden, resp.StatusCode)

require.Equal(t, 1, callCount, "Expected only one HTTP call")
})
})
}