Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify the logic of script sort parallel collection #124639

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.LeafFieldComparator;
import org.apache.lucene.search.Pruning;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.SortField;
Expand Down Expand Up @@ -70,6 +69,20 @@ protected SortedBinaryDocValues getValues(LeafReaderContext context) throws IOEx

protected void setScorer(LeafReaderContext context, Scorable scorer) {}

protected BinaryDocValues getBinaryDocValues(LeafReaderContext context, BytesRef missingBytes, SortedBinaryDocValues values)
throws IOException {
final BinaryDocValues selectedValues;
if (nested == null) {
selectedValues = sortMode.select(values, missingBytes);
} else {
final BitSet rootDocs = nested.rootDocs(context);
final DocIdSetIterator innerDocs = nested.innerDocs(context);
final int maxChildren = nested.getNestedSort() != null ? nested.getNestedSort().getMaxChildren() : Integer.MAX_VALUE;
selectedValues = sortMode.select(values, missingBytes, rootDocs, innerDocs, maxChildren);
}
return selectedValues;
}

@Override
public FieldComparator<?> newComparator(String fieldname, int numHits, Pruning enableSkipping, boolean reversed) {
assert indexFieldData == null || fieldname.equals(indexFieldData.getFieldName());
Expand Down Expand Up @@ -102,61 +115,22 @@ protected SortedDocValues getSortedDocValues(LeafReaderContext context, String f

};
}
return newComparatorWithoutOrdinal(fieldname, numHits, enableSkipping, reversed, missingBytes, sortMissingLast);
}

