Skip to content

Commit 6e438fd

Browse files
authored
Fix erroneous context canceled errors (Azure#17797)
When per-try timeouts are enabled, the context can only be canceled after the body has been read and closed. Our tests missed this due to very small response body sizes that fit in the transport's read buffer. Made NoClosingBytesReader internally available.
1 parent 082b8e1 commit 6e438fd

File tree

7 files changed

+222
-118
lines changed

7 files changed

+222
-118
lines changed

sdk/azcore/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
* `Poller[T].PollUntilDone()` now takes an `options *PollUntilDoneOptions` param instead of `freq time.Duration`
1313

1414
### Bugs Fixed
15+
* When per-try timeouts are enabled, only cancel the context after the body has been read and closed.
1516

1617
### Other Changes
1718
* The functionality in `arm/runtime/poller.go` has been merged into `runtime/poller.go` so it should be used instead.

sdk/azcore/internal/exported/exported.go

Lines changed: 2 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
package exported
88

99
import (
10-
"errors"
1110
"io"
1211
"io/ioutil"
1312
"net/http"
@@ -49,69 +48,14 @@ func HasStatusCode(resp *http.Response, statusCodes ...int) bool {
4948
// Exported as runtime.Payload().
5049
func Payload(resp *http.Response) ([]byte, error) {
5150
// r.Body won't be a nopClosingBytesReader if downloading was skipped
52-
if buf, ok := resp.Body.(*nopClosingBytesReader); ok {
51+
if buf, ok := resp.Body.(*shared.NopClosingBytesReader); ok {
5352
return buf.Bytes(), nil
5453
}
5554
bytesBody, err := ioutil.ReadAll(resp.Body)
5655
resp.Body.Close()
5756
if err != nil {
5857
return nil, err
5958
}
60-
resp.Body = &nopClosingBytesReader{s: bytesBody, i: 0}
59+
resp.Body = shared.NewNopClosingBytesReader(bytesBody)
6160
return bytesBody, nil
6261
}
63-
64-
// NopClosingBytesReader is an io.ReadSeekCloser around a byte slice.
65-
// It also provides direct access to the byte slice to avoid rereading.
66-
type nopClosingBytesReader struct {
67-
s []byte
68-
i int64
69-
}
70-
71-
// Bytes returns the underlying byte slice.
72-
func (r *nopClosingBytesReader) Bytes() []byte {
73-
return r.s
74-
}
75-
76-
// Close implements the io.Closer interface.
77-
func (*nopClosingBytesReader) Close() error {
78-
return nil
79-
}
80-
81-
// Read implements the io.Reader interface.
82-
func (r *nopClosingBytesReader) Read(b []byte) (n int, err error) {
83-
if r.i >= int64(len(r.s)) {
84-
return 0, io.EOF
85-
}
86-
n = copy(b, r.s[r.i:])
87-
r.i += int64(n)
88-
return
89-
}
90-
91-
// Set replaces the existing byte slice with the specified byte slice and resets the reader.
92-
func (r *nopClosingBytesReader) Set(b []byte) {
93-
r.s = b
94-
r.i = 0
95-
}
96-
97-
// Seek implements the io.Seeker interface.
98-
func (r *nopClosingBytesReader) Seek(offset int64, whence int) (int64, error) {
99-
var i int64
100-
switch whence {
101-
case io.SeekStart:
102-
i = offset
103-
case io.SeekCurrent:
104-
i = r.i + offset
105-
case io.SeekEnd:
106-
i = int64(len(r.s)) + offset
107-
default:
108-
return 0, errors.New("nopClosingBytesReader: invalid whence")
109-
}
110-
if i < 0 {
111-
return 0, errors.New("nopClosingBytesReader: negative position")
112-
}
113-
r.i = i
114-
return i, nil
115-
}
116-
117-
var _ shared.BytesSetter = (*nopClosingBytesReader)(nil)

sdk/azcore/internal/exported/exported_test.go

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -55,62 +55,3 @@ func TestPayload(t *testing.T) {
5555
t.Fatalf("got %s, want %s", string(b), val)
5656
}
5757
}
58-
59-
func TestNopClosingBytesReader(t *testing.T) {
60-
const val1 = "the data"
61-
ncbr := &nopClosingBytesReader{s: []byte(val1)}
62-
b, err := io.ReadAll(ncbr)
63-
if err != nil {
64-
t.Fatal(err)
65-
}
66-
if string(b) != val1 {
67-
t.Fatalf("got %s, want %s", string(b), val1)
68-
}
69-
const val2 = "something else"
70-
ncbr.Set([]byte(val2))
71-
b, err = io.ReadAll(ncbr)
72-
if err != nil {
73-
t.Fatal(err)
74-
}
75-
if string(b) != val2 {
76-
t.Fatalf("got %s, want %s", string(b), val2)
77-
}
78-
if err = ncbr.Close(); err != nil {
79-
t.Fatal(err)
80-
}
81-
// seek to beginning and read again
82-
i, err := ncbr.Seek(0, io.SeekStart)
83-
if err != nil {
84-
t.Fatal(err)
85-
}
86-
if i != 0 {
87-
t.Fatalf("got %d, want %d", i, 0)
88-
}
89-
b, err = io.ReadAll(ncbr)
90-
if err != nil {
91-
t.Fatal(err)
92-
}
93-
if string(b) != val2 {
94-
t.Fatalf("got %s, want %s", string(b), val2)
95-
}
96-
// seek to middle from the end
97-
i, err = ncbr.Seek(-4, io.SeekEnd)
98-
if err != nil {
99-
t.Fatal(err)
100-
}
101-
if l := int64(len(val2)) - 4; i != l {
102-
t.Fatalf("got %d, want %d", l, i)
103-
}
104-
b, err = io.ReadAll(ncbr)
105-
if err != nil {
106-
t.Fatal(err)
107-
}
108-
if string(b) != "else" {
109-
t.Fatalf("got %s, want %s", string(b), "else")
110-
}
111-
// underflow
112-
_, err = ncbr.Seek(-int64(len(val2)+1), io.SeekCurrent)
113-
if err == nil {
114-
t.Fatal("unexpected nil error")
115-
}
116-
}

sdk/azcore/internal/shared/shared.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ package shared
88

99
import (
1010
"context"
11+
"errors"
12+
"io"
1113
"net/http"
1214
"reflect"
1315
"strconv"
@@ -63,3 +65,63 @@ func TypeOfT[T any]() reflect.Type {
6365
type BytesSetter interface {
6466
Set(b []byte)
6567
}
68+
69+
// NewNopClosingBytesReader creates a new *NopClosingBytesReader for the specified slice.
70+
func NewNopClosingBytesReader(data []byte) *NopClosingBytesReader {
71+
return &NopClosingBytesReader{s: data}
72+
}
73+
74+
// NopClosingBytesReader is an io.ReadSeekCloser around a byte slice.
75+
// It also provides direct access to the byte slice to avoid rereading.
76+
type NopClosingBytesReader struct {
77+
s []byte
78+
i int64
79+
}
80+
81+
// Bytes returns the underlying byte slice.
82+
func (r *NopClosingBytesReader) Bytes() []byte {
83+
return r.s
84+
}
85+
86+
// Close implements the io.Closer interface.
87+
func (*NopClosingBytesReader) Close() error {
88+
return nil
89+
}
90+
91+
// Read implements the io.Reader interface.
92+
func (r *NopClosingBytesReader) Read(b []byte) (n int, err error) {
93+
if r.i >= int64(len(r.s)) {
94+
return 0, io.EOF
95+
}
96+
n = copy(b, r.s[r.i:])
97+
r.i += int64(n)
98+
return
99+
}
100+
101+
// Set replaces the existing byte slice with the specified byte slice and resets the reader.
102+
func (r *NopClosingBytesReader) Set(b []byte) {
103+
r.s = b
104+
r.i = 0
105+
}
106+
107+
// Seek implements the io.Seeker interface.
108+
func (r *NopClosingBytesReader) Seek(offset int64, whence int) (int64, error) {
109+
var i int64
110+
switch whence {
111+
case io.SeekStart:
112+
i = offset
113+
case io.SeekCurrent:
114+
i = r.i + offset
115+
case io.SeekEnd:
116+
i = int64(len(r.s)) + offset
117+
default:
118+
return 0, errors.New("nopClosingBytesReader: invalid whence")
119+
}
120+
if i < 0 {
121+
return 0, errors.New("nopClosingBytesReader: negative position")
122+
}
123+
r.i = i
124+
return i, nil
125+
}
126+
127+
var _ BytesSetter = (*NopClosingBytesReader)(nil)

sdk/azcore/internal/shared/shared_test.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package shared
88

99
import (
1010
"context"
11+
"io"
1112
"net/http"
1213
"reflect"
1314
"testing"
@@ -63,3 +64,62 @@ func TestTypeOfT(t *testing.T) {
6364
t.Fatal("didn't expect types to match")
6465
}
6566
}
67+
68+
func TestNopClosingBytesReader(t *testing.T) {
69+
const val1 = "the data"
70+
ncbr := &NopClosingBytesReader{s: []byte(val1)}
71+
b, err := io.ReadAll(ncbr)
72+
if err != nil {
73+
t.Fatal(err)
74+
}
75+
if string(b) != val1 {
76+
t.Fatalf("got %s, want %s", string(b), val1)
77+
}
78+
const val2 = "something else"
79+
ncbr.Set([]byte(val2))
80+
b, err = io.ReadAll(ncbr)
81+
if err != nil {
82+
t.Fatal(err)
83+
}
84+
if string(b) != val2 {
85+
t.Fatalf("got %s, want %s", string(b), val2)
86+
}
87+
if err = ncbr.Close(); err != nil {
88+
t.Fatal(err)
89+
}
90+
// seek to beginning and read again
91+
i, err := ncbr.Seek(0, io.SeekStart)
92+
if err != nil {
93+
t.Fatal(err)
94+
}
95+
if i != 0 {
96+
t.Fatalf("got %d, want %d", i, 0)
97+
}
98+
b, err = io.ReadAll(ncbr)
99+
if err != nil {
100+
t.Fatal(err)
101+
}
102+
if string(b) != val2 {
103+
t.Fatalf("got %s, want %s", string(b), val2)
104+
}
105+
// seek to middle from the end
106+
i, err = ncbr.Seek(-4, io.SeekEnd)
107+
if err != nil {
108+
t.Fatal(err)
109+
}
110+
if l := int64(len(val2)) - 4; i != l {
111+
t.Fatalf("got %d, want %d", l, i)
112+
}
113+
b, err = io.ReadAll(ncbr)
114+
if err != nil {
115+
t.Fatal(err)
116+
}
117+
if string(b) != "else" {
118+
t.Fatalf("got %s, want %s", string(b), "else")
119+
}
120+
// underflow
121+
_, err = ncbr.Seek(-int64(len(val2)+1), io.SeekCurrent)
122+
if err == nil {
123+
t.Fatal("unexpected nil error")
124+
}
125+
}

sdk/azcore/runtime/policy_retry.go

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,15 @@ func (p *retryPolicy) Do(req *policy.Request) (resp *http.Response, err error) {
127127
tryCtx, tryCancel := context.WithTimeout(req.Raw().Context(), options.TryTimeout)
128128
clone := req.Clone(tryCtx)
129129
resp, err = clone.Next() // Make the request
130-
tryCancel()
130+
// if the body was already downloaded or there was an error it's safe to cancel the context now
131+
if err != nil {
132+
tryCancel()
133+
} else if _, ok := resp.Body.(*shared.NopClosingBytesReader); ok {
134+
tryCancel()
135+
} else {
136+
// must cancel the context after the body has been read and closed
137+
resp.Body = &contextCancelReadCloser{cf: tryCancel, body: resp.Body}
138+
}
131139
}
132140
if err == nil {
133141
log.Writef(log.EventRetryPolicy, "response %d", resp.StatusCode)
@@ -213,3 +221,22 @@ func (b *retryableRequestBody) realClose() error {
213221
}
214222
return nil
215223
}
224+
225+
// ********** The following type/methods implement the contextCancelReadCloser
226+
227+
// contextCancelReadCloser combines an io.ReadCloser with a cancel func.
228+
// it ensures the cancel func is invoked once the body has been read and closed.
229+
type contextCancelReadCloser struct {
230+
cf context.CancelFunc
231+
body io.ReadCloser
232+
}
233+
234+
func (rc *contextCancelReadCloser) Read(p []byte) (n int, err error) {
235+
return rc.body.Read(p)
236+
}
237+
238+
func (rc *contextCancelReadCloser) Close() error {
239+
err := rc.body.Close()
240+
rc.cf()
241+
return err
242+
}

0 commit comments

Comments
 (0)