Skip to content

Commit 485f6dc

Browse files
authored
fix: handle user-defined types in ResolveType (#147)
ResolveType needs to consume an extension.Set in order to handle user-defines types when they are present as the return type of a function BREAKING CHANGE: ResolveType now consumes a extensions.Set BREAKING CHANGE: EvaluateTypeExpression now consumes a URI and a extensions.Set
1 parent 3b86a0e commit 485f6dc

File tree

6 files changed

+262
-19
lines changed

6 files changed

+262
-19
lines changed

expr/builder_test.go

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,17 @@
33
package expr_test
44

55
import (
6+
"strings"
67
"testing"
78

89
"github.com/stretchr/testify/assert"
910
"github.com/stretchr/testify/require"
1011
"github.com/substrait-io/substrait-go/v4/expr"
1112
"github.com/substrait-io/substrait-go/v4/extensions"
13+
"github.com/substrait-io/substrait-go/v4/plan"
1214
"github.com/substrait-io/substrait-go/v4/types"
1315
"github.com/substrait-io/substrait-protobuf/go/substraitpb"
16+
"google.golang.org/protobuf/types/known/anypb"
1417
)
1518

1619
func TestExprBuilder(t *testing.T) {
@@ -112,6 +115,140 @@ func TestExprBuilder(t *testing.T) {
112115
}
113116
}
114117

118+
func TestCustomTypesInFunctionOutput(t *testing.T) {
119+
custom := `%YAML 1.2
120+
---
121+
types:
122+
- name: custom_type1
123+
- name: custom_type2
124+
- name: custom_type3
125+
- name: custom_type4
126+
127+
scalar_functions:
128+
- name: custom_function
129+
description: "custom function that takes in and returns custom types"
130+
impls:
131+
- args:
132+
- name: arg1
133+
value: u!custom_type2
134+
return: u!custom_type1
135+
136+
aggregate_functions:
137+
- name: "custom_aggr"
138+
description: "custom aggregator that takes in and returns custom types"
139+
impls:
140+
- args:
141+
- name: arg1
142+
value: u!custom_type2
143+
return: u!custom_type3
144+
145+
window_functions:
146+
- name: "custom_window"
147+
description: "custom window function that takes in and returns custom types"
148+
impls:
149+
- args:
150+
- name: arg1
151+
value: u!custom_type2
152+
return: u!custom_type1
153+
`
154+
155+
customReader := strings.NewReader(custom)
156+
collection := extensions.Collection{}
157+
err := collection.Load("custom", customReader)
158+
require.NoError(t, err)
159+
160+
planBuilder := plan.NewBuilder(&collection)
161+
162+
customType1 := planBuilder.UserDefinedType("custom", "custom_type1")
163+
customType2 := planBuilder.UserDefinedType("custom", "custom_type2")
164+
customType3 := planBuilder.UserDefinedType("custom", "custom_type3")
165+
166+
anyVal, err := anypb.New(expr.NewPrimitiveLiteral("foo", false).ToProto())
167+
require.NoError(t, err)
168+
169+
customLiteral := planBuilder.GetExprBuilder().Literal(&expr.ProtoLiteral{
170+
Type: &customType2,
171+
Value: anyVal,
172+
})
173+
174+
// check scalar function
175+
scalar, err := planBuilder.GetExprBuilder().ScalarFunc(extensions.ID{
176+
URI: "custom",
177+
Name: "custom_function",
178+
}).Args(
179+
customLiteral,
180+
).BuildExpr()
181+
require.NoError(t, err)
182+
scalarProto := scalar.ToProto()
183+
184+
fnCall := scalarProto.GetScalarFunction()
185+
require.Len(t, fnCall.Arguments, 1)
186+
require.Equal(t, customType2.TypeReference, fnCall.Arguments[0].GetValue().GetLiteral().GetUserDefined().TypeReference)
187+
require.Equal(t, customType1.TypeReference, fnCall.OutputType.GetUserDefined().TypeReference)
188+
189+
// check aggregate function
190+
aggr, err := planBuilder.GetExprBuilder().AggFunc(extensions.ID{
191+
URI: "custom",
192+
Name: "custom_aggr",
193+
}).Args(
194+
customLiteral,
195+
).Build()
196+
require.NoError(t, err)
197+
aggrProto := aggr.ToProto()
198+
199+
require.Len(t, aggrProto.Arguments, 1)
200+
require.Equal(t, customType2.TypeReference, aggrProto.Arguments[0].GetValue().GetLiteral().GetUserDefined().TypeReference)
201+
require.Equal(t, customType3.TypeReference, aggrProto.OutputType.GetUserDefined().TypeReference)
202+
203+
// check window function
204+
window, err := planBuilder.GetExprBuilder().WindowFunc(extensions.ID{
205+
URI: "custom",
206+
Name: "custom_window",
207+
}).Args(
208+
customLiteral,
209+
).Phase(types.AggPhaseInitialToResult).Build()
210+
require.NoError(t, err)
211+
windowProto := window.ToProto()
212+
213+
windowFnCall := windowProto.GetWindowFunction()
214+
require.Len(t, windowFnCall.Arguments, 1)
215+
require.Equal(t, customType2.TypeReference, windowFnCall.Arguments[0].GetValue().GetLiteral().GetUserDefined().TypeReference)
216+
require.Equal(t, customType1.TypeReference, windowFnCall.OutputType.GetUserDefined().TypeReference)
217+
218+
// build a full plan and check that user defined types are registered in the extensions
219+
table, err := planBuilder.VirtualTable([]string{"col_a", "col_b"}, []expr.Literal{expr.NewPrimitiveLiteral(int64(2), false), expr.NewPrimitiveLiteral(int64(3), false)})
220+
require.NoError(t, err)
221+
222+
aggregated, err := planBuilder.GetRelBuilder().AggregateRel(table, []plan.AggRelMeasure{planBuilder.Measure(aggr, window)}).Build()
223+
require.NoError(t, err)
224+
225+
project, err := planBuilder.Project(aggregated, scalar)
226+
require.NoError(t, err)
227+
228+
p, err := planBuilder.Plan(project, []string{"output1", "output2"})
229+
require.NoError(t, err)
230+
231+
pp, err := p.ToProto()
232+
require.NoError(t, err)
233+
234+
// custom_type1 is referenced as an argument and return type, so should be registered in the extensions
235+
// custom_type2 is referenced as an argument and return type, so should be registered in the extensions
236+
// custom_type3 is only referenced as a return type, but should still be registered in the extensions
237+
// custom_type4 is not referenced in the plan at all, so not be registerd in the extensions
238+
typeExtensionsFound := []string{}
239+
for _, ext := range pp.Extensions {
240+
typeExt := ext.GetExtensionType()
241+
if typeExt == nil {
242+
continue
243+
}
244+
typeExtensionsFound = append(typeExtensionsFound, typeExt.GetName())
245+
}
246+
require.Equal(t, 3, len(typeExtensionsFound))
247+
require.Contains(t, typeExtensionsFound, "custom_type1")
248+
require.Contains(t, typeExtensionsFound, "custom_type2")
249+
require.Contains(t, typeExtensionsFound, "custom_type3")
250+
}
251+
115252
func TestBoundFromProto(t *testing.T) {
116253
for _, tc := range []struct {
117254
name string

expr/functions.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ func NewCustomScalarFunc(
206206

207207
type variant interface {
208208
*extensions.ScalarFunctionVariant | *extensions.AggregateFunctionVariant | *extensions.WindowFunctionVariant
209-
ResolveType([]types.Type) (types.Type, error)
209+
ResolveType([]types.Type, extensions.Set) (types.Type, error)
210210
}
211211

212212
func resolveVariant[T variant](
@@ -253,7 +253,7 @@ func resolveVariant[T variant](
253253
}
254254
}
255255

256-
outType, err := decl.ResolveType(argTypes)
256+
outType, err := decl.ResolveType(argTypes, reg.Set)
257257
if err != nil {
258258
return nil, nil, err
259259
}

extensions/extension_mgr_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ func TestLoadExtensionCollection(t *testing.T) {
115115
}}, add.Options())
116116

117117
i8Req := &types.Int8Type{Nullability: types.NullabilityRequired}
118-
ty, err := add.ResolveType([]types.Type{i8Req, i8Req})
118+
ty, err := add.ResolveType([]types.Type{i8Req, i8Req}, extensions.NewSet())
119119
assert.NoError(t, err)
120120
assert.Equal(t, i8Req, ty)
121121
})
@@ -133,7 +133,7 @@ func TestLoadExtensionCollection(t *testing.T) {
133133
assert.Equal(t, "subtract:i16_i16", sub.CompoundName())
134134

135135
i16Req := &types.Int16Type{Nullability: types.NullabilityRequired}
136-
ty, err := sub.ResolveType([]types.Type{i16Req, i16Req})
136+
ty, err := sub.ResolveType([]types.Type{i16Req, i16Req}, extensions.NewSet())
137137
assert.NoError(t, err)
138138
assert.Equal(t, i16Req, ty)
139139
})
@@ -422,10 +422,10 @@ func TestAggregateToWindow(t *testing.T) {
422422
// Test type resolution with the same arguments
423423
// Use a concrete type for testing resolution
424424
i32Type := &types.Int32Type{Nullability: types.NullabilityRequired}
425-
aggType, err := aggFunc.ResolveType([]types.Type{i32Type})
425+
aggType, err := aggFunc.ResolveType([]types.Type{i32Type}, extensions.NewSet())
426426
require.NoError(t, err)
427427

428-
winType, err := winFunc.ResolveType([]types.Type{i32Type})
428+
winType, err := winFunc.ResolveType([]types.Type{i32Type}, extensions.NewSet())
429429
require.NoError(t, err)
430430

431431
assert.Equal(t, aggType, winType)

extensions/variants.go

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ type FunctionVariant interface {
1919
Args() FuncParameterList
2020
Options() map[string]Option
2121
URI() string
22-
ResolveType(argTypes []types.Type) (types.Type, error)
22+
ResolveType(argTypes []types.Type, registry Set) (types.Type, error)
2323
Variadic() *VariadicBehavior
2424
// Match this function matches input arguments against this functions parameter list
2525
// returns (true, nil) if all input argument can type replace the function definition argument
@@ -74,10 +74,12 @@ func validateType(funcParameter FuncParameter, actual types.Type, idx int, nullH
7474

7575
// EvaluateTypeExpression evaluates the function return type given the input argumentTypes
7676
//
77+
// uri: the uri of the extension that defines the function. for functions that return user defined types, we assume the uri of the return type is the same as the uri of the function.
7778
// funcParameters: the function parameters as defined in the function signature in the extension
7879
// argumentTypes: the actual argument types provided to the function
79-
func EvaluateTypeExpression(nullHandling NullabilityHandling, returnTypeExpr types.FuncDefArgType,
80-
funcParameters FuncParameterList, variadic *VariadicBehavior, argumentTypes []types.Type) (types.Type, error) {
80+
// registry: the Set of extensions to look up/add user defined types to
81+
func EvaluateTypeExpression(uri string, nullHandling NullabilityHandling, returnTypeExpr types.FuncDefArgType,
82+
funcParameters FuncParameterList, variadic *VariadicBehavior, argumentTypes []types.Type, registry Set) (types.Type, error) {
8183
if variadic != nil {
8284
numVariadicArgs := len(argumentTypes) - (len(funcParameters) - 1)
8385
if numVariadicArgs < 0 {
@@ -136,6 +138,11 @@ func EvaluateTypeExpression(nullHandling NullabilityHandling, returnTypeExpr typ
136138
return nil, err
137139
}
138140

141+
if udt, ok := outType.(*types.UserDefinedType); ok {
142+
name := strings.TrimPrefix(returnTypeExpr.ShortString(), "u!") // short string contains the u! prefix, but type definitions in the extensions don't
143+
udt.TypeReference = registry.GetTypeAnchor(ID{Name: name, URI: uri})
144+
}
145+
139146
if nullHandling == MirrorNullability || nullHandling == "" {
140147
if allNonNull {
141148
return outType.WithNullability(types.NullabilityRequired), nil
@@ -328,8 +335,8 @@ func (s *ScalarFunctionVariant) Deterministic() bool { return s.imp
328335
func (s *ScalarFunctionVariant) SessionDependent() bool { return s.impl.SessionDependent }
329336
func (s *ScalarFunctionVariant) Nullability() NullabilityHandling { return s.impl.Nullability }
330337
func (s *ScalarFunctionVariant) URI() string { return s.uri }
331-
func (s *ScalarFunctionVariant) ResolveType(argumentTypes []types.Type) (types.Type, error) {
332-
return EvaluateTypeExpression(s.impl.Nullability, s.impl.Return.ValueType, s.impl.Args, s.impl.Variadic, argumentTypes)
338+
func (s *ScalarFunctionVariant) ResolveType(argumentTypes []types.Type, registry Set) (types.Type, error) {
339+
return EvaluateTypeExpression(s.uri, s.impl.Nullability, s.impl.Return.ValueType, s.impl.Args, s.impl.Variadic, argumentTypes, registry)
333340
}
334341
func (s *ScalarFunctionVariant) CompoundName() string {
335342
return s.name + ":" + s.impl.signatureKey()
@@ -440,8 +447,8 @@ func (s *AggregateFunctionVariant) Deterministic() bool { return s.
440447
func (s *AggregateFunctionVariant) SessionDependent() bool { return s.impl.SessionDependent }
441448
func (s *AggregateFunctionVariant) Nullability() NullabilityHandling { return s.impl.Nullability }
442449
func (s *AggregateFunctionVariant) URI() string { return s.uri }
443-
func (s *AggregateFunctionVariant) ResolveType(argumentTypes []types.Type) (types.Type, error) {
444-
return EvaluateTypeExpression(s.impl.Nullability, s.impl.Return.ValueType, s.impl.Args, s.impl.Variadic, argumentTypes)
450+
func (s *AggregateFunctionVariant) ResolveType(argumentTypes []types.Type, registry Set) (types.Type, error) {
451+
return EvaluateTypeExpression(s.uri, s.impl.Nullability, s.impl.Return.ValueType, s.impl.Args, s.impl.Variadic, argumentTypes, registry)
445452
}
446453
func (s *AggregateFunctionVariant) CompoundName() string {
447454
return s.name + ":" + s.impl.signatureKey()
@@ -560,8 +567,8 @@ func (s *WindowFunctionVariant) Deterministic() bool { return s.imp
560567
func (s *WindowFunctionVariant) SessionDependent() bool { return s.impl.SessionDependent }
561568
func (s *WindowFunctionVariant) Nullability() NullabilityHandling { return s.impl.Nullability }
562569
func (s *WindowFunctionVariant) URI() string { return s.uri }
563-
func (s *WindowFunctionVariant) ResolveType(argumentTypes []types.Type) (types.Type, error) {
564-
return EvaluateTypeExpression(s.impl.Nullability, s.impl.Return.ValueType, s.impl.Args, s.impl.Variadic, argumentTypes)
570+
func (s *WindowFunctionVariant) ResolveType(argumentTypes []types.Type, registry Set) (types.Type, error) {
571+
return EvaluateTypeExpression(s.uri, s.impl.Nullability, s.impl.Return.ValueType, s.impl.Args, s.impl.Variadic, argumentTypes, registry)
565572
}
566573
func (s *WindowFunctionVariant) CompoundName() string {
567574
return s.name + ":" + s.impl.signatureKey()

0 commit comments

Comments
 (0)