Added ReplyHolder and modified signature of AbstractReplyProducingMessageConsumer to return void and accept this additional parameter instead.

This commit is contained in:
Mark Fisher
2008-10-13 19:30:23 +00:00
parent 78ae171dd7
commit 6f7bd01c2a
22 changed files with 249 additions and 245 deletions

View File

@@ -153,9 +153,9 @@ public class AggregatorEndpointTests {
QueueChannel replyChannel = new QueueChannel();
QueueChannel discardChannel = new QueueChannel();
this.aggregator.setDiscardChannel(discardChannel);
this.aggregator.handle(createMessage("test-1a", 1, 1, 1, replyChannel));
this.aggregator.onMessage(createMessage("test-1a", 1, 1, 1, replyChannel));
assertEquals("test-1a", replyChannel.receive(100).getPayload());
this.aggregator.handle(createMessage("test-1b", 1, 1, 1, replyChannel));
this.aggregator.onMessage(createMessage("test-1b", 1, 1, 1, replyChannel));
assertEquals("test-1b", discardChannel.receive(100).getPayload());
}
@@ -165,13 +165,13 @@ public class AggregatorEndpointTests {
QueueChannel discardChannel = new QueueChannel();
this.aggregator.setTrackedCorrelationIdCapacity(3);
this.aggregator.setDiscardChannel(discardChannel);
this.aggregator.handle(createMessage("test-1a", 1, 1, 1, replyChannel));
this.aggregator.onMessage(createMessage("test-1a", 1, 1, 1, replyChannel));
assertEquals("test-1a", replyChannel.receive(100).getPayload());
this.aggregator.handle(createMessage("test-2", 2, 1, 1, replyChannel));
this.aggregator.onMessage(createMessage("test-2", 2, 1, 1, replyChannel));
assertEquals("test-2", replyChannel.receive(100).getPayload());
this.aggregator.handle(createMessage("test-3", 3, 1, 1, replyChannel));
this.aggregator.onMessage(createMessage("test-3", 3, 1, 1, replyChannel));
assertEquals("test-3", replyChannel.receive(100).getPayload());
this.aggregator.handle(createMessage("test-1b", 1, 1, 1, replyChannel));
this.aggregator.onMessage(createMessage("test-1b", 1, 1, 1, replyChannel));
assertEquals("test-1b", discardChannel.receive(100).getPayload());
}
@@ -181,23 +181,23 @@ public class AggregatorEndpointTests {
QueueChannel discardChannel = new QueueChannel();
this.aggregator.setTrackedCorrelationIdCapacity(3);
this.aggregator.setDiscardChannel(discardChannel);
this.aggregator.handle(createMessage("test-1a", 1, 1, 1, replyChannel));
this.aggregator.onMessage(createMessage("test-1a", 1, 1, 1, replyChannel));
assertEquals("test-1a", replyChannel.receive(100).getPayload());
this.aggregator.handle(createMessage("test-2", 2, 1, 1, replyChannel));
this.aggregator.onMessage(createMessage("test-2", 2, 1, 1, replyChannel));
assertEquals("test-2", replyChannel.receive(100).getPayload());
this.aggregator.handle(createMessage("test-3", 3, 1, 1, replyChannel));
this.aggregator.onMessage(createMessage("test-3", 3, 1, 1, replyChannel));
assertEquals("test-3", replyChannel.receive(100).getPayload());
this.aggregator.handle(createMessage("test-4", 4, 1, 1, replyChannel));
this.aggregator.onMessage(createMessage("test-4", 4, 1, 1, replyChannel));
assertEquals("test-4", replyChannel.receive(100).getPayload());
this.aggregator.handle(createMessage("test-1b", 1, 1, 1, replyChannel));
this.aggregator.onMessage(createMessage("test-1b", 1, 1, 1, replyChannel));
assertEquals("test-1b", replyChannel.receive(100).getPayload());
assertNull(discardChannel.receive(0));
}
@Test(expected=MessageHandlingException.class)
@Test(expected = MessageHandlingException.class)
public void testExceptionThrownIfNoCorrelationId() throws InterruptedException {
Message<?> message = createMessage("123", null, 2, 1, new QueueChannel());
this.aggregator.handle(message);
this.aggregator.onMessage(message);
}
@Test
@@ -311,16 +311,7 @@ public class AggregatorEndpointTests {
public void run() {
try {
Message<?> result = this.aggregator.handle(message);
if (result != null) {
Object returnAddress = message.getHeaders().getReturnAddress();
if (returnAddress instanceof MessageChannel) {
((MessageChannel) returnAddress).send(result);
}
else {
throw new IllegalStateException("'returnAddress' was not a MessageChannel instance");
}
}
this.aggregator.onMessage(message);
}
catch (Exception e) {
this.exception = e;

View File

@@ -57,9 +57,9 @@ public class ResequencerTests {
Message<?> message1 = createMessage("123", "ABC", 3, 3, replyChannel);
Message<?> message2 = createMessage("456", "ABC", 3, 1, replyChannel);
Message<?> message3 = createMessage("789", "ABC", 3, 2, replyChannel);
this.resequencer.handle(message1);
this.resequencer.handle(message3);
this.resequencer.handle(message2);
this.resequencer.onMessage(message1);
this.resequencer.onMessage(message3);
this.resequencer.onMessage(message2);
Message<?> reply1 = replyChannel.receive(0);
Message<?> reply2 = replyChannel.receive(0);
Message<?> reply3 = replyChannel.receive(0);
@@ -79,9 +79,9 @@ public class ResequencerTests {
Message<?> message2 = createMessage("456", "ABC", 4, 1, replyChannel);
Message<?> message3 = createMessage("789", "ABC", 4, 4, replyChannel);
Message<?> message4 = createMessage("XYZ", "ABC", 4, 3, replyChannel);
this.resequencer.handle(message1);
this.resequencer.handle(message2);
this.resequencer.handle(message3);
this.resequencer.onMessage(message1);
this.resequencer.onMessage(message2);
this.resequencer.onMessage(message3);
Message<?> reply1 = replyChannel.receive(0);
Message<?> reply2 = replyChannel.receive(0);
Message<?> reply3 = replyChannel.receive(0);
@@ -92,7 +92,7 @@ public class ResequencerTests {
assertEquals(new Integer(2), reply2.getHeaders().getSequenceNumber());
assertNull(reply3);
// when sending the last message, the whole sequence must have been sent
this.resequencer.handle(message4);
this.resequencer.onMessage(message4);
reply3 = replyChannel.receive(0);
Message<?> reply4 = replyChannel.receive(0);
assertNotNull(reply3);
@@ -110,9 +110,9 @@ public class ResequencerTests {
Message<?> message2 = createMessage("456", "ABC", 4, 1, replyChannel);
Message<?> message3 = createMessage("789", "ABC", 4, 4, replyChannel);
Message<?> message4 = createMessage("XYZ", "ABC", 4, 3, replyChannel);
this.resequencer.handle(message1);
this.resequencer.handle(message2);
this.resequencer.handle(message3);
this.resequencer.onMessage(message1);
this.resequencer.onMessage(message2);
this.resequencer.onMessage(message3);
Message<?> reply1 = replyChannel.receive(0);
Message<?> reply2 = replyChannel.receive(0);
Message<?> reply3 = replyChannel.receive(0);
@@ -121,7 +121,7 @@ public class ResequencerTests {
assertNull(reply2);
assertNull(reply3);
// after sending the last message, the whole sequence should have been sent
this.resequencer.handle(message4);
this.resequencer.onMessage(message4);
reply1 = replyChannel.receive(0);
reply2 = replyChannel.receive(0);
reply3 = replyChannel.receive(0);

View File

@@ -37,6 +37,7 @@ import org.springframework.integration.channel.QueueChannel;
import org.springframework.integration.config.xml.MessageBusParser;
import org.springframework.integration.endpoint.AbstractReplyProducingMessageConsumer;
import org.springframework.integration.endpoint.PollingConsumerEndpoint;
import org.springframework.integration.endpoint.ReplyHolder;
import org.springframework.integration.endpoint.SourcePollingChannelAdapter;
import org.springframework.integration.endpoint.SubscribingConsumerEndpoint;
import org.springframework.integration.message.ErrorMessage;
@@ -66,8 +67,8 @@ public class DefaultMessageBusTests {
.setReturnAddress("targetChannel").build();
sourceChannel.send(message);
AbstractReplyProducingMessageConsumer consumer = new AbstractReplyProducingMessageConsumer() {
public Message<?> handle(Message<?> message) {
return message;
public void handle(Message<?> message, ReplyHolder replyHolder) {
replyHolder.set(message);
}
};
consumer.setBeanFactory(context);
@@ -124,13 +125,13 @@ public class DefaultMessageBusTests {
QueueChannel outputChannel1 = new QueueChannel();
QueueChannel outputChannel2 = new QueueChannel();
AbstractReplyProducingMessageConsumer consumer1 = new AbstractReplyProducingMessageConsumer() {
public Message<?> handle(Message<?> message) {
return MessageBuilder.fromMessage(message).build();
public void handle(Message<?> message, ReplyHolder replyHolder) {
replyHolder.set(message);
}
};
AbstractReplyProducingMessageConsumer consumer2 = new AbstractReplyProducingMessageConsumer() {
public Message<?> handle(Message<?> message) {
return MessageBuilder.fromMessage(message).build();
public void handle(Message<?> message, ReplyHolder replyHolder) {
replyHolder.set(message);
}
};
inputChannel.setBeanName("input");
@@ -166,17 +167,15 @@ public class DefaultMessageBusTests {
QueueChannel outputChannel2 = new QueueChannel();
final CountDownLatch latch = new CountDownLatch(2);
AbstractReplyProducingMessageConsumer consumer1 = new AbstractReplyProducingMessageConsumer() {
public Message<?> handle(Message<?> message) {
Message<?> reply = MessageBuilder.fromMessage(message).build();
public void handle(Message<?> message, ReplyHolder replyHolder) {
replyHolder.set(message);
latch.countDown();
return reply;
}
};
AbstractReplyProducingMessageConsumer consumer2 = new AbstractReplyProducingMessageConsumer() {
public Message<?> handle(Message<?> message) {
Message<?> reply = MessageBuilder.fromMessage(message).build();
public void handle(Message<?> message, ReplyHolder replyHolder) {
replyHolder.set(message);
latch.countDown();
return reply;
}
};
inputChannel.setBeanName("input");
@@ -246,9 +245,8 @@ public class DefaultMessageBusTests {
context.getBeanFactory().registerSingleton(DefaultMessageBus.ERROR_CHANNEL_BEAN_NAME, errorChannel);
final CountDownLatch latch = new CountDownLatch(1);
AbstractReplyProducingMessageConsumer consumer = new AbstractReplyProducingMessageConsumer() {
public Message<?> handle(Message<?> message) {
public void handle(Message<?> message, ReplyHolder replyHolder) {
latch.countDown();
return null;
}
};
PollingConsumerEndpoint endpoint = new PollingConsumerEndpoint(consumer, errorChannel);

View File

@@ -30,6 +30,7 @@ import org.springframework.integration.channel.ThreadLocalChannel;
import org.springframework.integration.config.annotation.MessagingAnnotationPostProcessor;
import org.springframework.integration.config.xml.MessageBusParser;
import org.springframework.integration.endpoint.AbstractReplyProducingMessageConsumer;
import org.springframework.integration.endpoint.ReplyHolder;
import org.springframework.integration.endpoint.ServiceActivatorEndpoint;
import org.springframework.integration.endpoint.SubscribingConsumerEndpoint;
import org.springframework.integration.message.Message;
@@ -95,7 +96,7 @@ public class DirectChannelSubscriptionTests {
@Test(expected = MessagingException.class)
public void exceptionThrownFromRegisteredEndpoint() {
AbstractReplyProducingMessageConsumer consumer = new AbstractReplyProducingMessageConsumer() {
public Message<?> handle(Message<?> message) {
public void handle(Message<?> message, ReplyHolder replyHolder) {
throw new RuntimeException("intentional test failure");
}
};

View File

@@ -33,6 +33,7 @@ import org.springframework.context.support.GenericApplicationContext;
import org.springframework.integration.bus.DefaultMessageBus;
import org.springframework.integration.endpoint.AbstractReplyProducingMessageConsumer;
import org.springframework.integration.endpoint.PollingConsumerEndpoint;
import org.springframework.integration.endpoint.ReplyHolder;
import org.springframework.integration.message.Message;
import org.springframework.integration.message.MessageBuilder;
import org.springframework.integration.message.StringMessage;
@@ -51,9 +52,9 @@ public class MessageChannelTemplateTests {
this.requestChannel = new QueueChannel();
this.requestChannel.setBeanName("requestChannel");
AbstractReplyProducingMessageConsumer consumer = new AbstractReplyProducingMessageConsumer() {
public Message<?> handle(Message<?> message) {
return new StringMessage(message.getPayload().toString().toUpperCase());
}
public void handle(Message<?> message, ReplyHolder replyHolder) {
replyHolder.set(message.getPayload().toString().toUpperCase());
}
};
PollingConsumerEndpoint endpoint = new PollingConsumerEndpoint(consumer, requestChannel);
endpoint.afterPropertiesSet();

View File

@@ -50,20 +50,6 @@ public class CorrelationIdTests {
assertEquals(correlationId, reply.getHeaders().getCorrelationId());
}
@Test
public void testCorrelationIdCopiedFromMessageIdByDefault() {
Message<String> message = MessageBuilder.withPayload("test").build();
DirectChannel inputChannel = new DirectChannel();
QueueChannel outputChannel = new QueueChannel(1);
ServiceActivatorEndpoint serviceActivator = new ServiceActivatorEndpoint(new TestBean(), "upperCase");
serviceActivator.setOutputChannel(outputChannel);
SubscribingConsumerEndpoint endpoint = new SubscribingConsumerEndpoint(serviceActivator, inputChannel);
endpoint.start();
assertTrue(inputChannel.send(message));
Message<?> reply = outputChannel.receive(0);
assertEquals(message.getHeaders().getId(), reply.getHeaders().getCorrelationId());
}
@Test
public void testCorrelationIdCopiedFromMessageCorrelationIdIfAvailable() {
Message<String> message = MessageBuilder.withPayload("test")

View File

@@ -300,22 +300,6 @@ public class ServiceActivatorEndpointTests {
assertNull(reply.getHeaders().getCorrelationId());
}
@Test
public void correlationIdSetForReplyMessage() {
QueueChannel replyChannel = new QueueChannel(1);
ServiceActivatorEndpoint endpoint = new ServiceActivatorEndpoint(new Object() {
@SuppressWarnings("unused")
public Message<?> handle(Message<?> message) {
return MessageBuilder.fromMessage(message).build();
}
}, "handle");
Message<String> message = MessageBuilder.withPayload("test")
.setReturnAddress(replyChannel).build();
endpoint.onMessage(message);
Message<?> reply = replyChannel.receive(500);
assertEquals(message.getHeaders().getId(), reply.getHeaders().getCorrelationId());
}
@Test
public void correlationIdSetByHandlerTakesPrecedence() {
QueueChannel replyChannel = new QueueChannel(1);

View File

@@ -44,7 +44,10 @@ public class MessageFilterTests {
}
});
Message<?> message = new StringMessage("test");
assertEquals(message, filter.handle(message));
QueueChannel output = new QueueChannel();
filter.setOutputChannel(output);
filter.onMessage(message);
assertEquals(message, output.receive(0));
}
@Test
@@ -54,7 +57,10 @@ public class MessageFilterTests {
return false;
}
});
assertNull(filter.handle(new StringMessage("test")));
QueueChannel output = new QueueChannel();
filter.setOutputChannel(output);
filter.onMessage(new StringMessage("test"));
assertNull(output.receive(0));
}
@Test

View File

@@ -18,13 +18,16 @@ package org.springframework.integration.splitter;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import java.util.Arrays;
import java.util.List;
import org.junit.Test;
import org.springframework.integration.channel.DirectChannel;
import org.springframework.integration.channel.QueueChannel;
import org.springframework.integration.endpoint.SubscribingConsumerEndpoint;
import org.springframework.integration.message.Message;
import org.springframework.integration.message.MessageBuilder;
@@ -75,4 +78,18 @@ public class DefaultSplitterTests {
assertEquals("z", reply3.getPayload());
}
@Test
public void correlationIdCopiedFromMessageId() {
Message<String> message = MessageBuilder.withPayload("test").build();
DirectChannel inputChannel = new DirectChannel();
QueueChannel outputChannel = new QueueChannel(1);
DefaultMessageSplitter splitter = new DefaultMessageSplitter();
splitter.setOutputChannel(outputChannel);
SubscribingConsumerEndpoint endpoint = new SubscribingConsumerEndpoint(splitter, inputChannel);
endpoint.start();
assertTrue(inputChannel.send(message));
Message<?> reply = outputChannel.receive(0);
assertEquals(message.getHeaders().getId(), reply.getHeaders().getCorrelationId());
}
}