Commit 9da14fe6 authored by Andy Wilkinson's avatar Andy Wilkinson

Merge branch '1.2.x'

parents d3d713d0 6fd30424
...@@ -39,6 +39,7 @@ import org.springframework.core.Ordered; ...@@ -39,6 +39,7 @@ import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order; import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.filter.OncePerRequestFilter;
import org.springframework.web.util.NestedServletException;
/** /**
* A special {@link AbstractConfigurableEmbeddedServletContainer} for non-embedded * A special {@link AbstractConfigurableEmbeddedServletContainer} for non-embedded
...@@ -69,6 +70,8 @@ public class ErrorPageFilter extends AbstractConfigurableEmbeddedServletContaine ...@@ -69,6 +70,8 @@ public class ErrorPageFilter extends AbstractConfigurableEmbeddedServletContaine
private static final String ERROR_MESSAGE = "javax.servlet.error.message"; private static final String ERROR_MESSAGE = "javax.servlet.error.message";
public static final String ERROR_REQUEST_URI = "javax.servlet.error.request_uri";
private static final String ERROR_STATUS_CODE = "javax.servlet.error.status_code"; private static final String ERROR_STATUS_CODE = "javax.servlet.error.status_code";
private String global; private String global;
...@@ -121,7 +124,11 @@ public class ErrorPageFilter extends AbstractConfigurableEmbeddedServletContaine ...@@ -121,7 +124,11 @@ public class ErrorPageFilter extends AbstractConfigurableEmbeddedServletContaine
} }
} }
catch (Throwable ex) { catch (Throwable ex) {
handleException(request, response, wrapped, ex); Throwable exceptionToHandle = ex;
if (ex instanceof NestedServletException) {
exceptionToHandle = ((NestedServletException) ex).getRootCause();
}
handleException(request, response, wrapped, exceptionToHandle);
response.flushBuffer(); response.flushBuffer();
} }
} }
...@@ -225,9 +232,10 @@ public class ErrorPageFilter extends AbstractConfigurableEmbeddedServletContaine ...@@ -225,9 +232,10 @@ public class ErrorPageFilter extends AbstractConfigurableEmbeddedServletContaine
return this.global; return this.global;
} }
private void setErrorAttributes(ServletRequest request, int status, String message) { private void setErrorAttributes(HttpServletRequest request, int status, String message) {
request.setAttribute(ERROR_STATUS_CODE, status); request.setAttribute(ERROR_STATUS_CODE, status);
request.setAttribute(ERROR_MESSAGE, message); request.setAttribute(ERROR_MESSAGE, message);
request.setAttribute(ERROR_REQUEST_URI, request.getRequestURI());
} }
private void rethrow(Throwable ex) throws IOException, ServletException { private void rethrow(Throwable ex) throws IOException, ServletException {
......
...@@ -38,6 +38,7 @@ import org.springframework.web.context.request.async.DeferredResult; ...@@ -38,6 +38,7 @@ import org.springframework.web.context.request.async.DeferredResult;
import org.springframework.web.context.request.async.StandardServletAsyncWebRequest; import org.springframework.web.context.request.async.StandardServletAsyncWebRequest;
import org.springframework.web.context.request.async.WebAsyncManager; import org.springframework.web.context.request.async.WebAsyncManager;
import org.springframework.web.context.request.async.WebAsyncUtils; import org.springframework.web.context.request.async.WebAsyncUtils;
import org.springframework.web.util.NestedServletException;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
...@@ -62,7 +63,8 @@ public class ErrorPageFilterTests { ...@@ -62,7 +63,8 @@ public class ErrorPageFilterTests {
private ErrorPageFilter filter = new ErrorPageFilter(); private ErrorPageFilter filter = new ErrorPageFilter();
private MockHttpServletRequest request = new MockHttpServletRequest(); private MockHttpServletRequest request = new MockHttpServletRequest("GET",
"/test/path");
private MockHttpServletResponse response = new MockHttpServletResponse(); private MockHttpServletResponse response = new MockHttpServletResponse();
...@@ -199,6 +201,9 @@ public class ErrorPageFilterTests { ...@@ -199,6 +201,9 @@ public class ErrorPageFilterTests {
equalTo((Object) 400)); equalTo((Object) 400));
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE), assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE),
equalTo((Object) "BAD")); equalTo((Object) "BAD"));
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI),
equalTo((Object) "/test/path"));
assertTrue(this.response.isCommitted()); assertTrue(this.response.isCommitted());
assertThat(this.response.getForwardedUrl(), equalTo("/error")); assertThat(this.response.getForwardedUrl(), equalTo("/error"));
} }
...@@ -221,6 +226,8 @@ public class ErrorPageFilterTests { ...@@ -221,6 +226,8 @@ public class ErrorPageFilterTests {
equalTo((Object) 400)); equalTo((Object) 400));
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE), assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE),
equalTo((Object) "BAD")); equalTo((Object) "BAD"));
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI),
equalTo((Object) "/test/path"));
assertTrue(this.response.isCommitted()); assertTrue(this.response.isCommitted());
assertThat(this.response.getForwardedUrl(), equalTo("/400")); assertThat(this.response.getForwardedUrl(), equalTo("/400"));
} }
...@@ -264,6 +271,8 @@ public class ErrorPageFilterTests { ...@@ -264,6 +271,8 @@ public class ErrorPageFilterTests {
equalTo((Object) "BAD")); equalTo((Object) "BAD"));
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE), assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE),
equalTo((Object) RuntimeException.class.getName())); equalTo((Object) RuntimeException.class.getName()));
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI),
equalTo((Object) "/test/path"));
assertTrue(this.response.isCommitted()); assertTrue(this.response.isCommitted());
assertThat(this.response.getForwardedUrl(), equalTo("/500")); assertThat(this.response.getForwardedUrl(), equalTo("/500"));
} }
...@@ -319,6 +328,8 @@ public class ErrorPageFilterTests { ...@@ -319,6 +328,8 @@ public class ErrorPageFilterTests {
equalTo((Object) "BAD")); equalTo((Object) "BAD"));
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE), assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE),
equalTo((Object) IllegalStateException.class.getName())); equalTo((Object) IllegalStateException.class.getName()));
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI),
equalTo((Object) "/test/path"));
assertTrue(this.response.isCommitted()); assertTrue(this.response.isCommitted());
} }
...@@ -465,6 +476,32 @@ public class ErrorPageFilterTests { ...@@ -465,6 +476,32 @@ public class ErrorPageFilterTests {
assertThat(this.output.toString(), containsString("request [/test/alpha]")); assertThat(this.output.toString(), containsString("request [/test/alpha]"));
} }
@Test
public void nestedServletExceptionIsUnwrapped() throws Exception {
this.filter.addErrorPages(new ErrorPage(RuntimeException.class, "/500"));
this.chain = new MockFilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
super.doFilter(request, response);
throw new NestedServletException("Wrapper", new RuntimeException("BAD"));
}
};
this.filter.doFilter(this.request, this.response, this.chain);
assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getStatus(),
equalTo(500));
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_STATUS_CODE),
equalTo((Object) 500));
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE),
equalTo((Object) "BAD"));
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE),
equalTo((Object) RuntimeException.class.getName()));
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI),
equalTo((Object) "/test/path"));
assertTrue(this.response.isCommitted());
assertThat(this.response.getForwardedUrl(), equalTo("/500"));
}
private void setUpAsyncDispatch() throws Exception { private void setUpAsyncDispatch() throws Exception {
this.request.setAsyncSupported(true); this.request.setAsyncSupported(true);
this.request.setAsyncStarted(true); this.request.setAsyncStarted(true);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment