Skip to content

Commit 8924385

Browse files
authored
Merge branch 'main' into tvclean0
2 parents 55ad1cd + 8497d3d commit 8924385

File tree

6 files changed

+92
-21
lines changed

6 files changed

+92
-21
lines changed

docs/changelog/139347.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 139347
2+
summary: "Enable bfloat16 support for semantic text"
3+
area: Mapping
4+
type: enhancement
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import static org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder.RETRIEVER_RESULT_DIVERSIFICATION_USES_QUERY_VECTOR_BUILDER;
2424
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SEMANTIC_TEXT_AUTO_PREFILTERING;
25+
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SEMANTIC_TEXT_BFLOAT16_SUPPORT;
2526
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SEMANTIC_TEXT_EXCLUDE_SUB_FIELDS_FROM_FIELD_CAPS;
2627
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SEMANTIC_TEXT_INDEX_OPTIONS;
2728
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SEMANTIC_TEXT_INDEX_OPTIONS_WITH_DEFAULTS;
@@ -119,6 +120,7 @@ public Set<NodeFeature> getTestFeatures() {
119120
SEMANTIC_TEXT_HIGHLIGHTER_DISKBBQ_SIMILARITY_SUPPORT,
120121
SEMANTIC_TEXT_HIGHLIGHTER_VECTOR_SIMILARITY_SUPPORT,
121122
SEMANTIC_TEXT_AUTO_PREFILTERING,
123+
SEMANTIC_TEXT_BFLOAT16_SUPPORT,
122124
SemanticQueryBuilder.SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS,
123125
SemanticQueryBuilder.SEMANTIC_QUERY_FILTER_FIELD_CAPS_FIX,
124126
InterceptedInferenceQueryBuilder.NEW_SEMANTIC_QUERY_INTERCEPTORS,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
154154
);
155155
public static final NodeFeature SEMANTIC_TEXT_UPDATABLE_INFERENCE_ID = new NodeFeature("semantic_text.updatable_inference_id");
156156
public static final NodeFeature SEMANTIC_TEXT_AUTO_PREFILTERING = new NodeFeature("semantic_text.auto_prefiltering");
157+
public static final NodeFeature SEMANTIC_TEXT_BFLOAT16_SUPPORT = new NodeFeature("semantic_text.bfloat16_support");
157158

158159
public static final String CONTENT_TYPE = "semantic_text";
159160
public static final String DEFAULT_FALLBACK_ELSER_INFERENCE_ID = DEFAULT_ELSER_ID;
@@ -1423,10 +1424,6 @@ private static void configureDenseVectorMapperBuilder(
14231424
}
14241425
}
14251426

1426-
if (modelSettings.elementType() == DenseVectorFieldMapper.ElementType.BFLOAT16) {
1427-
throw new IllegalArgumentException("semantic_text does not support bfloat16");
1428-
}
1429-
14301427
assert modelSettings.dimensions() != null : "Model settings should have dimensions set by now for text embedding models";
14311428
denseVectorMapperBuilder.dimensions(modelSettings.dimensions());
14321429
denseVectorMapperBuilder.elementType(modelSettings.elementType());

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1592,21 +1592,31 @@ public void testDenseVectorElementType() throws IOException {
15921592
);
15931593
assertMapperService.accept(byteMapperService, DenseVectorFieldMapper.ElementType.BYTE);
15941594

