Skip to content

Commit

Permalink
Add global error handler for transactional event listener
Browse files Browse the repository at this point in the history
  • Loading branch information
hyh1016 committed Feb 16, 2025
1 parent fbed55e commit a15f5f6
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import org.springframework.core.type.AnnotationMetadata;
import org.springframework.transaction.TransactionManager;
import org.springframework.transaction.config.TransactionManagementConfigUtils;
import org.springframework.transaction.config.TransactionalEventErrorHandler;
import org.springframework.transaction.config.GlobalTransactionalEventErrorHandler;
import org.springframework.transaction.event.TransactionalEventListenerFactory;
import org.springframework.transaction.interceptor.RollbackRuleAttribute;
import org.springframework.transaction.interceptor.TransactionAttributeSource;
Expand Down Expand Up @@ -94,7 +94,10 @@ public TransactionAttributeSource transactionAttributeSource() {

@Bean(name = TransactionManagementConfigUtils.TRANSACTIONAL_EVENT_LISTENER_FACTORY_BEAN_NAME)
@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
public static TransactionalEventListenerFactory transactionalEventListenerFactory(@Nullable TransactionalEventErrorHandler errorHandler) {
public static TransactionalEventListenerFactory transactionalEventListenerFactory(@Nullable GlobalTransactionalEventErrorHandler errorHandler) {
if (errorHandler == null) {
return new RestrictedTransactionalEventListenerFactory();
}
return new RestrictedTransactionalEventListenerFactory(errorHandler);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@

import java.lang.reflect.Method;

import org.jspecify.annotations.Nullable;
import org.springframework.context.ApplicationListener;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.transaction.config.TransactionalEventErrorHandler;
import org.springframework.transaction.config.GlobalTransactionalEventErrorHandler;
import org.springframework.transaction.event.TransactionalEventListenerFactory;

/**
Expand All @@ -37,7 +36,11 @@
*/
public class RestrictedTransactionalEventListenerFactory extends TransactionalEventListenerFactory {

public RestrictedTransactionalEventListenerFactory(@Nullable TransactionalEventErrorHandler errorHandler) {
public RestrictedTransactionalEventListenerFactory() {
super();
}

public RestrictedTransactionalEventListenerFactory(GlobalTransactionalEventErrorHandler errorHandler) {
super(errorHandler);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package org.springframework.transaction.config;

import org.jspecify.annotations.Nullable;
import org.springframework.context.ApplicationEvent;
import org.springframework.transaction.event.TransactionalApplicationListener;

public abstract class GlobalTransactionalEventErrorHandler implements TransactionalApplicationListener.SynchronizationCallback {

public abstract void handle(ApplicationEvent event, @Nullable Throwable ex);

@Override
public void postProcessEvent(ApplicationEvent event, @Nullable Throwable ex) {
if (ex != null) {
handle(event, ex);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import org.springframework.context.event.EventListenerFactory;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.transaction.config.TransactionalEventErrorHandler;
import org.springframework.transaction.config.GlobalTransactionalEventErrorHandler;

/**
* {@link EventListenerFactory} implementation that handles {@link TransactionalEventListener}
Expand All @@ -37,11 +37,11 @@ public class TransactionalEventListenerFactory implements EventListenerFactory,

private int order = 50;

private @Nullable TransactionalEventErrorHandler errorHandler;
private @Nullable GlobalTransactionalEventErrorHandler errorHandler;

public TransactionalEventListenerFactory() { }

public TransactionalEventListenerFactory(@Nullable TransactionalEventErrorHandler errorHandler) {
public TransactionalEventListenerFactory(GlobalTransactionalEventErrorHandler errorHandler) {
this.errorHandler = errorHandler;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@
import java.util.List;
import java.util.Map;

import org.jspecify.annotations.Nullable;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationEvent;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
Expand All @@ -43,6 +45,7 @@
import org.springframework.transaction.annotation.EnableTransactionManagement;
import org.springframework.transaction.annotation.Propagation;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.transaction.config.GlobalTransactionalEventErrorHandler;
import org.springframework.transaction.support.TransactionSynchronization;
import org.springframework.transaction.support.TransactionSynchronizationManager;
import org.springframework.transaction.support.TransactionTemplate;
Expand Down Expand Up @@ -369,6 +372,45 @@ void conditionFoundOnMetaAnnotation() {
getEventCollector().assertNoEventReceived();
}

@Test
void afterCommitThrowException() {
doLoad(HandlerConfiguration.class, AfterCommitErrorHandlerTestListener.class);
this.transactionTemplate.execute(status -> {
getContext().publishEvent("test");
getEventCollector().assertNoEventReceived();
return null;
});
getEventCollector().assertEvents(EventCollector.AFTER_COMMIT, "test");
getEventCollector().assertEvents(EventCollector.HANDLE_ERROR, "HANDLE_ERROR");
getEventCollector().assertTotalEventsCount(2);
}

@Test
void afterRollbackThrowException() {
doLoad(HandlerConfiguration.class, AfterRollbackErrorHandlerTestListener.class);
this.transactionTemplate.execute(status -> {
getContext().publishEvent("test");
getEventCollector().assertNoEventReceived();
status.setRollbackOnly();
return null;
});
getEventCollector().assertEvents(EventCollector.AFTER_ROLLBACK, "test");
getEventCollector().assertEvents(EventCollector.HANDLE_ERROR, "HANDLE_ERROR");
getEventCollector().assertTotalEventsCount(2);
}

@Test
void afterCompletionThrowException() {
doLoad(HandlerConfiguration.class, AfterCompletionErrorHandlerTestListener.class);
this.transactionTemplate.execute(status -> {
getContext().publishEvent("test");
getEventCollector().assertNoEventReceived();
return null;
});
getEventCollector().assertEvents(EventCollector.AFTER_COMPLETION, "test");
getEventCollector().assertEvents(EventCollector.HANDLE_ERROR, "HANDLE_ERROR");
getEventCollector().assertTotalEventsCount(2);
}

protected EventCollector getEventCollector() {
return this.eventCollector;
Expand Down Expand Up @@ -442,6 +484,36 @@ public TransactionTemplate transactionTemplate() {
}
}

@Configuration
@EnableTransactionManagement
static class HandlerConfiguration {

@Bean
public EventCollector eventCollector() {
return new EventCollector();
}

@Bean
public TestBean testBean(ApplicationEventPublisher eventPublisher) {
return new TestBean(eventPublisher);
}

@Bean
public CallCountingTransactionManager transactionManager() {
return new CallCountingTransactionManager();
}

@Bean
public TransactionTemplate transactionTemplate() {
return new TransactionTemplate(transactionManager());
}

@Bean
public AfterRollbackErrorHandler errorHandler(ApplicationEventPublisher eventPublisher) {
return new AfterRollbackErrorHandler(eventPublisher);
}
}


@Configuration
static class MulticasterWithCustomExecutor {
Expand All @@ -467,7 +539,9 @@ static class EventCollector {

public static final String AFTER_ROLLBACK = "AFTER_ROLLBACK";

public static final String[] ALL_PHASES = {IMMEDIATELY, BEFORE_COMMIT, AFTER_COMMIT, AFTER_ROLLBACK};
public static final String HANDLE_ERROR = "HANDLE_ERROR";

public static final String[] ALL_PHASES = {IMMEDIATELY, BEFORE_COMMIT, AFTER_COMMIT, AFTER_ROLLBACK, HANDLE_ERROR};

private final MultiValueMap<String, Object> events = new LinkedMultiValueMap<>();

Expand Down Expand Up @@ -677,6 +751,51 @@ public void handleAfterCommit(String data) {
}


@Component
static class AfterCommitErrorHandlerTestListener extends BaseTransactionalTestListener {

@TransactionalEventListener(phase = AFTER_COMMIT, condition = "!'HANDLE_ERROR'.equals(#data)")
public void handleBeforeCommit(String data) {
handleEvent(EventCollector.AFTER_COMMIT, data);
throw new IllegalStateException("test");
}

@EventListener(condition = "'HANDLE_ERROR'.equals(#data)")
public void handleImmediately(String data) {
handleEvent(EventCollector.HANDLE_ERROR, data);
}
}

@Component
static class AfterRollbackErrorHandlerTestListener extends BaseTransactionalTestListener {

@TransactionalEventListener(phase = AFTER_ROLLBACK, condition = "!'HANDLE_ERROR'.equals(#data)")
public void handleBeforeCommit(String data) {
handleEvent(EventCollector.AFTER_ROLLBACK, data);
throw new IllegalStateException("test");
}

@EventListener(condition = "'HANDLE_ERROR'.equals(#data)")
public void handleImmediately(String data) {
handleEvent(EventCollector.HANDLE_ERROR, data);
}
}

@Component
static class AfterCompletionErrorHandlerTestListener extends BaseTransactionalTestListener {

@TransactionalEventListener(phase = AFTER_COMPLETION, condition = "!'HANDLE_ERROR'.equals(#data)")
public void handleBeforeCommit(String data) {
handleEvent(EventCollector.AFTER_COMPLETION, data);
throw new IllegalStateException("test");
}

@EventListener(condition = "'HANDLE_ERROR'.equals(#data)")
public void handleImmediately(String data) {
handleEvent(EventCollector.HANDLE_ERROR, data);
}
}

static class EventTransactionSynchronization implements TransactionSynchronization {

private final int order;
Expand All @@ -691,4 +810,18 @@ public int getOrder() {
}
}

static class AfterRollbackErrorHandler extends GlobalTransactionalEventErrorHandler {

private final ApplicationEventPublisher eventPublisher;

AfterRollbackErrorHandler(ApplicationEventPublisher eventPublisher) {
this.eventPublisher = eventPublisher;
}

@Override
public void handle(ApplicationEvent event, @Nullable Throwable ex) {
eventPublisher.publishEvent("HANDLE_ERROR");
}
}

}

0 comments on commit a15f5f6

Please sign in to comment.