Skip to content

[ML] Integrate OpenAi Chat Completion in SageMaker #127767

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/127767.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 127767
summary: Integrate `OpenAi` Chat Completion in `SageMaker`
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {

public void testGetServicesWithCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.COMPLETION);
assertThat(services.size(), equalTo(10));
assertThat(services.size(), equalTo(11));

var providers = providers(services);

Expand All @@ -140,19 +140,23 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
"deepseek",
"googleaistudio",
"openai",
"streaming_completion_test_service"
"streaming_completion_test_service",
"sagemaker"
).toArray()
)
);
}

public void testGetServicesWithChatCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
assertThat(services.size(), equalTo(4));
assertThat(services.size(), equalTo(5));

var providers = providers(services);

assertThat(providers, containsInAnyOrder(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray()));
assertThat(
providers,
containsInAnyOrder(List.of("deepseek", "elastic", "openai", "sagemaker", "streaming_completion_test_service").toArray())
);
}

public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;

import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.Iterator;
import java.util.concurrent.Flow;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Stream;

/**
* Processor that delegates the {@link java.util.concurrent.Flow.Subscription} to the upstream {@link java.util.concurrent.Flow.Publisher}
Expand All @@ -34,19 +33,13 @@ public abstract class DelegatingProcessor<T, R> implements Flow.Processor<T, R>
public static <ParsedChunk> Deque<ParsedChunk> parseEvent(
Deque<ServerSentEvent> item,
ParseChunkFunction<ParsedChunk> parseFunction,
XContentParserConfiguration parserConfig,
Logger logger
) throws Exception {
XContentParserConfiguration parserConfig
) {
var results = new ArrayDeque<ParsedChunk>(item.size());
for (ServerSentEvent event : item) {
if (event.hasData()) {
try {
var delta = parseFunction.apply(parserConfig, event);
delta.forEachRemaining(results::offer);
} catch (Exception e) {
logger.warn("Failed to parse event from inference provider: {}", event);
throw e;
}
var delta = parseFunction.apply(parserConfig, event);
delta.forEach(results::offer);
}
}

Expand All @@ -55,7 +48,7 @@ public static <ParsedChunk> Deque<ParsedChunk> parseEvent(

@FunctionalInterface
public interface ParseChunkFunction<ParsedChunk> {
Iterator<ParsedChunk> apply(XContentParserConfiguration parserConfig, ServerSentEvent event) throws IOException;
Stream<ParsedChunk> apply(XContentParserConfiguration parserConfig, ServerSentEvent event);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,12 @@ public class UnifiedChatCompletionRequestEntity implements ToXContentFragment {
private final boolean stream;

public UnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput) {
Objects.requireNonNull(unifiedChatInput);
this(Objects.requireNonNull(unifiedChatInput).getRequest(), Objects.requireNonNull(unifiedChatInput).stream());
}

this.unifiedRequest = unifiedChatInput.getRequest();
this.stream = unifiedChatInput.stream();
public UnifiedChatCompletionRequestEntity(UnifiedCompletionRequest unifiedRequest, boolean stream) {
this.unifiedRequest = Objects.requireNonNull(unifiedRequest);
this.stream = stream;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
Expand All @@ -20,11 +22,10 @@
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;

import java.io.IOException;
import java.util.Collections;
import java.util.Deque;
import java.util.Iterator;
import java.util.Objects;
import java.util.function.Predicate;
import java.util.stream.Stream;

import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList;
Expand Down Expand Up @@ -113,7 +114,7 @@ public class OpenAiStreamingProcessor extends DelegatingProcessor<Deque<ServerSe
@Override
protected void next(Deque<ServerSentEvent> item) throws Exception {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
var results = parseEvent(item, OpenAiStreamingProcessor::parse, parserConfig, log);
var results = parseEvent(item, OpenAiStreamingProcessor::parse, parserConfig);

if (results.isEmpty()) {
upstream().request(1);
Expand All @@ -122,10 +123,9 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {
}
}

private static Iterator<StreamingChatCompletionResults.Result> parse(XContentParserConfiguration parserConfig, ServerSentEvent event)
throws IOException {
public static Stream<StreamingChatCompletionResults.Result> parse(XContentParserConfiguration parserConfig, ServerSentEvent event) {
if (DONE_MESSAGE.equalsIgnoreCase(event.data())) {
return Collections.emptyIterator();
return Stream.empty();
}

try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.data())) {
Expand Down Expand Up @@ -167,11 +167,14 @@ private static Iterator<StreamingChatCompletionResults.Result> parse(XContentPar

consumeUntilObjectEnd(parser); // end choices
return ""; // stopped
}).stream()
.filter(Objects::nonNull)
.filter(Predicate.not(String::isEmpty))
.map(StreamingChatCompletionResults.Result::new)
.iterator();
}).stream().filter(Objects::nonNull).filter(Predicate.not(String::isEmpty)).map(StreamingChatCompletionResults.Result::new);
} catch (IOException e) {
throw new ElasticsearchStatusException(
"Failed to parse event from inference provider: {}",
RestStatus.INTERNAL_SERVER_ERROR,
e,
event
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ public OpenAiUnifiedChatCompletionResponseHandler(String requestType, ResponsePa
@Override
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
var openAiProcessor = new OpenAiUnifiedStreamingProcessor((m, e) -> buildMidStreamError(request, m, e));

var openAiProcessor = new OpenAiUnifiedStreamingProcessor((m, e) -> buildMidStreamError(request.getInferenceEntityId(), m, e));
flow.subscribe(serverSentEventProcessor);
serverSentEventProcessor.subscribe(openAiProcessor);
return new StreamingUnifiedChatCompletionResults(openAiProcessor);
Expand All @@ -67,15 +66,15 @@ protected Exception buildError(String message, Request request, HttpResult resul
}
}

private static Exception buildMidStreamError(Request request, String message, Exception e) {
public static UnifiedChatCompletionException buildMidStreamError(String inferenceEntityId, String message, Exception e) {
var errorResponse = OpenAiErrorResponse.fromString(message);
if (errorResponse instanceof OpenAiErrorResponse oer) {
return new UnifiedChatCompletionException(
RestStatus.INTERNAL_SERVER_ERROR,
format(
"%s for request from inference entity id [%s]. Error message: [%s]",
SERVER_ERROR_OBJECT,
request.getInferenceEntityId(),
inferenceEntityId,
errorResponse.getErrorMessage()
),
oer.type(),
Expand All @@ -87,7 +86,7 @@ private static Exception buildMidStreamError(Request request, String message, Ex
} else {
return new UnifiedChatCompletionException(
RestStatus.INTERNAL_SERVER_ERROR,
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()),
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, inferenceEntityId),
errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown",
"stream_error"
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@

import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Collections;
import java.util.Deque;
import java.util.Iterator;
import java.util.List;
import java.util.function.BiFunction;
import java.util.stream.Stream;

import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
Expand Down Expand Up @@ -75,7 +74,7 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {
} else if (event.hasData()) {
try {
var delta = parse(parserConfig, event);
delta.forEachRemaining(results::offer);
delta.forEach(results::offer);
} catch (Exception e) {
logger.warn("Failed to parse event from inference provider: {}", event);
throw errorParser.apply(event.data(), e);
Expand All @@ -90,12 +89,12 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {
}
}

private static Iterator<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> parse(
public static Stream<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> parse(
XContentParserConfiguration parserConfig,
ServerSentEvent event
) throws IOException {
if (DONE_MESSAGE.equalsIgnoreCase(event.data())) {
return Collections.emptyIterator();
return Stream.empty();
}

try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.data())) {
Expand All @@ -106,7 +105,7 @@ private static Iterator<StreamingUnifiedChatCompletionResults.ChatCompletionChun

StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = ChatCompletionChunkParser.parse(jsonParser);

return Collections.singleton(chunk).iterator();
return Stream.of(chunk);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ public class OpenAiChatCompletionResponseEntity {
*/

