Commit 92100291 authored by Andy Wilkinson's avatar Andy Wilkinson

Record trace with response status of 500 following unhandled exception

Previously, if the filter chain threw an unhandled exception,
WebRequestTraceFilter would record a trace with a response status of
200. This occurred because response.getStatus() would return 200 as
the container had not yet caught the exception and mapped it to an
error response.

This commit updates WebRequestTraceFilter to align its behaviour with
MetricsFilter. It now assumes that the response status will be a 500
and only updates that to the status of the response if the call to the
filter chain returns successfully.

To avoid making a breaking change to the signature of the protected
enhanceTrace method, an HttpServletResponseWrapper is used to include
the correct status in the trace.

Closes gh-5331
parent 2e540780
......@@ -29,6 +29,7 @@ import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
import javax.servlet.http.HttpSession;
import org.apache.commons.logging.Log;
......@@ -37,6 +38,7 @@ import org.apache.commons.logging.LogFactory;
import org.springframework.boot.actuate.trace.TraceProperties.Include;
import org.springframework.boot.autoconfigure.web.ErrorAttributes;
import org.springframework.core.Ordered;
import org.springframework.http.HttpStatus;
import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.filter.OncePerRequestFilter;
......@@ -108,11 +110,14 @@ public class WebRequestTraceFilter extends OncePerRequestFilter implements Order
throws ServletException, IOException {
Map<String, Object> trace = getTrace(request);
logTrace(request, trace);
int status = HttpStatus.INTERNAL_SERVER_ERROR.value();
try {
filterChain.doFilter(request, response);
status = response.getStatus();
}
finally {
enhanceTrace(trace, response);
enhanceTrace(trace, status == response.getStatus() ? response
: new CustomStatusResponseWrapper(response, status));
this.repository.add(trace);
}
}
......@@ -214,4 +219,21 @@ public class WebRequestTraceFilter extends OncePerRequestFilter implements Order
this.errorAttributes = errorAttributes;
}
private static final class CustomStatusResponseWrapper
extends HttpServletResponseWrapper {
private final int status;
private CustomStatusResponseWrapper(HttpServletResponse response, int status) {
super(response);
this.status = status;
}
@Override
public int getStatus() {
return this.status;
}
}
}
......@@ -37,8 +37,12 @@ import org.springframework.boot.autoconfigure.web.DefaultErrorAttributes;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
......@@ -183,4 +187,32 @@ public class WebRequestTraceFilterTests {
assertEquals("Foo", map.get("message").toString());
}
@Test
@SuppressWarnings("unchecked")
public void filterHas500ResponseStatusWhenExceptionIsThrown()
throws ServletException, IOException {
MockHttpServletRequest request = new MockHttpServletRequest("GET", "/foo");
MockHttpServletResponse response = new MockHttpServletResponse();
try {
this.filter.doFilterInternal(request, response, new FilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
throw new RuntimeException();
}
});
fail("Exception was swallowed");
}
catch (RuntimeException ex) {
Map<String, Object> headers = (Map<String, Object>) this.repository.findAll()
.iterator().next().getInfo().get("headers");
Map<String, Object> responseHeaders = (Map<String, Object>) headers
.get("response");
assertThat((String) responseHeaders.get("status"), is(equalTo("500")));
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment