Skip to content

Commit 2496ab9

Browse files
committed
add more UTs and address comments
Signed-off-by: Xun Zhang <[email protected]>
1 parent cfdf79f commit 2496ab9

File tree

11 files changed

+433
-19
lines changed

11 files changed

+433
-19
lines changed

data-prepper-plugins/ml-processor/src/main/java/org/opensearch/dataprepper/plugins/ml/processor/MLProcessor.java

-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828

2929
import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY;
3030

31-
@Experimental
3231
@DataPrepperPlugin(name = "ml", pluginType = Processor.class, pluginConfigurationType = MLProcessorConfig.class)
3332
public class MLProcessor extends AbstractProcessor<Record<Event>, Record<Event>> {
3433
public static final Logger LOG = LoggerFactory.getLogger(MLProcessor.class);

data-prepper-plugins/ml-processor/src/main/java/org/opensearch/dataprepper/plugins/ml/processor/common/AbstractBatchJobCreator.java

+2
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ public AbstractBatchJobCreator(MLProcessorConfig mlProcessorConfig,
3131
final PluginMetrics pluginMetrics) {
3232
this.mlProcessorConfig = mlProcessorConfig;
3333
this.awsCredentialsSupplier = awsCredentialsSupplier;
34+
System.out.println("AbstractBatchJobCreator constructor called");
3435
this.numberOfBatchJobsSuccessCounter = pluginMetrics.counter(NUMBER_OF_SUCCESSFUL_BATCH_JOBS_CREATION);
3536
this.numberOfBatchJobsFailedCounter = pluginMetrics.counter(NUMBER_OF_FAILED_BATCH_JOBS_CREATION);
37+
System.out.println("numberOfBatchJobsSuccessCounter is " + numberOfBatchJobsSuccessCounter.count());
3638
}
3739

3840
// Add common logic here that both subclasses can share

data-prepper-plugins/ml-processor/src/main/java/org/opensearch/dataprepper/plugins/ml/processor/common/SageMakerBatchJobCreator.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ public void createMLBatchJob(Collection<Record<Event>> records) {
6060
})
6161
.orElse(null); // Use null if no record is found
6262
String commonPrefix = findCommonPrefix(records);
63-
String manifestUrl = generateManifest(records, customerBucket, commonPrefix, mlProcessorConfig);
63+
String manifestUrl = generateManifest(records, customerBucket, commonPrefix);
6464
String payload = createPayloadSageMaker(manifestUrl, mlProcessorConfig);
6565

6666
boolean success = retryWithBackoff(() -> sendRequestToMLCommons(payload, mlProcessorConfig, awsCredentialsSupplier));
@@ -103,7 +103,7 @@ private String findCommonPrefix(String s1, String s2) {
103103
}
104104

105105

