Skip to content

Commit b55227f

Browse files
authored
Add support for boost params + mult boosts (#247)
1 parent dbb9637 commit b55227f

File tree

1 file changed

+86
-2
lines changed

1 file changed

+86
-2
lines changed

montysolr/src/main/java/org/apache/solr/search/AqpAdsabsQParser.java

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
package org.apache.solr.search;
22

3+
import org.apache.lucene.queries.function.FunctionQuery;
4+
import org.apache.lucene.queries.function.FunctionScoreQuery;
5+
import org.apache.lucene.queries.function.ValueSource;
6+
import org.apache.lucene.queries.function.valuesource.ProductFloatFunction;
7+
import org.apache.lucene.queries.function.valuesource.QueryValueSource;
38
import org.apache.lucene.queryparser.flexible.aqp.AqpAdsabsQueryTreeBuilder;
49
import org.apache.lucene.queryparser.flexible.aqp.AqpQueryParser;
510
import org.apache.lucene.queryparser.flexible.aqp.builders.AqpAdsabsFunctionProvider;
@@ -15,13 +20,16 @@
1520
import org.apache.lucene.queryparser.flexible.standard.config.PointsConfig;
1621
import org.apache.lucene.queryparser.flexible.standard.config.StandardQueryConfigHandler;
1722
import org.apache.lucene.queryparser.flexible.standard.config.StandardQueryConfigHandler.Operator;
18-
import org.apache.lucene.search.Query;
23+
import org.apache.lucene.search.*;
1924
import org.apache.solr.common.SolrException;
2025
import org.apache.solr.common.params.CommonParams;
26+
import org.apache.solr.common.params.DisMaxParams;
2127
import org.apache.solr.common.params.SolrParams;
28+
import org.apache.solr.common.util.NamedList;
2229
import org.apache.solr.request.SolrQueryRequest;
2330
import org.apache.solr.schema.IndexSchema;
2431
import org.apache.solr.util.DateMathParser;
32+
import org.apache.solr.util.SolrPluginUtils;
2533
import org.slf4j.Logger;
2634
import org.slf4j.LoggerFactory;
2735

@@ -48,6 +56,10 @@ public class AqpAdsabsQParser extends QParser {
4856

4957
private final AqpQueryParser qParser;
5058

59+
private String[] boostParams;
60+
private String[] boostFuncs;
61+
private String[] multBoosts;
62+
5163
public AqpAdsabsQParser(AqpQueryParser parser, String qstr, SolrParams localParams,
5264
SolrParams params, SolrQueryRequest req, SolrParserConfigParams defaultConfig)
5365
throws QueryNodeParseException {
@@ -71,6 +83,8 @@ public AqpAdsabsQParser(AqpQueryParser parser, String qstr, SolrParams localPara
7183

7284

7385
Map<String, String> namedParams = config.get(AqpStandardQueryConfigHandler.ConfigurationKeys.NAMED_PARAMETER);
86+
SolrParams solrParams = SolrParams.wrapDefaults(localParams,
87+
SolrParams.wrapDefaults(params, new NamedList<>(namedParams).toSolrParams()));
7488

7589
// get the parameters from the parser configuration (and pass them on)
7690
for (Entry<String, String> par : defaultConfig.params.entrySet()) {
@@ -235,6 +249,12 @@ public AqpAdsabsQParser(AqpQueryParser parser, String qstr, SolrParams localPara
235249
params.getBool("aqp.allow.leading_wildcard", false));
236250
}
237251

252+
// Boost factors
253+
boostParams = solrParams.getParams(DisMaxParams.BQ);
254+
255+
boostFuncs = solrParams.getParams(DisMaxParams.BF);
256+
257+
multBoosts = solrParams.getParams(AqpExtendedDismaxQParser.DMP.MULT_BOOST);
238258
}
239259

240260
public class NumberDateFormat extends NumberFormat {
@@ -338,7 +358,31 @@ public Query parse() throws SyntaxError {
338358
// return qParser.parse(getString() + config.get(AqpAdsabsQueryConfigHandler.ConfigurationKeys.DUMMY_VALUE), null);
339359
//}
340360

341-
return qParser.parse(getString(), null);
361+
Query userQuery = qParser.parse(getString(), null);
362+
Query topQuery = userQuery;
363+
364+
if (boostParams != null || boostFuncs != null) {
365+
BooleanQuery.Builder builder = new BooleanQuery.Builder();
366+
builder.add(userQuery, BooleanClause.Occur.MUST);
367+
368+
for (Query q : getBoostQueries()) {
369+
builder.add(q, BooleanClause.Occur.SHOULD);
370+
}
371+
372+
topQuery = builder.build();
373+
}
374+
375+
if (multBoosts != null) {
376+
List<ValueSource> boosts = getMultiplicativeBoosts();
377+
DoubleValuesSource multiplicativeBoostSource =
378+
boosts.size() > 1
379+
? new ProductFloatFunction(boosts.toArray(new ValueSource[0])).asDoubleValuesSource()
380+
: boosts.get(0).asDoubleValuesSource();
381+
382+
topQuery = FunctionScoreQuery.boostByValue(topQuery, multiplicativeBoostSource);
383+
}
384+
385+
return topQuery;
342386
} catch (QueryNodeException e) {
343387
throw new SyntaxError(e);
344388
} catch (SolrException e1) {
@@ -349,4 +393,44 @@ public Query parse() throws SyntaxError {
349393
public AqpQueryParser getParser() {
350394
return qParser;
351395
}
396+
397+
/**
398+
* Parses all multiplicative boosts
399+
*/
400+
protected List<ValueSource> getMultiplicativeBoosts() throws SyntaxError {
401+
List<ValueSource> boosts = new ArrayList<>();
402+
403+
for (String boostStr : multBoosts) {
404+
if (boostStr == null || boostStr.isBlank()) continue;
405+
406+
Query boost = subQuery(boostStr, FunctionQParserPlugin.NAME).getQuery();
407+
408+
ValueSource vs;
409+
if (boost instanceof FunctionQuery) {
410+
vs = ((FunctionQuery) boost).getValueSource();
411+
} else {
412+
vs = new QueryValueSource(boost, 1.0f);
413+
}
414+
415+
boosts.add(vs);
416+
}
417+
418+
return boosts;
419+
}
420+
421+
/**
422+
* Parses all additive boost fields
423+
*/
424+
protected List<Query> getBoostQueries() throws SyntaxError {
425+
List<Query> boostQueries = new LinkedList<>();
426+
427+
for (String boostParam : boostParams) {
428+
if (boostParam == null || boostParam.trim().isBlank()) continue;
429+
430+
Query q = subQuery(boostParam, null).getQuery();
431+
boostQueries.add(q);
432+
}
433+
434+
return boostQueries;
435+
}
352436
}

0 commit comments

Comments
 (0)