Commit 28f8e014 authored by Andy Wilkinson's avatar Andy Wilkinson

Merge branch '1.5.x'

parents beecbe30 c2f4d027
/* /*
* Copyright 2012-2016 the original author or authors. * Copyright 2012-2017 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -181,6 +181,8 @@ public class ErrorPageFilter implements Filter, ErrorPageRegistry { ...@@ -181,6 +181,8 @@ public class ErrorPageFilter implements Filter, ErrorPageRegistry {
response.reset(); response.reset();
response.sendError(500, ex.getMessage()); response.sendError(500, ex.getMessage());
request.getRequestDispatcher(path).forward(request, response); request.getRequestDispatcher(path).forward(request, response);
request.removeAttribute(ERROR_EXCEPTION);
request.removeAttribute(ERROR_EXCEPTION_TYPE);
} }
/** /**
......
/* /*
* Copyright 2012-2016 the original author or authors. * Copyright 2012-2017 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
package org.springframework.boot.web.support; package org.springframework.boot.web.support;
import java.io.IOException; import java.io.IOException;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;
import javax.servlet.RequestDispatcher; import javax.servlet.RequestDispatcher;
import javax.servlet.ServletException; import javax.servlet.ServletException;
...@@ -35,6 +38,7 @@ import org.springframework.mock.web.MockFilterChain; ...@@ -35,6 +38,7 @@ import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockFilterConfig; import org.springframework.mock.web.MockFilterConfig;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.mock.web.MockRequestDispatcher;
import org.springframework.web.context.request.async.DeferredResult; import org.springframework.web.context.request.async.DeferredResult;
import org.springframework.web.context.request.async.StandardServletAsyncWebRequest; import org.springframework.web.context.request.async.StandardServletAsyncWebRequest;
import org.springframework.web.context.request.async.WebAsyncManager; import org.springframework.web.context.request.async.WebAsyncManager;
...@@ -57,8 +61,7 @@ public class ErrorPageFilterTests { ...@@ -57,8 +61,7 @@ public class ErrorPageFilterTests {
private ErrorPageFilter filter = new ErrorPageFilter(); private ErrorPageFilter filter = new ErrorPageFilter();
private MockHttpServletRequest request = new MockHttpServletRequest("GET", private DispatchRecordingMockHttpServletRequest request = new DispatchRecordingMockHttpServletRequest();
"/test/path");
private MockHttpServletResponse response = new MockHttpServletResponse(); private MockHttpServletResponse response = new MockHttpServletResponse();
...@@ -261,8 +264,14 @@ public class ErrorPageFilterTests { ...@@ -261,8 +264,14 @@ public class ErrorPageFilterTests {
.isEqualTo(500); .isEqualTo(500);
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE)) assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE))
.isEqualTo("BAD"); .isEqualTo("BAD");
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE)) Map<String, Object> requestAttributes = getAttributesForDispatch("/500");
assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION_TYPE))
.isEqualTo(RuntimeException.class); .isEqualTo(RuntimeException.class);
assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION))
.isInstanceOf(RuntimeException.class);
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE))
.isNull();
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION)).isNull();
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI)) assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI))
.isEqualTo("/test/path"); .isEqualTo("/test/path");
assertThat(this.response.isCommitted()).isTrue(); assertThat(this.response.isCommitted()).isTrue();
...@@ -318,8 +327,14 @@ public class ErrorPageFilterTests { ...@@ -318,8 +327,14 @@ public class ErrorPageFilterTests {
.isEqualTo(500); .isEqualTo(500);
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE)) assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE))
.isEqualTo("BAD"); .isEqualTo("BAD");
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE)) Map<String, Object> requestAttributes = getAttributesForDispatch("/500");
assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION_TYPE))
.isEqualTo(IllegalStateException.class); .isEqualTo(IllegalStateException.class);
assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION))
.isInstanceOf(IllegalStateException.class);
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE))
.isNull();
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION)).isNull();
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI)) assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI))
.isEqualTo("/test/path"); .isEqualTo("/test/path");
assertThat(this.response.isCommitted()).isTrue(); assertThat(this.response.isCommitted()).isTrue();
...@@ -492,8 +507,14 @@ public class ErrorPageFilterTests { ...@@ -492,8 +507,14 @@ public class ErrorPageFilterTests {
.isEqualTo(500); .isEqualTo(500);
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE)) assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE))
.isEqualTo("BAD"); .isEqualTo("BAD");
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE)) Map<String, Object> requestAttributes = getAttributesForDispatch("/500");
assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION_TYPE))
.isEqualTo(RuntimeException.class); .isEqualTo(RuntimeException.class);
assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION))
.isInstanceOf(RuntimeException.class);
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE))
.isNull();
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION)).isNull();
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI)) assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI))
.isEqualTo("/test/path"); .isEqualTo("/test/path");
assertThat(this.response.isCommitted()).isTrue(); assertThat(this.response.isCommitted()).isTrue();
...@@ -510,4 +531,60 @@ public class ErrorPageFilterTests { ...@@ -510,4 +531,60 @@ public class ErrorPageFilterTests {
asyncManager.startDeferredResultProcessing(result); asyncManager.startDeferredResultProcessing(result);
} }
private Map<String, Object> getAttributesForDispatch(String path) {
return this.request.getDispatcher(path).getRequestAttributes();
}
private static final class DispatchRecordingMockHttpServletRequest
extends MockHttpServletRequest {
private final Map<String, AttributeCapturingRequestDispatcher> dispatchers = new HashMap<String, AttributeCapturingRequestDispatcher>();
private DispatchRecordingMockHttpServletRequest() {
super("GET", "/test/path");
}
@Override
public RequestDispatcher getRequestDispatcher(String path) {
AttributeCapturingRequestDispatcher dispatcher = new AttributeCapturingRequestDispatcher(
path);
this.dispatchers.put(path, dispatcher);
return dispatcher;
}
private AttributeCapturingRequestDispatcher getDispatcher(String path) {
return this.dispatchers.get(path);
}
private static final class AttributeCapturingRequestDispatcher
extends MockRequestDispatcher {
private final Map<String, Object> requestAttributes = new HashMap<String, Object>();
private AttributeCapturingRequestDispatcher(String resource) {
super(resource);
}
@Override
public void forward(ServletRequest request, ServletResponse response) {
captureAttributes(request);
super.forward(request, response);
}
private void captureAttributes(ServletRequest request) {
Enumeration<String> names = request.getAttributeNames();
while (names.hasMoreElements()) {
String name = names.nextElement();
this.requestAttributes.put(name, request.getAttribute(name));
}
}
private Map<String, Object> getRequestAttributes() {
return this.requestAttributes;
}
}
}
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment