diff --git a/org.springframework.integration/src/main/java/org/springframework/integration/gateway/GatewayProxyFactoryBean.java b/org.springframework.integration/src/main/java/org/springframework/integration/gateway/GatewayProxyFactoryBean.java index 224144feba..1dfe9fbfa6 100644 --- a/org.springframework.integration/src/main/java/org/springframework/integration/gateway/GatewayProxyFactoryBean.java +++ b/org.springframework.integration/src/main/java/org/springframework/integration/gateway/GatewayProxyFactoryBean.java @@ -114,7 +114,7 @@ public class GatewayProxyFactoryBean extends SimpleMessagingGateway return invocation.proceed(); } - private Object invokeGatewayMethod(MethodInvocation invocation) throws Throwable { + private Object invokeGatewayMethod(MethodInvocation invocation) throws Exception { if (!this.initialized) { this.afterPropertiesSet(); } diff --git a/org.springframework.integration/src/main/java/org/springframework/integration/gateway/SimpleMessagingGateway.java b/org.springframework.integration/src/main/java/org/springframework/integration/gateway/SimpleMessagingGateway.java index f4a2ea68fc..be6b7c5cd5 100644 --- a/org.springframework.integration/src/main/java/org/springframework/integration/gateway/SimpleMessagingGateway.java +++ b/org.springframework.integration/src/main/java/org/springframework/integration/gateway/SimpleMessagingGateway.java @@ -27,7 +27,6 @@ import org.springframework.integration.handler.ReplyMessageCorrelator; import org.springframework.integration.message.DefaultMessageCreator; import org.springframework.integration.message.DefaultMessageMapper; import org.springframework.integration.message.Message; -import org.springframework.integration.message.MessageBuilder; import org.springframework.integration.message.MessageCreator; import org.springframework.integration.message.MessageDeliveryException; import org.springframework.integration.message.MessageMapper; @@ -48,10 +47,6 @@ public class SimpleMessagingGateway extends MessagingGatewaySupport implements M private volatile PollableChannel replyChannel; - private volatile long replyTimeout = -1; - - private volatile int replyMapCapacity = 100; - private volatile MessageCreator messageCreator = new DefaultMessageCreator(); private volatile MessageMapper messageMapper = new DefaultMessageMapper(); @@ -91,14 +86,6 @@ public class SimpleMessagingGateway extends MessagingGatewaySupport implements M this.replyChannel = replyChannel; } - /** - * Set the max capacity for the map that is used to store replies - * until requested by the correlationId. The default value is 100. - */ - public void setReplyMapCapacity(int replyMapCapacity) { - this.replyMapCapacity = replyMapCapacity; - } - public void setMessageCreator(MessageCreator messageCreator) { Assert.notNull(messageCreator, "messageCreator must not be null"); this.messageCreator = messageCreator; @@ -113,11 +100,6 @@ public class SimpleMessagingGateway extends MessagingGatewaySupport implements M this.endpointRegistry = messageBus; } - public void setReplyTimeout(long replyTimeout) { - this.replyTimeout = replyTimeout; - super.setReplyTimeout(replyTimeout); - } - public void send(Object object) { if (this.requestChannel == null) { throw new IllegalStateException( @@ -167,23 +149,10 @@ public class SimpleMessagingGateway extends MessagingGatewaySupport implements M throw new MessageDeliveryException(message, "No request channel available. Cannot send request message."); } - if (this.replyChannel != null) { - return this.sendAndReceiveWithReplyMessageCorrelator(message); - } - else { - return this.getChannelTemplate().sendAndReceive(message, this.requestChannel); - } - } - - private Message sendAndReceiveWithReplyMessageCorrelator(Message message) { - if (this.replyMessageCorrelator == null) { + if (this.replyChannel != null && this.replyMessageCorrelator == null) { this.registerReplyMessageCorrelator(); } - message = MessageBuilder.fromMessage(message).setReturnAddress(this.replyChannel).build(); - this.send(message); - return (this.replyTimeout >= 0) - ? this.replyMessageCorrelator.getReply(message.getHeaders().getId(), this.replyTimeout) - : this.replyMessageCorrelator.getReply(message.getHeaders().getId()); + return this.getChannelTemplate().sendAndReceive(message, this.requestChannel); } private void registerReplyMessageCorrelator() { @@ -194,7 +163,7 @@ public class SimpleMessagingGateway extends MessagingGatewaySupport implements M if (this.endpointRegistry == null) { throw new ConfigurationException("No EndpointRegistry available. Cannot register ReplyMessageCorrelator."); } - ReplyMessageCorrelator correlator = new ReplyMessageCorrelator(this.replyMapCapacity); + ReplyMessageCorrelator correlator = new ReplyMessageCorrelator(); correlator.setBeanName("internal.correlator." + this); correlator.setInputChannel(this.replyChannel); correlator.afterPropertiesSet(); diff --git a/org.springframework.integration/src/main/java/org/springframework/integration/handler/ReplyMessageCorrelator.java b/org.springframework.integration/src/main/java/org/springframework/integration/handler/ReplyMessageCorrelator.java index 3e4e9113d9..d98b684e4f 100644 --- a/org.springframework.integration/src/main/java/org/springframework/integration/handler/ReplyMessageCorrelator.java +++ b/org.springframework.integration/src/main/java/org/springframework/integration/handler/ReplyMessageCorrelator.java @@ -16,63 +16,45 @@ package org.springframework.integration.handler; +import org.springframework.integration.channel.MessageChannel; import org.springframework.integration.endpoint.AbstractInOutEndpoint; import org.springframework.integration.message.Message; import org.springframework.integration.message.MessageHandlingException; -import org.springframework.integration.message.RetrievalBlockingMessageStore; -import org.springframework.util.Assert; +import org.springframework.integration.message.MessagingException; /** - * A handler for receiving messages from a "reply channel". Any component that - * is expecting a reply message can poll by providing the correlation identifier. + * A handler for receiving messages from a "reply channel". * * @author Mark Fisher */ public class ReplyMessageCorrelator extends AbstractInOutEndpoint { - private volatile long defaultTimeout = 5000; - - private final RetrievalBlockingMessageStore messageStore; - - - public ReplyMessageCorrelator(int capacity) { - this.messageStore = new RetrievalBlockingMessageStore(capacity); - } - - - public void setDefaultTimeout(long defaultTimeout) { - Assert.isTrue(defaultTimeout >= 0, "'defaultTimeout' must not be negative"); - this.defaultTimeout = defaultTimeout; - } - @Override public Message handle(Message message) { - Object correlationId = this.getCorrelationId(message); - if (correlationId == null) { - throw new MessageHandlingException(message, - "unable to handle response, message has no correlationId: " + message); + Object returnAddress = message.getHeaders().getReturnAddress(); + if (returnAddress == null) { + throw new MessageHandlingException(message, + "unable to correlate response, message has no returnAddress"); + } + MessageChannel replyChannel = null; + if (returnAddress instanceof MessageChannel) { + replyChannel = (MessageChannel) returnAddress; + } + else if (returnAddress instanceof String) { + replyChannel = this.getChannelRegistry().lookupChannel((String) returnAddress); + if (replyChannel == null) { + throw new MessagingException(message, + "unable to resolve returnAddress '" + returnAddress + "'"); + } + } + else { + throw new MessagingException(message, + "invalid returnAddress type [" + returnAddress.getClass() + "]"); + } + if (replyChannel != null) { + replyChannel.send(message); } - this.messageStore.put(correlationId, message); return null; } - public Message getReply(Object correlationId) { - return this.getReply(correlationId, this.defaultTimeout); - } - - public Message getReply(Object correlationId, long timeout) { - Assert.notNull(correlationId, "'correlationId' must not be null"); - return this.messageStore.remove(correlationId, timeout); - } - - /** - * Retrieve the correlation identifier from the provided message. - *

- * This method may be overridden by subclasses. The default implementation - * returns the 'correlationId' from the message header. - */ - protected Object getCorrelationId(final Message message) { - return message.getHeaders().getCorrelationId(); - } - } diff --git a/org.springframework.integration/src/test/java/org/springframework/integration/gateway/GatewayProxyFactoryBeanTests.java b/org.springframework.integration/src/test/java/org/springframework/integration/gateway/GatewayProxyFactoryBeanTests.java index 310cd5123b..3f87c23210 100644 --- a/org.springframework.integration/src/test/java/org/springframework/integration/gateway/GatewayProxyFactoryBeanTests.java +++ b/org.springframework.integration/src/test/java/org/springframework/integration/gateway/GatewayProxyFactoryBeanTests.java @@ -124,7 +124,7 @@ public class GatewayProxyFactoryBeanTests { public void testMultipleMessagesWithResponseCorrelator() throws InterruptedException { ClassPathXmlApplicationContext context = new ClassPathXmlApplicationContext( "gatewayWithResponseCorrelator.xml", GatewayProxyFactoryBeanTests.class); - int numRequests = 5; + int numRequests = 500; final TestService service = (TestService) context.getBean("proxy"); final String[] results = new String[numRequests]; final CountDownLatch latch = new CountDownLatch(numRequests); diff --git a/org.springframework.integration/src/test/java/org/springframework/integration/gateway/config/GatewayParserTests.java b/org.springframework.integration/src/test/java/org/springframework/integration/gateway/config/GatewayParserTests.java index 96d0df0a06..456e1eac3e 100644 --- a/org.springframework.integration/src/test/java/org/springframework/integration/gateway/config/GatewayParserTests.java +++ b/org.springframework.integration/src/test/java/org/springframework/integration/gateway/config/GatewayParserTests.java @@ -64,7 +64,7 @@ public class GatewayParserTests { this.startResponder(requestChannel, replyChannel); TestService service = (TestService) context.getBean("requestReply"); String result = service.requestReply("foo"); - assertEquals("foobar", result); + assertEquals("foo", result); } @Test @@ -75,7 +75,7 @@ public class GatewayParserTests { this.startResponder(requestChannel, replyChannel); TestService service = (TestService) context.getBean("requestReplyWithMessageMapper"); String result = service.requestReply("foo"); - assertEquals("foobar.mapped", result); + assertEquals("foo.mapped", result); } @Test @@ -86,7 +86,7 @@ public class GatewayParserTests { this.startResponder(requestChannel, replyChannel); TestService service = (TestService) context.getBean("requestReplyWithMessageCreator"); String result = service.requestReply("foo"); - assertEquals("created.foobar", result); + assertEquals("created.foo", result); } @@ -94,7 +94,7 @@ public class GatewayParserTests { Executors.newSingleThreadExecutor().execute(new Runnable() { public void run() { Message request = requestChannel.receive(); - Message reply = MessageBuilder.fromPayload(request.getPayload() + "bar") + Message reply = MessageBuilder.fromMessage(request) .setCorrelationId(request.getHeaders().getId()).build(); replyChannel.send(reply); } diff --git a/org.springframework.integration/src/test/java/org/springframework/integration/gateway/gatewayWithResponseCorrelator.xml b/org.springframework.integration/src/test/java/org/springframework/integration/gateway/gatewayWithResponseCorrelator.xml index 711661a286..9e2d8a6836 100644 --- a/org.springframework.integration/src/test/java/org/springframework/integration/gateway/gatewayWithResponseCorrelator.xml +++ b/org.springframework.integration/src/test/java/org/springframework/integration/gateway/gatewayWithResponseCorrelator.xml @@ -18,7 +18,7 @@ - + diff --git a/org.springframework.integration/src/test/java/org/springframework/integration/handler/ReplyMessageCorrelatorTests.java b/org.springframework.integration/src/test/java/org/springframework/integration/handler/ReplyMessageCorrelatorTests.java index cee2c74527..02e537eece 100644 --- a/org.springframework.integration/src/test/java/org/springframework/integration/handler/ReplyMessageCorrelatorTests.java +++ b/org.springframework.integration/src/test/java/org/springframework/integration/handler/ReplyMessageCorrelatorTests.java @@ -18,14 +18,9 @@ package org.springframework.integration.handler; import static org.junit.Assert.assertEquals; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.Executor; -import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; - import org.junit.Test; +import org.springframework.integration.channel.QueueChannel; import org.springframework.integration.message.Message; import org.springframework.integration.message.MessageBuilder; @@ -36,47 +31,13 @@ public class ReplyMessageCorrelatorTests { @Test public void testReceiversPrecedeReply() throws InterruptedException { - final ReplyMessageCorrelator correlator = new ReplyMessageCorrelator(10); - final AtomicInteger replyCounter = new AtomicInteger(); - CountDownLatch latch = startReceivers(correlator, replyCounter, 5, 500); + final ReplyMessageCorrelator correlator = new ReplyMessageCorrelator(); + QueueChannel replyChannel = new QueueChannel(); Message message = MessageBuilder.fromPayload("test") - .setCorrelationId("123").build(); + .setCorrelationId("123").setReturnAddress(replyChannel).build(); correlator.handle(message); - latch.await(1000, TimeUnit.MILLISECONDS); - assertEquals(0, latch.getCount()); - assertEquals(1, replyCounter.get()); - } - - @Test - public void testReplyPrecedeReceivers() throws InterruptedException { - final ReplyMessageCorrelator correlator = new ReplyMessageCorrelator(10); - Message message = MessageBuilder.fromPayload("test") - .setCorrelationId("123").build(); - correlator.handle(message); - final AtomicInteger replyCounter = new AtomicInteger(); - CountDownLatch latch = startReceivers(correlator, replyCounter, 5, 50); - latch.await(1000, TimeUnit.MILLISECONDS); - assertEquals(0, latch.getCount()); - assertEquals(1, replyCounter.get()); - } - - - private static CountDownLatch startReceivers(final ReplyMessageCorrelator correlator, - final AtomicInteger replyCounter, int numReceivers, final long timeout) { - final CountDownLatch latch = new CountDownLatch(numReceivers); - Executor executor = Executors.newFixedThreadPool(numReceivers); - for (int i = 0; i < numReceivers; i++) { - executor.execute(new Runnable() { - public void run() { - Message reply = correlator.getReply("123", timeout); - if (reply != null) { - replyCounter.incrementAndGet(); - } - latch.countDown(); - } - }); - } - return latch; + Message reply = replyChannel.receive(0); + assertEquals("test", reply.getPayload()); } }