Description
As far as my understanding is correct, configured ServletRequests are only registered correctly when using embedded Tomcat.
As a solution, I implemented my own filter, which works similarly to the Spring Security Filterchain, delegating requests to the list of ServletRequestFilters registered in the application context.
In order to access the registered ServletRequestFilter instances, they are passed to the delegating filter after the aplicationcontext refresh (because my own Filter does not have access to the list of ServletRequestFilters at creation time).
In addition, two points are currently only possible as a workaround:
- Determination whether the respective ServletRequestFilter has processed the request (using fixed http status code 429)
- Determination whether the respective ServletRequestFilter has recognized a RateLimit (checking for X-Rate-Limit-Remaining http response header)
This solution should work also for embedded Tomcat (not verified). To get rid of the workarounds, the following enhancements would helpfull:
- ServletRequestFilter should expose his Filter configuration as a public method
- ServletRequestFilter should expose the
shouldNotFilter(HttpServletRequest request)
method as public method. - ServletRequestFilter should expose his outcomming result (maybe as a request attribute?)
I would be very happy if the suggestions or parts of them could be taken into account in a next version
Here are the implementations:
Spring Boot Configuration
import com.giffing.bucket4j.spring.boot.starter.filter.servlet.ServletRequestFilter;
import jakarta.servlet.DispatcherType;
import jakarta.servlet.Filter;
import lombok.extern.log4j.Log4j2;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.event.ContextRefreshedEvent;
import org.springframework.context.event.EventListener;
import org.springframework.core.Ordered;
import java.util.LinkedList;
import java.util.List;
@Configuration
@Log4j2
@ConditionalOnProperty(name = "bucket4j.enabled")
public class RateLimitConfiguration {
private RateLimitFilter rateLimitFilter;
@Bean
public FilterRegistrationBean<RateLimitFilter> registrationRateLimitFilter() {
this.rateLimitFilter = new RateLimitFilter();
FilterRegistrationBean<RateLimitFilter> filterRegistrationBean = new FilterRegistrationBean<>();
filterRegistrationBean.setName("rateLimitFilter");
filterRegistrationBean.setFilter(rateLimitFilter);
filterRegistrationBean.setDispatcherTypes(DispatcherType.REQUEST);
filterRegistrationBean.setOrder(Ordered.HIGHEST_PRECEDENCE + 10);
return filterRegistrationBean;
}
@EventListener
public void handleContextRefreshedEvent(ContextRefreshedEvent event) {
ApplicationContext applicationContext = event.getApplicationContext ();
List<ServletRequestFilter> servletRequestFilters = new LinkedList<> ();
String[] filterBeanNames = applicationContext.getBeanNamesForType(Filter.class);
for (String beanName: filterBeanNames) {
Filter filter = applicationContext.getBean(beanName, Filter.class);
if (filter instanceof ServletRequestFilter servletRequestFilter) {
servletRequestFilters.add (servletRequestFilter);
};
}
this.rateLimitFilter.setServletRequestFilters(servletRequestFilters);
}
}
Filter
import com.giffing.bucket4j.spring.boot.starter.filter.servlet.ServletRequestFilter;
import jakarta.servlet.Filter;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.log4j.Log4j2;
import org.springframework.web.filter.OncePerRequestFilter;
import java.io.IOException;
import java.util.Comparator;
import java.util.LinkedList;
import java.util.List;
@Log4j2
public class RateLimitFilter extends OncePerRequestFilter {
private LinkedList<ServletRequestFilter> servletRequestFilters = new LinkedList<> ();
@Override
protected void doFilterInternal (HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
if (this.servletRequestFilters.isEmpty ()) {
filterChain.doFilter (request, response);
return;
}
new VirtualFilterChain (filterChain, this.servletRequestFilters).doFilter (request, response);
}
public void setServletRequestFilters (List<ServletRequestFilter> servletRequestFilters) {
this.servletRequestFilters.addAll (servletRequestFilters);
this.servletRequestFilters.sort (Comparator.comparing (ServletRequestFilter::getOrder));
}
private static class VirtualFilterChain implements FilterChain {
private final FilterChain originalChain;
private final List<? extends Filter> filters;
private int currentPosition = 0;
public VirtualFilterChain(FilterChain chain, List<? extends Filter> filters) {
this.originalChain = chain;
this.filters = filters;
}
@Override
public void doFilter(final ServletRequest request, final ServletResponse response) throws IOException, ServletException {
if (response instanceof HttpServletResponse httpServletResponse) {
// TO many requests?
if (httpServletResponse.getStatus () == 429) {
return;
}
// ServletRequestFilter matched and proceed the request?
if (httpServletResponse.containsHeader ("X-Rate-Limit-Remaining")) {
this.currentPosition = this.filters.size();
}
}
if (this.currentPosition == this.filters.size()) {
this.originalChain.doFilter(request, response);
} else {
this.currentPosition++;
Filter nextFilter = this.filters.get(this.currentPosition - 1);
nextFilter.doFilter(request, response, this);
}
}
}
}