Skip to content

Commit 628ac6c

Browse files
committed
two phase adaption
Signed-off-by: zhichao-aws <[email protected]>
1 parent e5f724a commit 628ac6c

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed

src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java

-1
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,6 @@ protected void doWriteTo(StreamOutput out) throws IOException {
295295
out.writeVInt(this.k);
296296
}
297297
out.writeOptionalNamedWriteable(this.filter);
298-
out.writeOptionalNamedWriteable(this.filter);
299298
if (isClusterOnOrAfterMinReqVersionForRadialSearch()) {
300299
out.writeOptionalFloat(this.maxDistance);
301300
out.writeOptionalFloat(this.minScore);

src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java

+16-9
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,6 @@ public class NeuralSparseQueryBuilder extends AbstractQueryBuilder<NeuralSparseQ
9494
// it means it's origin NeuralSparseQueryBuilder and should split the low score tokens form itself then put it into
9595
// twoPhaseSharedQueryToken.
9696
private Map<String, Float> twoPhaseSharedQueryToken;
97-
// A parameter with a default value 0F,
98-
// 1. If the query request are using neural_sparse_two_phase_processor and be collected,
99-
// It's value will be the ratio of processor.
100-
// 2. If it's the sub query only build for two-phase, the value will be set to -1 * ratio of processor.
101-
// Then in the DoToQuery, we can use this to determine which type are this queryBuilder.
10297
private NeuralSparseQueryTwoPhaseInfo neuralSparseQueryTwoPhaseInfo = new NeuralSparseQueryTwoPhaseInfo();
10398

10499
private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_13_0;
@@ -167,7 +162,7 @@ protected void doWriteTo(StreamOutput out) throws IOException {
167162
}
168163

169164
/**
170-
* Copy this QueryBuilder for two phase rescorer, set the copy one's twoPhasePruneRatio to -1.
165+
* Copy this QueryBuilder for two phase rescorer.
171166
* @param pruneRatio the parameter of the NeuralSparseTwoPhaseProcessor, control how to split the queryTokens to two phase.
172167
* @return A copy NeuralSparseQueryBuilder for twoPhase, it will be added to the rescorer.
173168
*/
@@ -184,7 +179,7 @@ public NeuralSparseQueryBuilder getCopyNeuralSparseQueryBuilderForTwoPhase(float
184179
.analyzer(this.analyzer)
185180
.maxTokenScore(this.maxTokenScore)
186181
.neuralSparseQueryTwoPhaseInfo(
187-
new NeuralSparseQueryTwoPhaseInfo(NeuralSparseQueryTwoPhaseInfo.TwoPhaseStatus.CHILD, pruneRatio * -1f, pruneType)
182+
new NeuralSparseQueryTwoPhaseInfo(NeuralSparseQueryTwoPhaseInfo.TwoPhaseStatus.CHILD, pruneRatio, pruneType)
188183
);
189184
if (Objects.nonNull(this.queryTokensSupplier)) {
190185
Map<String, Float> tokens = queryTokensSupplier.get();
@@ -398,7 +393,7 @@ private BiConsumer<Client, ActionListener<?>> getModelInferenceAsync(SetOnce<Map
398393
}
399394

400395
private Map<String, Float> getQueryTokens(QueryShardContext context) {
401-
if (Objects.nonNull(queryTokensSupplier)) {
396+
if (Objects.nonNull(queryTokensSupplier) && !queryTokensSupplier.get().isEmpty()) {
402397
return queryTokensSupplier.get();
403398
} else if (Objects.nonNull(analyzer)) {
404399
Map<String, Float> queryTokens = new HashMap<>();
@@ -417,7 +412,19 @@ private Map<String, Float> getQueryTokens(QueryShardContext context) {
417412
} catch (IOException e) {
418413
throw new OpenSearchException("failed to analyze query text. ", e);
419414
}
420-
return queryTokens;
415+
return switch (neuralSparseQueryTwoPhaseInfo.getStatus()) {
416+
case NeuralSparseQueryTwoPhaseInfo.TwoPhaseStatus.CHILD -> PruneUtils.splitSparseVector(
417+
neuralSparseQueryTwoPhaseInfo.getTwoPhasePruneType(),
418+
neuralSparseQueryTwoPhaseInfo.getTwoPhasePruneRatio(),
419+
queryTokens
420+
).v2();
421+
case NeuralSparseQueryTwoPhaseInfo.TwoPhaseStatus.PARENT -> PruneUtils.splitSparseVector(
422+
neuralSparseQueryTwoPhaseInfo.getTwoPhasePruneType(),
423+
neuralSparseQueryTwoPhaseInfo.getTwoPhasePruneRatio(),
424+
queryTokens
425+
).v1();
426+
default -> queryTokens;
427+
};
421428
}
422429
throw new IllegalArgumentException("Query tokens cannot be null.");
423430
}

0 commit comments

Comments
 (0)