Skip to content

Commit e5f724a

Browse files
committed
two phase adaption
Signed-off-by: zhichao-aws <[email protected]>
1 parent 1c852b6 commit e5f724a

File tree

6 files changed

+141
-50
lines changed

6 files changed

+141
-50
lines changed

src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java

+1
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ protected void doWriteTo(StreamOutput out) throws IOException {
295295
out.writeVInt(this.k);
296296
}
297297
out.writeOptionalNamedWriteable(this.filter);
298+
out.writeOptionalNamedWriteable(this.filter);
298299
if (isClusterOnOrAfterMinReqVersionForRadialSearch()) {
299300
out.writeOptionalFloat(this.maxDistance);
300301
out.writeOptionalFloat(this.minScore);

src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java

+47-37
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,7 @@ public class NeuralSparseQueryBuilder extends AbstractQueryBuilder<NeuralSparseQ
9999
// It's value will be the ratio of processor.
100100
// 2. If it's the sub query only build for two-phase, the value will be set to -1 * ratio of processor.
101101
// Then in the DoToQuery, we can use this to determine which type are this queryBuilder.
102-
private float twoPhasePruneRatio = 0F;
103-
private PruneType twoPhasePruneType = PruneType.NONE;
102+
private NeuralSparseQueryTwoPhaseInfo neuralSparseQueryTwoPhaseInfo = new NeuralSparseQueryTwoPhaseInfo();
104103

105104
private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_13_0;
106105
private static final Version MINIMAL_SUPPORTED_VERSION_ANALYZER = Version.V_3_0_0;
@@ -124,14 +123,15 @@ public NeuralSparseQueryBuilder(StreamInput in) throws IOException {
124123
} else {
125124
this.modelId = in.readString();
126125
}
127-
if (isClusterOnOrAfterMinReqVersionForAnalyzer()) {
128-
this.analyzer = in.readOptionalString();
129-
}
130126
this.maxTokenScore = in.readOptionalFloat();
131127
if (in.readBoolean()) {
132128
Map<String, Float> queryTokens = in.readMap(StreamInput::readString, StreamInput::readFloat);
133129
this.queryTokensSupplier = () -> queryTokens;
134130
}
131+
if (isClusterOnOrAfterMinReqVersionForAnalyzer()) {
132+
this.analyzer = in.readOptionalString();
133+
this.neuralSparseQueryTwoPhaseInfo = new NeuralSparseQueryTwoPhaseInfo(in);
134+
}
135135
// to be backward compatible with previous version, we need to use writeString/readString API instead of optionalString API
136136
// after supporting query by tokens, queryText and modelId can be null. here we write an empty String instead
137137
if (StringUtils.EMPTY.equals(this.queryText)) {
@@ -142,21 +142,50 @@ public NeuralSparseQueryBuilder(StreamInput in) throws IOException {
142142
}
143143
}
144144

145+
@Override
146+
protected void doWriteTo(StreamOutput out) throws IOException {
147+
out.writeString(this.fieldName);
148+
// to be backward compatible with previous version, we need to use writeString/readString API instead of optionalString API
149+
// after supporting query by tokens, queryText and modelId can be null. here we write an empty String instead
150+
out.writeString(StringUtils.defaultString(this.queryText, StringUtils.EMPTY));
151+
if (isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
152+
out.writeOptionalString(this.modelId);
153+
} else {
154+
out.writeString(StringUtils.defaultString(this.modelId, StringUtils.EMPTY));
155+
}
156+
out.writeOptionalFloat(maxTokenScore);
157+
if (!Objects.isNull(this.queryTokensSupplier) && !Objects.isNull(this.queryTokensSupplier.get())) {
158+
out.writeBoolean(true);
159+
out.writeMap(this.queryTokensSupplier.get(), StreamOutput::writeString, StreamOutput::writeFloat);
160+
} else {
161+
out.writeBoolean(false);
162+
}
163+
if (isClusterOnOrAfterMinReqVersionForAnalyzer()) {
164+
out.writeOptionalString(this.analyzer);
165+
this.neuralSparseQueryTwoPhaseInfo.writeTo(out);
166+
}
167+
}
168+
145169
/**
146170
* Copy this QueryBuilder for two phase rescorer, set the copy one's twoPhasePruneRatio to -1.
147171
* @param pruneRatio the parameter of the NeuralSparseTwoPhaseProcessor, control how to split the queryTokens to two phase.
148172
* @return A copy NeuralSparseQueryBuilder for twoPhase, it will be added to the rescorer.
149173
*/
150174
public NeuralSparseQueryBuilder getCopyNeuralSparseQueryBuilderForTwoPhase(float pruneRatio, PruneType pruneType) {
151-
this.twoPhasePruneRatio(pruneRatio);
152-
this.twoPhasePruneType(pruneType);
175+
this.neuralSparseQueryTwoPhaseInfo = new NeuralSparseQueryTwoPhaseInfo(
176+
NeuralSparseQueryTwoPhaseInfo.TwoPhaseStatus.PARENT,
177+
pruneRatio,
178+
pruneType
179+
);
153180
NeuralSparseQueryBuilder copy = new NeuralSparseQueryBuilder().fieldName(this.fieldName)
154181
.queryName(this.queryName)
155182
.queryText(this.queryText)
156183
.modelId(this.modelId)
157184
.analyzer(this.analyzer)
158185
.maxTokenScore(this.maxTokenScore)
159-
.twoPhasePruneRatio(-1f * pruneRatio);
186+
.neuralSparseQueryTwoPhaseInfo(
187+
new NeuralSparseQueryTwoPhaseInfo(NeuralSparseQueryTwoPhaseInfo.TwoPhaseStatus.CHILD, pruneRatio * -1f, pruneType)
188+
);
160189
if (Objects.nonNull(this.queryTokensSupplier)) {
161190
Map<String, Float> tokens = queryTokensSupplier.get();
162191
// Splitting tokens based on a threshold value: tokens greater than the threshold are stored in v1,
@@ -171,29 +200,6 @@ public NeuralSparseQueryBuilder getCopyNeuralSparseQueryBuilderForTwoPhase(float
171200
return copy;
172201
}
173202

174-
@Override
175-
protected void doWriteTo(StreamOutput out) throws IOException {
176-
out.writeString(this.fieldName);
177-
// to be backward compatible with previous version, we need to use writeString/readString API instead of optionalString API
178-
// after supporting query by tokens, queryText and modelId can be null. here we write an empty String instead
179-
out.writeString(StringUtils.defaultString(this.queryText, StringUtils.EMPTY));
180-
if (isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
181-
out.writeOptionalString(this.modelId);
182-
} else {
183-
out.writeString(StringUtils.defaultString(this.modelId, StringUtils.EMPTY));
184-
}
185-
if (isClusterOnOrAfterMinReqVersionForAnalyzer()) {
186-
out.writeOptionalString(this.analyzer);
187-
}
188-
out.writeOptionalFloat(maxTokenScore);
189-
if (!Objects.isNull(this.queryTokensSupplier) && !Objects.isNull(this.queryTokensSupplier.get())) {
190-
out.writeBoolean(true);
191-
out.writeMap(this.queryTokensSupplier.get(), StreamOutput::writeString, StreamOutput::writeFloat);
192-
} else {
193-
out.writeBoolean(false);
194-
}
195-
}
196-
197203
@Override
198204
protected void doXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
199205
xContentBuilder.startObject(NAME);
@@ -362,7 +368,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
362368
.maxTokenScore(maxTokenScore)
363369
.queryTokensSupplier(queryTokensSetOnce::get)
364370
.twoPhaseSharedQueryToken(twoPhaseSharedQueryToken)
365-
.twoPhasePruneRatio(twoPhasePruneRatio);
371+
.neuralSparseQueryTwoPhaseInfo(neuralSparseQueryTwoPhaseInfo);
366372
}
367373

368374
private BiConsumer<Client, ActionListener<?>> getModelInferenceAsync(SetOnce<Map<String, Float>> setOnce) {
@@ -377,8 +383,8 @@ private BiConsumer<Client, ActionListener<?>> getModelInferenceAsync(SetOnce<Map
377383
Map<String, Float> queryTokens = TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList).get(0);
378384
if (Objects.nonNull(twoPhaseSharedQueryToken)) {
379385
Tuple<Map<String, Float>, Map<String, Float>> splitQueryTokens = PruneUtils.splitSparseVector(
380-
twoPhasePruneType,
381-
twoPhasePruneRatio,
386+
neuralSparseQueryTwoPhaseInfo.getTwoPhasePruneType(),
387+
neuralSparseQueryTwoPhaseInfo.getTwoPhasePruneRatio(),
382388
queryTokens
383389
);
384390
setOnce.set(splitQueryTokens.v1());
@@ -404,7 +410,7 @@ private Map<String, Float> getQueryTokens(QueryShardContext context) {
404410

405411
while (stream.incrementToken()) {
406412
String token = term.toString();
407-
Float weight = Objects.isNull(payload.getPayload()) ? 1 : HFModelTokenizer.bytesToFloat(payload.getPayload().bytes);
413+
Float weight = Objects.isNull(payload.getPayload()) ? 1.0f : HFModelTokenizer.bytesToFloat(payload.getPayload().bytes);
408414
queryTokens.put(token, weight);
409415
}
410416
stream.end();
@@ -466,7 +472,9 @@ protected boolean doEquals(NeuralSparseQueryBuilder obj) {
466472
.append(queryText, obj.queryText)
467473
.append(modelId, obj.modelId)
468474
.append(maxTokenScore, obj.maxTokenScore)
469-
.append(twoPhasePruneRatio, obj.twoPhasePruneRatio)
475+
.append(neuralSparseQueryTwoPhaseInfo.getTwoPhasePruneType(), obj.neuralSparseQueryTwoPhaseInfo.getTwoPhasePruneType())
476+
.append(neuralSparseQueryTwoPhaseInfo.getTwoPhasePruneRatio(), obj.neuralSparseQueryTwoPhaseInfo.getTwoPhasePruneRatio())
477+
.append(neuralSparseQueryTwoPhaseInfo.getStatus().getValue(), obj.neuralSparseQueryTwoPhaseInfo.getStatus().getValue())
470478
.append(twoPhaseSharedQueryToken, obj.twoPhaseSharedQueryToken);
471479
if (Objects.nonNull(queryTokensSupplier)) {
472480
equalsBuilder.append(queryTokensSupplier.get(), obj.queryTokensSupplier.get());
@@ -480,7 +488,9 @@ protected int doHashCode() {
480488
.append(queryText)
481489
.append(modelId)
482490
.append(maxTokenScore)
483-
.append(twoPhasePruneRatio)
491+
.append(neuralSparseQueryTwoPhaseInfo.getTwoPhasePruneType())
492+
.append(neuralSparseQueryTwoPhaseInfo.getTwoPhasePruneRatio())
493+
.append(neuralSparseQueryTwoPhaseInfo.getStatus().getValue())
484494
.append(twoPhaseSharedQueryToken);
485495
if (Objects.nonNull(queryTokensSupplier)) {
486496
builder.append(queryTokensSupplier.get());
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.query;
6+
7+
import lombok.Getter;
8+
import lombok.Setter;
9+
import org.opensearch.core.common.io.stream.StreamInput;
10+
import org.opensearch.core.common.io.stream.StreamOutput;
11+
import org.opensearch.core.common.io.stream.Writeable;
12+
import org.opensearch.neuralsearch.util.prune.PruneType;
13+
14+
import java.io.IOException;
15+
import java.util.Arrays;
16+
import java.util.Locale;
17+
import java.util.Map;
18+
import java.util.function.Function;
19+
import java.util.stream.Collectors;
20+
21+
@Getter
22+
@Setter
23+
public class NeuralSparseQueryTwoPhaseInfo implements Writeable {
24+
private TwoPhaseStatus status = TwoPhaseStatus.NOT_ENABLED;
25+
private float twoPhasePruneRatio = 0F;
26+
private PruneType twoPhasePruneType = PruneType.NONE;
27+
28+
NeuralSparseQueryTwoPhaseInfo() {}
29+
30+
NeuralSparseQueryTwoPhaseInfo(TwoPhaseStatus status, float twoPhasePruneRatio, PruneType twoPhasePruneType) {
31+
this.status = status;
32+
this.twoPhasePruneRatio = twoPhasePruneRatio;
33+
this.twoPhasePruneType = twoPhasePruneType;
34+
}
35+
36+
NeuralSparseQueryTwoPhaseInfo(StreamInput in) throws IOException {
37+
this.status = TwoPhaseStatus.fromInt(in.readInt());
38+
this.twoPhasePruneRatio = in.readFloat();
39+
this.twoPhasePruneType = PruneType.fromString(in.readString());
40+
}
41+
42+
@Override
43+
public void writeTo(StreamOutput out) throws IOException {
44+
out.writeInt(status.getValue());
45+
out.writeFloat(twoPhasePruneRatio);
46+
out.writeString(twoPhasePruneType.getValue());
47+
}
48+
49+
public enum TwoPhaseStatus {
50+
NOT_ENABLED(0),
51+
PARENT(1),
52+
CHILD(2);
53+
54+
private static final Map<Integer, TwoPhaseStatus> VALUE_MAP = Arrays.stream(values())
55+
.collect(Collectors.toUnmodifiableMap(status -> status.value, Function.identity()));
56+
private final int value;
57+
58+
TwoPhaseStatus(int value) {
59+
this.value = value;
60+
}
61+
62+
public int getValue() {
63+
return value;
64+
}
65+
66+
public static TwoPhaseStatus fromInt(final int value) {
67+
TwoPhaseStatus status = VALUE_MAP.get(value);
68+
if (status == null) {
69+
throw new IllegalArgumentException(String.format(Locale.ROOT, "Invalid two phase status value: %d", value));
70+
}
71+
return status;
72+
}
73+
}
74+
}

src/main/java/org/opensearch/neuralsearch/util/prune/PruneType.java

+10-5
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66

77
import org.apache.commons.lang.StringUtils;
88

9+
import java.util.Arrays;
910
import java.util.Locale;
11+
import java.util.Map;
12+
import java.util.function.Function;
13+
import java.util.stream.Collectors;
1014

1115
/**
1216
* Enum representing different types of prune methods for sparse vectors
@@ -19,6 +23,8 @@ public enum PruneType {
1923
ABS_VALUE("abs_value");
2024

2125
private final String value;
26+
private static final Map<String, PruneType> VALUE_MAP = Arrays.stream(values())
27+
.collect(Collectors.toUnmodifiableMap(status -> status.value, Function.identity()));
2228

2329
PruneType(String value) {
2430
this.value = value;
@@ -37,11 +43,10 @@ public String getValue() {
3743
*/
3844
public static PruneType fromString(final String value) {
3945
if (StringUtils.isEmpty(value)) return NONE;
40-
for (PruneType type : PruneType.values()) {
41-
if (type.value.equals(value)) {
42-
return type;
43-
}
46+
PruneType type = VALUE_MAP.get(value);
47+
if (type == null) {
48+
throw new IllegalArgumentException(String.format(Locale.ROOT, "Unknown prune type: %s", value));
4449
}
45-
throw new IllegalArgumentException(String.format(Locale.ROOT, "Unknown prune type: %s", value));
50+
return type;
4651
}
4752
}

src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorTests.java

+6-6
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ public void testProcessRequest_whenTwoPhaseEnabled_thenSuccess() throws Exceptio
8282
NeuralSparseTwoPhaseProcessor processor = createTestProcessor(factory, 0.5f, true, 4.0f, 10000);
8383
processor.processRequest(searchRequest);
8484
NeuralSparseQueryBuilder queryBuilder = (NeuralSparseQueryBuilder) searchRequest.source().query();
85-
assertEquals(queryBuilder.twoPhasePruneRatio(), 0.5f, 1e-3);
85+
assertEquals(queryBuilder.neuralSparseQueryTwoPhaseInfo().getTwoPhasePruneRatio(), 0.5f, 1e-3);
8686
assertNotNull(searchRequest.source().rescores());
8787
}
8888

@@ -94,8 +94,8 @@ public void testProcessRequest_whenUseCustomPruneType_thenSuccess() throws Excep
9494
NeuralSparseTwoPhaseProcessor processor = createTestProcessor(factory, 0.5f, "alpha_mass", true, 4.0f, 10000);
9595
processor.processRequest(searchRequest);
9696
NeuralSparseQueryBuilder queryBuilder = (NeuralSparseQueryBuilder) searchRequest.source().query();
97-
assertEquals(queryBuilder.twoPhasePruneRatio(), 0.5f, 1e-3);
98-
assertEquals(queryBuilder.twoPhasePruneType(), PruneType.ALPHA_MASS);
97+
assertEquals(queryBuilder.neuralSparseQueryTwoPhaseInfo().getTwoPhasePruneRatio(), 0.5f, 1e-3);
98+
assertEquals(queryBuilder.neuralSparseQueryTwoPhaseInfo().getTwoPhasePruneType(), PruneType.ALPHA_MASS);
9999
assertNotNull(searchRequest.source().rescores());
100100
}
101101

@@ -110,7 +110,7 @@ public void testProcessRequest_whenTwoPhaseEnabledAndNestedBoolean_thenSuccess()
110110
processor.processRequest(searchRequest);
111111
BoolQueryBuilder queryBuilder = (BoolQueryBuilder) searchRequest.source().query();
112112
NeuralSparseQueryBuilder neuralSparseQueryBuilder = (NeuralSparseQueryBuilder) queryBuilder.should().get(0);
113-
assertEquals(neuralSparseQueryBuilder.twoPhasePruneRatio(), 0.5f, 1e-3);
113+
assertEquals(neuralSparseQueryBuilder.neuralSparseQueryTwoPhaseInfo().getTwoPhasePruneRatio(), 0.5f, 1e-3);
114114
assertNotNull(searchRequest.source().rescores());
115115
}
116116

@@ -125,7 +125,7 @@ public void testProcessRequestWithRescorer_whenTwoPhaseEnabled_thenSuccess() thr
125125
NeuralSparseTwoPhaseProcessor processor = createTestProcessor(factory, 0.5f, true, 4.0f, 10000);
126126
processor.processRequest(searchRequest);
127127
NeuralSparseQueryBuilder queryBuilder = (NeuralSparseQueryBuilder) searchRequest.source().query();
128-
assertEquals(queryBuilder.twoPhasePruneRatio(), 0.5f, 1e-3);
128+
assertEquals(queryBuilder.neuralSparseQueryTwoPhaseInfo().getTwoPhasePruneRatio(), 0.5f, 1e-3);
129129
assertNotNull(searchRequest.source().rescores());
130130
}
131131

@@ -137,7 +137,7 @@ public void testProcessRequest_whenTwoPhaseDisabled_thenSuccess() throws Excepti
137137
NeuralSparseTwoPhaseProcessor processor = createTestProcessor(factory, 0.5f, false, 4.0f, 10000);
138138
processor.processRequest(searchRequest);
139139
NeuralSparseQueryBuilder queryBuilder = (NeuralSparseQueryBuilder) searchRequest.source().query();
140-
assertEquals(queryBuilder.twoPhasePruneRatio(), 0f, 1e-3);
140+
assertEquals(queryBuilder.neuralSparseQueryTwoPhaseInfo().getTwoPhasePruneRatio(), 0f, 1e-3);
141141
assertNull(searchRequest.source().rescores());
142142
}
143143

src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -659,8 +659,9 @@ public void testRewrite_whenQueryTokensSupplierNull_andPruneSet_thenSuceessPrune
659659
.queryText(QUERY_TEXT)
660660
.modelId(MODEL_ID)
661661
.twoPhaseSharedQueryToken(Map.of())
662-
.twoPhasePruneRatio(3.0f)
663-
.twoPhasePruneType(PruneType.ABS_VALUE);
662+
.neuralSparseQueryTwoPhaseInfo(
663+
new NeuralSparseQueryTwoPhaseInfo(NeuralSparseQueryTwoPhaseInfo.TwoPhaseStatus.PARENT, 3f, PruneType.ABS_VALUE)
664+
);
664665
Map<String, Float> expectedMap = Map.of("1", 1f, "2", 5f);
665666
MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class);
666667
doAnswer(invocation -> {

0 commit comments

Comments
 (0)