diff --git a/embedded/sql/engine_test.go b/embedded/sql/engine_test.go index 0459162faa..5601146fa7 100644 --- a/embedded/sql/engine_test.go +++ b/embedded/sql/engine_test.go @@ -9513,6 +9513,51 @@ func TestFunctions(t *testing.T) { ) require.NoError(t, err) + t.Run("coalesce", func(t *testing.T) { + type testCase struct { + query string + expectedValues string + err error + } + + cases := []testCase{ + { + query: "SELECT COALESCE (NULL)", + expectedValues: "NULL", + }, + { + query: "SELECT COALESCE (NULL, NULL)", + expectedValues: "NULL", + }, + { + query: "SELECT COALESCE(NULL, 1, 1.5, 3)", + expectedValues: "1", + }, + { + query: "SELECT COALESCE('one', 'two', 'three')", + expectedValues: "'one'", + }, + { + query: "SELECT COALESCE(1, 'test')", + err: ErrInvalidTypes, + }, + } + + for _, tc := range cases { + if tc.err != nil { + _, err := engine.queryAll(context.Background(), nil, tc.query, nil) + require.ErrorIs(t, err, tc.err) + continue + } + + assertQueryShouldProduceResults( + t, + engine, + tc.query, + fmt.Sprintf("SELECT * FROM (VALUES (%s))", tc.expectedValues)) + } + }) + t.Run("timestamp functions", func(t *testing.T) { _, err := engine.queryAll(context.Background(), nil, "SELECT NOW(1) FROM mytable", nil) require.ErrorIs(t, err, ErrIllegalArguments) diff --git a/embedded/sql/functions.go b/embedded/sql/functions.go index a93d00f80e..ad2019cfc2 100644 --- a/embedded/sql/functions.go +++ b/embedded/sql/functions.go @@ -25,6 +25,7 @@ import ( ) const ( + CoalesceFnCall string = "COALESCE" LengthFnCall string = "LENGTH" SubstringFnCall string = "SUBSTRING" ConcatFnCall string = "CONCAT" @@ -47,6 +48,7 @@ const ( ) var builtinFunctions = map[string]Function{ + CoalesceFnCall: &CoalesceFn{}, LengthFnCall: &LengthFn{}, SubstringFnCall: &SubstringFn{}, ConcatFnCall: &ConcatFn{}, @@ -67,6 +69,37 @@ type Function interface { Apply(tx *SQLTx, params []TypedValue) (TypedValue, error) } +type CoalesceFn struct{} + +func (f *CoalesceFn) InferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) { + return AnyType, nil +} + +func (f *CoalesceFn) RequiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error { + return nil +} + +func (f *CoalesceFn) Apply(tx *SQLTx, params []TypedValue) (TypedValue, error) { + t := AnyType + + for _, p := range params { + if !p.IsNull() { + if t == AnyType { + t = p.Type() + } else if p.Type() != t && !(IsNumericType(t) && IsNumericType(p.Type())) { + return nil, fmt.Errorf("coalesce: %w", ErrInvalidTypes) + } + } + } + + for _, p := range params { + if !p.IsNull() { + return p, nil + } + } + return NewNull(t), nil +} + // ------------------------------------- // String Functions // -------------------------------------