@@ -11,22 +11,67 @@ import (
1111
1212 uuid "github.com/satori/go.uuid"
1313
14- "github.com/aws/aws-sdk-go/aws"
15- "github.com/aws/aws-sdk-go/aws/session"
16- "github.com/aws/aws-sdk-go/service/athena"
17- "github.com/aws/aws-sdk-go/service/athena/athenaiface"
14+ "github.com/aws/aws-sdk-go-v2/aws"
15+ "github.com/aws/aws-sdk-go-v2/service/athena"
16+ "github.com/aws/aws-sdk-go-v2/service/athena/types"
1817)
1918
19+ // Query type patterns
20+ var (
21+ ddlQueryPattern = regexp .MustCompile (`(?i)^(ALTER|CREATE|DESCRIBE|DROP|MSCK|SHOW)` )
22+ selectQueryPattern = regexp .MustCompile (`(?i)^SELECT` )
23+ ctasQueryPattern = regexp .MustCompile (`(?i)^CREATE.+AS\s+SELECT` )
24+ )
25+
26+ // queryType represents the type of SQL query
27+ type queryType int
28+
29+ const (
30+ queryTypeUnknown queryType = iota
31+ queryTypeDDL
32+ queryTypeSelect
33+ queryTypeCTAS
34+ )
35+
36+ // getQueryType determines the type of the query
37+ func getQueryType (query string ) queryType {
38+ switch {
39+ case ddlQueryPattern .MatchString (query ):
40+ return queryTypeDDL
41+ case ctasQueryPattern .MatchString (query ):
42+ return queryTypeCTAS
43+ case selectQueryPattern .MatchString (query ):
44+ return queryTypeSelect
45+ default :
46+ return queryTypeUnknown
47+ }
48+ }
49+
50+ // isDDLQuery determines if the query is a DDL statement
51+ func isDDLQuery (query string ) bool {
52+ return getQueryType (query ) == queryTypeDDL
53+ }
54+
55+ // isSelectQuery determines if the query is a SELECT statement
56+ func isSelectQuery (query string ) bool {
57+ return getQueryType (query ) == queryTypeSelect
58+ }
59+
60+ // isCTASQuery determines if the query is a CREATE TABLE AS SELECT statement
61+ func isCTASQuery (query string ) bool {
62+ return getQueryType (query ) == queryTypeCTAS
63+ }
64+
2065type conn struct {
21- athena athenaiface. AthenaAPI
66+ athena * athena. Client
2267 db string
2368 OutputLocation string
2469 workgroup string
2570
2671 pollFrequency time.Duration
2772
2873 resultMode ResultMode
29- session * session. Session
74+ config aws. Config
3075 timeout uint
3176 catalog string
3277}
@@ -54,6 +99,9 @@ func (c *conn) runQuery(ctx context.Context, query string) (driver.Rows, error)
5499 isSelect := isSelectQuery (query )
55100 resultMode := c .resultMode
56101 if rmode , ok := getResultMode (ctx ); ok {
102+ if ! isValidResultMode (rmode ) {
103+ return nil , ErrInvalidResultMode
104+ }
57105 resultMode = rmode
58106 }
59107 if ! isSelect {
@@ -91,7 +139,7 @@ func (c *conn) runQuery(ctx context.Context, query string) (driver.Rows, error)
91139 afterDownload = c .dropCTASTable (ctx , ctasTable )
92140 }
93141
94- queryID , err := c .startQuery (query )
142+ queryID , err := c .startQuery (ctx , query )
95143 if err != nil {
96144 return nil , err
97145 }
@@ -105,7 +153,7 @@ func (c *conn) runQuery(ctx context.Context, query string) (driver.Rows, error)
105153 QueryID : queryID ,
106154 SkipHeader : ! isDDLQuery (query ),
107155 ResultMode : resultMode ,
108- Session : c . session ,
156+ Config : c . config ,
109157 OutputLocation : c .OutputLocation ,
110158 Timeout : timeout ,
111159 AfterDownload : afterDownload ,
@@ -119,7 +167,7 @@ func (c *conn) dropCTASTable(ctx context.Context, table string) func() error {
119167 return func () error {
120168 query := fmt .Sprintf ("DROP TABLE %s" , table )
121169
122- queryID , err := c .startQuery (query )
170+ queryID , err := c .startQuery (ctx , query )
123171 if err != nil {
124172 return err
125173 }
@@ -129,13 +177,13 @@ func (c *conn) dropCTASTable(ctx context.Context, table string) func() error {
129177}
130178
131179// startQuery starts an Athena query and returns its ID.
132- func (c * conn ) startQuery (query string ) (string , error ) {
133- resp , err := c .athena .StartQueryExecution (& athena.StartQueryExecutionInput {
180+ func (c * conn ) startQuery (ctx context. Context , query string ) (string , error ) {
181+ resp , err := c .athena .StartQueryExecution (ctx , & athena.StartQueryExecutionInput {
134182 QueryString : aws .String (query ),
135- QueryExecutionContext : & athena .QueryExecutionContext {
183+ QueryExecutionContext : & types .QueryExecutionContext {
136184 Database : aws .String (c .db ),
137185 },
138- ResultConfiguration : & athena .ResultConfiguration {
186+ ResultConfiguration : & types .ResultConfiguration {
139187 OutputLocation : aws .String (c .OutputLocation ),
140188 },
141189 WorkGroup : aws .String (c .workgroup ),
@@ -150,28 +198,28 @@ func (c *conn) startQuery(query string) (string, error) {
150198// waitOnQuery blocks until a query finishes, returning an error if it failed.
151199func (c * conn ) waitOnQuery (ctx context.Context , queryID string ) error {
152200 for {
153- statusResp , err := c .athena .GetQueryExecutionWithContext (ctx , & athena.GetQueryExecutionInput {
201+ statusResp , err := c .athena .GetQueryExecution (ctx , & athena.GetQueryExecutionInput {
154202 QueryExecutionId : aws .String (queryID ),
155203 })
156204 if err != nil {
157205 return err
158206 }
159207
160- switch * statusResp .QueryExecution .Status .State {
161- case athena .QueryExecutionStateCancelled :
208+ switch statusResp .QueryExecution .Status .State {
209+ case types .QueryExecutionStateCancelled :
162210 return context .Canceled
163- case athena .QueryExecutionStateFailed :
211+ case types .QueryExecutionStateFailed :
164212 reason := * statusResp .QueryExecution .Status .StateChangeReason
165213 return errors .New (reason )
166- case athena .QueryExecutionStateSucceeded :
214+ case types .QueryExecutionStateSucceeded :
167215 return nil
168- case athena .QueryExecutionStateQueued :
169- case athena .QueryExecutionStateRunning :
216+ case types .QueryExecutionStateQueued :
217+ case types .QueryExecutionStateRunning :
170218 }
171219
172220 select {
173221 case <- ctx .Done ():
174- c .athena .StopQueryExecution (& athena.StopQueryExecutionInput {
222+ c .athena .StopQueryExecution (ctx , & athena.StopQueryExecutionInput {
175223 QueryExecutionId : aws .String (queryID ),
176224 })
177225
@@ -229,7 +277,7 @@ func (c *conn) prepareContext(ctx context.Context, query string) (driver.Stmt, e
229277 prepareKey := fmt .Sprintf ("tmp_prepare_%v" , strings .Replace (uuid .NewV4 ().String (), "-" , "" , - 1 ))
230278 newQuery := fmt .Sprintf ("PREPARE %s FROM %s" , prepareKey , query )
231279
232- queryID , err := c .startQuery (newQuery )
280+ queryID , err := c .startQuery (ctx , newQuery )
233281 if err != nil {
234282 return nil , err
235283 }
@@ -273,22 +321,16 @@ func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) {
273321var _ driver.Queryer = (* conn )(nil )
274322var _ driver.Execer = (* conn )(nil )
275323
276- // supported DDL statements by Athena
277- // https://docs.aws.amazon.com/athena/latest/ug/language-reference.html
278- var ddlQueryRegex = regexp .MustCompile (`(?i)^(ALTER|CREATE|DESCRIBE|DROP|MSCK|SHOW)` )
279-
280- func isDDLQuery (query string ) bool {
281- return ddlQueryRegex .Match ([]byte (query ))
282- }
283-
284- func isSelectQuery (query string ) bool {
285- return regexp .MustCompile (`(?i)^SELECT` ).Match ([]byte (query ))
286- }
287-
288- func isCTASQuery (query string ) bool {
289- return regexp .MustCompile (`(?i)^CREATE.+AS\s+SELECT` ).Match ([]byte (query ))
290- }
291-
292324func isCreatingCTASTable (isSelect bool , resultMode ResultMode ) bool {
293325 return isSelect && resultMode == ResultModeGzipDL
294326}
327+
328+ // isValidResultMode checks if the given result mode is valid
329+ func isValidResultMode (mode ResultMode ) bool {
330+ switch mode {
331+ case ResultModeAPI , ResultModeDL , ResultModeGzipDL :
332+ return true
333+ default :
334+ return false
335+ }
336+ }
0 commit comments