@@ -36,43 +36,74 @@ type FunctionVariant interface {
3636	MatchAt (typ  types.Type , pos  int ) (bool , error )
3737}
3838
39- func  EvaluateTypeExpression (nullHandling  NullabilityHandling , expr  parser.TypeExpression , paramTypeList  ArgumentList , actualTypes  []types.Type ) (types.Type , error ) {
39+ func  validateType (arg  Argument , actual  types.Type , idx  int , nullHandling  NullabilityHandling ) (bool , error ) {
40+ 	allNonNull  :=  true 
41+ 	switch  p  :=  arg .(type ) {
42+ 	case  EnumArg :
43+ 		if  actual  !=  nil  {
44+ 			return  allNonNull , fmt .Errorf ("%w: arg #%d (%s) should be an enum" ,
45+ 				substraitgo .ErrInvalidType , idx , p .Name )
46+ 		}
47+ 	case  ValueArg :
48+ 		if  actual  ==  nil  {
49+ 			return  allNonNull , fmt .Errorf ("%w: arg #%d should be of type %s" ,
50+ 				substraitgo .ErrInvalidType , idx , p .toTypeString ())
51+ 		}
52+ 
53+ 		isNullable  :=  actual .GetNullability () !=  types .NullabilityRequired 
54+ 		if  isNullable  {
55+ 			allNonNull  =  false 
56+ 		}
57+ 
58+ 		if  nullHandling  ==  DiscreteNullability  {
59+ 			if  t , ok  :=  p .Value .Expr .(* parser.Type ); ok  {
60+ 				if  isNullable  !=  t .Optional () {
61+ 					return  allNonNull , fmt .Errorf ("%w: discrete nullability did not match for arg #%d" ,
62+ 						substraitgo .ErrInvalidType , idx )
63+ 				}
64+ 			} else  {
65+ 				return  allNonNull , substraitgo .ErrNotImplemented 
66+ 			}
67+ 		}
68+ 	case  TypeArg :
69+ 		return  allNonNull , substraitgo .ErrNotImplemented 
70+ 	}
71+ 
72+ 	return  allNonNull , nil 
73+ }
74+ 
75+ func  EvaluateTypeExpression (nullHandling  NullabilityHandling , expr  parser.TypeExpression , paramTypeList  ArgumentList , variadic  * VariadicBehavior , actualTypes  []types.Type ) (types.Type , error ) {
4076	if  len (paramTypeList ) !=  len (actualTypes ) {
41- 		return  nil , fmt .Errorf ("%w: mismatch in number of arguments provided. got %d, expected %d" ,
42- 			substraitgo .ErrInvalidExpr , len (actualTypes ), len (paramTypeList ))
77+ 		if  variadic  ==  nil  {
78+ 			return  nil , fmt .Errorf ("%w: mismatch in number of arguments provided. got %d, expected %d" ,
79+ 				substraitgo .ErrInvalidExpr , len (actualTypes ), len (paramTypeList ))
80+ 		}
81+ 
82+ 		if  ! variadic .IsValidArgumentCount (len (actualTypes ) -  len (paramTypeList ) -  1 ) {
83+ 			return  nil , fmt .Errorf ("%w: mismatch in number of arguments provided, invalid number of variadic params. got %d total" ,
84+ 				substraitgo .ErrInvalidExpr , len (actualTypes ))
85+ 		}
4386	}
4487
4588	allNonNull  :=  true 
4689	for  i , p  :=  range  paramTypeList  {
47- 		switch  p  :=  p .(type ) {
48- 		case  EnumArg :
49- 			if  actualTypes [i ] !=  nil  {
50- 				return  nil , fmt .Errorf ("%w: arg #%d (%s) should be an enum" ,
51- 					substraitgo .ErrInvalidType , i , p .Name )
52- 			}
53- 		case  ValueArg :
54- 			if  actualTypes [i ] ==  nil  {
55- 				return  nil , fmt .Errorf ("%w: arg #%d should be of type %s" ,
56- 					substraitgo .ErrInvalidType , i , p .toTypeString ())
57- 			}
58- 
59- 			isNullable  :=  actualTypes [i ].GetNullability () !=  types .NullabilityRequired 
60- 			if  isNullable  {
61- 				allNonNull  =  false 
62- 			}
90+ 		nonNull , err  :=  validateType (p , actualTypes [i ], i , nullHandling )
91+ 		if  err  !=  nil  {
92+ 			return  nil , err 
93+ 		}
94+ 		allNonNull  =  allNonNull  &&  nonNull 
95+ 	}
6396
64- 			if  nullHandling  ==  DiscreteNullability  {
65- 				if  t , ok  :=  p .Value .Expr .(* parser.Type ); ok  {
66- 					if  isNullable  !=  t .Optional () {
67- 						return  nil , fmt .Errorf ("%w: discrete nullability did not match for arg #%d" ,
68- 							substraitgo .ErrInvalidType , i )
69- 					}
70- 				} else  {
71- 					return  nil , substraitgo .ErrNotImplemented 
72- 				}
97+ 	// validate varidic argument consistency 
98+ 	if  variadic  !=  nil  &&  len (actualTypes ) >  len (paramTypeList ) &&  variadic .ParameterConsistency  ==  ConsistentParams  {
99+ 		nparams  :=  len (paramTypeList )
100+ 		lastParam  :=  paramTypeList [nparams - 1 ]
101+ 		for  i , actual  :=  range  actualTypes [nparams :] {
102+ 			nonNull , err  :=  validateType (lastParam , actual , nparams + i , nullHandling )
103+ 			if  err  !=  nil  {
104+ 				return  nil , err 
73105			}
74- 		case  TypeArg :
75- 			return  nil , substraitgo .ErrNotImplemented 
106+ 			allNonNull  =  allNonNull  &&  nonNull 
76107		}
77108	}
78109
@@ -267,7 +298,7 @@ func (s *ScalarFunctionVariant) SessionDependent() bool           { return s.imp
267298func  (s  * ScalarFunctionVariant ) Nullability () NullabilityHandling  { return  s .impl .Nullability  }
268299func  (s  * ScalarFunctionVariant ) URI () string                       { return  s .uri  }
269300func  (s  * ScalarFunctionVariant ) ResolveType (argumentTypes  []types.Type ) (types.Type , error ) {
270- 	return  EvaluateTypeExpression (s .impl .Nullability , s .impl .Return , s .impl .Args , argumentTypes )
301+ 	return  EvaluateTypeExpression (s .impl .Nullability , s .impl .Return , s .impl .Args , s . impl . Variadic ,  argumentTypes )
271302}
272303func  (s  * ScalarFunctionVariant ) CompoundName () string  {
273304	return  s .name  +  ":"  +  s .impl .signatureKey ()
@@ -375,7 +406,7 @@ func (s *AggregateFunctionVariant) SessionDependent() bool           { return s.
375406func  (s  * AggregateFunctionVariant ) Nullability () NullabilityHandling  { return  s .impl .Nullability  }
376407func  (s  * AggregateFunctionVariant ) URI () string                       { return  s .uri  }
377408func  (s  * AggregateFunctionVariant ) ResolveType (argumentTypes  []types.Type ) (types.Type , error ) {
378- 	return  EvaluateTypeExpression (s .impl .Nullability , s .impl .Return , s .impl .Args , argumentTypes )
409+ 	return  EvaluateTypeExpression (s .impl .Nullability , s .impl .Return , s .impl .Args , s . impl . Variadic ,  argumentTypes )
379410}
380411func  (s  * AggregateFunctionVariant ) CompoundName () string  {
381412	return  s .name  +  ":"  +  s .impl .signatureKey ()
@@ -488,7 +519,7 @@ func (s *WindowFunctionVariant) SessionDependent() bool           { return s.imp
488519func  (s  * WindowFunctionVariant ) Nullability () NullabilityHandling  { return  s .impl .Nullability  }
489520func  (s  * WindowFunctionVariant ) URI () string                       { return  s .uri  }
490521func  (s  * WindowFunctionVariant ) ResolveType (argumentTypes  []types.Type ) (types.Type , error ) {
491- 	return  EvaluateTypeExpression (s .impl .Nullability , s .impl .Return , s .impl .Args , argumentTypes )
522+ 	return  EvaluateTypeExpression (s .impl .Nullability , s .impl .Return , s .impl .Args , s . impl . Variadic ,  argumentTypes )
492523}
493524func  (s  * WindowFunctionVariant ) CompoundName () string  {
494525	return  s .name  +  ":"  +  s .impl .signatureKey ()
0 commit comments