-
Notifications
You must be signed in to change notification settings - Fork 25.3k
[ML] Integrate OpenAi Chat Completion in SageMaker #127767
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bf2de8f
37683f2
fe52265
1aea84c
bd9aa11
d0d7b7c
a3dc3a4
c30ddc0
f28a3f8
31e7d60
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
pr: 127767 | ||
summary: Integrate `OpenAi` Chat Completion in `SageMaker` | ||
area: Machine Learning | ||
type: enhancement | ||
issues: [] |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -47,6 +47,7 @@ | |||||
public class SageMakerService implements InferenceService { | ||||||
public static final String NAME = "sagemaker"; | ||||||
private static final int DEFAULT_BATCH_SIZE = 256; | ||||||
private static final TimeValue DEFAULT_TIMEOUT = TimeValue.THIRTY_SECONDS; | ||||||
private final SageMakerModelBuilder modelBuilder; | ||||||
private final SageMakerClient client; | ||||||
private final SageMakerSchemas schemas; | ||||||
|
@@ -128,7 +129,7 @@ public void infer( | |||||
boolean stream, | ||||||
Map<String, Object> taskSettings, | ||||||
InputType inputType, | ||||||
TimeValue timeout, | ||||||
@Nullable TimeValue timeout, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm I thought the timeout is defaulted in the Line 182 in a4a2714
Can it be null here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I believe I was hitting an issue when I was using Line 49 in a4a2714
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should be defaulting it there too I think:
I think we should consider it a bug if it's null once it gets to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
ActionListener<InferenceServiceResults> listener | ||||||
) { | ||||||
if (model instanceof SageMakerModel == false) { | ||||||
|
@@ -148,7 +149,7 @@ public void infer( | |||||
client.invokeStream( | ||||||
regionAndSecrets, | ||||||
request, | ||||||
timeout, | ||||||
timeout != null ? timeout : DEFAULT_TIMEOUT, | ||||||
ActionListener.wrap( | ||||||
response -> listener.onResponse(schema.streamResponse(sageMakerModel, response)), | ||||||
e -> listener.onFailure(schema.error(sageMakerModel, e)) | ||||||
|
@@ -160,7 +161,7 @@ public void infer( | |||||
client.invoke( | ||||||
regionAndSecrets, | ||||||
request, | ||||||
timeout, | ||||||
timeout != null ? timeout : DEFAULT_TIMEOUT, | ||||||
ActionListener.wrap( | ||||||
response -> listener.onResponse(schema.response(sageMakerModel, response, threadPool.getThreadContext())), | ||||||
e -> listener.onFailure(schema.error(sageMakerModel, e)) | ||||||
|
@@ -201,7 +202,7 @@ private static ElasticsearchStatusException internalFailure(Model model, Excepti | |||||
public void unifiedCompletionInfer( | ||||||
Model model, | ||||||
UnifiedCompletionRequest request, | ||||||
TimeValue timeout, | ||||||
@Nullable TimeValue timeout, | ||||||
ActionListener<InferenceServiceResults> listener | ||||||
) { | ||||||
if (model instanceof SageMakerModel == false) { | ||||||
|
@@ -217,7 +218,7 @@ public void unifiedCompletionInfer( | |||||
client.invokeStream( | ||||||
regionAndSecrets, | ||||||
sagemakerRequest, | ||||||
timeout, | ||||||
timeout != null ? timeout : DEFAULT_TIMEOUT, | ||||||
ActionListener.wrap( | ||||||
response -> listener.onResponse(schema.chatCompletionStreamResponse(sageMakerModel, response)), | ||||||
e -> listener.onFailure(schema.chatCompletionError(sageMakerModel, e)) | ||||||
|
@@ -235,7 +236,7 @@ public void chunkedInfer( | |||||
List<ChunkInferenceInput> input, | ||||||
Map<String, Object> taskSettings, | ||||||
InputType inputType, | ||||||
TimeValue timeout, | ||||||
@Nullable TimeValue timeout, | ||||||
ActionListener<List<ChunkedInference>> listener | ||||||
) { | ||||||
if (model instanceof SageMakerModel == false) { | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We've talked about switching to use 502s, do you think that'd be appropriate here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so? Because this IOException is an error with our parsing logic, which may or may not mean there is something wrong with their response. It could be that we're out of date.