Skip to content

[ML] Integrate OpenAi Chat Completion in SageMaker #127767

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/127767.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 127767
summary: Integrate `OpenAi` Chat Completion in `SageMaker`
area: Machine Learning
type: enhancement
issues: []
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME_8_19 = def(8_841_0_34);
public static final TransportVersion RERANKER_FAILURES_ALLOWED_8_19 = def(8_841_0_35);
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED_8_19 = def(8_841_0_36);
public static final TransportVersion ML_INFERENCE_SAGEMAKER_CHAT_COMPLETION_8_19 = def(8_841_0_37);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
Expand Down Expand Up @@ -264,6 +265,7 @@ static TransportVersion def(int id) {
public static final TransportVersion NODES_STATS_SUPPORTS_MULTI_PROJECT = def(9_079_0_00);
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED = def(9_080_0_00);
public static final TransportVersion SETTINGS_IN_DATA_STREAMS_DRY_RUN = def(9_081_0_00);
public static final TransportVersion ML_INFERENCE_SAGEMAKER_CHAT_COMPLETION = def(9_082_0_00);
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {

public void testGetServicesWithCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.COMPLETION);
assertThat(services.size(), equalTo(11));
assertThat(services.size(), equalTo(12));

var providers = providers(services);

Expand All @@ -142,21 +142,24 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
"googleaistudio",
"openai",
"streaming_completion_test_service",
"hugging_face"
"hugging_face",
"sagemaker"
).toArray()
)
);
}

public void testGetServicesWithChatCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
assertThat(services.size(), equalTo(5));
assertThat(services.size(), equalTo(6));

var providers = providers(services);

