Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions embedded/sql/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9513,6 +9513,51 @@ func TestFunctions(t *testing.T) {
)
require.NoError(t, err)

t.Run("coalesce", func(t *testing.T) {
type testCase struct {
query string
expectedValues string
err error
}

cases := []testCase{
{
query: "SELECT COALESCE (NULL)",
expectedValues: "NULL",
},
{
query: "SELECT COALESCE (NULL, NULL)",
expectedValues: "NULL",
},
{
query: "SELECT COALESCE(NULL, 1, 1.5, 3)",
expectedValues: "1",
},
{
query: "SELECT COALESCE('one', 'two', 'three')",
expectedValues: "'one'",
},
{
query: "SELECT COALESCE(1, 'test')",
err: ErrInvalidTypes,
},
}

for _, tc := range cases {
if tc.err != nil {
_, err := engine.queryAll(context.Background(), nil, tc.query, nil)
require.ErrorIs(t, err, tc.err)
continue
}

assertQueryShouldProduceResults(
t,
engine,
tc.query,
fmt.Sprintf("SELECT * FROM (VALUES (%s))", tc.expectedValues))
}
})

t.Run("timestamp functions", func(t *testing.T) {
_, err := engine.queryAll(context.Background(), nil, "SELECT NOW(1) FROM mytable", nil)
require.ErrorIs(t, err, ErrIllegalArguments)
Expand Down
33 changes: 33 additions & 0 deletions embedded/sql/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
)

const (
CoalesceFnCall string = "COALESCE"
LengthFnCall string = "LENGTH"
SubstringFnCall string = "SUBSTRING"
ConcatFnCall string = "CONCAT"
Expand All @@ -47,6 +48,7 @@ const (
)

var builtinFunctions = map[string]Function{
CoalesceFnCall: &CoalesceFn{},
LengthFnCall: &LengthFn{},
SubstringFnCall: &SubstringFn{},
ConcatFnCall: &ConcatFn{},
Expand All @@ -67,6 +69,37 @@ type Function interface {
Apply(tx *SQLTx, params []TypedValue) (TypedValue, error)
}

type CoalesceFn struct{}

func (f *CoalesceFn) InferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
return AnyType, nil
}

func (f *CoalesceFn) RequiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
return nil
}

func (f *CoalesceFn) Apply(tx *SQLTx, params []TypedValue) (TypedValue, error) {
t := AnyType

for _, p := range params {
if !p.IsNull() {
if t == AnyType {
t = p.Type()
} else if p.Type() != t && !(IsNumericType(t) && IsNumericType(p.Type())) {
return nil, fmt.Errorf("coalesce: %w", ErrInvalidTypes)
}
}
}

for _, p := range params {
if !p.IsNull() {
return p, nil
}
}
return NewNull(t), nil
}

// -------------------------------------
// String Functions
// -------------------------------------
Expand Down
Loading