public static ChatCompletionResults fromResponse(Request request, HttpResult response) throws IOException {
try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) {
return fromResponse(response.body());
}

public static ChatCompletionResults fromResponse(byte[] response) throws IOException {
try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response)) {
return CompletionResult.PARSER.apply(p, null).toChatCompletionResults();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
public class SageMakerService implements InferenceService {
public static final String NAME = "sagemaker";
private static final int DEFAULT_BATCH_SIZE = 256;
private static final TimeValue DEFAULT_TIMEOUT = TimeValue.THIRTY_SECONDS;
private final SageMakerModelBuilder modelBuilder;
private final SageMakerClient client;
private final SageMakerSchemas schemas;
Expand Down Expand Up @@ -128,7 +129,7 @@ public void infer(
boolean stream,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
@Nullable TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
if (model instanceof SageMakerModel == false) {
Expand All @@ -148,7 +149,7 @@ public void infer(
client.invokeStream(
regionAndSecrets,
request,
timeout,
timeout != null ? timeout : DEFAULT_TIMEOUT,
ActionListener.wrap(
response -> listener.onResponse(schema.streamResponse(sageMakerModel, response)),
e -> listener.onFailure(schema.error(sageMakerModel, e))
Expand All @@ -160,7 +161,7 @@ public void infer(
client.invoke(
regionAndSecrets,
request,
timeout,
timeout != null ? timeout : DEFAULT_TIMEOUT,
ActionListener.wrap(
response -> listener.onResponse(schema.response(sageMakerModel, response, threadPool.getThreadContext())),
e -> listener.onFailure(schema.error(sageMakerModel, e))
Expand Down Expand Up @@ -201,7 +202,7 @@ private static ElasticsearchStatusException internalFailure(Model model, Excepti
public void unifiedCompletionInfer(
Model model,
UnifiedCompletionRequest request,
TimeValue timeout,
@Nullable TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
if (model instanceof SageMakerModel == false) {
Expand All @@ -217,7 +218,7 @@ public void unifiedCompletionInfer(
client.invokeStream(
regionAndSecrets,
sagemakerRequest,
timeout,
timeout != null ? timeout : DEFAULT_TIMEOUT,
ActionListener.wrap(
response -> listener.onResponse(schema.chatCompletionStreamResponse(sageMakerModel, response)),
e -> listener.onFailure(schema.chatCompletionError(sageMakerModel, e))
Expand All @@ -235,7 +236,7 @@ public void chunkedInfer(
List<ChunkInferenceInput> input,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
@Nullable TimeValue timeout,
ActionListener<List<ChunkedInference>> listener
) {
if (model instanceof SageMakerModel == false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiCompletionPayload;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiTextEmbeddingPayload;

import java.util.Arrays;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand All @@ -39,7 +41,7 @@ public class SageMakerSchemas {
/*
* Add new model API to the register call.
*/
schemas = register(new OpenAiTextEmbeddingPayload());
schemas = register(new OpenAiTextEmbeddingPayload(), new OpenAiCompletionPayload());

streamSchemas = schemas.entrySet()
.stream()
Expand Down Expand Up @@ -88,7 +90,16 @@ public static List<NamedWriteableRegistry.Entry> namedWriteables() {
)
),
schemas.values().stream().flatMap(SageMakerSchema::namedWriteables)
).toList();
)
// Dedupe based on Entry name, we allow Payloads to declare the same Entry but the Registry does not handle duplicates
.collect(
() -> new HashMap<String, NamedWriteableRegistry.Entry>(),
(map, entry) -> map.putIfAbsent(entry.name, entry),
Map::putAll
)
.values()
.stream()
.toList();
}

public SageMakerSchema schemaFor(SageMakerModel model) throws ElasticsearchStatusException {
Expand Down
Loading
Loading