Commit c2f4d027 authored by Andy Wilkinson's avatar Andy Wilkinson

Merge branch '1.4.x' into 1.5.x

parents 64b65e1f 9247cd9c
/*
* Copyright 2012-2016 the original author or authors.
* Copyright 2012-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -181,6 +181,8 @@ public class ErrorPageFilter implements Filter, ErrorPageRegistry {
response.reset();
response.sendError(500, ex.getMessage());
request.getRequestDispatcher(path).forward(request, response);
request.removeAttribute(ERROR_EXCEPTION);
request.removeAttribute(ERROR_EXCEPTION_TYPE);
}
/**
......
/*
* Copyright 2012-2016 the original author or authors.
* Copyright 2012-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -17,6 +17,9 @@
package org.springframework.boot.web.support;
import java.io.IOException;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;
import javax.servlet.RequestDispatcher;
import javax.servlet.ServletException;
......@@ -35,6 +38,7 @@ import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockFilterConfig;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.mock.web.MockRequestDispatcher;
import org.springframework.web.context.request.async.DeferredResult;
import org.springframework.web.context.request.async.StandardServletAsyncWebRequest;
import org.springframework.web.context.request.async.WebAsyncManager;
......@@ -57,8 +61,7 @@ public class ErrorPageFilterTests {
private ErrorPageFilter filter = new ErrorPageFilter();
private MockHttpServletRequest request = new MockHttpServletRequest("GET",
"/test/path");
private DispatchRecordingMockHttpServletRequest request = new DispatchRecordingMockHttpServletRequest();
private MockHttpServletResponse response = new MockHttpServletResponse();
......@@ -261,8 +264,14 @@ public class ErrorPageFilterTests {
.isEqualTo(500);
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE))
.isEqualTo("BAD");
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE))
Map<String, Object> requestAttributes = getAttributesForDispatch("/500");
assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION_TYPE))
.isEqualTo(RuntimeException.class);
assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION))
.isInstanceOf(RuntimeException.class);
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE))
.isNull();
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION)).isNull();
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI))
.isEqualTo("/test/path");
assertThat(this.response.isCommitted()).isTrue();
......@@ -318,8 +327,14 @@ public class ErrorPageFilterTests {
.isEqualTo(500);
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE))
.isEqualTo("BAD");
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE))
Map<String, Object> requestAttributes = getAttributesForDispatch("/500");
assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION_TYPE))
.isEqualTo(IllegalStateException.class);
assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION))
.isInstanceOf(IllegalStateException.class);
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE))
.isNull();
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION)).isNull();
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI))
.isEqualTo("/test/path");
assertThat(this.response.isCommitted()).isTrue();
......@@ -492,8 +507,14 @@ public class ErrorPageFilterTests {
.isEqualTo(500);
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_MESSAGE))
.isEqualTo("BAD");
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE))
Map<String, Object> requestAttributes = getAttributesForDispatch("/500");
assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION_TYPE))
.isEqualTo(RuntimeException.class);
assertThat(requestAttributes.get(RequestDispatcher.ERROR_EXCEPTION))
.isInstanceOf(RuntimeException.class);
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE))
.isNull();
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION)).isNull();
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_REQUEST_URI))
.isEqualTo("/test/path");
assertThat(this.response.isCommitted()).isTrue();
......@@ -510,4 +531,60 @@ public class ErrorPageFilterTests {
asyncManager.startDeferredResultProcessing(result);
}
private Map<String, Object> getAttributesForDispatch(String path) {
return this.request.getDispatcher(path).getRequestAttributes();
}
private static final class DispatchRecordingMockHttpServletRequest
extends MockHttpServletRequest {
private final Map<String, AttributeCapturingRequestDispatcher> dispatchers = new HashMap<String, AttributeCapturingRequestDispatcher>();
private DispatchRecordingMockHttpServletRequest() {
super("GET", "/test/path");
}
@Override
public RequestDispatcher getRequestDispatcher(String path) {
AttributeCapturingRequestDispatcher dispatcher = new AttributeCapturingRequestDispatcher(
path);
this.dispatchers.put(path, dispatcher);
return dispatcher;
}
private AttributeCapturingRequestDispatcher getDispatcher(String path) {
return this.dispatchers.get(path);
}
private static final class AttributeCapturingRequestDispatcher
extends MockRequestDispatcher {
private final Map<String, Object> requestAttributes = new HashMap<String, Object>();
private AttributeCapturingRequestDispatcher(String resource) {
super(resource);
}
@Override
public void forward(ServletRequest request, ServletResponse response) {
captureAttributes(request);
super.forward(request, response);
}
private void captureAttributes(ServletRequest request) {
Enumeration<String> names = request.getAttributeNames();
while (names.hasMoreElements()) {
String name = names.nextElement();
this.requestAttributes.put(name, request.getAttribute(name));
}
}
private Map<String, Object> getRequestAttributes() {
return this.requestAttributes;
}
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment