diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcOperations.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcOperations.java index b335ee3d259a..7e7ce78ca1e1 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcOperations.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcOperations.java @@ -16,6 +16,7 @@ package org.springframework.jdbc.core.namedparam; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.stream.Stream; @@ -586,4 +587,16 @@ int update(String sql, SqlParameterSource paramSource, KeyHolder generatedKeyHol int[] batchUpdate(String sql, SqlParameterSource[] batchArgs, KeyHolder generatedKeyHolder, String[] keyColumnNames); + /** + * Executes the specified SQL update statement in multiple batches using the provided batch arguments. + * @param sql the SQL statement to execute + * @param batchArgs the collection of {@link SqlParameterSource} containing arguments for the query + * @param batchSize batch size + * @return a two-dimensional array containing results for each batch execution. + * (may also contain special JDBC-defined negative values for affected rows such as + * {@link java.sql.Statement#SUCCESS_NO_INFO}/{@link java.sql.Statement#EXECUTE_FAILED}) + * @throws DataAccessException if there is any problem issuing the update + */ + int[][] batchUpdate(String sql, Collection batchArgs, int batchSize); + } diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplate.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplate.java index 2a9e94892946..d0878c1a3ac2 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplate.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplate.java @@ -18,6 +18,7 @@ import java.sql.PreparedStatement; import java.sql.SQLException; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.function.Consumer; @@ -423,6 +424,21 @@ public int getBatchSize() { }, generatedKeyHolder); } + @Override + public int[][] batchUpdate(String sql, Collection batchArgs, int batchSize) { + if (batchArgs.isEmpty() || batchSize <= 0) { + return new int[0][0]; + } + + ParsedSql parsedSql = getParsedSql(sql); + SqlParameterSource sqlParameterSource = batchArgs.iterator().next(); + PreparedStatementCreatorFactory pscf = getPreparedStatementCreatorFactory(parsedSql, sqlParameterSource); + + return getJdbcOperations().batchUpdate(pscf.getSql(), batchArgs, batchSize, (ps, argument) -> { + @Nullable Object[] values = NamedParameterUtils.buildValueArray(parsedSql, argument, null); + pscf.newPreparedStatementSetter(values).setValues(ps); + }); + } /** * Build a {@link PreparedStatementCreator} based on the given SQL and named parameters. diff --git a/spring-jdbc/src/test/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplateTests.java b/spring-jdbc/src/test/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplateTests.java index 37e661c8990f..9928b19d1acf 100644 --- a/spring-jdbc/src/test/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplateTests.java +++ b/spring-jdbc/src/test/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplateTests.java @@ -53,6 +53,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; /** * @author Rick Evans @@ -580,4 +581,104 @@ void testBatchUpdateWithSqlParameterSourcePlusTypeInfo() throws Exception { verify(connection, atLeastOnce()).close(); } + + @Test + void testMultipleBatchUpdateWithSqlParameterSource() throws Exception { + List ids = List.of( + new MapSqlParameterSource("id", 100), + new MapSqlParameterSource("id", 200), + new MapSqlParameterSource("id", 300), + new MapSqlParameterSource("id", 400), + new MapSqlParameterSource("id", 500) + ); + + int[] rowsAffected1 = new int[]{1, 2}; + int[] rowsAffected2 = new int[]{3, 4}; + int[] rowsAffected3 = new int[]{5}; + + given(preparedStatement.executeBatch()).willReturn(rowsAffected1, rowsAffected2, rowsAffected3); + given(connection.getMetaData()).willReturn(databaseMetaData); + + namedParameterTemplate = new NamedParameterJdbcTemplate(new JdbcTemplate(dataSource, false)); + + int batchSize = 2; + + int[][] actualRowsAffected = namedParameterTemplate.batchUpdate( + "UPDATE NOSUCHTABLE SET DATE_DISPATCHED = SYSDATE WHERE ID = :id", ids, batchSize); + + assertThat(actualRowsAffected.length).as("executed 3 batches").isEqualTo(3); + assertThat(actualRowsAffected[0]).isEqualTo(rowsAffected1); + assertThat(actualRowsAffected[1]).isEqualTo(rowsAffected2); + assertThat(actualRowsAffected[2]).isEqualTo(rowsAffected3); + + verify(preparedStatement, times(5)).addBatch(); + verify(connection, times(1)).prepareStatement("UPDATE NOSUCHTABLE SET DATE_DISPATCHED = SYSDATE WHERE ID = ?"); + verify(preparedStatement).setObject(1, 100); + verify(preparedStatement).setObject(1, 200); + verify(preparedStatement).setObject(1, 300); + verify(preparedStatement).setObject(1, 400); + verify(preparedStatement).setObject(1, 500); + verify(preparedStatement, atLeastOnce()).close(); + verify(connection, atLeastOnce()).close(); + } + + @Test + void testMultipleBatchUpdateWithEmptySqlParameterSourceArg() { + namedParameterTemplate = new NamedParameterJdbcTemplate(new JdbcTemplate(dataSource, false)); + + int batchSize = 2; + + int[][] actualRowsAffected = namedParameterTemplate.batchUpdate( + "UPDATE NOSUCHTABLE SET DATE_DISPATCHED = SYSDATE WHERE ID = :id", Collections.emptyList(), batchSize); + + assertThat(actualRowsAffected.length).as("executed 0 batches").isZero(); + verifyNoInteractions(preparedStatement); + verifyNoInteractions(connection); + } + + @Test + void testMultipleBatchUpdateWithSqlParameterSourceWithZeroBatchSize() { + namedParameterTemplate = new NamedParameterJdbcTemplate(new JdbcTemplate(dataSource, false)); + + int batchSize = 0; + + int[][] actualRowsAffected = namedParameterTemplate.batchUpdate( + "UPDATE NOSUCHTABLE SET DATE_DISPATCHED = SYSDATE WHERE ID = :id", + List.of(new MapSqlParameterSource("id", 100)), + batchSize); + + assertThat(actualRowsAffected.length).as("executed 0 batches").isZero(); + verifyNoInteractions(preparedStatement); + verifyNoInteractions(connection); + } + + @Test + void testMultipleBatchUpdateWithSqlParameterSourceSmallerThanBatchSize() throws Exception { + List ids = List.of( + new MapSqlParameterSource("id", 100), + new MapSqlParameterSource("id", 200) + ); + + int[] rowsAffected = new int[]{1, 2}; + + given(preparedStatement.executeBatch()).willReturn(rowsAffected); + given(connection.getMetaData()).willReturn(databaseMetaData); + + namedParameterTemplate = new NamedParameterJdbcTemplate(new JdbcTemplate(dataSource, false)); + + int batchSize = 3; + + int[][] actualRowsAffected = namedParameterTemplate.batchUpdate( + "UPDATE NOSUCHTABLE SET DATE_DISPATCHED = SYSDATE WHERE ID = :id", ids, batchSize); + + assertThat(actualRowsAffected.length).as("executed 1 batch").isEqualTo(1); + assertThat(actualRowsAffected[0]).isEqualTo(rowsAffected); + + verify(preparedStatement, times(2)).addBatch(); + verify(connection, times(1)).prepareStatement("UPDATE NOSUCHTABLE SET DATE_DISPATCHED = SYSDATE WHERE ID = ?"); + verify(preparedStatement).setObject(1, 100); + verify(preparedStatement).setObject(1, 200); + verify(preparedStatement, atLeastOnce()).close(); + verify(connection, atLeastOnce()).close(); + } }