Skip to content

Commit 51d23ce

Browse files
committed
fix(d1): buffer request body for retries and fix response body leaks
- Buffer request body so it can be reused across retry attempts - Properly drain and close response bodies before retries and on errors - Fix incorrect JSON tag on d1ExecResponse.Duration ("uuid" → "duration") - Remove redundant validation check in rows parsing - Add deferred response body close in StreamRowsContext
1 parent e14ec3a commit 51d23ce

1 file changed

Lines changed: 46 additions & 6 deletions

File tree

core/dbio/database/database_d1.go

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package database
22

33
import (
4+
"bytes"
45
"context"
56
"database/sql"
67
"encoding/json"
@@ -61,10 +62,24 @@ func (conn *D1Conn) makeRequest(ctx context.Context, method, route string, body
6162
URL = g.F("%s/%s/d1/database/%s/%s", urlBase, conn.AccountID, conn.UUID, route)
6263
}
6364

65+
// buffer body for potential retries (readers are single-use)
66+
var bodyBytes []byte
67+
if body != nil {
68+
bodyBytes, err = io.ReadAll(body)
69+
if err != nil {
70+
return nil, g.Error(err, "could not read request body for %s @ %s", method, URL)
71+
}
72+
}
73+
6474
retry:
6575
tries++
6676
g.Trace("request #%d for %s @ %s", tries, method, URL)
67-
req, err := http.NewRequestWithContext(ctx, method, URL, body)
77+
78+
var reqBody io.Reader
79+
if bodyBytes != nil {
80+
reqBody = bytes.NewReader(bodyBytes)
81+
}
82+
req, err := http.NewRequestWithContext(ctx, method, URL, reqBody)
6883
if err != nil {
6984
return nil, g.Error(err, "could not make request for %s @ %s", method, URL)
7085
}
@@ -79,16 +94,24 @@ retry:
7994
return
8095
}
8196

82-
// retry logic
97+
// retry logic for transient server errors / rate limits
8398
if (resp.StatusCode >= 502 || resp.StatusCode == 429) && tries <= 4 {
8499
delay := tries * 5
85100
g.Debug("d1 request failed %d: %s. Retrying in %d seconds.", resp.StatusCode, resp.Status, delay)
101+
// drain and close body before retry to avoid leaks
102+
if resp.Body != nil {
103+
io.Copy(io.Discard, resp.Body)
104+
resp.Body.Close()
105+
}
86106
time.Sleep(time.Duration(delay * int(time.Second)))
87107
goto retry
88108
}
89109

90110
if resp.StatusCode >= 400 || resp.StatusCode < 200 {
91111
respBytes, _ := io.ReadAll(resp.Body)
112+
if resp.Body != nil {
113+
resp.Body.Close()
114+
}
92115
err = g.Error("Unexpected Response %d: %s (%s) => %s", resp.StatusCode, resp.Status, URL, string(respBytes))
93116
return
94117
}
@@ -152,6 +175,9 @@ func (conn *D1Conn) GetDatabases() (data iop.Dataset, err error) {
152175
}
153176

154177
respBytes, err := io.ReadAll(resp.Body)
178+
if resp.Body != nil {
179+
resp.Body.Close()
180+
}
155181
if err != nil {
156182
return data, g.Error(err, "could not read from request body")
157183
}
@@ -179,7 +205,7 @@ type d1ExecResponse struct {
179205
Result []struct {
180206
Meta struct {
181207
ServedBy string `json:"served_by"`
182-
Duration float64 `json:"uuid"`
208+
Duration float64 `json:"duration"`
183209
Changes int64 `json:"changes"`
184210
LastRowID any `json:"last_row_id"`
185211
ChangedDB bool `json:"changed_db"`
@@ -236,6 +262,9 @@ func (conn *D1Conn) ExecContext(ctx context.Context, q string, args ...interface
236262
}
237263

238264
respBytes, err := io.ReadAll(resp.Body)
265+
if resp.Body != nil {
266+
resp.Body.Close()
267+
}
239268
if err != nil {
240269
return nil, g.Error(err, "could not read from request body")
241270
}
@@ -285,10 +314,14 @@ func (conn *D1Conn) StreamRowsContext(ctx context.Context, query string, options
285314
// g.Warn(string(respBytes))
286315
// return nil, g.Error("stopping")
287316

288-
decoder := json.NewDecoder(resp.Body)
317+
respBody := resp.Body
318+
decoder := json.NewDecoder(respBody)
289319

290320
// Read opening object
291321
if t, err := decoder.Token(); err != nil || t != json.Delim('{') {
322+
if respBody != nil {
323+
respBody.Close()
324+
}
292325
return nil, g.Error(err, "invalid JSON structure: expected opening brace")
293326
}
294327

@@ -350,8 +383,6 @@ func (conn *D1Conn) StreamRowsContext(ctx context.Context, query string, options
350383
t, err = decoder.Token()
351384
if err != nil || cast.ToString(t) != "rows" {
352385
return nil, g.Error(err, "invalid JSON structure: expected rows array inside results")
353-
} else if cast.ToString(t) != "rows" {
354-
return nil, g.Error("invalid JSON structure: expected rows array inside results")
355386
}
356387

357388
// Read opening bracket of rows array
@@ -405,17 +436,26 @@ func (conn *D1Conn) StreamRowsContext(ctx context.Context, query string, options
405436

406437
nextFunc, err := makeNextFunc()
407438
if err != nil {
439+
if respBody != nil {
440+
respBody.Close()
441+
}
408442
return ds, err
409443
}
410444

411445
ds = iop.NewDatastreamIt(queryContext.Ctx, fetchedColumns, nextFunc)
412446
ds.NoDebug = strings.Contains(query, noDebugKey)
413447
ds.SetMetadata(conn.GetProp("METADATA"))
414448
ds.SetConfig(conn.Props())
449+
if respBody != nil {
450+
ds.Defer(func() { respBody.Close() })
451+
}
415452

416453
err = ds.Start()
417454
if err != nil {
418455
queryContext.Cancel()
456+
if respBody != nil {
457+
respBody.Close()
458+
}
419459
return ds, g.Error(err, "could start datastream")
420460
}
421461

0 commit comments

Comments
 (0)