diff --git a/spring-rabbit/src/main/java/org/springframework/amqp/rabbit/support/PublisherCallbackChannelImpl.java b/spring-rabbit/src/main/java/org/springframework/amqp/rabbit/support/PublisherCallbackChannelImpl.java index 193f3727..f7c9ec61 100644 --- a/spring-rabbit/src/main/java/org/springframework/amqp/rabbit/support/PublisherCallbackChannelImpl.java +++ b/spring-rabbit/src/main/java/org/springframework/amqp/rabbit/support/PublisherCallbackChannelImpl.java @@ -541,7 +541,9 @@ public class PublisherCallbackChannelImpl implements PublisherCallbackChannel, C public void addPendingConfirm(Listener listener, long seq, PendingConfirm pendingConfirm) { SortedMap pendingConfirmsForListener = this.pendingConfirms.get(listener); Assert.notNull(pendingConfirmsForListener, "Listener not registered"); - pendingConfirmsForListener.put(seq, pendingConfirm); + synchronized (this.pendingConfirms) { + pendingConfirmsForListener.put(seq, pendingConfirm); + } this.listenerForSeq.put(seq, listener); } diff --git a/spring-rabbit/src/test/java/org/springframework/amqp/rabbit/core/RabbitTemplatePublisherCallbacksIntegrationTests.java b/spring-rabbit/src/test/java/org/springframework/amqp/rabbit/core/RabbitTemplatePublisherCallbacksIntegrationTests.java index a951e5d4..ada18017 100644 --- a/spring-rabbit/src/test/java/org/springframework/amqp/rabbit/core/RabbitTemplatePublisherCallbacksIntegrationTests.java +++ b/spring-rabbit/src/test/java/org/springframework/amqp/rabbit/core/RabbitTemplatePublisherCallbacksIntegrationTests.java @@ -23,6 +23,7 @@ import static org.mockito.Mockito.when; import java.util.ArrayList; import java.util.Collection; +import java.util.ConcurrentModificationException; import java.util.HashSet; import java.util.Iterator; import java.util.List; @@ -67,8 +68,6 @@ public class RabbitTemplatePublisherCallbacksIntegrationTests { private CachingConnectionFactory connectionFactoryWithReturnsEnabled; - private RabbitTemplate template; - private RabbitTemplate templateWithConfirmsEnabled; private RabbitTemplate templateWithReturnsEnabled; @@ -78,7 +77,6 @@ public class RabbitTemplatePublisherCallbacksIntegrationTests { connectionFactory = new CachingConnectionFactory(); connectionFactory.setChannelCacheSize(1); connectionFactory.setPort(BrokerTestUtils.getPort()); - template = new RabbitTemplate(connectionFactory); connectionFactoryWithConfirmsEnabled = new CachingConnectionFactory(); // When using publisher confirms, the cache size needs to be large enough // otherwise channels can be closed before confirms are received. @@ -429,4 +427,77 @@ public class RabbitTemplatePublisherCallbacksIntegrationTests { assertTrue(confirms.contains("ghi2")); assertEquals(3, confirms.size()); } + + /** + * AMQP-262 + * Sets up a situation where we are processing 'multi' acks at the same + * time as adding a new pending ack to the map. Test verifies we don't + * get a {@link ConcurrentModificationException}. + */ + @Test + public void testConcurrentConfirms() throws Exception { + ConnectionFactory mockConnectionFactory = mock(ConnectionFactory.class); + Connection mockConnection = mock(Connection.class); + Channel mockChannel = mock(Channel.class); + when(mockChannel.getNextPublishSeqNo()).thenReturn(1L, 2L, 3L, 4L); + + when(mockConnectionFactory.newConnection((ExecutorService) null)).thenReturn(mockConnection); + when(mockConnection.isOpen()).thenReturn(true); + final PublisherCallbackChannelImpl channel = new PublisherCallbackChannelImpl(mockChannel); + when(mockConnection.createChannel()).thenReturn(channel); + + final RabbitTemplate template = new RabbitTemplate(new SingleConnectionFactory(mockConnectionFactory)); + + final CountDownLatch first2SentOnThread1Latch = new CountDownLatch(1); + final CountDownLatch delayAckProcessingLatch = new CountDownLatch(1); + final CountDownLatch startedProcessingMultiAcksLatch = new CountDownLatch(1); + final CountDownLatch waitForAll3AcksLatch = new CountDownLatch(3); + final CountDownLatch allSentLatch = new CountDownLatch(1); + final AtomicInteger acks = new AtomicInteger(); + template.setConfirmCallback(new ConfirmCallback() { + + public void confirm(CorrelationData correlationData, boolean ack) { + try { + startedProcessingMultiAcksLatch.countDown(); + // delay processing here; ensures thread 2 put would be concurrent + delayAckProcessingLatch.await(2, TimeUnit.SECONDS); + // only delay first time through + delayAckProcessingLatch.countDown(); + waitForAll3AcksLatch.countDown(); + acks.incrementAndGet(); + } + catch (InterruptedException e) { + e.printStackTrace(); + } + } + }); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + template.convertAndSend(ROUTE, (Object) "message", new CorrelationData("abc")); + template.convertAndSend(ROUTE, (Object) "message", new CorrelationData("def")); + first2SentOnThread1Latch.countDown(); + } + }); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + startedProcessingMultiAcksLatch.await(); + template.convertAndSend(ROUTE, (Object) "message", new CorrelationData("ghi")); + allSentLatch.countDown(); + } + catch (InterruptedException e) { + e.printStackTrace(); + } + } + }); + assertTrue(first2SentOnThread1Latch.await(10, TimeUnit.SECONDS)); + // there should be no concurrent execution exception here + channel.handleAck(2, true); + assertTrue(allSentLatch.await(10, TimeUnit.SECONDS)); + channel.handleAck(3, false); + assertTrue(waitForAll3AcksLatch.await(10, TimeUnit.SECONDS)); + assertEquals(3, acks.get()); + } }