Skip to content

Commit 95063e7

Browse files
fix: Use gRPC DnsNameResolver for periodic DNS re-resolution
Arrow Flight's default FlightClient.Builder uses NettyChannelBuilder.forAddress(SocketAddress), which calls Location.toSocketAddress() -> new InetSocketAddress(host, port). This eagerly resolves DNS once at construction time and registers a DirectAddressNameResolverProvider that never re-resolves. For long-lived clients connecting to load-balanced endpoints (e.g. AWS ALBs) where backend IPs can change, this causes the gRPC channel to get stuck on stale IPs indefinitely. If the old IP is recycled to a different service, the client sees TLS certificate mismatches and cannot recover without being fully reconstructed. This change builds the gRPC ManagedChannel directly using NettyChannelBuilder.forTarget("dns:///host:port") instead of going through Arrow's FlightClient.Builder. The "dns:///" target scheme activates gRPC's DnsNameResolver, which periodically re-resolves the hostname (default 30s cache TTL) and triggers re-resolution on transient failures via its refresh() method. The FlightClient is then created via FlightGrpcUtils.createFlightClient() with the custom channel. Fixes: spicehq/customer-summation#7
1 parent 063a963 commit 95063e7

2 files changed

Lines changed: 47 additions & 11 deletions

File tree

