From 2f5be5ff1b510906e113212bd10905ffd5232779 Mon Sep 17 00:00:00 2001 From: "joao.folgado" Date: Tue, 1 Apr 2025 21:30:05 +0100 Subject: [PATCH] make sure when query is executed with a deadline, underlying trino query is canceled --- trino/integration_test.go | 172 +++++++++++++++++++++++--------------- trino/trino.go | 2 +- 2 files changed, 105 insertions(+), 69 deletions(-) diff --git a/trino/integration_test.go b/trino/integration_test.go index 6426e74..dc402a5 100644 --- a/trino/integration_test.go +++ b/trino/integration_test.go @@ -1008,82 +1008,118 @@ func TestIntegrationUnsupportedHeader(t *testing.T) { } } -func TestIntegrationQueryContextCancellation(t *testing.T) { - err := RegisterCustomClient("uncompressed", &http.Client{Transport: &http.Transport{DisableCompression: true}}) - if err != nil { +func TestIntegrationQueryContext(t *testing.T) { + tests := []struct { + name string + timeout time.Duration + expectedErrMsg string + }{ + { + name: "Context Cancellation", + timeout: 0, + expectedErrMsg: "canceled", + }, + { + name: "Context Deadline Exceeded", + timeout: 3 * time.Second, + expectedErrMsg: "context deadline exceeded", + }, + } + + if err := RegisterCustomClient("uncompressed", &http.Client{Transport: &http.Transport{DisableCompression: true}}); err != nil { t.Fatal(err) } - dsn := *integrationServerFlag - dsn += "?catalog=tpch&schema=sf100&source=cancel-test&custom_client=uncompressed" + + dsn := *integrationServerFlag + "?catalog=tpch&schema=sf100&source=cancel-test&custom_client=uncompressed" db := integrationOpen(t, dsn) defer db.Close() - ctx, cancel := context.WithCancel(context.Background()) - errCh := make(chan error, 3) - done := make(chan struct{}) - longQuery := "SELECT COUNT(*) FROM lineitem" - go func() { - // query will complete in ~7s unless cancelled - rows, err := db.QueryContext(ctx, longQuery) - if err != nil { - errCh <- err - return - } - rows.Next() - if err = rows.Err(); err != nil { - errCh <- err - return - } - close(done) - }() - - // poll system.runtime.queries and wait for query to start working - var queryID string - pollCtx, pollCancel := context.WithTimeout(context.Background(), 1*time.Second) - defer pollCancel() - for { - row := db.QueryRowContext(pollCtx, "SELECT query_id FROM system.runtime.queries WHERE state = 'RUNNING' AND source = 'cancel-test' AND query = ?", longQuery) - err := row.Scan(&queryID) - if err == nil { - break - } - if err != sql.ErrNoRows { - t.Fatal("failed to read query id", err) - } - if err = contextSleep(pollCtx, 100*time.Millisecond); err != nil { - t.Fatal("query did not start in 1 second") - } - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var ctx context.Context + var cancel context.CancelFunc - cancel() + if tt.timeout == 0 { + ctx, cancel = context.WithCancel(context.Background()) + } else { + ctx, cancel = context.WithTimeout(context.Background(), tt.timeout) + } + defer cancel() - select { - case <-done: - t.Fatal("unexpected query with cancelled context succeeded") - break - case err = <-errCh: - if !strings.Contains(err.Error(), "canceled") { - t.Fatal("expected err to be canceled but got:", err) - } - } + errCh := make(chan error, 1) + done := make(chan struct{}) + longQuery := "SELECT COUNT(*) FROM lineitem" - // poll system.runtime.queries and wait for query to be cancelled - pollCtx, pollCancel = context.WithTimeout(context.Background(), 1*time.Second) - defer pollCancel() - for { - row := db.QueryRowContext(pollCtx, "SELECT state, error_code FROM system.runtime.queries WHERE query_id = ?", queryID) - var state string - var code *string - err := row.Scan(&state, &code) - if err != nil { - t.Fatal("failed to read query id", err) - } - if state == "FAILED" && code != nil && *code == "USER_CANCELED" { - break - } - if err = contextSleep(pollCtx, 100*time.Millisecond); err != nil { - t.Fatal("query was not cancelled in 1 second; state, code, err are:", state, code, err) - } + go func() { + // query will complete in ~7s unless cancelled + rows, err := db.QueryContext(ctx, longQuery) + if err != nil { + errCh <- err + return + } + defer rows.Close() + + rows.Next() + if err = rows.Err(); err != nil { + errCh <- err + return + } + close(done) + }() + + // Poll system.runtime.queries to get the query ID + var queryID string + pollCtx, pollCancel := context.WithTimeout(context.Background(), 1*time.Second) + defer pollCancel() + + for { + row := db.QueryRowContext(pollCtx, "SELECT query_id FROM system.runtime.queries WHERE state = 'RUNNING' AND source = 'cancel-test' AND query = ?", longQuery) + err := row.Scan(&queryID) + if err == nil { + break + } + if err != sql.ErrNoRows { + t.Fatal("failed to read query ID:", err) + } + if err = contextSleep(pollCtx, 100*time.Millisecond); err != nil { + t.Fatal("query did not start in 1 second") + } + } + + if tt.timeout == 0 { + cancel() + } + + // Wait for the query to be canceled or completed + select { + case <-done: + t.Fatal("unexpected query succeeded despite cancellation or deadline") + case err := <-errCh: + if !strings.Contains(err.Error(), tt.expectedErrMsg) { + t.Fatalf("expected error containing %q, but got: %v", tt.expectedErrMsg, err) + } + } + + // Poll system.runtime.queries to verify the query was canceled + pollCtx, pollCancel = context.WithTimeout(context.Background(), 2*time.Second) + defer pollCancel() + + for { + row := db.QueryRowContext(pollCtx, "SELECT state, error_code FROM system.runtime.queries WHERE query_id = ?", queryID) + var state string + var code *string + err := row.Scan(&state, &code) + if err != nil { + t.Fatal("failed to read query state:", err) + } + if state == "FAILED" && code != nil && *code == "USER_CANCELED" { + return + } + if err = contextSleep(pollCtx, 100*time.Millisecond); err != nil { + t.Fatalf("query was not canceled in 2 seconds; state: %s, code: %v, err: %v", state, code, err) + } + } + }) } } diff --git a/trino/trino.go b/trino/trino.go index e12257c..d7b0c7e 100644 --- a/trino/trino.go +++ b/trino/trino.go @@ -1342,7 +1342,7 @@ func (qr *driverRows) fetch() error { // Channel was closed, which means the statement // or rows were closed. err = io.EOF - } else if err == context.Canceled { + } else if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { qr.Close() } qr.err = err