diff --git a/client.go b/client.go index adbdd92..a8f731a 100644 --- a/client.go +++ b/client.go @@ -535,12 +535,42 @@ func PassthroughErrorHandler(resp *http.Response, err error, _ int) (*http.Respo // Do wraps calling an HTTP method with retries. func (c *Client) Do(req *Request) (*http.Response, error) { + return c.DoFunc(func() (*Request, error) { + // Make shallow copy of http Request so that we can modify its body + // without racing against the closeBody call in persistConn.writeLoop. + httpreq := *req.Request + req.Request = &httpreq + return req, nil + }) +} + +// Do wraps calling an HTTP method with retries. +// +// Usage: +// import ( +// ... +// v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" +// ) +// ... +// +// resp, err := client.DoFunc(func()(*retryablehttp.Request, error) { +// req, err := http.NewRequestWithContext(...) +// signer := v4.NewSigner(...) +// signer.Sign(ctx, req, ...) +// return retryablehttp.FromRequest(req) +// }) +func (c *Client) DoFunc(f func() (*Request, error)) (*http.Response, error) { c.clientInit.Do(func() { if c.HTTPClient == nil { c.HTTPClient = cleanhttp.DefaultPooledClient() } }) + req, err := f() + if err != nil { + return nil, err + } + logger := c.logger() if logger != nil { @@ -655,10 +685,10 @@ func (c *Client) Do(req *Request) (*http.Response, error) { case <-time.After(wait): } - // Make shallow copy of http Request so that we can modify its body - // without racing against the closeBody call in persistConn.writeLoop. - httpreq := *req.Request - req.Request = &httpreq + req, err = f() + if err != nil { + return nil, err + } } // this is the closest we have to success criteria @@ -668,7 +698,7 @@ func (c *Client) Do(req *Request) (*http.Response, error) { defer c.HTTPClient.CloseIdleConnections() - err := doErr + err = doErr if checkErr != nil { err = checkErr }