Skip to content

Commit 500bc3d

Browse files
feat: add mTLS client certificate support
1 parent 1380a8f commit 500bc3d

2 files changed

Lines changed: 160 additions & 7 deletions

File tree

src/main/java/ai/spice/SpiceClient.java

Lines changed: 112 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,8 @@ public class SpiceClient implements AutoCloseable {
137137
private static final int MAX_INBOUND_MESSAGE_SIZE = Integer.MAX_VALUE;
138138
private static final int MAX_INBOUND_METADATA_SIZE = Integer.MAX_VALUE;
139139

140-
// Cached HttpClient for refresh operations (thread-safe, connection pooling)
141-
private static final HttpClient HTTP_CLIENT = HttpClient.newBuilder()
142-
.connectTimeout(Duration.ofSeconds(15))
143-
.build();
140+
// HttpClient for refresh operations (thread-safe, connection pooling)
141+
private final HttpClient httpClient;
144142

145143
// Pre-computed parameter field names to avoid string concatenation in hot path
146144
private static final String[] PARAM_NAMES = new String[64];
@@ -156,6 +154,9 @@ public class SpiceClient implements AutoCloseable {
156154
private URI flightAddress;
157155
private URI httpAddress;
158156
private int maxRetries;
157+
private String tlsClientCertFile;
158+
private String tlsClientKeyFile;
159+
private String tlsRootCertFile;
159160
private FlightSqlClient flightClient;
160161
private CredentialCallOption authCallOptions = null;
161162
private BufferAllocator allocator;
@@ -200,11 +201,24 @@ public static SpiceClientBuilder builder() throws URISyntaxException {
200201
*/
201202
public SpiceClient(String appId, String apiKey, URI flightAddress, URI httpAddress, int maxRetries,
202203
String userAgent, long memoryLimitMB) {
204+
this(appId, apiKey, flightAddress, httpAddress, maxRetries, userAgent, memoryLimitMB, null, null);
205+
}
206+
207+
public SpiceClient(String appId, String apiKey, URI flightAddress, URI httpAddress, int maxRetries,
208+
String userAgent, long memoryLimitMB, String tlsClientCertFile, String tlsClientKeyFile) {
209+
this(appId, apiKey, flightAddress, httpAddress, maxRetries, userAgent, memoryLimitMB, tlsClientCertFile, tlsClientKeyFile, null);
210+
}
211+
212+
public SpiceClient(String appId, String apiKey, URI flightAddress, URI httpAddress, int maxRetries,
213+
String userAgent, long memoryLimitMB, String tlsClientCertFile, String tlsClientKeyFile, String tlsRootCertFile) {
203214
this.appId = appId;
204215
this.apiKey = apiKey;
205216
this.maxRetries = maxRetries;
206217
this.httpAddress = httpAddress;
207218
this.userAgent = userAgent;
219+
this.tlsClientCertFile = tlsClientCertFile;
220+
this.tlsClientKeyFile = tlsClientKeyFile;
221+
this.tlsRootCertFile = tlsRootCertFile;
208222

209223
// Arrow Flight requires URI to be grpc protocol, convert http/https for
210224
// convinience
@@ -223,6 +237,19 @@ public SpiceClient(String appId, String apiKey, URI flightAddress, URI httpAddre
223237
: memoryLimitMB * BYTES_PER_MB;
224238
this.allocator = new RootAllocator(memoryLimitBytes);
225239

240+
// Build the HTTP client with optional mTLS support
241+
HttpClient.Builder httpBuilder = HttpClient.newBuilder()
242+
.connectTimeout(Duration.ofSeconds(15));
243+
if (this.tlsRootCertFile != null || (this.tlsClientCertFile != null && this.tlsClientKeyFile != null)) {
244+
try {
245+
javax.net.ssl.SSLContext sslContext = buildSslContext();
246+
httpBuilder.sslContext(sslContext);
247+
} catch (Exception e) {
248+
throw new RuntimeException("Failed to configure TLS for HTTP client", e);
249+
}
250+
}
251+
this.httpClient = httpBuilder.build();
252+
226253
try {
227254
// Build the Flight client (channel + auth handshake)
228255
buildFlightClient();
@@ -274,8 +301,17 @@ private synchronized void buildFlightClient() {
274301
NettyChannelBuilder channelBuilder = NettyChannelBuilder.forTarget(target);
275302
if (useTls) {
276303
try {
304+
var sslContextBuilder = GrpcSslContexts.forClient();
305+
if (this.tlsClientCertFile != null && this.tlsClientKeyFile != null) {
306+
sslContextBuilder.keyManager(
307+
new java.io.File(this.tlsClientCertFile),
308+
new java.io.File(this.tlsClientKeyFile));
309+
}
310+
if (this.tlsRootCertFile != null) {
311+
sslContextBuilder.trustManager(new java.io.File(this.tlsRootCertFile));
312+
}
277313
channelBuilder.useTransportSecurity()
278-
.sslContext(GrpcSslContexts.forClient().build());
314+
.sslContext(sslContextBuilder.build());
279315
} catch (Exception e) {
280316
throw new RuntimeException("Failed to configure TLS for Flight client", e);
281317
}
@@ -404,6 +440,76 @@ public synchronized void reset() {
404440

405441
/**
406442
* Initializes the cached retryer instances.
443+
/**
444+
* Builds an SSLContext configured with the custom CA and/or client certificate
445+
* for the JDK HTTP client.
446+
*/
447+
private javax.net.ssl.SSLContext buildSslContext() throws Exception {
448+
// Ensure BouncyCastle provider is registered for PEM private key parsing
449+
if (java.security.Security.getProvider("BC") == null) {
450+
java.security.Security.addProvider(new org.bouncycastle.jce.provider.BouncyCastleProvider());
451+
}
452+
453+
javax.net.ssl.KeyManager[] keyManagers = null;
454+
javax.net.ssl.TrustManager[] trustManagers = null;
455+
456+
if (this.tlsClientCertFile != null && this.tlsClientKeyFile != null) {
457+
// Load the client certificate
458+
java.security.cert.CertificateFactory cf = java.security.cert.CertificateFactory.getInstance("X.509");
459+
java.security.cert.Certificate clientCert;
460+
try (java.io.FileInputStream fis = new java.io.FileInputStream(this.tlsClientCertFile)) {
461+
clientCert = cf.generateCertificate(fis);
462+
}
463+
464+
// Parse the PEM private key using BouncyCastle
465+
java.security.PrivateKey privateKey;
466+
try (java.io.FileReader keyReader = new java.io.FileReader(this.tlsClientKeyFile);
467+
org.bouncycastle.openssl.PEMParser pemParser = new org.bouncycastle.openssl.PEMParser(keyReader)) {
468+
Object parsed = pemParser.readObject();
469+
org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter converter =
470+
new org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter().setProvider("BC");
471+
if (parsed instanceof org.bouncycastle.asn1.pkcs.PrivateKeyInfo) {
472+
privateKey = converter.getPrivateKey((org.bouncycastle.asn1.pkcs.PrivateKeyInfo) parsed);
473+
} else if (parsed instanceof org.bouncycastle.openssl.PEMKeyPair) {
474+
privateKey = converter.getPrivateKey(((org.bouncycastle.openssl.PEMKeyPair) parsed).getPrivateKeyInfo());
475+
} else {
476+
throw new IllegalArgumentException("Unsupported PEM key format in " + this.tlsClientKeyFile);
477+
}
478+
}
479+
480+
// Build a KeyStore with the client identity
481+
java.security.KeyStore keyStore = java.security.KeyStore.getInstance("PKCS12");
482+
keyStore.load(null, null);
483+
keyStore.setKeyEntry("client", privateKey, new char[0],
484+
new java.security.cert.Certificate[]{clientCert});
485+
javax.net.ssl.KeyManagerFactory kmf = javax.net.ssl.KeyManagerFactory.getInstance(
486+
javax.net.ssl.KeyManagerFactory.getDefaultAlgorithm());
487+
kmf.init(keyStore, new char[0]);
488+
keyManagers = kmf.getKeyManagers();
489+
}
490+
491+
if (this.tlsRootCertFile != null) {
492+
java.security.cert.CertificateFactory cf = java.security.cert.CertificateFactory.getInstance("X.509");
493+
java.security.KeyStore trustStore = java.security.KeyStore.getInstance(java.security.KeyStore.getDefaultType());
494+
trustStore.load(null, null);
495+
try (java.io.FileInputStream fis = new java.io.FileInputStream(this.tlsRootCertFile)) {
496+
int i = 0;
497+
for (java.security.cert.Certificate cert : cf.generateCertificates(fis)) {
498+
trustStore.setCertificateEntry("custom-ca-" + i++, cert);
499+
}
500+
}
501+
javax.net.ssl.TrustManagerFactory tmf = javax.net.ssl.TrustManagerFactory.getInstance(
502+
javax.net.ssl.TrustManagerFactory.getDefaultAlgorithm());
503+
tmf.init(trustStore);
504+
trustManagers = tmf.getTrustManagers();
505+
}
506+
507+
javax.net.ssl.SSLContext sslContext = javax.net.ssl.SSLContext.getInstance("TLS");
508+
sslContext.init(keyManagers, trustManagers, null);
509+
return sslContext;
510+
}
511+
512+
/**
407513
* Called from constructor and must be called after maxRetries is set.
408514
*/
409515
private void initRetryers() {
@@ -1017,7 +1123,7 @@ public void refreshDataset(String dataset, RefreshOptions refreshOptions) throws
10171123
}
10181124

10191125
HttpRequest request = builder.build();
1020-
HttpResponse<String> response = HTTP_CLIENT.send(request, HttpResponse.BodyHandlers.ofString());
1126+
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
10211127

10221128
if (response.statusCode() != 201) {
10231129
logger.error("Dataset refresh failed - dataset={}, statusCode={}, response={}", dataset, response.statusCode(), response.body());

src/main/java/ai/spice/SpiceClientBuilder.java

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ public class SpiceClientBuilder {
3939
private URI httpAddress;
4040
private int maxRetries = 3;
4141
private long memoryLimitMB = Long.MAX_VALUE; // Default is all available memory.
42+
private String tlsClientCertFile;
43+
private String tlsClientKeyFile;
44+
private String tlsRootCertFile;
4245

4346
/**
4447
* Constructs a new SpiceClientBuilder instance
@@ -167,12 +170,56 @@ public SpiceClientBuilder withArrowMemoryLimitMB(long memoryLimitMB) {
167170
return this;
168171
}
169172

173+
/**
174+
* Sets the path to a PEM-encoded client certificate file for mTLS.
175+
* Must be used together with {@link #withTlsClientKeyFile(String)}.
176+
*
177+
* @param certFile Path to the client certificate PEM file
178+
* @return The current instance of SpiceClientBuilder for method chaining.
179+
*/
180+
public SpiceClientBuilder withTlsClientCertFile(String certFile) {
181+
this.tlsClientCertFile = certFile;
182+
return this;
183+
}
184+
185+
/**
186+
* Sets the path to a PEM-encoded client private key file for mTLS.
187+
* Must be used together with {@link #withTlsClientCertFile(String)}.
188+
*
189+
* @param keyFile Path to the client private key PEM file
190+
* @return The current instance of SpiceClientBuilder for method chaining.
191+
*/
192+
public SpiceClientBuilder withTlsClientKeyFile(String keyFile) {
193+
this.tlsClientKeyFile = keyFile;
194+
return this;
195+
}
196+
197+
/**
198+
* Sets the path to a PEM-encoded CA certificate file for server verification.
199+
* When set, this CA is used instead of the system trust store.
200+
*
201+
* @param caFile Path to the CA certificate PEM file
202+
* @return The current instance of SpiceClientBuilder for method chaining.
203+
*/
204+
public SpiceClientBuilder withTlsRootCertFile(String caFile) {
205+
this.tlsRootCertFile = caFile;
206+
return this;
207+
}
208+
170209
/**
171210
* Creates SpiceClient with provided parameters.
172211
*
173212
* @return The SpiceClient instance
174213
*/
175214
public SpiceClient build() {
176-
return new SpiceClient(appId, apiKey, flightAddress, httpAddress, maxRetries, userAgent, memoryLimitMB);
215+
// Validate that client cert and key are either both set or both unset
216+
boolean hasCert = tlsClientCertFile != null && !tlsClientCertFile.isBlank();
217+
boolean hasKey = tlsClientKeyFile != null && !tlsClientKeyFile.isBlank();
218+
if (hasCert != hasKey) {
219+
throw new IllegalArgumentException(
220+
"Both tlsClientCertFile and tlsClientKeyFile must be provided together for mTLS. "
221+
+ (hasCert ? "tlsClientKeyFile is missing." : "tlsClientCertFile is missing."));
222+
}
223+
return new SpiceClient(appId, apiKey, flightAddress, httpAddress, maxRetries, userAgent, memoryLimitMB, tlsClientCertFile, tlsClientKeyFile, tlsRootCertFile);
177224
}
178225
}

0 commit comments

Comments
 (0)