protected FieldComparator<?> newComparatorWithoutOrdinal(
String fieldname,
int numHits,
Pruning enableSkipping,
boolean reversed,
BytesRef missingBytes,
boolean sortMissingLast
) {
return new FieldComparator.TermValComparator(numHits, null, sortMissingLast) {

@Override
protected BinaryDocValues getBinaryDocValues(LeafReaderContext context, String field) throws IOException {
final SortedBinaryDocValues values = getValues(context);
final BinaryDocValues selectedValues;
if (nested == null) {
selectedValues = sortMode.select(values, missingBytes);
} else {
final BitSet rootDocs = nested.rootDocs(context);
final DocIdSetIterator innerDocs = nested.innerDocs(context);
final int maxChildren = nested.getNestedSort() != null ? nested.getNestedSort().getMaxChildren() : Integer.MAX_VALUE;
selectedValues = sortMode.select(values, missingBytes, rootDocs, innerDocs, maxChildren);
}
return selectedValues;
}

@Override
public LeafFieldComparator getLeafComparator(LeafReaderContext context) throws IOException {
LeafFieldComparator leafComparator = super.getLeafComparator(context);
// TopFieldCollector interacts with inter-segment concurrency by creating a FieldValueHitQueue per slice, each one with a
// specific instance of the FieldComparator. This ensures sequential execution across LeafFieldComparators returned by
// the same parent FieldComparator. That allows for effectively sharing the same instance of leaf comparator, like in this
// case in the Lucene code. That's fine dealing with sorting by field, but not when using script sorting, because we then
// need to set to Scorer to the specific leaf comparator, to make the _score variable available in sort scripts. The
// setScorer call happens concurrently across slices and needs to target the specific leaf context that is being searched.
return new LeafFieldComparator() {
@Override
public void setBottom(int slot) throws IOException {
leafComparator.setBottom(slot);
}

@Override
public int compareBottom(int doc) throws IOException {
return leafComparator.compareBottom(doc);
}

@Override
public int compareTop(int doc) throws IOException {
return leafComparator.compareTop(doc);
}

@Override
public void copy(int slot, int doc) throws IOException {
leafComparator.copy(slot, doc);
}

@Override
public void setScorer(Scorable scorer) {
// this ensures that the scorer is set for the specific leaf comparator
// corresponding to the leaf context we are scoring
BytesRefFieldComparatorSource.this.setScorer(context, scorer);
}
};
return BytesRefFieldComparatorSource.this.getBinaryDocValues(context, missingBytes, getValues(context));
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@ protected SortedNumericDoubleValues getValues(LeafReaderContext context) throws
}

private NumericDoubleValues getNumericDocValues(LeafReaderContext context, double missingValue) throws IOException {
final SortedNumericDoubleValues values = getValues(context);
return getNumericDocValues(context, missingValue, getValues(context));
}

protected NumericDoubleValues getNumericDocValues(LeafReaderContext context, double missingValue, SortedNumericDoubleValues values)
throws IOException {
if (nested == null) {
return FieldData.replaceMissing(sortMode.select(values), missingValue);
} else {
Expand All @@ -78,6 +82,10 @@ public FieldComparator<?> newComparator(String fieldname, int numHits, Pruning e
assert indexFieldData == null || fieldname.equals(indexFieldData.getFieldName());

final double dMissingValue = (Double) missingObject(missingValue, reversed);
return newComparator(numHits, enableSkipping, reversed, dMissingValue);
}

protected FieldComparator<?> newComparator(int numHits, Pruning enableSkipping, boolean reversed, double dMissingValue) {
// NOTE: it's important to pass null as a missing value in the constructor so that
// the comparator doesn't check docsWithField since we replace missing values in select()
return new DoubleComparator(numHits, null, null, reversed, Pruning.NONE) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.index.fielddata.fieldcomparator;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.LeafFieldComparator;
import org.apache.lucene.search.Pruning;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.comparators.DoubleComparator;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.core.CheckedFunction;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.fielddata.FieldData;
import org.elasticsearch.index.fielddata.IndexNumericFieldData;
import org.elasticsearch.index.fielddata.NumericDoubleValues;
import org.elasticsearch.index.fielddata.SortedNumericDoubleValues;
import org.elasticsearch.script.NumberSortScript;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.MultiValueMode;
import org.elasticsearch.search.sort.BucketedSort;
import org.elasticsearch.search.sort.SortOrder;

import java.io.IOException;

/**
* Script comparator source for double values.
*/
public class ScriptDoubleValuesComparatorSource extends DoubleValuesComparatorSource {

private final CheckedFunction<LeafReaderContext, NumberSortScript, IOException> scriptSupplier;

public ScriptDoubleValuesComparatorSource(
CheckedFunction<LeafReaderContext, NumberSortScript, IOException> scriptSupplier,
IndexNumericFieldData indexFieldData,
@Nullable Object missingValue,
MultiValueMode sortMode,
Nested nested
) {
super(indexFieldData, missingValue, sortMode, nested);
this.scriptSupplier = scriptSupplier;
}

private SortedNumericDoubleValues getValues(NumberSortScript leafScript) throws IOException {
final NumericDoubleValues values = new NumericDoubleValues() {
@Override
public boolean advanceExact(int doc) {
leafScript.setDocument(doc);
return true;
}

@Override
public double doubleValue() {
return leafScript.execute();
}
};
return FieldData.singleton(values);
}

private NumericDoubleValues getNumericDocValues(LeafReaderContext context, double missingValue, NumberSortScript leafScript)
throws IOException {
return getNumericDocValues(context, missingValue, getValues(leafScript));
}

@Override
protected FieldComparator<?> newComparator(int numHits, Pruning enableSkipping, boolean reversed, double dMissingValue) {
// NOTE: it's important to pass null as a missing value in the constructor so that
// the comparator doesn't check docsWithField since we replace missing values in select()
return new DoubleComparator(numHits, null, null, reversed, Pruning.NONE) {
@Override
public LeafFieldComparator getLeafComparator(LeafReaderContext context) throws IOException {
NumberSortScript leafScript = scriptSupplier.apply(context);
return new DoubleLeafComparator(context) {
@Override
protected NumericDocValues getNumericDocValues(LeafReaderContext context, String field) throws IOException {
return ScriptDoubleValuesComparatorSource.this.getNumericDocValues(context, dMissingValue, leafScript)
.getRawDoubleValues();
}

@Override
public void setScorer(Scorable scorer) {
leafScript.setScorer(scorer);
}
};
}
};
}

@Override
public BucketedSort newBucketedSort(
BigArrays bigArrays,
SortOrder sortOrder,
DocValueFormat format,
int bucketSize,
BucketedSort.ExtraData extra
) {
return new BucketedSort.ForDoubles(bigArrays, sortOrder, format, bucketSize, extra) {
private final double dMissingValue = (Double) missingObject(missingValue, sortOrder == SortOrder.DESC);

@Override
public Leaf forLeaf(LeafReaderContext ctx) throws IOException {
NumberSortScript leafScript = scriptSupplier.apply(ctx);
return new Leaf(ctx) {
private final NumericDoubleValues docValues = getNumericDocValues(ctx, dMissingValue, leafScript);
private double docValue;

@Override
protected boolean advanceExact(int doc) throws IOException {
if (docValues.advanceExact(doc)) {
docValue = docValues.doubleValue();
return true;
}
return false;
}

@Override
protected double docValue() {
return docValue;
}
};
}
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.index.fielddata.fieldcomparator;

import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.Pruning;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.core.CheckedFunction;
import org.elasticsearch.index.fielddata.AbstractBinaryDocValues;
import org.elasticsearch.index.fielddata.FieldData;
import org.elasticsearch.index.fielddata.IndexFieldData;
import org.elasticsearch.index.fielddata.SortedBinaryDocValues;
import org.elasticsearch.script.StringSortScript;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.MultiValueMode;
import org.elasticsearch.search.sort.BucketedSort;
import org.elasticsearch.search.sort.ScriptSortBuilder;
import org.elasticsearch.search.sort.SortOrder;

import java.io.IOException;

/**
* Script comparator source for string/binary values.
*/
public class ScriptStringFieldComparatorSource extends BytesRefFieldComparatorSource {

final CheckedFunction<LeafReaderContext, StringSortScript, IOException> scriptSupplier;

public ScriptStringFieldComparatorSource(
CheckedFunction<LeafReaderContext, StringSortScript, IOException> scriptSupplier,
IndexFieldData<?> indexFieldData,
Object missingValue,
MultiValueMode sortMode,
Nested nested
) {
super(indexFieldData, missingValue, sortMode, nested);
this.scriptSupplier = scriptSupplier;
}

private SortedBinaryDocValues getValues(StringSortScript leafScript) throws IOException {
final BinaryDocValues values = new AbstractBinaryDocValues() {
final BytesRefBuilder spare = new BytesRefBuilder();

@Override
public boolean advanceExact(int doc) {
leafScript.setDocument(doc);
return true;
}

@Override
public BytesRef binaryValue() {
spare.copyChars(leafScript.execute());
return spare.get();
}
};
return FieldData.singleton(values);
}

@Override
protected FieldComparator<?> newComparatorWithoutOrdinal(
String fieldname,
int numHits,
Pruning enableSkipping,
boolean reversed,
BytesRef missingBytes,
boolean sortMissingLast
) {
return new FieldComparator.TermValComparator(numHits, null, sortMissingLast) {

StringSortScript leafScript;

@Override
protected BinaryDocValues getBinaryDocValues(LeafReaderContext context, String field) throws IOException {
leafScript = scriptSupplier.apply(context);
return ScriptStringFieldComparatorSource.this.getBinaryDocValues(context, missingBytes, getValues(leafScript));
}

@Override
public void setScorer(Scorable scorer) {
leafScript.setScorer(scorer);
}
};
}

@Override
public BucketedSort newBucketedSort(
BigArrays bigArrays,
SortOrder sortOrder,
DocValueFormat format,
int bucketSize,
BucketedSort.ExtraData extra
) {
throw new IllegalArgumentException(
"error building sort for [_script]: "
+ "script sorting only supported on [numeric] scripts but was ["
+ ScriptSortBuilder.ScriptSortType.STRING
+ "]"
);
}
}
Loading