Skip to content

Commit 18c1a41

Browse files
authored
fix: allow variadic variants to evaluate (#55)
* fix: allow variadic variants to evaluate * add some tests
1 parent 1a800d9 commit 18c1a41

File tree

2 files changed

+112
-34
lines changed

2 files changed

+112
-34
lines changed

extensions/variants.go

Lines changed: 64 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -36,43 +36,74 @@ type FunctionVariant interface {
3636
MatchAt(typ types.Type, pos int) (bool, error)
3737
}
3838

39-
func EvaluateTypeExpression(nullHandling NullabilityHandling, expr parser.TypeExpression, paramTypeList ArgumentList, actualTypes []types.Type) (types.Type, error) {
39+
func validateType(arg Argument, actual types.Type, idx int, nullHandling NullabilityHandling) (bool, error) {
40+
allNonNull := true
41+
switch p := arg.(type) {
42+
case EnumArg:
43+
if actual != nil {
44+
return allNonNull, fmt.Errorf("%w: arg #%d (%s) should be an enum",
45+
substraitgo.ErrInvalidType, idx, p.Name)
46+
}
47+
case ValueArg:
48+
if actual == nil {
49+
return allNonNull, fmt.Errorf("%w: arg #%d should be of type %s",
50+
substraitgo.ErrInvalidType, idx, p.toTypeString())
51+
}
52+
53+
isNullable := actual.GetNullability() != types.NullabilityRequired
54+
if isNullable {
55+
allNonNull = false
56+
}
57+
58+
if nullHandling == DiscreteNullability {
59+
if t, ok := p.Value.Expr.(*parser.Type); ok {
60+
if isNullable != t.Optional() {
61+
return allNonNull, fmt.Errorf("%w: discrete nullability did not match for arg #%d",
62+
substraitgo.ErrInvalidType, idx)
63+
}
64+
} else {
65+
return allNonNull, substraitgo.ErrNotImplemented
66+
}
67+
}
68+
case TypeArg:
69+
return allNonNull, substraitgo.ErrNotImplemented
70+
}
71+
72+
return allNonNull, nil
73+
}
74+
75+
func EvaluateTypeExpression(nullHandling NullabilityHandling, expr parser.TypeExpression, paramTypeList ArgumentList, variadic *VariadicBehavior, actualTypes []types.Type) (types.Type, error) {
4076
if len(paramTypeList) != len(actualTypes) {
41-
return nil, fmt.Errorf("%w: mismatch in number of arguments provided. got %d, expected %d",
42-
substraitgo.ErrInvalidExpr, len(actualTypes), len(paramTypeList))
77+
if variadic == nil {
78+
return nil, fmt.Errorf("%w: mismatch in number of arguments provided. got %d, expected %d",
79+
substraitgo.ErrInvalidExpr, len(actualTypes), len(paramTypeList))
80+
}
81+
82+
if !variadic.IsValidArgumentCount(len(actualTypes) - len(paramTypeList) - 1) {
83+
return nil, fmt.Errorf("%w: mismatch in number of arguments provided, invalid number of variadic params. got %d total",
84+
substraitgo.ErrInvalidExpr, len(actualTypes))
85+
}
4386
}
4487

4588
allNonNull := true
4689
for i, p := range paramTypeList {
47-
switch p := p.(type) {
48-
case EnumArg:
49-
if actualTypes[i] != nil {
50-
return nil, fmt.Errorf("%w: arg #%d (%s) should be an enum",
51-
substraitgo.ErrInvalidType, i, p.Name)
52-
}
53-
case ValueArg:
54-
if actualTypes[i] == nil {
55-
return nil, fmt.Errorf("%w: arg #%d should be of type %s",
56-
substraitgo.ErrInvalidType, i, p.toTypeString())
57-
}
58-
59-
isNullable := actualTypes[i].GetNullability() != types.NullabilityRequired
60-
if isNullable {
61-
allNonNull = false
62-
}
90+
nonNull, err := validateType(p, actualTypes[i], i, nullHandling)
91+
if err != nil {
92+
return nil, err
93+
}
94+
allNonNull = allNonNull && nonNull
95+
}
6396

64-
if nullHandling == DiscreteNullability {
65-
if t, ok := p.Value.Expr.(*parser.Type); ok {
66-
if isNullable != t.Optional() {
67-
return nil, fmt.Errorf("%w: discrete nullability did not match for arg #%d",
68-
substraitgo.ErrInvalidType, i)
69-
}
70-
} else {
71-
return nil, substraitgo.ErrNotImplemented
72-
}
97+
// validate varidic argument consistency
98+
if variadic != nil && len(actualTypes) > len(paramTypeList) && variadic.ParameterConsistency == ConsistentParams {
99+
nparams := len(paramTypeList)
100+
lastParam := paramTypeList[nparams-1]
101+
for i, actual := range actualTypes[nparams:] {
102+
nonNull, err := validateType(lastParam, actual, nparams+i, nullHandling)
103+
if err != nil {
104+
return nil, err
73105
}
74-
case TypeArg:
75-
return nil, substraitgo.ErrNotImplemented
106+
allNonNull = allNonNull && nonNull
76107
}
77108
}
78109

@@ -267,7 +298,7 @@ func (s *ScalarFunctionVariant) SessionDependent() bool { return s.imp
267298
func (s *ScalarFunctionVariant) Nullability() NullabilityHandling { return s.impl.Nullability }
268299
func (s *ScalarFunctionVariant) URI() string { return s.uri }
269300
func (s *ScalarFunctionVariant) ResolveType(argumentTypes []types.Type) (types.Type, error) {
270-
return EvaluateTypeExpression(s.impl.Nullability, s.impl.Return, s.impl.Args, argumentTypes)
301+
return EvaluateTypeExpression(s.impl.Nullability, s.impl.Return, s.impl.Args, s.impl.Variadic, argumentTypes)
271302
}
272303
func (s *ScalarFunctionVariant) CompoundName() string {
273304
return s.name + ":" + s.impl.signatureKey()
@@ -375,7 +406,7 @@ func (s *AggregateFunctionVariant) SessionDependent() bool { return s.
375406
func (s *AggregateFunctionVariant) Nullability() NullabilityHandling { return s.impl.Nullability }
376407
func (s *AggregateFunctionVariant) URI() string { return s.uri }
377408
func (s *AggregateFunctionVariant) ResolveType(argumentTypes []types.Type) (types.Type, error) {
378-
return EvaluateTypeExpression(s.impl.Nullability, s.impl.Return, s.impl.Args, argumentTypes)
409+
return EvaluateTypeExpression(s.impl.Nullability, s.impl.Return, s.impl.Args, s.impl.Variadic, argumentTypes)
379410
}
380411
func (s *AggregateFunctionVariant) CompoundName() string {
381412
return s.name + ":" + s.impl.signatureKey()
@@ -488,7 +519,7 @@ func (s *WindowFunctionVariant) SessionDependent() bool { return s.imp
488519
func (s *WindowFunctionVariant) Nullability() NullabilityHandling { return s.impl.Nullability }
489520
func (s *WindowFunctionVariant) URI() string { return s.uri }
490521
func (s *WindowFunctionVariant) ResolveType(argumentTypes []types.Type) (types.Type, error) {
491-
return EvaluateTypeExpression(s.impl.Nullability, s.impl.Return, s.impl.Args, argumentTypes)
522+
return EvaluateTypeExpression(s.impl.Nullability, s.impl.Return, s.impl.Args, s.impl.Variadic, argumentTypes)
492523
}
493524
func (s *WindowFunctionVariant) CompoundName() string {
494525
return s.name + ":" + s.impl.signatureKey()

extensions/variants_test.go

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,54 @@ func TestEvaluateTypeExpression(t *testing.T) {
5757

5858
for _, tt := range tests {
5959
t.Run(tt.name, func(t *testing.T) {
60-
result, err := extensions.EvaluateTypeExpression(tt.nulls, tt.ret, tt.extArgs, tt.args)
60+
result, err := extensions.EvaluateTypeExpression(tt.nulls, tt.ret, tt.extArgs, nil, tt.args)
61+
if tt.err == "" {
62+
assert.NoError(t, err)
63+
assert.Truef(t, tt.expected.Equals(result), "expected: %s\ngot: %s", tt.expected, result)
64+
} else {
65+
assert.EqualError(t, err, tt.err)
66+
}
67+
})
68+
}
69+
}
70+
71+
func TestVariantWithVariadic(t *testing.T) {
72+
var (
73+
p, _ = parser.New()
74+
i64Null, _ = p.ParseString("i64?")
75+
i64NonNull, _ = p.ParseString("i64")
76+
// strNull, _ = p.ParseString("string?")
77+
)
78+
79+
tests := []struct {
80+
name string
81+
nulls extensions.NullabilityHandling
82+
ret parser.TypeExpression
83+
extArgs extensions.ArgumentList
84+
args []types.Type
85+
expected types.Type
86+
variadic extensions.VariadicBehavior
87+
err string
88+
}{
89+
{"basic", "", *i64NonNull, extensions.ArgumentList{
90+
extensions.ValueArg{Value: i64Null}},
91+
[]types.Type{&types.Int64Type{Nullability: types.NullabilityNullable},
92+
&types.Int64Type{Nullability: types.NullabilityNullable}},
93+
&types.Int64Type{Nullability: types.NullabilityNullable},
94+
extensions.VariadicBehavior{
95+
Min: 0, ParameterConsistency: extensions.ConsistentParams}, ""},
96+
{"bad arg count", "", *i64NonNull, extensions.ArgumentList{
97+
extensions.ValueArg{Value: i64Null}},
98+
[]types.Type{&types.Int64Type{Nullability: types.NullabilityNullable},
99+
&types.Int64Type{Nullability: types.NullabilityNullable}},
100+
nil, extensions.VariadicBehavior{
101+
Min: 2, ParameterConsistency: extensions.ConsistentParams},
102+
"invalid expression: mismatch in number of arguments provided, invalid number of variadic params. got 2 total"},
103+
}
104+
105+
for _, tt := range tests {
106+
t.Run(tt.name, func(t *testing.T) {
107+
result, err := extensions.EvaluateTypeExpression(tt.nulls, tt.ret, tt.extArgs, &tt.variadic, tt.args)
61108
if tt.err == "" {
62109
assert.NoError(t, err)
63110
assert.Truef(t, tt.expected.Equals(result), "expected: %s\ngot: %s", tt.expected, result)

0 commit comments

Comments
 (0)