diff --git a/org.springframework.integration/src/main/java/org/springframework/integration/channel/MessageChannelTemplate.java b/org.springframework.integration/src/main/java/org/springframework/integration/channel/MessageChannelTemplate.java index dd0bbe797d..fe83399f46 100644 --- a/org.springframework.integration/src/main/java/org/springframework/integration/channel/MessageChannelTemplate.java +++ b/org.springframework.integration/src/main/java/org/springframework/integration/channel/MessageChannelTemplate.java @@ -49,6 +49,8 @@ public class MessageChannelTemplate implements InitializingBean { protected final Log logger = LogFactory.getLog(this.getClass()); + private volatile MessageChannel defaultChannel; + private volatile long sendTimeout = -1; private volatile long receiveTimeout = -1; @@ -70,6 +72,29 @@ public class MessageChannelTemplate implements InitializingBean { private final Object initializationMonitor = new Object(); + /** + * Create a MessageChannelTemplate with no default channel. Note, that one + * may be provided by invoking {@link #setDefaultChannel(MessageChannel)}. + */ + public MessageChannelTemplate() { + } + + /** + * Create a MessageChannelTemplate with the given default channel. + */ + public MessageChannelTemplate(MessageChannel defaultChannel) { + this.defaultChannel = defaultChannel; + } + + + /** + * Specify the default MessageChannel to use when invoking the send and/or + * receive methods that do not expect a channel parameter. + */ + public void setDefaultChannel(MessageChannel defaultChannel) { + this.defaultChannel = defaultChannel; + } + /** * Specify the timeout value to use for send operations. * @@ -137,6 +162,10 @@ public class MessageChannelTemplate implements InitializingBean { } } + public boolean send(final Message message) { + return this.send(message, this.getRequiredDefaultChannel()); + } + public boolean send(final Message message, final MessageChannel channel) { TransactionTemplate txTemplate = this.getTransactionTemplate(); if (txTemplate != null) { @@ -149,6 +178,13 @@ public class MessageChannelTemplate implements InitializingBean { return this.doSend(message, channel); } + public Message receive() { + MessageChannel channel = this.getRequiredDefaultChannel(); + Assert.state(channel instanceof PollableChannel, + "The 'defaultChannel' must be a PollableChannel for receive operations."); + return this.receive((PollableChannel) channel); + } + public Message receive(final PollableChannel channel) { TransactionTemplate txTemplate = this.getTransactionTemplate(); if (txTemplate != null) { @@ -161,6 +197,10 @@ public class MessageChannelTemplate implements InitializingBean { return this.doReceive(channel); } + public Message sendAndReceive(final Message request) { + return this.sendAndReceive(request, this.getRequiredDefaultChannel()); + } + public Message sendAndReceive(final Message request, final MessageChannel channel) { TransactionTemplate txTemplate = this.getTransactionTemplate(); if (txTemplate != null) { @@ -206,6 +246,13 @@ public class MessageChannelTemplate implements InitializingBean { return this.doReceive(returnAddress); } + private MessageChannel getRequiredDefaultChannel() { + Assert.state(this.defaultChannel != null, + "No 'defaultChannel' specified for MessageChannelTemplate. " + + "Unable to invoke methods without a channel argument."); + return this.defaultChannel; + } + @SuppressWarnings("unchecked") private static class TemporaryReturnAddress implements PollableChannel { diff --git a/org.springframework.integration/src/test/java/org/springframework/integration/channel/MessageChannelTemplateTests.java b/org.springframework.integration/src/test/java/org/springframework/integration/channel/MessageChannelTemplateTests.java index 8a90fde55c..24de10a9b3 100644 --- a/org.springframework.integration/src/test/java/org/springframework/integration/channel/MessageChannelTemplateTests.java +++ b/org.springframework.integration/src/test/java/org/springframework/integration/channel/MessageChannelTemplateTests.java @@ -17,6 +17,8 @@ package org.springframework.integration.channel; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import java.util.ArrayList; @@ -64,14 +66,140 @@ public class MessageChannelTemplateTests { @Test - public void testSendAndReceive() { + public void send() { + MessageChannelTemplate template = new MessageChannelTemplate(); + QueueChannel channel = new QueueChannel(); + template.send(new StringMessage("test"), channel); + Message reply = channel.receive(0); + assertNotNull(reply); + assertEquals("test", reply.getPayload()); + } + + @Test + public void sendWithDefaultChannelProvidedBySetter() { + QueueChannel channel = new QueueChannel(); + MessageChannelTemplate template = new MessageChannelTemplate(); + template.setDefaultChannel(channel); + template.send(new StringMessage("test")); + Message reply = channel.receive(0); + assertNotNull(reply); + assertEquals("test", reply.getPayload()); + } + + @Test + public void sendWithDefaultChannelProvidedByConstructor() { + QueueChannel channel = new QueueChannel(); + MessageChannelTemplate template = new MessageChannelTemplate(channel); + template.send(new StringMessage("test")); + Message reply = channel.receive(0); + assertNotNull(reply); + assertEquals("test", reply.getPayload()); + } + + @Test + public void sendWithExplicitChannelTakesPrecedenceOverDefault() { + QueueChannel explicitChannel = new QueueChannel(); + QueueChannel defaultChannel = new QueueChannel(); + MessageChannelTemplate template = new MessageChannelTemplate(defaultChannel); + template.send(new StringMessage("test"), explicitChannel); + Message reply = explicitChannel.receive(0); + assertNotNull(reply); + assertEquals("test", reply.getPayload()); + assertNull(defaultChannel.receive(0)); + } + + @Test(expected = IllegalStateException.class) + public void sendWithoutChannelArgFailsIfNoDefaultAvailable() { + MessageChannelTemplate template = new MessageChannelTemplate(); + template.send(new StringMessage("test")); + } + + @Test + public void receive() { + QueueChannel channel = new QueueChannel(); + channel.send(new StringMessage("test")); + MessageChannelTemplate template = new MessageChannelTemplate(); + Message reply = template.receive(channel); + assertEquals("test", reply.getPayload()); + } + + @Test + public void receiveWithDefaultChannelProvidedBySetter() { + QueueChannel channel = new QueueChannel(); + channel.send(new StringMessage("test")); + MessageChannelTemplate template = new MessageChannelTemplate(); + template.setDefaultChannel(channel); + Message reply = template.receive(); + assertEquals("test", reply.getPayload()); + } + + @Test + public void receiveWithDefaultChannelProvidedByConstructor() { + QueueChannel channel = new QueueChannel(); + channel.send(new StringMessage("test")); + MessageChannelTemplate template = new MessageChannelTemplate(channel); + Message reply = template.receive(); + assertEquals("test", reply.getPayload()); + } + + @Test + public void receiveWithExplicitChannelTakesPrecedenceOverDefault() { + QueueChannel explicitChannel = new QueueChannel(); + QueueChannel defaultChannel = new QueueChannel(); + explicitChannel.send(new StringMessage("test")); + MessageChannelTemplate template = new MessageChannelTemplate(defaultChannel); + template.setReceiveTimeout(0); + Message reply = template.receive(explicitChannel); + assertEquals("test", reply.getPayload()); + assertNull(template.receive()); + } + + @Test(expected = IllegalStateException.class) + public void receiveWithoutChannelArgFailsIfNoDefaultAvailable() { + MessageChannelTemplate template = new MessageChannelTemplate(); + template.receive(); + } + + @Test(expected = IllegalStateException.class) + public void receiveWithNonPollableDefaultFails() { + DirectChannel channel = new DirectChannel(); + MessageChannelTemplate template = new MessageChannelTemplate(channel); + template.receive(); + } + + @Test + public void sendAndReceive() { MessageChannelTemplate template = new MessageChannelTemplate(); Message reply = template.sendAndReceive(new StringMessage("test"), this.requestChannel); assertEquals("TEST", reply.getPayload()); } @Test - public void testSendWithReturnAddress() throws InterruptedException { + public void sendAndReceiveWithDefaultChannel() { + MessageChannelTemplate template = new MessageChannelTemplate(); + template.setDefaultChannel(this.requestChannel); + Message reply = template.sendAndReceive(new StringMessage("test")); + assertEquals("TEST", reply.getPayload()); + } + + @Test + public void sendAndReceiveWithExplicitChannelTakesPrecedenceOverDefault() { + QueueChannel defaultChannel = new QueueChannel(); + MessageChannelTemplate template = new MessageChannelTemplate(defaultChannel); + Message message = new StringMessage("test"); + Message reply = template.sendAndReceive(message, this.requestChannel); + assertEquals("TEST", reply.getPayload()); + assertNull(defaultChannel.receive(0)); + } + + @Test(expected = IllegalStateException.class) + public void sendAndReceiveWithoutChannelArgFailsIfNoDefaultAvailable() { + MessageChannelTemplate template = new MessageChannelTemplate(); + template.sendAndReceive(new StringMessage("test")); + } + + @Test + public void sendWithReturnAddress() throws InterruptedException { final List replies = new ArrayList(3); final CountDownLatch latch = new CountDownLatch(3); MessageChannel replyChannel = new MessageChannel() {