Skip to content

Commit 0cc2548

Browse files
committed
Add support for DataSourceScriptDatabaseInitializer interception
Intercept Spring Boot's DataSourceScriptDatabaseInitializer via BeanPostProcessor to register SQL init scripts as DatabasePreparer, enabling optimized template-based database initialization and parallel prefetching using ThreadLocal DataSource approach.
1 parent d684bfb commit 0cc2548

12 files changed

Lines changed: 676 additions & 0 deletions

File tree

embedded-database-spring-test/src/main/java/io/zonky/test/db/config/EmbeddedDatabaseAutoConfiguration.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import io.zonky.test.db.flyway.FlywayDatabaseExtension;
2020
import io.zonky.test.db.flyway.FlywayPropertiesPostProcessor;
21+
import io.zonky.test.db.init.DataSourceScriptDatabaseExtension;
2122
import io.zonky.test.db.init.EmbeddedDatabaseInitializer;
2223
import io.zonky.test.db.init.ScriptDatabasePreparer;
2324
import io.zonky.test.db.liquibase.LiquibaseDatabaseExtension;
@@ -289,6 +290,14 @@ public BeanPostProcessor liquibasePropertiesPostProcessor() {
289290
return new LiquibasePropertiesPostProcessor();
290291
}
291292

293+
@Bean
294+
@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
295+
@ConditionalOnClass(name = "org.springframework.boot.jdbc.init.DataSourceScriptDatabaseInitializer")
296+
@ConditionalOnMissingBean(name = "dataSourceScriptDatabaseExtension")
297+
public DataSourceScriptDatabaseExtension dataSourceScriptDatabaseExtension() {
298+
return new DataSourceScriptDatabaseExtension();
299+
}
300+
292301
@Bean
293302
@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
294303
@ConditionalOnMissingBean(name = "embeddedDatabaseInitializer")
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/*
2+
* Copyright 2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.zonky.test.db.init;
18+
19+
import io.zonky.test.db.context.DatabaseContext;
20+
import io.zonky.test.db.util.AopProxyUtils;
21+
import io.zonky.test.db.util.ReflectionUtils;
22+
import org.springframework.beans.factory.config.BeanPostProcessor;
23+
import org.springframework.boot.jdbc.init.DataSourceScriptDatabaseInitializer;
24+
import org.springframework.core.Ordered;
25+
26+
import javax.sql.DataSource;
27+
28+
public class DataSourceScriptDatabaseExtension implements BeanPostProcessor, Ordered {
29+
30+
@Override
31+
public int getOrder() {
32+
return Ordered.HIGHEST_PRECEDENCE + 1;
33+
}
34+
35+
@Override
36+
public Object postProcessBeforeInitialization(Object bean, String beanName) {
37+
if (bean instanceof DataSourceScriptDatabaseInitializer) {
38+
DataSourceScriptDatabaseInitializer initializer = (DataSourceScriptDatabaseInitializer) bean;
39+
DataSource dataSource = ReflectionUtils.getField(initializer, "dataSource");
40+
DatabaseContext context = AopProxyUtils.getDatabaseContext(dataSource);
41+
42+
if (context != null) {
43+
context.apply(new DataSourceScriptDatabasePreparer(initializer));
44+
return new SuppressedDataSourceInitializer(initializer);
45+
}
46+
}
47+
48+
return bean;
49+
}
50+
51+
@Override
52+
public Object postProcessAfterInitialization(Object bean, String beanName) {
53+
return bean;
54+
}
55+
56+
public static class SuppressedDataSourceInitializer {
57+
58+
private final DataSourceScriptDatabaseInitializer originalInitializer;
59+
60+
public SuppressedDataSourceInitializer(DataSourceScriptDatabaseInitializer originalInitializer) {
61+
this.originalInitializer = originalInitializer;
62+
}
63+
64+
public DataSourceScriptDatabaseInitializer getOriginalInitializer() {
65+
return originalInitializer;
66+
}
67+
68+
@Override
69+
public String toString() {
70+
return "SuppressedDataSourceInitializer{originalType=" + originalInitializer.getClass().getName() + "}";
71+
}
72+
}
73+
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright 2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.zonky.test.db.init;
18+
19+
import com.cedarsoftware.util.DeepEquals;
20+
import io.zonky.test.db.preparer.DatabasePreparer;
21+
import io.zonky.test.db.util.ReflectionUtils;
22+
import org.springframework.boot.jdbc.init.DataSourceScriptDatabaseInitializer;
23+
import org.springframework.boot.sql.init.DatabaseInitializationSettings;
24+
25+
import javax.sql.DataSource;
26+
27+
public class DataSourceScriptDatabasePreparer implements DatabasePreparer {
28+
29+
private final DataSourceScriptDatabaseInitializer initializer;
30+
private final ThreadLocalDataSource threadLocalDataSource;
31+
private final DatabaseInitializationSettings settings;
32+
33+
public DataSourceScriptDatabasePreparer(DataSourceScriptDatabaseInitializer initializer) {
34+
this.initializer = initializer;
35+
this.settings = ReflectionUtils.getField(initializer, "settings");
36+
this.threadLocalDataSource = new ThreadLocalDataSource();
37+
ReflectionUtils.setField(initializer, "dataSource", threadLocalDataSource);
38+
}
39+
40+
@Override
41+
public long estimatedDuration() {
42+
return 10;
43+
}
44+
45+
@Override
46+
public void prepare(DataSource dataSource) {
47+
threadLocalDataSource.set(dataSource);
48+
try {
49+
initializer.initializeDatabase();
50+
} finally {
51+
threadLocalDataSource.clear();
52+
}
53+
}
54+
55+
@Override
56+
public boolean equals(Object o) {
57+
if (this == o) return true;
58+
if (o == null || getClass() != o.getClass()) return false;
59+
DataSourceScriptDatabasePreparer that = (DataSourceScriptDatabasePreparer) o;
60+
return initializer.getClass() == that.initializer.getClass()
61+
&& DeepEquals.deepEquals(settings, that.settings);
62+
}
63+
64+
@Override
65+
public int hashCode() {
66+
int result = initializer.getClass().hashCode();
67+
result = 31 * result + DeepEquals.deepHashCode(settings);
68+
return result;
69+
}
70+
}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
* Copyright 2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.zonky.test.db.init;
18+
19+
import javax.sql.DataSource;
20+
import java.io.PrintWriter;
21+
import java.sql.Connection;
22+
import java.sql.SQLException;
23+
import java.sql.SQLFeatureNotSupportedException;
24+
import java.util.logging.Logger;
25+
26+
class ThreadLocalDataSource implements DataSource {
27+
28+
private final ThreadLocal<DataSource> current = new ThreadLocal<>();
29+
30+
void set(DataSource dataSource) {
31+
current.set(dataSource);
32+
}
33+
34+
void clear() {
35+
current.remove();
36+
}
37+
38+
@Override
39+
public Connection getConnection() throws SQLException {
40+
return current.get().getConnection();
41+
}
42+
43+
@Override
44+
public Connection getConnection(String username, String password) throws SQLException {
45+
return current.get().getConnection(username, password);
46+
}
47+
48+
@Override
49+
public PrintWriter getLogWriter() throws SQLException {
50+
return current.get().getLogWriter();
51+
}
52+
53+
@Override
54+
public void setLogWriter(PrintWriter out) throws SQLException {
55+
current.get().setLogWriter(out);
56+
}
57+
58+
@Override
59+
public void setLoginTimeout(int seconds) throws SQLException {
60+
current.get().setLoginTimeout(seconds);
61+
}
62+
63+
@Override
64+
public int getLoginTimeout() throws SQLException {
65+
return current.get().getLoginTimeout();
66+
}
67+
68+
@Override
69+
public Logger getParentLogger() throws SQLFeatureNotSupportedException {
70+
return current.get().getParentLogger();
71+
}
72+
73+
@Override
74+
public <T> T unwrap(Class<T> iface) throws SQLException {
75+
return current.get().unwrap(iface);
76+
}
77+
78+
@Override
79+
public boolean isWrapperFor(Class<?> iface) throws SQLException {
80+
return current.get().isWrapperFor(iface);
81+
}
82+
}
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Copyright 2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.zonky.test.db;
18+
19+
import io.zonky.test.category.FlywayTestSuite;
20+
import io.zonky.test.support.ConditionalTestRule;
21+
import io.zonky.test.support.TestAssumptions;
22+
import org.junit.ClassRule;
23+
import org.junit.Test;
24+
import org.junit.experimental.categories.Category;
25+
import org.junit.runner.RunWith;
26+
import org.springframework.beans.factory.annotation.Autowired;
27+
import org.springframework.boot.test.autoconfigure.jdbc.JdbcTest;
28+
import org.springframework.context.annotation.Bean;
29+
import org.springframework.context.annotation.Configuration;
30+
import org.springframework.jdbc.core.JdbcTemplate;
31+
import org.springframework.test.context.TestPropertySource;
32+
import org.springframework.test.context.junit4.SpringRunner;
33+
34+
import javax.sql.DataSource;
35+
import java.util.List;
36+
import java.util.Map;
37+
38+
import static io.zonky.test.db.AutoConfigureEmbeddedDatabase.DatabaseType.POSTGRES;
39+
import static org.assertj.core.api.Assertions.assertThat;
40+
import static org.assertj.core.api.Assertions.entry;
41+
42+
@RunWith(SpringRunner.class)
43+
@Category(FlywayTestSuite.class)
44+
@AutoConfigureEmbeddedDatabase(type = POSTGRES)
45+
@JdbcTest
46+
@TestPropertySource(properties = {
47+
"spring.sql.init.mode=always",
48+
"spring.sql.init.schema-locations=" +
49+
"classpath:/db/schema/init-schema.sql," +
50+
"classpath:/db/migration/V0001_1__create_person_table.sql," +
51+
"classpath:/db/migration/V0002_1__rename_surname_column.sql",
52+
53+
"flyway.enabled=false",
54+
"spring.flyway.enabled=false",
55+
"liquibase.enabled=false",
56+
"spring.liquibase.enabled=false"
57+
})
58+
public class SpringBootSqlInitIntegrationTest {
59+
60+
@ClassRule
61+
public static ConditionalTestRule conditionalTestRule = new ConditionalTestRule(() -> {
62+
TestAssumptions.assumeSpringBootSupportsJdbcTestAnnotation();
63+
TestAssumptions.assumeSpringBootSupportsSqlInit();
64+
});
65+
66+
private static final String SQL_SELECT_PERSONS = "select * from test.person";
67+
68+
@Configuration
69+
static class Config {
70+
71+
@Bean
72+
public JdbcTemplate jdbcTemplate(DataSource dataSource) {
73+
return new JdbcTemplate(dataSource);
74+
}
75+
}
76+
77+
@Autowired
78+
private JdbcTemplate jdbcTemplate;
79+
80+
@Test
81+
public void testSqlInitScripts() {
82+
assertThat(jdbcTemplate).isNotNull();
83+
84+
List<Map<String, Object>> persons = jdbcTemplate.queryForList(SQL_SELECT_PERSONS);
85+
assertThat(persons).isNotNull().hasSize(1);
86+
87+
Map<String, Object> person = persons.get(0);
88+
assertThat(person).containsExactly(
89+
entry("id", 1L),
90+
entry("first_name", "Dave"),
91+
entry("last_name", "Syer"));
92+
}
93+
}

0 commit comments

Comments
 (0)