-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathquery.go
More file actions
93 lines (81 loc) · 2.49 KB
/
query.go
File metadata and controls
93 lines (81 loc) · 2.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
package gospice
import (
"context"
"fmt"
"strings"
"github.com/apache/arrow-go/v18/arrow/array"
"github.com/apache/arrow-go/v18/arrow/flight"
"github.com/cenkalti/backoff/v4"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// Sql executes a SQL query against Spice.ai and returns an Apache Arrow RecordReader
// For more information on Apache Arrow RecordReader visit https://godoc.org/github.com/apache/arrow/go/arrow/array#RecordReader
func (c *SpiceClient) Sql(ctx context.Context, sql string) (array.RecordReader, error) {
var rdr array.RecordReader
err := backoff.Retry(func() error {
var err error
rdr, err = queryInternal(ctx, c.flightClient, c.appId, c.apiKey, sql)
if err != nil {
st, ok := status.FromError(err)
if ok {
switch st.Code() {
case codes.Unavailable, codes.Unknown, codes.DeadlineExceeded, codes.Aborted, codes.Internal:
return err
}
if strings.Contains(err.Error(), "malformed header: missing HTTP content-type") {
return err
}
if err.Error() == "rpc error: code = Unknown desc = " {
return err
}
}
return backoff.Permanent(err)
}
return nil
}, backoff.WithMaxRetries(c.backoffPolicy, uint64(c.maxRetries)))
if err != nil {
return nil, err
}
return rdr, nil
}
// Query is deprecated. Use Sql instead.
// Kept for backward compatibility with v7.
func (c *SpiceClient) Query(ctx context.Context, sql string) (array.RecordReader, error) {
return c.Sql(ctx, sql)
}
func queryInternal(ctx context.Context, client flight.Client, appId string, apiKey string, sql string) (array.RecordReader, error) {
if client == nil {
return nil, fmt.Errorf("flight client is not initialized")
}
// Only authenticate if credentials are provided
queryCtx := ctx
if appId != "" && apiKey != "" {
authContext, err := client.AuthenticateBasicToken(ctx, appId, apiKey)
if err != nil {
return nil, err
}
queryCtx = authContext
}
fd := &flight.FlightDescriptor{
Type: flight.DescriptorCMD,
Cmd: []byte(sql),
}
info, err := client.GetFlightInfo(queryCtx, fd)
if err != nil {
return nil, err
}
stream, err := client.DoGet(queryCtx, info.Endpoint[0].Ticket)
if err != nil {
return nil, err
}
rdr, err := flight.NewRecordReader(stream)
if err != nil {
// Ensure stream is closed if reader creation fails
if closeErr := stream.CloseSend(); closeErr != nil {
return nil, fmt.Errorf("error creating record reader: %w (failed to close stream: %v)", err, closeErr)
}
return nil, err
}
return rdr, nil
}