@@ -99,8 +99,7 @@ public class NeuralSparseQueryBuilder extends AbstractQueryBuilder<NeuralSparseQ
99
99
// It's value will be the ratio of processor.
100
100
// 2. If it's the sub query only build for two-phase, the value will be set to -1 * ratio of processor.
101
101
// Then in the DoToQuery, we can use this to determine which type are this queryBuilder.
102
- private float twoPhasePruneRatio = 0F ;
103
- private PruneType twoPhasePruneType = PruneType .NONE ;
102
+ private NeuralSparseQueryTwoPhaseInfo neuralSparseQueryTwoPhaseInfo = new NeuralSparseQueryTwoPhaseInfo ();
104
103
105
104
private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version .V_2_13_0 ;
106
105
private static final Version MINIMAL_SUPPORTED_VERSION_ANALYZER = Version .V_3_0_0 ;
@@ -124,14 +123,15 @@ public NeuralSparseQueryBuilder(StreamInput in) throws IOException {
124
123
} else {
125
124
this .modelId = in .readString ();
126
125
}
127
- if (isClusterOnOrAfterMinReqVersionForAnalyzer ()) {
128
- this .analyzer = in .readOptionalString ();
129
- }
130
126
this .maxTokenScore = in .readOptionalFloat ();
131
127
if (in .readBoolean ()) {
132
128
Map <String , Float > queryTokens = in .readMap (StreamInput ::readString , StreamInput ::readFloat );
133
129
this .queryTokensSupplier = () -> queryTokens ;
134
130
}
131
+ if (isClusterOnOrAfterMinReqVersionForAnalyzer ()) {
132
+ this .analyzer = in .readOptionalString ();
133
+ this .neuralSparseQueryTwoPhaseInfo = new NeuralSparseQueryTwoPhaseInfo (in );
134
+ }
135
135
// to be backward compatible with previous version, we need to use writeString/readString API instead of optionalString API
136
136
// after supporting query by tokens, queryText and modelId can be null. here we write an empty String instead
137
137
if (StringUtils .EMPTY .equals (this .queryText )) {
@@ -142,21 +142,50 @@ public NeuralSparseQueryBuilder(StreamInput in) throws IOException {
142
142
}
143
143
}
144
144
145
+ @ Override
146
+ protected void doWriteTo (StreamOutput out ) throws IOException {
147
+ out .writeString (this .fieldName );
148
+ // to be backward compatible with previous version, we need to use writeString/readString API instead of optionalString API
149
+ // after supporting query by tokens, queryText and modelId can be null. here we write an empty String instead
150
+ out .writeString (StringUtils .defaultString (this .queryText , StringUtils .EMPTY ));
151
+ if (isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport ()) {
152
+ out .writeOptionalString (this .modelId );
153
+ } else {
154
+ out .writeString (StringUtils .defaultString (this .modelId , StringUtils .EMPTY ));
155
+ }
156
+ out .writeOptionalFloat (maxTokenScore );
157
+ if (!Objects .isNull (this .queryTokensSupplier ) && !Objects .isNull (this .queryTokensSupplier .get ())) {
158
+ out .writeBoolean (true );
159
+ out .writeMap (this .queryTokensSupplier .get (), StreamOutput ::writeString , StreamOutput ::writeFloat );
160
+ } else {
161
+ out .writeBoolean (false );
162
+ }
163
+ if (isClusterOnOrAfterMinReqVersionForAnalyzer ()) {
164
+ out .writeOptionalString (this .analyzer );
165
+ this .neuralSparseQueryTwoPhaseInfo .writeTo (out );
166
+ }
167
+ }
168
+
145
169
/**
146
170
* Copy this QueryBuilder for two phase rescorer, set the copy one's twoPhasePruneRatio to -1.
147
171
* @param pruneRatio the parameter of the NeuralSparseTwoPhaseProcessor, control how to split the queryTokens to two phase.
148
172
* @return A copy NeuralSparseQueryBuilder for twoPhase, it will be added to the rescorer.
149
173
*/
150
174
public NeuralSparseQueryBuilder getCopyNeuralSparseQueryBuilderForTwoPhase (float pruneRatio , PruneType pruneType ) {
151
- this .twoPhasePruneRatio (pruneRatio );
152
- this .twoPhasePruneType (pruneType );
175
+ this .neuralSparseQueryTwoPhaseInfo = new NeuralSparseQueryTwoPhaseInfo (
176
+ NeuralSparseQueryTwoPhaseInfo .TwoPhaseStatus .PARENT ,
177
+ pruneRatio ,
178
+ pruneType
179
+ );
153
180
NeuralSparseQueryBuilder copy = new NeuralSparseQueryBuilder ().fieldName (this .fieldName )
154
181
.queryName (this .queryName )
155
182
.queryText (this .queryText )
156
183
.modelId (this .modelId )
157
184
.analyzer (this .analyzer )
158
185
.maxTokenScore (this .maxTokenScore )
159
- .twoPhasePruneRatio (-1f * pruneRatio );
186
+ .neuralSparseQueryTwoPhaseInfo (
187
+ new NeuralSparseQueryTwoPhaseInfo (NeuralSparseQueryTwoPhaseInfo .TwoPhaseStatus .CHILD , pruneRatio * -1f , pruneType )
188
+ );
160
189
if (Objects .nonNull (this .queryTokensSupplier )) {
161
190
Map <String , Float > tokens = queryTokensSupplier .get ();
162
191
// Splitting tokens based on a threshold value: tokens greater than the threshold are stored in v1,
@@ -171,29 +200,6 @@ public NeuralSparseQueryBuilder getCopyNeuralSparseQueryBuilderForTwoPhase(float
171
200
return copy ;
172
201
}
173
202
174
- @ Override
175
- protected void doWriteTo (StreamOutput out ) throws IOException {
176
- out .writeString (this .fieldName );
177
- // to be backward compatible with previous version, we need to use writeString/readString API instead of optionalString API
178
- // after supporting query by tokens, queryText and modelId can be null. here we write an empty String instead
179
- out .writeString (StringUtils .defaultString (this .queryText , StringUtils .EMPTY ));
180
- if (isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport ()) {
181
- out .writeOptionalString (this .modelId );
182
- } else {
183
- out .writeString (StringUtils .defaultString (this .modelId , StringUtils .EMPTY ));
184
- }
185
- if (isClusterOnOrAfterMinReqVersionForAnalyzer ()) {
186
- out .writeOptionalString (this .analyzer );
187
- }
188
- out .writeOptionalFloat (maxTokenScore );
189
- if (!Objects .isNull (this .queryTokensSupplier ) && !Objects .isNull (this .queryTokensSupplier .get ())) {
190
- out .writeBoolean (true );
191
- out .writeMap (this .queryTokensSupplier .get (), StreamOutput ::writeString , StreamOutput ::writeFloat );
192
- } else {
193
- out .writeBoolean (false );
194
- }
195
- }
196
-
197
203
@ Override
198
204
protected void doXContent (XContentBuilder xContentBuilder , Params params ) throws IOException {
199
205
xContentBuilder .startObject (NAME );
@@ -362,7 +368,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
362
368
.maxTokenScore (maxTokenScore )
363
369
.queryTokensSupplier (queryTokensSetOnce ::get )
364
370
.twoPhaseSharedQueryToken (twoPhaseSharedQueryToken )
365
- .twoPhasePruneRatio ( twoPhasePruneRatio );
371
+ .neuralSparseQueryTwoPhaseInfo ( neuralSparseQueryTwoPhaseInfo );
366
372
}
367
373
368
374
private BiConsumer <Client , ActionListener <?>> getModelInferenceAsync (SetOnce <Map <String , Float >> setOnce ) {
@@ -377,8 +383,8 @@ private BiConsumer<Client, ActionListener<?>> getModelInferenceAsync(SetOnce<Map
377
383
Map <String , Float > queryTokens = TokenWeightUtil .fetchListOfTokenWeightMap (mapResultList ).get (0 );
378
384
if (Objects .nonNull (twoPhaseSharedQueryToken )) {
379
385
Tuple <Map <String , Float >, Map <String , Float >> splitQueryTokens = PruneUtils .splitSparseVector (
380
- twoPhasePruneType ,
381
- twoPhasePruneRatio ,
386
+ neuralSparseQueryTwoPhaseInfo . getTwoPhasePruneType () ,
387
+ neuralSparseQueryTwoPhaseInfo . getTwoPhasePruneRatio () ,
382
388
queryTokens
383
389
);
384
390
setOnce .set (splitQueryTokens .v1 ());
@@ -404,7 +410,7 @@ private Map<String, Float> getQueryTokens(QueryShardContext context) {
404
410
405
411
while (stream .incrementToken ()) {
406
412
String token = term .toString ();
407
- Float weight = Objects .isNull (payload .getPayload ()) ? 1 : HFModelTokenizer .bytesToFloat (payload .getPayload ().bytes );
413
+ Float weight = Objects .isNull (payload .getPayload ()) ? 1.0f : HFModelTokenizer .bytesToFloat (payload .getPayload ().bytes );
408
414
queryTokens .put (token , weight );
409
415
}
410
416
stream .end ();
@@ -466,7 +472,9 @@ protected boolean doEquals(NeuralSparseQueryBuilder obj) {
466
472
.append (queryText , obj .queryText )
467
473
.append (modelId , obj .modelId )
468
474
.append (maxTokenScore , obj .maxTokenScore )
469
- .append (twoPhasePruneRatio , obj .twoPhasePruneRatio )
475
+ .append (neuralSparseQueryTwoPhaseInfo .getTwoPhasePruneType (), obj .neuralSparseQueryTwoPhaseInfo .getTwoPhasePruneType ())
476
+ .append (neuralSparseQueryTwoPhaseInfo .getTwoPhasePruneRatio (), obj .neuralSparseQueryTwoPhaseInfo .getTwoPhasePruneRatio ())
477
+ .append (neuralSparseQueryTwoPhaseInfo .getStatus ().getValue (), obj .neuralSparseQueryTwoPhaseInfo .getStatus ().getValue ())
470
478
.append (twoPhaseSharedQueryToken , obj .twoPhaseSharedQueryToken );
471
479
if (Objects .nonNull (queryTokensSupplier )) {
472
480
equalsBuilder .append (queryTokensSupplier .get (), obj .queryTokensSupplier .get ());
@@ -480,7 +488,9 @@ protected int doHashCode() {
480
488
.append (queryText )
481
489
.append (modelId )
482
490
.append (maxTokenScore )
483
- .append (twoPhasePruneRatio )
491
+ .append (neuralSparseQueryTwoPhaseInfo .getTwoPhasePruneType ())
492
+ .append (neuralSparseQueryTwoPhaseInfo .getTwoPhasePruneRatio ())
493
+ .append (neuralSparseQueryTwoPhaseInfo .getStatus ().getValue ())
484
494
.append (twoPhaseSharedQueryToken );
485
495
if (Objects .nonNull (queryTokensSupplier )) {
486
496
builder .append (queryTokensSupplier .get ());
0 commit comments