@@ -542,6 +542,96 @@ func TestParseAggregateFuncWithVariousTypes(t *testing.T) {
542542 }
543543}
544544
545+ func TestParseAggregateFuncAllFormats (t * testing.T ) {
546+ header := makeAggregateTestHeader ("v1.0" , "/extensions/functions_arithmetic.yaml" )
547+ header += "# basic\n "
548+
549+ tests := []struct {
550+ testCaseStr string
551+ wantData [][]expr.Literal
552+ }{
553+ {"avg((1,2,3)::i64) = 2::fp64" , [][]expr.Literal {newInt64Values (1 , 2 , 3 )}},
554+ {"((1), (2), (3)) avg(col0::i64) = 2::fp64" , [][]expr.Literal {newInt64Values (1 , 2 , 3 )}},
555+ {"DEFINE t1(i64) = ((1), (2), (3))\n avg(t1.col0) = 2::fp64" , [][]expr.Literal {newInt64Values (1 , 2 , 3 )}},
556+
557+ // tests with empty input data
558+ {"avg(()::i64) = 2::fp64" , [][]expr.Literal {{}}},
559+ {"DEFINE t1(i64) = ()\n avg(t1.col0) = 2::fp64" , [][]expr.Literal {{}}},
560+
561+ //tests with multiple columns
562+ {"((20, 20), (-3, -3), (1, 1), (10,10), (5,5)) corr(col0::fp32, col1::fp32?) = 1::fp64?" , [][]expr.Literal {newFloat32Values (false , 20 , - 3 , 1 , 10 , 5 ), newFloat32Values (true , 20 , - 3 , 1 , 10 , 5 )}},
563+ {"DEFINE t1(fp32, fp32?) = ((20, 20), (-3, -3), (1, 1), (10,10), (5,5))\n corr(t1.col0, t1.col1) = 1::fp64?" , [][]expr.Literal {newFloat32Values (false , 20 , - 3 , 1 , 10 , 5 ), newFloat32Values (true , 20 , - 3 , 1 , 10 , 5 )}},
564+ }
565+ for _ , test := range tests {
566+ t .Run (test .testCaseStr , func (t * testing.T ) {
567+ testFile , err := ParseTestCasesFromString (header + test .testCaseStr )
568+ require .NoError (t , err )
569+ require .NotNil (t , testFile )
570+ assert .Len (t , testFile .TestCases , 1 )
571+ tc := testFile .TestCases [0 ]
572+ assert .Contains (t , test .testCaseStr , tc .FuncName )
573+ assert .Equal (t , tc .GroupDesc , "basic" )
574+ assert .Equal (t , tc .BaseURI , "/extensions/functions_arithmetic.yaml" )
575+ assert .Len (t , tc .Args , 0 )
576+
577+ // check that the types are correct
578+ argTypes := tc .GetArgTypes ()
579+ assert .Len (t , argTypes , len (test .wantData ))
580+ if len (test .wantData [0 ]) > 0 {
581+ for i , argType := range argTypes {
582+ assert .Equal (t , argType , test .wantData [i ][0 ].GetType ())
583+ }
584+ } else {
585+ // check that the type is correct for empty input data
586+ assert .Equal (t , & types.Int64Type {Nullability : types .NullabilityRequired }, argTypes [0 ])
587+ }
588+
589+ assert .Equal (t , AggregateFuncType , tc .FuncType )
590+ _ , err = tc .GetScalarFunctionInvocation (nil , nil )
591+ require .Error (t , err )
592+
593+ reg := expr .NewEmptyExtensionRegistry (extensions .GetDefaultCollectionWithNoError ())
594+ testGetFunctionInvocation (t , tc , & reg , nil )
595+ data , err := tc .GetAggregateColumnsData ()
596+ require .NoError (t , err )
597+
598+ // check that the data is correct
599+ assert .Len (t , data , len (test .wantData ))
600+ assert .Equal (t , test .wantData , data )
601+ })
602+ }
603+ }
604+
605+ func TestBadInputsToGetAggregateColumnsData (t * testing.T ) {
606+ tests := []struct {
607+ name string
608+ testCase * TestCase
609+ expectedError error
610+ }{
611+ {
612+ name : "invalid function type" ,
613+ testCase : & TestCase {FuncType : ScalarFuncType },
614+ expectedError : fmt .Errorf ("expected function type %v, but got %v" , AggregateFuncType , ScalarFuncType ),
615+ },
616+ {
617+ name : "invalid argument type" ,
618+ testCase : & TestCase {
619+ FuncType : AggregateFuncType ,
620+ AggregateArgs : []* AggregateArgument {{Argument : & CaseLiteral {Value : expr .NewNullLiteral (& types.Float32Type {})}}},
621+ },
622+ expectedError : fmt .Errorf ("column 0: expected NestedLiteral[ListLiteralValue], but got %T" , expr .NewNullLiteral (& types.Float32Type {})),
623+ },
624+ }
625+
626+ for _ , tt := range tests {
627+ t .Run (tt .name , func (t * testing.T ) {
628+ _ , err := tt .testCase .GetAggregateColumnsData ()
629+ assert .Error (t , err )
630+ assert .Equal (t , tt .expectedError .Error (), err .Error ())
631+ })
632+ }
633+ }
634+
545635func TestParseAggregateFuncWithMixedArgs (t * testing.T ) {
546636 header := makeAggregateTestHeader ("v1.0" , "/extensions/functions_arithmetic.yaml" )
547637 tests := `# basic
0 commit comments