Commit 13c01376 authored by Andy Wilkinson's avatar Andy Wilkinson

Merge branch '1.3.x'

parents b9d7a396 92100291
...@@ -29,6 +29,7 @@ import javax.servlet.FilterChain; ...@@ -29,6 +29,7 @@ import javax.servlet.FilterChain;
import javax.servlet.ServletException; import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
import javax.servlet.http.HttpSession; import javax.servlet.http.HttpSession;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
...@@ -37,6 +38,7 @@ import org.apache.commons.logging.LogFactory; ...@@ -37,6 +38,7 @@ import org.apache.commons.logging.LogFactory;
import org.springframework.boot.actuate.trace.TraceProperties.Include; import org.springframework.boot.actuate.trace.TraceProperties.Include;
import org.springframework.boot.autoconfigure.web.ErrorAttributes; import org.springframework.boot.autoconfigure.web.ErrorAttributes;
import org.springframework.core.Ordered; import org.springframework.core.Ordered;
import org.springframework.http.HttpStatus;
import org.springframework.web.context.request.ServletRequestAttributes; import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.filter.OncePerRequestFilter;
...@@ -97,11 +99,14 @@ public class WebRequestTraceFilter extends OncePerRequestFilter implements Order ...@@ -97,11 +99,14 @@ public class WebRequestTraceFilter extends OncePerRequestFilter implements Order
throws ServletException, IOException { throws ServletException, IOException {
Map<String, Object> trace = getTrace(request); Map<String, Object> trace = getTrace(request);
logTrace(request, trace); logTrace(request, trace);
int status = HttpStatus.INTERNAL_SERVER_ERROR.value();
try { try {
filterChain.doFilter(request, response); filterChain.doFilter(request, response);
status = response.getStatus();
} }
finally { finally {
enhanceTrace(trace, response); enhanceTrace(trace, status == response.getStatus() ? response
: new CustomStatusResponseWrapper(response, status));
this.repository.add(trace); this.repository.add(trace);
} }
} }
...@@ -203,4 +208,21 @@ public class WebRequestTraceFilter extends OncePerRequestFilter implements Order ...@@ -203,4 +208,21 @@ public class WebRequestTraceFilter extends OncePerRequestFilter implements Order
this.errorAttributes = errorAttributes; this.errorAttributes = errorAttributes;
} }
private static final class CustomStatusResponseWrapper
extends HttpServletResponseWrapper {
private final int status;
private CustomStatusResponseWrapper(HttpServletResponse response, int status) {
super(response);
this.status = status;
}
@Override
public int getStatus() {
return this.status;
}
}
} }
...@@ -38,6 +38,7 @@ import org.springframework.mock.web.MockHttpServletRequest; ...@@ -38,6 +38,7 @@ import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
...@@ -183,4 +184,32 @@ public class WebRequestTraceFilterTests { ...@@ -183,4 +184,32 @@ public class WebRequestTraceFilterTests {
assertThat(map.get("message").toString()).isEqualTo("Foo"); assertThat(map.get("message").toString()).isEqualTo("Foo");
} }
@Test
@SuppressWarnings("unchecked")
public void filterHas500ResponseStatusWhenExceptionIsThrown()
throws ServletException, IOException {
MockHttpServletRequest request = new MockHttpServletRequest("GET", "/foo");
MockHttpServletResponse response = new MockHttpServletResponse();
try {
this.filter.doFilterInternal(request, response, new FilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
throw new RuntimeException();
}
});
fail("Exception was swallowed");
}
catch (RuntimeException ex) {
Map<String, Object> headers = (Map<String, Object>) this.repository.findAll()
.iterator().next().getInfo().get("headers");
Map<String, Object> responseHeaders = (Map<String, Object>) headers
.get("response");
assertThat((String) responseHeaders.get("status")).isEqualTo("500");
}
}
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment