Skip to content

Commit e7d4a28

Browse files
Support configurable chunking in semantic_text fields (#121041)
* test * Revert "test" This reverts commit 9f4e2ad. * Refactor InferenceService to allow passing in chunking settings * Add chunking config to inference field metadata and store in semantic_text field * Fix test compilation errors * Hacking around trying to get ingest to work * Debugging * [CI] Auto commit changes from spotless * POC works and update TODO to fix this * [CI] Auto commit changes from spotless * Refactor chunking settings from model settings to field inference request * A bit of cleanup * Revert a bunch of changes to try to narrow down what broke CI * test * Revert "test" This reverts commit 9f4e2ad. * Fix InferenceFieldMetadataTest * [CI] Auto commit changes from spotless * Add chunking settings back in * Update builder to use new map * Fix compilation errors after merge * Debugging tests * debugging * Cleanup * Add yaml test * Update tests * Add chunking to test inference service * Trying to get tests to work * Shard bulk inference test never specifies chunking settings * Fix test * Always process batches in order * Fix chunking in test inference service and yaml tests * [CI] Auto commit changes from spotless * Refactor - remove convenience method with default chunking settings * Fix ShardBulkInferenceActionFilterTests * Fix ElasticsearchInternalServiceTests * Fix SemanticTextFieldMapperTests * [CI] Auto commit changes from spotless * Fix test data to fit within bounds * Add additional yaml test cases * Playing with xcontent parsing * A little cleanup * Update docs/changelog/121041.yaml * Fix failures introduced by merge * [CI] Auto commit changes from spotless * Address PR feedback * [CI] Auto commit changes from spotless * Fix predicate in updated test * Better handling of null/empty ChunkingSettings * Update parsing settings * Fix errors post merge * PR feedback * [CI] Auto commit changes from spotless * PR feedback and fix Xcontent parsing for SemanticTextField * Remove chunking settings check to use what's passed in from sender service * Fix some tests * Cleanup * Test failure whack-a-mole * Cleanup * Refactor to handle memory optimized bulk shard inference actions - this is ugly but at least it compiles * [CI] Auto commit changes from spotless * Minor cleanup * A bit more cleanup * Spotless * Revert change * Update chunking setting update logic * Go back to serializing maps * Revert change to model settings - source still errors on missing model_id * Fix updating chunking settings * Look up model if null * Fix test * Work around #125723 in semantic text field serialization * Add BWC tests * Add chunking_settings to docs * Refactor/rename * Address minor PR feedback * Add test case for null update * PR feedback - adjust refactor of chunked inputs * Refactored AbstractTestInferenceService to return offsets instead of just Strings * [CI] Auto commit changes from spotless * Fix tests where chunk output was of size 3 * Update mappings per PR feedback * PR Feedback * Fix problems related to merge * PR optimization * Fix test * Delete extra file --------- Co-authored-by: elasticsearchmachine <[email protected]>
1 parent bcdf51a commit e7d4a28

File tree

108 files changed

+2193
-430
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

108 files changed

+2193
-430
lines changed

docs/changelog/121041.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 121041
2+
summary: Support configurable chunking in `semantic_text` fields
3+
area: Relevance
4+
type: enhancement
5+
issues: []

docs/reference/elasticsearch/mapping-reference/semantic-text.md

+119-41
Large diffs are not rendered by default.

server/src/main/java/org/elasticsearch/TransportVersions.java

+1
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ static TransportVersion def(int id) {
213213
public static final TransportVersion REMOTE_EXCEPTION = def(9_044_0_00);
214214
public static final TransportVersion ESQL_REMOVE_AGGREGATE_TYPE = def(9_045_0_00);
215215
public static final TransportVersion ADD_PROJECT_ID_TO_DSL_ERROR_INFO = def(9_046_0_00);
216+
public static final TransportVersion SEMANTIC_TEXT_CHUNKING_CONFIG = def(9_047_00_0);
216217

217218
/*
218219
* STOP! READ THIS FIRST! No, really,

server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java

+46-6
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.elasticsearch.TransportVersions;
1313
import org.elasticsearch.cluster.Diff;
1414
import org.elasticsearch.cluster.SimpleDiffable;
15+
import org.elasticsearch.common.Strings;
1516
import org.elasticsearch.common.io.stream.StreamInput;
1617
import org.elasticsearch.common.io.stream.StreamOutput;
1718
import org.elasticsearch.xcontent.ToXContentFragment;
@@ -22,8 +23,11 @@
2223
import java.util.ArrayList;
2324
import java.util.Arrays;
2425
import java.util.List;
26+
import java.util.Map;
2527
import java.util.Objects;
2628

29+
import static org.elasticsearch.TransportVersions.SEMANTIC_TEXT_CHUNKING_CONFIG;
30+
2731
/**
2832
* Contains inference field data for fields.
2933
* As inference is done in the coordinator node to avoid re-doing it at shard / replica level, the coordinator needs to check for the need
@@ -35,21 +39,30 @@ public final class InferenceFieldMetadata implements SimpleDiffable<InferenceFie
3539
private static final String INFERENCE_ID_FIELD = "inference_id";
3640
private static final String SEARCH_INFERENCE_ID_FIELD = "search_inference_id";
3741
private static final String SOURCE_FIELDS_FIELD = "source_fields";
42+
static final String CHUNKING_SETTINGS_FIELD = "chunking_settings";
3843

3944
private final String name;
4045
private final String inferenceId;
4146
private final String searchInferenceId;
4247
private final String[] sourceFields;
48+
private final Map<String, Object> chunkingSettings;
4349

44-
public InferenceFieldMetadata(String name, String inferenceId, String[] sourceFields) {
45-
this(name, inferenceId, inferenceId, sourceFields);
50+
public InferenceFieldMetadata(String name, String inferenceId, String[] sourceFields, Map<String, Object> chunkingSettings) {
51+
this(name, inferenceId, inferenceId, sourceFields, chunkingSettings);
4652
}
4753

48-
public InferenceFieldMetadata(String name, String inferenceId, String searchInferenceId, String[] sourceFields) {
54+
public InferenceFieldMetadata(
55+
String name,
56+
String inferenceId,
57+
String searchInferenceId,
58+
String[] sourceFields,
59+
Map<String, Object> chunkingSettings
60+
) {
4961
this.name = Objects.requireNonNull(name);
5062
this.inferenceId = Objects.requireNonNull(inferenceId);
5163
this.searchInferenceId = Objects.requireNonNull(searchInferenceId);
5264
this.sourceFields = Objects.requireNonNull(sourceFields);
65+
this.chunkingSettings = chunkingSettings != null ? Map.copyOf(chunkingSettings) : null;
5366
}
5467

5568
public InferenceFieldMetadata(StreamInput input) throws IOException {
@@ -61,6 +74,11 @@ public InferenceFieldMetadata(StreamInput input) throws IOException {
6174
this.searchInferenceId = this.inferenceId;
6275
}
6376
this.sourceFields = input.readStringArray();
77+
if (input.getTransportVersion().onOrAfter(SEMANTIC_TEXT_CHUNKING_CONFIG)) {
78+
this.chunkingSettings = input.readGenericMap();
79+
} else {
80+
this.chunkingSettings = null;
81+
}
6482
}
6583

6684
@Override
@@ -71,6 +89,9 @@ public void writeTo(StreamOutput out) throws IOException {
7189
out.writeString(searchInferenceId);
7290
}
7391
out.writeStringArray(sourceFields);
92+
if (out.getTransportVersion().onOrAfter(SEMANTIC_TEXT_CHUNKING_CONFIG)) {
93+
out.writeGenericMap(chunkingSettings);
94+
}
7495
}
7596

7697
@Override
@@ -81,16 +102,22 @@ public boolean equals(Object o) {
81102
return Objects.equals(name, that.name)
82103
&& Objects.equals(inferenceId, that.inferenceId)
83104
&& Objects.equals(searchInferenceId, that.searchInferenceId)
84-
&& Arrays.equals(sourceFields, that.sourceFields);
105+
&& Arrays.equals(sourceFields, that.sourceFields)
106+
&& Objects.equals(chunkingSettings, that.chunkingSettings);
85107
}
86108

87109
@Override
88110
public int hashCode() {
89-
int result = Objects.hash(name, inferenceId, searchInferenceId);
111+
int result = Objects.hash(name, inferenceId, searchInferenceId, chunkingSettings);
90112
result = 31 * result + Arrays.hashCode(sourceFields);
91113
return result;
92114
}
93115

116+
@Override
117+
public String toString() {
118+
return Strings.toString(this);
119+
}
120+
94121
public String getName() {
95122
return name;
96123
}
@@ -107,6 +134,10 @@ public String[] getSourceFields() {
107134
return sourceFields;
108135
}
109136

137+
public Map<String, Object> getChunkingSettings() {
138+
return chunkingSettings;
139+
}
140+
110141
public static Diff<InferenceFieldMetadata> readDiffFrom(StreamInput in) throws IOException {
111142
return SimpleDiffable.readDiffFrom(InferenceFieldMetadata::new, in);
112143
}
@@ -119,6 +150,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
119150
builder.field(SEARCH_INFERENCE_ID_FIELD, searchInferenceId);
120151
}
121152
builder.array(SOURCE_FIELDS_FIELD, sourceFields);
153+
if (chunkingSettings != null) {
154+
builder.startObject(CHUNKING_SETTINGS_FIELD);
155+
builder.mapContents(chunkingSettings);
156+
builder.endObject();
157+
}
122158
return builder.endObject();
123159
}
124160

@@ -131,6 +167,7 @@ public static InferenceFieldMetadata fromXContent(XContentParser parser) throws
131167
String currentFieldName = null;
132168
String inferenceId = null;
133169
String searchInferenceId = null;
170+
Map<String, Object> chunkingSettings = null;
134171
List<String> inputFields = new ArrayList<>();
135172
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
136173
if (token == XContentParser.Token.FIELD_NAME) {
@@ -151,6 +188,8 @@ public static InferenceFieldMetadata fromXContent(XContentParser parser) throws
151188
}
152189
}
153190
}
191+
} else if (CHUNKING_SETTINGS_FIELD.equals(currentFieldName)) {
192+
chunkingSettings = parser.map();
154193
} else {
155194
parser.skipChildren();
156195
}
@@ -159,7 +198,8 @@ public static InferenceFieldMetadata fromXContent(XContentParser parser) throws
159198
name,
160199
inferenceId,
161200
searchInferenceId == null ? inferenceId : searchInferenceId,
162-
inputFields.toArray(String[]::new)
201+
inputFields.toArray(String[]::new),
202+
chunkingSettings
163203
);
164204
}
165205
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.inference;
11+
12+
import org.elasticsearch.core.Nullable;
13+
14+
import java.util.List;
15+
16+
public record ChunkInferenceInput(String input, @Nullable ChunkingSettings chunkingSettings) {
17+
18+
public ChunkInferenceInput(String input) {
19+
this(input, null);
20+
}
21+
22+
public static List<String> inputs(List<ChunkInferenceInput> chunkInferenceInputs) {
23+
return chunkInferenceInputs.stream().map(ChunkInferenceInput::input).toList();
24+
}
25+
}

server/src/main/java/org/elasticsearch/inference/ChunkingSettings.java

+4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
1313
import org.elasticsearch.xcontent.ToXContentObject;
1414

15+
import java.util.Map;
16+
1517
public interface ChunkingSettings extends ToXContentObject, VersionedNamedWriteable {
1618
ChunkingStrategy getChunkingStrategy();
19+
20+
Map<String, Object> asMap();
1721
}

server/src/main/java/org/elasticsearch/inference/InferenceService.java

+8-8
Original file line numberDiff line numberDiff line change
@@ -133,18 +133,18 @@ void unifiedCompletionInfer(
133133
/**
134134
* Chunk long text.
135135
*
136-
* @param model The model
137-
* @param query Inference query, mainly for re-ranking
138-
* @param input Inference input
139-
* @param taskSettings Settings in the request to override the model's defaults
140-
* @param inputType For search, ingest etc
141-
* @param timeout The timeout for the request
142-
* @param listener Chunked Inference result listener
136+
* @param model The model
137+
* @param query Inference query, mainly for re-ranking
138+
* @param input Inference input
139+
* @param taskSettings Settings in the request to override the model's defaults
140+
* @param inputType For search, ingest etc
141+
* @param timeout The timeout for the request
142+
* @param listener Chunked Inference result listener
143143
*/
144144
void chunkedInfer(
145145
Model model,
146146
@Nullable String query,
147-
List<String> input,
147+
List<ChunkInferenceInput> input,
148148
Map<String, Object> taskSettings,
149149
InputType inputType,
150150
TimeValue timeout,

server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -727,7 +727,8 @@ private static InferenceFieldMetadata randomInferenceFieldMetadata(String name)
727727
name,
728728
randomIdentifier(),
729729
randomIdentifier(),
730-
randomSet(1, 5, ESTestCase::randomIdentifier).toArray(String[]::new)
730+
randomSet(1, 5, ESTestCase::randomIdentifier).toArray(String[]::new),
731+
InferenceFieldMetadataTests.generateRandomChunkingSettings()
731732
);
732733
}
733734

server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java

+46-10
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
import org.elasticsearch.xcontent.XContentParser;
1616

1717
import java.io.IOException;
18+
import java.util.Map;
1819
import java.util.function.Predicate;
1920

21+
import static org.elasticsearch.cluster.metadata.InferenceFieldMetadata.CHUNKING_SETTINGS_FIELD;
2022
import static org.hamcrest.Matchers.equalTo;
2123

2224
public class InferenceFieldMetadataTests extends AbstractXContentTestCase<InferenceFieldMetadata> {
@@ -37,11 +39,6 @@ protected InferenceFieldMetadata createTestInstance() {
3739
return createTestItem();
3840
}
3941

40-
@Override
41-
protected Predicate<String> getRandomFieldsExcludeFilter() {
42-
return p -> p.equals(""); // do not add elements at the top-level as any element at this level is parsed as a new inference field
43-
}
44-
4542
@Override
4643
protected InferenceFieldMetadata doParseInstance(XContentParser parser) throws IOException {
4744
if (parser.nextToken() == XContentParser.Token.START_OBJECT) {
@@ -58,18 +55,57 @@ protected boolean supportsUnknownFields() {
5855
return true;
5956
}
6057

58+
@Override
59+
protected Predicate<String> getRandomFieldsExcludeFilter() {
60+
// do not add elements at the top-level as any element at this level is parsed as a new inference field,
61+
// and do not add additional elements to chunking maps as they will fail parsing with extra data
62+
return field -> field.equals("") || field.contains(CHUNKING_SETTINGS_FIELD);
63+
}
64+
6165
private static InferenceFieldMetadata createTestItem() {
6266
String name = randomAlphaOfLengthBetween(3, 10);
6367
String inferenceId = randomIdentifier();
6468
String searchInferenceId = randomIdentifier();
6569
String[] inputFields = generateRandomStringArray(5, 10, false, false);
66-
return new InferenceFieldMetadata(name, inferenceId, searchInferenceId, inputFields);
70+
Map<String, Object> chunkingSettings = generateRandomChunkingSettings();
71+
return new InferenceFieldMetadata(name, inferenceId, searchInferenceId, inputFields, chunkingSettings);
72+
}
73+
74+
public static Map<String, Object> generateRandomChunkingSettings() {
75+
if (randomBoolean()) {
76+
return null; // Defaults to model chunking settings
77+
}
78+
return randomBoolean() ? generateRandomWordBoundaryChunkingSettings() : generateRandomSentenceBoundaryChunkingSettings();
79+
}
80+
81+
private static Map<String, Object> generateRandomWordBoundaryChunkingSettings() {
82+
return Map.of("strategy", "word_boundary", "max_chunk_size", randomIntBetween(20, 100), "overlap", randomIntBetween(1, 50));
83+
}
84+
85+
private static Map<String, Object> generateRandomSentenceBoundaryChunkingSettings() {
86+
return Map.of(
87+
"strategy",
88+
"sentence_boundary",
89+
"max_chunk_size",
90+
randomIntBetween(20, 100),
91+
"sentence_overlap",
92+
randomIntBetween(0, 1)
93+
);
6794
}
6895

6996
public void testNullCtorArgsThrowException() {
70-
assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata(null, "inferenceId", "searchInferenceId", new String[0]));
71-
assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", null, "searchInferenceId", new String[0]));
72-
assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", "inferenceId", null, new String[0]));
73-
assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", "inferenceId", "searchInferenceId", null));
97+
assertThrows(
98+
NullPointerException.class,
99+
() -> new InferenceFieldMetadata(null, "inferenceId", "searchInferenceId", new String[0], Map.of())
100+
);
101+
assertThrows(
102+
NullPointerException.class,
103+
() -> new InferenceFieldMetadata("name", null, "searchInferenceId", new String[0], Map.of())
104+
);
105+
assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", "inferenceId", null, new String[0], Map.of()));
106+
assertThrows(
107+
NullPointerException.class,
108+
() -> new InferenceFieldMetadata("name", "inferenceId", "searchInferenceId", null, Map.of())
109+
);
74110
}
75111
}

server/src/test/java/org/elasticsearch/index/mapper/MappingLookupInferenceFieldMapperTests.java

+8-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import org.apache.lucene.search.Query;
1313
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
14+
import org.elasticsearch.cluster.metadata.InferenceFieldMetadataTests;
1415
import org.elasticsearch.index.query.SearchExecutionContext;
1516
import org.elasticsearch.plugins.MapperPlugin;
1617
import org.elasticsearch.plugins.Plugin;
@@ -102,7 +103,13 @@ private static class TestInferenceFieldMapper extends FieldMapper implements Inf
102103

103104
@Override
104105
public InferenceFieldMetadata getMetadata(Set<String> sourcePaths) {
105-
return new InferenceFieldMetadata(fullPath(), INFERENCE_ID, SEARCH_INFERENCE_ID, sourcePaths.toArray(new String[0]));
106+
return new InferenceFieldMetadata(
107+
fullPath(),
108+
INFERENCE_ID,
109+
SEARCH_INFERENCE_ID,
110+
sourcePaths.toArray(new String[0]),
111+
InferenceFieldMetadataTests.generateRandomChunkingSettings()
112+
);
106113
}
107114

108115
@Override

0 commit comments

Comments
 (0)