@@ -542,6 +542,96 @@ func TestParseAggregateFuncWithVariousTypes(t *testing.T) {
542
542
}
543
543
}
544
544
545
+ func TestParseAggregateFuncAllFormats (t * testing.T ) {
546
+ header := makeAggregateTestHeader ("v1.0" , "/extensions/functions_arithmetic.yaml" )
547
+ header += "# basic\n "
548
+
549
+ tests := []struct {
550
+ testCaseStr string
551
+ wantData [][]expr.Literal
552
+ }{
553
+ {"avg((1,2,3)::i64) = 2::fp64" , [][]expr.Literal {newInt64Values (1 , 2 , 3 )}},
554
+ {"((1), (2), (3)) avg(col0::i64) = 2::fp64" , [][]expr.Literal {newInt64Values (1 , 2 , 3 )}},
555
+ {"DEFINE t1(i64) = ((1), (2), (3))\n avg(t1.col0) = 2::fp64" , [][]expr.Literal {newInt64Values (1 , 2 , 3 )}},
556
+
557
+ // tests with empty input data
558
+ {"avg(()::i64) = 2::fp64" , [][]expr.Literal {{}}},
559
+ {"DEFINE t1(i64) = ()\n avg(t1.col0) = 2::fp64" , [][]expr.Literal {{}}},
560
+
561
+ //tests with multiple columns
562
+ {"((20, 20), (-3, -3), (1, 1), (10,10), (5,5)) corr(col0::fp32, col1::fp32?) = 1::fp64?" , [][]expr.Literal {newFloat32Values (false , 20 , - 3 , 1 , 10 , 5 ), newFloat32Values (true , 20 , - 3 , 1 , 10 , 5 )}},
563
+ {"DEFINE t1(fp32, fp32?) = ((20, 20), (-3, -3), (1, 1), (10,10), (5,5))\n corr(t1.col0, t1.col1) = 1::fp64?" , [][]expr.Literal {newFloat32Values (false , 20 , - 3 , 1 , 10 , 5 ), newFloat32Values (true , 20 , - 3 , 1 , 10 , 5 )}},
564
+ }
565
+ for _ , test := range tests {
566
+ t .Run (test .testCaseStr , func (t * testing.T ) {
567
+ testFile , err := ParseTestCasesFromString (header + test .testCaseStr )
568
+ require .NoError (t , err )
569
+ require .NotNil (t , testFile )
570
+ assert .Len (t , testFile .TestCases , 1 )
571
+ tc := testFile .TestCases [0 ]
572
+ assert .Contains (t , test .testCaseStr , tc .FuncName )
573
+ assert .Equal (t , tc .GroupDesc , "basic" )
574
+ assert .Equal (t , tc .BaseURI , "/extensions/functions_arithmetic.yaml" )
575
+ assert .Len (t , tc .Args , 0 )
576
+
577
+ // check that the types are correct
578
+ argTypes := tc .GetArgTypes ()
579
+ assert .Len (t , argTypes , len (test .wantData ))
580
+ if len (test .wantData [0 ]) > 0 {
581
+ for i , argType := range argTypes {
582
+ assert .Equal (t , argType , test .wantData [i ][0 ].GetType ())
583
+ }
584
+ } else {
585
+ // check that the type is correct for empty input data
586
+ assert .Equal (t , & types.Int64Type {Nullability : types .NullabilityRequired }, argTypes [0 ])
587
+ }
588
+
589
+ assert .Equal (t , AggregateFuncType , tc .FuncType )
590
+ _ , err = tc .GetScalarFunctionInvocation (nil , nil )
591
+ require .Error (t , err )
592
+
593
+ reg := expr .NewEmptyExtensionRegistry (extensions .GetDefaultCollectionWithNoError ())
594
+ testGetFunctionInvocation (t , tc , & reg , nil )
595
+ data , err := tc .GetAggregateColumnsData ()
596
+ require .NoError (t , err )
597
+
598
+ // check that the data is correct
599
+ assert .Len (t , data , len (test .wantData ))
600
+ assert .Equal (t , test .wantData , data )
601
+ })
602
+ }
603
+ }
604
+
605
+ func TestBadInputsToGetAggregateColumnsData (t * testing.T ) {
606
+ tests := []struct {
607
+ name string
608
+ testCase * TestCase
609
+ expectedError error
610
+ }{
611
+ {
612
+ name : "invalid function type" ,
613
+ testCase : & TestCase {FuncType : ScalarFuncType },
614
+ expectedError : fmt .Errorf ("expected function type %v, but got %v" , AggregateFuncType , ScalarFuncType ),
615
+ },
616
+ {
617
+ name : "invalid argument type" ,
618
+ testCase : & TestCase {
619
+ FuncType : AggregateFuncType ,
620
+ AggregateArgs : []* AggregateArgument {{Argument : & CaseLiteral {Value : expr .NewNullLiteral (& types.Float32Type {})}}},
621
+ },
622
+ expectedError : fmt .Errorf ("column 0: expected NestedLiteral[ListLiteralValue], but got %T" , expr .NewNullLiteral (& types.Float32Type {})),
623
+ },
624
+ }
625
+
626
+ for _ , tt := range tests {
627
+ t .Run (tt .name , func (t * testing.T ) {
628
+ _ , err := tt .testCase .GetAggregateColumnsData ()
629
+ assert .Error (t , err )
630
+ assert .Equal (t , tt .expectedError .Error (), err .Error ())
631
+ })
632
+ }
633
+ }
634
+
545
635
func TestParseAggregateFuncWithMixedArgs (t * testing.T ) {
546
636
header := makeAggregateTestHeader ("v1.0" , "/extensions/functions_arithmetic.yaml" )
547
637
tests := `# basic
0 commit comments