106-
private String generateManifest(Collection<Record<Event>> records, String customerBucket, String prefix, MLProcessorConfig mlProcessorConfig) {
106+
private String generateManifest(Collection<Record<Event>> records, String customerBucket, String prefix) {
107107
try {
108108
// Generate timestamp
109109
String timestamp = new SimpleDateFormat("yyyyMMddHHmmss").format(new Date());
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.dataprepper.plugins.ml.processor.util;
7+
8+
import software.amazon.awssdk.http.HttpExecuteRequest;
9+
import software.amazon.awssdk.http.HttpExecuteResponse;
10+
11+
public interface HttpClientExecutor {
12+
HttpExecuteResponse execute(HttpExecuteRequest executeRequest) throws Exception;
13+
}

data-prepper-plugins/ml-processor/src/main/java/org/opensearch/dataprepper/plugins/ml/processor/util/MlCommonRequester.java

+10-13
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package org.opensearch.dataprepper.plugins.ml.processor.util;
77

8+
import com.google.common.annotations.VisibleForTesting;
89
import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions;
910
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
1011
import org.opensearch.dataprepper.plugins.ml.processor.MLProcessor;
@@ -14,29 +15,33 @@
1415
import software.amazon.awssdk.auth.credentials.*;
1516
import software.amazon.awssdk.auth.signer.Aws4Signer;
1617
import software.amazon.awssdk.auth.signer.params.Aws4SignerParams;
17-
import software.amazon.awssdk.core.internal.http.loader.DefaultSdkHttpClientBuilder;
1818
import software.amazon.awssdk.core.sync.RequestBody;
1919
import software.amazon.awssdk.http.*;
2020
import software.amazon.awssdk.regions.Region;
21-
import software.amazon.awssdk.utils.AttributeMap;
2221

2322
import java.io.BufferedReader;
2423
import java.io.IOException;
2524
import java.io.InputStreamReader;
2625
import java.net.URI;
2726
import java.nio.charset.StandardCharsets;
28-
import java.time.Duration;
2927
import java.util.stream.Collectors;
3028

3129
import static org.opensearch.dataprepper.plugins.ml.processor.client.S3ClientFactory.convertToCredentialsOptions;
3230

3331
public class MlCommonRequester {
3432
private static final Aws4Signer signer;
3533
private static final Logger LOG = LoggerFactory.getLogger(MLProcessor.class);
34+
private static HttpClientExecutor httpClientExecutor = new SdkHttpClientExecutor();
35+
3636
static {
3737
signer = Aws4Signer.create();
3838
}
3939

40+
@VisibleForTesting
41+
static void setHttpClientExecutor(HttpClientExecutor executor) {
42+
httpClientExecutor = executor;
43+
}
44+
4045
public static void sendRequestToMLCommons(String payload, MLProcessorConfig mlProcessorConfig, final AwsCredentialsSupplier awsCredentialsSupplier) {
4146
String host = mlProcessorConfig.getHostUrl();
4247
String modelId = mlProcessorConfig.getModelId();
@@ -65,20 +70,12 @@ public static void sendRequestToMLCommons(String payload, MLProcessorConfig mlPr
6570
}
6671

6772
private static void executeHttpRequest(HttpExecuteRequest executeRequest) {
68-
AttributeMap attributeMap = AttributeMap.builder()
69-
.put(SdkHttpConfigurationOption.CONNECTION_TIMEOUT, Duration.ofMillis(30000))
70-
.put(SdkHttpConfigurationOption.READ_TIMEOUT, Duration.ofMillis(3000))
71-
.put(SdkHttpConfigurationOption.MAX_CONNECTIONS, 10)
72-
.build();
73-
SdkHttpClient httpClient = new DefaultSdkHttpClientBuilder().buildWithDefaults(attributeMap);
74-
7573
try {
76-
HttpExecuteResponse response = httpClient.prepareRequest(executeRequest).call();
77-
System.out.println("Making HTTP call to ML Commons...");
74+
HttpExecuteResponse response = httpClientExecutor.execute(executeRequest);
7875

7976
handleHttpResponse(response);
8077
} catch (Exception e) { // TODO: catch different exceptions and retry
81-
throw new RuntimeException("Failed to execute request in AWS connector", e);
78+
throw new RuntimeException("Failed to execute HTTP request using the ML Commons model", e);
8279
}
8380
}
8481

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.dataprepper.plugins.ml.processor.util;
7+
8+
import software.amazon.awssdk.core.internal.http.loader.DefaultSdkHttpClientBuilder;
9+
import software.amazon.awssdk.http.HttpExecuteRequest;
10+
import software.amazon.awssdk.http.HttpExecuteResponse;
11+
import software.amazon.awssdk.http.SdkHttpClient;
12+
import software.amazon.awssdk.http.SdkHttpConfigurationOption;
13+
import software.amazon.awssdk.utils.AttributeMap;
14+
15+
import java.time.Duration;
16+
17+
public class SdkHttpClientExecutor implements HttpClientExecutor {
18+
private final SdkHttpClient httpClient;
19+
20+
public SdkHttpClientExecutor() {
21+
AttributeMap attributeMap = AttributeMap.builder()
22+
.put(SdkHttpConfigurationOption.CONNECTION_TIMEOUT, Duration.ofMillis(30000))
23+
.put(SdkHttpConfigurationOption.READ_TIMEOUT, Duration.ofMillis(3000))
24+
.put(SdkHttpConfigurationOption.MAX_CONNECTIONS, 10)
25+
.build();
26+
this.httpClient = new DefaultSdkHttpClientBuilder().buildWithDefaults(attributeMap);
27+
}
28+
29+
@Override
30+
public HttpExecuteResponse execute(HttpExecuteRequest executeRequest) throws Exception {
31+
return httpClient.prepareRequest(executeRequest).call();
32+
}
33+
}

data-prepper-plugins/ml-processor/src/test/java/org/opensearch/dataprepper/plugins/ml/processor/client/S3ClientFactoryTest.java

-2
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ public class S3ClientFactoryTest {
4141
@Mock
4242
private AwsCredentialsProvider awsCredentialsProvider;
4343

44-
@InjectMocks
45-
private S3ClientFactory s3ClientFactory;
4644

4745
@BeforeEach
4846
void setUp() {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.dataprepper.plugins.ml.processor.common;
7+
8+
import io.micrometer.core.instrument.Counter;
9+
import org.junit.jupiter.api.BeforeEach;
10+
import org.junit.jupiter.api.Test;
11+
import org.mockito.Mock;
12+
import org.mockito.MockitoAnnotations;
13+
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
14+
import org.opensearch.dataprepper.metrics.PluginMetrics;
15+
import org.opensearch.dataprepper.model.event.Event;
16+
import org.opensearch.dataprepper.model.record.Record;
17+
import org.opensearch.dataprepper.plugins.ml.processor.MLProcessorConfig;
18+
import org.opensearch.dataprepper.plugins.ml.processor.util.RetryUtil;
19+
20+
import java.util.Arrays;
21+
22+
import static org.junit.jupiter.api.Assertions.assertTrue;
23+
import static org.mockito.Mockito.*;
24+
import static org.opensearch.dataprepper.plugins.ml.processor.common.AbstractBatchJobCreator.*;
25+
26+
public class BedrockBatchJobCreatorTest {
27+
@Mock
28+
private MLProcessorConfig mlProcessorConfig;
29+
30+
@Mock
31+
private AwsCredentialsSupplier awsCredentialsSupplier;
32+
33+
@Mock
34+
private PluginMetrics pluginMetrics;
35+
36+
private BedrockBatchJobCreator bedrockBatchJobCreator;
37+
private Counter counter;;
38+
39+
@BeforeEach
40+
void setUp() {
41+
MockitoAnnotations.openMocks(this);
42+
when(mlProcessorConfig.getOutputPath()).thenReturn("s3://offlinebatch/output");
43+
counter = new Counter() {
44+
@Override
45+
public void increment(double v) {}
46+
47+
@Override
48+
public double count() {
49+
return 0;
50+
}
51+
52+
@Override
53+
public Id getId() {
54+
return null;
55+
}
56+
};
57+
when(pluginMetrics.counter(NUMBER_OF_SUCCESSFUL_BATCH_JOBS_CREATION)).thenReturn(counter);
58+
when(pluginMetrics.counter(NUMBER_OF_FAILED_BATCH_JOBS_CREATION)).thenReturn(counter);
59+
bedrockBatchJobCreator = spy(new BedrockBatchJobCreator(mlProcessorConfig, awsCredentialsSupplier, pluginMetrics));
60+
}
61+
62+
@Test
63+
void testCreateMLBatchJob_Success() {
64+
Event event = mock(Event.class);
65+
Record<Event> record = new Record<>(event);
66+
67+
when(event.getJsonNode()).thenReturn(OBJECT_MAPPER.createObjectNode()
68+
.put("bucket", "test-bucket")
69+
.put("key", "input.jsonl"));
70+
71+
mockStatic(RetryUtil.class);
72+
when(RetryUtil.retryWithBackoff(any())).thenReturn(true);
73+
74+
bedrockBatchJobCreator.createMLBatchJob(Arrays.asList(record));
75+
76+
verify(bedrockBatchJobCreator, times(1)).incrementSuccessCounter();
77+
}
78+
79+
@Test
80+
void testCreateMLBatchJob_Failure() {
81+
Event event = mock(Event.class);
82+
Record<Event> record = new Record<>(event);
83+
84+
when(event.getJsonNode()).thenReturn(OBJECT_MAPPER.createObjectNode()
85+
.put("bucket", "test-bucket")
86+
.put("key", "input.jsonl"));
87+
88+
mockStatic(RetryUtil.class);
89+
when(RetryUtil.retryWithBackoff(any())).thenReturn(false);
90+
91+
bedrockBatchJobCreator.createMLBatchJob(Arrays.asList(record));
92+
93+
verify(bedrockBatchJobCreator, times(1)).incrementFailureCounter();
94+
}
95+
96+
@Test
97+
void testInterruptedExceptionHandling() throws InterruptedException {
98+
Event event = mock(Event.class);
99+
Record<Event> record = new Record<>(event);
100+
101+
when(event.getJsonNode()).thenReturn(OBJECT_MAPPER.createObjectNode()
102+
.put("bucket", "test-bucket")
103+
.put("key", "input.jsonl"));
104+
105+
mockStatic(RetryUtil.class);
106+
when(RetryUtil.retryWithBackoff(any())).thenReturn(true);
107+
108+
Thread.currentThread().interrupt();
109+
bedrockBatchJobCreator.createMLBatchJob(Arrays.asList(record));
110+
111+
assertTrue(Thread.interrupted()); // Ensure interrupted flag is reset
112+
}
113+
}

0 commit comments

Comments
 (0)