src/main/java/ai/spice/SpiceClient.java

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ of this software and associated documentation files (the "Software"), to deal
4949
import org.apache.arrow.adbc.driver.flightsql.FlightSqlDriver;
5050
import org.apache.arrow.flight.CallStatus;
5151
import org.apache.arrow.flight.FlightClient;
52-
import org.apache.arrow.flight.FlightClient.Builder;
52+
import org.apache.arrow.flight.FlightClientMiddleware;
53+
import org.apache.arrow.flight.FlightGrpcUtils;
5354
import org.apache.arrow.flight.FlightStream;
5455
import org.apache.arrow.flight.Location;
5556
import org.apache.arrow.flight.Ticket;
@@ -59,6 +60,10 @@ of this software and associated documentation files (the "Software"), to deal
5960
import org.apache.arrow.flight.grpc.CredentialCallOption;
6061
import org.apache.arrow.flight.FlightInfo;
6162
import org.apache.arrow.flight.FlightRuntimeException;
63+
64+
import io.grpc.ManagedChannel;
65+
import io.grpc.netty.GrpcSslContexts;
66+
import io.grpc.netty.NettyChannelBuilder;
6267
import org.apache.arrow.memory.BufferAllocator;
6368
import org.apache.arrow.memory.RootAllocator;
6469
import org.apache.arrow.vector.BigIntVector;
@@ -209,12 +214,40 @@ public SpiceClient(String appId, String apiKey, URI flightAddress, URI httpAddre
209214
? Long.MAX_VALUE
210215
: memoryLimitMB * BYTES_PER_MB;
211216
this.allocator = new RootAllocator(memoryLimitBytes);
212-
Builder builder = FlightClient.builder(allocator, new Location(this.flightAddress));
217+
218+
// Build a gRPC channel using forTarget() with the "dns:///" scheme so that
219+
// gRPC's DnsNameResolver periodically re-resolves the hostname. This is critical
220+
// for long-lived clients connecting to load-balanced endpoints (e.g. AWS ALBs)
221+
// where backend IPs can change. Arrow Flight's default FlightClient.Builder uses
222+
// NettyChannelBuilder.forAddress(SocketAddress), which resolves DNS exactly once
223+
// at construction time and never re-resolves, causing clients to get stuck on
224+
// stale IPs.
225+
boolean useTls = this.flightAddress.getScheme().equals("grpc+tls");
226+
String host = this.flightAddress.getHost();
227+
int port = this.flightAddress.getPort();
228+
String target = "dns:///" + host + ":" + port;
229+
230+
NettyChannelBuilder channelBuilder = NettyChannelBuilder.forTarget(target);
231+
if (useTls) {
232+
try {
233+
channelBuilder.useTransportSecurity()
234+
.sslContext(GrpcSslContexts.forClient().build());
235+
} catch (Exception e) {
236+
throw new RuntimeException("Failed to configure TLS for Flight client", e);
237+
}
238+
} else {
239+
channelBuilder.usePlaintext();
240+
}
241+
channelBuilder
242+
.maxInboundMessageSize(Integer.MAX_VALUE)
243+
.maxInboundMetadataSize(Integer.MAX_VALUE);
244+
ManagedChannel channel = channelBuilder.build();
213245

214246
if (Strings.isNullOrEmpty(apiKey)) {
215-
this.flightClient = new FlightSqlClient(builder.build());
247+
FlightClient client = FlightGrpcUtils.createFlightClient(allocator, channel);
248+
this.flightClient = new FlightSqlClient(client);
216249
initRetryers();
217-
logger.debug("SpiceClient initialized (unauthenticated) - flightAddress={}", this.flightAddress);
250+
logger.debug("SpiceClient initialized (unauthenticated) - flightAddress={}, target={}", this.flightAddress, target);
218251
return;
219252
}
220253

@@ -236,15 +269,18 @@ public SpiceClient(String appId, String apiKey, URI flightAddress, URI httpAddre
236269
// factories instead
237270
final HeaderAuthMiddlewareFactory combinedFactory = new HeaderAuthMiddlewareFactory(authFactory, headers);
238271

239-
final FlightClient client = builder.intercept(combinedFactory).build();
272+
List<FlightClientMiddleware.Factory> middleware = new ArrayList<>();
273+
middleware.add(combinedFactory);
274+
275+
final FlightClient client = FlightGrpcUtils.createFlightClient(allocator, channel, middleware);
240276
client.handshake(new CredentialCallOption(new BasicAuthCredentialWriter(this.appId, this.apiKey)));
241277
this.authCallOptions = authFactory.getCredentialCallOption();
242278
this.flightClient = new FlightSqlClient(client);
243279

244280
// Initialize cached retryers (immutable, built once)
245281
initRetryers();
246282

247-
logger.debug("SpiceClient initialized (authenticated) - flightAddress={}, appId={}", this.flightAddress, this.appId);
283+
logger.debug("SpiceClient initialized (authenticated) - flightAddress={}, appId={}, target={}", this.flightAddress, this.appId, target);
248284
}
249285

250286
/**

src/test/java/ai/spice/ParameterizedQueryTest.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ public void testParameterizedQuerySpiceOSS() throws Exception {
102102
}
103103
} catch (ExecutionException e) {
104104
// Local Spice runtime might not be running, skip test
105-
if (e.getMessage().contains("UNAVAILABLE") || e.getMessage().contains("Connection refused") || e.getMessage().contains("not found")) {
105+
if (e.getMessage().contains("UNAVAILABLE") || e.getMessage().contains("Connection refused") || e.getMessage().contains("not found") || e.getMessage().contains("io exception")) {
106106
return;
107107
}
108108
throw e;
@@ -129,7 +129,7 @@ public void testMultipleParameters() throws Exception {
129129
}
130130
} catch (ExecutionException e) {
131131
// Local Spice runtime might not be running, skip test
132-
if (e.getMessage().contains("UNAVAILABLE") || e.getMessage().contains("Connection refused") || e.getMessage().contains("not found")) {
132+
if (e.getMessage().contains("UNAVAILABLE") || e.getMessage().contains("Connection refused") || e.getMessage().contains("not found") || e.getMessage().contains("io exception")) {
133133
return;
134134
}
135135
throw e;
@@ -157,7 +157,7 @@ public void testStringParameter() throws Exception {
157157
}
158158
} catch (ExecutionException e) {
159159
// Local Spice runtime might not be running, skip test
160-
if (e.getMessage().contains("UNAVAILABLE") || e.getMessage().contains("Connection refused") || e.getMessage().contains("not found")) {
160+
if (e.getMessage().contains("UNAVAILABLE") || e.getMessage().contains("Connection refused") || e.getMessage().contains("not found") || e.getMessage().contains("io exception")) {
161161
return;
162162
}
163163
throw e;
@@ -184,7 +184,7 @@ public void testExplicitParamTypes() throws Exception {
184184
}
185185
} catch (ExecutionException e) {
186186
// Local Spice runtime might not be running, skip test
187-
if (e.getMessage().contains("UNAVAILABLE") || e.getMessage().contains("Connection refused") || e.getMessage().contains("not found")) {
187+
if (e.getMessage().contains("UNAVAILABLE") || e.getMessage().contains("Connection refused") || e.getMessage().contains("not found") || e.getMessage().contains("io exception")) {
188188
return;
189189
}
190190
throw e;
@@ -212,7 +212,7 @@ public void testMixedParameterTypes() throws Exception {
212212
}
213213
} catch (ExecutionException e) {
214214
// Local Spice runtime might not be running, skip test
215-
if (e.getMessage().contains("UNAVAILABLE") || e.getMessage().contains("Connection refused") || e.getMessage().contains("not found")) {
215+
if (e.getMessage().contains("UNAVAILABLE") || e.getMessage().contains("Connection refused") || e.getMessage().contains("not found") || e.getMessage().contains("io exception")) {
216216
return;
217217
}
218218
throw e;

0 commit comments

Comments
 (0)