Skip to content

Commit 5474bf2

Browse files
jchrysmirromutth
andauthored
Fixed NPE with SslMode.TUNNEL Usage (#225)
Motivation: A NPE was identified when utilizing `SslMode.TUNNEL`. The issue arises when `ConnectionContext#isMariaDb` is invoked from `SslBridgeHandler#isTls13Enabled`, leading to an NPE due to the `ConnectionContext` not being initialized at that time. Modification: Do not invoke `ConnectionContext#isMariaDb` when it is not initialized. Result: This change addresses the NPE issue, ensuring stability when `SslMode.TUNNEL` is selected. It resolves the problem reported in GoogleCloudPlatform/cloud-sql-jdbc-socket-factory#1828 --------- Signed-off-by: jchrys <[email protected]> Co-authored-by: Mirro Mutth <[email protected]>
1 parent 6982acc commit 5474bf2

File tree

4 files changed

+299
-2
lines changed

4 files changed

+299
-2
lines changed

pom.xml

+12
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
<mbr.version>0.3.0.RELEASE</mbr.version>
8080
<jsr305.version>3.0.2</jsr305.version>
8181
<java-annotations.version>24.1.0</java-annotations.version>
82+
<bouncy-castle.version>1.77</bouncy-castle.version>
8283
</properties>
8384

8485
<dependencyManagement>
@@ -117,6 +118,12 @@
117118
<version>${java-annotations.version}</version>
118119
<scope>provided</scope>
119120
</dependency>
121+
<dependency>
122+
<groupId>org.bouncycastle</groupId>
123+
<artifactId>bcpkix-jdk18on</artifactId>
124+
<version>${bouncy-castle.version}</version>
125+
<scope>test</scope>
126+
</dependency>
120127
</dependencies>
121128
</dependencyManagement>
122129

@@ -240,6 +247,11 @@
240247
<artifactId>jackson-annotations</artifactId>
241248
<scope>test</scope>
242249
</dependency>
250+
<dependency>
251+
<groupId>org.bouncycastle</groupId>
252+
<artifactId>bcpkix-jdk18on</artifactId>
253+
<scope>test</scope>
254+
</dependency>
243255
</dependencies>
244256

245257
<build>

src/main/java/io/asyncer/r2dbc/mysql/ConnectionContext.java

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ public final class ConnectionContext implements CodecContext {
5757
*/
5858
private volatile short serverStatuses = ServerStatuses.AUTO_COMMIT;
5959

60+
@Nullable
6061
private volatile Capability capability = null;
6162

6263
ConnectionContext(ZeroDateOption zeroDateOption, @Nullable Path localInfilePath,

src/main/java/io/asyncer/r2dbc/mysql/client/SslBridgeHandler.java

+4-2
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,10 @@ static MySqlSslContextSpec forClient(MySqlSslConfiguration ssl, ConnectionContex
220220
.applicationProtocolConfig(null);
221221
String[] tlsProtocols = ssl.getTlsVersion();
222222

223-
if (tlsProtocols.length > 0) {
224-
builder.protocols(tlsProtocols);
223+
if (tlsProtocols.length > 0 || ssl.getSslMode() == SslMode.TUNNEL) {
224+
if (tlsProtocols.length > 0) {
225+
builder.protocols(tlsProtocols);
226+
}
225227
} else if (isTls13Enabled(context)) {
226228
builder.protocols(TLS_PROTOCOLS);
227229
} else {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
/*
2+
* Copyright 2024 asyncer.io projects
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.asyncer.r2dbc.mysql;
18+
19+
20+
import io.asyncer.r2dbc.mysql.constant.SslMode;
21+
import io.netty.bootstrap.Bootstrap;
22+
import io.netty.bootstrap.ServerBootstrap;
23+
import io.netty.buffer.Unpooled;
24+
import io.netty.channel.Channel;
25+
import io.netty.channel.ChannelFuture;
26+
import io.netty.channel.ChannelFutureListener;
27+
import io.netty.channel.ChannelHandlerContext;
28+
import io.netty.channel.ChannelInboundHandlerAdapter;
29+
import io.netty.channel.ChannelInitializer;
30+
import io.netty.channel.ChannelOption;
31+
import io.netty.channel.nio.NioEventLoopGroup;
32+
import io.netty.channel.socket.SocketChannel;
33+
import io.netty.channel.socket.nio.NioServerSocketChannel;
34+
import io.netty.handler.ssl.SslContext;
35+
import io.netty.handler.ssl.SslContextBuilder;
36+
import io.netty.handler.ssl.util.SelfSignedCertificate;
37+
import org.junit.jupiter.api.AfterEach;
38+
import org.junit.jupiter.api.BeforeEach;
39+
import org.junit.jupiter.api.Test;
40+
41+
import javax.net.ssl.SSLException;
42+
import java.net.InetSocketAddress;
43+
import java.security.cert.CertificateException;
44+
import java.time.Duration;
45+
46+
import static org.assertj.core.api.Assertions.assertThat;
47+
48+
public class SslTunnelIntegrationTest {
49+
50+
private SelfSignedCertificate server;
51+
52+
private SelfSignedCertificate client;
53+
54+
private SslTunnelServer sslTunnelServer;
55+
56+
@BeforeEach
57+
void setUp() throws CertificateException, SSLException, InterruptedException {
58+
server = new SelfSignedCertificate();
59+
client = new SelfSignedCertificate();
60+
final SslContext sslContext = SslContextBuilder.forServer(server.key(), server.cert()).build();
61+
sslTunnelServer = new SslTunnelServer("localhost", 3306, sslContext);
62+
sslTunnelServer.setUp();
63+
}
64+
65+
@AfterEach
66+
void tearDown() throws InterruptedException {
67+
server.delete();
68+
client.delete();
69+
sslTunnelServer.tearDown();
70+
}
71+
72+
@Test
73+
void sslTunnelConnectionTest() {
74+
final String password = System.getProperty("test.mysql.password");
75+
assertThat(password).withFailMessage("Property test.mysql.password must exists and not be empty")
76+
.isNotNull()
77+
.isNotEmpty();
78+
79+
final MySqlConnectionConfiguration configuration = MySqlConnectionConfiguration
80+
.builder()
81+
.host("localhost")
82+
.port(sslTunnelServer.getLocalPort())
83+
.connectTimeout(Duration.ofSeconds(3))
84+
.user("root")
85+
.password(password)
86+
.database("r2dbc")
87+
.createDatabaseIfNotExist(true)
88+
.sslMode(SslMode.TUNNEL)
89+
.sslKey(client.privateKey().getAbsolutePath())
90+
.sslCert(client.certificate().getAbsolutePath())
91+
.sslCa(server.certificate().getAbsolutePath())
92+
.build();
93+
94+
final MySqlConnectionFactory connectionFactory = MySqlConnectionFactory.from(configuration);
95+
96+
final MySqlConnection connection = connectionFactory.create().block();
97+
assert null != connection;
98+
connection.createStatement("SELECT 3").execute()
99+
.flatMap(it -> it.map((row, rowMetadata) -> row.get(0, Long.class)))
100+
.doOnNext(it -> assertThat(it).isEqualTo(3L))
101+
.blockLast();
102+
103+
connection.close().block();
104+
}
105+
106+
private static class SslTunnelServer {
107+
108+
private final String remoteHost;
109+
110+
private final int remotePort;
111+
112+
private final SslContext sslContext;
113+
114+
private volatile ChannelFuture channelFuture;
115+
116+
117+
private SslTunnelServer(String remoteHost, int remotePort, SslContext sslContext) {
118+
this.remoteHost = remoteHost;
119+
this.remotePort = remotePort;
120+
this.sslContext = sslContext;
121+
}
122+
123+
void setUp() throws InterruptedException {
124+
// Configure the server.
125+
ServerBootstrap b = new ServerBootstrap();
126+
b.localAddress(0)
127+
.group(new NioEventLoopGroup())
128+
.channel(NioServerSocketChannel.class)
129+
.childHandler(new ProxyInitializer(remoteHost, remotePort, sslContext))
130+
.childOption(ChannelOption.AUTO_READ, false);
131+
132+
// Start the server.
133+
channelFuture = b.bind().sync();
134+
}
135+
136+
void tearDown() throws InterruptedException {
137+
channelFuture.channel().close().sync();
138+
}
139+
140+
int getLocalPort() {
141+
return ((InetSocketAddress) channelFuture.channel().localAddress()).getPort();
142+
}
143+
144+
}
145+
146+
147+
private static class ProxyInitializer extends ChannelInitializer<SocketChannel> {
148+
149+
private final String remoteHost;
150+
151+
private final int remotePort;
152+
153+
private final SslContext sslContext;
154+
155+
ProxyInitializer(String remoteHost, int remotePort, SslContext sslContext) {
156+
this.remoteHost = remoteHost;
157+
this.remotePort = remotePort;
158+
this.sslContext = sslContext;
159+
}
160+
161+
@Override
162+
public void initChannel(SocketChannel ch) {
163+
ch.pipeline().addLast(sslContext.newHandler(ch.alloc()));
164+
ch.pipeline().addLast(new ProxyFrontendHandler(remoteHost, remotePort));
165+
}
166+
}
167+
168+
private static class ProxyFrontendHandler extends ChannelInboundHandlerAdapter {
169+
170+
private final String remoteHost;
171+
private final int remotePort;
172+
173+
// As we use inboundChannel.eventLoop() when building the Bootstrap this does not need to be volatile as
174+
// the outboundChannel will use the same EventLoop (and therefore Thread) as the inboundChannel.
175+
private Channel outboundChannel;
176+
177+
private ProxyFrontendHandler(String remoteHost, int remotePort) {
178+
this.remoteHost = remoteHost;
179+
this.remotePort = remotePort;
180+
}
181+
182+
@Override
183+
public void channelActive(ChannelHandlerContext ctx) {
184+
final Channel inboundChannel = ctx.channel();
185+
186+
// Start the connection attempt.
187+
Bootstrap b = new Bootstrap();
188+
b.group(inboundChannel.eventLoop())
189+
.channel(ctx.channel().getClass())
190+
.handler(new ProxyBackendHandler(inboundChannel))
191+
.option(ChannelOption.AUTO_READ, false);
192+
ChannelFuture f = b.connect(remoteHost, remotePort);
193+
outboundChannel = f.channel();
194+
f.addListener((ChannelFutureListener) future -> {
195+
if (future.isSuccess()) {
196+
// connection complete start to read first data
197+
inboundChannel.read();
198+
} else {
199+
// Close the connection if the connection attempt has failed.
200+
inboundChannel.close();
201+
}
202+
});
203+
}
204+
205+
@Override
206+
public void channelRead(final ChannelHandlerContext ctx, Object msg) {
207+
if (outboundChannel.isActive()) {
208+
outboundChannel.writeAndFlush(msg).addListener((ChannelFutureListener) future -> {
209+
if (future.isSuccess()) {
210+
// was able to flush out data, start to read the next chunk
211+
ctx.channel().read();
212+
} else {
213+
future.channel().close();
214+
}
215+
});
216+
}
217+
}
218+
219+
@Override
220+
public void channelInactive(ChannelHandlerContext ctx) {
221+
if (outboundChannel != null) {
222+
closeOnFlush(outboundChannel);
223+
}
224+
}
225+
226+
@Override
227+
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
228+
cause.printStackTrace();
229+
closeOnFlush(ctx.channel());
230+
}
231+
232+
/**
233+
* Closes the specified channel after all queued write requests are flushed.
234+
*/
235+
static void closeOnFlush(Channel ch) {
236+
if (ch.isActive()) {
237+
ch.writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE);
238+
}
239+
}
240+
}
241+
242+
private static class ProxyBackendHandler extends ChannelInboundHandlerAdapter {
243+
244+
private final Channel inboundChannel;
245+
246+
private ProxyBackendHandler(Channel inboundChannel) {
247+
this.inboundChannel = inboundChannel;
248+
}
249+
250+
@Override
251+
public void channelActive(ChannelHandlerContext ctx) {
252+
if (!inboundChannel.isActive()) {
253+
ProxyFrontendHandler.closeOnFlush(ctx.channel());
254+
} else {
255+
ctx.read();
256+
}
257+
}
258+
259+
@Override
260+
public void channelRead(final ChannelHandlerContext ctx, Object msg) {
261+
inboundChannel.writeAndFlush(msg).addListener((ChannelFutureListener) future -> {
262+
if (future.isSuccess()) {
263+
ctx.channel().read();
264+
} else {
265+
future.channel().close();
266+
}
267+
});
268+
}
269+
270+
@Override
271+
public void channelInactive(ChannelHandlerContext ctx) {
272+
ProxyFrontendHandler.closeOnFlush(inboundChannel);
273+
}
274+
275+
@Override
276+
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
277+
cause.printStackTrace();
278+
ProxyFrontendHandler.closeOnFlush(ctx.channel());
279+
}
280+
}
281+
282+
}

0 commit comments

Comments
 (0)