Skip to content

Commit 91573ae

Browse files
authored
Merge pull request #544 from mspruc/main
Multi conditional joins for sql-api
2 parents ee1a7da + 489da98 commit 91573ae

File tree

5 files changed

+270
-1
lines changed

5 files changed

+270
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.wayang.api.sql.calcite.converter;
20+
21+
import java.io.Serializable;
22+
import java.util.List;
23+
import java.util.stream.Collectors;
24+
import org.apache.calcite.rel.core.Join;
25+
import org.apache.calcite.rex.RexCall;
26+
import org.apache.calcite.rex.RexInputRef;
27+
import org.apache.calcite.rex.RexNode;
28+
import org.apache.wayang.api.sql.calcite.converter.functions.MultiConditionJoinFuncImpl;
29+
import org.apache.wayang.api.sql.calcite.converter.functions.MultiConditionJoinKeyExtractor;
30+
import org.apache.wayang.api.sql.calcite.rel.WayangJoin;
31+
import org.apache.wayang.basic.data.Record;
32+
import org.apache.wayang.basic.data.Tuple2;
33+
import org.apache.wayang.basic.operators.JoinOperator;
34+
import org.apache.wayang.basic.operators.MapOperator;
35+
import org.apache.wayang.core.function.TransformationDescriptor;
36+
import org.apache.wayang.core.function.FunctionDescriptor.SerializableFunction;
37+
import org.apache.wayang.core.plan.wayangplan.Operator;
38+
import org.apache.wayang.core.util.ReflectionUtils;
39+
40+
public class WayangMultiConditionJoinVisitor extends WayangRelNodeVisitor<WayangJoin> implements Serializable {
41+
42+
/**
43+
* Visitor that visits join statements that has multiple conditions like:
44+
* AND(=($1,$2),=($2,$3))
45+
* Note that this doesnt support nway joins or multijoins.
46+
*
47+
* @param wayangRelConverter
48+
*/
49+
WayangMultiConditionJoinVisitor(final WayangRelConverter wayangRelConverter) {
50+
super(wayangRelConverter);
51+
}
52+
53+
@Override
54+
Operator visit(WayangJoin wayangRelNode) {
55+
final Operator childOpLeft = wayangRelConverter.convert(wayangRelNode.getInput(0));
56+
final Operator childOpRight = wayangRelConverter.convert(wayangRelNode.getInput(1));
57+
final RexNode condition = ((Join) wayangRelNode).getCondition();
58+
final RexCall call = (RexCall) condition;
59+
60+
//
61+
final List<RexCall> subConditions = call.operands.stream()
62+
.map(RexCall.class::cast)
63+
.collect(Collectors.toList());
64+
65+
// calcite generates the RexInputRef indexes via looking at the union
66+
// field list of the left and right input of a join.
67+
// since the left input is always the first in this joined field list
68+
// we can eagerly get the fields in the left input
69+
final List<RexInputRef> leftTableInputRefs = subConditions.stream()
70+
.map(sub -> sub.getOperands().stream()
71+
.map(RexInputRef.class::cast)
72+
.min((left, right) -> Integer.compare(left.getIndex(), right.getIndex()))
73+
.get())
74+
.collect(Collectors.toList());
75+
76+
final Integer[] leftTableKeyIndexes = leftTableInputRefs.stream()
77+
.map(RexInputRef::getIndex)
78+
.toArray(Integer[]::new);
79+
80+
// for the right table input refs, the indexes are offset by the amount of rows
81+
// in the left
82+
// input to the join
83+
final List<RexInputRef> rightTableInputRefs = subConditions.stream()
84+
.map(sub -> sub.getOperands().stream()
85+
.map(RexInputRef.class::cast)
86+
.max((left, right) -> Integer.compare(left.getIndex(), right.getIndex()))
87+
.get())
88+
.collect(Collectors.toList());
89+
90+
final Integer[] rightTableKeyIndexes = rightTableInputRefs.stream()
91+
.map(RexInputRef::getIndex)
92+
.map(key -> key - wayangRelNode.getLeft().getRowType().getFieldCount()) // apply offset
93+
.toArray(Integer[]::new);
94+
95+
/*
96+
final List<RelDataTypeField> leftFields = Arrays.stream(leftTableKeyIndexes)
97+
.map(key -> wayangRelNode.getLeft().getRowType().getFieldList().get(key))
98+
.collect(Collectors.toList());
99+
100+
final List<RelDataTypeField> rightFields = Arrays.stream(rightTableKeyIndexes)
101+
.map(key -> wayangRelNode.getRight().getRowType().getFieldList().get(key))
102+
.collect(Collectors.toList());
103+
104+
final String joiningTableName = childOpLeft instanceof WayangTableScan ? childOpLeft.getName() : childOpRight.getName();
105+
*/
106+
107+
// if join is joining the LHS of a join condition "JOIN left ON left = right"
108+
// then we pick the first case, otherwise the 2nd "JOIN right ON left = right"
109+
final JoinOperator<Record, Record, Record> join = this.getJoinOperator(
110+
leftTableKeyIndexes,
111+
rightTableKeyIndexes,
112+
wayangRelNode,
113+
"",
114+
"",
115+
"",
116+
"");
117+
118+
childOpLeft.connectTo(0, join, 0);
119+
childOpRight.connectTo(0, join, 1);
120+
121+
// Join returns Tuple2 - map to a Record
122+
final SerializableFunction<Tuple2<Record, Record>, Record> mp = new MultiConditionJoinFuncImpl();
123+
124+
final MapOperator<Tuple2<Record, Record>, Record> mapOperator = new MapOperator<Tuple2<Record, Record>, Record>(
125+
mp,
126+
ReflectionUtils.specify(Tuple2.class),
127+
Record.class);
128+
129+
join.connectTo(0, mapOperator, 0);
130+
131+
return mapOperator;
132+
}
133+
134+
/**
135+
* This method handles the {@link JoinOperator} creation, used in conjunction
136+
* with:
137+
* {@link #determineKeyExtractionDirection(Integer, Integer, WayangJoin)}
138+
*
139+
* @param wayangRelNode
140+
* @param leftKeyIndex
141+
* @param rightKeyIndex
142+
* @return a {@link JoinOperator} with {@link KeyExtractors} set
143+
*/
144+
protected JoinOperator<Record, Record, Record> getJoinOperator(final Integer[] leftKeyIndexes,
145+
final Integer[] rightKeyIndexes,
146+
final WayangJoin wayangRelNode, final String leftTableName, final String leftFieldNames,
147+
final String rightTableName, final String rightFieldNames) {
148+
// TODO: needs withSqlImplementation() for sql support
149+
150+
if (wayangRelNode.getInputs().size() != 2)
151+
throw new UnsupportedOperationException("Join had an unexpected amount of inputs, found: "
152+
+ wayangRelNode.getInputs().size() + ", expected: 2");
153+
154+
final TransformationDescriptor<Record, Record> leftProjectionDescriptor = new TransformationDescriptor<Record, Record>(
155+
new MultiConditionJoinKeyExtractor(leftKeyIndexes),
156+
Record.class, Record.class);
157+
// .withSqlImplementation(""," ")
158+
159+
final TransformationDescriptor<Record, Record> rightProjectionDescriptor = new TransformationDescriptor<Record, Record>(
160+
new MultiConditionJoinKeyExtractor(rightKeyIndexes),
161+
Record.class, Record.class);
162+
// .withSqlImplementation(""," ")
163+
164+
final JoinOperator<Record, Record, Record> join = new JoinOperator<>(
165+
leftProjectionDescriptor,
166+
rightProjectionDescriptor);
167+
168+
return join;
169+
}
170+
}

wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/WayangRelConverter.java

+3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package org.apache.wayang.api.sql.calcite.converter;
2121

2222
import org.apache.calcite.rel.RelNode;
23+
import org.apache.calcite.sql.SqlKind;
2324
import org.apache.wayang.api.sql.calcite.rel.*;
2425
import org.apache.wayang.core.api.Configuration;
2526
import org.apache.wayang.core.plan.wayangplan.Operator;
@@ -54,6 +55,8 @@ public Operator convert(final RelNode node) {
5455
return new WayangProjectVisitor(this).visit((WayangProject) node);
5556
} else if (node instanceof WayangFilter) {
5657
return new WayangFilterVisitor(this).visit((WayangFilter) node);
58+
} else if (node instanceof WayangJoin && ((WayangJoin) node).getCondition().isA(SqlKind.AND)) {
59+
return new WayangMultiConditionJoinVisitor(this).visit((WayangJoin) node);
5760
} else if (node instanceof WayangJoin && WayangJoin.class.cast(node).getCondition().isAlwaysTrue()) {
5861
return new WayangCrossJoinVisitor(this).visit((WayangJoin) node);
5962
} else if (node instanceof WayangJoin) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.wayang.api.sql.calcite.converter.functions;
20+
21+
import org.apache.wayang.basic.data.Record;
22+
import org.apache.wayang.basic.data.Tuple2;
23+
import org.apache.wayang.core.function.FunctionDescriptor;
24+
25+
/**
26+
* Flattens Tuple2<Record, Record> to Record
27+
*/
28+
public class MultiConditionJoinFuncImpl implements FunctionDescriptor.SerializableFunction<Tuple2<Record, Record>, Record> {
29+
30+
public MultiConditionJoinFuncImpl() {
31+
32+
}
33+
34+
@Override
35+
public Record apply(final Tuple2<Record, Record> tuple2) {
36+
final int length1 = ((Tuple2<Record, Record>) tuple2).getField0().size();
37+
final int length2 = ((Tuple2<Record, Record>) tuple2).getField1().size();
38+
39+
final int totalLength = length1 + length2;
40+
41+
final Object[] fields = new Object[totalLength];
42+
43+
for (int i = 0; i < length1; i++) {
44+
fields[i] = ((Tuple2<Record, Record>) tuple2).getField0().getField(i);
45+
}
46+
for (int j = length1; j < totalLength; j++) {
47+
fields[j] = ((Tuple2<Record, Record>) tuple2).getField1().getField(j - length1);
48+
}
49+
50+
return new Record(fields);
51+
}
52+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.wayang.api.sql.calcite.converter.functions;
20+
21+
import java.util.Arrays;
22+
import java.util.function.Function;
23+
24+
import org.apache.wayang.basic.data.Record;
25+
import org.apache.wayang.core.function.FunctionDescriptor;
26+
27+
public class MultiConditionJoinKeyExtractor implements FunctionDescriptor.SerializableFunction<Record, Record> {
28+
private final Integer[] indexes;
29+
30+
/**
31+
* Extracts a key for a {@link WayangMultiConditionJoinVisitor}.
32+
* is a subtype of {@link Function}, {@link Serializable} (as required by engines which use serialisation i.e. flink/spark)
33+
* Takes an input {@link Record} & {@link Integer} key and maps it to a generic field object T.
34+
* Performs an unchecked cast when applied.
35+
* @param index key
36+
*/
37+
public MultiConditionJoinKeyExtractor(final Integer... indexes) {
38+
this.indexes = indexes;
39+
}
40+
41+
public Record apply(final Record record) {
42+
return new Record(Arrays.stream(indexes).map(record::getField).toArray());
43+
}
44+
}

wayang-api/wayang-api-sql/src/test/java/org/apache/wayang/api/sql/SqlToWayangRelTest.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ public Tuple2<Collection<Record>, WayangPlan> buildCollectorAndWayangPlan(final
111111
return new Tuple2<>(collector, wayangPlan);
112112
}
113113

114-
// @Test
114+
@Test
115115
public void javaMultiConditionJoin() throws Exception {
116116
final SqlContext sqlContext = this.createSqlContext("/data/largeLeftTableIndex.csv");
117117
// SELECT acc.location, count(*) FROM postgres.site

0 commit comments

Comments
 (0)