Commit 4a33ab55 authored by Dave Syer's avatar Dave Syer

Make sure ErrorPageFilter is only applied once per request

Fixes gh-1257
parent 0c52817c
...@@ -38,6 +38,7 @@ import org.springframework.boot.context.embedded.ErrorPage; ...@@ -38,6 +38,7 @@ import org.springframework.boot.context.embedded.ErrorPage;
import org.springframework.core.Ordered; import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order; import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.web.filter.OncePerRequestFilter;
/** /**
* A special {@link AbstractConfigurableEmbeddedServletContainer} for non-embedded * A special {@link AbstractConfigurableEmbeddedServletContainer} for non-embedded
...@@ -77,20 +78,27 @@ class ErrorPageFilter extends AbstractConfigurableEmbeddedServletContainer imple ...@@ -77,20 +78,27 @@ class ErrorPageFilter extends AbstractConfigurableEmbeddedServletContainer imple
private final Map<Class<?>, Class<?>> subtypes = new HashMap<Class<?>, Class<?>>(); private final Map<Class<?>, Class<?>> subtypes = new HashMap<Class<?>, Class<?>>();
private final OncePerRequestFilter delegate = new OncePerRequestFilter(
) {
@Override
protected void doFilterInternal(HttpServletRequest request,
HttpServletResponse response, FilterChain chain)
throws ServletException, IOException {
ErrorPageFilter.this.doFilter(request, response, chain);
}
};
@Override @Override
public void init(FilterConfig filterConfig) throws ServletException { public void init(FilterConfig filterConfig) throws ServletException {
delegate.init(filterConfig);
} }
@Override @Override
public void doFilter(ServletRequest request, ServletResponse response, public void doFilter(ServletRequest request, ServletResponse response,
FilterChain chain) throws IOException, ServletException { FilterChain chain) throws IOException, ServletException {
if (request instanceof HttpServletRequest delegate.doFilter(request, response, chain);
&& response instanceof HttpServletResponse) {
doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain);
}
else {
chain.doFilter(request, response);
}
} }
private void doFilter(HttpServletRequest request, HttpServletResponse response, private void doFilter(HttpServletRequest request, HttpServletResponse response,
......
...@@ -16,6 +16,11 @@ ...@@ -16,6 +16,11 @@
package org.springframework.boot.context.web; package org.springframework.boot.context.web;
import static org.hamcrest.Matchers.equalTo;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import java.io.IOException; import java.io.IOException;
import javax.servlet.RequestDispatcher; import javax.servlet.RequestDispatcher;
...@@ -29,13 +34,10 @@ import org.junit.Test; ...@@ -29,13 +34,10 @@ import org.junit.Test;
import org.springframework.boot.context.embedded.ErrorPage; import org.springframework.boot.context.embedded.ErrorPage;
import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus;
import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockFilterConfig;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import static org.hamcrest.Matchers.equalTo;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
/** /**
* Tests for {@link ErrorPageFilter}. * Tests for {@link ErrorPageFilter}.
* *
...@@ -97,6 +99,21 @@ public class ErrorPageFilterTests { ...@@ -97,6 +99,21 @@ public class ErrorPageFilterTests {
equalTo(400)); equalTo(400));
} }
@Test
public void oncePerRequest() throws Exception {
this.chain = new MockFilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
((HttpServletResponse) response).sendError(400, "BAD");
assertNotNull(request.getAttribute("FILTER.FILTERED"));
super.doFilter(request, response);
}
};
filter.init(new MockFilterConfig("FILTER"));
this.filter.doFilter(this.request, this.response, this.chain);
}
@Test @Test
public void globalError() throws Exception { public void globalError() throws Exception {
this.filter.addErrorPages(new ErrorPage("/error")); this.filter.addErrorPages(new ErrorPage("/error"));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment