diff --git a/org.springframework.integration/src/test/java/org/springframework/integration/endpoint/PollingConsumerEndpointTests.java b/org.springframework.integration/src/test/java/org/springframework/integration/endpoint/PollingConsumerEndpointTests.java index 69a0383293..fb2b4d0ac9 100644 --- a/org.springframework.integration/src/test/java/org/springframework/integration/endpoint/PollingConsumerEndpointTests.java +++ b/org.springframework.integration/src/test/java/org/springframework/integration/endpoint/PollingConsumerEndpointTests.java @@ -22,11 +22,13 @@ import static org.easymock.EasyMock.expectLastCall; import static org.easymock.EasyMock.replay; import static org.easymock.EasyMock.reset; import static org.easymock.EasyMock.verify; +import static org.junit.Assert.assertEquals; import java.util.Date; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.After; import org.junit.Before; @@ -51,7 +53,7 @@ public class PollingConsumerEndpointTests { private PollingConsumerEndpoint endpoint; - private Trigger trigger = new TestTrigger(); + private TestTrigger trigger = new TestTrigger(); private TestConsumer consumer = new TestConsumer(); @@ -68,6 +70,8 @@ public class PollingConsumerEndpointTests { @Before public void init() throws InterruptedException { + consumer.counter.set(0); + trigger.reset(); endpoint = new PollingConsumerEndpoint(consumer, channelMock); endpoint.setTaskScheduler(taskScheduler); taskScheduler.setErrorHandler(errorHandler); @@ -84,37 +88,40 @@ public class PollingConsumerEndpointTests { @Test - public void singleMessage() throws InterruptedException { + public void singleMessage() { expect(channelMock.receive()).andReturn(message); expectLastCall(); replay(channelMock); endpoint.setMaxMessagesPerPoll(1); endpoint.start(); - consumer.await(500); + trigger.await(); endpoint.stop(); + assertEquals(1, consumer.counter.get()); verify(channelMock); } @Test - public void multipleMessages() throws InterruptedException { + public void multipleMessages() { expect(channelMock.receive()).andReturn(message).times(5); replay(channelMock); endpoint.setMaxMessagesPerPoll(5); endpoint.start(); - consumer.await(500); + trigger.await(); endpoint.stop(); + assertEquals(5, consumer.counter.get()); verify(channelMock); } @Test - public void multipleMessages_underrun() throws InterruptedException { + public void multipleMessages_underrun() { expect(channelMock.receive()).andReturn(message).times(5); expect(channelMock.receive()).andReturn(null); replay(channelMock); endpoint.setMaxMessagesPerPoll(6); endpoint.start(); - consumer.await(500); + trigger.await(); endpoint.stop(); + assertEquals(5, consumer.counter.get()); verify(channelMock); } @@ -123,9 +130,10 @@ public class PollingConsumerEndpointTests { expect(channelMock.receive()).andReturn(badMessage); replay(channelMock); endpoint.start(); - consumer.await(500); + trigger.await(); endpoint.stop(); verify(channelMock); + assertEquals(1, consumer.counter.get()); errorHandler.throwLastErrorIfAvailable(); } @@ -135,64 +143,49 @@ public class PollingConsumerEndpointTests { replay(channelMock); endpoint.setMaxMessagesPerPoll(10); endpoint.start(); - consumer.await(500); + trigger.await(); endpoint.stop(); verify(channelMock); + assertEquals(1, consumer.counter.get()); errorHandler.throwLastErrorIfAvailable(); } - @Test(expected = TestTimeoutException.class) - public void blockingSourceTimedOut() throws InterruptedException { + @Test + public void blockingSourceTimedOut() { // we don't need to await the timeout, returning null suffices expect(channelMock.receive(1)).andReturn(null); replay(channelMock); endpoint.setReceiveTimeout(1); endpoint.start(); - try { - consumer.await(500); - } - finally { - endpoint.stop(); - verify(channelMock); - } + trigger.await(); + endpoint.stop(); + assertEquals(0, consumer.counter.get()); + verify(channelMock); } @Test - public void blockingSourceNotTimedOut() throws InterruptedException { + public void blockingSourceNotTimedOut() { expect(channelMock.receive(1)).andReturn(message); expectLastCall(); replay(channelMock); endpoint.setReceiveTimeout(1); endpoint.setMaxMessagesPerPoll(1); endpoint.start(); - consumer.await(500); + trigger.await(); endpoint.stop(); + assertEquals(1, consumer.counter.get()); verify(channelMock); } private static class TestConsumer implements MessageConsumer { - private volatile CountDownLatch latch = new CountDownLatch(1); + private volatile AtomicInteger counter = new AtomicInteger(); public void onMessage(Message message) { - try { - if ("bad".equals(message.getPayload().toString())) { - throw new MessageRejectedException(message, "intentional test failure"); - } - } - finally { - this.latch.countDown(); - } - } - - public void await(long timeout) throws InterruptedException { - this.latch.await(timeout, TimeUnit.MILLISECONDS); - if (this.latch.getCount() == 0) { - this.latch = new CountDownLatch(1); - } - else { - throw new TestTimeoutException(); + this.counter.incrementAndGet(); + if ("bad".equals(message.getPayload().toString())) { + throw new MessageRejectedException(message, "intentional test failure"); } } } @@ -202,12 +195,33 @@ public class PollingConsumerEndpointTests { private final AtomicBoolean hasRun = new AtomicBoolean(); + private volatile CountDownLatch latch = new CountDownLatch(1); + + public Date getNextRunTime(Date lastScheduledRunTime, Date lastCompleteTime) { - if (!hasRun.getAndSet(true)) { + if (!this.hasRun.getAndSet(true)) { return new Date(); } + this.latch.countDown(); return null; } + + public void reset() { + this.latch = new CountDownLatch(1); + this.hasRun.set(false); + } + + public void await() { + try { + this.latch.await(5000, TimeUnit.MILLISECONDS); + if (latch.getCount() != 0) { + throw new RuntimeException("test latch.await() did not count down"); + } + } + catch (InterruptedException e) { + throw new RuntimeException("test latch.await() interrupted"); + } + } } @@ -226,9 +240,4 @@ public class PollingConsumerEndpointTests { } } - - @SuppressWarnings("serial") - private static class TestTimeoutException extends RuntimeException { - } - }