diff --git a/src/main/java/org/springframework/data/redis/listener/RedisMessageListenerContainer.java b/src/main/java/org/springframework/data/redis/listener/RedisMessageListenerContainer.java
index 071167dfb..6eace700d 100644
--- a/src/main/java/org/springframework/data/redis/listener/RedisMessageListenerContainer.java
+++ b/src/main/java/org/springframework/data/redis/listener/RedisMessageListenerContainer.java
@@ -20,6 +20,7 @@ import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
+import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.Executor;
@@ -42,6 +43,7 @@ import org.springframework.data.redis.connection.util.ByteArrayWrapper;
import org.springframework.data.redis.serializer.RedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;
import org.springframework.scheduling.SchedulingAwareRunnable;
+import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.CollectionUtils;
import org.springframework.util.ErrorHandler;
@@ -55,7 +57,11 @@ import org.springframework.util.ErrorHandler;
* the message dispatch being done through the task executor.
*
*
- * Note the container uses the connection in a lazy fashion (the connection is used only if at least one listener is configured).
+ * Note the container uses the connection in a lazy fashion (the connection is used only if at least one listener is configured).
+ *
+ *
+ * Adding and removing listeners at the same time has undefined results. It is strongly recommended to synchronize/order these
+ * methods accordingly.
*
* @author Costin Leau
*/
@@ -65,7 +71,6 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab
protected final Log logger = LogFactory.getLog(getClass());
-
/**
* Default thread name prefix: "RedisListeningContainer-".
*/
@@ -104,13 +109,15 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab
private final Map> patternMapping = new ConcurrentHashMap>();
// lookup map between channels and listeners
private final Map> channelMapping = new ConcurrentHashMap>();
+ // lookup map between listeners and channels
+ private final Map> listenerTopics = new ConcurrentHashMap>();
private final SubscriptionTask subscriptionTask = new SubscriptionTask();
private volatile RedisSerializer serializer = new StringRedisSerializer();
-
+
public void afterPropertiesSet() {
if (taskExecutor == null) {
manageExecutor = true;
@@ -135,7 +142,7 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab
return new SimpleAsyncTaskExecutor(threadNamePrefix);
}
-
+
public void destroy() throws Exception {
initialized = false;
@@ -152,29 +159,29 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab
}
}
-
+
public boolean isAutoStartup() {
return true;
}
-
+
public void stop(Runnable callback) {
stop();
callback.run();
}
-
+
public int getPhase() {
// start the latest
return Integer.MAX_VALUE;
}
-
+
public boolean isRunning() {
return running;
}
-
+
public void start() {
if (!running) {
running = true;
@@ -196,13 +203,14 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab
}
}
-
+
public void stop() {
if (isRunning()) {
running = false;
synchronized (monitor) {
boolean shouldWait = listening;
subscriptionTask.cancel();
+ listening = false;
if (shouldWait) {
try {
monitor.wait(initWait);
@@ -300,7 +308,7 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab
this.connectionFactory = connectionFactory;
}
-
+
public void setBeanName(String name) {
this.beanName = name;
}
@@ -388,14 +396,56 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab
addMessageListener(listener, Collections.singleton(topic));
}
+ /**
+ * Removes a message listener from the given topics. If the container is running,
+ * the listener stops receiving (matching) messages as soon as possible.
+ *
+ * Note that this method obeys the Redis (p)unsubscribe semantics - meaning an empty/null collection will remove
+ * listener from all channels.
+ * Similarly a null listener will unsubscribe all listeners from the given topic.
+ *
+ * @param listener message listener
+ * @param topics message listener topics
+ */
+ public void removeMessageListener(MessageListener listener, Collection extends Topic> topics) {
+ removeListener(listener, topics);
+ }
+
+ /**
+ * Removes a message listener from the from the given topic. If the container is running,
+ * the listener stops receiving (matching) messages as soon as possible.
+ *
+ *
+ * Note that this method obeys the Redis (p)unsubscribe semantics - meaning an empty/null collection will remove
+ * listener from all channels. Similarly a null listener will unsubscribe all listeners from the given topic.
+ *
+ * @param listener message listener
+ * @param topic message topic
+ */
+ public void removeMessageListener(MessageListener listener, Topic topic) {
+ removeMessageListener(listener, Collections.singleton(topic));
+ }
+
+ /**
+ * Removes the given message listener completely (from all topics). If the container is running,
+ * the listener stops receiving (matching) messages as soon as possible.
+ * Similarly a null listener will unsubscribe all listeners from the given topic.
+ *
+ * @param listener message listener
+ */
+ public void removeMessageListener(MessageListener listener) {
+ removeMessageListener(listener, Collections. emptySet());
+ }
+
private void initMapping(Map extends MessageListener, Collection extends Topic>> listeners) {
// stop the listener if currently running
if (isRunning()) {
- stop();
+ subscriptionTask.cancel();
}
patternMapping.clear();
channelMapping.clear();
+ listenerTopics.clear();
if (!CollectionUtils.isEmpty(listeners)) {
for (Map.Entry extends MessageListener, Collection extends Topic>> entry : listeners.entrySet()) {
@@ -440,11 +490,22 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab
}
private void addListener(MessageListener listener, Collection extends Topic> topics) {
+ Assert.notNull(listener, "a valid listener is required");
+ Assert.notEmpty(topics, "at least one topic is required");
+
List channels = new ArrayList(topics.size());
List patterns = new ArrayList(topics.size());
boolean trace = logger.isTraceEnabled();
+ // add listener mapping
+ Set set = listenerTopics.get(listener);
+ if (set == null) {
+ set = new CopyOnWriteArraySet();
+ listenerTopics.put(listener, set);
+ }
+ set.addAll(topics);
+
for (Topic topic : topics) {
ByteArrayWrapper holder = new ByteArrayWrapper(serializer.serialize(topic.getTopic()));
@@ -487,6 +548,85 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab
}
}
+ private void removeListener(MessageListener listener, Collection extends Topic> topics) {
+ boolean trace = logger.isTraceEnabled();
+
+ // check stop listening case
+ if (listener == null && CollectionUtils.isEmpty(topics)) {
+ subscriptionTask.cancel();
+ logger.debug("Stopped listening for Redis messages");
+ return;
+ }
+
+ List channelsToRemove = new ArrayList();
+ List patternsToRemove = new ArrayList();
+
+ // check unsubscribe all topics case
+ if (CollectionUtils.isEmpty(topics)) {
+ Set set = listenerTopics.remove(listener);
+ // listener not found, bail out
+ if (set == null) {
+ return;
+ }
+ topics = set;
+ }
+
+ for (Topic topic : topics) {
+ ByteArrayWrapper holder = new ByteArrayWrapper(serializer.serialize(topic.getTopic()));
+
+ if (topic instanceof ChannelTopic) {
+ remove(listener, topic, holder, channelMapping, channelsToRemove);
+
+ if (trace) {
+ String msg = (listener != null ? "listener '" + listener + "'" : "all listeners");
+ logger.trace("Removing " + msg + " from channel '" + topic.getTopic() + "'");
+ }
+ }
+
+ else if (topic instanceof PatternTopic) {
+ remove(listener, topic, holder, patternMapping, patternsToRemove);
+
+ if (trace) {
+ String msg = (listener != null ? "listener '" + listener + "'" : "all listeners");
+ logger.trace("Removing " + msg + " from pattern '" + topic.getTopic() + "'");
+ }
+ }
+ }
+
+ // check the current listening state
+ if (listening) {
+ subscriptionTask.unsubscribeChannel(channelsToRemove.toArray(new byte[channelsToRemove.size()][]));
+ subscriptionTask.unsubscribePattern(patternsToRemove.toArray(new byte[patternsToRemove.size()][]));
+ }
+ }
+
+ private void remove(MessageListener listener, Topic topic, ByteArrayWrapper holder, Map> mapping, List topicToRemove) {
+
+ Collection listeners = mapping.get(holder);
+ if (listeners != null) {
+ if (listener != null) {
+ listeners.remove(listener);
+ }
+ // remove all listeners for the given topic
+ else {
+ for (MessageListener messageListener : listeners) {
+ Set topics = listenerTopics.get(messageListener);
+ if (topics != null) {
+ topics.remove(topic);
+ }
+ if (topics.isEmpty()) {
+ listenerTopics.remove(messageListener);
+ }
+ }
+ }
+
+ if (listener == null || listeners.isEmpty()) {
+ mapping.remove(holder);
+ topicToRemove.add(holder.getArray());
+ }
+ }
+ }
+
/**
* Runnable used for Redis subscription. Implemented as a dedicated class to provide as many hints
@@ -509,12 +649,12 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab
private long WAIT = 500;
private long ROUNDS = 3;
-
+
public boolean isLongLived() {
return false;
}
-
+
public void run() {
// wait for subscription to be initialized
boolean done = false;
@@ -542,12 +682,12 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab
private volatile RedisConnection connection;
private final Object localMonitor = new Object();
-
+
public boolean isLongLived() {
return true;
}
-
+
public void run() {
connection = connectionFactory.getConnection();
try {
@@ -556,7 +696,6 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab
}
// NB: each Xsubscribe call blocks
-
synchronized (monitor) {
monitor.notify();
}
@@ -613,6 +752,9 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab
}
void cancel() {
+ if (!listening) {
+ return;
+ }
if (connection != null) {
synchronized (localMonitor) {
if (connection != null) {
@@ -694,48 +836,35 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab
*/
private class DispatchMessageListener implements MessageListener {
-
public void onMessage(Message message, byte[] pattern) {
- // do channel matching first
- byte[] channel = message.getChannel();
+ Collection listeners = null;
- Collection ch = channelMapping.get(new ByteArrayWrapper(channel));
- Collection pt = null;
-
- // followed by pattern matching
+ // if it's a pattern, disregard channel
if (pattern != null && pattern.length > 0) {
- pt = patternMapping.get(new ByteArrayWrapper(pattern));
+ listeners = patternMapping.get(new ByteArrayWrapper(pattern));
+ }
+ else {
+ pattern = null;
+ // do channel matching first
+ listeners = channelMapping.get(new ByteArrayWrapper(message.getChannel()));
}
- if (!CollectionUtils.isEmpty(ch)) {
- dispatchChannels(ch, message);
- }
-
- if (!CollectionUtils.isEmpty(pt)) {
- dispatchPatterns(pt, message, pattern);
- }
- }
-
- private void dispatchChannels(Collection ch, final Message message) {
- for (final MessageListener messageListener : ch) {
- taskExecutor.execute(new Runnable() {
-
- public void run() {
- processMessage(messageListener, message, message.getChannel());
- }
- });
- }
- }
-
- private void dispatchPatterns(Collection pt, final Message message, final byte[] pattern) {
- for (final MessageListener messageListener : pt) {
- taskExecutor.execute(new Runnable() {
-
- public void run() {
- processMessage(messageListener, message, pattern.clone());
- }
- });
+ if (!CollectionUtils.isEmpty(listeners)) {
+ dispatchMessage(listeners, message, pattern);
}
}
}
+
+ private void dispatchMessage(Collection listeners, final Message message, final byte[] pattern) {
+ final byte[] source = (pattern != null ? pattern.clone() : message.getChannel());
+
+ for (final MessageListener messageListener : listeners) {
+ taskExecutor.execute(new Runnable() {
+ public void run() {
+ processMessage(messageListener, message, source);
+ }
+ });
+ }
+ }
+
}
\ No newline at end of file
diff --git a/src/test/java/org/springframework/data/redis/listener/PubSubResubscribeTests.java b/src/test/java/org/springframework/data/redis/listener/PubSubResubscribeTests.java
new file mode 100644
index 000000000..64bd8059f
--- /dev/null
+++ b/src/test/java/org/springframework/data/redis/listener/PubSubResubscribeTests.java
@@ -0,0 +1,153 @@
+/*
+ * Copyright 2011-2013 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.data.redis.listener;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.BlockingDeque;
+import java.util.concurrent.LinkedBlockingDeque;
+import java.util.concurrent.TimeUnit;
+
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.Before;
+import org.junit.Test;
+import org.springframework.core.task.SimpleAsyncTaskExecutor;
+import org.springframework.core.task.SyncTaskExecutor;
+import org.springframework.data.redis.ConnectionFactoryTracker;
+import org.springframework.data.redis.SettingsUtils;
+import org.springframework.data.redis.connection.RedisConnectionFactory;
+import org.springframework.data.redis.connection.jedis.JedisConnectionFactory;
+import org.springframework.data.redis.core.RedisTemplate;
+import org.springframework.data.redis.core.StringRedisTemplate;
+import org.springframework.data.redis.listener.adapter.MessageListenerAdapter;
+
+/**
+ * @author Costin Leau
+ */
+public class PubSubResubscribeTests {
+
+ private static final String CHANNEL = "pubsub::test";
+ private final Long ZERO = Long.valueOf(0);
+ private final Long ONE = Long.valueOf(1);
+ private final Long TWO = Long.valueOf(2);
+
+ protected RedisMessageListenerContainer container;
+ protected RedisConnectionFactory factory;
+ protected RedisTemplate template;
+
+ private final BlockingDeque bag = new LinkedBlockingDeque(99);
+
+ private final Object handler = new Object() {
+ public void handleMessage(String message) {
+ System.out.println(message);
+ bag.add(message);
+ }
+ };
+
+ private final MessageListenerAdapter adapter = new MessageListenerAdapter(handler);
+
+ @AfterClass
+ public static void cleanUp() {
+ ConnectionFactoryTracker.cleanUp();
+ }
+
+ @Before
+ public void setUp() throws Exception {
+ JedisConnectionFactory jedisConnFactory = new JedisConnectionFactory();
+ jedisConnFactory.setUsePool(false);
+ jedisConnFactory.setPort(SettingsUtils.getPort());
+ jedisConnFactory.setHostName(SettingsUtils.getHost());
+ jedisConnFactory.setDatabase(2);
+ jedisConnFactory.afterPropertiesSet();
+
+ factory = jedisConnFactory;
+
+ template = new StringRedisTemplate(jedisConnFactory);
+ ConnectionFactoryTracker.add(template.getConnectionFactory());
+
+ bag.clear();
+
+ adapter.setSerializer(template.getValueSerializer());
+ adapter.afterPropertiesSet();
+
+ container = new RedisMessageListenerContainer();
+ container.setConnectionFactory(template.getConnectionFactory());
+ container.setBeanName("container");
+ container.addMessageListener(adapter, new ChannelTopic(CHANNEL));
+ container.setTaskExecutor(new SyncTaskExecutor());
+ container.setSubscriptionExecutor(new SimpleAsyncTaskExecutor());
+ container.afterPropertiesSet();
+ container.start();
+
+ Thread.sleep(1000);
+ }
+
+ @After
+ public void tearDown() throws Exception {
+ container.destroy();
+ }
+
+
+ @Test
+ public void testContainerPatternResubscribe() throws Exception {
+ String payload1 = "do";
+ String payload2 = "re mi";
+
+ final String PATTERN = "p*";
+ final String ANOTHER_CHANNEL = "pubsub::test::extra";
+
+ MessageListenerAdapter anotherListener = new MessageListenerAdapter(handler);
+ anotherListener.setSerializer(template.getValueSerializer());
+ anotherListener.afterPropertiesSet();
+
+ // remove adapter from all channels
+ container.removeMessageListener(adapter);
+ container.addMessageListener(anotherListener, new PatternTopic(PATTERN));
+
+ // test no messages are sent just to patterns
+ assertEquals(ONE, template.convertAndSend(CHANNEL, payload1));
+ assertEquals(ONE, template.convertAndSend(ANOTHER_CHANNEL, payload2));
+
+ List msgs = new ArrayList();
+ msgs.add(bag.poll(1, TimeUnit.SECONDS));
+ msgs.add(bag.poll(1, TimeUnit.SECONDS));
+
+ assertEquals(2, msgs.size());
+ // bind original listener on another channel
+ container.addMessageListener(adapter, new ChannelTopic(ANOTHER_CHANNEL));
+
+ assertEquals(ONE, template.convertAndSend(CHANNEL, payload1));
+ assertEquals(TWO, template.convertAndSend(ANOTHER_CHANNEL, payload2));
+
+ msgs.add(bag.poll(1, TimeUnit.SECONDS));
+ msgs.add(bag.poll(1, TimeUnit.SECONDS));
+ msgs.add(bag.poll(1, TimeUnit.SECONDS));
+ // this message will not arrive on time
+ assertNull(bag.poll(1, TimeUnit.SECONDS));
+
+ // same message received first per channel subscription, second based on the pattern
+ assertEquals(5, msgs.size());
+
+ assertTrue(msgs.contains(payload1));
+ assertTrue(msgs.contains(payload2));
+ }
+}
diff --git a/src/test/java/org/springframework/data/redis/listener/PubSubTests.java b/src/test/java/org/springframework/data/redis/listener/PubSubTests.java
index eb3db83d8..e6137254c 100644
--- a/src/test/java/org/springframework/data/redis/listener/PubSubTests.java
+++ b/src/test/java/org/springframework/data/redis/listener/PubSubTests.java
@@ -16,6 +16,7 @@
package org.springframework.data.redis.listener;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import java.util.Arrays;
@@ -33,6 +34,8 @@ import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;
+import org.springframework.core.task.SimpleAsyncTaskExecutor;
+import org.springframework.core.task.SyncTaskExecutor;
import org.springframework.data.redis.ConnectionFactoryTracker;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.listener.adapter.MessageListenerAdapter;
@@ -47,6 +50,9 @@ import org.springframework.data.redis.support.collections.ObjectFactory;
public class PubSubTests {
private static final String CHANNEL = "pubsub::test";
+ private final Long ZERO = Long.valueOf(0);
+ private final Long ONE = Long.valueOf(1);
+ private final Long TWO = Long.valueOf(2);
protected RedisMessageListenerContainer container;
protected ObjectFactory factory;
@@ -64,6 +70,8 @@ public class PubSubTests {
@Before
public void setUp() throws Exception {
+ bag.clear();
+
adapter.setSerializer(template.getValueSerializer());
adapter.afterPropertiesSet();
@@ -71,6 +79,8 @@ public class PubSubTests {
container.setConnectionFactory(template.getConnectionFactory());
container.setBeanName("container");
container.addMessageListener(adapter, Arrays.asList(new ChannelTopic(CHANNEL)));
+ container.setTaskExecutor(new SyncTaskExecutor());
+ container.setSubscriptionExecutor(new SimpleAsyncTaskExecutor());
container.afterPropertiesSet();
container.start();
@@ -111,15 +121,13 @@ public class PubSubTests {
String payload1 = "do";
String payload2 = "re mi";
- template.convertAndSend(CHANNEL, payload1);
- template.convertAndSend(CHANNEL, payload2);
+ assertEquals(ONE, template.convertAndSend(CHANNEL, payload1));
+ assertEquals(ONE, template.convertAndSend(CHANNEL, payload2));
Set set = new LinkedHashSet();
set.add(bag.poll(1, TimeUnit.SECONDS));
set.add(bag.poll(1, TimeUnit.SECONDS));
- System.out.println(set);
-
assertTrue(set.contains(payload1));
assertTrue(set.contains(payload2));
}
@@ -134,4 +142,55 @@ public class PubSubTests {
Thread.sleep(1000);
assertEquals(COUNT, bag.size());
}
+
+ @Test
+ public void testContainerUnsubscribe() throws Exception {
+ String payload1 = "do";
+ String payload2 = "re mi";
+
+ container.removeMessageListener(adapter, new ChannelTopic(CHANNEL));
+ assertEquals(ZERO, template.convertAndSend(CHANNEL, payload1));
+ assertEquals(ZERO, template.convertAndSend(CHANNEL, payload2));
+
+ Set set = new LinkedHashSet();
+ set.add(bag.poll(1, TimeUnit.SECONDS));
+ set.add(bag.poll(1, TimeUnit.SECONDS));
+
+ assertFalse(set.contains(payload1));
+ assertFalse(set.contains(payload2));
+ }
+
+ @Test
+ public void testContainerChannelResubscribe() throws Exception {
+ String payload1 = "do";
+ String payload2 = "re mi";
+
+ String anotherPayload1 = "od";
+ String anotherPayload2 = "mi er";
+
+ String ANOTHER_CHANNEL = "pubsub::test::extra";
+
+ // bind listener on another channel
+ container.addMessageListener(adapter, new ChannelTopic(ANOTHER_CHANNEL));
+ container.removeMessageListener(null, new ChannelTopic(CHANNEL));
+
+ assertEquals(ZERO, template.convertAndSend(CHANNEL, payload1));
+ assertEquals(ZERO, template.convertAndSend(CHANNEL, payload2));
+
+ assertEquals(ONE, template.convertAndSend(ANOTHER_CHANNEL, anotherPayload1));
+ assertEquals(ONE, template.convertAndSend(ANOTHER_CHANNEL, anotherPayload2));
+
+ Set set = new LinkedHashSet();
+ set.add(bag.poll(1, TimeUnit.SECONDS));
+ set.add(bag.poll(1, TimeUnit.SECONDS));
+
+ System.out.println(set);
+
+ assertFalse(set.contains(payload1));
+ assertFalse(set.contains(payload2));
+
+ assertTrue(set.contains(anotherPayload1));
+ assertTrue(set.contains(anotherPayload2));
+ }
+
}
\ No newline at end of file
diff --git a/src/test/resources/log4j.properties b/src/test/resources/log4j.properties
index 945449482..8810f96a4 100644
--- a/src/test/resources/log4j.properties
+++ b/src/test/resources/log4j.properties
@@ -4,7 +4,7 @@ log4j.appender.stdout=org.apache.log4j.ConsoleAppender
log4j.appender.stdout.layout=org.apache.log4j.PatternLayout
log4j.appender.stdout.layout.ConversionPattern=%d %p [%c] - <%m>%n
-log4j.category.org.springframework.data.keyvalue.redis.listener=TRACE
+log4j.category.org.springframework.data.redis.listener=TRACE
# for debugging datasource initialization
# log4j.category.test.jdbc=DEBUG