2121import java .util .ArrayList ;
2222import java .util .List ;
2323import java .util .function .BinaryOperator ;
24+ import java .util .stream .Collectors ;
2425
2526import org .apache .calcite .rex .RexCall ;
2627import org .apache .calcite .rex .RexInputRef ;
3031import org .apache .calcite .sql .fun .SqlStdOperatorTable ;
3132
3233import org .apache .wayang .core .function .FunctionDescriptor ;
34+ import org .apache .wayang .core .function .FunctionDescriptor .SerializableFunction ;
3335import org .apache .wayang .basic .data .Record ;
3436
35-
3637public class ProjectMapFuncImpl implements
3738 FunctionDescriptor .SerializableFunction <Record , Record > {
38- private final List <RexNode > projects ;
39+ final List <SerializableFunction < Record , Object >> projections ;
3940
4041 public ProjectMapFuncImpl (final List <RexNode > projects ) {
41- this .projects = projects ;
42- }
43-
44- @ Override
45- public Record apply (final Record record ) {
46-
47- final List <Object > projectedRecord = new ArrayList <>();
48- for (int i = 0 ; i < projects .size (); i ++) {
49- final RexNode exp = projects .get (i );
42+ this .projections = projects .stream ().map (exp -> {
5043 if (exp instanceof RexInputRef ) {
51- projectedRecord .add (record .getField (((RexInputRef ) exp ).getIndex ()));
44+ final int key = ((RexInputRef ) exp ).getIndex ();
45+ return (SerializableFunction <Record , Object >) record -> record .getField (key );
5246 } else if (exp instanceof RexLiteral ) {
53- final RexLiteral literal = (RexLiteral ) exp ;
54- projectedRecord . add ( literal . getValue ()) ;
47+ final Object literalValue = (( RexLiteral ) exp ). getValue () ;
48+ return ( SerializableFunction < Record , Object >) record -> literalValue ;
5549 } else if (exp instanceof RexCall ) {
56- projectedRecord .add (evaluateRexCall (record , (RexCall ) exp ));
50+ return (SerializableFunction <Record , Object >) record -> evaluateRexCall (record , (RexCall ) exp );
51+ } else {
52+ throw new UnsupportedOperationException ("Could not resolve record for exp: " + exp );
5753 }
58- }
59- return new Record (projectedRecord .toArray (new Object [0 ]));
54+ }).collect (Collectors .toList ());
6055 }
6156
62- public static Object evaluateRexCall (final Record record , final RexCall rexCall ) {
57+ public Object evaluateRexCall (final Record record , final RexCall rexCall ) {
6358 if (rexCall == null ) {
6459 return null ;
6560 }
@@ -70,7 +65,7 @@ public static Object evaluateRexCall(final Record record, final RexCall rexCall)
7065
7166 if (operator == SqlStdOperatorTable .PLUS ) {
7267 // Handle addition
73- return evaluateNaryOperation (record , operands , Double :: sum );
68+ return evaluateNaryOperation (record , operands , ( a , b ) -> a + b );
7469 } else if (operator == SqlStdOperatorTable .MINUS ) {
7570 // Handle subtraction
7671 return evaluateNaryOperation (record , operands , (a , b ) -> a - b );
@@ -85,7 +80,7 @@ public static Object evaluateRexCall(final Record record, final RexCall rexCall)
8580 }
8681 }
8782
88- public static Object evaluateNaryOperation (final Record record , final List <RexNode > operands ,
83+ public Object evaluateNaryOperation (final Record record , final List <RexNode > operands ,
8984 final BinaryOperator <Double > operation ) {
9085 if (operands .isEmpty ()) {
9186 return null ;
@@ -110,7 +105,7 @@ public static Object evaluateNaryOperation(final Record record, final List<RexNo
110105 return result ;
111106 }
112107
113- public static Object evaluateRexNode (final Record record , final RexNode rexNode ) {
108+ public Object evaluateRexNode (final Record record , final RexNode rexNode ) {
114109 if (rexNode instanceof RexCall ) {
115110 // Recursively evaluate a RexCall
116111 return evaluateRexCall (record , (RexCall ) rexNode );
@@ -124,4 +119,9 @@ public static Object evaluateRexNode(final Record record, final RexNode rexNode)
124119 return null ; // Unsupported or unknown expression
125120 }
126121 }
122+
123+ @ Override
124+ public Record apply (final Record record ) {
125+ return new Record (projections .stream ().map (func -> func .apply (record )).toArray ());
126+ }
127127}
0 commit comments