diff --git a/org.springframework.integration.adapter/src/main/java/org/springframework/integration/adapter/AbstractRemotingOutboundGateway.java b/org.springframework.integration.adapter/src/main/java/org/springframework/integration/adapter/AbstractRemotingOutboundGateway.java index 7affe87234..98641cbe75 100644 --- a/org.springframework.integration.adapter/src/main/java/org/springframework/integration/adapter/AbstractRemotingOutboundGateway.java +++ b/org.springframework.integration.adapter/src/main/java/org/springframework/integration/adapter/AbstractRemotingOutboundGateway.java @@ -20,6 +20,7 @@ import java.io.Serializable; import org.springframework.integration.channel.MessageChannel; import org.springframework.integration.endpoint.AbstractReplyProducingMessageConsumer; +import org.springframework.integration.endpoint.ReplyHolder; import org.springframework.integration.message.Message; import org.springframework.integration.message.MessageHandlingException; import org.springframework.remoting.RemoteAccessException; @@ -49,10 +50,13 @@ public abstract class AbstractRemotingOutboundGateway extends AbstractReplyProdu protected abstract MessageHandler createHandlerProxy(String url); - public final Message handle(Message message) { + public final void handle(Message message, ReplyHolder replyHolder) { this.verifySerializability(message); try { - return this.handlerProxy.handle(message); + Message reply = this.handlerProxy.handle(message); + if (reply != null) { + replyHolder.set(reply); + } } catch (RemoteAccessException e) { throw new MessageHandlingException(message, "unable to handle message remotely", e); diff --git a/org.springframework.integration.rmi/src/test/java/org/springframework/integration/rmi/RmiOutboundGatewayTests.java b/org.springframework.integration.rmi/src/test/java/org/springframework/integration/rmi/RmiOutboundGatewayTests.java index ea5c3c5f45..828bcd1090 100644 --- a/org.springframework.integration.rmi/src/test/java/org/springframework/integration/rmi/RmiOutboundGatewayTests.java +++ b/org.springframework.integration.rmi/src/test/java/org/springframework/integration/rmi/RmiOutboundGatewayTests.java @@ -26,6 +26,7 @@ import org.junit.Before; import org.junit.Test; import org.springframework.integration.adapter.MessageHandler; +import org.springframework.integration.channel.QueueChannel; import org.springframework.integration.message.GenericMessage; import org.springframework.integration.message.Message; import org.springframework.integration.message.MessageBuilder; @@ -42,6 +43,13 @@ public class RmiOutboundGatewayTests { private final RmiOutboundGateway gateway = new RmiOutboundGateway("rmi://localhost:1099/testRemoteHandler"); + private final QueueChannel output = new QueueChannel(1); + + + @Before + public void initializeGateway() { + this.gateway.setOutputChannel(this.output); + } @Before public void createExporter() throws RemoteException { @@ -55,7 +63,8 @@ public class RmiOutboundGatewayTests { @Test public void serializablePayload() throws RemoteException { - Message replyMessage = gateway.handle(new StringMessage("test")); + gateway.onMessage(new StringMessage("test")); + Message replyMessage = output.receive(0); assertNotNull(replyMessage); assertEquals("TEST", replyMessage.getPayload()); } @@ -64,7 +73,8 @@ public class RmiOutboundGatewayTests { public void serializableAttribute() throws RemoteException { Message requestMessage = MessageBuilder.withPayload("test") .setHeader("testAttribute", "foo").build(); - Message replyMessage = gateway.handle(requestMessage); + gateway.onMessage(requestMessage); + Message replyMessage = output.receive(0); assertNotNull(replyMessage); assertEquals("foo", replyMessage.getHeaders().get("testAttribute")); } @@ -73,14 +83,14 @@ public class RmiOutboundGatewayTests { public void nonSerializablePayload() throws RemoteException { NonSerializableTestObject payload = new NonSerializableTestObject(); Message requestMessage = new GenericMessage(payload); - gateway.handle(requestMessage); + gateway.onMessage(requestMessage); } @Test(expected = MessageHandlingException.class) public void nonSerializableAttribute() throws RemoteException { Message requestMessage = MessageBuilder.withPayload("test") .setHeader("testAttribute", new NonSerializableTestObject()).build(); - gateway.handle(requestMessage); + gateway.onMessage(requestMessage); } @Test @@ -88,7 +98,7 @@ public class RmiOutboundGatewayTests { RmiOutboundGateway gateway = new RmiOutboundGateway("rmi://localhost:1099/noSuchService"); boolean exceptionThrown = false; try { - gateway.handle(new StringMessage("test")); + gateway.onMessage(new StringMessage("test")); } catch (MessageHandlingException e) { assertEquals(RemoteLookupFailureException.class, e.getCause().getClass()); @@ -102,7 +112,7 @@ public class RmiOutboundGatewayTests { RmiOutboundGateway gateway = new RmiOutboundGateway("rmi://noSuchHost:1099/testRemoteHandler"); boolean exceptionThrown = false; try { - gateway.handle(new StringMessage("test")); + gateway.onMessage(new StringMessage("test")); } catch (MessageHandlingException e) { assertEquals(RemoteLookupFailureException.class, e.getCause().getClass()); @@ -116,7 +126,7 @@ public class RmiOutboundGatewayTests { RmiOutboundGateway gateway = new RmiOutboundGateway("invalid"); boolean exceptionThrown = false; try { - gateway.handle(new StringMessage("test")); + gateway.onMessage(new StringMessage("test")); } catch (MessageHandlingException e) { assertEquals(RemoteLookupFailureException.class, e.getCause().getClass()); diff --git a/org.springframework.integration.ws/src/main/java/org/springframework/integration/ws/AbstractWebServiceOutboundGateway.java b/org.springframework.integration.ws/src/main/java/org/springframework/integration/ws/AbstractWebServiceOutboundGateway.java index 0173fd28fd..c635846c07 100644 --- a/org.springframework.integration.ws/src/main/java/org/springframework/integration/ws/AbstractWebServiceOutboundGateway.java +++ b/org.springframework.integration.ws/src/main/java/org/springframework/integration/ws/AbstractWebServiceOutboundGateway.java @@ -21,7 +21,7 @@ import java.net.URI; import org.springframework.integration.channel.MessageChannel; import org.springframework.integration.endpoint.AbstractReplyProducingMessageConsumer; -import org.springframework.integration.message.GenericMessage; +import org.springframework.integration.endpoint.ReplyHolder; import org.springframework.integration.message.Message; import org.springframework.util.Assert; import org.springframework.ws.WebServiceMessage; @@ -85,9 +85,11 @@ public abstract class AbstractWebServiceOutboundGateway extends AbstractReplyPro } @Override - public final Message handle(Message message) { + public final void handle(Message message, ReplyHolder replyHolder) { Object responsePayload = this.doHandle(message.getPayload(), this.getRequestCallback(message)); - return responsePayload != null ? new GenericMessage(responsePayload, message.getHeaders()) : null; + if (responsePayload != null) { + replyHolder.set(responsePayload); + } } protected abstract Object doHandle(Object requestPayload, WebServiceMessageCallback requestCallback); diff --git a/org.springframework.integration/src/main/java/org/springframework/integration/aggregator/AbstractMessageBarrierConsumer.java b/org.springframework.integration/src/main/java/org/springframework/integration/aggregator/AbstractMessageBarrierConsumer.java index e851db75e5..261b005f73 100644 --- a/org.springframework.integration/src/main/java/org/springframework/integration/aggregator/AbstractMessageBarrierConsumer.java +++ b/org.springframework.integration/src/main/java/org/springframework/integration/aggregator/AbstractMessageBarrierConsumer.java @@ -31,6 +31,7 @@ import org.apache.commons.logging.LogFactory; import org.springframework.beans.factory.InitializingBean; import org.springframework.integration.channel.MessageChannel; import org.springframework.integration.endpoint.AbstractReplyProducingMessageConsumer; +import org.springframework.integration.endpoint.ReplyHolder; import org.springframework.integration.message.Message; import org.springframework.integration.message.MessageConsumer; import org.springframework.integration.message.MessageHandlingException; @@ -166,7 +167,7 @@ public abstract class AbstractMessageBarrierConsumer extends AbstractReplyProduc } @Override - protected final Message handle(Message message) { + protected final void handle(Message message, ReplyHolder replyHolder) { if (!this.initialized) { this.afterPropertiesSet(); } @@ -181,7 +182,7 @@ public abstract class AbstractMessageBarrierConsumer extends AbstractReplyProduc + correlationId + "' has already completed or timed out."); } this.sendToDiscardChannelIfAvailable(message); - return null; + return; } MessageBarrier barrier = barriers.putIfAbsent(correlationId, createMessageBarrier()); if (barrier == null) { @@ -189,17 +190,17 @@ public abstract class AbstractMessageBarrierConsumer extends AbstractReplyProduc } List> releasedMessages = barrier.addAndRelease(message); if (CollectionUtils.isEmpty(releasedMessages)) { - return null; + return; } if (isBarrierRemovable(correlationId, releasedMessages)) { this.removeBarrier(correlationId); } Message[] processedMessages = this.processReleasedMessages(correlationId, releasedMessages); if (ObjectUtils.isEmpty(processedMessages)) { - return null; + return; } this.afterRelease(correlationId, releasedMessages); - return null; + return; } private void afterRelease(Object correlationId, List> releasedMessages) { diff --git a/org.springframework.integration/src/main/java/org/springframework/integration/endpoint/AbstractReplyProducingMessageConsumer.java b/org.springframework.integration/src/main/java/org/springframework/integration/endpoint/AbstractReplyProducingMessageConsumer.java index 505e4b45b3..6d10b07bb1 100644 --- a/org.springframework.integration/src/main/java/org/springframework/integration/endpoint/AbstractReplyProducingMessageConsumer.java +++ b/org.springframework.integration/src/main/java/org/springframework/integration/endpoint/AbstractReplyProducingMessageConsumer.java @@ -16,16 +16,12 @@ package org.springframework.integration.endpoint; -import java.util.ArrayList; -import java.util.List; - import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.BeanFactoryAware; import org.springframework.integration.channel.BeanFactoryChannelResolver; import org.springframework.integration.channel.ChannelResolver; import org.springframework.integration.channel.MessageChannel; import org.springframework.integration.channel.MessageChannelTemplate; -import org.springframework.integration.message.CompositeMessage; import org.springframework.integration.message.Message; import org.springframework.integration.message.MessageBuilder; import org.springframework.integration.message.MessageHandlingException; @@ -101,36 +97,31 @@ public abstract class AbstractReplyProducingMessageConsumer extends AbstractMess if (!this.supports(message)) { throw new MessageRejectedException(message, "unsupported message"); } - Object result = this.handle(message); - if (result == null) { + ReplyHolder replyHolder = new ReplyHolder(); + this.handle(message, replyHolder); + if (replyHolder.isEmpty()) { if (this.requiresReply) { throw new MessageHandlingException(message, "consumer '" + this + "' requires a reply, but no reply was received"); } return; } - Message reply = null; - if (result instanceof Message && result.equals(message)) { - // we simply pass along an unaltered request Message - reply = (Message) result; + Object targetChannelValue = replyHolder.getTargetChannel(); + MessageChannel replyChannel = null; + if (targetChannelValue == null) { + replyChannel = this.resolveReplyChannel(message); } - else { - reply = buildReplyMessage(result, message.getHeaders()); + else if (targetChannelValue instanceof String) { + replyChannel = this.channelResolver.resolveChannelName((String) targetChannelValue); } - MessageChannel replyChannel = this.resolveReplyChannel(message); - if (reply instanceof CompositeMessage && this.shouldSplitComposite()) { - boolean sentAtLeastOne = false; - for (Message nextReply : (CompositeMessage) reply) { - boolean sent = this.sendReplyMessage(nextReply, replyChannel); - sentAtLeastOne = (sentAtLeastOne || sent); - } - } - else { - this.sendReplyMessage(reply, replyChannel); + MessageHeaders requestHeaders = message.getHeaders(); + for (MessageBuilder builder : replyHolder.builders()) { + builder.copyHeadersIfAbsent(requestHeaders); + this.sendReplyMessage(builder.build(), replyChannel); } } - protected abstract Object handle(Message message); + protected abstract void handle(Message message, ReplyHolder replyHolder); protected boolean supports(Message message) { if (this.selector != null && !this.selector.accept(message)) { @@ -142,38 +133,10 @@ public abstract class AbstractReplyProducingMessageConsumer extends AbstractMess return true; } - protected boolean shouldSplitComposite() { - return false; - } - protected boolean sendReplyMessage(Message replyMessage, MessageChannel replyChannel) { return this.channelTemplate.send(replyMessage, replyChannel); } - private Message buildReplyMessage(Object result, MessageHeaders requestHeaders) { - MessageBuilder builder = null; - if (result instanceof MessageBuilder) { - builder = (MessageBuilder) result; - } - else if (result instanceof CompositeMessage) { - List> messages = ((CompositeMessage) result).getPayload(); - List> replies = new ArrayList>(); - for (Message message : messages) { - replies.add(this.buildReplyMessage(message, requestHeaders)); - } - return new CompositeMessage(replies); - } - else if (result instanceof Message) { - builder = MessageBuilder.fromMessage((Message) result); - } - else { - builder = MessageBuilder.withPayload(result); - } - return builder.copyHeadersIfAbsent(requestHeaders) - .setHeaderIfAbsent(MessageHeaders.CORRELATION_ID, requestHeaders.getId()) - .build(); - } - private MessageChannel resolveReplyChannel(Message requestMessage) { MessageChannel replyChannel = this.getOutputChannel(); if (replyChannel == null) { diff --git a/org.springframework.integration/src/main/java/org/springframework/integration/endpoint/ReplyHolder.java b/org.springframework.integration/src/main/java/org/springframework/integration/endpoint/ReplyHolder.java new file mode 100644 index 0000000000..41aaf54d67 --- /dev/null +++ b/org.springframework.integration/src/main/java/org/springframework/integration/endpoint/ReplyHolder.java @@ -0,0 +1,85 @@ +/* + * Copyright 2002-2008 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.integration.endpoint; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import org.springframework.integration.channel.MessageChannel; +import org.springframework.integration.message.Message; +import org.springframework.integration.message.MessageBuilder; + +/** + * @author Mark Fisher + */ +public class ReplyHolder { + + private final List> builders = new ArrayList>(); + + private volatile Object targetChannel; + + + public MessageBuilder set(Object replyObject) { + return this.createAndAddBuilder(replyObject, true); + } + + public MessageBuilder add(Object replyObject) { + return this.createAndAddBuilder(replyObject, false); + } + + public void setTargetChannel(MessageChannel targetChannel) { + this.targetChannel = targetChannel; + } + + public void setTargetChannelName(String targetChannelName) { + this.targetChannel = targetChannelName; + } + + protected Object getTargetChannel() { + return this.targetChannel; + } + + public boolean isEmpty() { + return this.builders.isEmpty(); + } + + public List> builders() { + return Collections.unmodifiableList(this.builders); + } + + private MessageBuilder createAndAddBuilder(Object replyObject, boolean clearExistingValues) { + MessageBuilder builder = null; + if (replyObject instanceof MessageBuilder) { + builder = (MessageBuilder) replyObject; + } + else if (replyObject instanceof Message) { + builder = MessageBuilder.fromMessage((Message) replyObject); + } + else { + builder = MessageBuilder.withPayload(replyObject); + } + synchronized (this.builders) { + if (clearExistingValues) { + this.builders.clear(); + } + this.builders.add(builder); + } + return builder; + } + +} diff --git a/org.springframework.integration/src/main/java/org/springframework/integration/endpoint/ServiceActivatorEndpoint.java b/org.springframework.integration/src/main/java/org/springframework/integration/endpoint/ServiceActivatorEndpoint.java index d845c28921..d7f19c30e8 100644 --- a/org.springframework.integration/src/main/java/org/springframework/integration/endpoint/ServiceActivatorEndpoint.java +++ b/org.springframework.integration/src/main/java/org/springframework/integration/endpoint/ServiceActivatorEndpoint.java @@ -24,8 +24,8 @@ import org.springframework.integration.message.Message; import org.springframework.integration.message.MessageHandlingException; import org.springframework.integration.message.MessageMappingMethodInvoker; import org.springframework.integration.util.DefaultMethodResolver; -import org.springframework.integration.util.MethodResolver; import org.springframework.integration.util.MethodInvoker; +import org.springframework.integration.util.MethodResolver; import org.springframework.util.Assert; /** @@ -63,9 +63,12 @@ public class ServiceActivatorEndpoint extends AbstractReplyProducingMessageConsu } @Override - protected Object handle(Message message) { + protected void handle(Message message, ReplyHolder replyHolder) { try { - return this.invoker.invokeMethod(message); + Object result = this.invoker.invokeMethod(message); + if (result != null) { + replyHolder.set(result); + } } catch (Exception e) { if (e instanceof RuntimeException) { diff --git a/org.springframework.integration/src/main/java/org/springframework/integration/filter/MessageFilter.java b/org.springframework.integration/src/main/java/org/springframework/integration/filter/MessageFilter.java index 3b5988b868..c34765456b 100644 --- a/org.springframework.integration/src/main/java/org/springframework/integration/filter/MessageFilter.java +++ b/org.springframework.integration/src/main/java/org/springframework/integration/filter/MessageFilter.java @@ -17,6 +17,7 @@ package org.springframework.integration.filter; import org.springframework.integration.endpoint.AbstractReplyProducingMessageConsumer; +import org.springframework.integration.endpoint.ReplyHolder; import org.springframework.integration.message.Message; import org.springframework.integration.message.selector.MessageSelector; import org.springframework.util.Assert; @@ -40,11 +41,10 @@ public class MessageFilter extends AbstractReplyProducingMessageConsumer { @Override - protected Message handle(Message message) { + protected void handle(Message message, ReplyHolder replyHolder) { if (this.selector.accept(message)) { - return message; + replyHolder.set(message); } - return null; } } diff --git a/org.springframework.integration/src/main/java/org/springframework/integration/gateway/AbstractMessagingGateway.java b/org.springframework.integration/src/main/java/org/springframework/integration/gateway/AbstractMessagingGateway.java index c5a94749ed..56b59b0590 100644 --- a/org.springframework.integration/src/main/java/org/springframework/integration/gateway/AbstractMessagingGateway.java +++ b/org.springframework.integration/src/main/java/org/springframework/integration/gateway/AbstractMessagingGateway.java @@ -26,6 +26,7 @@ import org.springframework.integration.endpoint.AbstractReplyProducingMessageCon import org.springframework.integration.endpoint.MessageEndpoint; import org.springframework.integration.endpoint.MessagingGateway; import org.springframework.integration.endpoint.PollingConsumerEndpoint; +import org.springframework.integration.endpoint.ReplyHolder; import org.springframework.integration.endpoint.SubscribingConsumerEndpoint; import org.springframework.integration.message.Message; import org.springframework.integration.message.MessageConsumer; @@ -154,8 +155,8 @@ public abstract class AbstractMessagingGateway implements MessagingGateway, Mess MessageEndpoint correlator = null; MessageConsumer consumer = new AbstractReplyProducingMessageConsumer() { @Override - protected Object handle(Message message) { - return message; + protected void handle(Message message, ReplyHolder replyHolder) { + replyHolder.set(message); } }; if (this.replyChannel instanceof SubscribableChannel) { diff --git a/org.springframework.integration/src/main/java/org/springframework/integration/message/CompositeMessage.java b/org.springframework.integration/src/main/java/org/springframework/integration/message/CompositeMessage.java deleted file mode 100644 index 685b55e976..0000000000 --- a/org.springframework.integration/src/main/java/org/springframework/integration/message/CompositeMessage.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright 2002-2008 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.integration.message; - -import java.util.Arrays; -import java.util.Collections; -import java.util.Iterator; -import java.util.List; - -/** - * @author Mark Fisher - */ -public class CompositeMessage extends GenericMessage>> implements Iterable> { - - public CompositeMessage(Message[] messages) { - this(Arrays.asList(messages)); - } - - public CompositeMessage(List> messages) { - super(Collections.unmodifiableList(messages)); - } - - - public Iterator> iterator() { - return this.getPayload().iterator(); - } - -} diff --git a/org.springframework.integration/src/main/java/org/springframework/integration/message/MessageBuilder.java b/org.springframework.integration/src/main/java/org/springframework/integration/message/MessageBuilder.java index 727cb2c930..85e8de589e 100644 --- a/org.springframework.integration/src/main/java/org/springframework/integration/message/MessageBuilder.java +++ b/org.springframework.integration/src/main/java/org/springframework/integration/message/MessageBuilder.java @@ -35,13 +35,21 @@ public final class MessageBuilder { private final Map headers = new HashMap(); + private final Message originalMessage; + + private volatile boolean modified; + /** * Private constructor to be invoked from the static factory methods only. */ - private MessageBuilder(T payload) { + private MessageBuilder(T payload, Message originalMessage) { Assert.notNull(payload, "payload must not be null"); this.payload = payload; + this.originalMessage = originalMessage; + if (originalMessage != null) { + this.headers.putAll(originalMessage.getHeaders()); + } } @@ -54,8 +62,8 @@ public final class MessageBuilder { * will be copied */ public static MessageBuilder fromMessage(Message message) { - MessageBuilder builder = new MessageBuilder(message.getPayload()); - builder.headers.putAll(message.getHeaders()); + Assert.notNull(message, "message must not be null"); + MessageBuilder builder = new MessageBuilder(message.getPayload(), message); return builder; } @@ -65,7 +73,7 @@ public final class MessageBuilder { * @param payload the payload for the new message */ public static MessageBuilder withPayload(T payload) { - MessageBuilder builder = new MessageBuilder(payload); + MessageBuilder builder = new MessageBuilder(payload, null); return builder; } @@ -76,6 +84,7 @@ public final class MessageBuilder { */ public MessageBuilder setHeader(String headerName, Object headerValue) { if (StringUtils.hasLength(headerName) && !(this.isReadOnly(headerName))) { + this.modified = true; if (headerValue == null) { this.headers.remove(headerName); } @@ -102,6 +111,7 @@ public final class MessageBuilder { */ public MessageBuilder removeHeader(String headerName) { if (StringUtils.hasLength(headerName)) { + this.modified = true; this.headers.remove(headerName); } return this; @@ -174,6 +184,9 @@ public final class MessageBuilder { } public Message build() { + if (!this.modified && this.originalMessage != null) { + return this.originalMessage; + } return new GenericMessage(this.payload, this.headers); } diff --git a/org.springframework.integration/src/main/java/org/springframework/integration/splitter/AbstractMessageSplitter.java b/org.springframework.integration/src/main/java/org/springframework/integration/splitter/AbstractMessageSplitter.java index 1e74b00ec4..d7239fce0d 100644 --- a/org.springframework.integration/src/main/java/org/springframework/integration/splitter/AbstractMessageSplitter.java +++ b/org.springframework.integration/src/main/java/org/springframework/integration/splitter/AbstractMessageSplitter.java @@ -16,15 +16,11 @@ package org.springframework.integration.splitter; -import java.util.ArrayList; import java.util.Collection; -import java.util.List; import org.springframework.integration.endpoint.AbstractReplyProducingMessageConsumer; -import org.springframework.integration.message.CompositeMessage; +import org.springframework.integration.endpoint.ReplyHolder; import org.springframework.integration.message.Message; -import org.springframework.integration.message.MessageBuilder; -import org.springframework.integration.message.MessageHeaders; /** * Base class for Message-splitting consumers. @@ -34,24 +30,18 @@ import org.springframework.integration.message.MessageHeaders; public abstract class AbstractMessageSplitter extends AbstractReplyProducingMessageConsumer { @Override - protected final boolean shouldSplitComposite() { - return true; - } - - @Override - protected final Message handle(Message message) { + protected final void handle(Message message, ReplyHolder replyHolder) { Object result = this.splitMessage(message); if (result == null) { - return null; + return; } - MessageHeaders requestHeaders = message.getHeaders(); - List> results = new ArrayList>(); + Object correlationId = message.getHeaders().getId(); if (result instanceof Collection) { Collection items = (Collection) result; int sequenceNumber = 0; int sequenceSize = items.size(); for (Object item : items) { - results.add(this.createSplitMessage(item, requestHeaders, ++sequenceNumber, sequenceSize)); + this.addReply(replyHolder, item, correlationId, ++sequenceNumber, sequenceSize); } } else if (result.getClass().isArray()) { @@ -59,16 +49,18 @@ public abstract class AbstractMessageSplitter extends AbstractReplyProducingMess int sequenceNumber = 0; int sequenceSize = items.length; for (Object item : items) { - results.add(this.createSplitMessage(item, requestHeaders, ++sequenceNumber, sequenceSize)); + this.addReply(replyHolder, item, correlationId, ++sequenceNumber, sequenceSize); } } else { - results.add(this.createSplitMessage(result, requestHeaders, 1, 1)); + this.addReply(replyHolder, result, correlationId, 1, 1); } - if (results.isEmpty()) { - return null; - } - return new CompositeMessage(results); + } + + private void addReply(ReplyHolder replyHolder, Object item, Object correlationId, int sequenceNumber, int sequenceSize) { + replyHolder.add(item).setCorrelationId(correlationId) + .setSequenceNumber(sequenceNumber) + .setSequenceSize(sequenceSize); } /** @@ -81,20 +73,4 @@ public abstract class AbstractMessageSplitter extends AbstractReplyProducingMess */ protected abstract Object splitMessage(Message message); - - private Message createSplitMessage(Object item, MessageHeaders requestHeaders, int sequenceNumber, int sequenceSize) { - if (item instanceof Message) { - return setSplitMessageHeaders(MessageBuilder.fromMessage((Message) item), - requestHeaders.getId(), sequenceNumber, sequenceSize); - } - return setSplitMessageHeaders(MessageBuilder.withPayload(item), - requestHeaders.getId(), sequenceNumber, sequenceSize); - } - - private Message setSplitMessageHeaders(MessageBuilder builder, Object requestMessageId, int sequenceNumber, int sequenceSize) { - return builder.setCorrelationId(requestMessageId) - .setSequenceNumber(sequenceNumber) - .setSequenceSize(sequenceSize).build(); - } - } diff --git a/org.springframework.integration/src/main/java/org/springframework/integration/transformer/MessageTransformingConsumer.java b/org.springframework.integration/src/main/java/org/springframework/integration/transformer/MessageTransformingConsumer.java index 355a80b104..934a1e466c 100644 --- a/org.springframework.integration/src/main/java/org/springframework/integration/transformer/MessageTransformingConsumer.java +++ b/org.springframework.integration/src/main/java/org/springframework/integration/transformer/MessageTransformingConsumer.java @@ -17,6 +17,7 @@ package org.springframework.integration.transformer; import org.springframework.integration.endpoint.AbstractReplyProducingMessageConsumer; +import org.springframework.integration.endpoint.ReplyHolder; import org.springframework.integration.message.Message; import org.springframework.util.Assert; @@ -43,8 +44,11 @@ public class MessageTransformingConsumer extends AbstractReplyProducingMessageCo @Override - protected Message handle(Message message) { - return transformer.transform(message); + protected void handle(Message message, ReplyHolder replyHolder) { + Message result = transformer.transform(message); + if (result != null) { + replyHolder.set(result); + } } } diff --git a/org.springframework.integration/src/test/java/org/springframework/integration/aggregator/AggregatorEndpointTests.java b/org.springframework.integration/src/test/java/org/springframework/integration/aggregator/AggregatorEndpointTests.java index 74add7b598..73431a5fe9 100644 --- a/org.springframework.integration/src/test/java/org/springframework/integration/aggregator/AggregatorEndpointTests.java +++ b/org.springframework.integration/src/test/java/org/springframework/integration/aggregator/AggregatorEndpointTests.java @@ -153,9 +153,9 @@ public class AggregatorEndpointTests { QueueChannel replyChannel = new QueueChannel(); QueueChannel discardChannel = new QueueChannel(); this.aggregator.setDiscardChannel(discardChannel); - this.aggregator.handle(createMessage("test-1a", 1, 1, 1, replyChannel)); + this.aggregator.onMessage(createMessage("test-1a", 1, 1, 1, replyChannel)); assertEquals("test-1a", replyChannel.receive(100).getPayload()); - this.aggregator.handle(createMessage("test-1b", 1, 1, 1, replyChannel)); + this.aggregator.onMessage(createMessage("test-1b", 1, 1, 1, replyChannel)); assertEquals("test-1b", discardChannel.receive(100).getPayload()); } @@ -165,13 +165,13 @@ public class AggregatorEndpointTests { QueueChannel discardChannel = new QueueChannel(); this.aggregator.setTrackedCorrelationIdCapacity(3); this.aggregator.setDiscardChannel(discardChannel); - this.aggregator.handle(createMessage("test-1a", 1, 1, 1, replyChannel)); + this.aggregator.onMessage(createMessage("test-1a", 1, 1, 1, replyChannel)); assertEquals("test-1a", replyChannel.receive(100).getPayload()); - this.aggregator.handle(createMessage("test-2", 2, 1, 1, replyChannel)); + this.aggregator.onMessage(createMessage("test-2", 2, 1, 1, replyChannel)); assertEquals("test-2", replyChannel.receive(100).getPayload()); - this.aggregator.handle(createMessage("test-3", 3, 1, 1, replyChannel)); + this.aggregator.onMessage(createMessage("test-3", 3, 1, 1, replyChannel)); assertEquals("test-3", replyChannel.receive(100).getPayload()); - this.aggregator.handle(createMessage("test-1b", 1, 1, 1, replyChannel)); + this.aggregator.onMessage(createMessage("test-1b", 1, 1, 1, replyChannel)); assertEquals("test-1b", discardChannel.receive(100).getPayload()); } @@ -181,23 +181,23 @@ public class AggregatorEndpointTests { QueueChannel discardChannel = new QueueChannel(); this.aggregator.setTrackedCorrelationIdCapacity(3); this.aggregator.setDiscardChannel(discardChannel); - this.aggregator.handle(createMessage("test-1a", 1, 1, 1, replyChannel)); + this.aggregator.onMessage(createMessage("test-1a", 1, 1, 1, replyChannel)); assertEquals("test-1a", replyChannel.receive(100).getPayload()); - this.aggregator.handle(createMessage("test-2", 2, 1, 1, replyChannel)); + this.aggregator.onMessage(createMessage("test-2", 2, 1, 1, replyChannel)); assertEquals("test-2", replyChannel.receive(100).getPayload()); - this.aggregator.handle(createMessage("test-3", 3, 1, 1, replyChannel)); + this.aggregator.onMessage(createMessage("test-3", 3, 1, 1, replyChannel)); assertEquals("test-3", replyChannel.receive(100).getPayload()); - this.aggregator.handle(createMessage("test-4", 4, 1, 1, replyChannel)); + this.aggregator.onMessage(createMessage("test-4", 4, 1, 1, replyChannel)); assertEquals("test-4", replyChannel.receive(100).getPayload()); - this.aggregator.handle(createMessage("test-1b", 1, 1, 1, replyChannel)); + this.aggregator.onMessage(createMessage("test-1b", 1, 1, 1, replyChannel)); assertEquals("test-1b", replyChannel.receive(100).getPayload()); assertNull(discardChannel.receive(0)); } - @Test(expected=MessageHandlingException.class) + @Test(expected = MessageHandlingException.class) public void testExceptionThrownIfNoCorrelationId() throws InterruptedException { Message message = createMessage("123", null, 2, 1, new QueueChannel()); - this.aggregator.handle(message); + this.aggregator.onMessage(message); } @Test @@ -311,16 +311,7 @@ public class AggregatorEndpointTests { public void run() { try { - Message result = this.aggregator.handle(message); - if (result != null) { - Object returnAddress = message.getHeaders().getReturnAddress(); - if (returnAddress instanceof MessageChannel) { - ((MessageChannel) returnAddress).send(result); - } - else { - throw new IllegalStateException("'returnAddress' was not a MessageChannel instance"); - } - } + this.aggregator.onMessage(message); } catch (Exception e) { this.exception = e; diff --git a/org.springframework.integration/src/test/java/org/springframework/integration/aggregator/ResequencerTests.java b/org.springframework.integration/src/test/java/org/springframework/integration/aggregator/ResequencerTests.java index d6c8fb29e4..8c3dd048fc 100644 --- a/org.springframework.integration/src/test/java/org/springframework/integration/aggregator/ResequencerTests.java +++ b/org.springframework.integration/src/test/java/org/springframework/integration/aggregator/ResequencerTests.java @@ -57,9 +57,9 @@ public class ResequencerTests { Message message1 = createMessage("123", "ABC", 3, 3, replyChannel); Message message2 = createMessage("456", "ABC", 3, 1, replyChannel); Message message3 = createMessage("789", "ABC", 3, 2, replyChannel); - this.resequencer.handle(message1); - this.resequencer.handle(message3); - this.resequencer.handle(message2); + this.resequencer.onMessage(message1); + this.resequencer.onMessage(message3); + this.resequencer.onMessage(message2); Message reply1 = replyChannel.receive(0); Message reply2 = replyChannel.receive(0); Message reply3 = replyChannel.receive(0); @@ -79,9 +79,9 @@ public class ResequencerTests { Message message2 = createMessage("456", "ABC", 4, 1, replyChannel); Message message3 = createMessage("789", "ABC", 4, 4, replyChannel); Message message4 = createMessage("XYZ", "ABC", 4, 3, replyChannel); - this.resequencer.handle(message1); - this.resequencer.handle(message2); - this.resequencer.handle(message3); + this.resequencer.onMessage(message1); + this.resequencer.onMessage(message2); + this.resequencer.onMessage(message3); Message reply1 = replyChannel.receive(0); Message reply2 = replyChannel.receive(0); Message reply3 = replyChannel.receive(0); @@ -92,7 +92,7 @@ public class ResequencerTests { assertEquals(new Integer(2), reply2.getHeaders().getSequenceNumber()); assertNull(reply3); // when sending the last message, the whole sequence must have been sent - this.resequencer.handle(message4); + this.resequencer.onMessage(message4); reply3 = replyChannel.receive(0); Message reply4 = replyChannel.receive(0); assertNotNull(reply3); @@ -110,9 +110,9 @@ public class ResequencerTests { Message message2 = createMessage("456", "ABC", 4, 1, replyChannel); Message message3 = createMessage("789", "ABC", 4, 4, replyChannel); Message message4 = createMessage("XYZ", "ABC", 4, 3, replyChannel); - this.resequencer.handle(message1); - this.resequencer.handle(message2); - this.resequencer.handle(message3); + this.resequencer.onMessage(message1); + this.resequencer.onMessage(message2); + this.resequencer.onMessage(message3); Message reply1 = replyChannel.receive(0); Message reply2 = replyChannel.receive(0); Message reply3 = replyChannel.receive(0); @@ -121,7 +121,7 @@ public class ResequencerTests { assertNull(reply2); assertNull(reply3); // after sending the last message, the whole sequence should have been sent - this.resequencer.handle(message4); + this.resequencer.onMessage(message4); reply1 = replyChannel.receive(0); reply2 = replyChannel.receive(0); reply3 = replyChannel.receive(0); diff --git a/org.springframework.integration/src/test/java/org/springframework/integration/bus/DefaultMessageBusTests.java b/org.springframework.integration/src/test/java/org/springframework/integration/bus/DefaultMessageBusTests.java index b3be17e722..1d70353b3b 100644 --- a/org.springframework.integration/src/test/java/org/springframework/integration/bus/DefaultMessageBusTests.java +++ b/org.springframework.integration/src/test/java/org/springframework/integration/bus/DefaultMessageBusTests.java @@ -37,6 +37,7 @@ import org.springframework.integration.channel.QueueChannel; import org.springframework.integration.config.xml.MessageBusParser; import org.springframework.integration.endpoint.AbstractReplyProducingMessageConsumer; import org.springframework.integration.endpoint.PollingConsumerEndpoint; +import org.springframework.integration.endpoint.ReplyHolder; import org.springframework.integration.endpoint.SourcePollingChannelAdapter; import org.springframework.integration.endpoint.SubscribingConsumerEndpoint; import org.springframework.integration.message.ErrorMessage; @@ -66,8 +67,8 @@ public class DefaultMessageBusTests { .setReturnAddress("targetChannel").build(); sourceChannel.send(message); AbstractReplyProducingMessageConsumer consumer = new AbstractReplyProducingMessageConsumer() { - public Message handle(Message message) { - return message; + public void handle(Message message, ReplyHolder replyHolder) { + replyHolder.set(message); } }; consumer.setBeanFactory(context); @@ -124,13 +125,13 @@ public class DefaultMessageBusTests { QueueChannel outputChannel1 = new QueueChannel(); QueueChannel outputChannel2 = new QueueChannel(); AbstractReplyProducingMessageConsumer consumer1 = new AbstractReplyProducingMessageConsumer() { - public Message handle(Message message) { - return MessageBuilder.fromMessage(message).build(); + public void handle(Message message, ReplyHolder replyHolder) { + replyHolder.set(message); } }; AbstractReplyProducingMessageConsumer consumer2 = new AbstractReplyProducingMessageConsumer() { - public Message handle(Message message) { - return MessageBuilder.fromMessage(message).build(); + public void handle(Message message, ReplyHolder replyHolder) { + replyHolder.set(message); } }; inputChannel.setBeanName("input"); @@ -166,17 +167,15 @@ public class DefaultMessageBusTests { QueueChannel outputChannel2 = new QueueChannel(); final CountDownLatch latch = new CountDownLatch(2); AbstractReplyProducingMessageConsumer consumer1 = new AbstractReplyProducingMessageConsumer() { - public Message handle(Message message) { - Message reply = MessageBuilder.fromMessage(message).build(); + public void handle(Message message, ReplyHolder replyHolder) { + replyHolder.set(message); latch.countDown(); - return reply; } }; AbstractReplyProducingMessageConsumer consumer2 = new AbstractReplyProducingMessageConsumer() { - public Message handle(Message message) { - Message reply = MessageBuilder.fromMessage(message).build(); + public void handle(Message message, ReplyHolder replyHolder) { + replyHolder.set(message); latch.countDown(); - return reply; } }; inputChannel.setBeanName("input"); @@ -246,9 +245,8 @@ public class DefaultMessageBusTests { context.getBeanFactory().registerSingleton(DefaultMessageBus.ERROR_CHANNEL_BEAN_NAME, errorChannel); final CountDownLatch latch = new CountDownLatch(1); AbstractReplyProducingMessageConsumer consumer = new AbstractReplyProducingMessageConsumer() { - public Message handle(Message message) { + public void handle(Message message, ReplyHolder replyHolder) { latch.countDown(); - return null; } }; PollingConsumerEndpoint endpoint = new PollingConsumerEndpoint(consumer, errorChannel); diff --git a/org.springframework.integration/src/test/java/org/springframework/integration/bus/DirectChannelSubscriptionTests.java b/org.springframework.integration/src/test/java/org/springframework/integration/bus/DirectChannelSubscriptionTests.java index 4ad235134a..9351f5bafd 100644 --- a/org.springframework.integration/src/test/java/org/springframework/integration/bus/DirectChannelSubscriptionTests.java +++ b/org.springframework.integration/src/test/java/org/springframework/integration/bus/DirectChannelSubscriptionTests.java @@ -30,6 +30,7 @@ import org.springframework.integration.channel.ThreadLocalChannel; import org.springframework.integration.config.annotation.MessagingAnnotationPostProcessor; import org.springframework.integration.config.xml.MessageBusParser; import org.springframework.integration.endpoint.AbstractReplyProducingMessageConsumer; +import org.springframework.integration.endpoint.ReplyHolder; import org.springframework.integration.endpoint.ServiceActivatorEndpoint; import org.springframework.integration.endpoint.SubscribingConsumerEndpoint; import org.springframework.integration.message.Message; @@ -95,7 +96,7 @@ public class DirectChannelSubscriptionTests { @Test(expected = MessagingException.class) public void exceptionThrownFromRegisteredEndpoint() { AbstractReplyProducingMessageConsumer consumer = new AbstractReplyProducingMessageConsumer() { - public Message handle(Message message) { + public void handle(Message message, ReplyHolder replyHolder) { throw new RuntimeException("intentional test failure"); } }; diff --git a/org.springframework.integration/src/test/java/org/springframework/integration/channel/MessageChannelTemplateTests.java b/org.springframework.integration/src/test/java/org/springframework/integration/channel/MessageChannelTemplateTests.java index 8353cb397f..22dfd8814f 100644 --- a/org.springframework.integration/src/test/java/org/springframework/integration/channel/MessageChannelTemplateTests.java +++ b/org.springframework.integration/src/test/java/org/springframework/integration/channel/MessageChannelTemplateTests.java @@ -33,6 +33,7 @@ import org.springframework.context.support.GenericApplicationContext; import org.springframework.integration.bus.DefaultMessageBus; import org.springframework.integration.endpoint.AbstractReplyProducingMessageConsumer; import org.springframework.integration.endpoint.PollingConsumerEndpoint; +import org.springframework.integration.endpoint.ReplyHolder; import org.springframework.integration.message.Message; import org.springframework.integration.message.MessageBuilder; import org.springframework.integration.message.StringMessage; @@ -51,9 +52,9 @@ public class MessageChannelTemplateTests { this.requestChannel = new QueueChannel(); this.requestChannel.setBeanName("requestChannel"); AbstractReplyProducingMessageConsumer consumer = new AbstractReplyProducingMessageConsumer() { - public Message handle(Message message) { - return new StringMessage(message.getPayload().toString().toUpperCase()); - } + public void handle(Message message, ReplyHolder replyHolder) { + replyHolder.set(message.getPayload().toString().toUpperCase()); + } }; PollingConsumerEndpoint endpoint = new PollingConsumerEndpoint(consumer, requestChannel); endpoint.afterPropertiesSet(); diff --git a/org.springframework.integration/src/test/java/org/springframework/integration/endpoint/CorrelationIdTests.java b/org.springframework.integration/src/test/java/org/springframework/integration/endpoint/CorrelationIdTests.java index 6e21538f93..fc8762d862 100644 --- a/org.springframework.integration/src/test/java/org/springframework/integration/endpoint/CorrelationIdTests.java +++ b/org.springframework.integration/src/test/java/org/springframework/integration/endpoint/CorrelationIdTests.java @@ -50,20 +50,6 @@ public class CorrelationIdTests { assertEquals(correlationId, reply.getHeaders().getCorrelationId()); } - @Test - public void testCorrelationIdCopiedFromMessageIdByDefault() { - Message message = MessageBuilder.withPayload("test").build(); - DirectChannel inputChannel = new DirectChannel(); - QueueChannel outputChannel = new QueueChannel(1); - ServiceActivatorEndpoint serviceActivator = new ServiceActivatorEndpoint(new TestBean(), "upperCase"); - serviceActivator.setOutputChannel(outputChannel); - SubscribingConsumerEndpoint endpoint = new SubscribingConsumerEndpoint(serviceActivator, inputChannel); - endpoint.start(); - assertTrue(inputChannel.send(message)); - Message reply = outputChannel.receive(0); - assertEquals(message.getHeaders().getId(), reply.getHeaders().getCorrelationId()); - } - @Test public void testCorrelationIdCopiedFromMessageCorrelationIdIfAvailable() { Message message = MessageBuilder.withPayload("test") diff --git a/org.springframework.integration/src/test/java/org/springframework/integration/endpoint/ServiceActivatorEndpointTests.java b/org.springframework.integration/src/test/java/org/springframework/integration/endpoint/ServiceActivatorEndpointTests.java index 709693935b..ad1447f406 100644 --- a/org.springframework.integration/src/test/java/org/springframework/integration/endpoint/ServiceActivatorEndpointTests.java +++ b/org.springframework.integration/src/test/java/org/springframework/integration/endpoint/ServiceActivatorEndpointTests.java @@ -300,22 +300,6 @@ public class ServiceActivatorEndpointTests { assertNull(reply.getHeaders().getCorrelationId()); } - @Test - public void correlationIdSetForReplyMessage() { - QueueChannel replyChannel = new QueueChannel(1); - ServiceActivatorEndpoint endpoint = new ServiceActivatorEndpoint(new Object() { - @SuppressWarnings("unused") - public Message handle(Message message) { - return MessageBuilder.fromMessage(message).build(); - } - }, "handle"); - Message message = MessageBuilder.withPayload("test") - .setReturnAddress(replyChannel).build(); - endpoint.onMessage(message); - Message reply = replyChannel.receive(500); - assertEquals(message.getHeaders().getId(), reply.getHeaders().getCorrelationId()); - } - @Test public void correlationIdSetByHandlerTakesPrecedence() { QueueChannel replyChannel = new QueueChannel(1); diff --git a/org.springframework.integration/src/test/java/org/springframework/integration/filter/MessageFilterTests.java b/org.springframework.integration/src/test/java/org/springframework/integration/filter/MessageFilterTests.java index 9268430d04..d03a57b061 100644 --- a/org.springframework.integration/src/test/java/org/springframework/integration/filter/MessageFilterTests.java +++ b/org.springframework.integration/src/test/java/org/springframework/integration/filter/MessageFilterTests.java @@ -44,7 +44,10 @@ public class MessageFilterTests { } }); Message message = new StringMessage("test"); - assertEquals(message, filter.handle(message)); + QueueChannel output = new QueueChannel(); + filter.setOutputChannel(output); + filter.onMessage(message); + assertEquals(message, output.receive(0)); } @Test @@ -54,7 +57,10 @@ public class MessageFilterTests { return false; } }); - assertNull(filter.handle(new StringMessage("test"))); + QueueChannel output = new QueueChannel(); + filter.setOutputChannel(output); + filter.onMessage(new StringMessage("test")); + assertNull(output.receive(0)); } @Test diff --git a/org.springframework.integration/src/test/java/org/springframework/integration/splitter/DefaultSplitterTests.java b/org.springframework.integration/src/test/java/org/springframework/integration/splitter/DefaultSplitterTests.java index 9149f40878..1910c7381c 100644 --- a/org.springframework.integration/src/test/java/org/springframework/integration/splitter/DefaultSplitterTests.java +++ b/org.springframework.integration/src/test/java/org/springframework/integration/splitter/DefaultSplitterTests.java @@ -18,13 +18,16 @@ package org.springframework.integration.splitter; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; import java.util.Arrays; import java.util.List; import org.junit.Test; +import org.springframework.integration.channel.DirectChannel; import org.springframework.integration.channel.QueueChannel; +import org.springframework.integration.endpoint.SubscribingConsumerEndpoint; import org.springframework.integration.message.Message; import org.springframework.integration.message.MessageBuilder; @@ -75,4 +78,18 @@ public class DefaultSplitterTests { assertEquals("z", reply3.getPayload()); } + @Test + public void correlationIdCopiedFromMessageId() { + Message message = MessageBuilder.withPayload("test").build(); + DirectChannel inputChannel = new DirectChannel(); + QueueChannel outputChannel = new QueueChannel(1); + DefaultMessageSplitter splitter = new DefaultMessageSplitter(); + splitter.setOutputChannel(outputChannel); + SubscribingConsumerEndpoint endpoint = new SubscribingConsumerEndpoint(splitter, inputChannel); + endpoint.start(); + assertTrue(inputChannel.send(message)); + Message reply = outputChannel.receive(0); + assertEquals(message.getHeaders().getId(), reply.getHeaders().getCorrelationId()); + } + }