Skip to content

Commit

Permalink
Add global error handler for transactional event listener
Browse files Browse the repository at this point in the history
Signed-off-by: YIHYUN HA <[email protected]>
  • Loading branch information
hyh1016 committed Feb 16, 2025
1 parent 21604d1 commit 56b9013
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.springframework.core.type.AnnotationMetadata;
import org.springframework.transaction.TransactionManager;
import org.springframework.transaction.config.TransactionManagementConfigUtils;
import org.springframework.transaction.config.GlobalTransactionalEventErrorHandler;
import org.springframework.transaction.event.TransactionalEventListenerFactory;
import org.springframework.transaction.interceptor.RollbackRuleAttribute;
import org.springframework.transaction.interceptor.TransactionAttributeSource;
Expand Down Expand Up @@ -93,8 +94,11 @@ public TransactionAttributeSource transactionAttributeSource() {

@Bean(name = TransactionManagementConfigUtils.TRANSACTIONAL_EVENT_LISTENER_FACTORY_BEAN_NAME)
@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
public static TransactionalEventListenerFactory transactionalEventListenerFactory() {
return new RestrictedTransactionalEventListenerFactory();
public static TransactionalEventListenerFactory transactionalEventListenerFactory(@Nullable GlobalTransactionalEventErrorHandler errorHandler) {
if (errorHandler == null) {
return new RestrictedTransactionalEventListenerFactory();
}
return new RestrictedTransactionalEventListenerFactory(errorHandler);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import org.springframework.context.ApplicationListener;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.transaction.config.GlobalTransactionalEventErrorHandler;
import org.springframework.transaction.event.TransactionalEventListenerFactory;

/**
Expand All @@ -35,6 +36,14 @@
*/
public class RestrictedTransactionalEventListenerFactory extends TransactionalEventListenerFactory {

public RestrictedTransactionalEventListenerFactory() {
super();
}

public RestrictedTransactionalEventListenerFactory(GlobalTransactionalEventErrorHandler errorHandler) {
super(errorHandler);
}

@Override
public ApplicationListener<?> createApplicationListener(String beanName, Class<?> type, Method method) {
Transactional txAnn = AnnotatedElementUtils.findMergedAnnotation(method, Transactional.class);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package org.springframework.transaction.config;

import org.jspecify.annotations.Nullable;
import org.springframework.context.ApplicationEvent;
import org.springframework.transaction.event.TransactionalApplicationListener;

public abstract class GlobalTransactionalEventErrorHandler implements TransactionalApplicationListener.SynchronizationCallback {

public abstract void handle(ApplicationEvent event, @Nullable Throwable ex);

@Override
public void postProcessEvent(ApplicationEvent event, @Nullable Throwable ex) {
if (ex != null) {
handle(event, ex);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@

import java.lang.reflect.Method;

import org.jspecify.annotations.Nullable;
import org.springframework.context.ApplicationListener;
import org.springframework.context.event.EventListenerFactory;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.transaction.config.GlobalTransactionalEventErrorHandler;

/**
* {@link EventListenerFactory} implementation that handles {@link TransactionalEventListener}
Expand All @@ -35,6 +37,13 @@ public class TransactionalEventListenerFactory implements EventListenerFactory,

private int order = 50;

private @Nullable GlobalTransactionalEventErrorHandler errorHandler;

public TransactionalEventListenerFactory() { }

public TransactionalEventListenerFactory(GlobalTransactionalEventErrorHandler errorHandler) {
this.errorHandler = errorHandler;
}

public void setOrder(int order) {
this.order = order;
Expand All @@ -53,7 +62,14 @@ public boolean supportsMethod(Method method) {

@Override
public ApplicationListener<?> createApplicationListener(String beanName, Class<?> type, Method method) {
return new TransactionalApplicationListenerMethodAdapter(beanName, type, method);
if (errorHandler == null) {
return new TransactionalApplicationListenerMethodAdapter(beanName, type, method);
}
else {
TransactionalApplicationListenerMethodAdapter listener = new TransactionalApplicationListenerMethodAdapter(beanName, type, method);
listener.addCallback(errorHandler);
return listener;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@
import java.util.List;
import java.util.Map;

import org.jspecify.annotations.Nullable;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationEvent;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
Expand All @@ -43,6 +45,7 @@
import org.springframework.transaction.annotation.EnableTransactionManagement;
import org.springframework.transaction.annotation.Propagation;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.transaction.config.GlobalTransactionalEventErrorHandler;
import org.springframework.transaction.support.TransactionSynchronization;
import org.springframework.transaction.support.TransactionSynchronizationManager;
import org.springframework.transaction.support.TransactionTemplate;
Expand Down Expand Up @@ -99,12 +102,12 @@ void immediately() {
void immediatelyImpactsCurrentTransaction() {
load(ImmediateTestListener.class, BeforeCommitTestListener.class);
assertThatIllegalStateException().isThrownBy(() ->
this.transactionTemplate.execute(status -> {
getContext().publishEvent("FAIL");
throw new AssertionError("Should have thrown an exception at this point");
}))
.withMessageContaining("Test exception")
.withMessageContaining(EventCollector.IMMEDIATELY);
this.transactionTemplate.execute(status -> {
getContext().publishEvent("FAIL");
throw new AssertionError("Should have thrown an exception at this point");
}))
.withMessageContaining("Test exception")
.withMessageContaining(EventCollector.IMMEDIATELY);

getEventCollector().assertEvents(EventCollector.IMMEDIATELY, "FAIL");
getEventCollector().assertTotalEventsCount(1);
Expand Down Expand Up @@ -369,6 +372,45 @@ void conditionFoundOnMetaAnnotation() {
getEventCollector().assertNoEventReceived();
}

@Test
void afterCommitThrowException() {
doLoad(HandlerConfiguration.class, AfterCommitErrorHandlerTestListener.class);
this.transactionTemplate.execute(status -> {
getContext().publishEvent("test");
getEventCollector().assertNoEventReceived();
return null;
});
getEventCollector().assertEvents(EventCollector.AFTER_COMMIT, "test");
getEventCollector().assertEvents(EventCollector.HANDLE_ERROR, "HANDLE_ERROR");
getEventCollector().assertTotalEventsCount(2);
}

@Test
void afterRollbackThrowException() {
doLoad(HandlerConfiguration.class, AfterRollbackErrorHandlerTestListener.class);
this.transactionTemplate.execute(status -> {
getContext().publishEvent("test");
getEventCollector().assertNoEventReceived();
status.setRollbackOnly();
return null;
});
getEventCollector().assertEvents(EventCollector.AFTER_ROLLBACK, "test");
getEventCollector().assertEvents(EventCollector.HANDLE_ERROR, "HANDLE_ERROR");
getEventCollector().assertTotalEventsCount(2);
}

@Test
void afterCompletionThrowException() {
doLoad(HandlerConfiguration.class, AfterCompletionErrorHandlerTestListener.class);
this.transactionTemplate.execute(status -> {
getContext().publishEvent("test");
getEventCollector().assertNoEventReceived();
return null;
});
getEventCollector().assertEvents(EventCollector.AFTER_COMPLETION, "test");
getEventCollector().assertEvents(EventCollector.HANDLE_ERROR, "HANDLE_ERROR");
getEventCollector().assertTotalEventsCount(2);
}

protected EventCollector getEventCollector() {
return this.eventCollector;
Expand Down Expand Up @@ -442,6 +484,36 @@ public TransactionTemplate transactionTemplate() {
}
}

@Configuration
@EnableTransactionManagement
static class HandlerConfiguration {

@Bean
public EventCollector eventCollector() {
return new EventCollector();
}

@Bean
public TestBean testBean(ApplicationEventPublisher eventPublisher) {
return new TestBean(eventPublisher);
}

@Bean
public CallCountingTransactionManager transactionManager() {
return new CallCountingTransactionManager();
}

@Bean
public TransactionTemplate transactionTemplate() {
return new TransactionTemplate(transactionManager());
}

@Bean
public AfterRollbackErrorHandler errorHandler(ApplicationEventPublisher eventPublisher) {
return new AfterRollbackErrorHandler(eventPublisher);
}
}


@Configuration
static class MulticasterWithCustomExecutor {
Expand All @@ -467,7 +539,9 @@ static class EventCollector {

public static final String AFTER_ROLLBACK = "AFTER_ROLLBACK";

public static final String[] ALL_PHASES = {IMMEDIATELY, BEFORE_COMMIT, AFTER_COMMIT, AFTER_ROLLBACK};
public static final String HANDLE_ERROR = "HANDLE_ERROR";

public static final String[] ALL_PHASES = {IMMEDIATELY, BEFORE_COMMIT, AFTER_COMMIT, AFTER_ROLLBACK, HANDLE_ERROR};

private final MultiValueMap<String, Object> events = new LinkedMultiValueMap<>();

Expand All @@ -486,7 +560,7 @@ public void assertNoEventReceived(String... phases) {
for (String phase : phases) {
List<Object> eventsForPhase = getEvents(phase);
assertThat(eventsForPhase.size()).as("Expected no events for phase '" + phase + "' " +
"but got " + eventsForPhase + ":").isEqualTo(0);
"but got " + eventsForPhase + ":").isEqualTo(0);
}
}

Expand All @@ -504,7 +578,7 @@ public void assertTotalEventsCount(int number) {
size += entry.getValue().size();
}
assertThat(size).as("Wrong number of total events (" + this.events.size() + ") " +
"registered phase(s)").isEqualTo(number);
"registered phase(s)").isEqualTo(number);
}
}

Expand Down Expand Up @@ -677,6 +751,51 @@ public void handleAfterCommit(String data) {
}


@Component
static class AfterCommitErrorHandlerTestListener extends BaseTransactionalTestListener {

@TransactionalEventListener(phase = AFTER_COMMIT, condition = "!'HANDLE_ERROR'.equals(#data)")
public void handleBeforeCommit(String data) {
handleEvent(EventCollector.AFTER_COMMIT, data);
throw new IllegalStateException("test");
}

@EventListener(condition = "'HANDLE_ERROR'.equals(#data)")
public void handleImmediately(String data) {
handleEvent(EventCollector.HANDLE_ERROR, data);
}
}

@Component
static class AfterRollbackErrorHandlerTestListener extends BaseTransactionalTestListener {

@TransactionalEventListener(phase = AFTER_ROLLBACK, condition = "!'HANDLE_ERROR'.equals(#data)")
public void handleBeforeCommit(String data) {
handleEvent(EventCollector.AFTER_ROLLBACK, data);
throw new IllegalStateException("test");
}

@EventListener(condition = "'HANDLE_ERROR'.equals(#data)")
public void handleImmediately(String data) {
handleEvent(EventCollector.HANDLE_ERROR, data);
}
}

@Component
static class AfterCompletionErrorHandlerTestListener extends BaseTransactionalTestListener {

@TransactionalEventListener(phase = AFTER_COMPLETION, condition = "!'HANDLE_ERROR'.equals(#data)")
public void handleBeforeCommit(String data) {
handleEvent(EventCollector.AFTER_COMPLETION, data);
throw new IllegalStateException("test");
}

@EventListener(condition = "'HANDLE_ERROR'.equals(#data)")
public void handleImmediately(String data) {
handleEvent(EventCollector.HANDLE_ERROR, data);
}
}

static class EventTransactionSynchronization implements TransactionSynchronization {

private final int order;
Expand All @@ -691,4 +810,18 @@ public int getOrder() {
}
}

static class AfterRollbackErrorHandler extends GlobalTransactionalEventErrorHandler {

private final ApplicationEventPublisher eventPublisher;

AfterRollbackErrorHandler(ApplicationEventPublisher eventPublisher) {
this.eventPublisher = eventPublisher;
}

@Override
public void handle(ApplicationEvent event, @Nullable Throwable ex) {
eventPublisher.publishEvent("HANDLE_ERROR");
}
}

}

0 comments on commit 56b9013

Please sign in to comment.