diff --git a/org.springframework.integration/src/main/java/org/springframework/integration/channel/ExecutorChannel.java b/org.springframework.integration/src/main/java/org/springframework/integration/channel/ExecutorChannel.java index 18758284e9..0c8ed430b5 100644 --- a/org.springframework.integration/src/main/java/org/springframework/integration/channel/ExecutorChannel.java +++ b/org.springframework.integration/src/main/java/org/springframework/integration/channel/ExecutorChannel.java @@ -16,10 +16,15 @@ package org.springframework.integration.channel; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.BeanFactoryAware; import org.springframework.core.task.TaskExecutor; import org.springframework.integration.core.MessageChannel; import org.springframework.integration.dispatcher.LoadBalancingStrategy; import org.springframework.integration.dispatcher.UnicastingDispatcher; +import org.springframework.integration.util.ErrorHandler; +import org.springframework.integration.util.ErrorHandlingTaskExecutor; +import org.springframework.util.Assert; /** * An implementation of {@link MessageChannel} that delegates to an instance of @@ -36,26 +41,41 @@ import org.springframework.integration.dispatcher.UnicastingDispatcher; * @author Mark Fisher * @since 1.0.3 */ -public class ExecutorChannel extends AbstractSubscribableChannel { +public class ExecutorChannel extends AbstractSubscribableChannel implements BeanFactoryAware { - private final UnicastingDispatcher dispatcher; + private volatile UnicastingDispatcher dispatcher; + + private volatile TaskExecutor taskExecutor; + + private volatile boolean failover = true; + + private volatile LoadBalancingStrategy loadBalancingStrategy; /** * Create an ExecutorChannel that delegates to the provided * {@link TaskExecutor} when dispatching Messages. + *

+ * The TaskExecutor must not be null. */ public ExecutorChannel(TaskExecutor taskExecutor) { - this.dispatcher = new UnicastingDispatcher(taskExecutor); + this(taskExecutor, null); } /** - * Create an ExecutorChannel with a {@link LoadBalancingStrategy}. The - * strategy must not be null. + * Create an ExecutorChannel with a {@link LoadBalancingStrategy} that + * delegates to the provided {@link TaskExecutor} when dispatching Messages. + *

+ * The TaskExecutor must not be null. */ public ExecutorChannel(TaskExecutor taskExecutor, LoadBalancingStrategy loadBalancingStrategy) { - this(taskExecutor); - this.dispatcher.setLoadBalancingStrategy(loadBalancingStrategy); + Assert.notNull(taskExecutor, "taskExecutor must not be null"); + this.taskExecutor = taskExecutor; + this.dispatcher = new UnicastingDispatcher(taskExecutor); + if (loadBalancingStrategy != null) { + this.loadBalancingStrategy = loadBalancingStrategy; + this.dispatcher.setLoadBalancingStrategy(loadBalancingStrategy); + } } @@ -64,6 +84,7 @@ public class ExecutorChannel extends AbstractSubscribableChannel { * By default, it will. Set this value to 'false' to disable it. */ public void setFailover(boolean failover) { + this.failover = failover; this.dispatcher.setFailover(failover); } @@ -72,4 +93,16 @@ public class ExecutorChannel extends AbstractSubscribableChannel { return this.dispatcher; } + public void setBeanFactory(BeanFactory beanFactory) { + if (!(this.taskExecutor instanceof ErrorHandlingTaskExecutor)) { + ErrorHandler errorHandler = new MessagePublishingErrorHandler(new BeanFactoryChannelResolver(beanFactory)); + this.taskExecutor = new ErrorHandlingTaskExecutor(this.taskExecutor, errorHandler); + } + this.dispatcher = new UnicastingDispatcher(this.taskExecutor); + this.dispatcher.setFailover(this.failover); + if (this.loadBalancingStrategy != null) { + this.dispatcher.setLoadBalancingStrategy(this.loadBalancingStrategy); + } + } + } diff --git a/org.springframework.integration/src/test/java/org/springframework/integration/channel/DispatchingChannelErrorHandlingTests.java b/org.springframework.integration/src/test/java/org/springframework/integration/channel/DispatchingChannelErrorHandlingTests.java new file mode 100644 index 0000000000..84d7c68772 --- /dev/null +++ b/org.springframework.integration/src/test/java/org/springframework/integration/channel/DispatchingChannelErrorHandlingTests.java @@ -0,0 +1,155 @@ +/* + * Copyright 2002-2009 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.integration.channel; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertSame; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import org.junit.Test; + +import org.springframework.context.support.StaticApplicationContext; +import org.springframework.core.task.SimpleAsyncTaskExecutor; +import org.springframework.core.task.TaskExecutor; +import org.springframework.integration.channel.DirectChannel; +import org.springframework.integration.context.IntegrationContextUtils; +import org.springframework.integration.core.Message; +import org.springframework.integration.core.MessagingException; +import org.springframework.integration.message.MessageBuilder; +import org.springframework.integration.message.MessageHandler; + +/** + * @author Mark Fisher + * @since 1.0.3 + */ +public class DispatchingChannelErrorHandlingTests { + + private final CountDownLatch latch = new CountDownLatch(1); + + + @Test(expected = UnsupportedOperationException.class) + public void handlerThrowsExceptionPublishSubscribeWithoutExecutor() { + PublishSubscribeChannel channel = new PublishSubscribeChannel(); + channel.subscribe(new MessageHandler() { + public void handleMessage(Message message) { + throw new UnsupportedOperationException("intentional test failure"); + } + }); + Message message = MessageBuilder.withPayload("test").build(); + channel.send(message); + } + + @Test + public void handlerThrowsExceptionPublishSubscribeWithExecutor() { + StaticApplicationContext context = new StaticApplicationContext(); + context.registerSingleton( + IntegrationContextUtils.ERROR_CHANNEL_BEAN_NAME, DirectChannel.class); + context.refresh(); + DirectChannel defaultErrorChannel = (DirectChannel) context.getBean( + IntegrationContextUtils.ERROR_CHANNEL_BEAN_NAME); + TaskExecutor executor = new SimpleAsyncTaskExecutor(); + PublishSubscribeChannel channel = new PublishSubscribeChannel(executor); + channel.setBeanFactory(context); + ResultHandler resultHandler = new ResultHandler(); + defaultErrorChannel.subscribe(resultHandler); + channel.subscribe(new MessageHandler() { + public void handleMessage(Message message) { + throw new MessagingException(message, + new UnsupportedOperationException("intentional test failure")); + } + }); + Message message = MessageBuilder.withPayload("test").build(); + channel.send(message); + this.waitForLatch(1000); + Message errorMessage = resultHandler.lastMessage; + assertEquals(MessagingException.class, errorMessage.getPayload().getClass()); + MessagingException exceptionPayload = (MessagingException) errorMessage.getPayload(); + assertEquals(UnsupportedOperationException.class, exceptionPayload.getCause().getClass()); + assertSame(message, exceptionPayload.getFailedMessage()); + assertNotSame(Thread.currentThread(), resultHandler.lastThread); + } + + @Test + public void handlerThrowsExceptionExecutorChannel() { + StaticApplicationContext context = new StaticApplicationContext(); + context.registerSingleton( + IntegrationContextUtils.ERROR_CHANNEL_BEAN_NAME, DirectChannel.class); + context.refresh(); + DirectChannel defaultErrorChannel = (DirectChannel) context.getBean( + IntegrationContextUtils.ERROR_CHANNEL_BEAN_NAME); + TaskExecutor executor = new SimpleAsyncTaskExecutor(); + ExecutorChannel channel = new ExecutorChannel(executor); + channel.setBeanFactory(context); + ResultHandler resultHandler = new ResultHandler(); + defaultErrorChannel.subscribe(resultHandler); + channel.subscribe(new MessageHandler() { + public void handleMessage(Message message) { + throw new MessagingException(message, + new UnsupportedOperationException("intentional test failure")); + } + }); + Message message = MessageBuilder.withPayload("test").build(); + channel.send(message); + this.waitForLatch(1000); + Message errorMessage = resultHandler.lastMessage; + assertEquals(MessagingException.class, errorMessage.getPayload().getClass()); + MessagingException exceptionPayload = (MessagingException) errorMessage.getPayload(); + assertEquals(UnsupportedOperationException.class, exceptionPayload.getCause().getClass()); + assertSame(message, exceptionPayload.getFailedMessage()); + assertNotSame(Thread.currentThread(), resultHandler.lastThread); + } + + + private void waitForLatch(long timeout) { + try { + this.latch.await(timeout, TimeUnit.MILLISECONDS); + if (latch.getCount() != 0) { + throw new TestTimedOutException(); + } + } + catch (InterruptedException e) { + throw new RuntimeException("interrupted while waiting for latch"); + } + } + + + private class ResultHandler implements MessageHandler { + + private volatile Message lastMessage; + + private volatile Thread lastThread; + + public void handleMessage(Message message) { + this.lastMessage = message; + this.lastThread = Thread.currentThread(); + latch.countDown(); + } + } + + + @SuppressWarnings("serial") + private static class TestTimedOutException extends RuntimeException { + + public TestTimedOutException() { + super("timed out while waiting for latch"); + } + } + +} diff --git a/org.springframework.integration/src/test/java/org/springframework/integration/channel/config/DispatchingChannelParserTests.java b/org.springframework.integration/src/test/java/org/springframework/integration/channel/config/DispatchingChannelParserTests.java index 809deb09cc..419aa094e8 100644 --- a/org.springframework.integration/src/test/java/org/springframework/integration/channel/config/DispatchingChannelParserTests.java +++ b/org.springframework.integration/src/test/java/org/springframework/integration/channel/config/DispatchingChannelParserTests.java @@ -36,6 +36,7 @@ import org.springframework.integration.channel.DirectChannel; import org.springframework.integration.channel.ExecutorChannel; import org.springframework.integration.core.MessageChannel; import org.springframework.integration.dispatcher.RoundRobinLoadBalancingStrategy; +import org.springframework.integration.util.ErrorHandlingTaskExecutor; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; @@ -71,7 +72,10 @@ public class DispatchingChannelParserTests { public void taskExecutorOnly() { MessageChannel channel = channels.get("taskExecutorOnly"); assertEquals(ExecutorChannel.class, channel.getClass()); - assertSame(context.getBean("taskExecutor"), getDispatcherProperty("taskExecutor", channel)); + Object executor = getDispatcherProperty("taskExecutor", channel); + assertEquals(ErrorHandlingTaskExecutor.class, executor.getClass()); + assertSame(context.getBean("taskExecutor"), + new DirectFieldAccessor(executor).getPropertyValue("taskExecutor")); assertTrue((Boolean) getDispatcherProperty("failover", channel)); assertEquals(RoundRobinLoadBalancingStrategy.class, getDispatcherProperty("loadBalancingStrategy", channel).getClass()); @@ -109,7 +113,10 @@ public class DispatchingChannelParserTests { assertEquals(ExecutorChannel.class, channel.getClass()); assertTrue((Boolean) getDispatcherProperty("failover", channel)); assertNull(getDispatcherProperty("loadBalancingStrategy", channel)); - assertSame(context.getBean("taskExecutor"), getDispatcherProperty("taskExecutor", channel)); + Object executor = getDispatcherProperty("taskExecutor", channel); + assertEquals(ErrorHandlingTaskExecutor.class, executor.getClass()); + assertSame(context.getBean("taskExecutor"), + new DirectFieldAccessor(executor).getPropertyValue("taskExecutor")); } @Test @@ -119,7 +126,10 @@ public class DispatchingChannelParserTests { assertTrue((Boolean) getDispatcherProperty("failover", channel)); assertEquals(RoundRobinLoadBalancingStrategy.class, getDispatcherProperty("loadBalancingStrategy", channel).getClass()); - assertSame(context.getBean("taskExecutor"), getDispatcherProperty("taskExecutor", channel)); + Object executor = getDispatcherProperty("taskExecutor", channel); + assertEquals(ErrorHandlingTaskExecutor.class, executor.getClass()); + assertSame(context.getBean("taskExecutor"), + new DirectFieldAccessor(executor).getPropertyValue("taskExecutor")); }