Skip to content

Add Custom Auth Provider with support for gRPC, plus tests and exception #5578

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: common-server-builder-and-auth-module
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.testcustomauth;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;

public class TestCustomAuthenticationConfig {
private final String customToken;
private final String header;

@JsonCreator
public TestCustomAuthenticationConfig(
@JsonProperty("custom_token") String customToken,
@JsonProperty("header") String header) {
this.customToken = customToken;
this.header = header != null ? header : "authentication";
}

public String customToken() {
return customToken;
}

public String header() {
return header;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.testcustomauth;

import com.linecorp.armeria.server.HttpService;
import io.grpc.ServerInterceptor;

import java.util.Optional;
import java.util.function.Function;

public interface TestCustomAuthenticationProvider {

String UNAUTHENTICATED_PLUGIN_NAME = "unauthenticated";


ServerInterceptor getAuthenticationInterceptor();

default Optional<Function<? super HttpService, ? extends HttpService>> getHttpAuthenticationService() {
return Optional.empty();
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package org.opensearch.dataprepper.plugins.testcustomauth;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;

import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
public class TestCustomAuthenticationProviderTest {

private static final String TOKEN = "test-token";
private static final String HEADER = "authentication";

@Mock
private TestCustomAuthenticationConfig config;

private TestCustomGrpcAuthenticationProvider provider;

@BeforeEach
void setUp() {
when(config.customToken()).thenReturn(TOKEN);
when(config.header()).thenReturn(HEADER);

provider = new TestCustomGrpcAuthenticationProvider(config);
}

@Test
void testGetHttpAuthenticationService_shouldReturnValidOptional() {
var optionalService = provider.getHttpAuthenticationService();
Assertions.assertTrue(optionalService.isPresent());
}

@Test
void testGetAuthenticationInterceptor_shouldReturnNonNull() {
var interceptor = provider.getAuthenticationInterceptor();
Assertions.assertNotNull(interceptor);
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
package org.opensearch.dataprepper.plugins.testcustomauth;

import com.linecorp.armeria.client.WebClient;
import com.linecorp.armeria.common.AggregatedHttpResponse;
import com.linecorp.armeria.common.HttpData;
import com.linecorp.armeria.common.HttpMethod;
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.HttpStatus;
import com.linecorp.armeria.common.MediaType;
import com.linecorp.armeria.common.RequestHeaders;
import com.linecorp.armeria.server.ServerBuilder;
import com.linecorp.armeria.server.grpc.GrpcService;
import com.linecorp.armeria.server.grpc.GrpcServiceBuilder;
import com.linecorp.armeria.testing.junit5.server.ServerExtension;
import io.grpc.ServerInterceptors;
import io.grpc.health.v1.HealthCheckRequest;
import io.grpc.health.v1.HealthCheckResponse;
import io.grpc.health.v1.HealthGrpc;
import io.grpc.stub.StreamObserver;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.opensearch.dataprepper.armeria.authentication.GrpcAuthenticationProvider;

import java.nio.charset.Charset;
import java.util.Collections;
import java.util.UUID;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class TestCustomBasicAuthenticationProviderTest {
private static final String TOKEN = UUID.randomUUID().toString();
private static final String HEADER_NAME = "x-" + UUID.randomUUID();
private static GrpcAuthenticationProvider grpcAuthenticationProvider;

@RegisterExtension
static ServerExtension server = new ServerExtension() {
@Override
protected void configure(ServerBuilder sb) {
TestCustomAuthenticationConfig config = mock(TestCustomAuthenticationConfig.class);
when(config.customToken()).thenReturn(TOKEN);
when(config.header()).thenReturn(HEADER_NAME);

grpcAuthenticationProvider = new TestCustomGrpcAuthenticationProvider(config);

GrpcServiceBuilder grpcServiceBuilder = GrpcService.builder()
.enableUnframedRequests(true)
.addService(ServerInterceptors.intercept(
new SampleHealthGrpcService(),
Collections.singletonList(grpcAuthenticationProvider.getAuthenticationInterceptor())));

sb.service(grpcServiceBuilder.build());
}
};

private static class SampleHealthGrpcService extends HealthGrpc.HealthImplBase {
@Override
public void check(HealthCheckRequest request, StreamObserver<HealthCheckResponse> responseObserver) {
responseObserver.onNext(
HealthCheckResponse.newBuilder().setStatus(HealthCheckResponse.ServingStatus.SERVING).build());
responseObserver.onCompleted();
}
}

@Nested
class ConstructorTests {
TestCustomAuthenticationConfig config;

@BeforeEach
void setUp() {
config = mock(TestCustomAuthenticationConfig.class);
}

@Test
void constructor_with_null_config_throws() {
assertThrows(NullPointerException.class, () -> new TestCustomGrpcAuthenticationProvider(null));
}
}

@Nested
class WithServer {
@Test
void request_without_token_responds_Unauthorized() {
WebClient client = WebClient.of(server.httpUri());
HttpRequest request = HttpRequest.of(RequestHeaders.builder()
.method(HttpMethod.POST)
.path("/grpc.health.v1.Health/Check")
.contentType(MediaType.JSON_UTF_8)
.build());

final AggregatedHttpResponse httpResponse = client.execute(request).aggregate().join();

assertThat(httpResponse.status(), equalTo(HttpStatus.UNAUTHORIZED));
}

@Test
void request_with_invalid_token_responds_Unauthorized() {
WebClient client = WebClient.builder(server.httpUri())
.addHeader(HEADER_NAME, "invalid-token")
.build();

HttpRequest request = HttpRequest.of(RequestHeaders.builder()
.method(HttpMethod.POST)
.path("/grpc.health.v1.Health/Check")
.contentType(MediaType.JSON_UTF_8)
.build());

final AggregatedHttpResponse httpResponse = client.execute(request).aggregate().join();

assertThat(httpResponse.status(), equalTo(HttpStatus.UNAUTHORIZED));
}

@Test
void request_with_valid_token_responds_OK() {
WebClient client = WebClient.builder(server.httpUri())
.addHeader(HEADER_NAME, TOKEN)
.build();

HttpRequest request = HttpRequest.of(RequestHeaders.builder()
.method(HttpMethod.POST)
.path("/grpc.health.v1.Health/Check")
.contentType(MediaType.JSON_UTF_8)
.build(),
HttpData.of(Charset.defaultCharset(), "{\"healthCheckConfig\":{\"serviceName\": \"test\"} }"));


final AggregatedHttpResponse httpResponse = client.execute(request).aggregate().join();

assertThat(httpResponse.status(), equalTo(HttpStatus.OK));
}
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.testcustomauth;

import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.HttpStatus;
import com.linecorp.armeria.common.MediaType;
import com.linecorp.armeria.server.HttpService;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import org.opensearch.dataprepper.armeria.authentication.GrpcAuthenticationProvider;
import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin;
import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor;

import java.util.Optional;
import java.util.function.Function;

@DataPrepperPlugin(
name = "test_custom_auth",
pluginType = GrpcAuthenticationProvider.class,
pluginConfigurationType = TestCustomAuthenticationConfig.class
)
public class TestCustomGrpcAuthenticationProvider implements GrpcAuthenticationProvider {
private final String token;
private final String header;

@DataPrepperPluginConstructor
public TestCustomGrpcAuthenticationProvider(final TestCustomAuthenticationConfig config) {
this.token = config.customToken();
this.header = config.header();
}

@Override
public ServerInterceptor getAuthenticationInterceptor() {
return new ServerInterceptor() {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call,
Metadata headers,
ServerCallHandler<ReqT, RespT> next) {

String auth = headers.get(Metadata.Key.of(header, Metadata.ASCII_STRING_MARSHALLER));

if (!isValid(auth)) {
call.close(Status.UNAUTHENTICATED.withDescription("Invalid token"), new Metadata());
return new ServerCall.Listener<>() {};
}

return next.startCall(call, headers);
}
};
}

@Override
public Optional<Function<? super HttpService, ? extends HttpService>> getHttpAuthenticationService() {
return Optional.of(delegate -> (ctx, req) -> {
final String auth = req.headers().get(header);
if (!isValid(auth)) {
return HttpResponse.of(
HttpStatus.UNAUTHORIZED,
MediaType.PLAIN_TEXT_UTF_8,
"Unauthorized: Invalid or missing token"
);
}
return delegate.serve(ctx, req);
});
}

/**
* Checks if the provided authentication token is valid.
*
* @param authHeader the value of the authentication header
* @return true if valid, false otherwise
*/
private boolean isValid(final String authHeader) {
return authHeader != null && authHeader.equals(token);
}
}


Loading
Loading