Skip to content

Commit 0a862af

Browse files
authored
Fix secrets rotation handling in Rds source (#5465)
* Fix secrets rotation handling Signed-off-by: Hai Yan <[email protected]> * Update message types Signed-off-by: Hai Yan <[email protected]> --------- Signed-off-by: Hai Yan <[email protected]>
1 parent b0ebd1b commit 0a862af

File tree

9 files changed

+83
-38
lines changed

9 files changed

+83
-38
lines changed

Diff for: data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/model/MessageType.java

+15-14
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,29 @@
1010

1111
package org.opensearch.dataprepper.plugins.source.rds.model;
1212

13+
import java.util.Arrays;
1314
import java.util.Map;
15+
import java.util.stream.Collectors;
1416

1517
public enum MessageType {
1618
BEGIN('B'),
17-
RELATION('R'),
18-
INSERT('I'),
19-
UPDATE('U'),
20-
DELETE('D'),
2119
COMMIT('C'),
22-
TYPE('Y');
20+
DELETE('D'),
21+
INSERT('I'),
22+
MESSAGE('M'),
23+
ORIGIN('O'),
24+
RELATION('R'),
25+
TRUNCATE('T'),
26+
TYPE('Y'),
27+
UPDATE('U');
2328

2429
private final char value;
2530

26-
private static final Map<Character, MessageType> MESSAGE_TYPE_MAP = Map.of(
27-
BEGIN.getValue(), BEGIN,
28-
RELATION.getValue(), RELATION,
29-
INSERT.getValue(), INSERT,
30-
UPDATE.getValue(), UPDATE,
31-
DELETE.getValue(), DELETE,
32-
COMMIT.getValue(), COMMIT,
33-
TYPE.getValue(), TYPE
34-
);
31+
private static final Map<Character, MessageType> MESSAGE_TYPE_MAP = Arrays.stream(values())
32+
.collect(Collectors.toMap(
33+
messageType -> messageType.value,
34+
messageType -> messageType
35+
));
3536

3637
MessageType(char value) {
3738
this.value = value;

Diff for: data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationClient.java

+17-7
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ public class LogicalReplicationClient implements ReplicationLogClient {
3535
private LogSequenceNumber startLsn;
3636
private LogicalReplicationEventProcessor eventProcessor;
3737

38+
private PGReplicationStream stream = null;
3839
private volatile boolean disconnectRequested = false;
3940

4041
public LogicalReplicationClient(final ConnectionManager connectionManager,
@@ -48,7 +49,6 @@ public LogicalReplicationClient(final ConnectionManager connectionManager,
4849
@Override
4950
public void connect() {
5051
LOG.debug("Start connecting logical replication stream. ");
51-
PGReplicationStream stream;
5252
try (Connection conn = connectionManager.getConnection()) {
5353
PGConnection pgConnection = conn.unwrap(PGConnection.class);
5454

@@ -85,20 +85,17 @@ public void connect() {
8585
stream.setAppliedLSN(lsn);
8686
} catch (Exception e) {
8787
LOG.error("Exception while processing Postgres replication stream. ", e);
88-
stream.close();
89-
LOG.debug("Replication stream closed.");
88+
closeStream();
9089
throw e;
9190
}
9291
}
9392
}
9493

95-
stream.close();
96-
LOG.debug("Replication stream closed.");
94+
closeStream();
9795

9896
disconnectRequested = false;
99-
LOG.debug("Replication stream closed successfully.");
10097
} catch (Exception e) {
101-
LOG.error("Exception while creating Postgres replication stream. ", e);
98+
LOG.error("Exception while creating or processing Postgres replication stream. ", e);
10299
throw new RuntimeException(e);
103100
}
104101
}
@@ -108,6 +105,8 @@ public void disconnect() {
108105
disconnectRequested = true;
109106
LOG.debug("Requested to disconnect logical replication stream.");
110107

108+
closeStream();
109+
111110
if (eventProcessor != null) {
112111
eventProcessor.stopCheckpointManager();
113112
LOG.debug("Stopped checkpoint manager.");
@@ -121,4 +120,15 @@ public void setEventProcessor(LogicalReplicationEventProcessor eventProcessor) {
121120
public void setStartLsn(LogSequenceNumber startLsn) {
122121
this.startLsn = startLsn;
123122
}
123+
124+
private void closeStream() {
125+
if (stream != null && !stream.isClosed()) {
126+
try {
127+
stream.close();
128+
LOG.debug("Replication stream closed.");
129+
} catch (Exception e) {
130+
LOG.error("Exception while closing replication stream. ", e);
131+
}
132+
}
133+
}
124134
}

Diff for: data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationEventProcessor.java

+11-2
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,16 @@ public void process(ByteBuffer msg) {
163163
// If it's a RELATION, update table metadata map
164164
// If it's INSERT/UPDATE/DELETE, prepare events
165165
// If it's a COMMIT, convert all prepared events and send to buffer
166-
MessageType messageType = MessageType.from((char) msg.get());
166+
MessageType messageType;
167+
char typeChar = '\0';
168+
try {
169+
typeChar = (char) msg.get();
170+
messageType = MessageType.from(typeChar);
171+
} catch (IllegalArgumentException e) {
172+
LOG.warn("Unknown message type {} received from stream. Skipping.", typeChar);
173+
return;
174+
}
175+
167176
switch (messageType) {
168177
case BEGIN:
169178
handleMessageWithRetries(msg, this::processBeginMessage, messageType);
@@ -187,7 +196,7 @@ public void process(ByteBuffer msg) {
187196
handleMessageWithRetries(msg, this::processTypeMessage, messageType);
188197
break;
189198
default:
190-
throw new IllegalArgumentException("Replication message type [" + messageType + "] is not supported. ");
199+
LOG.debug("Replication message type '{}' is not supported. Skipping.", messageType);
191200
}
192201
}
193202

Diff for: data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/ReplicationLogClientFactory.java

+5-4
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626

2727
public class ReplicationLogClientFactory {
2828

29-
private final RdsSourceConfig sourceConfig;
3029
private final RdsClient rdsClient;
3130
private final DbMetadata dbMetadata;
31+
private RdsSourceConfig sourceConfig;
3232
private String username;
3333
private String password;
3434
private SSLMode sslMode = SSLMode.REQUIRED;
@@ -82,9 +82,10 @@ public void setSSLMode(SSLMode sslMode) {
8282
this.sslMode = sslMode;
8383
}
8484

85-
public void setCredentials(String username, String password) {
86-
this.username = username;
87-
this.password = password;
85+
public void updateCredentials(RdsSourceConfig sourceConfig) {
86+
this.sourceConfig = sourceConfig;
87+
this.username = sourceConfig.getAuthenticationConfig().getUsername();
88+
this.password = sourceConfig.getAuthenticationConfig().getPassword();
8889
}
8990

9091
private String getDatabaseName(List<String> tableNames) {

Diff for: data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorker.java

+9-5
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,15 @@ public void processStream(final StreamPartition streamPartition) {
6666
sourceCoordinator.giveUpPartition(streamPartition);
6767
throw new RuntimeException(e);
6868
} finally {
69-
try {
70-
replicationLogClient.disconnect();
71-
} catch (Exception e) {
72-
LOG.error("Binary log client failed to disconnect.", e);
73-
}
69+
shutdown();
70+
}
71+
}
72+
73+
public void shutdown() {
74+
try {
75+
replicationLogClient.disconnect();
76+
} catch (Exception e) {
77+
LOG.error("Replication log client failed to disconnect.", e);
7478
}
7579
}
7680

Diff for: data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorkerTaskRefresher.java

+4-3
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ public class StreamWorkerTaskRefresher implements PluginConfigObserver<RdsSource
4848

4949
private ExecutorService executorService;
5050
private RdsSourceConfig currentSourceConfig;
51+
private StreamWorker streamWorker;
5152

5253
public StreamWorkerTaskRefresher(final EnhancedSourceCoordinator sourceCoordinator,
5354
final StreamPartition streamPartition,
@@ -96,10 +97,10 @@ public void update(RdsSourceConfig sourceConfig) {
9697
LOG.info("Database credentials were updated. Refreshing stream worker...");
9798
credentialsChangeCounter.increment();
9899
try {
100+
streamWorker.shutdown();
99101
executorService.shutdownNow();
100102
executorService = executorServiceSupplier.get();
101-
replicationLogClientFactory.setCredentials(
102-
sourceConfig.getAuthenticationConfig().getUsername(), sourceConfig.getAuthenticationConfig().getPassword());
103+
replicationLogClientFactory.updateCredentials(sourceConfig);
103104

104105
refreshTask(sourceConfig);
105106

@@ -132,7 +133,7 @@ private void refreshTask(RdsSourceConfig sourceConfig) {
132133
streamPartition, sourceConfig, buffer, s3Prefix, pluginMetrics, logicalReplicationClient,
133134
streamCheckpointer, acknowledgementSetManager));
134135
}
135-
final StreamWorker streamWorker = StreamWorker.create(sourceCoordinator, replicationLogClient, pluginMetrics);
136+
streamWorker = StreamWorker.create(sourceCoordinator, replicationLogClient, pluginMetrics);
136137
executorService.submit(() -> streamWorker.processStream(streamPartition));
137138
}
138139

Diff for: data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationClientTest.java

+3
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ void test_disconnect() throws SQLException, InterruptedException {
119119
when(logicalStreamBuilder.start()).thenReturn(stream);
120120
when(stream.readPending()).thenReturn(message).thenReturn(null);
121121
when(stream.getLastReceiveLSN()).thenReturn(lsn);
122+
when(stream.isClosed()).thenReturn(false, true);
122123

123124
final ExecutorService executorService = Executors.newSingleThreadExecutor();
124125
executorService.submit(() -> logicalReplicationClient.connect());
@@ -155,6 +156,7 @@ void test_connect_disconnect_cycles() throws SQLException, InterruptedException
155156
when(logicalStreamBuilder.start()).thenReturn(stream);
156157
when(stream.readPending()).thenReturn(message).thenReturn(null);
157158
when(stream.getLastReceiveLSN()).thenReturn(lsn);
159+
when(stream.isClosed()).thenReturn(false, true);
158160

159161
// First connect
160162
final ExecutorService executorService = Executors.newSingleThreadExecutor();
@@ -174,6 +176,7 @@ void test_connect_disconnect_cycles() throws SQLException, InterruptedException
174176

175177
// Second connect
176178
when(stream.readPending()).thenReturn(message).thenReturn(null);
179+
when(stream.isClosed()).thenReturn(false, true);
177180
executorService.submit(() -> logicalReplicationClient.connect());
178181
await().atMost(Duration.ofSeconds(1))
179182
.untilAsserted(() -> verify(eventProcessor, times(2)).process(message));

Diff for: data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationEventProcessorTest.java

+13-3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import org.junit.jupiter.api.BeforeEach;
1515
import org.junit.jupiter.api.Test;
1616
import org.junit.jupiter.api.extension.ExtendWith;
17+
import org.junit.jupiter.params.ParameterizedTest;
18+
import org.junit.jupiter.params.provider.ValueSource;
1719
import org.mockito.Answers;
1820
import org.mockito.Mock;
1921
import org.mockito.junit.jupiter.MockitoExtension;
@@ -30,7 +32,7 @@
3032
import java.util.Random;
3133
import java.util.UUID;
3234

33-
import static org.junit.jupiter.api.Assertions.assertThrows;
35+
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
3436
import static org.mockito.ArgumentMatchers.anyString;
3537
import static org.mockito.Mockito.doNothing;
3638
import static org.mockito.Mockito.spy;
@@ -149,13 +151,21 @@ void test_correct_process_method_invoked_for_type_message() {
149151
verify(objectUnderTest).processTypeMessage(message);
150152
}
151153

154+
@ParameterizedTest
155+
@ValueSource(chars = {'M', 'O', 'T'})
156+
void test_unsupported_message_type_not_throw_exception(char typeChar) {
157+
setMessageType(MessageType.from(typeChar));
158+
159+
assertDoesNotThrow(() -> objectUnderTest.process(message));
160+
}
161+
152162
@Test
153-
void test_unsupported_message_type_throws_exception() {
163+
void test_unknown_message_type_not_throw_exception() {
154164
message = ByteBuffer.allocate(1);
155165
message.put((byte) 'A');
156166
message.flip();
157167

158-
assertThrows(IllegalArgumentException.class, () -> objectUnderTest.process(message));
168+
assertDoesNotThrow(() -> objectUnderTest.process(message));
159169
}
160170

161171
@Test

Diff for: data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorkerTest.java

+6
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,12 @@ void test_processStream_without_current_binlog_coordinates() throws IOException
8989
verify(binlogClientWrapper).connect();
9090
}
9191

92+
@Test
93+
void test_shutdown() throws IOException {
94+
streamWorker.shutdown();
95+
verify(binlogClientWrapper).disconnect();
96+
}
97+
9298
private StreamWorker createObjectUnderTest() {
9399
return new StreamWorker(sourceCoordinator, binlogClientWrapper, pluginMetrics);
94100
}

0 commit comments

Comments
 (0)