|
18 | 18 |
|
19 | 19 | package org.apache.wayang.api.sql.calcite.converter.functions; |
20 | 20 |
|
21 | | -import java.util.ArrayList; |
| 21 | +import java.io.Serializable; |
22 | 22 | import java.util.List; |
23 | | -import java.util.function.BinaryOperator; |
24 | 23 | import java.util.stream.Collectors; |
25 | 24 |
|
26 | 25 | import org.apache.calcite.rex.RexCall; |
27 | 26 | import org.apache.calcite.rex.RexInputRef; |
28 | 27 | import org.apache.calcite.rex.RexLiteral; |
29 | 28 | import org.apache.calcite.rex.RexNode; |
30 | | -import org.apache.calcite.sql.SqlOperator; |
31 | | -import org.apache.calcite.sql.fun.SqlStdOperatorTable; |
| 29 | +import org.apache.calcite.sql.SqlKind; |
32 | 30 |
|
33 | 31 | import org.apache.wayang.core.function.FunctionDescriptor; |
34 | | -import org.apache.wayang.core.function.FunctionDescriptor.SerializableFunction; |
| 32 | +import org.apache.wayang.core.function.FunctionDescriptor.SerializableBiFunction; |
35 | 33 | import org.apache.wayang.basic.data.Record; |
36 | 34 |
|
37 | 35 | public class ProjectMapFuncImpl implements |
38 | 36 | FunctionDescriptor.SerializableFunction<Record, Record> { |
39 | | - final List<SerializableFunction<Record, Object>> projections; |
40 | 37 |
|
41 | | - public ProjectMapFuncImpl(final List<RexNode> projects) { |
42 | | - this.projections = projects.stream().map(exp -> { |
43 | | - if (exp instanceof RexInputRef) { |
44 | | - final int key = ((RexInputRef) exp).getIndex(); |
45 | | - return (SerializableFunction<Record, Object>) record -> record.getField(key); |
46 | | - } else if (exp instanceof RexLiteral) { |
47 | | - final Object literalValue = ((RexLiteral) exp).getValue(); |
48 | | - return (SerializableFunction<Record, Object>) record -> literalValue; |
49 | | - } else if (exp instanceof RexCall) { |
50 | | - return (SerializableFunction<Record, Object>) record -> evaluateRexCall(record, (RexCall) exp); |
| 38 | + /** |
| 39 | + * Serializable representation of {@link RexNode} |
| 40 | + */ |
| 41 | + abstract class Node implements Serializable { |
| 42 | + public Node transform(final RexNode node) { |
| 43 | + if (node instanceof RexCall) { |
| 44 | + return new Call((RexCall) node); |
| 45 | + } else if (node instanceof RexInputRef) { |
| 46 | + return new InputRef((RexInputRef) node); |
| 47 | + } else if (node instanceof RexLiteral) { |
| 48 | + return new Literal((RexLiteral) node); |
51 | 49 | } else { |
52 | | - throw new UnsupportedOperationException("Could not resolve record for exp: " + exp); |
| 50 | + throw new UnsupportedOperationException("RexNode not supported in projection function: " + node); |
53 | 51 | } |
54 | | - }).collect(Collectors.toList()); |
| 52 | + } |
| 53 | + |
| 54 | + abstract Object evaluate(Record record); |
55 | 55 | } |
56 | 56 |
|
57 | | - public Object evaluateRexCall(final Record record, final RexCall rexCall) { |
58 | | - if (rexCall == null) { |
59 | | - return null; |
| 57 | + /** |
| 58 | + * Serializable representation of {@link RexCall} |
| 59 | + */ |
| 60 | + class Call extends Node { |
| 61 | + final SerializableBiFunction<Number, Number, Number> operation; |
| 62 | + final List<Node> children; |
| 63 | + |
| 64 | + public Call(final RexCall call) { |
| 65 | + this.operation = deriveOperation(call.getOperator().getKind()); |
| 66 | + this.children = call.getOperands().stream() |
| 67 | + .map(op -> this.transform(op)) |
| 68 | + .collect(Collectors.toList()); |
60 | 69 | } |
61 | 70 |
|
62 | | - // Get the operator and operands |
63 | | - final SqlOperator operator = rexCall.getOperator(); |
64 | | - final List<RexNode> operands = rexCall.getOperands(); |
65 | | - |
66 | | - if (operator == SqlStdOperatorTable.PLUS) { |
67 | | - // Handle addition |
68 | | - return evaluateNaryOperation(record, operands, (a, b) -> a + b); |
69 | | - } else if (operator == SqlStdOperatorTable.MINUS) { |
70 | | - // Handle subtraction |
71 | | - return evaluateNaryOperation(record, operands, (a, b) -> a - b); |
72 | | - } else if (operator == SqlStdOperatorTable.MULTIPLY) { |
73 | | - // Handle multiplication |
74 | | - return evaluateNaryOperation(record, operands, (a, b) -> a * b); |
75 | | - } else if (operator == SqlStdOperatorTable.DIVIDE) { |
76 | | - // Handle division |
77 | | - return evaluateNaryOperation(record, operands, (a, b) -> a / b); |
78 | | - } else { |
79 | | - return null; |
| 71 | + public Object evaluate(final Record record) { |
| 72 | + assert (children.size() == 2) : "Project func call should only have two children"; |
| 73 | + return operation.apply((Number) children.get(0).evaluate(record), |
| 74 | + (Number) children.get(1).evaluate(record)); |
80 | 75 | } |
81 | | - } |
82 | 76 |
|
83 | | - public Object evaluateNaryOperation(final Record record, final List<RexNode> operands, |
84 | | - final BinaryOperator<Double> operation) { |
85 | | - if (operands.isEmpty()) { |
86 | | - return null; |
| 77 | + /** |
| 78 | + * Derives the java operator for a given {@link SqlKind}, and turns it into a serializable function |
| 79 | + * @param kind {@link SqlKind} from {@link RexCall} SqlOperator |
| 80 | + * @return a serializable function of +, -, * or / |
| 81 | + * @throws UnsupportedOperationException on unrecognized {@link SqlKind} |
| 82 | + */ |
| 83 | + static SerializableBiFunction<Number, Number, Number> deriveOperation(final SqlKind kind) { |
| 84 | + return (a, b) -> { |
| 85 | + final double l = a.doubleValue(); |
| 86 | + final double r = b.doubleValue(); |
| 87 | + switch (kind) { |
| 88 | + case PLUS: |
| 89 | + return l + r; |
| 90 | + case MINUS: |
| 91 | + return l - r; |
| 92 | + case TIMES: |
| 93 | + return l * r; |
| 94 | + case DIVIDE: |
| 95 | + return l / r; |
| 96 | + default: |
| 97 | + throw new UnsupportedOperationException( |
| 98 | + "Operation not supported in projection function RexCall: " + kind); |
| 99 | + } |
| 100 | + }; |
87 | 101 | } |
| 102 | + } |
88 | 103 |
|
89 | | - final List<Double> values = new ArrayList<>(); |
| 104 | + /** |
| 105 | + * Serializable representation of {@link RexLiteral} |
| 106 | + */ |
| 107 | + class Literal extends Node { |
| 108 | + final Comparable<?> value; |
90 | 109 |
|
91 | | - for (int i = 0; i < operands.size(); i++) { |
92 | | - final Number val = (Number) evaluateRexNode(record, operands.get(i)); |
93 | | - if (val == null) { |
94 | | - return null; |
95 | | - } |
96 | | - values.add(val.doubleValue()); |
| 110 | + Literal(final RexLiteral literal) { |
| 111 | + this.value = literal.getValueAs(Double.class); |
97 | 112 | } |
98 | 113 |
|
99 | | - Object result = values.get(0); |
100 | | - // Perform the operation with the remaining operands |
101 | | - for (int i = 1; i < operands.size(); i++) { |
102 | | - result = operation.apply((double) result, values.get(i)); |
| 114 | + @Override |
| 115 | + public Object evaluate(final Record record) { |
| 116 | + return value; |
103 | 117 | } |
104 | | - |
105 | | - return result; |
106 | 118 | } |
107 | 119 |
|
108 | | - public Object evaluateRexNode(final Record record, final RexNode rexNode) { |
109 | | - if (rexNode instanceof RexCall) { |
110 | | - // Recursively evaluate a RexCall |
111 | | - return evaluateRexCall(record, (RexCall) rexNode); |
112 | | - } else if (rexNode instanceof RexLiteral) { |
113 | | - // Handle literals (e.g., numbers) |
114 | | - final RexLiteral literal = (RexLiteral) rexNode; |
115 | | - return literal.getValue(); |
116 | | - } else if (rexNode instanceof RexInputRef) { |
117 | | - return record.getField(((RexInputRef) rexNode).getIndex()); |
118 | | - } else { |
119 | | - return null; // Unsupported or unknown expression |
| 120 | + /** |
| 121 | + * Serializable representation of {@link InputRef} |
| 122 | + */ |
| 123 | + class InputRef extends Node { |
| 124 | + final int key; |
| 125 | + |
| 126 | + InputRef(final RexInputRef inputRef) { |
| 127 | + this.key = inputRef.getIndex(); |
| 128 | + } |
| 129 | + |
| 130 | + @Override |
| 131 | + public Object evaluate(final Record record) { |
| 132 | + return record.getField(key); |
120 | 133 | } |
121 | 134 | } |
122 | 135 |
|
| 136 | + /** |
| 137 | + * AST of the {@link RexCall} arithmetic, composed into serializable nodes; {@link Call}, {@link InputRef}, {@link Literal} |
| 138 | + */ |
| 139 | + final List<Node> projectionSyntaxTrees; |
| 140 | + |
| 141 | + public ProjectMapFuncImpl(final List<RexNode> projects) { |
| 142 | + this.projectionSyntaxTrees = projects.stream() |
| 143 | + .map(projection -> { |
| 144 | + if (projection instanceof RexCall) { |
| 145 | + return new Call((RexCall) projection); |
| 146 | + } else if (projection instanceof RexLiteral) { |
| 147 | + return new Literal((RexLiteral) projection); |
| 148 | + } else if (projection instanceof RexInputRef) { |
| 149 | + return new InputRef((RexInputRef) projection); |
| 150 | + } else { |
| 151 | + throw new UnsupportedOperationException("RexNode not supported in projection: " + projection); |
| 152 | + } |
| 153 | + }) |
| 154 | + .collect(Collectors.toList()); |
| 155 | + } |
| 156 | + |
123 | 157 | @Override |
124 | 158 | public Record apply(final Record record) { |
125 | | - return new Record(projections.stream().map(func -> func.apply(record)).toArray()); |
| 159 | + return new Record(projectionSyntaxTrees.stream().map(call -> call.evaluate(record)).toArray()); |
126 | 160 | } |
127 | 161 | } |
0 commit comments