diff --git a/org.springframework.integration/src/main/java/org/springframework/integration/splitter/AbstractMessageSplitter.java b/org.springframework.integration/src/main/java/org/springframework/integration/splitter/AbstractMessageSplitter.java index d07fbb33b7..22736f8d62 100644 --- a/org.springframework.integration/src/main/java/org/springframework/integration/splitter/AbstractMessageSplitter.java +++ b/org.springframework.integration/src/main/java/org/springframework/integration/splitter/AbstractMessageSplitter.java @@ -35,7 +35,8 @@ public abstract class AbstractMessageSplitter extends AbstractReplyProducingMess if (result == null) { return; } - Object correlationId = message.getHeaders().getId(); + Object correlationId = (message.getHeaders().getCorrelationId() != null) ? + message.getHeaders().getCorrelationId(): message.getHeaders().getId(); if (result instanceof Collection) { Collection items = (Collection) result; int sequenceNumber = 0; diff --git a/org.springframework.integration/src/test/java/org/springframework/integration/endpoint/CorrelationIdTests.java b/org.springframework.integration/src/test/java/org/springframework/integration/endpoint/CorrelationIdTests.java index d4d3e2b9fb..0f1cafa65b 100644 --- a/org.springframework.integration/src/test/java/org/springframework/integration/endpoint/CorrelationIdTests.java +++ b/org.springframework.integration/src/test/java/org/springframework/integration/endpoint/CorrelationIdTests.java @@ -98,7 +98,7 @@ public class CorrelationIdTests { } @Test - public void testCorrelationIdWithSplitter() throws Exception { + public void testCorrelationIdWithSplitterWhenNotValueSetOnIncomingMessage() throws Exception { Message message = new StringMessage("test1,test2"); QueueChannel testChannel = new QueueChannel(); MethodInvokingSplitter splitter = new MethodInvokingSplitter( @@ -111,6 +111,21 @@ public class CorrelationIdTests { assertEquals(message.getHeaders().getId(), reply2.getHeaders().getCorrelationId()); } + @Test + public void testCorrelationIdWithSplitterWhenValueSetOnIncomingMessage() throws Exception { + + final String correlationIdForTest = "#FOR_TEST#"; + Message message = MessageBuilder.withPayload("test1,test2").setCorrelationId(correlationIdForTest).build(); + QueueChannel testChannel = new QueueChannel(); + MethodInvokingSplitter splitter = new MethodInvokingSplitter( + new TestBean(), TestBean.class.getMethod("split", String.class)); + splitter.setOutputChannel(testChannel); + splitter.handleMessage(message); + Message reply1 = testChannel.receive(100); + Message reply2 = testChannel.receive(100); + assertEquals(correlationIdForTest, reply1.getHeaders().getCorrelationId()); + assertEquals(correlationIdForTest, reply2.getHeaders().getCorrelationId()); + } private static class TestBean {