diff --git a/org.springframework.integration/src/main/java/org/springframework/integration/aggregator/AbstractMessageBarrierEndpoint.java b/org.springframework.integration/src/main/java/org/springframework/integration/aggregator/AbstractMessageBarrierEndpoint.java index d7211fa66b..4f38f627bc 100644 --- a/org.springframework.integration/src/main/java/org/springframework/integration/aggregator/AbstractMessageBarrierEndpoint.java +++ b/org.springframework.integration/src/main/java/org/springframework/integration/aggregator/AbstractMessageBarrierEndpoint.java @@ -29,7 +29,6 @@ import java.util.concurrent.TimeUnit; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.springframework.beans.factory.InitializingBean; import org.springframework.integration.channel.BlockingChannel; import org.springframework.integration.channel.MessageChannel; import org.springframework.integration.endpoint.AbstractInOutEndpoint; @@ -61,7 +60,7 @@ import org.springframework.util.ObjectUtils; * @author Mark Fisher * @author Marius Bogoevici */ -public abstract class AbstractMessageBarrierEndpoint extends AbstractInOutEndpoint implements InitializingBean { +public abstract class AbstractMessageBarrierEndpoint extends AbstractInOutEndpoint { public final static long DEFAULT_SEND_TIMEOUT = 1000; @@ -151,7 +150,8 @@ public abstract class AbstractMessageBarrierEndpoint extends AbstractInOutEndpoi /** * Initialize this endpoint. */ - public void afterPropertiesSet() { + @Override + protected void initialize() { this.trackedCorrelationIds = new ArrayBlockingQueue(this.trackedCorrelationIdCapacity); this.executor.scheduleWithFixedDelay(new ReaperTask(), this.reaperInterval, this.reaperInterval, TimeUnit.MILLISECONDS); diff --git a/org.springframework.integration/src/main/java/org/springframework/integration/endpoint/AbstractEndpoint.java b/org.springframework.integration/src/main/java/org/springframework/integration/endpoint/AbstractEndpoint.java index eef6815491..2b0773173a 100644 --- a/org.springframework.integration/src/main/java/org/springframework/integration/endpoint/AbstractEndpoint.java +++ b/org.springframework.integration/src/main/java/org/springframework/integration/endpoint/AbstractEndpoint.java @@ -20,6 +20,8 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.beans.factory.BeanNameAware; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.integration.ConfigurationException; import org.springframework.integration.channel.ChannelRegistry; import org.springframework.integration.channel.ChannelRegistryAware; import org.springframework.integration.message.Message; @@ -27,6 +29,7 @@ import org.springframework.integration.message.MessageExchangeTemplate; import org.springframework.integration.message.MessageHandlingException; import org.springframework.integration.message.MessageSource; import org.springframework.integration.message.MessagingException; +import org.springframework.integration.message.SubscribableSource; import org.springframework.integration.util.ErrorHandler; /** @@ -34,7 +37,7 @@ import org.springframework.integration.util.ErrorHandler; * * @author Mark Fisher */ -public abstract class AbstractEndpoint implements MessageEndpoint, ChannelRegistryAware, BeanNameAware { +public abstract class AbstractEndpoint implements MessageEndpoint, ChannelRegistryAware, BeanNameAware, InitializingBean { protected final Log logger = LogFactory.getLog(this.getClass()); @@ -90,6 +93,27 @@ public abstract class AbstractEndpoint implements MessageEndpoint, ChannelRegist this.errorHandler = errorHandler; } + public final void afterPropertiesSet() { + if (this.source != null && (this.source instanceof SubscribableSource)) { + ((SubscribableSource) this.source).subscribe(this); + } + try { + this.initialize(); + } + catch (Exception e) { + if (e instanceof RuntimeException) { + throw (RuntimeException) e; + } + throw new ConfigurationException("failed to initialize endpoint '" + this.getName() + "'", e); + } + } + + /** + * Subclasses may override this method for custom initialization requirements. + */ + protected void initialize() throws Exception { + } + public final boolean send(Message message) { if (message == null || message.getPayload() == null) { throw new IllegalArgumentException("Message and its payload must not be null"); diff --git a/org.springframework.integration/src/main/java/org/springframework/integration/endpoint/AbstractInOutEndpoint.java b/org.springframework.integration/src/main/java/org/springframework/integration/endpoint/AbstractInOutEndpoint.java index b0e2672c7d..9937bcd052 100644 --- a/org.springframework.integration/src/main/java/org/springframework/integration/endpoint/AbstractInOutEndpoint.java +++ b/org.springframework.integration/src/main/java/org/springframework/integration/endpoint/AbstractInOutEndpoint.java @@ -91,7 +91,14 @@ public abstract class AbstractInOutEndpoint extends AbstractEndpoint { } return true; } - Message reply = buildReplyMessage(result, message.getHeaders()); + Message reply = null; + if (result instanceof Message && result.equals(message)) { + // we simply pass along an unaltered request Message + reply = (Message) result; + } + else { + reply = buildReplyMessage(result, message.getHeaders()); + } MessageChannel replyChannel = this.resolveReplyChannel(message); if (reply instanceof CompositeMessage && this.shouldSplitComposite()) { boolean sentAtLeastOne = false; diff --git a/org.springframework.integration/src/main/java/org/springframework/integration/endpoint/ServiceActivatorEndpoint.java b/org.springframework.integration/src/main/java/org/springframework/integration/endpoint/ServiceActivatorEndpoint.java index a49fe5c13a..80a92e5d49 100644 --- a/org.springframework.integration/src/main/java/org/springframework/integration/endpoint/ServiceActivatorEndpoint.java +++ b/org.springframework.integration/src/main/java/org/springframework/integration/endpoint/ServiceActivatorEndpoint.java @@ -26,7 +26,7 @@ import org.springframework.util.Assert; /** * @author Mark Fisher */ -public class ServiceActivatorEndpoint extends AbstractInOutEndpoint implements InitializingBean { +public class ServiceActivatorEndpoint extends AbstractInOutEndpoint { public static final String DEFAULT_LISTENER_METHOD = "handle"; @@ -48,7 +48,8 @@ public class ServiceActivatorEndpoint extends AbstractInOutEndpoint implements I } - public void afterPropertiesSet() throws Exception { + @Override + protected void initialize() throws Exception { if (this.invoker instanceof InitializingBean) { ((InitializingBean) this.invoker).afterPropertiesSet(); } diff --git a/org.springframework.integration/src/main/java/org/springframework/integration/gateway/SimpleMessagingGateway.java b/org.springframework.integration/src/main/java/org/springframework/integration/gateway/SimpleMessagingGateway.java index 96f0780ec9..6f772eb486 100644 --- a/org.springframework.integration/src/main/java/org/springframework/integration/gateway/SimpleMessagingGateway.java +++ b/org.springframework.integration/src/main/java/org/springframework/integration/gateway/SimpleMessagingGateway.java @@ -206,6 +206,7 @@ public class SimpleMessagingGateway extends MessagingGatewaySupport implements M ReplyMessageCorrelator correlator = new ReplyMessageCorrelator(this.replyMapCapacity); correlator.setBeanName("internal.correlator." + this); correlator.setSource(this.replyChannel); + correlator.afterPropertiesSet(); this.endpointRegistry.registerEndpoint(correlator); this.replyMessageCorrelator = correlator; } diff --git a/org.springframework.integration/src/main/java/org/springframework/integration/handler/MessageFilter.java b/org.springframework.integration/src/main/java/org/springframework/integration/handler/MessageFilter.java index dbf40cdaa0..60e8ceb069 100644 --- a/org.springframework.integration/src/main/java/org/springframework/integration/handler/MessageFilter.java +++ b/org.springframework.integration/src/main/java/org/springframework/integration/handler/MessageFilter.java @@ -16,25 +16,30 @@ package org.springframework.integration.handler; +import org.springframework.integration.endpoint.AbstractInOutEndpoint; import org.springframework.integration.message.Message; import org.springframework.integration.message.selector.MessageSelector; +import org.springframework.util.Assert; /** - * Handler for deciding whether to pass a message. Implements - * {@link MessageHandler} and simply delegates to a {@link MessageSelector}. + * Message Endpoint that decides whether to pass a message along to its + * output channel. Delegates to a {@link MessageSelector}. * * @author Mark Fisher */ -public class MessageFilter implements MessageHandler { +public class MessageFilter extends AbstractInOutEndpoint { private MessageSelector selector; public MessageFilter(MessageSelector selector) { + Assert.notNull(selector, "selector must not be null"); this.selector = selector; } - public Message handle(Message message) { + + @Override + protected Message handle(Message message) { if (this.selector.accept(message)) { return message; } diff --git a/org.springframework.integration/src/test/java/org/springframework/integration/endpoint/ServiceActivatorEndpointTests.java b/org.springframework.integration/src/test/java/org/springframework/integration/endpoint/ServiceActivatorEndpointTests.java index c83a47fdc9..deb42b58c3 100644 --- a/org.springframework.integration/src/test/java/org/springframework/integration/endpoint/ServiceActivatorEndpointTests.java +++ b/org.springframework.integration/src/test/java/org/springframework/integration/endpoint/ServiceActivatorEndpointTests.java @@ -287,7 +287,7 @@ public class ServiceActivatorEndpointTests { } @Test - public void correlationId() { + public void correlationIdNotSetIfMessageIsReturnedUnaltered() { QueueChannel replyChannel = new QueueChannel(1); ServiceActivatorEndpoint endpoint = new ServiceActivatorEndpoint(new Object() { @SuppressWarnings("unused") @@ -299,6 +299,22 @@ public class ServiceActivatorEndpointTests { .setReturnAddress(replyChannel).build(); endpoint.send(message); Message reply = replyChannel.receive(500); + assertNull(reply.getHeaders().getCorrelationId()); + } + + @Test + public void correlationIdSetForReplyMessage() { + QueueChannel replyChannel = new QueueChannel(1); + ServiceActivatorEndpoint endpoint = new ServiceActivatorEndpoint(new Object() { + @SuppressWarnings("unused") + public Message handle(Message message) { + return MessageBuilder.fromMessage(message).build(); + } + }, "handle"); + Message message = MessageBuilder.fromPayload("test") + .setReturnAddress(replyChannel).build(); + endpoint.send(message); + Message reply = replyChannel.receive(500); assertEquals(message.getHeaders().getId(), reply.getHeaders().getCorrelationId()); } diff --git a/org.springframework.integration/src/test/java/org/springframework/integration/gateway/GatewayProxyFactoryBeanTests.java b/org.springframework.integration/src/test/java/org/springframework/integration/gateway/GatewayProxyFactoryBeanTests.java index c9ae8b6b8a..310cd5123b 100644 --- a/org.springframework.integration/src/test/java/org/springframework/integration/gateway/GatewayProxyFactoryBeanTests.java +++ b/org.springframework.integration/src/test/java/org/springframework/integration/gateway/GatewayProxyFactoryBeanTests.java @@ -124,7 +124,7 @@ public class GatewayProxyFactoryBeanTests { public void testMultipleMessagesWithResponseCorrelator() throws InterruptedException { ClassPathXmlApplicationContext context = new ClassPathXmlApplicationContext( "gatewayWithResponseCorrelator.xml", GatewayProxyFactoryBeanTests.class); - int numRequests = 25; + int numRequests = 5; final TestService service = (TestService) context.getBean("proxy"); final String[] results = new String[numRequests]; final CountDownLatch latch = new CountDownLatch(numRequests); diff --git a/org.springframework.integration/src/test/java/org/springframework/integration/handler/CorrelationIdTests.java b/org.springframework.integration/src/test/java/org/springframework/integration/handler/CorrelationIdTests.java index 98a88b1a24..1545d86646 100644 --- a/org.springframework.integration/src/test/java/org/springframework/integration/handler/CorrelationIdTests.java +++ b/org.springframework.integration/src/test/java/org/springframework/integration/handler/CorrelationIdTests.java @@ -21,7 +21,9 @@ import static org.junit.Assert.assertTrue; import org.junit.Test; +import org.springframework.integration.channel.DirectChannel; import org.springframework.integration.channel.QueueChannel; +import org.springframework.integration.endpoint.ServiceActivatorEndpoint; import org.springframework.integration.message.Message; import org.springframework.integration.message.MessageBuilder; import org.springframework.integration.message.StringMessage; @@ -39,22 +41,28 @@ public class CorrelationIdTests { Object correlationId = "123-ABC"; Message message = MessageBuilder.fromPayload("test") .setCorrelationId(correlationId).build(); - DefaultMessageHandler handler = new DefaultMessageHandler(); - handler.setObject(new TestBean()); - handler.setMethodName("upperCase"); - handler.afterPropertiesSet(); - Message reply = handler.handle(message); + DirectChannel inputChannel = new DirectChannel(); + QueueChannel outputChannel = new QueueChannel(1); + ServiceActivatorEndpoint endpoint = new ServiceActivatorEndpoint(new TestBean(), "upperCase"); + endpoint.setSource(inputChannel); + endpoint.setOutputChannel(outputChannel); + endpoint.afterPropertiesSet(); + assertTrue(inputChannel.send(message)); + Message reply = outputChannel.receive(0); assertEquals(correlationId, reply.getHeaders().getCorrelationId()); } @Test public void testCorrelationIdCopiedFromMessageIdByDefault() { Message message = MessageBuilder.fromPayload("test").build(); - DefaultMessageHandler handler = new DefaultMessageHandler(); - handler.setObject(new TestBean()); - handler.setMethodName("upperCase"); - handler.afterPropertiesSet(); - Message reply = handler.handle(message); + DirectChannel inputChannel = new DirectChannel(); + QueueChannel outputChannel = new QueueChannel(1); + ServiceActivatorEndpoint endpoint = new ServiceActivatorEndpoint(new TestBean(), "upperCase"); + endpoint.setSource(inputChannel); + endpoint.setOutputChannel(outputChannel); + endpoint.afterPropertiesSet(); + assertTrue(inputChannel.send(message)); + Message reply = outputChannel.receive(0); assertEquals(message.getHeaders().getId(), reply.getHeaders().getCorrelationId()); } @@ -62,11 +70,14 @@ public class CorrelationIdTests { public void testCorrelationIdCopiedFromMessageCorrelationIdIfAvailable() { Message message = MessageBuilder.fromPayload("test") .setCorrelationId("correlationId").build(); - DefaultMessageHandler handler = new DefaultMessageHandler(); - handler.setObject(new TestBean()); - handler.setMethodName("upperCase"); - handler.afterPropertiesSet(); - Message reply = handler.handle(message); + DirectChannel inputChannel = new DirectChannel(); + QueueChannel outputChannel = new QueueChannel(1); + ServiceActivatorEndpoint endpoint = new ServiceActivatorEndpoint(new TestBean(), "upperCase"); + endpoint.setSource(inputChannel); + endpoint.setOutputChannel(outputChannel); + endpoint.afterPropertiesSet(); + assertTrue(inputChannel.send(message)); + Message reply = outputChannel.receive(0); assertEquals(message.getHeaders().getCorrelationId(), reply.getHeaders().getCorrelationId()); assertTrue(message.getHeaders().getCorrelationId().equals(reply.getHeaders().getCorrelationId())); } @@ -76,22 +87,28 @@ public class CorrelationIdTests { Object correlationId = "123-ABC"; Message message = MessageBuilder.fromPayload("test") .setCorrelationId(correlationId).build(); - DefaultMessageHandler handler = new DefaultMessageHandler(); - handler.setObject(new TestBean()); - handler.setMethodName("createMessage"); - handler.afterPropertiesSet(); - Message reply = handler.handle(message); + DirectChannel inputChannel = new DirectChannel(); + QueueChannel outputChannel = new QueueChannel(1); + ServiceActivatorEndpoint endpoint = new ServiceActivatorEndpoint(new TestBean(), "createMessage"); + endpoint.setSource(inputChannel); + endpoint.setOutputChannel(outputChannel); + endpoint.afterPropertiesSet(); + assertTrue(inputChannel.send(message)); + Message reply = outputChannel.receive(0); assertEquals("456-XYZ", reply.getHeaders().getCorrelationId()); } @Test public void testCorrelationNotCopiedFromRequestMessgeIdIfAlreadySetByHandler() throws Exception { Message message = new StringMessage("test"); - DefaultMessageHandler handler = new DefaultMessageHandler(); - handler.setObject(new TestBean()); - handler.setMethodName("createMessage"); - handler.afterPropertiesSet(); - Message reply = handler.handle(message); + DirectChannel inputChannel = new DirectChannel(); + QueueChannel outputChannel = new QueueChannel(1); + ServiceActivatorEndpoint endpoint = new ServiceActivatorEndpoint(new TestBean(), "createMessage"); + endpoint.setSource(inputChannel); + endpoint.setOutputChannel(outputChannel); + endpoint.afterPropertiesSet(); + assertTrue(inputChannel.send(message)); + Message reply = outputChannel.receive(0); assertEquals("456-XYZ", reply.getHeaders().getCorrelationId()); } diff --git a/org.springframework.integration/src/test/java/org/springframework/integration/handler/MessageFilterTests.java b/org.springframework.integration/src/test/java/org/springframework/integration/handler/MessageFilterTests.java index 030af6cb0d..ab9112e4fb 100644 --- a/org.springframework.integration/src/test/java/org/springframework/integration/handler/MessageFilterTests.java +++ b/org.springframework.integration/src/test/java/org/springframework/integration/handler/MessageFilterTests.java @@ -17,10 +17,14 @@ package org.springframework.integration.handler; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; import org.junit.Test; +import org.springframework.integration.channel.DirectChannel; +import org.springframework.integration.channel.QueueChannel; import org.springframework.integration.message.Message; import org.springframework.integration.message.StringMessage; import org.springframework.integration.message.selector.MessageSelector; @@ -31,7 +35,7 @@ import org.springframework.integration.message.selector.MessageSelector; public class MessageFilterTests { @Test - public void testFilterAcceptsMessage() { + public void filterAcceptsMessage() { MessageFilter filter = new MessageFilter(new MessageSelector() { public boolean accept(Message message) { return true; @@ -42,7 +46,7 @@ public class MessageFilterTests { } @Test - public void testFilterRejectsMessage() { + public void filterRejectsMessage() { MessageFilter filter = new MessageFilter(new MessageSelector() { public boolean accept(Message message) { return false; @@ -51,4 +55,40 @@ public class MessageFilterTests { assertNull(filter.handle(new StringMessage("test"))); } + @Test + public void filterAcceptsWithChannels() { + DirectChannel inputChannel = new DirectChannel(); + QueueChannel outputChannel = new QueueChannel(); + MessageFilter filter = new MessageFilter(new MessageSelector() { + public boolean accept(Message message) { + return true; + } + }); + filter.setSource(inputChannel); + filter.setOutputChannel(outputChannel); + filter.afterPropertiesSet(); + Message message = new StringMessage("test"); + assertTrue(inputChannel.send(message)); + Message reply = outputChannel.receive(0); + assertNotNull(reply); + assertEquals(message, reply); + } + + @Test + public void filterRejectsWithChannels() { + DirectChannel inputChannel = new DirectChannel(); + QueueChannel outputChannel = new QueueChannel(); + MessageFilter filter = new MessageFilter(new MessageSelector() { + public boolean accept(Message message) { + return false; + } + }); + filter.setSource(inputChannel); + filter.setOutputChannel(outputChannel); + filter.afterPropertiesSet(); + Message message = new StringMessage("test"); + assertTrue(inputChannel.send(message)); + assertNull(outputChannel.receive(0)); + } + }