assertThat(
providers,
containsInAnyOrder(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service", "hugging_face").toArray())
containsInAnyOrder(
List.of("deepseek", "elastic", "openai", "streaming_completion_test_service", "hugging_face", "sagemaker").toArray()
)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;

import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.Iterator;
import java.util.concurrent.Flow;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Stream;

/**
* Processor that delegates the {@link java.util.concurrent.Flow.Subscription} to the upstream {@link java.util.concurrent.Flow.Publisher}
Expand All @@ -34,19 +33,13 @@ public abstract class DelegatingProcessor<T, R> implements Flow.Processor<T, R>
public static <ParsedChunk> Deque<ParsedChunk> parseEvent(
Deque<ServerSentEvent> item,
ParseChunkFunction<ParsedChunk> parseFunction,
XContentParserConfiguration parserConfig,
Logger logger
) throws Exception {
XContentParserConfiguration parserConfig
) {
var results = new ArrayDeque<ParsedChunk>(item.size());
for (ServerSentEvent event : item) {
if (event.hasData()) {
try {
var delta = parseFunction.apply(parserConfig, event);
delta.forEachRemaining(results::offer);
} catch (Exception e) {
logger.warn("Failed to parse event from inference provider: {}", event);
throw e;
}
var delta = parseFunction.apply(parserConfig, event);
delta.forEach(results::offer);
}
}

Expand All @@ -55,7 +48,7 @@ public static <ParsedChunk> Deque<ParsedChunk> parseEvent(

@FunctionalInterface
public interface ParseChunkFunction<ParsedChunk> {
Iterator<ParsedChunk> apply(XContentParserConfiguration parserConfig, ServerSentEvent event) throws IOException;
Stream<ParsedChunk> apply(XContentParserConfiguration parserConfig, ServerSentEvent event);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,12 @@ public class UnifiedChatCompletionRequestEntity implements ToXContentFragment {
private final boolean stream;

public UnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput) {
Objects.requireNonNull(unifiedChatInput);
this(Objects.requireNonNull(unifiedChatInput).getRequest(), Objects.requireNonNull(unifiedChatInput).stream());
}

this.unifiedRequest = unifiedChatInput.getRequest();
this.stream = unifiedChatInput.stream();
public UnifiedChatCompletionRequestEntity(UnifiedCompletionRequest unifiedRequest, boolean stream) {
this.unifiedRequest = Objects.requireNonNull(unifiedRequest);
this.stream = stream;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
Expand All @@ -20,11 +22,10 @@
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;

import java.io.IOException;
import java.util.Collections;
import java.util.Deque;
import java.util.Iterator;
import java.util.Objects;
import java.util.function.Predicate;
import java.util.stream.Stream;

import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList;
Expand Down Expand Up @@ -113,7 +114,7 @@ public class OpenAiStreamingProcessor extends DelegatingProcessor<Deque<ServerSe
@Override
protected void next(Deque<ServerSentEvent> item) throws Exception {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
var results = parseEvent(item, OpenAiStreamingProcessor::parse, parserConfig, log);
var results = parseEvent(item, OpenAiStreamingProcessor::parse, parserConfig);

if (results.isEmpty()) {
upstream().request(1);
Expand All @@ -122,10 +123,9 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {
}
}

private static Iterator<StreamingChatCompletionResults.Result> parse(XContentParserConfiguration parserConfig, ServerSentEvent event)
throws IOException {
public static Stream<StreamingChatCompletionResults.Result> parse(XContentParserConfiguration parserConfig, ServerSentEvent event) {
if (DONE_MESSAGE.equalsIgnoreCase(event.data())) {
return Collections.emptyIterator();
return Stream.empty();
}

try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.data())) {
Expand Down Expand Up @@ -167,11 +167,14 @@ private static Iterator<StreamingChatCompletionResults.Result> parse(XContentPar

consumeUntilObjectEnd(parser); // end choices
return ""; // stopped
}).stream()
.filter(Objects::nonNull)
.filter(Predicate.not(String::isEmpty))
.map(StreamingChatCompletionResults.Result::new)
.iterator();
}).stream().filter(Objects::nonNull).filter(Predicate.not(String::isEmpty)).map(StreamingChatCompletionResults.Result::new);
} catch (IOException e) {
throw new ElasticsearchStatusException(
"Failed to parse event from inference provider: {}",
RestStatus.INTERNAL_SERVER_ERROR,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've talked about switching to use 502s, do you think that'd be appropriate here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so? Because this IOException is an error with our parsing logic, which may or may not mean there is something wrong with their response. It could be that we're out of date.

e,
event
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ public OpenAiUnifiedChatCompletionResponseHandler(
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
var openAiProcessor = new OpenAiUnifiedStreamingProcessor((m, e) -> buildMidStreamError(request, m, e));

flow.subscribe(serverSentEventProcessor);
serverSentEventProcessor.subscribe(openAiProcessor);
return new StreamingUnifiedChatCompletionResults(openAiProcessor);
Expand Down Expand Up @@ -81,14 +80,18 @@ protected static String createErrorType(ErrorResponse errorResponse) {
}

protected Exception buildMidStreamError(Request request, String message, Exception e) {
return buildMidStreamError(request.getInferenceEntityId(), message, e);
}

public static UnifiedChatCompletionException buildMidStreamError(String inferenceEntityId, String message, Exception e) {
var errorResponse = OpenAiErrorResponse.fromString(message);
if (errorResponse instanceof OpenAiErrorResponse oer) {
return new UnifiedChatCompletionException(
RestStatus.INTERNAL_SERVER_ERROR,
format(
"%s for request from inference entity id [%s]. Error message: [%s]",
SERVER_ERROR_OBJECT,
request.getInferenceEntityId(),
inferenceEntityId,
errorResponse.getErrorMessage()
),
oer.type(),
Expand All @@ -100,7 +103,7 @@ protected Exception buildMidStreamError(Request request, String message, Excepti
} else {
return new UnifiedChatCompletionException(
RestStatus.INTERNAL_SERVER_ERROR,
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()),
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, inferenceEntityId),
createErrorType(errorResponse),
"stream_error"
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@

import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Collections;
import java.util.Deque;
import java.util.Iterator;
import java.util.List;
import java.util.function.BiFunction;
import java.util.stream.Stream;

import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
Expand Down Expand Up @@ -75,7 +74,7 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {
} else if (event.hasData()) {
try {
var delta = parse(parserConfig, event);
delta.forEachRemaining(results::offer);
delta.forEach(results::offer);
} catch (Exception e) {
logger.warn("Failed to parse event from inference provider: {}", event);
throw errorParser.apply(event.data(), e);
Expand All @@ -90,12 +89,12 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {
}
}

private static Iterator<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> parse(
public static Stream<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> parse(
XContentParserConfiguration parserConfig,
ServerSentEvent event
) throws IOException {
if (DONE_MESSAGE.equalsIgnoreCase(event.data())) {
return Collections.emptyIterator();
return Stream.empty();
}

try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.data())) {
Expand All @@ -106,7 +105,7 @@ private static Iterator<StreamingUnifiedChatCompletionResults.ChatCompletionChun

StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = ChatCompletionChunkParser.parse(jsonParser);

return Collections.singleton(chunk).iterator();
return Stream.of(chunk);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ public class OpenAiChatCompletionResponseEntity {
*/

public static ChatCompletionResults fromResponse(Request request, HttpResult response) throws IOException {
try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) {
return fromResponse(response.body());
}

public static ChatCompletionResults fromResponse(byte[] response) throws IOException {
try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response)) {
return CompletionResult.PARSER.apply(p, null).toChatCompletionResults();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
public class SageMakerService implements InferenceService {
public static final String NAME = "sagemaker";
private static final int DEFAULT_BATCH_SIZE = 256;
private static final TimeValue DEFAULT_TIMEOUT = TimeValue.THIRTY_SECONDS;
private final SageMakerModelBuilder modelBuilder;
private final SageMakerClient client;
private final SageMakerSchemas schemas;
Expand Down Expand Up @@ -128,7 +129,7 @@ public void infer(
boolean stream,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
@Nullable TimeValue timeout,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I thought the timeout is defaulted in the InferenceAction

Can it be null here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I believe I was hitting an issue when I was using curl, I think it can be null through this path:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be defaulting it there too I think:

    static TimeValue parseTimeout(RestRequest restRequest) {
        return restRequest.paramAsTime(InferenceAction.Request.TIMEOUT.getPreferredName(), InferenceAction.Request.DEFAULT_TIMEOUT);
    }

I think we should consider it a bug if it's null once it gets to the infer() calls. We should make sure it's defaulted prior to those calls.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ActionListener<InferenceServiceResults> listener
) {
if (model instanceof SageMakerModel == false) {
Expand All @@ -148,7 +149,7 @@ public void infer(
client.invokeStream(
regionAndSecrets,
request,
timeout,
timeout != null ? timeout : DEFAULT_TIMEOUT,
ActionListener.wrap(
response -> listener.onResponse(schema.streamResponse(sageMakerModel, response)),
e -> listener.onFailure(schema.error(sageMakerModel, e))
Expand All @@ -160,7 +161,7 @@ public void infer(
client.invoke(
regionAndSecrets,
request,
timeout,
timeout != null ? timeout : DEFAULT_TIMEOUT,
ActionListener.wrap(
response -> listener.onResponse(schema.response(sageMakerModel, response, threadPool.getThreadContext())),
e -> listener.onFailure(schema.error(sageMakerModel, e))
Expand Down Expand Up @@ -201,7 +202,7 @@ private static ElasticsearchStatusException internalFailure(Model model, Excepti
public void unifiedCompletionInfer(
Model model,
UnifiedCompletionRequest request,
TimeValue timeout,
@Nullable TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
if (model instanceof SageMakerModel == false) {
Expand All @@ -217,7 +218,7 @@ public void unifiedCompletionInfer(
client.invokeStream(
regionAndSecrets,
sagemakerRequest,
timeout,
timeout != null ? timeout : DEFAULT_TIMEOUT,
ActionListener.wrap(
response -> listener.onResponse(schema.chatCompletionStreamResponse(sageMakerModel, response)),
e -> listener.onFailure(schema.chatCompletionError(sageMakerModel, e))
Expand All @@ -235,7 +236,7 @@ public void chunkedInfer(
List<ChunkInferenceInput> input,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
@Nullable TimeValue timeout,
ActionListener<List<ChunkedInference>> listener
) {
if (model instanceof SageMakerModel == false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiCompletionPayload;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiTextEmbeddingPayload;

import java.util.Arrays;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand All @@ -39,7 +41,7 @@ public class SageMakerSchemas {
/*
* Add new model API to the register call.
*/
schemas = register(new OpenAiTextEmbeddingPayload());
schemas = register(new OpenAiTextEmbeddingPayload(), new OpenAiCompletionPayload());

streamSchemas = schemas.entrySet()
.stream()
Expand Down Expand Up @@ -88,7 +90,16 @@ public static List<NamedWriteableRegistry.Entry> namedWriteables() {
)
),
schemas.values().stream().flatMap(SageMakerSchema::namedWriteables)
).toList();
)
// Dedupe based on Entry name, we allow Payloads to declare the same Entry but the Registry does not handle duplicates
.collect(
() -> new HashMap<String, NamedWriteableRegistry.Entry>(),
(map, entry) -> map.putIfAbsent(entry.name, entry),
Map::putAll
)
.values()
.stream()
.toList();
}

public SageMakerSchema schemaFor(SageMakerModel model) throws ElasticsearchStatusException {
Expand Down
Loading
Loading