|
18 | 18 |
|
19 | 19 | package org.apache.wayang.api.sql.calcite.converter; |
20 | 20 |
|
| 21 | +import java.util.List; |
| 22 | +import java.util.stream.Collectors; |
| 23 | + |
21 | 24 | import org.apache.calcite.rel.core.Join; |
22 | 25 | import org.apache.calcite.rex.RexCall; |
23 | 26 | import org.apache.calcite.rex.RexInputRef; |
24 | 27 | import org.apache.calcite.rex.RexNode; |
25 | | -import org.apache.calcite.rex.RexVisitorImpl; |
26 | 28 | import org.apache.calcite.sql.SqlKind; |
| 29 | + |
| 30 | +import org.apache.wayang.api.sql.calcite.converter.functions.JoinFlattenResult; |
| 31 | +import org.apache.wayang.api.sql.calcite.converter.functions.JoinKeyExtractor; |
27 | 32 | import org.apache.wayang.api.sql.calcite.rel.WayangJoin; |
| 33 | + |
28 | 34 | import org.apache.wayang.basic.data.Record; |
29 | 35 | import org.apache.wayang.basic.data.Tuple2; |
30 | 36 | import org.apache.wayang.basic.operators.JoinOperator; |
31 | 37 | import org.apache.wayang.basic.operators.MapOperator; |
32 | | -import org.apache.wayang.core.function.FunctionDescriptor; |
33 | 38 | import org.apache.wayang.core.function.TransformationDescriptor; |
34 | 39 | import org.apache.wayang.core.plan.wayangplan.Operator; |
| 40 | +import org.apache.wayang.core.util.ReflectionUtils; |
35 | 41 |
|
36 | 42 | public class WayangJoinVisitor extends WayangRelNodeVisitor<WayangJoin> { |
37 | 43 |
|
38 | | - WayangJoinVisitor(WayangRelConverter wayangRelConverter) { |
| 44 | + WayangJoinVisitor(final WayangRelConverter wayangRelConverter) { |
39 | 45 | super(wayangRelConverter); |
40 | 46 | } |
41 | 47 |
|
42 | 48 | @Override |
43 | | - Operator visit(WayangJoin wayangRelNode) { |
44 | | - Operator childOpLeft = wayangRelConverter.convert(wayangRelNode.getInput(0)); |
45 | | - Operator childOpRight = wayangRelConverter.convert(wayangRelNode.getInput(1)); |
| 49 | + Operator visit(final WayangJoin wayangRelNode) { |
| 50 | + final Operator childOpLeft = wayangRelConverter.convert(wayangRelNode.getInput(0)); |
| 51 | + final Operator childOpRight = wayangRelConverter.convert(wayangRelNode.getInput(1)); |
| 52 | + |
| 53 | + final RexNode condition = ((Join) wayangRelNode).getCondition(); |
| 54 | + final RexCall call = (RexCall) condition; |
46 | 55 |
|
47 | | - RexNode condition = ((Join) wayangRelNode).getCondition(); |
| 56 | + final List<Integer> keys = call.getOperands().stream() |
| 57 | + .map(RexInputRef.class::cast) |
| 58 | + .map(RexInputRef::getIndex) |
| 59 | + .collect(Collectors.toList()); |
48 | 60 |
|
| 61 | + assert (keys.size() == 2) : "Amount of keys found in join was not 2, got: " + keys.size(); |
| 62 | + |
49 | 63 | if (!condition.isA(SqlKind.EQUALS)) { |
50 | 64 | throw new UnsupportedOperationException("Only equality joins supported"); |
51 | 65 | } |
52 | 66 |
|
53 | | - //offset of the index in the right child |
54 | | - int offset = wayangRelNode.getInput(0).getRowType().getFieldCount(); |
| 67 | + // offset of the index in the right child |
| 68 | + final int offset = wayangRelNode.getInput(0).getRowType().getFieldCount(); |
55 | 69 |
|
56 | | - int leftKeyIndex = condition.accept(new KeyIndex(false, Child.LEFT)); |
57 | | - int rightKeyIndex = condition.accept(new KeyIndex(false, Child.RIGHT)) - offset; |
| 70 | + final int leftKeyIndex = keys.get(0); |
| 71 | + final int rightKeyIndex = keys.get(1) - offset; |
58 | 72 |
|
59 | | - JoinOperator<Record, Record, Object> join = new JoinOperator<>( |
60 | | - new TransformationDescriptor<>(new KeyExtractor(leftKeyIndex), Record.class, Object.class), |
61 | | - new TransformationDescriptor<>(new KeyExtractor(rightKeyIndex), Record.class, Object.class) |
62 | | - ); |
| 73 | + final JoinOperator<Record, Record, Object> join = new JoinOperator<>( |
| 74 | + new TransformationDescriptor<>(new JoinKeyExtractor(leftKeyIndex), Record.class, Object.class), |
| 75 | + new TransformationDescriptor<>(new JoinKeyExtractor(rightKeyIndex), Record.class, Object.class)); |
63 | 76 |
|
64 | | - //call connectTo on both operators (left and right) |
| 77 | + // call connectTo on both operators (left and right) |
65 | 78 | childOpLeft.connectTo(0, join, 0); |
66 | 79 | childOpRight.connectTo(0, join, 1); |
67 | 80 |
|
68 | 81 | // Join returns Tuple2 - map to a Record |
69 | | - MapOperator<Tuple2, Record> mapOperator = new MapOperator( |
70 | | - new MapFunctionImpl(), |
71 | | - Tuple2.class, |
72 | | - Record.class |
73 | | - ); |
| 82 | + final MapOperator<Tuple2<Record, Record>, Record> mapOperator = new MapOperator<Tuple2<Record, Record>, Record>( |
| 83 | + new JoinFlattenResult(), |
| 84 | + ReflectionUtils.specify(Tuple2.class), |
| 85 | + Record.class); |
| 86 | + |
74 | 87 | join.connectTo(0, mapOperator, 0); |
75 | 88 |
|
76 | 89 | return mapOperator; |
77 | 90 | } |
78 | | - |
79 | | - /** |
80 | | - * Extracts key index from the call |
81 | | - */ |
82 | | - private class KeyIndex extends RexVisitorImpl<Integer> { |
83 | | - final Child child; |
84 | | - |
85 | | - protected KeyIndex(boolean deep, Child child) { |
86 | | - super(deep); |
87 | | - this.child = child; |
88 | | - } |
89 | | - |
90 | | - @Override |
91 | | - public Integer visitCall(RexCall call) { |
92 | | - RexNode operand = call.getOperands().get(child.ordinal()); |
93 | | - if (!(operand instanceof RexInputRef)) { |
94 | | - throw new UnsupportedOperationException("Unsupported operation"); |
95 | | - } |
96 | | - RexInputRef rexInputRef = (RexInputRef) operand; |
97 | | - return rexInputRef.getIndex(); |
98 | | - } |
99 | | - } |
100 | | - |
101 | | - /** |
102 | | - * Extracts the key |
103 | | - */ |
104 | | - private class KeyExtractor implements FunctionDescriptor.SerializableFunction<Record, Object> { |
105 | | - private final int index; |
106 | | - |
107 | | - public KeyExtractor(int index) { |
108 | | - this.index = index; |
109 | | - } |
110 | | - |
111 | | - public Object apply(final Record record) { |
112 | | - return record.getField(index); |
113 | | - } |
114 | | - } |
115 | | - |
116 | | - /** |
117 | | - * Flattens Tuple2<Record, Record> to Record |
118 | | - */ |
119 | | - private class MapFunctionImpl implements FunctionDescriptor.SerializableFunction<Tuple2<Record, Record>, Record> { |
120 | | - public MapFunctionImpl() { |
121 | | - super(); |
122 | | - } |
123 | | - |
124 | | - @Override |
125 | | - public Record apply(final Tuple2<Record, Record> tuple2) { |
126 | | - int length1 = tuple2.getField0().size(); |
127 | | - int length2 = tuple2.getField1().size(); |
128 | | - |
129 | | - int totalLength = length1 + length2; |
130 | | - |
131 | | - Object[] fields = new Object[totalLength]; |
132 | | - |
133 | | - for (int i = 0; i < length1; i++) { |
134 | | - fields[i] = tuple2.getField0().getField(i); |
135 | | - } |
136 | | - for (int j = length1; j < totalLength; j++) { |
137 | | - fields[j] = tuple2.getField1().getField(j - length1); |
138 | | - } |
139 | | - return new Record(fields); |
140 | | - |
141 | | - } |
142 | | - } |
143 | | - |
144 | | - // Helpers |
145 | | - private enum Child { |
146 | | - LEFT, RIGHT |
147 | | - } |
148 | 91 | } |
0 commit comments