From 613c2e09e70ddbc7fa5f29ce4eb85f0d1f31363f Mon Sep 17 00:00:00 2001 From: Costin Leau Date: Fri, 25 Jan 2013 00:11:47 +0200 Subject: [PATCH] remove listeners programatically in RMLContainer DATAREDIS-107 listeners can be removed not just added programatically fix bug causing double message dispatch for pattern subscriptions introduce more pubsub tests --- .../RedisMessageListenerContainer.java | 235 ++++++++++++++---- .../listener/PubSubResubscribeTests.java | 153 ++++++++++++ .../data/redis/listener/PubSubTests.java | 67 ++++- src/test/resources/log4j.properties | 2 +- 4 files changed, 399 insertions(+), 58 deletions(-) create mode 100644 src/test/java/org/springframework/data/redis/listener/PubSubResubscribeTests.java diff --git a/src/main/java/org/springframework/data/redis/listener/RedisMessageListenerContainer.java b/src/main/java/org/springframework/data/redis/listener/RedisMessageListenerContainer.java index 071167dfb..6eace700d 100644 --- a/src/main/java/org/springframework/data/redis/listener/RedisMessageListenerContainer.java +++ b/src/main/java/org/springframework/data/redis/listener/RedisMessageListenerContainer.java @@ -20,6 +20,7 @@ import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CopyOnWriteArraySet; import java.util.concurrent.Executor; @@ -42,6 +43,7 @@ import org.springframework.data.redis.connection.util.ByteArrayWrapper; import org.springframework.data.redis.serializer.RedisSerializer; import org.springframework.data.redis.serializer.StringRedisSerializer; import org.springframework.scheduling.SchedulingAwareRunnable; +import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.CollectionUtils; import org.springframework.util.ErrorHandler; @@ -55,7 +57,11 @@ import org.springframework.util.ErrorHandler; * the message dispatch being done through the task executor. * *

- * Note the container uses the connection in a lazy fashion (the connection is used only if at least one listener is configured). + * Note the container uses the connection in a lazy fashion (the connection is used only if at least one listener is configured). + * + *

+ * Adding and removing listeners at the same time has undefined results. It is strongly recommended to synchronize/order these + * methods accordingly. * * @author Costin Leau */ @@ -65,7 +71,6 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab protected final Log logger = LogFactory.getLog(getClass()); - /** * Default thread name prefix: "RedisListeningContainer-". */ @@ -104,13 +109,15 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab private final Map> patternMapping = new ConcurrentHashMap>(); // lookup map between channels and listeners private final Map> channelMapping = new ConcurrentHashMap>(); + // lookup map between listeners and channels + private final Map> listenerTopics = new ConcurrentHashMap>(); private final SubscriptionTask subscriptionTask = new SubscriptionTask(); private volatile RedisSerializer serializer = new StringRedisSerializer(); - + public void afterPropertiesSet() { if (taskExecutor == null) { manageExecutor = true; @@ -135,7 +142,7 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab return new SimpleAsyncTaskExecutor(threadNamePrefix); } - + public void destroy() throws Exception { initialized = false; @@ -152,29 +159,29 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab } } - + public boolean isAutoStartup() { return true; } - + public void stop(Runnable callback) { stop(); callback.run(); } - + public int getPhase() { // start the latest return Integer.MAX_VALUE; } - + public boolean isRunning() { return running; } - + public void start() { if (!running) { running = true; @@ -196,13 +203,14 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab } } - + public void stop() { if (isRunning()) { running = false; synchronized (monitor) { boolean shouldWait = listening; subscriptionTask.cancel(); + listening = false; if (shouldWait) { try { monitor.wait(initWait); @@ -300,7 +308,7 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab this.connectionFactory = connectionFactory; } - + public void setBeanName(String name) { this.beanName = name; } @@ -388,14 +396,56 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab addMessageListener(listener, Collections.singleton(topic)); } + /** + * Removes a message listener from the given topics. If the container is running, + * the listener stops receiving (matching) messages as soon as possible. + *

