Skip to content

Commit e9075de

Browse files
committed
fix: remove client certificate rotate callback (#64)
It is not necessary to disconnect the MQTT Bridge when the client certificate rotates. This change removes that logic, but still ensures that the new client certificate is used the next time the bridge connects. CA rotation still results in a disconnect, as that indicates that the client should no longer trust the server.
1 parent 12577f9 commit e9075de

4 files changed

Lines changed: 50 additions & 27 deletions

File tree

src/main/java/com/aws/greengrass/mqttbridge/auth/MQTTClientKeyStore.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public class MQTTClientKeyStore {
5656

5757
@FunctionalInterface
5858
public interface UpdateListener {
59-
void onUpdate();
59+
void onCAUpdate();
6060
}
6161

6262
/**
@@ -108,9 +108,8 @@ private KeyPair newRSAKeyPair() throws NoSuchAlgorithmException {
108108

109109
private void updateCert(X509Certificate... certChain) {
110110
try {
111+
LOGGER.atDebug().log("Storing new client certificate to be used on next connect attempt");
111112
keyStore.setKeyEntry(KEY_ALIAS, keyPair.getPrivate(), DEFAULT_KEYSTORE_PASSWORD, certChain);
112-
113-
updateListeners.forEach(UpdateListener::onUpdate); //notify MQTTClient
114113
} catch (KeyStoreException e) {
115114
LOGGER.atError("Unable to store generated cert", e);
116115
}
@@ -139,7 +138,7 @@ public void updateCA(List<String> caCerts) throws IOException, CertificateExcept
139138
keyStore.setCertificateEntry("CA" + i, caCert);
140139
}
141140

142-
updateListeners.forEach(UpdateListener::onUpdate); //notify MQTTClient
141+
updateListeners.forEach(UpdateListener::onCAUpdate); //notify MQTTClient
143142
}
144143

145144
private X509Certificate pemToX509Certificate(String certPem) throws IOException, CertificateException {
@@ -156,7 +155,7 @@ private X509Certificate pemToX509Certificate(String certPem) throws IOException,
156155
* Add listener to listen to KeyStore updates.
157156
* @param listener listener method
158157
*/
159-
public synchronized void listenToUpdates(UpdateListener listener) {
158+
public synchronized void listenToCAUpdates(UpdateListener listener) {
160159
updateListeners.add(listener);
161160
}
162161

@@ -177,7 +176,7 @@ public SSLSocketFactory getSSLSocketFactory() throws KeyStoreException {
177176
SSLContext sc = SSLContext.getInstance("TLS");
178177
sc.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
179178
return sc.getSocketFactory();
180-
} catch (NoSuchAlgorithmException | KeyStoreException | UnrecoverableKeyException | KeyManagementException e) {
179+
} catch (NoSuchAlgorithmException | UnrecoverableKeyException | KeyManagementException e) {
181180
throw new KeyStoreException("Unable to create SocketFactory from KeyStore", e);
182181
}
183182
}

src/main/java/com/aws/greengrass/mqttbridge/clients/MQTTClient.java

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ public class MQTTClient implements MessageClient {
3838
private static final int MIN_WAIT_RETRY_IN_SECONDS = 1;
3939
private static final int MAX_WAIT_RETRY_IN_SECONDS = 120;
4040

41-
private final MqttConnectOptions connOpts = new MqttConnectOptions();
4241
private Consumer<Message> messageHandler;
4342
private final URI brokerUri;
4443
private final String clientId;
@@ -104,7 +103,7 @@ protected MQTTClient(@NonNull URI brokerUri, @NonNull String clientId, MQTTClien
104103
this.mqttClientInternal = mqttClient;
105104
this.dataStore = new MemoryPersistence();
106105
this.mqttClientKeyStore = mqttClientKeyStore;
107-
this.mqttClientKeyStore.listenToUpdates(this::reset);
106+
this.mqttClientKeyStore.listenToCAUpdates(this::reset);
108107
this.executorService = executorService;
109108
}
110109

@@ -231,19 +230,23 @@ private synchronized void updateSubscriptionsInternal() {
231230
});
232231
}
233232

234-
private synchronized void connectAndSubscribe() throws KeyStoreException {
235-
if (connectFuture != null) {
236-
connectFuture.cancel(true);
237-
}
238-
239-
//TODO: persistent session could be used
233+
private MqttConnectOptions getConnectionOptions() throws KeyStoreException {
234+
MqttConnectOptions connOpts = new MqttConnectOptions();
240235
connOpts.setCleanSession(true);
241236

242237
if ("ssl".equalsIgnoreCase(brokerUri.getScheme())) {
243238
SSLSocketFactory ssf = mqttClientKeyStore.getSSLSocketFactory();
244239
connOpts.setSocketFactory(ssf);
245240
}
246241

242+
return connOpts;
243+
}
244+
245+
private synchronized void connectAndSubscribe() throws KeyStoreException {
246+
if (connectFuture != null) {
247+
connectFuture.cancel(true);
248+
}
249+
247250
LOGGER.atInfo()
248251
.kv(BridgeConfig.KEY_BROKER_URI, brokerUri)
249252
.kv(BridgeConfig.KEY_CLIENT_ID, clientId)
@@ -252,9 +255,9 @@ private synchronized void connectAndSubscribe() throws KeyStoreException {
252255
connectFuture = executorService.submit(this::reconnectAndResubscribe);
253256
}
254257

255-
private synchronized void doConnect() throws MqttException {
258+
private synchronized void doConnect() throws MqttException, KeyStoreException {
256259
if (!mqttClientInternal.isConnected()) {
257-
mqttClientInternal.connect(connOpts);
260+
mqttClientInternal.connect(getConnectionOptions());
258261
LOGGER.atInfo()
259262
.kv(BridgeConfig.KEY_BROKER_URI, brokerUri)
260263
.kv(BridgeConfig.KEY_CLIENT_ID, clientId)
@@ -269,7 +272,7 @@ private void reconnectAndResubscribe() {
269272
try {
270273
// TODO: Clean up this loop
271274
doConnect();
272-
} catch (MqttException e) {
275+
} catch (MqttException | KeyStoreException e) {
273276
LOGGER.atDebug().setCause(e)
274277
.log("Unable to connect. Will be retried after {} seconds", waitBeforeRetry);
275278
try {

src/test/java/com/aws/greengrass/mqttbridge/auth/MQTTClientKeyStoreTest.java

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343

4444
@ExtendWith({MockitoExtension.class, GGExtension.class})
4545
public class MQTTClientKeyStoreTest {
46-
public static final String CERTIFICATE = "-----BEGIN CERTIFICATE-----\r\n"
46+
private static final String CERTIFICATE = "-----BEGIN CERTIFICATE-----\r\n"
4747
+ "MIICujCCAaICCQCQcEEQmGoJqjANBgkqhkiG9w0BAQUFADAfMR0wGwYDVQQDDBRt\r\n"
4848
+ "b3F1ZXR0ZS5lY2xpcHNlLm9yZzAeFw0yMDA3MjExODA2MzdaFw0yMTA3MTYxODA2\r\n"
4949
+ "MzdaMB8xHTAbBgNVBAMMFG1vcXVldHRlLmVjbGlwc2Uub3JnMIIBIjANBgkqhkiG\r\n"
@@ -67,11 +67,9 @@ public class MQTTClientKeyStoreTest {
6767
private CertificateManager mockCertificateManager;
6868

6969
@Test
70-
void GIVEN_MQTTClientKeyStore_WHEN_initialized_THEN_key_and_cert_generated() throws Exception {
70+
void GIVEN_MQTTClientKeyStore_WHEN_initialized_THEN_keyAndCertGenerated() throws Exception {
7171
MQTTClientKeyStore mqttClientKeyStore = new MQTTClientKeyStore(mockCertificateManager);
7272
mqttClientKeyStore.init();
73-
CountDownLatch updateLatch = new CountDownLatch(1);
74-
mqttClientKeyStore.listenToUpdates(updateLatch::countDown);
7573

7674
ArgumentCaptor<Consumer<X509Certificate[]>> cbArgumentCaptor = ArgumentCaptor.forClass(Consumer.class);
7775
verify(mockCertificateManager, times(1))
@@ -84,7 +82,6 @@ void GIVEN_MQTTClientKeyStore_WHEN_initialized_THEN_key_and_cert_generated() thr
8482
X509Certificate certificate = pemToX509Certificate(CERTIFICATE);
8583
X509Certificate[] chain = {certificate, certificate};
8684
certCallback.accept(chain);
87-
assertThat(updateLatch.await(100, TimeUnit.MILLISECONDS), is(true));
8885
assertThat(keyStore.size(), is(1));
8986

9087
PrivateKey privateKey = (PrivateKey) keyStore.getKey(KEY_ALIAS, DEFAULT_KEYSTORE_PASSWORD);
@@ -106,7 +103,7 @@ void GIVEN_MQTTClientKeyStore_WHEN_called_updateCA_THEN_CA_stored() throws Excep
106103
MQTTClientKeyStore mqttClientKeyStore = new MQTTClientKeyStore(mockCertificateManager);
107104
mqttClientKeyStore.init();
108105
CountDownLatch updateLatch = new CountDownLatch(1);
109-
mqttClientKeyStore.listenToUpdates(updateLatch::countDown);
106+
mqttClientKeyStore.listenToCAUpdates(updateLatch::countDown);
110107

111108
KeyStore keyStore = mqttClientKeyStore.getKeyStore();
112109
assertThat(keyStore.size(), is(0));
@@ -122,11 +119,11 @@ void GIVEN_MQTTClientKeyStore_WHEN_called_updateCA_THEN_CA_stored() throws Excep
122119
}
123120

124121
@Test
125-
void GIVEN_MQTTClientKeyStore_WHEN_called_getSSLSocketFactory_THEN_returns_SSLSocketFactory() throws Exception {
122+
void GIVEN_MQTTClientKeyStore_WHEN_getSSLSocketFactory_THEN_returns_SSLSocketFactory() throws Exception {
126123
MQTTClientKeyStore mqttClientKeyStore = new MQTTClientKeyStore(mockCertificateManager);
127124
mqttClientKeyStore.init();
128-
CountDownLatch updateLatch = new CountDownLatch(2);
129-
mqttClientKeyStore.listenToUpdates(updateLatch::countDown);
125+
CountDownLatch updateLatch = new CountDownLatch(1);
126+
mqttClientKeyStore.listenToCAUpdates(updateLatch::countDown);
130127

131128
ArgumentCaptor<Consumer<X509Certificate[]>> cbArgumentCaptor = ArgumentCaptor.forClass(Consumer.class);
132129
verify(mockCertificateManager, times(1))

src/test/java/com/aws/greengrass/mqttbridge/clients/MQTTClientTest.java

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import com.aws.greengrass.mqttbridge.Message;
99
import com.aws.greengrass.mqttbridge.auth.MQTTClientKeyStore;
1010
import com.aws.greengrass.testcommons.testutilities.GGExtension;
11+
import org.eclipse.paho.client.mqttv3.MqttConnectOptions;
1112
import org.eclipse.paho.client.mqttv3.MqttMessage;
1213
import org.junit.jupiter.api.AfterEach;
1314
import org.junit.jupiter.api.BeforeEach;
@@ -206,7 +207,7 @@ void GIVEN_mqttClient_WHEN_connectionLost_THEN_clientReconnectsAndResubscribes()
206207
}
207208

208209
@Test
209-
void GIVEN_mqttClient_WHEN_reset_THEN_connectsWithUpdatedSslContext() throws Exception {
210+
void GIVEN_mqttClient_WHEN_caRotates_THEN_connectsWithUpdatedSslContext() throws Exception {
210211
MQTTClientKeyStore mockKeyStore = mock(MQTTClientKeyStore.class);
211212
MQTTClient mqttClient = new MQTTClient(ENCRYPTED_URI, CLIENT_ID, mockKeyStore, ses, fakeMqttClient);
212213
mqttClient.start();
@@ -223,4 +224,27 @@ void GIVEN_mqttClient_WHEN_reset_THEN_connectsWithUpdatedSslContext() throws Exc
223224
assertThat(fakeMqttClient.getConnectOptions().getSocketFactory(), is(mockSocketFactory));
224225
assertThat(fakeMqttClient.getConnectCount(), is(2));
225226
}
227+
228+
@Test
229+
void GIVEN_mqttClient_WHEN_clientCertRotates_THEN_newCertIsUsedUponSubsequentReconnects() throws Exception {
230+
SSLSocketFactory mockSocketFactory1 = mock(SSLSocketFactory.class);
231+
SSLSocketFactory mockSocketFactory2 = mock(SSLSocketFactory.class);
232+
when(mockMqttClientKeyStore.getSSLSocketFactory()).thenReturn(mockSocketFactory1);
233+
234+
MQTTClient mqttClient = new MQTTClient(ENCRYPTED_URI, CLIENT_ID, mockMqttClientKeyStore, ses, fakeMqttClient);
235+
mqttClient.start();
236+
fakeMqttClient.waitForConnect(1000);
237+
238+
assertThat(fakeMqttClient.isConnected(), is(true));
239+
MqttConnectOptions connectOptions = fakeMqttClient.getConnectOptions();
240+
assertThat(connectOptions.getSocketFactory(), is(mockSocketFactory1));
241+
242+
// Update socket factory and inject a connection loss
243+
when(mockMqttClientKeyStore.getSSLSocketFactory()).thenReturn(mockSocketFactory2);
244+
fakeMqttClient.injectConnectionLoss();
245+
246+
assertThat(fakeMqttClient.isConnected(), is(true));
247+
connectOptions = fakeMqttClient.getConnectOptions();
248+
assertThat(connectOptions.getSocketFactory(), is(mockSocketFactory2));
249+
}
226250
}

0 commit comments

Comments
 (0)