Skip to content

Commit add837a

Browse files
shlok7296julianhyde
authored andcommitted
[CALCITE-4233] In Elasticsearch adapter, support generating disjunction max (dis_max) queries (shlok7296)
close #2218
1 parent e7c579f commit add837a

File tree

3 files changed

+94
-1
lines changed

3 files changed

+94
-1
lines changed

elasticsearch/src/main/java/org/apache/calcite/adapter/elasticsearch/ElasticsearchFilter.java

+16-1
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,17 @@
2323
import org.apache.calcite.rel.RelNode;
2424
import org.apache.calcite.rel.core.Filter;
2525
import org.apache.calcite.rel.metadata.RelMetadataQuery;
26+
import org.apache.calcite.rex.RexCall;
2627
import org.apache.calcite.rex.RexNode;
28+
import org.apache.calcite.sql.SqlKind;
2729

2830
import com.fasterxml.jackson.core.JsonGenerator;
2931
import com.fasterxml.jackson.databind.ObjectMapper;
3032

3133
import java.io.IOException;
3234
import java.io.StringWriter;
3335
import java.io.UncheckedIOException;
36+
import java.util.Iterator;
3437
import java.util.Objects;
3538

3639
/**
@@ -82,7 +85,19 @@ String translateMatch(RexNode condition) throws IOException,
8285

8386
StringWriter writer = new StringWriter();
8487
JsonGenerator generator = mapper.getFactory().createGenerator(writer);
85-
QueryBuilders.constantScoreQuery(PredicateAnalyzer.analyze(condition)).writeJson(generator);
88+
boolean disMax = condition.isA(SqlKind.OR);
89+
Iterator<RexNode> operands = ((RexCall) condition).getOperands().iterator();
90+
while (operands.hasNext() && !disMax) {
91+
if (operands.next().isA(SqlKind.OR)) {
92+
disMax = true;
93+
break;
94+
}
95+
}
96+
if (disMax) {
97+
QueryBuilders.disMaxQueryBuilder(PredicateAnalyzer.analyze(condition)).writeJson(generator);
98+
} else {
99+
QueryBuilders.constantScoreQuery(PredicateAnalyzer.analyze(condition)).writeJson(generator);
100+
}
86101
generator.flush();
87102
generator.close();
88103
return "{\"query\" : " + writer.toString() + "}";

elasticsearch/src/main/java/org/apache/calcite/adapter/elasticsearch/QueryBuilders.java

+37
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,16 @@ static ConstantScoreQueryBuilder constantScoreQuery(QueryBuilder queryBuilder) {
186186
return new ConstantScoreQueryBuilder(queryBuilder);
187187
}
188188

189+
/**
190+
* A query that wraps another query and simply returns a dismax score equal to the
191+
* query boost for every document in the query.
192+
*
193+
* @param queryBuilder The query to wrap in a constant score query
194+
*/
195+
static DisMaxQueryBuilder disMaxQueryBuilder(QueryBuilder queryBuilder) {
196+
return new DisMaxQueryBuilder(queryBuilder);
197+
}
198+
189199
/**
190200
* A filter to filter only documents where a field exists in them.
191201
*
@@ -540,6 +550,33 @@ private ConstantScoreQueryBuilder(final QueryBuilder builder) {
540550
}
541551
}
542552

553+
/**
554+
* A query that wraps a filter and simply returns a dismax score equal to the
555+
* query boost for every document in the filter.
556+
*/
557+
static class DisMaxQueryBuilder extends QueryBuilder {
558+
559+
private final QueryBuilder builder;
560+
561+
private DisMaxQueryBuilder(final QueryBuilder builder) {
562+
this.builder = Objects.requireNonNull(builder, "builder");
563+
}
564+
565+
@Override void writeJson(final JsonGenerator generator) throws IOException {
566+
generator.writeStartObject();
567+
generator.writeFieldName("dis_max");
568+
generator.writeStartObject();
569+
generator.writeFieldName("queries");
570+
generator.writeStartArray();
571+
builder.writeJson(generator);
572+
generator.writeEndArray();
573+
generator.writeEndObject();
574+
generator.writeEndObject();
575+
}
576+
}
577+
578+
579+
543580
/**
544581
* A query that matches on all documents.
545582
* <pre>

elasticsearch/src/test/java/org/apache/calcite/adapter/elasticsearch/ElasticSearchAdapterTest.java

+41
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,19 @@ private static Consumer<ResultSet> sortedResultSetChecker(String column,
295295
.query("select _MAP['state'] from elastic.zips order by _MAP['city']")
296296
.returnsCount(ZIPS_SIZE);
297297

298+
CalciteAssert.that()
299+
.with(newConnectionFactory())
300+
.query("select * from elastic.zips where _MAP['state'] = 'NY' or "
301+
+ "_MAP['city'] = 'BROOKLYN'"
302+
+ " order by _MAP['city']")
303+
.queryContains(
304+
ElasticsearchChecker.elasticsearchChecker(
305+
"query:{'dis_max':{'queries':[{'bool':{'should':"
306+
+ "[{'term':{'state':'NY'}},{'term':"
307+
+ "{'city':'BROOKLYN'}}]}}]}},'sort':[{'city':'asc'}]",
308+
String.format(Locale.ROOT, "size:%s",
309+
ElasticsearchTransport.DEFAULT_FETCH_SIZE)));
310+
298311
CalciteAssert.that()
299312
.with(newConnectionFactory())
300313
.query("select _MAP['city'] from elastic.zips where _MAP['state'] = 'NY' "
@@ -421,6 +434,34 @@ private static Consumer<ResultSet> sortedResultSetChecker(String column,
421434
.explainContains(explain);
422435
}
423436

437+
@Test public void testDismaxQuery() {
438+
final String sql = "select * from zips\n"
439+
+ "where state = 'CA' or pop >= 94000\n"
440+
+ "order by state, pop";
441+
final String explain = "PLAN=ElasticsearchToEnumerableConverter\n"
442+
+ " ElasticsearchSort(sort0=[$4], sort1=[$3], dir0=[ASC], dir1=[ASC])\n"
443+
+ " ElasticsearchProject(city=[CAST(ITEM($0, 'city')):VARCHAR(20)], longitude=[CAST(ITEM(ITEM($0, 'loc'), 0)):FLOAT], latitude=[CAST(ITEM(ITEM($0, 'loc'), 1)):FLOAT], pop=[CAST(ITEM($0, 'pop')):INTEGER], state=[CAST(ITEM($0, 'state')):VARCHAR(2)], id=[CAST(ITEM($0, 'id')):VARCHAR(5)])\n"
444+
+ " ElasticsearchFilter(condition=[OR(=(CAST(ITEM($0, 'state')):VARCHAR(2), 'CA'), >=(CAST(ITEM($0, 'pop')):INTEGER, 94000))])\n"
445+
+ " ElasticsearchTableScan(table=[[elastic, zips]])\n\n";
446+
calciteAssert()
447+
.query(sql)
448+
.queryContains(
449+
ElasticsearchChecker.elasticsearchChecker("'query' : "
450+
+ "{'dis_max':{'queries':[{bool:"
451+
+ "{should:[{term:{state:'CA'}},"
452+
+ "{range:{pop:{gte:94000}}}]}}]}}",
453+
"'script_fields': {longitude:{script:'params._source.loc[0]'}, "
454+
+ "latitude:{script:'params._source.loc[1]'}, "
455+
+ "city:{script: 'params._source.city'}, "
456+
+ "pop:{script: 'params._source.pop'}, "
457+
+ "state:{script: 'params._source.state'}, "
458+
+ "id:{script: 'params._source.id'}}",
459+
"sort: [ {state: 'asc'}, {pop: 'asc'}]",
460+
String.format(Locale.ROOT, "size:%s",
461+
ElasticsearchTransport.DEFAULT_FETCH_SIZE)))
462+
.explainContains(explain);
463+
}
464+
424465
@Test void testFilterSortDesc() {
425466
final String sql = "select * from zips\n"
426467
+ "where pop BETWEEN 95000 AND 100000\n"

0 commit comments

Comments
 (0)