diff --git a/org.springframework.integration/src/main/java/org/springframework/integration/aggregator/AbstractMessageAggregator.java b/org.springframework.integration/src/main/java/org/springframework/integration/aggregator/AbstractMessageAggregator.java index f9938a3ae2..7713860534 100644 --- a/org.springframework.integration/src/main/java/org/springframework/integration/aggregator/AbstractMessageAggregator.java +++ b/org.springframework.integration/src/main/java/org/springframework/integration/aggregator/AbstractMessageAggregator.java @@ -62,8 +62,8 @@ public abstract class AbstractMessageAggregator extends } @Override - protected MessageBarrier>, Object> createMessageBarrier() { - return new MessageBarrier>, Object>(new LinkedHashMap>()); + protected MessageBarrier>, Object> createMessageBarrier(Object correlationKey) { + return new MessageBarrier>, Object>(new LinkedHashMap>(), correlationKey); } @Override @@ -75,12 +75,12 @@ public abstract class AbstractMessageAggregator extends } } if (barrier.isComplete()) { - this.removeBarrier(barrier.getCorrelationId()); + this.removeBarrier(barrier.getCorrelationKey()); Message result = this.aggregateMessages(messageList); if (result != null) { if (result.getHeaders().getCorrelationId() == null) { result = MessageBuilder.fromMessage(result) - .setCorrelationId(barrier.getCorrelationId()) + .setCorrelationId(barrier.getCorrelationKey()) .build(); } this.sendReply(result, this.resolveReplyChannelFromMessage(messageList.get(0))); diff --git a/org.springframework.integration/src/main/java/org/springframework/integration/aggregator/AbstractMessageBarrierHandler.java b/org.springframework.integration/src/main/java/org/springframework/integration/aggregator/AbstractMessageBarrierHandler.java index 15076f0357..f6296998b1 100644 --- a/org.springframework.integration/src/main/java/org/springframework/integration/aggregator/AbstractMessageBarrierHandler.java +++ b/org.springframework.integration/src/main/java/org/springframework/integration/aggregator/AbstractMessageBarrierHandler.java @@ -111,8 +111,10 @@ public abstract class AbstractMessageBarrierHandler> private final Object lifecycleMonitor = new Object(); + private CorrelationStrategy correlationStrategy = new HeaderAttributeCorrelationStrategy(MessageHeaders.CORRELATION_ID); - public AbstractMessageBarrierHandler() { + + public AbstractMessageBarrierHandler() { this.channelTemplate.setSendTimeout(DEFAULT_SEND_TIMEOUT); } @@ -181,7 +183,11 @@ public abstract class AbstractMessageBarrierHandler> } } - public final void afterPropertiesSet() { + public void setCorrelationStrategy(CorrelationStrategy correlationStrategy) { + this.correlationStrategy = correlationStrategy; + } + + public final void afterPropertiesSet() { synchronized (this.lifecycleMonitor) { if (!this.initialized) { this.trackedCorrelationIds = new ArrayBlockingQueue(this.trackedCorrelationIdCapacity); @@ -223,20 +229,20 @@ public abstract class AbstractMessageBarrierHandler> if (!this.initialized) { this.afterPropertiesSet(); } - Object correlationId = message.getHeaders().getCorrelationId(); - if (correlationId == null) { + Object correlationKey = this.correlationStrategy.getCorrelationKey(message); + if (correlationKey == null) { throw new MessageHandlingException(message, this.getClass().getSimpleName() - + " requires the 'correlationId' property"); + + " requires the 'correlationKey' property"); } - if (this.trackedCorrelationIds.contains(correlationId)) { + if (this.trackedCorrelationIds.contains(correlationKey)) { if (logger.isDebugEnabled()) { - logger.debug("Handling of Message group with correlationId '" + correlationId + logger.debug("Handling of Message group with correlationKey '" + correlationKey + "' has already completed or timed out."); } this.discardMessage(message); } else { - this.processMessage(message, correlationId); + this.processMessage(message, correlationKey); } } @@ -249,10 +255,10 @@ public abstract class AbstractMessageBarrierHandler> } } - private void processMessage(Message message, Object correlationId) { - MessageBarrier barrier = barriers.putIfAbsent(correlationId, createMessageBarrier()); + private void processMessage(Message message, Object correlationKey) { + MessageBarrier barrier = barriers.putIfAbsent(correlationKey, createMessageBarrier(correlationKey)); if (barrier == null) { - barrier = barriers.get(message.getHeaders().getCorrelationId()); + barrier = barriers.get(correlationKey); } synchronized (barrier) { if (canAddMessage(message, barrier)) { @@ -335,7 +341,7 @@ public abstract class AbstractMessageBarrierHandler> /** * Factory method for creating a MessageBarrier implementation. */ - protected abstract MessageBarrier createMessageBarrier(); + protected abstract MessageBarrier createMessageBarrier(Object correlationKey); /** * A method for processing the information in the message barrier after a message has been added or on pruning. diff --git a/org.springframework.integration/src/main/java/org/springframework/integration/aggregator/CorrelationStrategy.java b/org.springframework.integration/src/main/java/org/springframework/integration/aggregator/CorrelationStrategy.java new file mode 100644 index 0000000000..b6591a5461 --- /dev/null +++ b/org.springframework.integration/src/main/java/org/springframework/integration/aggregator/CorrelationStrategy.java @@ -0,0 +1,31 @@ +/* + * Copyright 2002-2008 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.integration.aggregator; + +import org.springframework.integration.core.Message; + +/** + * Strategy for determining how messages shall be correlated. Implementations + * shall return the correlation key value associated with a particular message. + * + * @author: Marius Bogoevici + */ +public interface CorrelationStrategy { + + Object getCorrelationKey(Message message); + +} diff --git a/org.springframework.integration/src/main/java/org/springframework/integration/aggregator/HeaderAttributeCorrelationStrategy.java b/org.springframework.integration/src/main/java/org/springframework/integration/aggregator/HeaderAttributeCorrelationStrategy.java new file mode 100644 index 0000000000..9c9bb74cd2 --- /dev/null +++ b/org.springframework.integration/src/main/java/org/springframework/integration/aggregator/HeaderAttributeCorrelationStrategy.java @@ -0,0 +1,41 @@ +/* + * Copyright 2002-2008 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.integration.aggregator; + +import org.springframework.integration.core.Message; + +/** + * Default implementation of {@link CorrelationStrategy}. Uses a header + * attribute to determine the correlation key value. + * + * @author: Marius Bogoevici + */ +public class HeaderAttributeCorrelationStrategy implements CorrelationStrategy { + + private String attributeName; + + + public HeaderAttributeCorrelationStrategy(String attributeName) { + this.attributeName = attributeName; + } + + + public Object getCorrelationKey(Message message) { + return message.getHeaders().get(this.attributeName); + } + +} diff --git a/org.springframework.integration/src/main/java/org/springframework/integration/aggregator/MessageBarrier.java b/org.springframework.integration/src/main/java/org/springframework/integration/aggregator/MessageBarrier.java index a497f53d6e..7c3987be68 100644 --- a/org.springframework.integration/src/main/java/org/springframework/integration/aggregator/MessageBarrier.java +++ b/org.springframework.integration/src/main/java/org/springframework/integration/aggregator/MessageBarrier.java @@ -39,18 +39,17 @@ public class MessageBarrier>, K> { private volatile boolean complete = false; + private Object correlationKey; + private final long timestamp = System.currentTimeMillis(); - public MessageBarrier(T messages) { + public MessageBarrier(T messages, Object correlationKey) { this.messages = messages; + this.correlationKey = correlationKey; } - public Object getCorrelationId() { - if (!messages.isEmpty()) { - return messages.values().iterator().next().getHeaders() - .getCorrelationId(); - } - return null; + public Object getCorrelationKey() { + return this.correlationKey; } /** diff --git a/org.springframework.integration/src/main/java/org/springframework/integration/aggregator/Resequencer.java b/org.springframework.integration/src/main/java/org/springframework/integration/aggregator/Resequencer.java index 23723bbac9..d5b4b08e80 100644 --- a/org.springframework.integration/src/main/java/org/springframework/integration/aggregator/Resequencer.java +++ b/org.springframework.integration/src/main/java/org/springframework/integration/aggregator/Resequencer.java @@ -50,9 +50,9 @@ public class Resequencer extends AbstractMessageBarrierHandler>, Integer> createMessageBarrier() { + protected MessageBarrier>, Integer> createMessageBarrier(Object correlationKey) { MessageBarrier>, Integer> messageBarrier - = new MessageBarrier>, Integer>(new TreeMap>()); + = new MessageBarrier>, Integer>(new TreeMap>(), correlationKey); messageBarrier.getMessages().put(0, createFlagMessage(0)); return messageBarrier; } @@ -66,7 +66,7 @@ public class Resequencer extends AbstractMessageBarrierHandler lastMessage = releasedMessages.get(releasedMessages.size()-1); if (lastMessage.getHeaders().getSequenceNumber().equals(lastMessage.getHeaders().getSequenceSize() - 1)) { - this.removeBarrier(barrier.getCorrelationId()); + this.removeBarrier(barrier.getCorrelationKey()); } this.sendReplies(releasedMessages, this.resolveReplyChannelFromMessage(releasedMessages.get(0))); } diff --git a/org.springframework.integration/src/test/java/org/springframework/integration/aggregator/HeaderAttributeCorrelationStrategyTests.java b/org.springframework.integration/src/test/java/org/springframework/integration/aggregator/HeaderAttributeCorrelationStrategyTests.java new file mode 100644 index 0000000000..534b304216 --- /dev/null +++ b/org.springframework.integration/src/test/java/org/springframework/integration/aggregator/HeaderAttributeCorrelationStrategyTests.java @@ -0,0 +1,43 @@ +/* + * Copyright 2002-2008 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.integration.aggregator; + +import static org.junit.Assert.assertEquals; + +import org.springframework.integration.core.Message; +import org.springframework.integration.core.MessageHeaders; +import org.springframework.integration.message.MessageBuilder; + +import org.junit.Test; + +/** + * @author: Marius Bogoevici + */ +public class HeaderAttributeCorrelationStrategyTests { + + @Test + public void testHeaderAttributeCorrelationStrategy() { + String testedHeaderValue = "@!arbitraryTestValue!@"; + String testHeaderName = "header.for.test"; + Message message = MessageBuilder.withPayload("irrelevantData").setHeader(testHeaderName, testedHeaderValue).build(); + HeaderAttributeCorrelationStrategy correlationStrategy = new HeaderAttributeCorrelationStrategy(testHeaderName); + assertEquals(testedHeaderValue, correlationStrategy.getCorrelationKey(message)); + } + + + +} diff --git a/org.springframework.integration/src/test/java/org/springframework/integration/aggregator/MessageBarrierTests.java b/org.springframework.integration/src/test/java/org/springframework/integration/aggregator/MessageBarrierTests.java index e1105551d0..581588894c 100644 --- a/org.springframework.integration/src/test/java/org/springframework/integration/aggregator/MessageBarrierTests.java +++ b/org.springframework.integration/src/test/java/org/springframework/integration/aggregator/MessageBarrierTests.java @@ -32,7 +32,7 @@ public class MessageBarrierTests { @Test public void testMessageRetrieval() { - MessageBarrier barrier = new MessageBarrier(new LinkedHashMap()); + MessageBarrier barrier = new MessageBarrier(new LinkedHashMap(), null); barrier.getMessages().put("1", new StringMessage("test1")); assertEquals(1, barrier.getMessages().size()); barrier.getMessages().put("2", new StringMessage("test2")); @@ -42,7 +42,7 @@ public class MessageBarrierTests { @Test public void testTimestamp() { long before = System.currentTimeMillis(); - MessageBarrier barrier = new MessageBarrier(new LinkedHashMap()); + MessageBarrier barrier = new MessageBarrier(new LinkedHashMap(), null); long timestamp = barrier.getTimestamp(); assertTrue(before <= timestamp); long after = System.currentTimeMillis(); @@ -51,7 +51,7 @@ public class MessageBarrierTests { @Test public void testEmptyMessageList() { - MessageBarrier barrier = new MessageBarrier(new LinkedHashMap()); + MessageBarrier barrier = new MessageBarrier(new LinkedHashMap(), null); assertEquals(0, barrier.getMessages().size()); }