Skip to content

Commit cfdf79f

Browse files
committed
add ml processor for offline batch inference
Signed-off-by: Xun Zhang <[email protected]>
1 parent 7c72188 commit cfdf79f

File tree

20 files changed

+1227
-1
lines changed

20 files changed

+1227
-1
lines changed

Diff for: data-prepper-plugins/aws-lambda/build.gradle

-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ dependencies {
2727
testCompileOnly 'org.projectlombok:lombok:1.18.20'
2828
testAnnotationProcessor 'org.projectlombok:lombok:1.18.20'
2929
testImplementation 'com.fasterxml.jackson.datatype:jackson-datatype-jsr310'
30-
testImplementation project(':data-prepper-test-common')
3130
testImplementation testLibs.slf4j.simple
3231
testImplementation 'org.mockito:mockito-core:4.6.1'
3332
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.8.2'

Diff for: data-prepper-plugins/ml-processor/README.md

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
2+
# ml Processor
3+
4+
This plugin enables you to send data from your Data Prepper pipeline directly to ml-commons for machine learning related activities.
5+
6+
## Usage
7+
```aidl
8+
lambda-pipeline:
9+
...
10+
processor:
11+
- ml:
12+
host: "https://search-xunzh-ml-tests-ihx7htldf7nvo2gdg25m6ehthq.us-east-1.es.amazonaws.com"
13+
aws_sigv4: true
14+
action_type: "batch_predict"
15+
service_name: "bedrock"
16+
model_id: "6ifdTZUBEBlFHJzvGSxO"
17+
output_path: "s3://offlinebatch/bedrock-multisource/output-multisource/"
18+
aws:
19+
region: "us-east-1"
20+
ml_when: /bucket == "offlinebatch"
21+
22+
```
23+
`model_id` as the model id that is registered in the OpenSearch ml-commons plugin.
24+
`service_name` as the remote AI service platform to process then batch job.
25+
`output_path` as the batch job output location of the S3 Uri
26+
27+
# Metrics
28+
29+
### Counter
30+
- `mlProcessorSuccessRequests`: measures total number of requests received and processed successfully by ml-processor.
31+
- `mlProcessorFailedRequests`: measures total number of requests failed by ml-processor.
32+
- `numberOfBatchJobsCreationSucceeded`: measures total number of batch jobs successfully created (200 response status code) by OpenSearch ml-commons API.
33+
- `numberOfBatchJobsCreationFailed`: measures total number of batch jobs failed in creation by OpenSearch ml-commons API.
34+
35+
## Developer Guide
36+
37+
The integration tests for this plugin do not run as part of the Data Prepper build.
38+
The following command runs the integration tests:

Diff for: data-prepper-plugins/ml-processor/build.gradle

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
dependencies {
7+
implementation project(path: ':data-prepper-plugins:common')
8+
implementation project(':data-prepper-plugins:aws-plugin-api')
9+
implementation 'software.amazon.awssdk:sdk-core'
10+
implementation 'software.amazon.awssdk:sts'
11+
implementation 'io.micrometer:micrometer-core'
12+
implementation 'org.json:json'
13+
implementation 'com.fasterxml.jackson.datatype:jackson-datatype-jsr310'
14+
implementation 'org.projectlombok:lombok:1.18.22'
15+
implementation 'software.amazon.awssdk:s3'
16+
compileOnly 'org.projectlombok:lombok:1.18.20'
17+
annotationProcessor 'org.projectlombok:lombok:1.18.20'
18+
testCompileOnly 'org.projectlombok:lombok:1.18.20'
19+
testAnnotationProcessor 'org.projectlombok:lombok:1.18.20'
20+
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.8.2'
21+
testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.8.2'
22+
}
23+
24+
test {
25+
useJUnitPlatform()
26+
}
27+
28+
sourceSets {
29+
integrationTest {
30+
java {
31+
compileClasspath += main.output + test.output
32+
runtimeClasspath += main.output + test.output
33+
srcDir file('src/integrationTest/java')
34+
}
35+
resources.srcDir file('src/integrationTest/resources')
36+
}
37+
}
38+
39+
configurations {
40+
integrationTestImplementation.extendsFrom testImplementation
41+
integrationTestRuntime.extendsFrom testRuntime
42+
}
43+
44+
task integrationTest(type: Test) {
45+
group = 'verification'
46+
testClassesDirs = sourceSets.integrationTest.output.classesDirs
47+
48+
useJUnitPlatform()
49+
50+
classpath = sourceSets.integrationTest.runtimeClasspath
51+
52+
systemProperty 'log4j.configurationFile', 'src/test/resources/log4j2.properties'
53+
54+
filter {
55+
includeTestsMatching '*IT'
56+
}
57+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.dataprepper.plugins.ml.processor;
7+
8+
import io.micrometer.core.instrument.Counter;
9+
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
10+
import org.opensearch.dataprepper.expression.ExpressionEvaluator;
11+
import org.opensearch.dataprepper.metrics.PluginMetrics;
12+
import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin;
13+
import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor;
14+
import org.opensearch.dataprepper.model.annotations.Experimental;
15+
import org.opensearch.dataprepper.model.event.Event;
16+
import org.opensearch.dataprepper.model.processor.AbstractProcessor;
17+
import org.opensearch.dataprepper.model.processor.Processor;
18+
import org.opensearch.dataprepper.model.record.Record;
19+
import org.opensearch.dataprepper.plugins.ml.processor.common.MLBatchJobCreator;
20+
import org.opensearch.dataprepper.plugins.ml.processor.common.MLBatchJobCreatorFactory;
21+
import org.opensearch.dataprepper.plugins.ml.processor.configuration.ServiceName;
22+
import org.slf4j.Logger;
23+
import org.slf4j.LoggerFactory;
24+
25+
import java.util.Collection;
26+
import java.util.List;
27+
import java.util.stream.Collectors;
28+
29+
import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY;
30+
31+
@Experimental
32+
@DataPrepperPlugin(name = "ml", pluginType = Processor.class, pluginConfigurationType = MLProcessorConfig.class)
33+
public class MLProcessor extends AbstractProcessor<Record<Event>, Record<Event>> {
34+
public static final Logger LOG = LoggerFactory.getLogger(MLProcessor.class);
35+
public static final String NUMBER_OF_ML_PROCESSOR_SUCCESS = "mlProcessorSuccessfullyCreated";
36+
public static final String NUMBER_OF_ML_PROCESSOR_FAILED = "mlProcessorFailedToCreated";
37+
38+
private final String whenCondition;
39+
private MLBatchJobCreator mlBatchJobCreator;
40+
private final Counter numberOfMLProcessorSuccessCounter;
41+
private final Counter numberOfMLProcessorFailedCounter;
42+
private final ExpressionEvaluator expressionEvaluator;
43+
44+
@DataPrepperPluginConstructor
45+
public MLProcessor(final MLProcessorConfig mlProcessorConfig, final PluginMetrics pluginMetrics, final AwsCredentialsSupplier awsCredentialsSupplier, final ExpressionEvaluator expressionEvaluator) {
46+
super(pluginMetrics);
47+
this.whenCondition = mlProcessorConfig.getWhenCondition();
48+
ServiceName serviceName = mlProcessorConfig.getServiceName();
49+
this.numberOfMLProcessorSuccessCounter = pluginMetrics.counter(
50+
NUMBER_OF_ML_PROCESSOR_SUCCESS);
51+
this.numberOfMLProcessorFailedCounter = pluginMetrics.counter(
52+
NUMBER_OF_ML_PROCESSOR_FAILED);
53+
this.expressionEvaluator = expressionEvaluator;
54+
55+
// Use factory to get the appropriate job creator
56+
mlBatchJobCreator = MLBatchJobCreatorFactory.getJobCreator(serviceName, mlProcessorConfig, awsCredentialsSupplier, pluginMetrics);
57+
}
58+
59+
@Override
60+
public Collection<Record<Event>> doExecute(Collection<Record<Event>> records) {
61+
// reads from input - S3 input
62+
if (records.size() == 0)
63+
return records;
64+
65+
List<Record<Event>> recordsToMlCommons = records.stream()
66+
.filter(record -> whenCondition == null || expressionEvaluator.evaluateConditional(whenCondition, record.getData()))
67+
.collect(Collectors.toList());
68+
69+
if (recordsToMlCommons.isEmpty()) {
70+
return records;
71+
}
72+
73+
try {
74+
mlBatchJobCreator.createMLBatchJob(recordsToMlCommons);
75+
numberOfMLProcessorSuccessCounter.increment();
76+
} catch (Exception e) {
77+
LOG.error(NOISY, e.getMessage(), e);
78+
numberOfMLProcessorFailedCounter.increment();
79+
}
80+
return records;
81+
}
82+
83+
@Override
84+
public void prepareForShutdown() {
85+
}
86+
87+
@Override
88+
public boolean isReadyForShutdown() {
89+
return true;
90+
}
91+
92+
@Override
93+
public void shutdown() {
94+
}
95+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.dataprepper.plugins.ml.processor;
7+
8+
import com.fasterxml.jackson.annotation.JsonClassDescription;
9+
import com.fasterxml.jackson.annotation.JsonProperty;
10+
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
11+
import com.fasterxml.jackson.annotation.JsonPropertyOrder;
12+
import jakarta.validation.Valid;
13+
import jakarta.validation.constraints.NotNull;
14+
import lombok.Getter;
15+
import org.opensearch.dataprepper.model.annotations.ExampleValues;
16+
import org.opensearch.dataprepper.plugins.ml.processor.configuration.ActionType;
17+
import org.opensearch.dataprepper.plugins.ml.processor.configuration.AwsAuthenticationOptions;
18+
import org.opensearch.dataprepper.plugins.ml.processor.configuration.ServiceName;
19+
20+
@Getter
21+
@JsonPropertyOrder
22+
@JsonClassDescription("The <code>ml</code> processor enables invocation of the ml-commons plugin in OpenSearch service within your pipeline in order to process events. " +
23+
"It supports both synchronous and asynchronous invocations based on your use case.")
24+
public class MLProcessorConfig {
25+
26+
@JsonProperty("aws")
27+
@NotNull
28+
@Valid
29+
private AwsAuthenticationOptions awsAuthenticationOptions;
30+
31+
@JsonPropertyDescription("action type defines the way we want to invoke ml-commons in the predict API")
32+
@JsonProperty("action_type")
33+
private ActionType actionType = ActionType.BATCH_PREDICT;
34+
35+
@JsonPropertyDescription("AI service hosting the remote model for ML Commons predictions")
36+
@JsonProperty("service_name")
37+
private ServiceName serviceName = ServiceName.SAGEMAKER;
38+
39+
@JsonPropertyDescription("defines the OpenSearch host url to be invoked")
40+
@JsonProperty("host")
41+
private String hostUrl;
42+
43+
@JsonPropertyDescription("defines the model id to be invoked in ml-commons")
44+
@JsonProperty("model_id")
45+
private String modelId;
46+
47+
@JsonPropertyDescription("defines the S3 location to write the offline model responses to")
48+
@JsonProperty("output_path")
49+
private String outputPath;
50+
51+
@JsonProperty("aws_sigv4")
52+
private boolean awsSigv4;
53+
54+
@JsonPropertyDescription("Defines a condition for event to use this processor.")
55+
@ExampleValues({
56+
@ExampleValues.Example(value = "/some_key == null", description = "The processor will only run on events where this condition evaluates to true.")
57+
})
58+
@JsonProperty("ml_when")
59+
private String whenCondition;
60+
61+
public ActionType getActionType() {
62+
return actionType;
63+
}
64+
65+
public String getModelId() { return modelId; }
66+
67+
public String getHostUrl() { return hostUrl; }
68+
69+
public String getWhenCondition() {
70+
return whenCondition;
71+
}
72+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.dataprepper.plugins.ml.processor.client;
7+
8+
import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions;
9+
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
10+
import org.opensearch.dataprepper.plugins.ml.processor.MLProcessorConfig;
11+
import org.opensearch.dataprepper.plugins.ml.processor.configuration.AwsAuthenticationOptions;
12+
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
13+
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
14+
import software.amazon.awssdk.regions.Region;
15+
import software.amazon.awssdk.services.s3.S3Client;
16+
17+
public class S3ClientFactory {
18+
19+
public static S3Client createS3Client(final MLProcessorConfig mlProcessorConfig, final AwsCredentialsSupplier awsCredentialsSupplier) {
20+
final AwsCredentialsOptions awsCredentialsOptions = convertToCredentialsOptions(
21+
mlProcessorConfig.getAwsAuthenticationOptions());
22+
final Region region = mlProcessorConfig.getAwsAuthenticationOptions().getAwsRegion();
23+
final AwsCredentialsProvider awsCredentialsProvider = awsCredentialsSupplier.getProvider(
24+
awsCredentialsOptions);
25+
26+
return S3Client.builder()
27+
.region(region)
28+
.credentialsProvider(awsCredentialsProvider)
29+
.overrideConfiguration(ClientOverrideConfiguration.builder()
30+
.retryPolicy(retryPolicy -> retryPolicy.numRetries(5).build())
31+
.build())
32+
.build();
33+
}
34+
35+
public static AwsCredentialsOptions convertToCredentialsOptions(
36+
final AwsAuthenticationOptions awsAuthenticationOptions) {
37+
if (awsAuthenticationOptions == null || awsAuthenticationOptions.getAwsStsRoleArn() == null) {
38+
return AwsCredentialsOptions.defaultOptionsWithDefaultCredentialsProvider();
39+
}
40+
return AwsCredentialsOptions.builder()
41+
.withRegion(awsAuthenticationOptions.getAwsRegion())
42+
.withStsRoleArn(awsAuthenticationOptions.getAwsStsRoleArn())
43+
.withStsExternalId(awsAuthenticationOptions.getAwsStsExternalId())
44+
.withStsHeaderOverrides(awsAuthenticationOptions.getAwsStsHeaderOverrides())
45+
.build();
46+
}
47+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.dataprepper.plugins.ml.processor.common;
7+
8+
import com.fasterxml.jackson.databind.ObjectMapper;
9+
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
10+
import io.micrometer.core.instrument.Counter;
11+
import org.opensearch.dataprepper.metrics.PluginMetrics;
12+
import org.opensearch.dataprepper.model.event.Event;
13+
import org.opensearch.dataprepper.model.record.Record;
14+
import org.opensearch.dataprepper.plugins.ml.processor.MLProcessorConfig;
15+
16+
import java.util.Collection;
17+
18+
public abstract class AbstractBatchJobCreator implements MLBatchJobCreator {
19+
public static final String NUMBER_OF_SUCCESSFUL_BATCH_JOBS_CREATION = "numberOfBatchJobsCreationSucceeded";
20+
public static final String NUMBER_OF_FAILED_BATCH_JOBS_CREATION = "numberOfBatchJobsCreationFailed";
21+
22+
protected static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
23+
protected final MLProcessorConfig mlProcessorConfig;
24+
protected final AwsCredentialsSupplier awsCredentialsSupplier;
25+
protected final Counter numberOfBatchJobsSuccessCounter;
26+
protected final Counter numberOfBatchJobsFailedCounter;
27+
28+
// Constructor
29+
public AbstractBatchJobCreator(MLProcessorConfig mlProcessorConfig,
30+
AwsCredentialsSupplier awsCredentialsSupplier,
31+
final PluginMetrics pluginMetrics) {
32+
this.mlProcessorConfig = mlProcessorConfig;
33+
this.awsCredentialsSupplier = awsCredentialsSupplier;
34+
this.numberOfBatchJobsSuccessCounter = pluginMetrics.counter(NUMBER_OF_SUCCESSFUL_BATCH_JOBS_CREATION);
35+
this.numberOfBatchJobsFailedCounter = pluginMetrics.counter(NUMBER_OF_FAILED_BATCH_JOBS_CREATION);
36+
}
37+
38+
// Add common logic here that both subclasses can share
39+
public void incrementSuccessCounter() {
40+
numberOfBatchJobsSuccessCounter.increment();
41+
}
42+
43+
public void incrementFailureCounter() {
44+
numberOfBatchJobsFailedCounter.increment();
45+
}
46+
47+
// Abstract methods for batch job creation, specific to the implementations
48+
public abstract void createMLBatchJob(Collection<Record<Event>> records);
49+
50+
}

0 commit comments

Comments
 (0)