Skip to content

Commit c3aac71

Browse files
authored
Support analyzer-based neural sparse query (#1088)
* merge main; add analyzer impl Signed-off-by: zhichao-aws <[email protected]> * two phase adaption Signed-off-by: zhichao-aws <[email protected]> * two phase adaption Signed-off-by: zhichao-aws <[email protected]> * remove analysis Signed-off-by: zhichao-aws <[email protected]> * lint Signed-off-by: zhichao-aws <[email protected]> * update Signed-off-by: zhichao-aws <[email protected]> * address comments Signed-off-by: zhichao-aws <[email protected]> * tests Signed-off-by: zhichao-aws <[email protected]> * modify plugin security policy Signed-off-by: zhichao-aws <[email protected]> * change log Signed-off-by: zhichao-aws <[email protected]> * address comments Signed-off-by: zhichao-aws <[email protected]> * modify to package-private Signed-off-by: zhichao-aws <[email protected]> --------- Signed-off-by: zhichao-aws <[email protected]>
1 parent 39b4892 commit c3aac71

File tree

9 files changed

+592
-65
lines changed

9 files changed

+592
-65
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
## [Unreleased 3.x](https://github.com/opensearch-project/neural-search/compare/main...HEAD)
77

88
### Features
9+
- Implement analyzer based neural sparse query ([#1088](https://github.com/opensearch-project/neural-search/pull/1088))
910
- [Semantic Field] Add semantic field mapper. ([#1225](https://github.com/opensearch-project/neural-search/pull/1225)).
1011

1112
### Enhancements

src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java

+141-48
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.query;
6+
7+
import lombok.Getter;
8+
import lombok.Setter;
9+
import org.opensearch.core.common.io.stream.StreamInput;
10+
import org.opensearch.core.common.io.stream.StreamOutput;
11+
import org.opensearch.core.common.io.stream.Writeable;
12+
import org.opensearch.neuralsearch.util.prune.PruneType;
13+
14+
import java.io.IOException;
15+
import java.util.Arrays;
16+
import java.util.Locale;
17+
import java.util.Map;
18+
import java.util.function.Function;
19+
import java.util.stream.Collectors;
20+
21+
/**
22+
* This class encapsulates information related to the two-phase execution process
23+
* for a neural sparse query. It tracks the current processing phase, the ratio
24+
* used for pruning during the two-phase process, and the type of pruning applied.
25+
*/
26+
@Getter
27+
@Setter
28+
public class NeuralSparseQueryTwoPhaseInfo implements Writeable {
29+
private TwoPhaseStatus status = TwoPhaseStatus.NOT_ENABLED;
30+
private float twoPhasePruneRatio = 0F;
31+
private PruneType twoPhasePruneType = PruneType.NONE;
32+
33+
NeuralSparseQueryTwoPhaseInfo() {}
34+
35+
NeuralSparseQueryTwoPhaseInfo(TwoPhaseStatus status, float twoPhasePruneRatio, PruneType twoPhasePruneType) {
36+
this.status = status;
37+
this.twoPhasePruneRatio = twoPhasePruneRatio;
38+
this.twoPhasePruneType = twoPhasePruneType;
39+
}
40+
41+
NeuralSparseQueryTwoPhaseInfo(StreamInput in) throws IOException {
42+
this.status = TwoPhaseStatus.fromInt(in.readInt());
43+
this.twoPhasePruneRatio = in.readFloat();
44+
this.twoPhasePruneType = PruneType.fromString(in.readString());
45+
}
46+
47+
@Override
48+
public void writeTo(StreamOutput out) throws IOException {
49+
out.writeInt(status.getValue());
50+
out.writeFloat(twoPhasePruneRatio);
51+
out.writeString(twoPhasePruneType.getValue());
52+
}
53+
54+
public enum TwoPhaseStatus {
55+
NOT_ENABLED(0),
56+
PHASE_ONE(1),
57+
PHASE_TWO(2);
58+
59+
private static final Map<Integer, TwoPhaseStatus> VALUE_MAP = Arrays.stream(values())
60+
.collect(Collectors.toUnmodifiableMap(status -> status.value, Function.identity()));
61+
private final int value;
62+
63+
TwoPhaseStatus(int value) {
64+
this.value = value;
65+
}
66+
67+
public int getValue() {
68+
return value;
69+
}
70+
71+
/**
72+
* Converts an integer value to the corresponding TwoPhaseStatus.
73+
* @param value the integer value to convert
74+
* @return the corresponding TwoPhaseStatus enum constant
75+
* @throws IllegalArgumentException if the value does not correspond to any known status
76+
*/
77+
public static TwoPhaseStatus fromInt(final int value) {
78+
TwoPhaseStatus status = VALUE_MAP.get(value);
79+
if (status == null) {
80+
throw new IllegalArgumentException(String.format(Locale.ROOT, "Invalid two phase status value: %d", value));
81+
}
82+
return status;
83+
}
84+
}
85+
}

src/main/java/org/opensearch/neuralsearch/util/prune/PruneType.java

+10-5
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66

77
import org.apache.commons.lang.StringUtils;
88

9+
import java.util.Arrays;
910
import java.util.Locale;
11+
import java.util.Map;
12+
import java.util.function.Function;
13+
import java.util.stream.Collectors;
1014

1115
/**
1216
* Enum representing different types of prune methods for sparse vectors
@@ -19,6 +23,8 @@ public enum PruneType {
1923
ABS_VALUE("abs_value");
2024

2125
private final String value;
26+
private static final Map<String, PruneType> VALUE_MAP = Arrays.stream(values())
27+
.collect(Collectors.toUnmodifiableMap(status -> status.value, Function.identity()));
2228

2329
PruneType(String value) {
2430
this.value = value;
@@ -37,11 +43,10 @@ public String getValue() {
3743
*/
3844
public static PruneType fromString(final String value) {
3945
if (StringUtils.isEmpty(value)) return NONE;
40-
for (PruneType type : PruneType.values()) {
41-
if (type.value.equals(value)) {
42-
return type;
43-
}
46+
PruneType type = VALUE_MAP.get(value);
47+
if (type == null) {
48+
throw new IllegalArgumentException(String.format(Locale.ROOT, "Unknown prune type: %s", value));
4449
}
45-
throw new IllegalArgumentException(String.format(Locale.ROOT, "Unknown prune type: %s", value));
50+
return type;
4651
}
4752
}

src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java

+6
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,14 @@
2929
import org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor;
3030
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
3131
import org.opensearch.neuralsearch.processor.RRFProcessor;
32+
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;
3233
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
3334
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
3435
import org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory;
3536
import org.opensearch.neuralsearch.processor.rerank.RerankProcessor;
3637
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
3738
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
39+
import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder;
3840
import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase;
3941
import org.opensearch.neuralsearch.settings.NeuralSearchSettings;
4042
import org.opensearch.plugins.SearchPipelinePlugin;
@@ -63,6 +65,8 @@ public class NeuralSearchTests extends OpenSearchQueryTestCase {
6365
private ClusterService clusterService;
6466
@Mock
6567
private ThreadPool threadPool;
68+
@Mock
69+
private Environment environment;
6670

6771
@Before
6872
public void setup() {
@@ -112,6 +116,7 @@ public void testQuerySpecs() {
112116
assertFalse(querySpecs.isEmpty());
113117
assertTrue(querySpecs.stream().anyMatch(spec -> NeuralQueryBuilder.NAME.equals(spec.getName().getPreferredName())));
114118
assertTrue(querySpecs.stream().anyMatch(spec -> HybridQueryBuilder.NAME.equals(spec.getName().getPreferredName())));
119+
assertTrue(querySpecs.stream().anyMatch(spec -> NeuralSparseQueryBuilder.NAME.equals(spec.getName().getPreferredName())));
115120
}
116121

117122
public void testProcessors() {
@@ -133,6 +138,7 @@ public void testProcessors() {
133138
Map<String, Processor.Factory> processors = plugin.getProcessors(processorParams);
134139
assertNotNull(processors);
135140
assertNotNull(processors.get(TextEmbeddingProcessor.TYPE));
141+
assertNotNull(processors.get(SparseEncodingProcessor.TYPE));
136142
}
137143

138144
public void testSearchPhaseResultsProcessors() {

src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorTests.java

+6-6
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ public void testProcessRequest_whenTwoPhaseEnabled_thenSuccess() throws Exceptio
8282
NeuralSparseTwoPhaseProcessor processor = createTestProcessor(factory, 0.5f, true, 4.0f, 10000);
8383
processor.processRequest(searchRequest);
8484
NeuralSparseQueryBuilder queryBuilder = (NeuralSparseQueryBuilder) searchRequest.source().query();
85-
assertEquals(queryBuilder.twoPhasePruneRatio(), 0.5f, 1e-3);
85+
assertEquals(queryBuilder.neuralSparseQueryTwoPhaseInfo().getTwoPhasePruneRatio(), 0.5f, 1e-3);
8686
assertNotNull(searchRequest.source().rescores());
8787
}
8888

@@ -94,8 +94,8 @@ public void testProcessRequest_whenUseCustomPruneType_thenSuccess() throws Excep
9494
NeuralSparseTwoPhaseProcessor processor = createTestProcessor(factory, 0.5f, "alpha_mass", true, 4.0f, 10000);
9595
processor.processRequest(searchRequest);
9696
NeuralSparseQueryBuilder queryBuilder = (NeuralSparseQueryBuilder) searchRequest.source().query();
97-
assertEquals(queryBuilder.twoPhasePruneRatio(), 0.5f, 1e-3);
98-
assertEquals(queryBuilder.twoPhasePruneType(), PruneType.ALPHA_MASS);
97+
assertEquals(queryBuilder.neuralSparseQueryTwoPhaseInfo().getTwoPhasePruneRatio(), 0.5f, 1e-3);
98+
assertEquals(queryBuilder.neuralSparseQueryTwoPhaseInfo().getTwoPhasePruneType(), PruneType.ALPHA_MASS);
9999
assertNotNull(searchRequest.source().rescores());
100100
}
101101

@@ -110,7 +110,7 @@ public void testProcessRequest_whenTwoPhaseEnabledAndNestedBoolean_thenSuccess()
110110
processor.processRequest(searchRequest);
111111
BoolQueryBuilder queryBuilder = (BoolQueryBuilder) searchRequest.source().query();
112112
NeuralSparseQueryBuilder neuralSparseQueryBuilder = (NeuralSparseQueryBuilder) queryBuilder.should().get(0);
113-
assertEquals(neuralSparseQueryBuilder.twoPhasePruneRatio(), 0.5f, 1e-3);
113+
assertEquals(neuralSparseQueryBuilder.neuralSparseQueryTwoPhaseInfo().getTwoPhasePruneRatio(), 0.5f, 1e-3);
114114
assertNotNull(searchRequest.source().rescores());
115115
}
116116

@@ -125,7 +125,7 @@ public void testProcessRequestWithRescorer_whenTwoPhaseEnabled_thenSuccess() thr
125125
NeuralSparseTwoPhaseProcessor processor = createTestProcessor(factory, 0.5f, true, 4.0f, 10000);
126126
processor.processRequest(searchRequest);
127127
NeuralSparseQueryBuilder queryBuilder = (NeuralSparseQueryBuilder) searchRequest.source().query();
128-
assertEquals(queryBuilder.twoPhasePruneRatio(), 0.5f, 1e-3);
128+
assertEquals(queryBuilder.neuralSparseQueryTwoPhaseInfo().getTwoPhasePruneRatio(), 0.5f, 1e-3);
129129
assertNotNull(searchRequest.source().rescores());
130130
}
131131

@@ -137,7 +137,7 @@ public void testProcessRequest_whenTwoPhaseDisabled_thenSuccess() throws Excepti
137137
NeuralSparseTwoPhaseProcessor processor = createTestProcessor(factory, 0.5f, false, 4.0f, 10000);
138138
processor.processRequest(searchRequest);
139139
NeuralSparseQueryBuilder queryBuilder = (NeuralSparseQueryBuilder) searchRequest.source().query();
140-
assertEquals(queryBuilder.twoPhasePruneRatio(), 0f, 1e-3);
140+
assertEquals(queryBuilder.neuralSparseQueryTwoPhaseInfo().getTwoPhasePruneRatio(), 0f, 1e-3);
141141
assertNull(searchRequest.source().rescores());
142142
}
143143

0 commit comments

Comments
 (0)