1595-
var e = expectThrows(
1596-
DocumentParsingException.class,
1597-
() -> mapperServiceForFieldWithModelSettings(
1598-
fieldName,
1599-
inferenceId,
1600-
new MinimalServiceSettings(
1601-
"my-service",
1602-
TaskType.TEXT_EMBEDDING,
1603-
1024,
1604-
SimilarityMeasure.COSINE,
1605-
DenseVectorFieldMapper.ElementType.BFLOAT16
1606-
)
1595+
MapperService bitMapperService = mapperServiceForFieldWithModelSettings(
1596+
fieldName,
1597+
inferenceId,
1598+
new MinimalServiceSettings(
1599+
"my-service",
1600+
TaskType.TEXT_EMBEDDING,
1601+
1024,
1602+
SimilarityMeasure.L2_NORM,
1603+
DenseVectorFieldMapper.ElementType.BIT
1604+
)
1605+
);
1606+
assertMapperService.accept(bitMapperService, DenseVectorFieldMapper.ElementType.BIT);
1607+
1608+
MapperService bfloat16MapperService = mapperServiceForFieldWithModelSettings(
1609+
fieldName,
1610+
inferenceId,
1611+
new MinimalServiceSettings(
1612+
"my-service",
1613+
TaskType.TEXT_EMBEDDING,
1614+
1024,
1615+
SimilarityMeasure.COSINE,
1616+
DenseVectorFieldMapper.ElementType.BFLOAT16
16071617
)
16081618
);
1609-
assertThat(e.getCause(), instanceOf(IllegalArgumentException.class));
1619+
assertMapperService.accept(bfloat16MapperService, DenseVectorFieldMapper.ElementType.BFLOAT16);
16101620
}
16111621

16121622
public void testSettingAndUpdatingChunkingSettings() throws IOException {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,8 @@ public static ChunkedInferenceEmbedding randomChunkedInferenceEmbedding(Model mo
190190
return switch (model.getTaskType()) {
191191
case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs);
192192
case TEXT_EMBEDDING -> switch (model.getServiceSettings().elementType()) {
193-
case FLOAT -> randomChunkedInferenceEmbeddingFloat(model, inputs);
193+
case FLOAT, BFLOAT16 -> randomChunkedInferenceEmbeddingFloat(model, inputs);
194194
case BIT, BYTE -> randomChunkedInferenceEmbeddingByte(model, inputs);
195-
case BFLOAT16 -> throw new AssertionError();
196195
};
197196
default -> throw new AssertionError("invalid task type: " + model.getTaskType().name());
198197
};
@@ -416,9 +415,8 @@ public static ChunkedInference toChunkedResult(
416415
ChunkedInference.TextOffset offset = createOffset(useLegacyFormat, entryChunk, matchedText);
417416
double[] values = parseDenseVector(entryChunk.rawEmbeddings(), embeddingLength, field.contentType());
418417
EmbeddingResults.Embedding<?> embedding = switch (elementType) {
419-
case FLOAT -> new DenseEmbeddingFloatResults.Embedding(FloatConversionUtils.floatArrayOf(values));
418+
case FLOAT, BFLOAT16 -> new DenseEmbeddingFloatResults.Embedding(FloatConversionUtils.floatArrayOf(values));
420419
case BYTE, BIT -> new DenseEmbeddingByteResults.Embedding(byteArrayOf(values));
421-
case BFLOAT16 -> throw new AssertionError();
422420
};
423421
chunks.add(new EmbeddingResults.Chunk(embedding, offset));
424422
}

x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/40_semantic_text_query.yml

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,65 @@ setup:
415415
- match: { hits.hits.0._id: "doc_1" }
416416
- close_to: { hits.hits.0._score: { value: 1.0, error: 0.0001 } }
417417

418+
---
419+
"Query using a dense embedding model that uses bfloat16 embeddings":
420+
- requires:
421+
cluster_features: [ "semantic_text.bfloat16_support" ]
422+
reason: bfloat16 support for semantic text
423+
- skip:
424+
features: [ "headers", "close_to" ]
425+
426+
- do:
427+
inference.put:
428+
task_type: text_embedding
429+
inference_id: dense-inference-bfloat16-id
430+
body: >
431+
{
432+
"service": "text_embedding_test_service",
433+
"service_settings": {
434+
"model": "my_model",
435+
"dimensions": 10,
436+
"api_key": "abc64",
437+
"similarity": "COSINE",
438+
"element_type": "bfloat16"
439+
},
440+
"task_settings": {
441+
}
442+
}
443+
444+
- do:
445+
indices.create:
446+
index: test-dense-bfloat16-index
447+
body:
448+
mappings:
449+
properties:
450+
inference_field:
451+
type: semantic_text
452+
inference_id: dense-inference-bfloat16-id
453+
454+
- do:
455+
index:
456+
index: test-dense-bfloat16-index
457+
id: doc_1
458+
body:
459+
inference_field: [ "inference test", "another inference test" ]
460+
refresh: true
461+
462+
- do:
463+
headers:
464+
# Force JSON content type so that we use a parser that interprets the floating-point score as a double
465+
Content-Type: application/json
466+
search:
467+
index: test-dense-bfloat16-index
468+
body:
469+
query:
470+
match:
471+
inference_field: "inference test"
472+
473+
- match: { hits.total.value: 1 }
474+
- match: { hits.hits.0._id: "doc_1" }
475+
- close_to: { hits.hits.0._score: { value: 1.0, error: 0.001 } }
476+
418477
---
419478
"Query using a dense embedding model via a search inference ID":
420479
- skip:

0 commit comments

Comments
 (0)