+ * Note that this method obeys the Redis (p)unsubscribe semantics - meaning an empty/null collection will remove + * listener from all channels. + * Similarly a null listener will unsubscribe all listeners from the given topic. + * + * @param listener message listener + * @param topics message listener topics + */ + public void removeMessageListener(MessageListener listener, Collection topics) { + removeListener(listener, topics); + } + + /** + * Removes a message listener from the from the given topic. If the container is running, + * the listener stops receiving (matching) messages as soon as possible. + * + *

+ * Note that this method obeys the Redis (p)unsubscribe semantics - meaning an empty/null collection will remove + * listener from all channels. Similarly a null listener will unsubscribe all listeners from the given topic. + * + * @param listener message listener + * @param topic message topic + */ + public void removeMessageListener(MessageListener listener, Topic topic) { + removeMessageListener(listener, Collections.singleton(topic)); + } + + /** + * Removes the given message listener completely (from all topics). If the container is running, + * the listener stops receiving (matching) messages as soon as possible. + * Similarly a null listener will unsubscribe all listeners from the given topic. + * + * @param listener message listener + */ + public void removeMessageListener(MessageListener listener) { + removeMessageListener(listener, Collections. emptySet()); + } + private void initMapping(Map> listeners) { // stop the listener if currently running if (isRunning()) { - stop(); + subscriptionTask.cancel(); } patternMapping.clear(); channelMapping.clear(); + listenerTopics.clear(); if (!CollectionUtils.isEmpty(listeners)) { for (Map.Entry> entry : listeners.entrySet()) { @@ -440,11 +490,22 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab } private void addListener(MessageListener listener, Collection topics) { + Assert.notNull(listener, "a valid listener is required"); + Assert.notEmpty(topics, "at least one topic is required"); + List channels = new ArrayList(topics.size()); List patterns = new ArrayList(topics.size()); boolean trace = logger.isTraceEnabled(); + // add listener mapping + Set set = listenerTopics.get(listener); + if (set == null) { + set = new CopyOnWriteArraySet(); + listenerTopics.put(listener, set); + } + set.addAll(topics); + for (Topic topic : topics) { ByteArrayWrapper holder = new ByteArrayWrapper(serializer.serialize(topic.getTopic())); @@ -487,6 +548,85 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab } } + private void removeListener(MessageListener listener, Collection topics) { + boolean trace = logger.isTraceEnabled(); + + // check stop listening case + if (listener == null && CollectionUtils.isEmpty(topics)) { + subscriptionTask.cancel(); + logger.debug("Stopped listening for Redis messages"); + return; + } + + List channelsToRemove = new ArrayList(); + List patternsToRemove = new ArrayList(); + + // check unsubscribe all topics case + if (CollectionUtils.isEmpty(topics)) { + Set set = listenerTopics.remove(listener); + // listener not found, bail out + if (set == null) { + return; + } + topics = set; + } + + for (Topic topic : topics) { + ByteArrayWrapper holder = new ByteArrayWrapper(serializer.serialize(topic.getTopic())); + + if (topic instanceof ChannelTopic) { + remove(listener, topic, holder, channelMapping, channelsToRemove); + + if (trace) { + String msg = (listener != null ? "listener '" + listener + "'" : "all listeners"); + logger.trace("Removing " + msg + " from channel '" + topic.getTopic() + "'"); + } + } + + else if (topic instanceof PatternTopic) { + remove(listener, topic, holder, patternMapping, patternsToRemove); + + if (trace) { + String msg = (listener != null ? "listener '" + listener + "'" : "all listeners"); + logger.trace("Removing " + msg + " from pattern '" + topic.getTopic() + "'"); + } + } + } + + // check the current listening state + if (listening) { + subscriptionTask.unsubscribeChannel(channelsToRemove.toArray(new byte[channelsToRemove.size()][])); + subscriptionTask.unsubscribePattern(patternsToRemove.toArray(new byte[patternsToRemove.size()][])); + } + } + + private void remove(MessageListener listener, Topic topic, ByteArrayWrapper holder, Map> mapping, List topicToRemove) { + + Collection listeners = mapping.get(holder); + if (listeners != null) { + if (listener != null) { + listeners.remove(listener); + } + // remove all listeners for the given topic + else { + for (MessageListener messageListener : listeners) { + Set topics = listenerTopics.get(messageListener); + if (topics != null) { + topics.remove(topic); + } + if (topics.isEmpty()) { + listenerTopics.remove(messageListener); + } + } + } + + if (listener == null || listeners.isEmpty()) { + mapping.remove(holder); + topicToRemove.add(holder.getArray()); + } + } + } + /** * Runnable used for Redis subscription. Implemented as a dedicated class to provide as many hints @@ -509,12 +649,12 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab private long WAIT = 500; private long ROUNDS = 3; - + public boolean isLongLived() { return false; } - + public void run() { // wait for subscription to be initialized boolean done = false; @@ -542,12 +682,12 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab private volatile RedisConnection connection; private final Object localMonitor = new Object(); - + public boolean isLongLived() { return true; } - + public void run() { connection = connectionFactory.getConnection(); try { @@ -556,7 +696,6 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab } // NB: each Xsubscribe call blocks - synchronized (monitor) { monitor.notify(); } @@ -613,6 +752,9 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab } void cancel() { + if (!listening) { + return; + } if (connection != null) { synchronized (localMonitor) { if (connection != null) { @@ -694,48 +836,35 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab */ private class DispatchMessageListener implements MessageListener { - public void onMessage(Message message, byte[] pattern) { - // do channel matching first - byte[] channel = message.getChannel(); + Collection listeners = null; - Collection ch = channelMapping.get(new ByteArrayWrapper(channel)); - Collection pt = null; - - // followed by pattern matching + // if it's a pattern, disregard channel if (pattern != null && pattern.length > 0) { - pt = patternMapping.get(new ByteArrayWrapper(pattern)); + listeners = patternMapping.get(new ByteArrayWrapper(pattern)); + } + else { + pattern = null; + // do channel matching first + listeners = channelMapping.get(new ByteArrayWrapper(message.getChannel())); } - if (!CollectionUtils.isEmpty(ch)) { - dispatchChannels(ch, message); - } - - if (!CollectionUtils.isEmpty(pt)) { - dispatchPatterns(pt, message, pattern); - } - } - - private void dispatchChannels(Collection ch, final Message message) { - for (final MessageListener messageListener : ch) { - taskExecutor.execute(new Runnable() { - - public void run() { - processMessage(messageListener, message, message.getChannel()); - } - }); - } - } - - private void dispatchPatterns(Collection pt, final Message message, final byte[] pattern) { - for (final MessageListener messageListener : pt) { - taskExecutor.execute(new Runnable() { - - public void run() { - processMessage(messageListener, message, pattern.clone()); - } - }); + if (!CollectionUtils.isEmpty(listeners)) { + dispatchMessage(listeners, message, pattern); } } } + + private void dispatchMessage(Collection listeners, final Message message, final byte[] pattern) { + final byte[] source = (pattern != null ? pattern.clone() : message.getChannel()); + + for (final MessageListener messageListener : listeners) { + taskExecutor.execute(new Runnable() { + public void run() { + processMessage(messageListener, message, source); + } + }); + } + } + } \ No newline at end of file diff --git a/src/test/java/org/springframework/data/redis/listener/PubSubResubscribeTests.java b/src/test/java/org/springframework/data/redis/listener/PubSubResubscribeTests.java new file mode 100644 index 000000000..64bd8059f --- /dev/null +++ b/src/test/java/org/springframework/data/redis/listener/PubSubResubscribeTests.java @@ -0,0 +1,153 @@ +/* + * Copyright 2011-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.data.redis.listener; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.BlockingDeque; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.TimeUnit; + +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Test; +import org.springframework.core.task.SimpleAsyncTaskExecutor; +import org.springframework.core.task.SyncTaskExecutor; +import org.springframework.data.redis.ConnectionFactoryTracker; +import org.springframework.data.redis.SettingsUtils; +import org.springframework.data.redis.connection.RedisConnectionFactory; +import org.springframework.data.redis.connection.jedis.JedisConnectionFactory; +import org.springframework.data.redis.core.RedisTemplate; +import org.springframework.data.redis.core.StringRedisTemplate; +import org.springframework.data.redis.listener.adapter.MessageListenerAdapter; + +/** + * @author Costin Leau + */ +public class PubSubResubscribeTests { + + private static final String CHANNEL = "pubsub::test"; + private final Long ZERO = Long.valueOf(0); + private final Long ONE = Long.valueOf(1); + private final Long TWO = Long.valueOf(2); + + protected RedisMessageListenerContainer container; + protected RedisConnectionFactory factory; + protected RedisTemplate template; + + private final BlockingDeque bag = new LinkedBlockingDeque(99); + + private final Object handler = new Object() { + public void handleMessage(String message) { + System.out.println(message); + bag.add(message); + } + }; + + private final MessageListenerAdapter adapter = new MessageListenerAdapter(handler); + + @AfterClass + public static void cleanUp() { + ConnectionFactoryTracker.cleanUp(); + } + + @Before + public void setUp() throws Exception { + JedisConnectionFactory jedisConnFactory = new JedisConnectionFactory(); + jedisConnFactory.setUsePool(false); + jedisConnFactory.setPort(SettingsUtils.getPort()); + jedisConnFactory.setHostName(SettingsUtils.getHost()); + jedisConnFactory.setDatabase(2); + jedisConnFactory.afterPropertiesSet(); + + factory = jedisConnFactory; + + template = new StringRedisTemplate(jedisConnFactory); + ConnectionFactoryTracker.add(template.getConnectionFactory()); + + bag.clear(); + + adapter.setSerializer(template.getValueSerializer()); + adapter.afterPropertiesSet(); + + container = new RedisMessageListenerContainer(); + container.setConnectionFactory(template.getConnectionFactory()); + container.setBeanName("container"); + container.addMessageListener(adapter, new ChannelTopic(CHANNEL)); + container.setTaskExecutor(new SyncTaskExecutor()); + container.setSubscriptionExecutor(new SimpleAsyncTaskExecutor()); + container.afterPropertiesSet(); + container.start(); + + Thread.sleep(1000); + } + + @After + public void tearDown() throws Exception { + container.destroy(); + } + + + @Test + public void testContainerPatternResubscribe() throws Exception { + String payload1 = "do"; + String payload2 = "re mi"; + + final String PATTERN = "p*"; + final String ANOTHER_CHANNEL = "pubsub::test::extra"; + + MessageListenerAdapter anotherListener = new MessageListenerAdapter(handler); + anotherListener.setSerializer(template.getValueSerializer()); + anotherListener.afterPropertiesSet(); + + // remove adapter from all channels + container.removeMessageListener(adapter); + container.addMessageListener(anotherListener, new PatternTopic(PATTERN)); + + // test no messages are sent just to patterns + assertEquals(ONE, template.convertAndSend(CHANNEL, payload1)); + assertEquals(ONE, template.convertAndSend(ANOTHER_CHANNEL, payload2)); + + List msgs = new ArrayList(); + msgs.add(bag.poll(1, TimeUnit.SECONDS)); + msgs.add(bag.poll(1, TimeUnit.SECONDS)); + + assertEquals(2, msgs.size()); + // bind original listener on another channel + container.addMessageListener(adapter, new ChannelTopic(ANOTHER_CHANNEL)); + + assertEquals(ONE, template.convertAndSend(CHANNEL, payload1)); + assertEquals(TWO, template.convertAndSend(ANOTHER_CHANNEL, payload2)); + + msgs.add(bag.poll(1, TimeUnit.SECONDS)); + msgs.add(bag.poll(1, TimeUnit.SECONDS)); + msgs.add(bag.poll(1, TimeUnit.SECONDS)); + // this message will not arrive on time + assertNull(bag.poll(1, TimeUnit.SECONDS)); + + // same message received first per channel subscription, second based on the pattern + assertEquals(5, msgs.size()); + + assertTrue(msgs.contains(payload1)); + assertTrue(msgs.contains(payload2)); + } +} diff --git a/src/test/java/org/springframework/data/redis/listener/PubSubTests.java b/src/test/java/org/springframework/data/redis/listener/PubSubTests.java index eb3db83d8..e6137254c 100644 --- a/src/test/java/org/springframework/data/redis/listener/PubSubTests.java +++ b/src/test/java/org/springframework/data/redis/listener/PubSubTests.java @@ -16,6 +16,7 @@ package org.springframework.data.redis.listener; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import java.util.Arrays; @@ -33,6 +34,8 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.junit.runners.Parameterized.Parameters; +import org.springframework.core.task.SimpleAsyncTaskExecutor; +import org.springframework.core.task.SyncTaskExecutor; import org.springframework.data.redis.ConnectionFactoryTracker; import org.springframework.data.redis.core.RedisTemplate; import org.springframework.data.redis.listener.adapter.MessageListenerAdapter; @@ -47,6 +50,9 @@ import org.springframework.data.redis.support.collections.ObjectFactory; public class PubSubTests { private static final String CHANNEL = "pubsub::test"; + private final Long ZERO = Long.valueOf(0); + private final Long ONE = Long.valueOf(1); + private final Long TWO = Long.valueOf(2); protected RedisMessageListenerContainer container; protected ObjectFactory factory; @@ -64,6 +70,8 @@ public class PubSubTests { @Before public void setUp() throws Exception { + bag.clear(); + adapter.setSerializer(template.getValueSerializer()); adapter.afterPropertiesSet(); @@ -71,6 +79,8 @@ public class PubSubTests { container.setConnectionFactory(template.getConnectionFactory()); container.setBeanName("container"); container.addMessageListener(adapter, Arrays.asList(new ChannelTopic(CHANNEL))); + container.setTaskExecutor(new SyncTaskExecutor()); + container.setSubscriptionExecutor(new SimpleAsyncTaskExecutor()); container.afterPropertiesSet(); container.start(); @@ -111,15 +121,13 @@ public class PubSubTests { String payload1 = "do"; String payload2 = "re mi"; - template.convertAndSend(CHANNEL, payload1); - template.convertAndSend(CHANNEL, payload2); + assertEquals(ONE, template.convertAndSend(CHANNEL, payload1)); + assertEquals(ONE, template.convertAndSend(CHANNEL, payload2)); Set set = new LinkedHashSet(); set.add(bag.poll(1, TimeUnit.SECONDS)); set.add(bag.poll(1, TimeUnit.SECONDS)); - System.out.println(set); - assertTrue(set.contains(payload1)); assertTrue(set.contains(payload2)); } @@ -134,4 +142,55 @@ public class PubSubTests { Thread.sleep(1000); assertEquals(COUNT, bag.size()); } + + @Test + public void testContainerUnsubscribe() throws Exception { + String payload1 = "do"; + String payload2 = "re mi"; + + container.removeMessageListener(adapter, new ChannelTopic(CHANNEL)); + assertEquals(ZERO, template.convertAndSend(CHANNEL, payload1)); + assertEquals(ZERO, template.convertAndSend(CHANNEL, payload2)); + + Set set = new LinkedHashSet(); + set.add(bag.poll(1, TimeUnit.SECONDS)); + set.add(bag.poll(1, TimeUnit.SECONDS)); + + assertFalse(set.contains(payload1)); + assertFalse(set.contains(payload2)); + } + + @Test + public void testContainerChannelResubscribe() throws Exception { + String payload1 = "do"; + String payload2 = "re mi"; + + String anotherPayload1 = "od"; + String anotherPayload2 = "mi er"; + + String ANOTHER_CHANNEL = "pubsub::test::extra"; + + // bind listener on another channel + container.addMessageListener(adapter, new ChannelTopic(ANOTHER_CHANNEL)); + container.removeMessageListener(null, new ChannelTopic(CHANNEL)); + + assertEquals(ZERO, template.convertAndSend(CHANNEL, payload1)); + assertEquals(ZERO, template.convertAndSend(CHANNEL, payload2)); + + assertEquals(ONE, template.convertAndSend(ANOTHER_CHANNEL, anotherPayload1)); + assertEquals(ONE, template.convertAndSend(ANOTHER_CHANNEL, anotherPayload2)); + + Set set = new LinkedHashSet(); + set.add(bag.poll(1, TimeUnit.SECONDS)); + set.add(bag.poll(1, TimeUnit.SECONDS)); + + System.out.println(set); + + assertFalse(set.contains(payload1)); + assertFalse(set.contains(payload2)); + + assertTrue(set.contains(anotherPayload1)); + assertTrue(set.contains(anotherPayload2)); + } + } \ No newline at end of file diff --git a/src/test/resources/log4j.properties b/src/test/resources/log4j.properties index 945449482..8810f96a4 100644 --- a/src/test/resources/log4j.properties +++ b/src/test/resources/log4j.properties @@ -4,7 +4,7 @@ log4j.appender.stdout=org.apache.log4j.ConsoleAppender log4j.appender.stdout.layout=org.apache.log4j.PatternLayout log4j.appender.stdout.layout.ConversionPattern=%d %p [%c] - <%m>%n -log4j.category.org.springframework.data.keyvalue.redis.listener=TRACE +log4j.category.org.springframework.data.redis.listener=TRACE # for debugging datasource initialization # log4j.category.test.jdbc=DEBUG