-
Notifications
You must be signed in to change notification settings - Fork 70
Fix unnecessary cancel requests #136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1138,27 +1138,34 @@ func (qr *driverRows) Close() error { | |
if qr.err == sql.ErrNoRows || qr.err == io.EOF { | ||
return nil | ||
} | ||
|
||
qr.err = io.EOF | ||
hs := make(http.Header) | ||
if qr.stmt.user != "" { | ||
hs.Add(trinoUserHeader, qr.stmt.user) | ||
} | ||
ctx, cancel := context.WithTimeout(context.WithoutCancel(qr.ctx), DefaultCancelQueryTimeout) | ||
defer cancel() | ||
req, err := qr.stmt.conn.newRequest(ctx, "DELETE", qr.stmt.conn.baseURL+"/v1/query/"+url.PathEscape(qr.queryID), nil, hs) | ||
if err != nil { | ||
return err | ||
} | ||
resp, err := qr.stmt.conn.roundTrip(ctx, req) | ||
if err != nil { | ||
qferr, ok := err.(*ErrQueryFailed) | ||
if ok && qferr.StatusCode == http.StatusNoContent { | ||
qr.nextURI = "" | ||
return nil | ||
|
||
if qr.nextURI != "" { | ||
hs := make(http.Header) | ||
if qr.stmt.user != "" { | ||
hs.Add(trinoUserHeader, qr.stmt.user) | ||
} | ||
return err | ||
|
||
ctx, cancel := context.WithTimeout(context.WithoutCancel(qr.ctx), DefaultCancelQueryTimeout) | ||
defer cancel() | ||
req, err := qr.stmt.conn.newRequest(ctx, "DELETE", qr.nextURI, nil, hs) | ||
if err != nil { | ||
return err | ||
} | ||
resp, err := qr.stmt.conn.roundTrip(ctx, req) | ||
if err != nil { | ||
qferr, ok := err.(*ErrQueryFailed) | ||
if ok && qferr.StatusCode == http.StatusNoContent { | ||
qr.nextURI = "" | ||
return nil | ||
} | ||
return err | ||
} | ||
resp.Body.Close() | ||
|
||
} | ||
resp.Body.Close() | ||
|
||
return qr.err | ||
} | ||
|
||
|
@@ -1205,6 +1212,7 @@ func (qr *driverRows) Next(dest []driver.Value) error { | |
if qr.err != nil { | ||
return qr.err | ||
} | ||
|
||
if qr.columns == nil || qr.rowindex >= len(qr.data) { | ||
if qr.nextURI == "" { | ||
qr.err = io.EOF | ||
|
@@ -1215,22 +1223,34 @@ func (qr *driverRows) Next(dest []driver.Value) error { | |
return err | ||
} | ||
} | ||
|
||
if len(qr.coltype) == 0 { | ||
qr.err = sql.ErrNoRows | ||
return qr.err | ||
} | ||
for i, v := range qr.coltype { | ||
if i > len(dest)-1 { | ||
|
||
row := qr.data[qr.rowindex] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move any refactor not directly related to the issue this PR fixes to a separate commit. |
||
for i, colType := range qr.coltype { | ||
if i >= len(dest) { | ||
break | ||
} | ||
vv, err := v.ConvertValue(qr.data[qr.rowindex][i]) | ||
val, err := colType.ConvertValue(row[i]) | ||
if err != nil { | ||
qr.err = err | ||
return err | ||
} | ||
dest[i] = vv | ||
dest[i] = val | ||
} | ||
|
||
qr.rowindex++ | ||
|
||
// Prefetch next set of rows | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Useless comment - it explains the what, which is pretty obvious, but not the why, which is not. Why do we need to do this? If the user doesn't request any more data, we'll waste resources on one round-trip. Can we do this only when cancelling the query? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have to rethink this PR. Now that I'm more familiar with the project, we can't approach it this way—it's a bit of a mess. |
||
if qr.rowindex == len(qr.data) && qr.nextURI != "" { | ||
if err := qr.fetch(); err != nil { | ||
qr.err = err | ||
} | ||
} | ||
|
||
return nil | ||
} | ||
|
||
|
@@ -1330,6 +1350,7 @@ func (qr *driverRows) fetch() error { | |
return err | ||
} | ||
qr.rowindex = 0 | ||
qr.nextURI = qresp.NextURI | ||
qr.data = qresp.Data | ||
qr.rowsAffected = qresp.UpdateCount | ||
qr.scheduleProgressUpdate(qresp.ID, qresp.Stats) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1973,3 +1973,198 @@ func TestForwardAuthorizationHeader(t *testing.T) { | |
|
||
assert.NoError(t, db.Close()) | ||
} | ||
|
||
func TestPagination(t *testing.T) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you come up with a better name? I'm not sure what this test is for. |
||
var buf, buf2, buf3 *bytes.Buffer | ||
var ts *httptest.Server | ||
ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
if r.URL.Path == "/v1/statement" { | ||
if buf == nil { | ||
buf = new(bytes.Buffer) | ||
|
||
json.NewEncoder(buf).Encode(&stmtResponse{ | ||
ID: "fake-query", | ||
NextURI: ts.URL + "/v1/statement/20210817_140827_00000_arvdv/1", | ||
Stats: stmtStats{ | ||
State: "QUEUED", | ||
}, | ||
}) | ||
} | ||
w.WriteHeader(http.StatusOK) | ||
w.Write(buf.Bytes()) | ||
return | ||
} | ||
|
||
if r.URL.Path == "/v1/statement/20210817_140827_00000_arvdv/1" { | ||
if buf2 == nil { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are separate buffers needed here? Can we initialize them early, and remove this condition? |
||
buf2 = new(bytes.Buffer) | ||
json.NewEncoder(buf2).Encode(&queryResponse{ | ||
ID: "fake-query", | ||
NextURI: ts.URL + "/v1/statement/20210817_140827_00000_arvdv/2", | ||
Columns: []queryColumn{ | ||
{ | ||
Name: "_col0", | ||
Type: "integer", | ||
TypeSignature: typeSignature{ | ||
RawType: "integer", | ||
Arguments: []typeArgument{}, | ||
}, | ||
}, | ||
}, | ||
Data: []queryData{ | ||
{1}, | ||
}, | ||
Stats: stmtStats{ | ||
State: "FINISHED", | ||
}, | ||
}) | ||
} | ||
w.WriteHeader(http.StatusOK) | ||
w.Write(buf2.Bytes()) | ||
return | ||
} | ||
|
||
if r.URL.Path == "/v1/statement/20210817_140827_00000_arvdv/2" { | ||
if buf3 == nil { | ||
buf3 = new(bytes.Buffer) | ||
json.NewEncoder(buf3).Encode(&queryResponse{ | ||
ID: "fake-query", | ||
Columns: []queryColumn{ | ||
{ | ||
Name: "_col1", | ||
Type: "integer", | ||
TypeSignature: typeSignature{ | ||
RawType: "integer", | ||
Arguments: []typeArgument{}, | ||
}, | ||
}, | ||
}, | ||
Data: []queryData{ | ||
{2}, | ||
}, | ||
Stats: stmtStats{ | ||
State: "FINISHED", | ||
}, | ||
}) | ||
} | ||
w.WriteHeader(http.StatusOK) | ||
w.Write(buf3.Bytes()) | ||
return | ||
} | ||
|
||
w.WriteHeader(http.StatusInternalServerError) | ||
json.NewEncoder(w).Encode(ErrTrino{ErrorName: "Unexpected request"}) | ||
})) | ||
|
||
defer ts.Close() | ||
|
||
db, err := sql.Open("trino", ts.URL) | ||
require.NoError(t, err) | ||
defer db.Close() | ||
|
||
// Run a query | ||
rows, err := db.Query("SELECT 1") | ||
|
||
var results []int | ||
for rows.Next() { | ||
var value int | ||
err := rows.Scan(&value) | ||
require.NoError(t, err) | ||
results = append(results, value) | ||
} | ||
|
||
// Ensure no error in iteration | ||
require.NoError(t, rows.Err()) | ||
|
||
// Assert expected results | ||
assert.Equal(t, []int{1, 2}, results, "Expected query results to match") | ||
} | ||
|
||
func TestQuerySingleRowDoesNotTriggerDeleteRequest(t *testing.T) { | ||
var buf, buf2, buf3 *bytes.Buffer | ||
var ts *httptest.Server | ||
var methodUsed string | ||
|
||
ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
if r.URL.Path == "/v1/statement" { | ||
if buf == nil { | ||
buf = new(bytes.Buffer) | ||
|
||
json.NewEncoder(buf).Encode(&stmtResponse{ | ||
ID: "fake-query", | ||
NextURI: ts.URL + "/v1/statement/20210817_140827_00000_arvdv/1", | ||
Stats: stmtStats{ | ||
State: "QUEUED", | ||
}, | ||
}) | ||
} | ||
w.WriteHeader(http.StatusOK) | ||
w.Write(buf.Bytes()) | ||
return | ||
} | ||
|
||
if r.URL.Path == "/v1/statement/20210817_140827_00000_arvdv/1" { | ||
if buf2 == nil { | ||
buf2 = new(bytes.Buffer) | ||
json.NewEncoder(buf2).Encode(&queryResponse{ | ||
ID: "fake-query", | ||
NextURI: ts.URL + "/v1/statement/20210817_140827_00000_arvdv/2", | ||
Columns: []queryColumn{ | ||
{ | ||
Name: "_col0", | ||
Type: "integer", | ||
TypeSignature: typeSignature{ | ||
RawType: "integer", | ||
Arguments: []typeArgument{}, | ||
}, | ||
}, | ||
}, | ||
Data: []queryData{ | ||
{1}, | ||
}, | ||
Stats: stmtStats{ | ||
State: "FINISHED", | ||
}, | ||
}) | ||
} | ||
w.WriteHeader(http.StatusOK) | ||
w.Write(buf2.Bytes()) | ||
return | ||
} | ||
|
||
if r.URL.Path == "/v1/statement/20210817_140827_00000_arvdv/2" { | ||
methodUsed = r.Method | ||
|
||
if buf3 == nil { | ||
buf3 = new(bytes.Buffer) | ||
json.NewEncoder(buf3).Encode(&queryResponse{ | ||
Stats: stmtStats{ | ||
State: "FINISHED", | ||
}, | ||
}) | ||
} | ||
w.WriteHeader(http.StatusOK) | ||
w.Write(buf3.Bytes()) | ||
return | ||
} | ||
|
||
w.WriteHeader(http.StatusInternalServerError) | ||
json.NewEncoder(w).Encode(ErrTrino{ErrorName: "Unexpected request"}) | ||
})) | ||
|
||
defer ts.Close() | ||
|
||
db, err := sql.Open("trino", ts.URL) | ||
require.NoError(t, err) | ||
defer db.Close() | ||
|
||
var v int | ||
|
||
err = db.QueryRow("SELECT 1").Scan(&v) | ||
|
||
require.NoError(t, err) | ||
|
||
assert.Equal(t, 1, v, "Expected query results to match") | ||
|
||
assert.NotEqual(t, http.MethodDelete, methodUsed, "Expected HTTP method to be GET") | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reverse this condition and do an early return to avoid indenting the remaining code.