Skip to content

Commit bfde3ee

Browse files
authored
Merge pull request #68 from HanzhouTang/dsql
Fix stale cached PreparedStatements and duplicate batch inserts
2 parents 7904809 + f33b5b2 commit bfde3ee

File tree

9 files changed

+167
-124
lines changed

9 files changed

+167
-124
lines changed

src/main/java/com/oltpbenchmark/benchmarks/tpcc/custom/auroradsql/BatchProcessor.java

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -51,58 +51,54 @@ public BatchProcessor(int batchSize, BiConsumer<PreparedStatement, T> statementS
5151
}
5252

5353
/**
54-
* Adds an item to the batch. Automatically flushes if batch size is reached.
54+
* Adds an item to the batch. Does NOT auto-flush. Caller is responsible for calling flush() when
55+
* ready.
5556
*
5657
* @param item The item to add
57-
* @param statement The prepared statement to use for execution
58-
* @throws SQLException if database operation fails
5958
*/
60-
public void add(T item, PreparedStatement statement) throws SQLException {
59+
public void add(T item) {
6160
batch.add(item);
62-
63-
if (batch.size() >= batchSize) {
64-
flush(statement);
65-
}
6661
}
6762

6863
/**
69-
* Flushes any remaining items in the batch.
64+
* Flushes items in the batch, processing in chunks of batchSize. This handles cases where the
65+
* batch may have accumulated more items than batchSize. Only removes items from the batch after
66+
* successful execution.
7067
*
7168
* @param statement The prepared statement to use for execution
7269
* @throws SQLException if database operation fails
7370
*/
7471
public void flush(PreparedStatement statement) throws SQLException {
75-
if (batch.isEmpty()) {
76-
return;
77-
}
78-
try {
79-
executeBatch(statement);
80-
} catch (SQLException e) {
81-
throw e;
82-
} catch (Exception e) {
83-
throw new SQLException("Failed to execute batch", e);
84-
}
85-
}
72+
while (!batch.isEmpty()) {
73+
int itemsToFlush = Math.min(batch.size(), batchSize);
74+
List<T> currentBatch = batch.subList(0, itemsToFlush);
8675

87-
/** Executes the current batch. */
88-
private void executeBatch(PreparedStatement statement) throws SQLException {
89-
for (T item : batch) {
90-
statementSetter.accept(statement, item);
91-
statement.addBatch();
92-
}
76+
// Execute this chunk
77+
for (T item : currentBatch) {
78+
statementSetter.accept(statement, item);
79+
statement.addBatch();
80+
}
9381

94-
statement.executeBatch();
95-
statement.clearBatch();
82+
statement.executeBatch();
83+
statement.clearBatch();
9684

97-
log.debug("Executed batch of {} items", batch.size());
98-
batch.clear();
85+
log.debug("Executed batch of {} items", itemsToFlush);
86+
87+
// Only remove if successful
88+
currentBatch.clear(); // Removes from original batch
89+
}
9990
}
10091

10192
/** Returns the current number of items in the batch. */
10293
public int size() {
10394
return batch.size();
10495
}
10596

97+
/** Returns true if the batch has reached the configured batch size and should be flushed. */
98+
public boolean shouldFlush() {
99+
return batch.size() >= batchSize;
100+
}
101+
106102
/** Clears the batch without executing. */
107103
public void clear() {
108104
batch.clear();

src/main/java/com/oltpbenchmark/benchmarks/tpcc/custom/auroradsql/ConnectionManager.java

Lines changed: 57 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ public class ConnectionManager {
3838

3939
private final BenchmarkModule benchmark;
4040
private final ConcurrentMap<String, Pair<Connection, Long>> connections;
41-
private final ConcurrentMap<String, PreparedStatement> statements;
41+
private final ConcurrentMap<String, ConcurrentMap<String, PreparedStatement>> statementsByThread;
4242

4343
public ConnectionManager(BenchmarkModule benchmark) {
4444
this.benchmark = benchmark;
4545
this.connections = new ConcurrentHashMap<>();
46-
this.statements = new ConcurrentHashMap<>();
46+
this.statementsByThread = new ConcurrentHashMap<>();
4747
}
4848

4949
/**
@@ -69,56 +69,75 @@ public void createConnection(String threadName) throws SQLException {
6969

7070
/** Refreshes the connection for a thread, closing the old one if it exists. */
7171
public void refreshConnection(String threadName) throws SQLException {
72+
// Clear all statements for this thread FIRST
73+
ConcurrentMap<String, PreparedStatement> threadStatements =
74+
statementsByThread.remove(threadName);
75+
76+
if (threadStatements != null) {
77+
threadStatements.forEach(
78+
(tableName, stmt) -> {
79+
try {
80+
stmt.close();
81+
} catch (SQLException e) {
82+
log.error("Failed to close statement for {}/{}", threadName, tableName, e);
83+
}
84+
});
85+
}
86+
87+
// Then refresh connection
7288
closeConnectionForThread(threadName);
7389
createConnection(threadName);
7490
}
7591

7692
/** Gets or creates a prepared statement for the given thread and table. */
7793
public PreparedStatement getPreparedStatement(String threadName, String tableName, String sql)
7894
throws SQLException {
79-
String key = generateStatementKey(threadName, tableName);
80-
PreparedStatement stmt = statements.get(key);
95+
ConcurrentMap<String, PreparedStatement> threadStatements =
96+
statementsByThread.computeIfAbsent(threadName, k -> new ConcurrentHashMap<>());
97+
98+
PreparedStatement stmt = threadStatements.get(tableName);
8199

82100
if (stmt == null || stmt.isClosed()) {
83101
Connection conn = getConnection(threadName);
84102
stmt = conn.prepareStatement(sql);
85-
statements.put(key, stmt);
103+
threadStatements.put(tableName, stmt);
86104
}
87105

88106
return stmt;
89107
}
90108

91109
/** Closes a specific prepared statement. */
92110
public void closePreparedStatement(String threadName, String tableName) {
93-
String key = generateStatementKey(threadName, tableName);
94-
PreparedStatement stmt = statements.remove(key);
95-
96-
if (stmt != null) {
97-
try {
98-
stmt.close();
99-
} catch (SQLException e) {
100-
log.error("Failed to close PreparedStatement for {}", key, e);
111+
ConcurrentMap<String, PreparedStatement> threadStatements = statementsByThread.get(threadName);
112+
113+
if (threadStatements != null) {
114+
PreparedStatement stmt = threadStatements.remove(tableName);
115+
if (stmt != null) {
116+
try {
117+
stmt.close();
118+
} catch (SQLException e) {
119+
log.error("Failed to close PreparedStatement for {}/{}", threadName, tableName, e);
120+
}
101121
}
102122
}
103123
}
104124

105125
/** Closes all resources for a specific thread. */
106126
public void closeResourcesForThread(String threadName) {
107127
// Close all statements for this thread
108-
statements
109-
.entrySet()
110-
.removeIf(
111-
entry -> {
112-
if (entry.getKey().startsWith(threadName)) {
113-
try {
114-
entry.getValue().close();
115-
} catch (SQLException e) {
116-
log.error("Failed to close statement: {}", entry.getKey(), e);
117-
}
118-
return true;
119-
}
120-
return false;
121-
});
128+
ConcurrentMap<String, PreparedStatement> threadStatements =
129+
statementsByThread.remove(threadName);
130+
131+
if (threadStatements != null) {
132+
threadStatements.forEach(
133+
(tableName, stmt) -> {
134+
try {
135+
stmt.close();
136+
} catch (SQLException e) {
137+
log.error("Failed to close statement for {}/{}", threadName, tableName, e);
138+
}
139+
});
140+
}
122141

123142
// Close connection
124143
closeConnectionForThread(threadName);
@@ -127,15 +146,18 @@ public void closeResourcesForThread(String threadName) {
127146
/** Closes all connections and statements. */
128147
public void closeAll() {
129148
// Close all statements
130-
statements.forEach(
131-
(key, stmt) -> {
132-
try {
133-
stmt.close();
134-
} catch (SQLException e) {
135-
log.error("Failed to close statement: {}", key, e);
136-
}
149+
statementsByThread.forEach(
150+
(threadName, threadStatements) -> {
151+
threadStatements.forEach(
152+
(tableName, stmt) -> {
153+
try {
154+
stmt.close();
155+
} catch (SQLException e) {
156+
log.error("Failed to close statement for {}/{}", threadName, tableName, e);
157+
}
158+
});
137159
});
138-
statements.clear();
160+
statementsByThread.clear();
139161

140162
// Close all connections
141163
connections.forEach(
@@ -167,8 +189,4 @@ private void closeConnectionForThread(String threadName) {
167189
private boolean isConnectionExpired(Pair<Connection, Long> connectionPair) {
168190
return System.currentTimeMillis() - connectionPair.second >= SESSION_DURATION;
169191
}
170-
171-
private String generateStatementKey(String threadName, String tableName) {
172-
return threadName + "_" + tableName;
173-
}
174192
}

src/main/java/com/oltpbenchmark/benchmarks/tpcc/custom/auroradsql/loaders/CustomerTableLoader.java

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,18 @@ private void loadCustomers(
7070
for (int d = 1; d <= districtsPerWarehouse; d++) {
7171
for (int c = 1; c <= customersPerDistrict; c++) {
7272
Customer customer = generateCustomer(warehouseId, d, c);
73-
74-
executeWithRetry(
75-
() -> {
76-
PreparedStatement stmt = getInsertStatement(threadName);
77-
batchProcessor.add(customer, stmt);
78-
},
79-
threadName,
80-
"Insert Customer");
73+
batchProcessor.add(customer);
74+
75+
// Flush when batch is full
76+
if (batchProcessor.shouldFlush()) {
77+
executeWithRetry(
78+
() -> {
79+
PreparedStatement stmt = getInsertStatement(threadName);
80+
batchProcessor.flush(stmt);
81+
},
82+
threadName,
83+
"Flush customers batch");
84+
}
8185
}
8286
}
8387

src/main/java/com/oltpbenchmark/benchmarks/tpcc/custom/auroradsql/loaders/HistoryTableLoader.java

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,18 @@ private void loadCustomerHistory(
7272
for (int d = 1; d <= districtsPerWarehouse; d++) {
7373
for (int c = 1; c <= customersPerDistrict; c++) {
7474
History history = generateHistory(warehouseId, d, c);
75-
76-
executeWithRetry(
77-
() -> {
78-
PreparedStatement stmt = getInsertStatement(threadName);
79-
batchProcessor.add(history, stmt);
80-
},
81-
threadName,
82-
"Insert History");
75+
batchProcessor.add(history);
76+
77+
// Flush when batch is full
78+
if (batchProcessor.shouldFlush()) {
79+
executeWithRetry(
80+
() -> {
81+
PreparedStatement stmt = getInsertStatement(threadName);
82+
batchProcessor.flush(stmt);
83+
},
84+
threadName,
85+
"Flush history batch");
86+
}
8387
}
8488
}
8589

src/main/java/com/oltpbenchmark/benchmarks/tpcc/custom/auroradsql/loaders/ItemTableLoader.java

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,18 @@ private void loadItems(String threadName, int itemCount) throws SQLException {
6161

6262
for (int i = 1; i <= itemCount; i++) {
6363
Item item = generateItem(i);
64-
65-
executeWithRetry(
66-
() -> {
67-
PreparedStatement stmt = getInsertStatement(threadName);
68-
batchProcessor.add(item, stmt);
69-
},
70-
threadName,
71-
"Insert Item");
64+
batchProcessor.add(item);
65+
66+
// Flush when batch is full
67+
if (batchProcessor.shouldFlush()) {
68+
executeWithRetry(
69+
() -> {
70+
PreparedStatement stmt = getInsertStatement(threadName);
71+
batchProcessor.flush(stmt);
72+
},
73+
threadName,
74+
"Flush items batch");
75+
}
7276
}
7377

7478
// Flush any remaining items

src/main/java/com/oltpbenchmark/benchmarks/tpcc/custom/auroradsql/loaders/NewOrderTableLoader.java

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,18 @@ private void loadNewOrders(
8585
newOrder.no_d_id = d;
8686
newOrder.no_o_id = c;
8787

88-
executeWithRetry(
89-
() -> {
90-
PreparedStatement stmt = getInsertStatement(threadName);
91-
batchProcessor.add(newOrder, stmt);
92-
},
93-
threadName,
94-
"Insert New Order");
88+
batchProcessor.add(newOrder);
89+
90+
// Flush when batch is full
91+
if (batchProcessor.shouldFlush()) {
92+
executeWithRetry(
93+
() -> {
94+
PreparedStatement stmt = getInsertStatement(threadName);
95+
batchProcessor.flush(stmt);
96+
},
97+
threadName,
98+
"Flush new orders batch");
99+
}
95100
}
96101
}
97102
}

src/main/java/com/oltpbenchmark/benchmarks/tpcc/custom/auroradsql/loaders/OrderLineTableLoader.java

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,18 @@ private void loadOrderLines(
8383

8484
for (int l = 1; l <= orderLineCount; l++) {
8585
OrderLine orderLine = generateOrderLine(warehouseId, d, c, l);
86-
87-
executeWithRetry(
88-
() -> {
89-
PreparedStatement stmt = getInsertStatement(threadName);
90-
batchProcessor.add(orderLine, stmt);
91-
},
92-
threadName,
93-
"Insert Order Line");
86+
batchProcessor.add(orderLine);
87+
88+
// Flush when batch is full
89+
if (batchProcessor.shouldFlush()) {
90+
executeWithRetry(
91+
() -> {
92+
PreparedStatement stmt = getInsertStatement(threadName);
93+
batchProcessor.flush(stmt);
94+
},
95+
threadName,
96+
"Flush order lines batch");
97+
}
9498
}
9599
}
96100
}

src/main/java/com/oltpbenchmark/benchmarks/tpcc/custom/auroradsql/loaders/OrderTableLoader.java

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,18 @@ private void loadOpenOrders(
7878

7979
for (int c = 1; c <= customersPerDistrict; c++) {
8080
Oorder order = generateOrder(warehouseId, d, c, c_ids[c - 1]);
81-
82-
executeWithRetry(
83-
() -> {
84-
PreparedStatement stmt = getInsertStatement(threadName);
85-
batchProcessor.add(order, stmt);
86-
},
87-
threadName,
88-
"Insert Order");
81+
batchProcessor.add(order);
82+
83+
// Flush when batch is full
84+
if (batchProcessor.shouldFlush()) {
85+
executeWithRetry(
86+
() -> {
87+
PreparedStatement stmt = getInsertStatement(threadName);
88+
batchProcessor.flush(stmt);
89+
},
90+
threadName,
91+
"Flush orders batch");
92+
}
8993
}
9094
}
9195

0 commit comments

Comments
 (0)