Skip to content

Add support for mTLS authentication in Arrow Flight client #25179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ public enum ArrowErrorCode
ARROW_INTERNAL_ERROR(1, INTERNAL_ERROR),
ARROW_FLIGHT_CLIENT_ERROR(2, EXTERNAL),
ARROW_FLIGHT_METADATA_ERROR(3, EXTERNAL),
ARROW_FLIGHT_TYPE_ERROR(4, EXTERNAL);
ARROW_FLIGHT_TYPE_ERROR(4, EXTERNAL),
ARROW_FLIGHT_INVALID_KEY_ERROR(5, INTERNAL_ERROR),
ARROW_FLIGHT_INVALID_CERT_ERROR(6, INTERNAL_ERROR);

private final ErrorCode errorCode;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ public class ArrowFlightConfig
private String server;
private boolean verifyServer = true;
private String flightServerSSLCertificate;
private String flightClientSSLCertificate;
private String flightClientSSLKey;
private boolean arrowFlightServerSslEnabled;
private Integer arrowFlightPort;

Expand Down Expand Up @@ -82,4 +84,38 @@ public ArrowFlightConfig setArrowFlightServerSslEnabled(boolean arrowFlightServe
this.arrowFlightServerSslEnabled = arrowFlightServerSslEnabled;
return this;
}

public String getFlightClientSSLCertificate()
{
return flightClientSSLCertificate;
}

/***
* Set the client SSL certificate used for mTLS authentication with Flight server.
* @param flightClientSSLCertificate path to the certificate file
* @return Returns this config instance
*/
@Config("arrow-flight.client-ssl-certificate")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it will be good to add a comment here saying this is needed for mTLS auth.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comments.

public ArrowFlightConfig setFlightClientSSLCertificate(String flightClientSSLCertificate)
{
this.flightClientSSLCertificate = flightClientSSLCertificate;
return this;
}

public String getFlightClientSSLKey()
{
return flightClientSSLKey;
}

