Skip to content

Commit 32ce783

Browse files
authored
feat: unify column data retrieval for aggregate test cases (#123)
1 parent b0cb727 commit 32ce783

File tree

2 files changed

+114
-0
lines changed

2 files changed

+114
-0
lines changed

testcases/parser/nodes.go

+24
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,30 @@ func (tc *TestCase) GetAggregateFunctionInvocation(reg *expr.ExtensionRegistry,
391391
return nil, fmt.Errorf("%w: no matching function found or %s", substraitgo.ErrNotFound, id)
392392
}
393393

394+
func (tc *TestCase) GetAggregateColumnsData() ([][]expr.Literal, error) {
395+
if tc.FuncType != AggregateFuncType {
396+
return nil, fmt.Errorf("expected function type %v, but got %v", AggregateFuncType, tc.FuncType)
397+
}
398+
399+
if len(tc.Columns) > 0 {
400+
return tc.Columns, nil
401+
}
402+
403+
columns := make([][]expr.Literal, len(tc.AggregateArgs))
404+
405+
for colIdx, arg := range tc.AggregateArgs {
406+
values, ok := arg.Argument.Value.(*expr.NestedLiteral[expr.ListLiteralValue])
407+
if !ok {
408+
return nil, fmt.Errorf("column %d: expected NestedLiteral[ListLiteralValue], but got %T", colIdx, arg.Argument.Value)
409+
}
410+
411+
columns[colIdx] = make([]expr.Literal, len(values.Value))
412+
copy(columns[colIdx], values.Value)
413+
}
414+
415+
return columns, nil
416+
}
417+
394418
type TestGroup struct {
395419
Description string
396420
TestCases []*TestCase

testcases/parser/parse_test.go

+90
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,96 @@ func TestParseAggregateFuncWithVariousTypes(t *testing.T) {
542542
}
543543
}
544544

545+
func TestParseAggregateFuncAllFormats(t *testing.T) {
546+
header := makeAggregateTestHeader("v1.0", "/extensions/functions_arithmetic.yaml")
547+
header += "# basic\n"
548+
549+
tests := []struct {
550+
testCaseStr string
551+
wantData [][]expr.Literal
552+
}{
553+
{"avg((1,2,3)::i64) = 2::fp64", [][]expr.Literal{newInt64Values(1, 2, 3)}},
554+
{"((1), (2), (3)) avg(col0::i64) = 2::fp64", [][]expr.Literal{newInt64Values(1, 2, 3)}},
555+
{"DEFINE t1(i64) = ((1), (2), (3))\navg(t1.col0) = 2::fp64", [][]expr.Literal{newInt64Values(1, 2, 3)}},
556+
557+
// tests with empty input data
558+
{"avg(()::i64) = 2::fp64", [][]expr.Literal{{}}},
559+
{"DEFINE t1(i64) = ()\navg(t1.col0) = 2::fp64", [][]expr.Literal{{}}},
560+
561+
//tests with multiple columns
562+
{"((20, 20), (-3, -3), (1, 1), (10,10), (5,5)) corr(col0::fp32, col1::fp32?) = 1::fp64?", [][]expr.Literal{newFloat32Values(false, 20, -3, 1, 10, 5), newFloat32Values(true, 20, -3, 1, 10, 5)}},
563+
{"DEFINE t1(fp32, fp32?) = ((20, 20), (-3, -3), (1, 1), (10,10), (5,5))\ncorr(t1.col0, t1.col1) = 1::fp64?", [][]expr.Literal{newFloat32Values(false, 20, -3, 1, 10, 5), newFloat32Values(true, 20, -3, 1, 10, 5)}},
564+
}
565+
for _, test := range tests {
566+
t.Run(test.testCaseStr, func(t *testing.T) {
567+
testFile, err := ParseTestCasesFromString(header + test.testCaseStr)
568+
require.NoError(t, err)
569+
require.NotNil(t, testFile)
570+
assert.Len(t, testFile.TestCases, 1)
571+
tc := testFile.TestCases[0]
572+
assert.Contains(t, test.testCaseStr, tc.FuncName)
573+
assert.Equal(t, tc.GroupDesc, "basic")
574+
assert.Equal(t, tc.BaseURI, "/extensions/functions_arithmetic.yaml")
575+
assert.Len(t, tc.Args, 0)
576+
577+
// check that the types are correct
578+
argTypes := tc.GetArgTypes()
579+
assert.Len(t, argTypes, len(test.wantData))
580+
if len(test.wantData[0]) > 0 {
581+
for i, argType := range argTypes {
582+
assert.Equal(t, argType, test.wantData[i][0].GetType())
583+
}
584+
} else {
585+
// check that the type is correct for empty input data
586+
assert.Equal(t, &types.Int64Type{Nullability: types.NullabilityRequired}, argTypes[0])
587+
}
588+
589+
assert.Equal(t, AggregateFuncType, tc.FuncType)
590+
_, err = tc.GetScalarFunctionInvocation(nil, nil)
591+
require.Error(t, err)
592+
593+
reg := expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError())
594+
testGetFunctionInvocation(t, tc, &reg, nil)
595+
data, err := tc.GetAggregateColumnsData()
596+
require.NoError(t, err)
597+
598+
// check that the data is correct
599+
assert.Len(t, data, len(test.wantData))
600+
assert.Equal(t, test.wantData, data)
601+
})
602+
}
603+
}
604+
605+
func TestBadInputsToGetAggregateColumnsData(t *testing.T) {
606+
tests := []struct {
607+
name string
608+
testCase *TestCase
609+
expectedError error
610+
}{
611+
{
612+
name: "invalid function type",
613+
testCase: &TestCase{FuncType: ScalarFuncType},
614+
expectedError: fmt.Errorf("expected function type %v, but got %v", AggregateFuncType, ScalarFuncType),
615+
},
616+
{
617+
name: "invalid argument type",
618+
testCase: &TestCase{
619+
FuncType: AggregateFuncType,
620+
AggregateArgs: []*AggregateArgument{{Argument: &CaseLiteral{Value: expr.NewNullLiteral(&types.Float32Type{})}}},
621+
},
622+
expectedError: fmt.Errorf("column 0: expected NestedLiteral[ListLiteralValue], but got %T", expr.NewNullLiteral(&types.Float32Type{})),
623+
},
624+
}
625+
626+
for _, tt := range tests {
627+
t.Run(tt.name, func(t *testing.T) {
628+
_, err := tt.testCase.GetAggregateColumnsData()
629+
assert.Error(t, err)
630+
assert.Equal(t, tt.expectedError.Error(), err.Error())
631+
})
632+
}
633+
}
634+
545635
func TestParseAggregateFuncWithMixedArgs(t *testing.T) {
546636
header := makeAggregateTestHeader("v1.0", "/extensions/functions_arithmetic.yaml")
547637
tests := `# basic

0 commit comments

Comments
 (0)