diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/CallableInterceptorChain.java b/spring-web/src/main/java/org/springframework/web/context/request/async/CallableInterceptorChain.java index b0984fbee6..19bfc0b37c 100644 --- a/spring-web/src/main/java/org/springframework/web/context/request/async/CallableInterceptorChain.java +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/CallableInterceptorChain.java @@ -18,6 +18,7 @@ package org.springframework.web.context.request.async; import java.util.List; import java.util.concurrent.Callable; +import java.util.concurrent.Future; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -39,11 +40,19 @@ class CallableInterceptorChain { private int preProcessIndex = -1; + private volatile Future taskFuture; + public CallableInterceptorChain(List interceptors) { this.interceptors = interceptors; } + + public void setTaskFuture(Future taskFuture) { + this.taskFuture = taskFuture; + } + + public void applyBeforeConcurrentHandling(NativeWebRequest request, Callable task) throws Exception { for (CallableProcessingInterceptor interceptor : this.interceptors) { interceptor.beforeConcurrentHandling(request, task); @@ -77,6 +86,7 @@ class CallableInterceptorChain { } public Object triggerAfterTimeout(NativeWebRequest request, Callable task) { + cancelTask(); for (CallableProcessingInterceptor interceptor : this.interceptors) { try { Object result = interceptor.handleTimeout(request, task); @@ -94,6 +104,18 @@ class CallableInterceptorChain { return CallableProcessingInterceptor.RESULT_NONE; } + private void cancelTask() { + Future future = this.taskFuture; + if (future != null) { + try { + future.cancel(true); + } + catch (Throwable ex) { + // Ignore + } + } + } + public void triggerAfterCompletion(NativeWebRequest request, Callable task) { for (int i = this.interceptors.size()-1; i >= 0; i--) { try { diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncManager.java b/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncManager.java index 83913f02b7..597f987180 100644 --- a/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncManager.java +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncManager.java @@ -21,6 +21,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.concurrent.Callable; +import java.util.concurrent.Future; import java.util.concurrent.RejectedExecutionException; import javax.servlet.http.HttpServletRequest; @@ -307,7 +308,7 @@ public final class WebAsyncManager { interceptorChain.applyBeforeConcurrentHandling(this.asyncWebRequest, callable); startAsyncProcessing(processingContext); try { - this.taskExecutor.submit(new Runnable() { + Future future = this.taskExecutor.submit(new Runnable() { @Override public void run() { Object result = null; @@ -324,6 +325,7 @@ public final class WebAsyncManager { setConcurrentResultAndDispatch(result); } }); + interceptorChain.setTaskFuture(future); } catch (RejectedExecutionException ex) { Object result = interceptorChain.applyPostProcess(this.asyncWebRequest, callable, ex); diff --git a/spring-web/src/test/java/org/springframework/web/context/request/async/WebAsyncManagerTimeoutTests.java b/spring-web/src/test/java/org/springframework/web/context/request/async/WebAsyncManagerTimeoutTests.java index d8ecfe834f..268f5fce44 100644 --- a/spring-web/src/test/java/org/springframework/web/context/request/async/WebAsyncManagerTimeoutTests.java +++ b/spring-web/src/test/java/org/springframework/web/context/request/async/WebAsyncManagerTimeoutTests.java @@ -17,6 +17,7 @@ package org.springframework.web.context.request.async; import java.util.concurrent.Callable; +import java.util.concurrent.Future; import javax.servlet.AsyncEvent; import org.junit.Before; @@ -28,9 +29,15 @@ import org.springframework.mock.web.test.MockHttpServletRequest; import org.springframework.mock.web.test.MockHttpServletResponse; import org.springframework.web.context.request.NativeWebRequest; -import static org.junit.Assert.*; -import static org.mockito.BDDMockito.*; -import static org.springframework.web.context.request.async.CallableProcessingInterceptor.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.BDDMockito.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.mock; +import static org.mockito.BDDMockito.verify; +import static org.mockito.BDDMockito.verifyNoMoreInteractions; +import static org.mockito.BDDMockito.when; +import static org.springframework.web.context.request.async.CallableProcessingInterceptor.RESULT_NONE; /** * {@link WebAsyncManager} tests where container-triggered timeout/completion @@ -148,6 +155,27 @@ public class WebAsyncManagerTimeoutTests { verify(interceptor).beforeConcurrentHandling(this.asyncWebRequest, callable); } + @SuppressWarnings("unchecked") + @Test + public void startCallableProcessingTimeoutAndCheckThreadInterrupted() throws Exception { + + StubCallable callable = new StubCallable(); + Future future = mock(Future.class); + + AsyncTaskExecutor executor = mock(AsyncTaskExecutor.class); + when(executor.submit(any(Runnable.class))).thenReturn(future); + + this.asyncManager.setTaskExecutor(executor); + this.asyncManager.startCallableProcessing(callable); + + this.asyncWebRequest.onTimeout(ASYNC_EVENT); + + assertTrue(this.asyncManager.hasConcurrentResult()); + + verify(future).cancel(true); + verifyNoMoreInteractions(future); + } + @Test public void startDeferredResultProcessingTimeoutAndComplete() throws Exception {