-
Notifications
You must be signed in to change notification settings - Fork 225
Add Custom Auth Provider with support for gRPC, plus tests and exception #5578
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: common-server-builder-and-auth-module
Are you sure you want to change the base?
Changes from 1 commit
e918e31
be19db6
741d81a
57232dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
package org.opensearch.dataprepper; | ||
|
||
import com.google.protobuf.Any; | ||
import com.linecorp.armeria.common.RequestContext; | ||
import com.linecorp.armeria.common.annotation.Nullable; | ||
import com.linecorp.armeria.common.grpc.GoogleGrpcExceptionHandlerFunction; | ||
import com.linecorp.armeria.server.RequestTimeoutException; | ||
import io.grpc.Metadata; | ||
import io.grpc.Status; | ||
import io.grpc.StatusRuntimeException; | ||
import io.micrometer.core.instrument.Counter; | ||
|
||
import org.opensearch.dataprepper.exceptions.BadRequestException; | ||
import org.opensearch.dataprepper.exceptions.BufferWriteException; | ||
import org.opensearch.dataprepper.exceptions.RequestCancelledException; | ||
import org.opensearch.dataprepper.metrics.PluginMetrics; | ||
import org.opensearch.dataprepper.model.buffer.SizeOverflowException; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
import java.time.Duration; | ||
import java.util.concurrent.TimeoutException; | ||
|
||
public class CustomAuthenticationExceptionHandler implements GoogleGrpcExceptionHandlerFunction { | ||
private static final Logger LOG = LoggerFactory.getLogger(CustomAuthenticationExceptionHandler.class); | ||
private static final String TIMEOUT_MESSAGE = "Request timed out. Check buffer availability or processing delays."; | ||
|
||
public static final String REQUEST_TIMEOUTS = "customAuthRequestTimeouts"; | ||
public static final String BAD_REQUESTS = "customAuthBadRequests"; | ||
public static final String REQUESTS_TOO_LARGE = "customAuthRequestsTooLarge"; | ||
public static final String INTERNAL_SERVER_ERROR = "customAuthInternalServerError"; | ||
|
||
private final Counter requestTimeoutsCounter; | ||
private final Counter badRequestsCounter; | ||
private final Counter requestsTooLargeCounter; | ||
private final Counter internalServerErrorCounter; | ||
private final GrpcRetryInfoCalculator retryInfoCalculator; | ||
|
||
public CustomAuthenticationExceptionHandler(final PluginMetrics pluginMetrics, | ||
final Duration retryInfoMinDelay, | ||
final Duration retryInfoMaxDelay) { | ||
this.requestTimeoutsCounter = pluginMetrics.counter(REQUEST_TIMEOUTS); | ||
this.badRequestsCounter = pluginMetrics.counter(BAD_REQUESTS); | ||
this.requestsTooLargeCounter = pluginMetrics.counter(REQUESTS_TOO_LARGE); | ||
this.internalServerErrorCounter = pluginMetrics.counter(INTERNAL_SERVER_ERROR); | ||
this.retryInfoCalculator = new GrpcRetryInfoCalculator(retryInfoMinDelay, retryInfoMaxDelay); | ||
} | ||
|
||
@Override | ||
public com.google.rpc.@Nullable Status applyStatusProto(RequestContext ctx, Throwable throwable, Metadata metadata) { | ||
final Throwable actualCause = (throwable instanceof BufferWriteException) | ||
? throwable.getCause() : throwable; | ||
return handleException(actualCause); | ||
} | ||
|
||
private com.google.rpc.Status handleException(Throwable e) { | ||
final String msg = e.getMessage(); | ||
if (e instanceof RequestTimeoutException || e instanceof TimeoutException) { | ||
requestTimeoutsCounter.increment(); | ||
return buildStatus(e, Status.Code.RESOURCE_EXHAUSTED); | ||
} else if (e instanceof SizeOverflowException) { | ||
requestsTooLargeCounter.increment(); | ||
return buildStatus(e, Status.Code.RESOURCE_EXHAUSTED); | ||
} else if (e instanceof BadRequestException) { | ||
badRequestsCounter.increment(); | ||
return buildStatus(e, Status.Code.INVALID_ARGUMENT); | ||
} else if ((e instanceof StatusRuntimeException) && | ||
(msg.contains("Invalid protobuf byte sequence") || msg.contains("Can't decode compressed frame"))) { | ||
badRequestsCounter.increment(); | ||
return buildStatus(e, Status.Code.INVALID_ARGUMENT); | ||
} else if (e instanceof RequestCancelledException) { | ||
requestTimeoutsCounter.increment(); | ||
return buildStatus(e, Status.Code.CANCELLED); | ||
} | ||
|
||
internalServerErrorCounter.increment(); | ||
LOG.error("CustomAuth gRPC handler caught unexpected exception", e); | ||
return buildStatus(e, Status.Code.INTERNAL); | ||
} | ||
|
||
private com.google.rpc.Status buildStatus(Throwable e, Status.Code code) { | ||
com.google.rpc.Status.Builder builder = com.google.rpc.Status.newBuilder() | ||
.setCode(code.value()); | ||
|
||
if (e instanceof RequestTimeoutException) { | ||
builder.setMessage(TIMEOUT_MESSAGE); | ||
} else { | ||
builder.setMessage(e.getMessage() != null ? e.getMessage() : code.name()); | ||
} | ||
|
||
if (code == Status.Code.RESOURCE_EXHAUSTED) { | ||
builder.addDetails(Any.pack(retryInfoCalculator.createRetryInfo())); | ||
} | ||
|
||
return builder.build(); | ||
} | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.dataprepper.armeria.authentication; | ||
|
||
import com.fasterxml.jackson.annotation.JsonCreator; | ||
import com.fasterxml.jackson.annotation.JsonProperty; | ||
|
||
public class CustomAuthenticationConfig { | ||
private final String customToken; | ||
|
||
@JsonCreator | ||
public CustomAuthenticationConfig( | ||
@JsonProperty("custom_token") String customToken) { | ||
this.customToken = customToken; | ||
} | ||
|
||
public String customToken() { | ||
return customToken; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.dataprepper.armeria.authentication; | ||
|
||
import com.linecorp.armeria.server.HttpService; | ||
import io.grpc.ServerInterceptor; | ||
|
||
import java.util.Optional; | ||
import java.util.function.Function; | ||
|
||
public interface CustomAuthenticationProvider { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe this is intended to be a test class, right? Move this into the Do the same with the other classes you created in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Quick question: implementing CustomGrpcAuthenticationProvider is to allow Data Prepper to support token-based custom authentication in production, as a simulation of SigV4. If we more classes into src/test. can we still able to load it during integration tests or real pipelines? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The difference between The Data Prepper plugin framework just needs the plugin to be in the Java classpath and in the specified Java package. Moving these to I'd say for now we just support running these in integration tests. |
||
|
||
String UNAUTHENTICATED_PLUGIN_NAME = "unauthenticated"; | ||
|
||
|
||
ServerInterceptor getAuthenticationInterceptor(); | ||
|
||
default Optional<Function<? super HttpService, ? extends HttpService>> getHttpAuthenticationService() { | ||
return Optional.empty(); | ||
} | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.dataprepper.plugins; | ||
|
||
import com.linecorp.armeria.common.HttpResponse; | ||
import com.linecorp.armeria.common.HttpStatus; | ||
import com.linecorp.armeria.common.MediaType; | ||
import com.linecorp.armeria.server.HttpService; | ||
import io.grpc.Metadata; | ||
import io.grpc.ServerCall; | ||
import io.grpc.ServerCallHandler; | ||
import io.grpc.ServerInterceptor; | ||
import io.grpc.Status; | ||
import org.opensearch.dataprepper.armeria.authentication.CustomAuthenticationConfig; | ||
import org.opensearch.dataprepper.armeria.authentication.GrpcAuthenticationProvider; | ||
import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin; | ||
import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor; | ||
|
||
import java.util.Optional; | ||
import java.util.function.Function; | ||
|
||
@DataPrepperPlugin( | ||
name = "custom_auth", | ||
pluginType = GrpcAuthenticationProvider.class, | ||
pluginConfigurationType = CustomAuthenticationConfig.class | ||
) | ||
public class CustomGrpcAuthenticationProvider implements GrpcAuthenticationProvider { | ||
private final String token; | ||
private static final String AUTH_HEADER = "authentication"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you make the header name configurable? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Made the header name configurable via pipeline configuration |
||
|
||
|
||
@DataPrepperPluginConstructor | ||
public CustomGrpcAuthenticationProvider(final CustomAuthenticationConfig config) { | ||
this.token = config.customToken(); | ||
} | ||
|
||
@Override | ||
public ServerInterceptor getAuthenticationInterceptor() { | ||
return new ServerInterceptor() { | ||
@Override | ||
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall( | ||
ServerCall<ReqT, RespT> call, | ||
Metadata headers, | ||
ServerCallHandler<ReqT, RespT> next) { | ||
|
||
String auth = headers.get(Metadata.Key.of("authentication", Metadata.ASCII_STRING_MARSHALLER)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this at least use the constant There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are right. I have made the change. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not in this PR, yet. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Made the change |
||
|
||
if (auth == null || !auth.equals(token)) { | ||
call.close(Status.UNAUTHENTICATED.withDescription("Invalid token"), new Metadata()); | ||
return new ServerCall.Listener<>() {}; | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if I have a different kind of token, a JWT for example? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the suggestion! For now, this implementation is meant for simple token validation. JWT support can be considered in future iterations. |
||
|
||
return next.startCall(call, headers); | ||
} | ||
}; | ||
} | ||
|
||
@Override | ||
public Optional<Function<? super HttpService, ? extends HttpService>> getHttpAuthenticationService() { | ||
return Optional.of(delegate -> (ctx, req) -> { | ||
final String auth = req.headers().get(AUTH_HEADER); | ||
if (auth == null || !auth.equals(token)) { | ||
return HttpResponse.of( | ||
HttpStatus.UNAUTHORIZED, | ||
MediaType.PLAIN_TEXT_UTF_8, | ||
"Unauthorized: Invalid or missing token" | ||
); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't this essentially the same code as ll. 49-54. Why are there differences? Can't this be a single implementation? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for pointing out. These two similar logic: one for gRPC headers using Armeria -> Metadata; another one for HTTP headers using Armeria's HttpRequest. Both perform Header extraction, Token validation, and Returning unauthorized response if token mismatches. But the APIs are different. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for pointing this out. I recommend extracting the condition There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! I’ve refactored the duplicated logic into a private isValid(auth) method as suggested. |
||
return delegate.serve(ctx, req); | ||
}); | ||
} | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.dataprepper.plugins; | ||
|
||
import io.grpc.ServerInterceptor; | ||
import io.grpc.ServerCall; | ||
import io.grpc.ServerCallHandler; | ||
import io.grpc.Metadata; | ||
import org.opensearch.dataprepper.armeria.authentication.GrpcAuthenticationProvider; | ||
import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin; | ||
|
||
|
||
/** | ||
* Plugin that allows unauthenticated gRPC access. | ||
*/ | ||
@DataPrepperPlugin( | ||
name = GrpcAuthenticationProvider.UNAUTHENTICATED_PLUGIN_NAME, | ||
pluginType = GrpcAuthenticationProvider.class | ||
) | ||
public class UnauthenticatedCustomGrpcAuthenticationProvider implements GrpcAuthenticationProvider { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We already have support for this here: Is there anything different about this plugin from that one? Alternatively, should this be moved into There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right. The purpose of writing It overlaps with the existing plugin and isn't needed in production code. Moving this version and its tests into |
||
|
||
@Override | ||
public ServerInterceptor getAuthenticationInterceptor() { | ||
return new ServerInterceptor() { | ||
@Override | ||
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall( | ||
ServerCall<ReqT, RespT> call, | ||
Metadata headers, | ||
ServerCallHandler<ReqT, RespT> next) { | ||
// No authentication is performed; allow the request to continue | ||
return next.startCall(call, headers); | ||
} | ||
}; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
package org.opensearch.dataprepper; | ||
|
||
import com.google.protobuf.Any; | ||
import com.google.rpc.RetryInfo; | ||
import com.linecorp.armeria.common.RequestContext; | ||
import io.grpc.Metadata; | ||
import io.grpc.Status; | ||
import io.micrometer.core.instrument.Counter; | ||
import org.junit.jupiter.api.BeforeEach; | ||
import org.junit.jupiter.api.Test; | ||
import org.junit.jupiter.api.extension.ExtendWith; | ||
import org.mockito.Mock; | ||
import org.mockito.junit.jupiter.MockitoExtension; | ||
import org.opensearch.dataprepper.exceptions.BadRequestException; | ||
import org.opensearch.dataprepper.exceptions.BufferWriteException; | ||
import org.opensearch.dataprepper.exceptions.RequestCancelledException; | ||
import org.opensearch.dataprepper.metrics.PluginMetrics; | ||
import org.opensearch.dataprepper.model.buffer.SizeOverflowException; | ||
|
||
import java.io.IOException; | ||
import java.time.Duration; | ||
import java.util.Optional; | ||
import java.util.UUID; | ||
import java.util.concurrent.TimeoutException; | ||
|
||
import static org.hamcrest.MatcherAssert.assertThat; | ||
import static org.hamcrest.Matchers.equalTo; | ||
import static org.junit.jupiter.api.Assertions.assertTrue; | ||
import static org.mockito.Mockito.verify; | ||
import static org.mockito.Mockito.when; | ||
|
||
@ExtendWith(MockitoExtension.class) | ||
public class CustomAuthenticationExceptionHandlerTest { | ||
@Mock | ||
private PluginMetrics pluginMetrics; | ||
|
||
@Mock | ||
private Counter requestTimeoutsCounter; | ||
|
||
@Mock | ||
private Counter badRequestsCounter; | ||
|
||
@Mock | ||
private Counter requestsTooLargeCounter; | ||
|
||
@Mock | ||
private Counter internalServerErrorCounter; | ||
|
||
@Mock | ||
private RequestContext requestContext; | ||
|
||
@Mock | ||
private Metadata metadata; | ||
|
||
private CustomAuthenticationExceptionHandler handler; | ||
|
||
@BeforeEach | ||
public void setUp() { | ||
when(pluginMetrics.counter(CustomAuthenticationExceptionHandler.REQUEST_TIMEOUTS)).thenReturn(requestTimeoutsCounter); | ||
when(pluginMetrics.counter(CustomAuthenticationExceptionHandler.BAD_REQUESTS)).thenReturn(badRequestsCounter); | ||
when(pluginMetrics.counter(CustomAuthenticationExceptionHandler.REQUESTS_TOO_LARGE)).thenReturn(requestsTooLargeCounter); | ||
when(pluginMetrics.counter(CustomAuthenticationExceptionHandler.INTERNAL_SERVER_ERROR)).thenReturn(internalServerErrorCounter); | ||
|
||
handler = new CustomAuthenticationExceptionHandler(pluginMetrics, Duration.ofMillis(100), Duration.ofSeconds(2)); | ||
} | ||
|
||
@Test | ||
public void testBadRequestExceptionHandling() { | ||
final String message = UUID.randomUUID().toString(); | ||
BadRequestException exception = new BadRequestException(message, new IOException()); | ||
|
||
com.google.rpc.Status status = handler.applyStatusProto(requestContext, exception, metadata); | ||
|
||
assertThat(status.getCode(), equalTo(Status.Code.INVALID_ARGUMENT.value())); | ||
assertThat(status.getMessage(), equalTo(message)); | ||
verify(badRequestsCounter).increment(); | ||
} | ||
|
||
@Test | ||
public void testTimeoutExceptionHandling() { | ||
TimeoutException timeout = new TimeoutException(); | ||
BufferWriteException bufferWriteException = new BufferWriteException("timeout", timeout); | ||
|
||
com.google.rpc.Status status = handler.applyStatusProto(requestContext, bufferWriteException, metadata); | ||
|
||
assertThat(status.getCode(), equalTo(Status.Code.RESOURCE_EXHAUSTED.value())); | ||
verify(requestTimeoutsCounter).increment(); | ||
Optional<Any> retryInfo = status.getDetailsList().stream().filter(d -> d.is(RetryInfo.class)).findFirst(); | ||
assertTrue(retryInfo.isPresent()); | ||
} | ||
|
||
@Test | ||
public void testSizeOverflowExceptionHandling() { | ||
SizeOverflowException overflow = new SizeOverflowException("Overflow"); | ||
BufferWriteException bufferWriteException = new BufferWriteException("overflow", overflow); | ||
|
||
com.google.rpc.Status status = handler.applyStatusProto(requestContext, bufferWriteException, metadata); | ||
|
||
assertThat(status.getCode(), equalTo(Status.Code.RESOURCE_EXHAUSTED.value())); | ||
verify(requestsTooLargeCounter).increment(); | ||
} | ||
|
||
@Test | ||
public void testCancelledRequestHandling() { | ||
String message = UUID.randomUUID().toString(); | ||
RequestCancelledException exception = new RequestCancelledException(message); | ||
|
||
com.google.rpc.Status status = handler.applyStatusProto(requestContext, exception, metadata); | ||
|
||
assertThat(status.getCode(), equalTo(Status.Code.CANCELLED.value())); | ||
assertThat(status.getMessage(), equalTo(message)); | ||
verify(requestTimeoutsCounter).increment(); | ||
} | ||
|
||
@Test | ||
public void testInternalExceptionHandling() { | ||
String message = UUID.randomUUID().toString(); | ||
RuntimeException exception = new RuntimeException(message); | ||
|
||
com.google.rpc.Status status = handler.applyStatusProto(requestContext, exception, metadata); | ||
|
||
assertThat(status.getCode(), equalTo(Status.Code.INTERNAL.value())); | ||
assertThat(status.getMessage(), equalTo(message)); | ||
verify(internalServerErrorCounter).increment(); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need an additional exception handler and where is this class actually used? It contains a lot of duplicated code, too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the feedback! That exception handler was only used in one test. Since it's no longer needed, I’ve safely removed it to avoid duplication and keep the codebase clean.