Commit 13e040c0 authored by Dave Syer's avatar Dave Syer

Add ErrorWrapperEmbeddedServletContainerFactory for error pages in WARs

Error pages are a feature of the servlet spec but there is no Java API for
registering them in the spec. This filter works around that by accepting error page
registrations from Spring Boot's EmbeddedServletContainerCustomizer (any beans
of that type in the context will be applied to this container).

In addition the ErrorController interface was enhanced to provide callers
the option to suppress logging.

Fixes gh-410
parent 3f125fb8
......@@ -51,7 +51,7 @@ public class ManagementErrorEndpoint implements MvcEndpoint {
@ResponseBody
public Map<String, Object> invoke() {
RequestAttributes attributes = RequestContextHolder.currentRequestAttributes();
return this.controller.extract(attributes, false);
return this.controller.extract(attributes, false, true);
}
@Override
......
......@@ -36,6 +36,7 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.boot.actuate.web.BasicErrorController;
import org.springframework.core.Ordered;
import org.springframework.web.context.request.ServletRequestAttributes;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
......@@ -159,7 +160,8 @@ public class WebRequestTraceFilter implements Filter, Ordered {
.getAttribute("javax.servlet.error.exception");
if (error != null) {
if (this.errorController != null) {
trace.put("error", this.errorController.error(request));
trace.put("error", this.errorController.extract(
new ServletRequestAttributes(request), true, false));
}
}
return trace;
......
......@@ -72,11 +72,13 @@ public class BasicErrorController implements ErrorController {
public Map<String, Object> error(HttpServletRequest request) {
ServletRequestAttributes attributes = new ServletRequestAttributes(request);
String trace = request.getParameter("trace");
return extract(attributes, trace != null && !"false".equals(trace.toLowerCase()));
return extract(attributes, trace != null && !"false".equals(trace.toLowerCase()),
true);
}
@Override
public Map<String, Object> extract(RequestAttributes attributes, boolean trace) {
public Map<String, Object> extract(RequestAttributes attributes, boolean trace,
boolean log) {
Map<String, Object> map = new LinkedHashMap<String, Object>();
map.put("timestamp", new Date());
try {
......@@ -105,7 +107,9 @@ public class BasicErrorController implements ErrorController {
stackTrace.flush();
map.put("trace", stackTrace.toString());
}
this.logger.error(error);
if (log) {
this.logger.error(error);
}
}
else {
Object message = attributes.getAttribute("javax.servlet.error.message",
......@@ -117,7 +121,9 @@ public class BasicErrorController implements ErrorController {
catch (Exception ex) {
map.put(ERROR_KEY, ex.getClass().getName());
map.put("message", ex.getMessage());
this.logger.error(ex);
if (log) {
this.logger.error(ex);
}
return map;
}
}
......
......@@ -38,8 +38,10 @@ public interface ErrorController {
* Extract a useful model of the error from the request attributes.
* @param attributes the request attributes
* @param trace flag to indicate that stack trace information should be included
* @param log flag to indicate that an error should be logged
* @return a model containing error messages and codes etc.
*/
public Map<String, Object> extract(RequestAttributes attributes, boolean trace);
public Map<String, Object> extract(RequestAttributes attributes, boolean trace,
boolean log);
}
/*
* Copyright 2012-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.context.web;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
import org.springframework.boot.context.embedded.AbstractEmbeddedServletContainerFactory;
import org.springframework.boot.context.embedded.EmbeddedServletContainer;
import org.springframework.boot.context.embedded.EmbeddedServletContainerCustomizer;
import org.springframework.boot.context.embedded.EmbeddedServletContainerException;
import org.springframework.boot.context.embedded.EmbeddedServletContainerFactory;
import org.springframework.boot.context.embedded.ErrorPage;
import org.springframework.boot.context.embedded.ServletContextInitializer;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
/**
* A special {@link EmbeddedServletContainerFactory} for non-embedded applications (i.e.
* deployed WAR files). It registers error pages and handles application errors by
* filtering requests and forwarding to the error pages instead of letting the container
* handle them. Error pages are a feature of the servlet spec but there is no Java API for
* registering them in the spec. This filter works around that by accepting error page
* registrations from Spring Boot's {@link EmbeddedServletContainerCustomizer} (any beans
* of that type in the context will be applied to this container).
*
* @author Dave Syer
*
*/
@Component
@Order(Ordered.HIGHEST_PRECEDENCE)
public class ErrorWrapperEmbeddedServletContainerFactory extends
AbstractEmbeddedServletContainerFactory implements Filter {
private String global;
private Map<Integer, String> statuses = new HashMap<Integer, String>();
private Map<Class<? extends Throwable>, String> exceptions = new HashMap<Class<? extends Throwable>, String>();
@Override
public void init(FilterConfig filterConfig) throws ServletException {
}
@Override
public void doFilter(ServletRequest request, ServletResponse response,
FilterChain chain) throws IOException, ServletException {
String errorPath;
ErrorWrapperResponse wrapped = new ErrorWrapperResponse(
(HttpServletResponse) response);
try {
chain.doFilter(request, wrapped);
int status = wrapped.getStatus();
if (status >= 400) {
errorPath = this.statuses.containsKey(status) ? this.statuses.get(status)
: this.global;
if (errorPath != null) {
request.setAttribute("javax.servlet.error.status_code", status);
request.setAttribute("javax.servlet.error.message",
wrapped.getMessage());
((HttpServletRequest) request).getRequestDispatcher(errorPath)
.forward(request, response);
}
else {
((HttpServletResponse) response).sendError(status,
wrapped.getMessage());
}
}
}
catch (Throwable e) {
Class<? extends Throwable> cls = e.getClass();
errorPath = this.exceptions.containsKey(cls) ? this.exceptions.get(cls)
: this.global;
if (errorPath != null) {
request.setAttribute("javax.servlet.error.status_code", 500);
request.setAttribute("javax.servlet.error.exception", e);
request.setAttribute("javax.servlet.error.message", e.getMessage());
wrapped.sendError(500, e.getMessage());
((HttpServletRequest) request).getRequestDispatcher(errorPath).forward(
request, response);
}
else {
rethrow(e);
}
}
}
private void rethrow(Throwable e) throws IOException, ServletException {
if (e instanceof RuntimeException) {
throw (RuntimeException) e;
}
if (e instanceof Error) {
throw (Error) e;
}
if (e instanceof IOException) {
throw (IOException) e;
}
if (e instanceof ServletException) {
throw (ServletException) e;
}
throw new IllegalStateException("Unidentified Exception", e);
}
@Override
public EmbeddedServletContainer getEmbeddedServletContainer(
ServletContextInitializer... initializers) {
return new EmbeddedServletContainer() {
@Override
public void start() throws EmbeddedServletContainerException {
}
@Override
public void stop() throws EmbeddedServletContainerException {
}
@Override
public int getPort() {
return -1;
}
};
}
@Override
public void addErrorPages(ErrorPage... errorPages) {
for (ErrorPage errorPage : errorPages) {
if (errorPage.isGlobal()) {
this.global = errorPage.getPath();
}
else if (errorPage.getStatus() != null) {
this.statuses.put(errorPage.getStatus().value(), errorPage.getPath());
}
else {
this.exceptions.put(errorPage.getException(), errorPage.getPath());
}
}
}
@Override
public void destroy() {
}
private static class ErrorWrapperResponse extends HttpServletResponseWrapper {
private int status;
private String message;
public ErrorWrapperResponse(HttpServletResponse response) {
super(response);
}
@Override
public void sendError(int status) throws IOException {
sendError(status, null);
}
@Override
public void sendError(int status, String message) throws IOException {
this.status = status;
this.message = message;
}
@Override
public int getStatus() {
return this.status;
}
public String getMessage() {
return this.message;
}
}
}
......@@ -84,6 +84,8 @@ public abstract class SpringBootServletInitializer implements WebApplicationInit
servletContext));
application.contextClass(AnnotationConfigEmbeddedWebApplicationContext.class);
application = configure(application);
// Ensure error pages ar registered
application.sources(ErrorWrapperEmbeddedServletContainerFactory.class);
return (WebApplicationContext) application.run();
}
......
/*
* Copyright 2012-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.context.web;
import java.io.IOException;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
import org.junit.Test;
import org.springframework.boot.context.embedded.ErrorPage;
import org.springframework.http.HttpStatus;
import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import static org.junit.Assert.assertEquals;
/**
* @author Dave Syer
*/
public class ErrorWrapperEmbeddedServletContainerFactoryTests {
private ErrorWrapperEmbeddedServletContainerFactory filter = new ErrorWrapperEmbeddedServletContainerFactory();
private MockHttpServletRequest request = new MockHttpServletRequest();
private MockHttpServletResponse response = new MockHttpServletResponse();
private MockFilterChain chain = new MockFilterChain();
@Test
public void notAnError() throws Exception {
this.filter.doFilter(this.request, this.response, this.chain);
assertEquals(this.request, this.chain.getRequest());
assertEquals(this.response,
((HttpServletResponseWrapper) this.chain.getResponse()).getResponse());
}
@Test
public void globalError() throws Exception {
this.filter.addErrorPages(new ErrorPage("/error"));
this.chain = new MockFilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
((HttpServletResponse) response).sendError(400, "BAD");
super.doFilter(request, response);
}
};
this.filter.doFilter(this.request, this.response, this.chain);
assertEquals(400,
((HttpServletResponseWrapper) this.chain.getResponse()).getStatus());
assertEquals(400, this.request.getAttribute("javax.servlet.error.status_code"));
assertEquals("BAD", this.request.getAttribute("javax.servlet.error.message"));
}
@Test
public void statusError() throws Exception {
this.filter.addErrorPages(new ErrorPage(HttpStatus.BAD_REQUEST, "/400"));
this.chain = new MockFilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
((HttpServletResponse) response).sendError(400, "BAD");
super.doFilter(request, response);
}
};
this.filter.doFilter(this.request, this.response, this.chain);
assertEquals(400,
((HttpServletResponseWrapper) this.chain.getResponse()).getStatus());
assertEquals(400, this.request.getAttribute("javax.servlet.error.status_code"));
assertEquals("BAD", this.request.getAttribute("javax.servlet.error.message"));
}
@Test
public void exceptionError() throws Exception {
this.filter.addErrorPages(new ErrorPage(RuntimeException.class, "/500"));
this.chain = new MockFilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
super.doFilter(request, response);
throw new RuntimeException("BAD");
}
};
this.filter.doFilter(this.request, this.response, this.chain);
assertEquals(500,
((HttpServletResponseWrapper) this.chain.getResponse()).getStatus());
assertEquals(500, this.request.getAttribute("javax.servlet.error.status_code"));
assertEquals("BAD", this.request.getAttribute("javax.servlet.error.message"));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment