|
3 | 3 | package expr_test |
4 | 4 |
|
5 | 5 | import ( |
| 6 | + "strings" |
6 | 7 | "testing" |
7 | 8 |
|
8 | 9 | "github.com/stretchr/testify/assert" |
9 | 10 | "github.com/stretchr/testify/require" |
10 | 11 | "github.com/substrait-io/substrait-go/v4/expr" |
11 | 12 | "github.com/substrait-io/substrait-go/v4/extensions" |
| 13 | + "github.com/substrait-io/substrait-go/v4/plan" |
12 | 14 | "github.com/substrait-io/substrait-go/v4/types" |
13 | 15 | "github.com/substrait-io/substrait-protobuf/go/substraitpb" |
| 16 | + "google.golang.org/protobuf/types/known/anypb" |
14 | 17 | ) |
15 | 18 |
|
16 | 19 | func TestExprBuilder(t *testing.T) { |
@@ -112,6 +115,140 @@ func TestExprBuilder(t *testing.T) { |
112 | 115 | } |
113 | 116 | } |
114 | 117 |
|
| 118 | +func TestCustomTypesInFunctionOutput(t *testing.T) { |
| 119 | + custom := `%YAML 1.2 |
| 120 | +--- |
| 121 | +types: |
| 122 | + - name: custom_type1 |
| 123 | + - name: custom_type2 |
| 124 | + - name: custom_type3 |
| 125 | + - name: custom_type4 |
| 126 | +
|
| 127 | +scalar_functions: |
| 128 | + - name: custom_function |
| 129 | + description: "custom function that takes in and returns custom types" |
| 130 | + impls: |
| 131 | + - args: |
| 132 | + - name: arg1 |
| 133 | + value: u!custom_type2 |
| 134 | + return: u!custom_type1 |
| 135 | +
|
| 136 | +aggregate_functions: |
| 137 | + - name: "custom_aggr" |
| 138 | + description: "custom aggregator that takes in and returns custom types" |
| 139 | + impls: |
| 140 | + - args: |
| 141 | + - name: arg1 |
| 142 | + value: u!custom_type2 |
| 143 | + return: u!custom_type3 |
| 144 | +
|
| 145 | +window_functions: |
| 146 | + - name: "custom_window" |
| 147 | + description: "custom window function that takes in and returns custom types" |
| 148 | + impls: |
| 149 | + - args: |
| 150 | + - name: arg1 |
| 151 | + value: u!custom_type2 |
| 152 | + return: u!custom_type1 |
| 153 | +` |
| 154 | + |
| 155 | + customReader := strings.NewReader(custom) |
| 156 | + collection := extensions.Collection{} |
| 157 | + err := collection.Load("custom", customReader) |
| 158 | + require.NoError(t, err) |
| 159 | + |
| 160 | + planBuilder := plan.NewBuilder(&collection) |
| 161 | + |
| 162 | + customType1 := planBuilder.UserDefinedType("custom", "custom_type1") |
| 163 | + customType2 := planBuilder.UserDefinedType("custom", "custom_type2") |
| 164 | + customType3 := planBuilder.UserDefinedType("custom", "custom_type3") |
| 165 | + |
| 166 | + anyVal, err := anypb.New(expr.NewPrimitiveLiteral("foo", false).ToProto()) |
| 167 | + require.NoError(t, err) |
| 168 | + |
| 169 | + customLiteral := planBuilder.GetExprBuilder().Literal(&expr.ProtoLiteral{ |
| 170 | + Type: &customType2, |
| 171 | + Value: anyVal, |
| 172 | + }) |
| 173 | + |
| 174 | + // check scalar function |
| 175 | + scalar, err := planBuilder.GetExprBuilder().ScalarFunc(extensions.ID{ |
| 176 | + URI: "custom", |
| 177 | + Name: "custom_function", |
| 178 | + }).Args( |
| 179 | + customLiteral, |
| 180 | + ).BuildExpr() |
| 181 | + require.NoError(t, err) |
| 182 | + scalarProto := scalar.ToProto() |
| 183 | + |
| 184 | + fnCall := scalarProto.GetScalarFunction() |
| 185 | + require.Len(t, fnCall.Arguments, 1) |
| 186 | + require.Equal(t, customType2.TypeReference, fnCall.Arguments[0].GetValue().GetLiteral().GetUserDefined().TypeReference) |
| 187 | + require.Equal(t, customType1.TypeReference, fnCall.OutputType.GetUserDefined().TypeReference) |
| 188 | + |
| 189 | + // check aggregate function |
| 190 | + aggr, err := planBuilder.GetExprBuilder().AggFunc(extensions.ID{ |
| 191 | + URI: "custom", |
| 192 | + Name: "custom_aggr", |
| 193 | + }).Args( |
| 194 | + customLiteral, |
| 195 | + ).Build() |
| 196 | + require.NoError(t, err) |
| 197 | + aggrProto := aggr.ToProto() |
| 198 | + |
| 199 | + require.Len(t, aggrProto.Arguments, 1) |
| 200 | + require.Equal(t, customType2.TypeReference, aggrProto.Arguments[0].GetValue().GetLiteral().GetUserDefined().TypeReference) |
| 201 | + require.Equal(t, customType3.TypeReference, aggrProto.OutputType.GetUserDefined().TypeReference) |
| 202 | + |
| 203 | + // check window function |
| 204 | + window, err := planBuilder.GetExprBuilder().WindowFunc(extensions.ID{ |
| 205 | + URI: "custom", |
| 206 | + Name: "custom_window", |
| 207 | + }).Args( |
| 208 | + customLiteral, |
| 209 | + ).Phase(types.AggPhaseInitialToResult).Build() |
| 210 | + require.NoError(t, err) |
| 211 | + windowProto := window.ToProto() |
| 212 | + |
| 213 | + windowFnCall := windowProto.GetWindowFunction() |
| 214 | + require.Len(t, windowFnCall.Arguments, 1) |
| 215 | + require.Equal(t, customType2.TypeReference, windowFnCall.Arguments[0].GetValue().GetLiteral().GetUserDefined().TypeReference) |
| 216 | + require.Equal(t, customType1.TypeReference, windowFnCall.OutputType.GetUserDefined().TypeReference) |
| 217 | + |
| 218 | + // build a full plan and check that user defined types are registered in the extensions |
| 219 | + table, err := planBuilder.VirtualTable([]string{"col_a", "col_b"}, []expr.Literal{expr.NewPrimitiveLiteral(int64(2), false), expr.NewPrimitiveLiteral(int64(3), false)}) |
| 220 | + require.NoError(t, err) |
| 221 | + |
| 222 | + aggregated, err := planBuilder.GetRelBuilder().AggregateRel(table, []plan.AggRelMeasure{planBuilder.Measure(aggr, window)}).Build() |
| 223 | + require.NoError(t, err) |
| 224 | + |
| 225 | + project, err := planBuilder.Project(aggregated, scalar) |
| 226 | + require.NoError(t, err) |
| 227 | + |
| 228 | + p, err := planBuilder.Plan(project, []string{"output1", "output2"}) |
| 229 | + require.NoError(t, err) |
| 230 | + |
| 231 | + pp, err := p.ToProto() |
| 232 | + require.NoError(t, err) |
| 233 | + |
| 234 | + // custom_type1 is referenced as an argument and return type, so should be registered in the extensions |
| 235 | + // custom_type2 is referenced as an argument and return type, so should be registered in the extensions |
| 236 | + // custom_type3 is only referenced as a return type, but should still be registered in the extensions |
| 237 | + // custom_type4 is not referenced in the plan at all, so not be registerd in the extensions |
| 238 | + typeExtensionsFound := []string{} |
| 239 | + for _, ext := range pp.Extensions { |
| 240 | + typeExt := ext.GetExtensionType() |
| 241 | + if typeExt == nil { |
| 242 | + continue |
| 243 | + } |
| 244 | + typeExtensionsFound = append(typeExtensionsFound, typeExt.GetName()) |
| 245 | + } |
| 246 | + require.Equal(t, 3, len(typeExtensionsFound)) |
| 247 | + require.Contains(t, typeExtensionsFound, "custom_type1") |
| 248 | + require.Contains(t, typeExtensionsFound, "custom_type2") |
| 249 | + require.Contains(t, typeExtensionsFound, "custom_type3") |
| 250 | +} |
| 251 | + |
115 | 252 | func TestBoundFromProto(t *testing.T) { |
116 | 253 | for _, tc := range []struct { |
117 | 254 | name string |
|
0 commit comments