Skip to content

Add Custom Auth Provider with support for gRPC, plus tests and exception #5578

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: common-server-builder-and-auth-module
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package org.opensearch.dataprepper;

import com.google.protobuf.Any;
import com.linecorp.armeria.common.RequestContext;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.grpc.GoogleGrpcExceptionHandlerFunction;
import com.linecorp.armeria.server.RequestTimeoutException;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.micrometer.core.instrument.Counter;

import org.opensearch.dataprepper.exceptions.BadRequestException;
import org.opensearch.dataprepper.exceptions.BufferWriteException;
import org.opensearch.dataprepper.exceptions.RequestCancelledException;
import org.opensearch.dataprepper.metrics.PluginMetrics;
import org.opensearch.dataprepper.model.buffer.SizeOverflowException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.time.Duration;
import java.util.concurrent.TimeoutException;

public class CustomAuthenticationExceptionHandler implements GoogleGrpcExceptionHandlerFunction {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need an additional exception handler and where is this class actually used? It contains a lot of duplicated code, too.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback! That exception handler was only used in one test. Since it's no longer needed, I’ve safely removed it to avoid duplication and keep the codebase clean.

private static final Logger LOG = LoggerFactory.getLogger(CustomAuthenticationExceptionHandler.class);
private static final String TIMEOUT_MESSAGE = "Request timed out. Check buffer availability or processing delays.";

public static final String REQUEST_TIMEOUTS = "customAuthRequestTimeouts";
public static final String BAD_REQUESTS = "customAuthBadRequests";
public static final String REQUESTS_TOO_LARGE = "customAuthRequestsTooLarge";
public static final String INTERNAL_SERVER_ERROR = "customAuthInternalServerError";

private final Counter requestTimeoutsCounter;
private final Counter badRequestsCounter;
private final Counter requestsTooLargeCounter;
private final Counter internalServerErrorCounter;
private final GrpcRetryInfoCalculator retryInfoCalculator;

public CustomAuthenticationExceptionHandler(final PluginMetrics pluginMetrics,
final Duration retryInfoMinDelay,
final Duration retryInfoMaxDelay) {
this.requestTimeoutsCounter = pluginMetrics.counter(REQUEST_TIMEOUTS);
this.badRequestsCounter = pluginMetrics.counter(BAD_REQUESTS);
this.requestsTooLargeCounter = pluginMetrics.counter(REQUESTS_TOO_LARGE);
this.internalServerErrorCounter = pluginMetrics.counter(INTERNAL_SERVER_ERROR);
this.retryInfoCalculator = new GrpcRetryInfoCalculator(retryInfoMinDelay, retryInfoMaxDelay);
}

@Override
public com.google.rpc.@Nullable Status applyStatusProto(RequestContext ctx, Throwable throwable, Metadata metadata) {
final Throwable actualCause = (throwable instanceof BufferWriteException)
? throwable.getCause() : throwable;
return handleException(actualCause);
}

private com.google.rpc.Status handleException(Throwable e) {
final String msg = e.getMessage();
if (e instanceof RequestTimeoutException || e instanceof TimeoutException) {
requestTimeoutsCounter.increment();
return buildStatus(e, Status.Code.RESOURCE_EXHAUSTED);
} else if (e instanceof SizeOverflowException) {
requestsTooLargeCounter.increment();
return buildStatus(e, Status.Code.RESOURCE_EXHAUSTED);
} else if (e instanceof BadRequestException) {
badRequestsCounter.increment();
return buildStatus(e, Status.Code.INVALID_ARGUMENT);
} else if ((e instanceof StatusRuntimeException) &&
(msg.contains("Invalid protobuf byte sequence") || msg.contains("Can't decode compressed frame"))) {
badRequestsCounter.increment();
return buildStatus(e, Status.Code.INVALID_ARGUMENT);
} else if (e instanceof RequestCancelledException) {
requestTimeoutsCounter.increment();
return buildStatus(e, Status.Code.CANCELLED);
}

internalServerErrorCounter.increment();
LOG.error("CustomAuth gRPC handler caught unexpected exception", e);
return buildStatus(e, Status.Code.INTERNAL);
}

private com.google.rpc.Status buildStatus(Throwable e, Status.Code code) {
com.google.rpc.Status.Builder builder = com.google.rpc.Status.newBuilder()
.setCode(code.value());

if (e instanceof RequestTimeoutException) {
builder.setMessage(TIMEOUT_MESSAGE);
} else {
builder.setMessage(e.getMessage() != null ? e.getMessage() : code.name());
}

if (code == Status.Code.RESOURCE_EXHAUSTED) {
builder.addDetails(Any.pack(retryInfoCalculator.createRetryInfo()));
}

return builder.build();
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.armeria.authentication;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;

public class CustomAuthenticationConfig {
private final String customToken;

@JsonCreator
public CustomAuthenticationConfig(
@JsonProperty("custom_token") String customToken) {
this.customToken = customToken;
}

public String customToken() {
return customToken;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.armeria.authentication;

import com.linecorp.armeria.server.HttpService;
import io.grpc.ServerInterceptor;

import java.util.Optional;
import java.util.function.Function;

public interface CustomAuthenticationProvider {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is intended to be a test class, right?

Move this into the src/test/ directory and rename it to something like TestCustomAuthenticationProvider.

Do the same with the other classes you created in src/main.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick question: implementing CustomGrpcAuthenticationProvider is to allow Data Prepper to support token-based custom authentication in production, as a simulation of SigV4. If we more classes into src/test. can we still able to load it during integration tests or real pipelines?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference between src/main and src/test is what is on the classpath during production and during integration tests.

The Data Prepper plugin framework just needs the plugin to be in the Java classpath and in the specified Java package. Moving these to src/test will allow you to use them in integration testing. But, you cannot run them in a real Data Prepper pipeline. I think this is what we generally want. Now, there could be value in running them for testing purposes. But, I'd think we want to isolate them somehow before doing that. We could do that by using a dedicated Java package.

I'd say for now we just support running these in integration tests.


String UNAUTHENTICATED_PLUGIN_NAME = "unauthenticated";


ServerInterceptor getAuthenticationInterceptor();

default Optional<Function<? super HttpService, ? extends HttpService>> getHttpAuthenticationService() {
return Optional.empty();
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins;

import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.HttpStatus;
import com.linecorp.armeria.common.MediaType;
import com.linecorp.armeria.server.HttpService;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import org.opensearch.dataprepper.armeria.authentication.CustomAuthenticationConfig;
import org.opensearch.dataprepper.armeria.authentication.GrpcAuthenticationProvider;
import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin;
import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor;

import java.util.Optional;
import java.util.function.Function;

@DataPrepperPlugin(
name = "custom_auth",
pluginType = GrpcAuthenticationProvider.class,
pluginConfigurationType = CustomAuthenticationConfig.class
)
public class CustomGrpcAuthenticationProvider implements GrpcAuthenticationProvider {
private final String token;
private static final String AUTH_HEADER = "authentication";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make the header name configurable?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made the header name configurable via pipeline configuration



@DataPrepperPluginConstructor
public CustomGrpcAuthenticationProvider(final CustomAuthenticationConfig config) {
this.token = config.customToken();
}

@Override
public ServerInterceptor getAuthenticationInterceptor() {
return new ServerInterceptor() {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call,
Metadata headers,
ServerCallHandler<ReqT, RespT> next) {

String auth = headers.get(Metadata.Key.of("authentication", Metadata.ASCII_STRING_MARSHALLER));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this at least use the constant AUTH_HEADER?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. I have made the change.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not in this PR, yet.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made the change


if (auth == null || !auth.equals(token)) {
call.close(Status.UNAUTHENTICATED.withDescription("Invalid token"), new Metadata());
return new ServerCall.Listener<>() {};
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if I have a different kind of token, a JWT for example?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! For now, this implementation is meant for simple token validation. JWT support can be considered in future iterations.


return next.startCall(call, headers);
}
};
}

@Override
public Optional<Function<? super HttpService, ? extends HttpService>> getHttpAuthenticationService() {
return Optional.of(delegate -> (ctx, req) -> {
final String auth = req.headers().get(AUTH_HEADER);
if (auth == null || !auth.equals(token)) {
return HttpResponse.of(
HttpStatus.UNAUTHORIZED,
MediaType.PLAIN_TEXT_UTF_8,
"Unauthorized: Invalid or missing token"
);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this essentially the same code as ll. 49-54. Why are there differences? Can't this be a single implementation?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing out. These two similar logic: one for gRPC headers using Armeria -> Metadata; another one for HTTP headers using Armeria's HttpRequest. Both perform Header extraction, Token validation, and Returning unauthorized response if token mismatches. But the APIs are different.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out. I recommend extracting the condition (auth == null || !auth.equals(token)) into a private method isValid(auth) or similar. After all, this is the main logic of the class and is currently duplicated.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I’ve refactored the duplicated logic into a private isValid(auth) method as suggested.

return delegate.serve(ctx, req);
});
}
}


Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins;

import io.grpc.ServerInterceptor;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.Metadata;
import org.opensearch.dataprepper.armeria.authentication.GrpcAuthenticationProvider;
import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin;


/**
* Plugin that allows unauthenticated gRPC access.
*/
@DataPrepperPlugin(
name = GrpcAuthenticationProvider.UNAUTHENTICATED_PLUGIN_NAME,
pluginType = GrpcAuthenticationProvider.class
)
public class UnauthenticatedCustomGrpcAuthenticationProvider implements GrpcAuthenticationProvider {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have support for this here:

https://github.com/opensearch-project/data-prepper/blob/main/data-prepper-plugins/armeria-common/src/main/java/org/opensearch/dataprepper/plugins/UnauthenticatedGrpcAuthenticationProvider.java

Is there anything different about this plugin from that one?

Alternatively, should this be moved into src/test similar to the other plugins?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. The purpose of writing UnauthenticatedCustomGrpcAuthenticationProvider was to simulate the behavior of UnauthenticatedGrpcAuthenticationProvider, specifically for testing custom authentication scenarios.

It overlaps with the existing plugin and isn't needed in production code.

Moving this version and its tests into src/test.


@Override
public ServerInterceptor getAuthenticationInterceptor() {
return new ServerInterceptor() {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call,
Metadata headers,
ServerCallHandler<ReqT, RespT> next) {
// No authentication is performed; allow the request to continue
return next.startCall(call, headers);
}
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package org.opensearch.dataprepper;

import com.google.protobuf.Any;
import com.google.rpc.RetryInfo;
import com.linecorp.armeria.common.RequestContext;
import io.grpc.Metadata;
import io.grpc.Status;
import io.micrometer.core.instrument.Counter;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.dataprepper.exceptions.BadRequestException;
import org.opensearch.dataprepper.exceptions.BufferWriteException;
import org.opensearch.dataprepper.exceptions.RequestCancelledException;
import org.opensearch.dataprepper.metrics.PluginMetrics;
import org.opensearch.dataprepper.model.buffer.SizeOverflowException;

import java.io.IOException;
import java.time.Duration;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.TimeoutException;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
public class CustomAuthenticationExceptionHandlerTest {
@Mock
private PluginMetrics pluginMetrics;

@Mock
private Counter requestTimeoutsCounter;

@Mock
private Counter badRequestsCounter;

@Mock
private Counter requestsTooLargeCounter;

@Mock
private Counter internalServerErrorCounter;

@Mock
private RequestContext requestContext;

@Mock
private Metadata metadata;

private CustomAuthenticationExceptionHandler handler;

@BeforeEach
public void setUp() {
when(pluginMetrics.counter(CustomAuthenticationExceptionHandler.REQUEST_TIMEOUTS)).thenReturn(requestTimeoutsCounter);
when(pluginMetrics.counter(CustomAuthenticationExceptionHandler.BAD_REQUESTS)).thenReturn(badRequestsCounter);
when(pluginMetrics.counter(CustomAuthenticationExceptionHandler.REQUESTS_TOO_LARGE)).thenReturn(requestsTooLargeCounter);
when(pluginMetrics.counter(CustomAuthenticationExceptionHandler.INTERNAL_SERVER_ERROR)).thenReturn(internalServerErrorCounter);

handler = new CustomAuthenticationExceptionHandler(pluginMetrics, Duration.ofMillis(100), Duration.ofSeconds(2));
}

@Test
public void testBadRequestExceptionHandling() {
final String message = UUID.randomUUID().toString();
BadRequestException exception = new BadRequestException(message, new IOException());

com.google.rpc.Status status = handler.applyStatusProto(requestContext, exception, metadata);

assertThat(status.getCode(), equalTo(Status.Code.INVALID_ARGUMENT.value()));
assertThat(status.getMessage(), equalTo(message));
verify(badRequestsCounter).increment();
}

@Test
public void testTimeoutExceptionHandling() {
TimeoutException timeout = new TimeoutException();
BufferWriteException bufferWriteException = new BufferWriteException("timeout", timeout);

com.google.rpc.Status status = handler.applyStatusProto(requestContext, bufferWriteException, metadata);

assertThat(status.getCode(), equalTo(Status.Code.RESOURCE_EXHAUSTED.value()));
verify(requestTimeoutsCounter).increment();
Optional<Any> retryInfo = status.getDetailsList().stream().filter(d -> d.is(RetryInfo.class)).findFirst();
assertTrue(retryInfo.isPresent());
}

@Test
public void testSizeOverflowExceptionHandling() {
SizeOverflowException overflow = new SizeOverflowException("Overflow");
BufferWriteException bufferWriteException = new BufferWriteException("overflow", overflow);

com.google.rpc.Status status = handler.applyStatusProto(requestContext, bufferWriteException, metadata);

assertThat(status.getCode(), equalTo(Status.Code.RESOURCE_EXHAUSTED.value()));
verify(requestsTooLargeCounter).increment();
}

@Test
public void testCancelledRequestHandling() {
String message = UUID.randomUUID().toString();
RequestCancelledException exception = new RequestCancelledException(message);

com.google.rpc.Status status = handler.applyStatusProto(requestContext, exception, metadata);

assertThat(status.getCode(), equalTo(Status.Code.CANCELLED.value()));
assertThat(status.getMessage(), equalTo(message));
verify(requestTimeoutsCounter).increment();
}

@Test
public void testInternalExceptionHandling() {
String message = UUID.randomUUID().toString();
RuntimeException exception = new RuntimeException(message);

com.google.rpc.Status status = handler.applyStatusProto(requestContext, exception, metadata);

assertThat(status.getCode(), equalTo(Status.Code.INTERNAL.value()));
assertThat(status.getMessage(), equalTo(message));
verify(internalServerErrorCounter).increment();
}
}
Loading
Loading