diff --git a/spring-test/src/main/java/org/springframework/mock/web/MockAsyncContext.java b/spring-test/src/main/java/org/springframework/mock/web/MockAsyncContext.java index c36c67e9af..ed63a6902d 100644 --- a/spring-test/src/main/java/org/springframework/mock/web/MockAsyncContext.java +++ b/spring-test/src/main/java/org/springframework/mock/web/MockAsyncContext.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -65,7 +65,14 @@ public class MockAsyncContext implements AsyncContext { public void addDispatchHandler(Runnable handler) { Assert.notNull(handler, "Dispatch handler must not be null"); - this.dispatchHandlers.add(handler); + synchronized (this) { + if (this.dispatchedPath == null) { + this.dispatchHandlers.add(handler); + } + else { + handler.run(); + } + } } @Override @@ -96,9 +103,9 @@ public class MockAsyncContext implements AsyncContext { @Override public void dispatch(@Nullable ServletContext context, String path) { - this.dispatchedPath = path; - for (Runnable r : this.dispatchHandlers) { - r.run(); + synchronized (this) { + this.dispatchedPath = path; + this.dispatchHandlers.forEach(Runnable::run); } } diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/DefaultMvcResult.java b/spring-test/src/main/java/org/springframework/test/web/servlet/DefaultMvcResult.java index 084f70b71d..5e5ec0a4e5 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/DefaultMvcResult.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/DefaultMvcResult.java @@ -16,11 +16,14 @@ package org.springframework.test.web.servlet; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import org.springframework.lang.Nullable; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.util.Assert; import org.springframework.web.servlet.FlashMap; import org.springframework.web.servlet.HandlerInterceptor; import org.springframework.web.servlet.ModelAndView; @@ -56,6 +59,9 @@ class DefaultMvcResult implements MvcResult { private final AtomicReference asyncResult = new AtomicReference<>(RESULT_NONE); + @Nullable + private CountDownLatch asyncDispatchLatch; + /** * Create a new instance with the given request and response. @@ -135,27 +141,31 @@ class DefaultMvcResult implements MvcResult { if (this.mockRequest.getAsyncContext() != null) { timeToWait = (timeToWait == -1 ? this.mockRequest.getAsyncContext().getTimeout() : timeToWait); } - - if (timeToWait > 0) { - long endTime = System.currentTimeMillis() + timeToWait; - while (System.currentTimeMillis() < endTime && this.asyncResult.get() == RESULT_NONE) { - try { - Thread.sleep(100); - } - catch (InterruptedException ex) { - Thread.currentThread().interrupt(); - throw new IllegalStateException("Interrupted while waiting for " + - "async result to be set for handler [" + this.handler + "]", ex); - } - } + if (!awaitAsyncDispatch(timeToWait)) { + throw new IllegalStateException("Async result for handler [" + this.handler + "]" + + " was not set during the specified timeToWait=" + timeToWait); } - Object result = this.asyncResult.get(); - if (result == RESULT_NONE) { - throw new IllegalStateException("Async result for handler [" + this.handler + "] " + - "was not set during the specified timeToWait=" + timeToWait); + Assert.state(result != RESULT_NONE, "Async result for handler [" + this.handler + "] was not set"); + return this.asyncResult.get(); + } + + /** + * True if is there a latch was not set, or the latch count reached 0. + */ + private boolean awaitAsyncDispatch(long timeout) { + Assert.state(this.asyncDispatchLatch != null, + "The asynDispatch CountDownLatch was not set by the TestDispatcherServlet.\n"); + try { + return this.asyncDispatchLatch.await(timeout, TimeUnit.MILLISECONDS); } - return result; + catch (InterruptedException e) { + return false; + } + } + + void setAsyncDispatchLatch(CountDownLatch asyncDispatchLatch) { + this.asyncDispatchLatch = asyncDispatchLatch; } } diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java b/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java index 6a858c3500..a85592b2b7 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,18 +18,19 @@ package org.springframework.test.web.servlet; import java.io.IOException; import java.util.concurrent.Callable; +import java.util.concurrent.CountDownLatch; import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.springframework.lang.Nullable; +import org.springframework.mock.web.MockAsyncContext; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.context.request.NativeWebRequest; import org.springframework.web.context.request.async.CallableProcessingInterceptor; import org.springframework.web.context.request.async.DeferredResult; import org.springframework.web.context.request.async.DeferredResultProcessingInterceptor; -import org.springframework.web.context.request.async.WebAsyncManager; import org.springframework.web.context.request.async.WebAsyncUtils; import org.springframework.web.servlet.DispatcherServlet; import org.springframework.web.servlet.HandlerExecutionChain; @@ -63,23 +64,34 @@ final class TestDispatcherServlet extends DispatcherServlet { throws ServletException, IOException { registerAsyncResultInterceptors(request); + super.service(request, response); + + if (request.getAsyncContext() != null) { + CountDownLatch dispatchLatch = new CountDownLatch(1); + ((MockAsyncContext) request.getAsyncContext()).addDispatchHandler(dispatchLatch::countDown); + getMvcResult(request).setAsyncDispatchLatch(dispatchLatch); + } } private void registerAsyncResultInterceptors(final HttpServletRequest request) { - WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(request); - asyncManager.registerCallableInterceptor(KEY, new CallableProcessingInterceptor() { - @Override - public void postProcess(NativeWebRequest r, Callable task, Object value) throws Exception { - getMvcResult(request).setAsyncResult(value); - } - }); - asyncManager.registerDeferredResultInterceptor(KEY, new DeferredResultProcessingInterceptor() { - @Override - public void postProcess(NativeWebRequest r, DeferredResult result, Object value) throws Exception { - getMvcResult(request).setAsyncResult(value); - } - }); + + WebAsyncUtils.getAsyncManager(request).registerCallableInterceptor(KEY, + new CallableProcessingInterceptor() { + @Override + public void postProcess(NativeWebRequest r, Callable task, Object value) { + // We got the result, must also wait for the dispatch + getMvcResult(request).setAsyncResult(value); + } + }); + + WebAsyncUtils.getAsyncManager(request).registerDeferredResultInterceptor(KEY, + new DeferredResultProcessingInterceptor() { + @Override + public void postProcess(NativeWebRequest r, DeferredResult result, Object value) { + getMvcResult(request).setAsyncResult(value); + } + }); } protected DefaultMvcResult getMvcResult(ServletRequest request) { diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/DefaultMvcResultTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/DefaultMvcResultTests.java index cf17c50329..87f75b1479 100644 --- a/spring-test/src/test/java/org/springframework/test/web/servlet/DefaultMvcResultTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/DefaultMvcResultTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,6 +15,8 @@ */ package org.springframework.test.web.servlet; +import java.util.concurrent.CountDownLatch; + import org.junit.Before; import org.junit.Test; @@ -38,13 +40,14 @@ public class DefaultMvcResultTests { } @Test - public void getAsyncResultSuccess() throws Exception { + public void getAsyncResultSuccess() { this.mvcResult.setAsyncResult("Foo"); - assertEquals("Foo", this.mvcResult.getAsyncResult()); + this.mvcResult.setAsyncDispatchLatch(new CountDownLatch(0)); + this.mvcResult.getAsyncResult(); } @Test(expected = IllegalStateException.class) - public void getAsyncResultFailure() throws Exception { + public void getAsyncResultFailure() { this.mvcResult.getAsyncResult(0); } diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/MockAsyncContext.java b/spring-web/src/test/java/org/springframework/mock/web/test/MockAsyncContext.java index f16e1f6b40..c9cdfb39d8 100644 --- a/spring-web/src/test/java/org/springframework/mock/web/test/MockAsyncContext.java +++ b/spring-web/src/test/java/org/springframework/mock/web/test/MockAsyncContext.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,6 +30,7 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.springframework.beans.BeanUtils; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.web.util.WebUtils; @@ -43,10 +44,12 @@ public class MockAsyncContext implements AsyncContext { private final HttpServletRequest request; + @Nullable private final HttpServletResponse response; private final List listeners = new ArrayList<>(); + @Nullable private String dispatchedPath; private long timeout = 10 * 1000L; // 10 seconds is Tomcat's default @@ -54,7 +57,7 @@ public class MockAsyncContext implements AsyncContext { private final List dispatchHandlers = new ArrayList<>(); - public MockAsyncContext(ServletRequest request, ServletResponse response) { + public MockAsyncContext(ServletRequest request, @Nullable ServletResponse response) { this.request = (HttpServletRequest) request; this.response = (HttpServletResponse) response; } @@ -62,7 +65,14 @@ public class MockAsyncContext implements AsyncContext { public void addDispatchHandler(Runnable handler) { Assert.notNull(handler, "Dispatch handler must not be null"); - this.dispatchHandlers.add(handler); + synchronized (this) { + if (this.dispatchedPath == null) { + this.dispatchHandlers.add(handler); + } + else { + handler.run(); + } + } } @Override @@ -71,6 +81,7 @@ public class MockAsyncContext implements AsyncContext { } @Override + @Nullable public ServletResponse getResponse() { return this.response; } @@ -91,13 +102,14 @@ public class MockAsyncContext implements AsyncContext { } @Override - public void dispatch(ServletContext context, String path) { - this.dispatchedPath = path; - for (Runnable r : this.dispatchHandlers) { - r.run(); + public void dispatch(@Nullable ServletContext context, String path) { + synchronized (this) { + this.dispatchedPath = path; + this.dispatchHandlers.forEach(Runnable::run); } } + @Nullable public String getDispatchedPath() { return this.dispatchedPath; }