@@ -94,11 +94,6 @@ public class NeuralSparseQueryBuilder extends AbstractQueryBuilder<NeuralSparseQ
94
94
// it means it's origin NeuralSparseQueryBuilder and should split the low score tokens form itself then put it into
95
95
// twoPhaseSharedQueryToken.
96
96
private Map <String , Float > twoPhaseSharedQueryToken ;
97
- // A parameter with a default value 0F,
98
- // 1. If the query request are using neural_sparse_two_phase_processor and be collected,
99
- // It's value will be the ratio of processor.
100
- // 2. If it's the sub query only build for two-phase, the value will be set to -1 * ratio of processor.
101
- // Then in the DoToQuery, we can use this to determine which type are this queryBuilder.
102
97
private NeuralSparseQueryTwoPhaseInfo neuralSparseQueryTwoPhaseInfo = new NeuralSparseQueryTwoPhaseInfo ();
103
98
104
99
private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version .V_2_13_0 ;
@@ -167,7 +162,7 @@ protected void doWriteTo(StreamOutput out) throws IOException {
167
162
}
168
163
169
164
/**
170
- * Copy this QueryBuilder for two phase rescorer, set the copy one's twoPhasePruneRatio to -1 .
165
+ * Copy this QueryBuilder for two phase rescorer.
171
166
* @param pruneRatio the parameter of the NeuralSparseTwoPhaseProcessor, control how to split the queryTokens to two phase.
172
167
* @return A copy NeuralSparseQueryBuilder for twoPhase, it will be added to the rescorer.
173
168
*/
@@ -184,7 +179,7 @@ public NeuralSparseQueryBuilder getCopyNeuralSparseQueryBuilderForTwoPhase(float
184
179
.analyzer (this .analyzer )
185
180
.maxTokenScore (this .maxTokenScore )
186
181
.neuralSparseQueryTwoPhaseInfo (
187
- new NeuralSparseQueryTwoPhaseInfo (NeuralSparseQueryTwoPhaseInfo .TwoPhaseStatus .CHILD , pruneRatio * - 1f , pruneType )
182
+ new NeuralSparseQueryTwoPhaseInfo (NeuralSparseQueryTwoPhaseInfo .TwoPhaseStatus .CHILD , pruneRatio , pruneType )
188
183
);
189
184
if (Objects .nonNull (this .queryTokensSupplier )) {
190
185
Map <String , Float > tokens = queryTokensSupplier .get ();
@@ -398,7 +393,7 @@ private BiConsumer<Client, ActionListener<?>> getModelInferenceAsync(SetOnce<Map
398
393
}
399
394
400
395
private Map <String , Float > getQueryTokens (QueryShardContext context ) {
401
- if (Objects .nonNull (queryTokensSupplier )) {
396
+ if (Objects .nonNull (queryTokensSupplier ) && ! queryTokensSupplier . get (). isEmpty () ) {
402
397
return queryTokensSupplier .get ();
403
398
} else if (Objects .nonNull (analyzer )) {
404
399
Map <String , Float > queryTokens = new HashMap <>();
@@ -417,7 +412,19 @@ private Map<String, Float> getQueryTokens(QueryShardContext context) {
417
412
} catch (IOException e ) {
418
413
throw new OpenSearchException ("failed to analyze query text. " , e );
419
414
}
420
- return queryTokens ;
415
+ return switch (neuralSparseQueryTwoPhaseInfo .getStatus ()) {
416
+ case NeuralSparseQueryTwoPhaseInfo .TwoPhaseStatus .CHILD -> PruneUtils .splitSparseVector (
417
+ neuralSparseQueryTwoPhaseInfo .getTwoPhasePruneType (),
418
+ neuralSparseQueryTwoPhaseInfo .getTwoPhasePruneRatio (),
419
+ queryTokens
420
+ ).v2 ();
421
+ case NeuralSparseQueryTwoPhaseInfo .TwoPhaseStatus .PARENT -> PruneUtils .splitSparseVector (
422
+ neuralSparseQueryTwoPhaseInfo .getTwoPhasePruneType (),
423
+ neuralSparseQueryTwoPhaseInfo .getTwoPhasePruneRatio (),
424
+ queryTokens
425
+ ).v1 ();
426
+ default -> queryTokens ;
427
+ };
421
428
}
422
429
throw new IllegalArgumentException ("Query tokens cannot be null." );
423
430
}
0 commit comments