Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.migrations.bulkload.common.http.ConnectionContext;
import org.opensearch.migrations.bulkload.common.http.GzipPayloadRequestTransformer;
import org.opensearch.migrations.bulkload.common.http.HttpResponse;
import org.opensearch.migrations.bulkload.common.http.TlsCredentialsProvider;
import org.opensearch.migrations.bulkload.netty.ReadMeteringHandler;
import org.opensearch.migrations.bulkload.netty.WriteMeteringHandler;
import org.opensearch.migrations.bulkload.tracing.IRfsContexts;
Expand Down Expand Up @@ -75,22 +76,14 @@ public RestClient(ConnectionContext connectionContext, int maxConnections) {

protected RestClient(ConnectionContext connectionContext, HttpClient httpClient) {
this.connectionContext = connectionContext;
TlsCredentialsProvider tlsCredentialsProvider = connectionContext.getTlsCredentialsProvider();

SslProvider sslProvider;
if (connectionContext.isInsecure()) {
try {
SslContext sslContext = SslContextBuilder.forClient()
.trustManager(InsecureTrustManagerFactory.INSTANCE)
.build();
sslProvider = SslProvider.builder().sslContext(sslContext).handlerConfigurator(sslHandler -> {
SSLEngine engine = sslHandler.engine();
SSLParameters sslParameters = engine.getSSLParameters();
sslParameters.setEndpointIdentificationAlgorithm(null);
engine.setSSLParameters(sslParameters);
}).build();
} catch (SSLException e) {
throw new IllegalStateException("Unable to construct SslProvider", e);
}

if (tlsCredentialsProvider != null) {
sslProvider = getSslProvider(tlsCredentialsProvider);
} else if (connectionContext.isInsecure()) {
sslProvider = getInsecureSslProvider();
} else {
sslProvider = SslProvider.defaultClientProvider();
}
Expand Down Expand Up @@ -273,4 +266,55 @@ private static void addNewHandler(ChannelPipeline p, String name, ChannelHandler
};
};
}

private SslProvider getSslProvider(TlsCredentialsProvider tlsCredentialsProvider) {
try {
SslContextBuilder builder = SslContextBuilder.forClient();

if (tlsCredentialsProvider.hasCACredentials()) {
builder.trustManager(tlsCredentialsProvider.getCaCertInputStream());
}

if (tlsCredentialsProvider.hasClientCredentials()) {
builder.keyManager(
tlsCredentialsProvider.getClientCertInputStream(),
tlsCredentialsProvider.getClientCertKeyInputStream()
);
}

SslContext sslContext = builder.build();

return SslProvider.builder()
.sslContext(sslContext)
.handlerConfigurator(sslHandler -> {
SSLEngine engine = sslHandler.engine();
SSLParameters sslParameters = engine.getSSLParameters();
engine.setSSLParameters(sslParameters);
})
.build();

} catch (SSLException e) {
throw new IllegalStateException("Unable to construct custom SslProvider", e);
}
}

private SslProvider getInsecureSslProvider() {
try {
SslContext sslContext = SslContextBuilder.forClient()
.trustManager(InsecureTrustManagerFactory.INSTANCE)
.build();

return SslProvider.builder()
.sslContext(sslContext)
.handlerConfigurator(sslHandler -> {
SSLEngine engine = sslHandler.engine();
SSLParameters sslParameters = engine.getSSLParameters();
sslParameters.setEndpointIdentificationAlgorithm(null);
engine.setSSLParameters(sslParameters);
})
.build();
} catch (SSLException e) {
throw new IllegalStateException("Unable to construct SslProvider", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

import java.net.URI;
import java.net.URISyntaxException;
import java.nio.file.Path;
import java.time.Clock;

import com.beust.jcommander.Parameter;
import com.beust.jcommander.ParametersDelegate;
import com.beust.jcommander.converters.PathConverter;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
Expand All @@ -30,6 +32,8 @@ public enum Protocol {
private final boolean compressionSupported;
private final boolean awsSpecificAuthentication;

private TlsCredentialsProvider tlsCredentialsProvider;

private ConnectionContext(IParams params) {
assert params.getHost() != null : "host is null";

Expand Down Expand Up @@ -79,6 +83,23 @@ else if (sigv4Enabled) {
requestTransformer = new NoAuthTransformer();
}
compressionSupported = params.isCompressionEnabled();

validateClientCertPairPresence(params);

if (isTlsCredentialsEnabled(params)) {
tlsCredentialsProvider = new FileTlsCredentialsProvider(
params.getCaCert(),
params.getClientCert(),
params.getClientCertKey());
}
}

/**
* Sets the TLS credentials provider.
* NOTE: This method is only intended for testing purposes.
*/
public void setTlsCredentialsProvider(TlsCredentialsProvider tlsCredentialsProvider) {
this.tlsCredentialsProvider = tlsCredentialsProvider;
}

public interface IParams {
Expand All @@ -92,6 +113,12 @@ public interface IParams {

String getAwsServiceSigningName();

Path getCaCert();

Path getClientCert();

Path getClientCertKey();

boolean isCompressionEnabled();

boolean isInsecure();
Expand Down Expand Up @@ -121,6 +148,27 @@ public static class TargetArgs implements IParams {
required = false)
public String password = null;

@Parameter(
names = {"--target-cacert", "--targetCaCert" },
description = "Optional. The target CA certificate",
required = false,
converter = PathConverter.class)
public Path caCert = null;

@Parameter(
names = {"--target-client-cert", "--targetClientCert" },
description = "Optional. The target client TLS certificate",
required = false,
converter = PathConverter.class)
public Path clientCert = null;

@Parameter(
names = {"--target-client-cert-key", "--targetClientCertKey" },
description = "Optional. The target client TLS certificate key",
required = false,
converter = PathConverter.class)
public Path clientCertKey = null;

@Parameter(
names = {"--target-aws-region", "--targetAwsRegion" },
description = "Optional. The target aws region. Required only if sigv4 auth is used",
Expand Down Expand Up @@ -178,6 +226,27 @@ public static class SourceArgs implements IParams {
required = false)
public String password = null;

@Parameter(
names = {"--source-cacert", "--sourceCaCert" },
description = "Optional. The source CA certificate",
required = false,
converter = PathConverter.class)
public Path caCert = null;

@Parameter(
names = {"--source-client-cert", "--sourceClientCert" },
description = "Optional. The source client TLS certificate",
required = false,
converter = PathConverter.class)
public Path clientCert = null;

@Parameter(
names = {"--source-client-cert-key", "--sourceClientCertKey" },
description = "Optional. The source client TLS certificate key",
required = false,
converter = PathConverter.class)
public Path clientCertKey = null;

@Parameter(
names = {"--source-aws-region", "--sourceAwsRegion" },
description = "Optional. The source aws region, e.g. 'us-east-1'. Required if sigv4 auth is used",
Expand All @@ -203,4 +272,15 @@ public boolean isCompressionEnabled() {
return false;
}
}

private static void validateClientCertPairPresence(IParams params) {
if ((params.getClientCert() != null) ^ (params.getClientCertKey() != null)) {
throw new IllegalArgumentException(
"Both clientCert and clientCertKey must be provided together, or neither.");
}
}

private static boolean isTlsCredentialsEnabled(IParams params) {
return (params.getCaCert() != null) || (params.getClientCert() != null && params.getClientCertKey() != null);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package org.opensearch.migrations.bulkload.common.http;

import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.InputStream;
import java.nio.file.Path;

public class FileTlsCredentialsProvider implements TlsCredentialsProvider {
private final Path caCert;
private final Path clientCert;
private final Path clientCertKey;

public FileTlsCredentialsProvider(Path caCert, Path clientCert, Path clientCertKey) {
this.caCert = caCert;
this.clientCert = clientCert;
this.clientCertKey = clientCertKey;
}

public InputStream getCaCertInputStream() {
return openStream(caCert);
}

public InputStream getClientCertInputStream() {
return openStream(clientCert);
}

public InputStream getClientCertKeyInputStream() {
return openStream(clientCertKey);
}

public boolean hasClientCredentials() {
return clientCert != null && clientCertKey != null;
}

public boolean hasCACredentials() {
return caCert != null;
}

private InputStream openStream(Path path) {
try {
return new FileInputStream(path.toFile());
} catch (FileNotFoundException e) {
throw new TlsCredentialLoadingException("Failed to load " + path, e);
}
}

public static class TlsCredentialLoadingException extends RuntimeException {
public TlsCredentialLoadingException(String message, Throwable cause) {
super(message, cause);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package org.opensearch.migrations.bulkload.common.http;

import java.io.InputStream;

public interface TlsCredentialsProvider {
InputStream getCaCertInputStream();
InputStream getClientCertInputStream();
InputStream getClientCertKeyInputStream();
boolean hasClientCredentials();
boolean hasCACredentials();
}
Loading
Loading