Skip to content

Commit f7eb311

Browse files
committed
Merge branch 'master' into fargate
2 parents baeba6c + 410278f commit f7eb311

File tree

9 files changed

+1773
-766
lines changed

9 files changed

+1773
-766
lines changed

embedded/sql/engine_test.go

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3202,6 +3202,69 @@ func TestQuery(t *testing.T) {
32023202
})
32033203
}
32043204

3205+
func TestExtractFromTimestamp(t *testing.T) {
3206+
st, err := store.Open(t.TempDir(), store.DefaultOptions().WithMultiIndexing(true))
3207+
require.NoError(t, err)
3208+
defer closeStore(t, st)
3209+
3210+
engine, err := NewEngine(st, DefaultOptions().WithPrefix(sqlPrefix))
3211+
require.NoError(t, err)
3212+
3213+
t.Run("extract from constant expressions", func(t *testing.T) {
3214+
assertQueryShouldProduceResults(
3215+
t,
3216+
engine,
3217+
`SELECT
3218+
EXTRACT(YEAR FROM '2020-01-15'),
3219+
EXTRACT(MONTH FROM '2020-01-15'),
3220+
EXTRACT(DAY FROM '2020-01-15'::TIMESTAMP),
3221+
EXTRACT(HOUR FROM '2020-01-15 12:30:24'),
3222+
EXTRACT(MINUTE FROM '2020-01-15 12:30:24'),
3223+
EXTRACT(SECOND FROM '2020-01-15 12:30:24'::TIMESTAMP)
3224+
`,
3225+
`SELECT * FROM (
3226+
VALUES (2020, 01, 15, 12, 30, 24)
3227+
)`,
3228+
)
3229+
})
3230+
3231+
t.Run("extract from table", func(t *testing.T) {
3232+
_, _, err := engine.Exec(
3233+
context.Background(),
3234+
nil,
3235+
`CREATE TABLE events(ts TIMESTAMP PRIMARY KEY);
3236+
3237+
INSERT INTO events(ts) VALUES
3238+
('2021-07-04 14:45:30'::TIMESTAMP),
3239+
('1999-12-31 23:59:59'::TIMESTAMP);
3240+
`,
3241+
nil,
3242+
)
3243+
require.NoError(t, err)
3244+
3245+
assertQueryShouldProduceResults(
3246+
t,
3247+
engine,
3248+
`SELECT
3249+
EXTRACT(YEAR FROM ts),
3250+
EXTRACT(MONTH FROM ts),
3251+
EXTRACT(DAY FROM ts),
3252+
EXTRACT(HOUR FROM ts),
3253+
EXTRACT(MINUTE FROM ts),
3254+
EXTRACT(SECOND FROM ts)
3255+
FROM events
3256+
ORDER BY ts
3257+
`,
3258+
`SELECT * FROM (
3259+
VALUES
3260+
(1999, 12, 31, 23, 59, 59),
3261+
(2021, 07, 04, 14, 45, 30)
3262+
)`,
3263+
)
3264+
})
3265+
3266+
}
3267+
32053268
func TestJSON(t *testing.T) {
32063269
opts := store.DefaultOptions().WithMultiIndexing(true)
32073270
opts.WithIndexOptions(opts.IndexOpts.WithMaxActiveSnapshots(1))
@@ -9450,6 +9513,51 @@ func TestFunctions(t *testing.T) {
94509513
)
94519514
require.NoError(t, err)
94529515

9516+
t.Run("coalesce", func(t *testing.T) {
9517+
type testCase struct {
9518+
query string
9519+
expectedValues string
9520+
err error
9521+
}
9522+
9523+
cases := []testCase{
9524+
{
9525+
query: "SELECT COALESCE (NULL)",
9526+
expectedValues: "NULL",
9527+
},
9528+
{
9529+
query: "SELECT COALESCE (NULL, NULL)",
9530+
expectedValues: "NULL",
9531+
},
9532+
{
9533+
query: "SELECT COALESCE(NULL, 1, 1.5, 3)",
9534+
expectedValues: "1",
9535+
},
9536+
{
9537+
query: "SELECT COALESCE('one', 'two', 'three')",
9538+
expectedValues: "'one'",
9539+
},
9540+
{
9541+
query: "SELECT COALESCE(1, 'test')",
9542+
err: ErrInvalidTypes,
9543+
},
9544+
}
9545+
9546+
for _, tc := range cases {
9547+
if tc.err != nil {
9548+
_, err := engine.queryAll(context.Background(), nil, tc.query, nil)
9549+
require.ErrorIs(t, err, tc.err)
9550+
continue
9551+
}
9552+
9553+
assertQueryShouldProduceResults(
9554+
t,
9555+
engine,
9556+
tc.query,
9557+
fmt.Sprintf("SELECT * FROM (VALUES (%s))", tc.expectedValues))
9558+
}
9559+
})
9560+
94539561
t.Run("timestamp functions", func(t *testing.T) {
94549562
_, err := engine.queryAll(context.Background(), nil, "SELECT NOW(1) FROM mytable", nil)
94559563
require.ErrorIs(t, err, ErrIllegalArguments)

embedded/sql/functions.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
)
2626

2727
const (
28+
CoalesceFnCall string = "COALESCE"
2829
LengthFnCall string = "LENGTH"
2930
SubstringFnCall string = "SUBSTRING"
3031
ConcatFnCall string = "CONCAT"
@@ -47,6 +48,7 @@ const (
4748
)
4849

4950
var builtinFunctions = map[string]Function{
51+
CoalesceFnCall: &CoalesceFn{},
5052
LengthFnCall: &LengthFn{},
5153
SubstringFnCall: &SubstringFn{},
5254
ConcatFnCall: &ConcatFn{},
@@ -67,6 +69,37 @@ type Function interface {
6769
Apply(tx *SQLTx, params []TypedValue) (TypedValue, error)
6870
}
6971

72+
type CoalesceFn struct{}
73+
74+
func (f *CoalesceFn) InferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
75+
return AnyType, nil
76+
}
77+
78+
func (f *CoalesceFn) RequiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
79+
return nil
80+
}
81+
82+
func (f *CoalesceFn) Apply(tx *SQLTx, params []TypedValue) (TypedValue, error) {
83+
t := AnyType
84+
85+
for _, p := range params {
86+
if !p.IsNull() {
87+
if t == AnyType {
88+
t = p.Type()
89+
} else if p.Type() != t && !(IsNumericType(t) && IsNumericType(p.Type())) {
90+
return nil, fmt.Errorf("coalesce: %w", ErrInvalidTypes)
91+
}
92+
}
93+
}
94+
95+
for _, p := range params {
96+
if !p.IsNull() {
97+
return p, nil
98+
}
99+
}
100+
return NewNull(t), nil
101+
}
102+
70103
// -------------------------------------
71104
// String Functions
72105
// -------------------------------------

embedded/sql/parser.go

Lines changed: 56 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import (
2828

2929
//go:generate go run golang.org/x/tools/cmd/goyacc -l -o sql_parser.go sql_grammar.y
3030

31-
var reservedWords = map[string]int{
31+
var keywords = map[string]int{
3232
"CREATE": CREATE,
3333
"DROP": DROP,
3434
"USE": USE,
@@ -118,6 +118,21 @@ var reservedWords = map[string]int{
118118
"THEN": THEN,
119119
"ELSE": ELSE,
120120
"END": END,
121+
"EXTRACT": EXTRACT,
122+
"INTEGER": INTEGER_TYPE,
123+
"BOOLEAN": BOOLEAN_TYPE,
124+
"VARCHAR": VARCHAR_TYPE,
125+
"TIMESTAMP": TIMESTAMP_TYPE,
126+
"FLOAT": FLOAT_TYPE,
127+
"BLOB": BLOB_TYPE,
128+
"UUID": UUID_TYPE,
129+
"JSON": JSON_TYPE,
130+
"YEAR": YEAR,
131+
"MONTH": MONTH,
132+
"DAY": DAY,
133+
"HOUR": HOUR,
134+
"MINUTE": MINUTE,
135+
"SECOND": SECOND,
121136
}
122137

123138
var joinTypes = map[string]JoinType{
@@ -126,17 +141,6 @@ var joinTypes = map[string]JoinType{
126141
"RIGHT": RightJoin,
127142
}
128143

129-
var types = map[string]SQLValueType{
130-
"INTEGER": IntegerType,
131-
"BOOLEAN": BooleanType,
132-
"VARCHAR": VarcharType,
133-
"UUID": UUIDType,
134-
"BLOB": BLOBType,
135-
"TIMESTAMP": TimestampType,
136-
"FLOAT": Float64Type,
137-
"JSON": JSONType,
138-
}
139-
140144
var aggregateFns = map[string]AggregateFn{
141145
"COUNT": COUNT,
142146
"SUM": SUM,
@@ -321,7 +325,7 @@ func (l *lexer) Lex(lval *yySymType) int {
321325
}
322326

323327
lval.blob = val
324-
return BLOB
328+
return BLOB_LIT
325329
}
326330

327331
if isLetter(ch) {
@@ -334,16 +338,10 @@ func (l *lexer) Lex(lval *yySymType) int {
334338
w := fmt.Sprintf("%c%s", ch, tail)
335339
tid := strings.ToUpper(w)
336340

337-
sqlType, ok := types[tid]
338-
if ok {
339-
lval.sqlType = sqlType
340-
return TYPE
341-
}
342-
343341
val, ok := boolValues[tid]
344342
if ok {
345343
lval.boolean = val
346-
return BOOLEAN
344+
return BOOLEAN_LIT
347345
}
348346

349347
afn, ok := aggregateFns[tid]
@@ -358,13 +356,13 @@ func (l *lexer) Lex(lval *yySymType) int {
358356
return JOINTYPE
359357
}
360358

361-
tkn, ok := reservedWords[tid]
359+
tkn, ok := keywords[tid]
362360
if ok {
361+
lval.keyword = w
363362
return tkn
364363
}
365364

366365
lval.id = strings.ToLower(w)
367-
368366
return IDENTIFIER
369367
}
370368

@@ -409,7 +407,7 @@ func (l *lexer) Lex(lval *yySymType) int {
409407
}
410408

411409
lval.float = val
412-
return FLOAT
410+
return FLOAT_LIT
413411
}
414412

415413
val, err := strconv.ParseUint(fmt.Sprintf("%c%s", ch, tail), 10, 64)
@@ -419,7 +417,7 @@ func (l *lexer) Lex(lval *yySymType) int {
419417
}
420418

421419
lval.integer = val
422-
return INTEGER
420+
return INTEGER_LIT
423421
}
424422

425423
if isComparison(ch) {
@@ -452,7 +450,7 @@ func (l *lexer) Lex(lval *yySymType) int {
452450
}
453451

454452
lval.str = tail
455-
return VARCHAR
453+
return VARCHAR_LIT
456454
}
457455

458456
if ch == ':' {
@@ -566,7 +564,7 @@ func (l *lexer) Lex(lval *yySymType) int {
566564
return ERROR
567565
}
568566
lval.float = val
569-
return FLOAT
567+
return FLOAT_LIT
570568
}
571569
return DOT
572570
}
@@ -681,3 +679,35 @@ func isDoubleQuote(ch byte) bool {
681679
func isDot(ch byte) bool {
682680
return ch == '.'
683681
}
682+
683+
func newCreateTableStmt(
684+
name string,
685+
elems []TableElem,
686+
ifNotExists bool,
687+
) *CreateTableStmt {
688+
colsSpecs := make([]*ColSpec, 0, 5)
689+
var checks []CheckConstraint
690+
691+
var pk PrimaryKeyConstraint
692+
for _, e := range elems {
693+
switch c := e.(type) {
694+
case *ColSpec:
695+
colsSpecs = append(colsSpecs, c)
696+
case PrimaryKeyConstraint:
697+
pk = c
698+
case CheckConstraint:
699+
if checks == nil {
700+
checks = make([]CheckConstraint, 0, 5)
701+
}
702+
checks = append(checks, c)
703+
}
704+
}
705+
706+
return &CreateTableStmt{
707+
ifNotExists: ifNotExists,
708+
table: name,
709+
colsSpec: colsSpecs,
710+
pkColNames: pk,
711+
checks: checks,
712+
}
713+
}

0 commit comments

Comments
 (0)