Skip to content

feat: [ML] Support binary embeddings from Amazon Bedrock Titan (#125378) #126540

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ public class AmazonBedrockConstants {
public static final String ACCESS_KEY_FIELD = "access_key";
public static final String SECRET_KEY_FIELD = "secret_key";
public static final String REGION_FIELD = "region";
public static final String MODEL_FIELD = "model";
public static final String MODEL_FIELD = "model_id";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing the value of this field will break existing configurations.

public static final String PROVIDER_FIELD = "provider";
public static final String EMBEDDING_TYPE_FIELD = "embedding_type";

public static final String TEMPERATURE_FIELD = "temperature";
public static final String TOP_P_FIELD = "top_p";
Expand All @@ -26,4 +27,5 @@ public class AmazonBedrockConstants {

public static final int DEFAULT_MAX_CHUNK_SIZE = 2048;

private AmazonBedrockConstants() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredEnum;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.EMBEDDING_TYPE_FIELD;
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MODEL_FIELD;
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.PROVIDER_FIELD;
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.REGION_FIELD;
Expand All @@ -35,10 +37,31 @@ public abstract class AmazonBedrockServiceSettings extends FilteredXContentObjec

protected static final String AMAZON_BEDROCK_BASE_NAME = "amazon_bedrock";

public enum AmazonBedrockEmbeddingType {
FLOAT,
BINARY;

public static AmazonBedrockEmbeddingType fromString(String value) {
return switch (value.toLowerCase()) {
case "float" -> FLOAT;
case "binary" -> BINARY;
default -> throw new IllegalArgumentException("unknown value for embedding type: " + value);
};
}

@Override
public String toString() {
return name().toLowerCase();
}
}

protected static final AmazonBedrockEmbeddingType DEFAULT_EMBEDDING_TYPE = AmazonBedrockEmbeddingType.FLOAT;

protected final String region;
protected final String model;
protected final AmazonBedrockProvider provider;
protected final RateLimitSettings rateLimitSettings;
protected final AmazonBedrockEmbeddingType embeddingType;

// the default requests per minute are defined as per-model in the "Runtime quotas" on AWS
// see: https://docs.aws.amazon.com/bedrock/latest/userguide/quotas.html
Expand Down Expand Up @@ -69,34 +92,50 @@ protected static AmazonBedrockServiceSettings.BaseAmazonBedrockCommonSettings fr
AMAZON_BEDROCK_BASE_NAME,
context
);
AmazonBedrockEmbeddingType embeddingType = extractOptionalEnum(
map,
EMBEDDING_TYPE_FIELD,
ModelConfigurations.SERVICE_SETTINGS,
AmazonBedrockEmbeddingType::fromString,
EnumSet.allOf(AmazonBedrockEmbeddingType.class),
validationException
).orElse(DEFAULT_EMBEDDING_TYPE);

return new BaseAmazonBedrockCommonSettings(region, model, provider, rateLimitSettings);
return new BaseAmazonBedrockCommonSettings(region, model, provider, rateLimitSettings, embeddingType);
}

protected record BaseAmazonBedrockCommonSettings(
String region,
String model,
AmazonBedrockProvider provider,
@Nullable RateLimitSettings rateLimitSettings
@Nullable RateLimitSettings rateLimitSettings,
AmazonBedrockEmbeddingType embeddingType
) {}

protected AmazonBedrockServiceSettings(StreamInput in) throws IOException {
this.region = in.readString();
this.model = in.readString();
this.provider = in.readEnum(AmazonBedrockProvider.class);
this.rateLimitSettings = new RateLimitSettings(in);
if (in.getTransportVersion().onOrAfter(TransportVersions.V_9_0_0)) { // Version set for BWC
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this.embeddingType = in.readEnum(AmazonBedrockEmbeddingType.class);
} else {
this.embeddingType = DEFAULT_EMBEDDING_TYPE;
}
}

protected AmazonBedrockServiceSettings(
String region,
String model,
AmazonBedrockProvider provider,
@Nullable RateLimitSettings rateLimitSettings
@Nullable RateLimitSettings rateLimitSettings,
AmazonBedrockEmbeddingType embeddingType
) {
this.region = Objects.requireNonNull(region);
this.model = Objects.requireNonNull(model);
this.provider = Objects.requireNonNull(provider);
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
this.embeddingType = Objects.requireNonNullElse(embeddingType, DEFAULT_EMBEDDING_TYPE);
}

@Override
Expand All @@ -121,12 +160,19 @@ public RateLimitSettings rateLimitSettings() {
return rateLimitSettings;
}

public AmazonBedrockEmbeddingType embeddingType() {
return embeddingType;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(region);
out.writeString(model);
out.writeEnum(provider);
rateLimitSettings.writeTo(out);
if (out.getTransportVersion().onOrAfter(TransportVersions.V_9_0_0)) { // Version set for BWC
out.writeEnum(embeddingType);
}
}

public void addBaseXContent(XContentBuilder builder, Params params) throws IOException {
Expand All @@ -137,6 +183,9 @@ protected void addXContentFragmentOfExposedFields(XContentBuilder builder, Param
builder.field(REGION_FIELD, region);
builder.field(MODEL_FIELD, model);
builder.field(PROVIDER_FIELD, provider.name());
if (embeddingType != DEFAULT_EMBEDDING_TYPE) {
builder.field(EMBEDDING_TYPE_FIELD, embeddingType.toString());
}
Comment on lines +186 to +188
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (embeddingType != DEFAULT_EMBEDDING_TYPE) {
builder.field(EMBEDDING_TYPE_FIELD, embeddingType.toString());
}
builder.field(EMBEDDING_TYPE_FIELD, embeddingType.toString());

We prefer show the default values explicitly rather than hiding them.

rateLimitSettings.toXContent(builder, params);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public static ToXContent createEntity(
if (truncatedInput.size() > 1) {
throw new ElasticsearchException("[input] cannot contain more than one string");
}
return new AmazonBedrockTitanEmbeddingsRequestEntity(truncatedInput.get(0));
return new AmazonBedrockTitanEmbeddingsRequestEntity(truncatedInput.get(0), serviceSettings.embeddingType());
}
case COHERE -> {
return new AmazonBedrockCohereEmbeddingsRequestEntity(truncatedInput, inputType, model.getTaskSettings());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,31 @@

import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockServiceSettings.AmazonBedrockEmbeddingType;

import java.io.IOException;
import java.util.Objects;

public record AmazonBedrockTitanEmbeddingsRequestEntity(String inputText) implements ToXContentObject {
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.EMBEDDING_TYPE_FIELD;

public record AmazonBedrockTitanEmbeddingsRequestEntity(String inputText, AmazonBedrockEmbeddingType embeddingType)
implements
ToXContentObject {

private static final String INPUT_TEXT_FIELD = "inputText";

public AmazonBedrockTitanEmbeddingsRequestEntity {
Objects.requireNonNull(inputText);
Objects.requireNonNull(embeddingType);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(INPUT_TEXT_FIELD, inputText);
if (embeddingType == AmazonBedrockEmbeddingType.BINARY) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Always be explicit rather than depending on default values set in a 3rd party service. If Bedrock changes the default embedding then it will break any integrations that rely on the default value being float. Also of a new value is added to the AmazonBedrockEmbeddingType enum it won't be set in this request.

The G1 models do not support binary embeddings only the V2 models

https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html

When an user creates

builder.field(EMBEDDING_TYPE_FIELD, embeddingType.toString());
}
builder.endObject();
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,18 @@
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBytesResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
import org.elasticsearch.xpack.inference.external.response.XContentUtils;
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider;
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockServiceSettings.AmazonBedrockEmbeddingType;
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.AmazonBedrockRequest;
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.embeddings.AmazonBedrockEmbeddingsRequest;
import org.elasticsearch.xpack.inference.services.amazonbedrock.response.AmazonBedrockResponse;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.List;

import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
Expand All @@ -42,13 +45,13 @@ public AmazonBedrockEmbeddingsResponse(InvokeModelResponse invokeModelResult) {
@Override
public InferenceServiceResults accept(AmazonBedrockRequest request) {
if (request instanceof AmazonBedrockEmbeddingsRequest asEmbeddingsRequest) {
return fromResponse(result, asEmbeddingsRequest.provider());
return fromResponse(result, asEmbeddingsRequest);
}

throw new ElasticsearchException("unexpected request type [" + request.getClass() + "]");
}

public static TextEmbeddingFloatResults fromResponse(InvokeModelResponse response, AmazonBedrockProvider provider) {
public static TextEmbeddingResults fromResponse(InvokeModelResponse response, AmazonBedrockEmbeddingsRequest request) {
var charset = StandardCharsets.UTF_8;
var bodyText = String.valueOf(charset.decode(response.body().asByteBuffer()));

Expand All @@ -61,28 +64,33 @@ public static TextEmbeddingFloatResults fromResponse(InvokeModelResponse respons
XContentParser.Token token = jsonParser.currentToken();
ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser);

var embeddingList = parseEmbeddings(jsonParser, provider);

return new TextEmbeddingFloatResults(embeddingList);
var embeddingType = request.getServiceSettings().embeddingType();
if (embeddingType == AmazonBedrockEmbeddingType.BINARY) {
var embeddingList = parseBinaryEmbeddings(jsonParser, request.provider());
return new TextEmbeddingBytesResults(embeddingList);
} else {
var embeddingList = parseFloatEmbeddings(jsonParser, request.provider());
return new TextEmbeddingFloatResults(embeddingList);
}
} catch (IOException e) {
throw new ElasticsearchException(e);
}
}

private static List<TextEmbeddingFloatResults.Embedding> parseEmbeddings(XContentParser jsonParser, AmazonBedrockProvider provider)
private static List<TextEmbeddingResults.InferredValue> parseFloatEmbeddings(XContentParser jsonParser, AmazonBedrockProvider provider)
throws IOException {
switch (provider) {
case AMAZONTITAN -> {
return parseTitanEmbeddings(jsonParser);
return parseTitanFloatEmbeddings(jsonParser);
}
case COHERE -> {
return parseCohereEmbeddings(jsonParser);
return parseCohereFloatEmbeddings(jsonParser);
}
default -> throw new IOException("Unsupported provider [" + provider + "]");
}
}

private static List<TextEmbeddingFloatResults.Embedding> parseTitanEmbeddings(XContentParser parser) throws IOException {
private static List<TextEmbeddingResults.InferredValue> parseTitanFloatEmbeddings(XContentParser parser) throws IOException {
/*
Titan response:
{
Expand All @@ -92,11 +100,11 @@ private static List<TextEmbeddingFloatResults.Embedding> parseTitanEmbeddings(XC
*/
positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE);
List<Float> embeddingValuesList = parseList(parser, XContentUtils::parseFloat);
var embeddingValues = TextEmbeddingFloatResults.Embedding.of(embeddingValuesList);
TextEmbeddingResults.InferredValue embeddingValues = TextEmbeddingFloatResults.Embedding.of(embeddingValuesList);
return List.of(embeddingValues);
}

private static List<TextEmbeddingFloatResults.Embedding> parseCohereEmbeddings(XContentParser parser) throws IOException {
private static List<TextEmbeddingResults.InferredValue> parseCohereFloatEmbeddings(XContentParser parser) throws IOException {
/*
Cohere response:
{
Expand All @@ -111,17 +119,43 @@ private static List<TextEmbeddingFloatResults.Embedding> parseCohereEmbeddings(X
*/
positionParserAtTokenAfterField(parser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE);

List<TextEmbeddingFloatResults.Embedding> embeddingList = parseList(
List<TextEmbeddingResults.InferredValue> embeddingList = parseList(
parser,
AmazonBedrockEmbeddingsResponse::parseCohereEmbeddingsListItem
AmazonBedrockEmbeddingsResponse::parseCohereFloatEmbeddingsListItem
);

return embeddingList;
}

private static TextEmbeddingFloatResults.Embedding parseCohereEmbeddingsListItem(XContentParser parser) throws IOException {
private static TextEmbeddingResults.InferredValue parseCohereFloatEmbeddingsListItem(XContentParser parser) throws IOException {
List<Float> embeddingValuesList = parseList(parser, XContentUtils::parseFloat);
return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList);
}

private static List<TextEmbeddingResults.InferredValue> parseBinaryEmbeddings(XContentParser jsonParser, AmazonBedrockProvider provider)
throws IOException {
switch (provider) {
case AMAZONTITAN -> {
return parseTitanBinaryEmbeddings(jsonParser);
}
default -> throw new IOException("Binary embeddings not supported for provider [" + provider + "]");
}
}

private static List<TextEmbeddingResults.InferredValue> parseTitanBinaryEmbeddings(XContentParser parser) throws IOException {
/*
Titan Binary response (structure assumed based on float version):
{
"embedding": "<base64-encoded-binary-data>",
"inputTextTokenCount": int
}
*/
positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE);
String base64Embedding = parser.text();
byte[] embeddingBytes = Base64.getDecoder().decode(base64Embedding);

TextEmbeddingResults.InferredValue embeddingValue = TextEmbeddingBytesResults.Embedding.of(embeddingBytes);
return List.of(embeddingValue);
}

}
Loading