Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 112 additions & 6 deletions src/main/java/ai/spice/SpiceClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,8 @@ public class SpiceClient implements AutoCloseable {
private static final int MAX_INBOUND_MESSAGE_SIZE = Integer.MAX_VALUE;
private static final int MAX_INBOUND_METADATA_SIZE = Integer.MAX_VALUE;

// Cached HttpClient for refresh operations (thread-safe, connection pooling)
private static final HttpClient HTTP_CLIENT = HttpClient.newBuilder()
.connectTimeout(Duration.ofSeconds(15))
.build();
// HttpClient for refresh operations (thread-safe, connection pooling)
private final HttpClient httpClient;
Comment thread
phillipleblanc marked this conversation as resolved.

// Pre-computed parameter field names to avoid string concatenation in hot path
private static final String[] PARAM_NAMES = new String[64];
Expand All @@ -156,6 +154,9 @@ public class SpiceClient implements AutoCloseable {
private URI flightAddress;
private URI httpAddress;
private int maxRetries;
private String tlsClientCertFile;
private String tlsClientKeyFile;
private String tlsRootCertFile;
private FlightSqlClient flightClient;
private CredentialCallOption authCallOptions = null;
Comment thread
phillipleblanc marked this conversation as resolved.
private BufferAllocator allocator;
Expand Down Expand Up @@ -200,11 +201,24 @@ public static SpiceClientBuilder builder() throws URISyntaxException {
*/
public SpiceClient(String appId, String apiKey, URI flightAddress, URI httpAddress, int maxRetries,
String userAgent, long memoryLimitMB) {
this(appId, apiKey, flightAddress, httpAddress, maxRetries, userAgent, memoryLimitMB, null, null);
}

public SpiceClient(String appId, String apiKey, URI flightAddress, URI httpAddress, int maxRetries,
String userAgent, long memoryLimitMB, String tlsClientCertFile, String tlsClientKeyFile) {
this(appId, apiKey, flightAddress, httpAddress, maxRetries, userAgent, memoryLimitMB, tlsClientCertFile, tlsClientKeyFile, null);
}

public SpiceClient(String appId, String apiKey, URI flightAddress, URI httpAddress, int maxRetries,
String userAgent, long memoryLimitMB, String tlsClientCertFile, String tlsClientKeyFile, String tlsRootCertFile) {
this.appId = appId;
this.apiKey = apiKey;
this.maxRetries = maxRetries;
this.httpAddress = httpAddress;
this.userAgent = userAgent;
this.tlsClientCertFile = tlsClientCertFile;
this.tlsClientKeyFile = tlsClientKeyFile;
Comment thread
phillipleblanc marked this conversation as resolved.
this.tlsRootCertFile = tlsRootCertFile;
Comment thread
phillipleblanc marked this conversation as resolved.

// Arrow Flight requires URI to be grpc protocol, convert http/https for
// convinience
Expand All @@ -223,6 +237,19 @@ public SpiceClient(String appId, String apiKey, URI flightAddress, URI httpAddre
: memoryLimitMB * BYTES_PER_MB;
this.allocator = new RootAllocator(memoryLimitBytes);

// Build the HTTP client with optional mTLS support
HttpClient.Builder httpBuilder = HttpClient.newBuilder()
.connectTimeout(Duration.ofSeconds(15));
if (this.tlsRootCertFile != null || (this.tlsClientCertFile != null && this.tlsClientKeyFile != null)) {
try {
javax.net.ssl.SSLContext sslContext = buildSslContext();
httpBuilder.sslContext(sslContext);
} catch (Exception e) {
throw new RuntimeException("Failed to configure TLS for HTTP client", e);
}
}
this.httpClient = httpBuilder.build();

try {
// Build the Flight client (channel + auth handshake)
buildFlightClient();
Expand Down Expand Up @@ -274,8 +301,17 @@ private synchronized void buildFlightClient() {
NettyChannelBuilder channelBuilder = NettyChannelBuilder.forTarget(target);
if (useTls) {
try {
var sslContextBuilder = GrpcSslContexts.forClient();
if (this.tlsClientCertFile != null && this.tlsClientKeyFile != null) {
sslContextBuilder.keyManager(
new java.io.File(this.tlsClientCertFile),
new java.io.File(this.tlsClientKeyFile));
}
Comment thread
phillipleblanc marked this conversation as resolved.
if (this.tlsRootCertFile != null) {
sslContextBuilder.trustManager(new java.io.File(this.tlsRootCertFile));
}
channelBuilder.useTransportSecurity()
.sslContext(GrpcSslContexts.forClient().build());
.sslContext(sslContextBuilder.build());
} catch (Exception e) {
throw new RuntimeException("Failed to configure TLS for Flight client", e);
}
Expand Down Expand Up @@ -404,6 +440,76 @@ public synchronized void reset() {

/**
* Initializes the cached retryer instances.
/**
* Builds an SSLContext configured with the custom CA and/or client certificate
* for the JDK HTTP client.
*/
Comment thread
Jeadie marked this conversation as resolved.
private javax.net.ssl.SSLContext buildSslContext() throws Exception {
// Ensure BouncyCastle provider is registered for PEM private key parsing
if (java.security.Security.getProvider("BC") == null) {
java.security.Security.addProvider(new org.bouncycastle.jce.provider.BouncyCastleProvider());
}

javax.net.ssl.KeyManager[] keyManagers = null;
Comment thread
phillipleblanc marked this conversation as resolved.
javax.net.ssl.TrustManager[] trustManagers = null;

if (this.tlsClientCertFile != null && this.tlsClientKeyFile != null) {
// Load the client certificate
java.security.cert.CertificateFactory cf = java.security.cert.CertificateFactory.getInstance("X.509");
java.security.cert.Certificate clientCert;
try (java.io.FileInputStream fis = new java.io.FileInputStream(this.tlsClientCertFile)) {
clientCert = cf.generateCertificate(fis);
}

// Parse the PEM private key using BouncyCastle
java.security.PrivateKey privateKey;
try (java.io.FileReader keyReader = new java.io.FileReader(this.tlsClientKeyFile, java.nio.charset.StandardCharsets.UTF_8);
org.bouncycastle.openssl.PEMParser pemParser = new org.bouncycastle.openssl.PEMParser(keyReader)) {
Object parsed = pemParser.readObject();
org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter converter =
new org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter().setProvider("BC");
if (parsed instanceof org.bouncycastle.asn1.pkcs.PrivateKeyInfo) {
privateKey = converter.getPrivateKey((org.bouncycastle.asn1.pkcs.PrivateKeyInfo) parsed);
} else if (parsed instanceof org.bouncycastle.openssl.PEMKeyPair) {
privateKey = converter.getPrivateKey(((org.bouncycastle.openssl.PEMKeyPair) parsed).getPrivateKeyInfo());
} else {
throw new IllegalArgumentException("Unsupported PEM key format in " + this.tlsClientKeyFile);
}
}

// Build a KeyStore with the client identity
java.security.KeyStore keyStore = java.security.KeyStore.getInstance("PKCS12");
keyStore.load(null, null);
keyStore.setKeyEntry("client", privateKey, new char[0],
new java.security.cert.Certificate[]{clientCert});
javax.net.ssl.KeyManagerFactory kmf = javax.net.ssl.KeyManagerFactory.getInstance(
javax.net.ssl.KeyManagerFactory.getDefaultAlgorithm());
kmf.init(keyStore, new char[0]);
keyManagers = kmf.getKeyManagers();
}

if (this.tlsRootCertFile != null) {
java.security.cert.CertificateFactory cf = java.security.cert.CertificateFactory.getInstance("X.509");
java.security.KeyStore trustStore = java.security.KeyStore.getInstance(java.security.KeyStore.getDefaultType());
trustStore.load(null, null);
try (java.io.FileInputStream fis = new java.io.FileInputStream(this.tlsRootCertFile)) {
int i = 0;
for (java.security.cert.Certificate cert : cf.generateCertificates(fis)) {
trustStore.setCertificateEntry("custom-ca-" + i++, cert);
}
}
javax.net.ssl.TrustManagerFactory tmf = javax.net.ssl.TrustManagerFactory.getInstance(
javax.net.ssl.TrustManagerFactory.getDefaultAlgorithm());
tmf.init(trustStore);
trustManagers = tmf.getTrustManagers();
}

javax.net.ssl.SSLContext sslContext = javax.net.ssl.SSLContext.getInstance("TLS");
sslContext.init(keyManagers, trustManagers, null);
return sslContext;
Comment thread
phillipleblanc marked this conversation as resolved.
}

/**
* Called from constructor and must be called after maxRetries is set.
*/
private void initRetryers() {
Expand Down Expand Up @@ -1017,7 +1123,7 @@ public void refreshDataset(String dataset, RefreshOptions refreshOptions) throws
}

HttpRequest request = builder.build();
HttpResponse<String> response = HTTP_CLIENT.send(request, HttpResponse.BodyHandlers.ofString());
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());

if (response.statusCode() != 201) {
logger.error("Dataset refresh failed - dataset={}, statusCode={}, response={}", dataset, response.statusCode(), response.body());
Expand Down
49 changes: 48 additions & 1 deletion src/main/java/ai/spice/SpiceClientBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ public class SpiceClientBuilder {
private URI httpAddress;
private int maxRetries = 3;
private long memoryLimitMB = Long.MAX_VALUE; // Default is all available memory.
private String tlsClientCertFile;
private String tlsClientKeyFile;
private String tlsRootCertFile;

/**
* Constructs a new SpiceClientBuilder instance
Expand Down Expand Up @@ -167,12 +170,56 @@ public SpiceClientBuilder withArrowMemoryLimitMB(long memoryLimitMB) {
return this;
}

/**
* Sets the path to a PEM-encoded client certificate file for mTLS.
* Must be used together with {@link #withTlsClientKeyFile(String)}.
*
* @param certFile Path to the client certificate PEM file
* @return The current instance of SpiceClientBuilder for method chaining.
*/
public SpiceClientBuilder withTlsClientCertFile(String certFile) {
this.tlsClientCertFile = certFile;
return this;
}
Comment thread
phillipleblanc marked this conversation as resolved.

/**
* Sets the path to a PEM-encoded client private key file for mTLS.
* Must be used together with {@link #withTlsClientCertFile(String)}.
*
* @param keyFile Path to the client private key PEM file
* @return The current instance of SpiceClientBuilder for method chaining.
*/
public SpiceClientBuilder withTlsClientKeyFile(String keyFile) {
Comment thread
phillipleblanc marked this conversation as resolved.
this.tlsClientKeyFile = keyFile;
return this;
}

/**
* Sets the path to a PEM-encoded CA certificate file for server verification.
* When set, this CA is used instead of the system trust store.
*
* @param caFile Path to the CA certificate PEM file
* @return The current instance of SpiceClientBuilder for method chaining.
*/
public SpiceClientBuilder withTlsRootCertFile(String caFile) {
this.tlsRootCertFile = caFile;
return this;
}

/**
* Creates SpiceClient with provided parameters.
*
* @return The SpiceClient instance
*/
public SpiceClient build() {
Comment thread
phillipleblanc marked this conversation as resolved.
return new SpiceClient(appId, apiKey, flightAddress, httpAddress, maxRetries, userAgent, memoryLimitMB);
// Validate that client cert and key are either both set or both unset
boolean hasCert = tlsClientCertFile != null && !tlsClientCertFile.isBlank();
boolean hasKey = tlsClientKeyFile != null && !tlsClientKeyFile.isBlank();
if (hasCert != hasKey) {
throw new IllegalArgumentException(
"Both tlsClientCertFile and tlsClientKeyFile must be provided together for mTLS. "
+ (hasCert ? "tlsClientKeyFile is missing." : "tlsClientCertFile is missing."));
}
return new SpiceClient(appId, apiKey, flightAddress, httpAddress, maxRetries, userAgent, memoryLimitMB, tlsClientCertFile, tlsClientKeyFile, tlsRootCertFile);
Comment thread
phillipleblanc marked this conversation as resolved.
Comment thread
phillipleblanc marked this conversation as resolved.
Comment thread
phillipleblanc marked this conversation as resolved.
}
Comment thread
phillipleblanc marked this conversation as resolved.
}
Loading