Skip to content

Ensure Query Cancellation on Context Deadline Expiry #140

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 104 additions & 68 deletions trino/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1008,82 +1008,118 @@ func TestIntegrationUnsupportedHeader(t *testing.T) {
}
}

func TestIntegrationQueryContextCancellation(t *testing.T) {
err := RegisterCustomClient("uncompressed", &http.Client{Transport: &http.Transport{DisableCompression: true}})
if err != nil {
func TestIntegrationQueryContext(t *testing.T) {
tests := []struct {
name string
timeout time.Duration
expectedErrMsg string
}{
{
name: "Context Cancellation",
timeout: 0,
expectedErrMsg: "canceled",
},
{
name: "Context Deadline Exceeded",
timeout: 3 * time.Second,
expectedErrMsg: "context deadline exceeded",
},
}

if err := RegisterCustomClient("uncompressed", &http.Client{Transport: &http.Transport{DisableCompression: true}}); err != nil {
t.Fatal(err)
}
dsn := *integrationServerFlag
dsn += "?catalog=tpch&schema=sf100&source=cancel-test&custom_client=uncompressed"

dsn := *integrationServerFlag + "?catalog=tpch&schema=sf100&source=cancel-test&custom_client=uncompressed"
db := integrationOpen(t, dsn)
defer db.Close()

ctx, cancel := context.WithCancel(context.Background())
errCh := make(chan error, 3)
done := make(chan struct{})
longQuery := "SELECT COUNT(*) FROM lineitem"
go func() {
// query will complete in ~7s unless cancelled
rows, err := db.QueryContext(ctx, longQuery)
if err != nil {
errCh <- err
return
}
rows.Next()
if err = rows.Err(); err != nil {
errCh <- err
return
}
close(done)
}()

// poll system.runtime.queries and wait for query to start working
var queryID string
pollCtx, pollCancel := context.WithTimeout(context.Background(), 1*time.Second)
defer pollCancel()
for {
row := db.QueryRowContext(pollCtx, "SELECT query_id FROM system.runtime.queries WHERE state = 'RUNNING' AND source = 'cancel-test' AND query = ?", longQuery)
err := row.Scan(&queryID)
if err == nil {
break
}
if err != sql.ErrNoRows {
t.Fatal("failed to read query id", err)
}
if err = contextSleep(pollCtx, 100*time.Millisecond); err != nil {
t.Fatal("query did not start in 1 second")
}
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var ctx context.Context
var cancel context.CancelFunc

cancel()
if tt.timeout == 0 {
ctx, cancel = context.WithCancel(context.Background())
} else {
ctx, cancel = context.WithTimeout(context.Background(), tt.timeout)
}
defer cancel()

select {
case <-done:
t.Fatal("unexpected query with cancelled context succeeded")
break
case err = <-errCh:
if !strings.Contains(err.Error(), "canceled") {
t.Fatal("expected err to be canceled but got:", err)
}
}
errCh := make(chan error, 1)
done := make(chan struct{})
longQuery := "SELECT COUNT(*) FROM lineitem"

// poll system.runtime.queries and wait for query to be cancelled
pollCtx, pollCancel = context.WithTimeout(context.Background(), 1*time.Second)
defer pollCancel()
for {
row := db.QueryRowContext(pollCtx, "SELECT state, error_code FROM system.runtime.queries WHERE query_id = ?", queryID)
var state string
var code *string
err := row.Scan(&state, &code)
if err != nil {
t.Fatal("failed to read query id", err)
}
if state == "FAILED" && code != nil && *code == "USER_CANCELED" {
break
}
if err = contextSleep(pollCtx, 100*time.Millisecond); err != nil {
t.Fatal("query was not cancelled in 1 second; state, code, err are:", state, code, err)
}
go func() {
// query will complete in ~7s unless cancelled
rows, err := db.QueryContext(ctx, longQuery)
if err != nil {
errCh <- err
return
}
defer rows.Close()

rows.Next()
if err = rows.Err(); err != nil {
errCh <- err
return
}
close(done)
}()

// Poll system.runtime.queries to get the query ID
var queryID string
pollCtx, pollCancel := context.WithTimeout(context.Background(), 1*time.Second)
defer pollCancel()

for {
row := db.QueryRowContext(pollCtx, "SELECT query_id FROM system.runtime.queries WHERE state = 'RUNNING' AND source = 'cancel-test' AND query = ?", longQuery)
err := row.Scan(&queryID)
if err == nil {
break
}
if err != sql.ErrNoRows {
t.Fatal("failed to read query ID:", err)
}
if err = contextSleep(pollCtx, 100*time.Millisecond); err != nil {
t.Fatal("query did not start in 1 second")
}
}

if tt.timeout == 0 {
cancel()
}

// Wait for the query to be canceled or completed
select {
case <-done:
t.Fatal("unexpected query succeeded despite cancellation or deadline")
case err := <-errCh:
if !strings.Contains(err.Error(), tt.expectedErrMsg) {
t.Fatalf("expected error containing %q, but got: %v", tt.expectedErrMsg, err)
}
}

// Poll system.runtime.queries to verify the query was canceled
pollCtx, pollCancel = context.WithTimeout(context.Background(), 2*time.Second)
defer pollCancel()

for {
row := db.QueryRowContext(pollCtx, "SELECT state, error_code FROM system.runtime.queries WHERE query_id = ?", queryID)
var state string
var code *string
err := row.Scan(&state, &code)
if err != nil {
t.Fatal("failed to read query state:", err)
}
if state == "FAILED" && code != nil && *code == "USER_CANCELED" {
return
}
if err = contextSleep(pollCtx, 100*time.Millisecond); err != nil {
t.Fatalf("query was not canceled in 2 seconds; state: %s, code: %v, err: %v", state, code, err)
}
}
})
}
}

Expand Down
2 changes: 1 addition & 1 deletion trino/trino.go
Original file line number Diff line number Diff line change
Expand Up @@ -1342,7 +1342,7 @@ func (qr *driverRows) fetch() error {
// Channel was closed, which means the statement
// or rows were closed.
err = io.EOF
} else if err == context.Canceled {
} else if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
qr.Close()
}
qr.err = err
Expand Down
Loading