/***
* Set the client SSL key used for mTLS authentication with Flight server
* @param flightClientSSLKey path to the key file
* @return Returns this config instance
*/
@Config("arrow-flight.client-ssl-key")
public ArrowFlightConfig setFlightClientSSLKey(String flightClientSSLKey)
{
this.flightClientSSLKey = flightClientSSLKey;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package com.facebook.plugin.arrow;

import com.facebook.airlift.log.Logger;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.SchemaTableName;
import org.apache.arrow.flight.CallOption;
Expand All @@ -30,11 +31,15 @@
import java.net.URISyntaxException;
import java.nio.ByteBuffer;
import java.nio.file.Paths;
import java.security.InvalidKeyException;
import java.security.cert.CertificateException;
import java.util.List;
import java.util.Optional;

import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_CLIENT_ERROR;
import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_INFO_ERROR;
import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_INVALID_CERT_ERROR;
import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_INVALID_KEY_ERROR;
import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_METADATA_ERROR;
import static java.nio.file.Files.newInputStream;
import static java.util.Objects.requireNonNull;
Expand All @@ -43,6 +48,7 @@ public abstract class BaseArrowFlightClientHandler
{
private final ArrowFlightConfig config;
private final BufferAllocator allocator;
private static final Logger logger = Logger.get(BaseArrowFlightClientHandler.class);

public BaseArrowFlightClientHandler(BufferAllocator allocator, ArrowFlightConfig config)
{
Expand All @@ -64,24 +70,61 @@ protected FlightClient createFlightClient()

protected FlightClient createFlightClient(Location location)
{
Optional<InputStream> trustedCertificate = Optional.empty();
Optional<InputStream> clientCertificate = Optional.empty();
Optional<InputStream> clientKey = Optional.empty();
try {
Optional<InputStream> trustedCertificate = Optional.empty();
FlightClient.Builder flightClientBuilder = FlightClient.builder(allocator, location);
flightClientBuilder.verifyServer(config.getVerifyServer());
if (config.getFlightServerSSLCertificate() != null) {
trustedCertificate = Optional.of(newInputStream(Paths.get(config.getFlightServerSSLCertificate())));
flightClientBuilder.trustedCertificates(trustedCertificate.get()).useTls();
}

FlightClient flightClient = flightClientBuilder.build();
if (trustedCertificate.isPresent()) {
trustedCertificate.get().close();
if (config.getFlightClientSSLCertificate() != null && config.getFlightClientSSLKey() != null) {
clientCertificate = Optional.of(newInputStream(Paths.get(config.getFlightClientSSLCertificate())));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trying to understand it a bit better, what happens if the certificate and key are invalid? Lets use a try-catch block here? And maybe add a test case covering this scenario?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the cert is invalid, executing a query will give the user an error that the cert is invalid. Added a test case that covers this scenario.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets add a try-catch here as well and modify the error message in the test case accordingly. Thank you for adding test case for this though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The invalid cert exception is thrown only at line 84 FlightClient flightClient = flightClientBuilder.build(); and we might get exception due to other reasons as well from the build method. So adding a try...catch here will not help in modifying the error message.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I actually foresee improperly configured client cert/key as a very probable source of error, and hence wanted to cover the scenario with a proper user facing message. Anyways I leave the final decision to you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored the code to catch errors due to invalid cert or key file. Rethrowing a Presto Exception with a custom message for those scenarios.

clientKey = Optional.of(newInputStream(Paths.get(config.getFlightClientSSLKey())));
flightClientBuilder.clientCertificate(clientCertificate.get(), clientKey.get()).useTls();
}

return flightClient;
return flightClientBuilder.build();
}
catch (Exception e) {
throw new ArrowException(ARROW_FLIGHT_CLIENT_ERROR, "Error creating flight client: " + e.getMessage(), e);
Optional<Throwable> cause = Optional.ofNullable(e.getCause());
if (cause.filter(c -> c instanceof InvalidKeyException).isPresent()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (cause.filter(c -> c instanceof InvalidKeyException).isPresent()) {
if (e instanceOf InvalidKeyException) {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be simplified like this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e will be instance of IllegalArgumentException. Inner exception e.getCause if not null, will be an instance of InvalidKeyException

throw new ArrowException(ARROW_FLIGHT_INVALID_KEY_ERROR, "Error creating flight client, invalid key file: " + e.getMessage(), e);
}
else if (cause.filter(c -> c instanceof CertificateException).isPresent()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

throw new ArrowException(ARROW_FLIGHT_INVALID_CERT_ERROR, "Error creating flight client, invalid certificate file: " + e.getMessage(), e);
}
else {
throw new ArrowException(ARROW_FLIGHT_CLIENT_ERROR, "Error creating flight client: " + e.getMessage(), e);
}
}
finally {
if (trustedCertificate.isPresent()) {
try {
trustedCertificate.get().close();
}
catch (IOException e) {
logger.error("Error closing input stream for server certificate", e);
}
}
if (clientCertificate.isPresent()) {
try {
clientCertificate.get().close();
}
catch (IOException e) {
logger.error("Error closing input stream for client certificate", e);
}
}
if (clientKey.isPresent()) {
try {
clientKey.get().close();
}
catch (IOException e) {
logger.error("Error closing input stream for client key", e);
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.plugin.arrow;

import com.facebook.airlift.log.Logger;
import com.facebook.plugin.arrow.testingServer.TestingArrowProducer;
import com.facebook.presto.testing.QueryRunner;
import com.facebook.presto.tests.AbstractTestQueryFramework;
import com.facebook.presto.tests.DistributedQueryRunner;
import com.google.common.collect.ImmutableMap;
import org.apache.arrow.flight.FlightServer;
import org.apache.arrow.flight.Location;
import org.apache.arrow.memory.RootAllocator;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;

import java.io.File;
import java.io.IOException;
import java.util.Map;
import java.util.Optional;

public abstract class AbstractArrowFlightMTLSTestFramework
extends AbstractTestQueryFramework
{
private static final Logger logger = Logger.get(AbstractArrowFlightMTLSTestFramework.class);
private final int serverPort;
private RootAllocator allocator;
private FlightServer server;
private DistributedQueryRunner arrowFlightQueryRunner;

public AbstractArrowFlightMTLSTestFramework()
throws IOException
{
this.serverPort = ArrowFlightQueryRunner.findUnusedPort();
}

@BeforeClass
void setup()
throws Exception
{
arrowFlightQueryRunner = getDistributedQueryRunner();
File certChainFile = new File("src/test/resources/mtls/server.crt");
File privateKeyFile = new File("src/test/resources/mtls/server.key");
File caCertFile = new File("src/test/resources/mtls/ca.crt");

allocator = new RootAllocator(Long.MAX_VALUE);

Location location = Location.forGrpcTls("localhost", serverPort);
server = FlightServer.builder(allocator, location, new TestingArrowProducer(allocator))
.useTls(certChainFile, privateKeyFile)
.useMTlsClientVerification(caCertFile)
.build();

server.start();
logger.info("Server listening on port %s", server.getPort());
}

@AfterClass(alwaysRun = true)
void tearDown()
throws InterruptedException
{
arrowFlightQueryRunner.close();
server.close();
allocator.close();
}

@Override
protected QueryRunner createQueryRunner()
throws Exception
{
return ArrowFlightQueryRunner.createQueryRunner(ImmutableMap.of(), getCatalogProperties(), ImmutableMap.of(), Optional.empty());
}

abstract Map<String, String> getCatalogProperties();

int getServerPort()
{
return serverPort;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,15 @@ public static DistributedQueryRunner createQueryRunner(
Optional<BiFunction<Integer, URI, Process>> externalWorkerLauncher)
throws Exception
{
return createQueryRunner(extraProperties, ImmutableMap.of("arrow-flight.server.port", String.valueOf(flightServerPort)), coordinatorProperties, externalWorkerLauncher);
ImmutableMap.Builder<String, String> catalogProperties = ImmutableMap.<String, String>builder()
.put("arrow-flight.server.port", String.valueOf(flightServerPort))
.put("arrow-flight.server", "localhost")
.put("arrow-flight.server-ssl-enabled", "true")
.put("arrow-flight.server-ssl-certificate", "src/test/resources/server.crt");
return createQueryRunner(extraProperties, catalogProperties.build(), coordinatorProperties, externalWorkerLauncher);
}

private static DistributedQueryRunner createQueryRunner(
protected static DistributedQueryRunner createQueryRunner(
Map<String, String> extraProperties,
Map<String, String> catalogProperties,
Map<String, String> coordinatorProperties,
Expand All @@ -92,13 +97,7 @@ private static DistributedQueryRunner createQueryRunner(
boolean nativeExecution = externalWorkerLauncher.isPresent();
queryRunner.installPlugin(new TestingArrowFlightPlugin(nativeExecution));

ImmutableMap.Builder<String, String> properties = ImmutableMap.<String, String>builder()
.putAll(catalogProperties)
.put("arrow-flight.server", "localhost")
.put("arrow-flight.server-ssl-enabled", "true")
.put("arrow-flight.server-ssl-certificate", "src/test/resources/server.crt");

queryRunner.createCatalog(ARROW_FLIGHT_CATALOG, ARROW_FLIGHT_CONNECTOR, properties.build());
queryRunner.createCatalog(ARROW_FLIGHT_CATALOG, ARROW_FLIGHT_CONNECTOR, catalogProperties);

return queryRunner;
}
Expand Down Expand Up @@ -140,8 +139,8 @@ public static void main(String[] args)
log.info("Server listening on port " + server.getPort());

DistributedQueryRunner queryRunner = createQueryRunner(
9443,
ImmutableMap.of("http-server.http.port", "8080"),
ImmutableMap.of("arrow-flight.server.port", String.valueOf(9443)),
ImmutableMap.of(),
Optional.empty());
Thread.sleep(10);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.plugin.arrow;

import com.facebook.airlift.log.Logger;
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.Test;

import java.io.IOException;
import java.util.Map;

public class TestArrowFlightMTLS
extends AbstractArrowFlightMTLSTestFramework
{
private static final Logger logger = Logger.get(TestArrowFlightMTLS.class);

public TestArrowFlightMTLS()
throws IOException
{
super();
}

@Override
Map<String, String> getCatalogProperties()
{
ImmutableMap.Builder<String, String> catalogProperties = ImmutableMap.<String, String>builder()
.put("arrow-flight.server.port", String.valueOf(getServerPort()))
.put("arrow-flight.server", "localhost")
.put("arrow-flight.server-ssl-enabled", "true")
.put("arrow-flight.server-ssl-certificate", "src/test/resources/mtls/server.crt")
.put("arrow-flight.client-ssl-certificate", "src/test/resources/mtls/client.crt")
.put("arrow-flight.client-ssl-key", "src/test/resources/mtls/client.key");
return catalogProperties.build();
}

@Test
public void testMtls()
{
assertQuery("SELECT COUNT(*) FROM orders");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.plugin.arrow;

import com.facebook.airlift.log.Logger;
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.Test;

import java.io.IOException;
import java.util.Map;

public class TestArrowFlightMTLSFails
extends AbstractArrowFlightMTLSTestFramework
{
private static final Logger logger = Logger.get(TestArrowFlightMTLSFails.class);

public TestArrowFlightMTLSFails()
throws IOException
{
super();
}

@Override
Map<String, String> getCatalogProperties()
{
ImmutableMap.Builder<String, String> catalogProperties = ImmutableMap.<String, String>builder()
.put("arrow-flight.server.port", String.valueOf(getServerPort()))
.put("arrow-flight.server", "localhost")
.put("arrow-flight.server-ssl-enabled", "true")
.put("arrow-flight.server-ssl-certificate", "src/test/resources/mtls/server.crt");
return catalogProperties.build();
}

@Test
public void testMtlsFailure()
{
assertQueryFails("SELECT COUNT(*) FROM orders", "ssl exception");
}
}
Loading
Loading