diff --git a/org.springframework.integration/src/test/java/org/springframework/integration/handler/MessageHandlerChainTests.java b/org.springframework.integration/src/test/java/org/springframework/integration/handler/MessageHandlerChainTests.java index b0ff9753a3..a569d6fb98 100644 --- a/org.springframework.integration/src/test/java/org/springframework/integration/handler/MessageHandlerChainTests.java +++ b/org.springframework.integration/src/test/java/org/springframework/integration/handler/MessageHandlerChainTests.java @@ -16,8 +16,7 @@ package org.springframework.integration.handler; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; +import static org.easymock.EasyMock.*; import java.util.ArrayList; import java.util.List; @@ -25,98 +24,144 @@ import java.util.List; import org.junit.Test; import org.springframework.beans.factory.support.DefaultListableBeanFactory; -import org.springframework.integration.channel.QueueChannel; import org.springframework.integration.core.Message; +import org.springframework.integration.core.MessageChannel; import org.springframework.integration.message.MessageBuilder; import org.springframework.integration.message.MessageHandler; -import org.springframework.integration.message.StringMessage; /** * @author Mark Fisher + * @author Iwein Fuld */ public class MessageHandlerChainTests { + private MessageChannel outputChannel = createMock(MessageChannel.class); + + private Message message = MessageBuilder.withPayload("foo").build(); + + private MessageHandler handler = createMock(MessageHandler.class); + + private ProducingHandlerStub producer = new ProducingHandlerStub(handler); + + private Object[] allMocks = new Object[] { outputChannel, handler }; + @Test public void chainWithOutputChannel() { - QueueChannel outputChannel = new QueueChannel(); + handler.handleMessage(message); + expectLastCall().times(3); + expect(outputChannel.send(eq(message), anyLong())).andReturn(true); + replay(allMocks); List handlers = new ArrayList(); - handlers.add(createHandler(1)); - handlers.add(createHandler(2)); - handlers.add(createHandler(3)); + handlers.add(producer); + handlers.add(producer); + handlers.add(producer); MessageHandlerChain chain = new MessageHandlerChain(); chain.setBeanName("testChain"); chain.setHandlers(handlers); chain.setOutputChannel(outputChannel); - chain.handleMessage(new StringMessage("test")); - Message reply = outputChannel.receive(0); - assertNotNull(reply); - assertEquals("test123", reply.getPayload()); + chain.handleMessage(message); } @Test(expected = IllegalArgumentException.class) public void chainWithOutputChannelButLastHandlerDoesNotProduceReplies() { - QueueChannel outputChannel = new QueueChannel(); + replay(allMocks); List handlers = new ArrayList(); - handlers.add(createHandler(1)); - handlers.add(createHandler(2)); - handlers.add(new MessageHandler() { - public void handleMessage(Message message) { - } - }); + handlers.add(producer); + handlers.add(producer); + handlers.add(handler); MessageHandlerChain chain = new MessageHandlerChain(); chain.setBeanName("testChain"); chain.setHandlers(handlers); chain.setOutputChannel(outputChannel); - chain.handleMessage(new StringMessage("test")); + chain.handleMessage(message); + } + @Test + public void chainWithoutOutputChannelButLastHandlerDoesNotProduceReplies() { + handler.handleMessage(message); + expectLastCall().times(3); + replay(allMocks); + List handlers = new ArrayList(); + handlers.add(producer); + handlers.add(producer); + handlers.add(handler); + MessageHandlerChain chain = new MessageHandlerChain(); + chain.setBeanName("testChain"); + chain.setHandlers(handlers); + chain.handleMessage(message); } @Test public void chainForwardsToReplyChannel() { - QueueChannel replyChannel = new QueueChannel(); + Message message = MessageBuilder.withPayload("test").setReplyChannel(outputChannel).build(); + handler.handleMessage(message); + expectLastCall().times(3); + //equality is lost when recreating the message + expect(outputChannel.send(isA(Message.class), anyLong())).andReturn(true); + replay(allMocks); List handlers = new ArrayList(); - handlers.add(createHandler(1)); - handlers.add(createHandler(2)); - handlers.add(createHandler(3)); + handlers.add(producer); + handlers.add(producer); + handlers.add(producer); MessageHandlerChain chain = new MessageHandlerChain(); chain.setBeanName("testChain"); chain.setHandlers(handlers); - Message message = MessageBuilder.withPayload("test") - .setReplyChannel(replyChannel).build(); chain.handleMessage(message); - Message reply = replyChannel.receive(0); - assertNotNull(reply); - assertEquals("test123", reply.getPayload()); } @Test public void chainResolvesReplyChannelName() { - QueueChannel replyChannel = new QueueChannel(); + Message message = MessageBuilder.withPayload("test").setReplyChannelName("testChannel").build(); + handler.handleMessage(message); + expectLastCall().times(3); + //equality is lost when recreating the message + expect(outputChannel.send(isA(Message.class), anyLong())).andReturn(true); + replay(allMocks); DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); - beanFactory.registerSingleton("testChannel", replyChannel); + beanFactory.registerSingleton("testChannel", outputChannel); List handlers = new ArrayList(); - handlers.add(createHandler(1)); - handlers.add(createHandler(2)); - handlers.add(createHandler(3)); + handlers.add(producer); + handlers.add(producer); + handlers.add(producer); MessageHandlerChain chain = new MessageHandlerChain(); chain.setBeanName("testChain"); chain.setHandlers(handlers); chain.setBeanFactory(beanFactory); - Message message = MessageBuilder.withPayload("test") - .setReplyChannelName("testChannel").build(); chain.handleMessage(message); - Message reply = replyChannel.receive(0); - assertNotNull(reply); - assertEquals("test123", reply.getPayload()); } + private class ProducingHandlerStub extends AbstractReplyProducingMessageHandler { + + private final MessageHandler messageHandler; + + public ProducingHandlerStub(MessageHandler handler) { + messageHandler = handler; + } + + @Override + protected void handleRequestMessage(Message requestMessage, ReplyMessageHolder replyMessageHolder) { + messageHandler.handleMessage(requestMessage); + replyMessageHolder.add(requestMessage); + } - private static MessageHandler createHandler(final int index) { - return new AbstractReplyProducingMessageHandler() { - @Override - protected void handleRequestMessage(Message requestMessage, ReplyMessageHolder replyMessageHolder) { - replyMessageHolder.set(requestMessage.getPayload().toString() + index); - } - }; } +// private class ProducingHandlerStub implements MessageHandler { +// private MessageChannel output; +// +// private final MessageHandler messageHandler; +// +// public ProducingHandlerStub(MessageHandler handler) { +// messageHandler = handler; +// } +// +// void setOutputChannel(MessageChannel channel) { +// this.output = channel; +// +// } +// +// public void handleMessage(Message message) { +// messageHandler.handleMessage(message); +// output.send(message); +// } +// } }