Skip to content

Commit 02f3c0f

Browse files
authored
fix: QuoteRef reserved keywords on table update sql tasks (#522)
* QuoteRef reserved keywords on table update sql tasks * fix structure
1 parent cd7338a commit 02f3c0f

3 files changed

Lines changed: 102 additions & 7 deletions

File tree

ext/store/maxcompute/sanitizer.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package maxcompute
2+
3+
import "strings"
4+
5+
var reservedKeywords = []string{
6+
"add", "after", "all", "alter", "analyze", "and", "archive", "array", "as", "asc",
7+
"before", "between", "bigint", "binary", "blob", "boolean", "both", "decimal",
8+
"bucket", "buckets", "by", "cascade", "case", "cast", "cfile", "change", "cluster",
9+
"clustered", "clusterstatus", "collection", "column", "columns", "comment", "compute",
10+
"concatenate", "continue", "create", "cross", "current", "cursor", "data", "database",
11+
"databases", "date", "datetime", "dbproperties", "deferred", "delete", "delimited",
12+
"desc", "describe", "directory", "disable", "distinct", "distribute", "double", "drop",
13+
"else", "enable", "end", "except", "escaped", "exclusive", "exists", "explain", "export",
14+
"extended", "external", "false", "fetch", "fields", "fileformat", "first", "float",
15+
"following", "format", "formatted", "from", "full", "function", "functions", "grant",
16+
"group", "having", "hold_ddltime", "idxproperties", "if", "import", "in", "index",
17+
"indexes", "inpath", "inputdriver", "inputformat", "insert", "int", "intersect", "into",
18+
"is", "items", "join", "keys", "lateral", "left", "lifecycle", "like", "limit", "lines",
19+
"load", "local", "location", "lock", "locks", "long", "map", "mapjoin", "materialized",
20+
"minus", "msck", "not", "no_drop", "null", "of", "offline", "offset", "on", "option",
21+
"or", "order", "out", "outer", "outputdriver", "outputformat", "over", "overwrite",
22+
"partition", "partitioned", "partitionproperties", "partitions", "percent", "plus",
23+
"preceding", "preserve", "procedure", "purge", "range", "rcfile", "read", "readonly",
24+
"reads", "rebuild", "recordreader", "recordwriter", "reduce", "regexp", "rename",
25+
"repair", "replace", "restrict", "revoke", "right", "rlike", "row", "rows", "schema",
26+
"schemas", "select", "semi", "sequencefile", "serde", "serdeproperties", "set", "shared",
27+
"show", "show_database", "smallint", "sort", "sorted", "ssl", "statistics", "status",
28+
"stored", "streamtable", "string", "struct", "table", "tables", "tablesample",
29+
"tblproperties", "temporary", "terminated", "textfile", "then", "timestamp", "tinyint",
30+
"to", "touch", "transform", "trigger", "true", "type", "unarchive", "unbounded", "undo",
31+
"union", "uniontype", "uniquejoin", "unlock", "unsigned", "update", "use", "using",
32+
"utc", "utc_timestamp", "view", "when", "where", "while", "div",
33+
}
34+
35+
var reservedKeywordsMap map[string]struct{}
36+
37+
//nolint:gochecknoinits
38+
func init() {
39+
reservedKeywordsMap = make(map[string]struct{}, len(reservedKeywords))
40+
for _, kw := range reservedKeywords {
41+
reservedKeywordsMap[kw] = struct{}{}
42+
}
43+
}
44+
45+
func QuoteIdentifier(identifier string) string {
46+
if _, ok := reservedKeywordsMap[strings.ToLower(identifier)]; ok {
47+
return "`" + identifier + "`"
48+
}
49+
return identifier
50+
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package maxcompute_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
8+
"github.com/goto/optimus/ext/store/maxcompute"
9+
)
10+
11+
func TestSanitizer(t *testing.T) {
12+
t.Run("returns quoted identifier for reserved keywords", func(t *testing.T) {
13+
testCases := []struct {
14+
input string
15+
expected string
16+
}{
17+
{"select", "`select`"},
18+
{"from", "`from`"},
19+
{"case", "`case`"},
20+
{"table", "`table`"},
21+
{"SELECT", "`SELECT`"},
22+
{"From", "`From`"},
23+
}
24+
25+
for _, tc := range testCases {
26+
result := maxcompute.QuoteIdentifier(tc.input)
27+
assert.Equal(t, tc.expected, result)
28+
}
29+
})
30+
31+
t.Run("returns identifier unchanged when not a reserved keyword", func(t *testing.T) {
32+
testCases := []struct {
33+
input string
34+
expected string
35+
}{
36+
{"customer_name", "customer_name"},
37+
{"other", "other"},
38+
}
39+
40+
for _, tc := range testCases {
41+
result := maxcompute.QuoteIdentifier(tc.input)
42+
assert.Equal(t, tc.expected, result)
43+
}
44+
})
45+
}

ext/store/maxcompute/table.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -160,14 +160,14 @@ func populateColumns(t *Table, schemaBuilder *tableschema.SchemaBuilder) error {
160160
func generateUpdateQuery(incoming, existing tableschema.TableSchema, schemaName string) ([]string, error) {
161161
var sqlTasks []string
162162
if incoming.Comment != existing.Comment {
163-
sqlTasks = append(sqlTasks, fmt.Sprintf("alter table %s.%s set comment %s;", schemaName, existing.TableName, common.QuoteString(incoming.Comment)))
163+
sqlTasks = append(sqlTasks, fmt.Sprintf("alter table %s.%s set comment %s;", QuoteIdentifier(schemaName), QuoteIdentifier(existing.TableName), common.QuoteString(incoming.Comment)))
164164
}
165165

166166
if incoming.Lifecycle != existing.Lifecycle {
167167
if incoming.Lifecycle <= 0 && existing.Lifecycle >= 0 {
168-
sqlTasks = append(sqlTasks, fmt.Sprintf("alter table %s.%s disable lifecycle;", schemaName, existing.TableName))
168+
sqlTasks = append(sqlTasks, fmt.Sprintf("alter table %s.%s disable lifecycle;", QuoteIdentifier(schemaName), QuoteIdentifier(existing.TableName)))
169169
} else if incoming.Lifecycle > 0 {
170-
sqlTasks = append(sqlTasks, fmt.Sprintf("alter table %s.%s set lifecycle %d;", schemaName, existing.TableName, incoming.Lifecycle))
170+
sqlTasks = append(sqlTasks, fmt.Sprintf("alter table %s.%s set lifecycle %d;", QuoteIdentifier(schemaName), QuoteIdentifier(existing.TableName), incoming.Lifecycle))
171171
}
172172
}
173173

@@ -259,7 +259,7 @@ func getNormalColumnDifferences(tableName, schemaName string, incoming []ColumnR
259259
if incomingColumnRecord.columnValue.NotNull {
260260
return fmt.Errorf("unable to add new required column")
261261
}
262-
segment := fmt.Sprintf("if not exists %s %s", incomingColumnRecord.columnStructure, incomingColumnRecord.columnValue.Type.Name())
262+
segment := fmt.Sprintf("if not exists %s %s", QuoteIdentifier(incomingColumnRecord.columnStructure), incomingColumnRecord.columnValue.Type.Name())
263263
if incomingColumnRecord.columnValue.Comment != "" {
264264
segment += fmt.Sprintf(" comment %s", common.QuoteString(incomingColumnRecord.columnValue.Comment))
265265
}
@@ -270,7 +270,7 @@ func getNormalColumnDifferences(tableName, schemaName string, incoming []ColumnR
270270
if !columnFound.NotNull && incomingColumnRecord.columnValue.NotNull {
271271
return fmt.Errorf("unable to modify column mode from nullable to required")
272272
} else if columnFound.NotNull && !incomingColumnRecord.columnValue.NotNull {
273-
*sqlTasks = append(*sqlTasks, fmt.Sprintf("alter table %s.%s change column %s null;", schemaName, tableName, columnFound.Name))
273+
*sqlTasks = append(*sqlTasks, fmt.Sprintf("alter table %s.%s change column %s null;", QuoteIdentifier(schemaName), QuoteIdentifier(tableName), QuoteIdentifier(columnFound.Name)))
274274
}
275275

276276
if columnFound.Type.ID() != incomingColumnRecord.columnValue.Type.ID() {
@@ -279,7 +279,7 @@ func getNormalColumnDifferences(tableName, schemaName string, incoming []ColumnR
279279

280280
if incomingColumnRecord.columnValue.Comment != columnFound.Comment {
281281
*sqlTasks = append(*sqlTasks, fmt.Sprintf("alter table %s.%s change column %s %s %s comment %s;",
282-
schemaName, tableName, columnFound.Name, incomingColumnRecord.columnValue.Name, columnFound.Type, common.QuoteString(incomingColumnRecord.columnValue.Comment)))
282+
QuoteIdentifier(schemaName), QuoteIdentifier(tableName), QuoteIdentifier(columnFound.Name), QuoteIdentifier(incomingColumnRecord.columnValue.Name), columnFound.Type, common.QuoteString(incomingColumnRecord.columnValue.Comment)))
283283
}
284284
delete(existing, incomingColumnRecord.columnStructure)
285285
}
@@ -292,7 +292,7 @@ func getNormalColumnDifferences(tableName, schemaName string, incoming []ColumnR
292292

293293
if len(columnAddition) > 0 {
294294
for _, segment := range columnAddition {
295-
addColumnQuery := fmt.Sprintf("alter table %s.%s add column ", schemaName, tableName) + segment + ";"
295+
addColumnQuery := fmt.Sprintf("alter table %s.%s add column ", QuoteIdentifier(schemaName), QuoteIdentifier(tableName)) + segment + ";"
296296
*sqlTasks = append(*sqlTasks, addColumnQuery)
297297
}
298298
}

0 commit comments

Comments
 (0)