remove listeners programatically in RMLContainer

DATAREDIS-107
listeners can be removed not just added programatically
fix bug causing double message dispatch for pattern subscriptions
introduce more pubsub tests
This commit is contained in:
Costin Leau
2013-01-25 00:11:47 +02:00
parent 12c66277f7
commit 613c2e09e7
4 changed files with 399 additions and 58 deletions

View File

@@ -20,6 +20,7 @@ import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.Executor;
@@ -42,6 +43,7 @@ import org.springframework.data.redis.connection.util.ByteArrayWrapper;
import org.springframework.data.redis.serializer.RedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;
import org.springframework.scheduling.SchedulingAwareRunnable;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.CollectionUtils;
import org.springframework.util.ErrorHandler;
@@ -55,7 +57,11 @@ import org.springframework.util.ErrorHandler;
* the message dispatch being done through the task executor.
*
* <p/>
* Note the container uses the connection in a lazy fashion (the connection is used only if at least one listener is configured).
* Note the container uses the connection in a lazy fashion (the connection is used only if at least one listener is configured).
*
* <p/>
* Adding and removing listeners at the same time has undefined results. It is strongly recommended to synchronize/order these
* methods accordingly.
*
* @author Costin Leau
*/
@@ -65,7 +71,6 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab
protected final Log logger = LogFactory.getLog(getClass());
/**
* Default thread name prefix: "RedisListeningContainer-".
*/
@@ -104,13 +109,15 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab
private final Map<ByteArrayWrapper, Collection<MessageListener>> patternMapping = new ConcurrentHashMap<ByteArrayWrapper, Collection<MessageListener>>();
// lookup map between channels and listeners
private final Map<ByteArrayWrapper, Collection<MessageListener>> channelMapping = new ConcurrentHashMap<ByteArrayWrapper, Collection<MessageListener>>();
// lookup map between listeners and channels
private final Map<MessageListener, Set<Topic>> listenerTopics = new ConcurrentHashMap<MessageListener, Set<Topic>>();
private final SubscriptionTask subscriptionTask = new SubscriptionTask();
private volatile RedisSerializer<String> serializer = new StringRedisSerializer();
public void afterPropertiesSet() {
if (taskExecutor == null) {
manageExecutor = true;
@@ -135,7 +142,7 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab
return new SimpleAsyncTaskExecutor(threadNamePrefix);
}
public void destroy() throws Exception {
initialized = false;
@@ -152,29 +159,29 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab
}
}
public boolean isAutoStartup() {
return true;
}
public void stop(Runnable callback) {
stop();
callback.run();
}
public int getPhase() {
// start the latest
return Integer.MAX_VALUE;
}
public boolean isRunning() {
return running;
}
public void start() {
if (!running) {
running = true;
@@ -196,13 +203,14 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab
}
}
public void stop() {
if (isRunning()) {
running = false;
synchronized (monitor) {
boolean shouldWait = listening;
subscriptionTask.cancel();
listening = false;
if (shouldWait) {
try {
monitor.wait(initWait);
@@ -300,7 +308,7 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab
this.connectionFactory = connectionFactory;
}
public void setBeanName(String name) {
this.beanName = name;
}
@@ -388,14 +396,56 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab
addMessageListener(listener, Collections.singleton(topic));
}
/**
* Removes a message listener from the given topics. If the container is running,
* the listener stops receiving (matching) messages as soon as possible.
* <p/>
* Note that this method obeys the Redis (p)unsubscribe semantics - meaning an empty/null collection will remove
* listener from all channels.
* Similarly a null listener will unsubscribe all listeners from the given topic.
*
* @param listener message listener
* @param topics message listener topics
*/
public void removeMessageListener(MessageListener listener, Collection<? extends Topic> topics) {
removeListener(listener, topics);
}
/**
* Removes a message listener from the from the given topic. If the container is running,
* the listener stops receiving (matching) messages as soon as possible.
*
* <p/>
* Note that this method obeys the Redis (p)unsubscribe semantics - meaning an empty/null collection will remove
* listener from all channels. Similarly a null listener will unsubscribe all listeners from the given topic.
*
* @param listener message listener
* @param topic message topic
*/
public void removeMessageListener(MessageListener listener, Topic topic) {
removeMessageListener(listener, Collections.singleton(topic));
}
/**
* Removes the given message listener completely (from all topics). If the container is running,
* the listener stops receiving (matching) messages as soon as possible.
* Similarly a null listener will unsubscribe all listeners from the given topic.
*
* @param listener message listener
*/
public void removeMessageListener(MessageListener listener) {
removeMessageListener(listener, Collections.<Topic> emptySet());
}
private void initMapping(Map<? extends MessageListener, Collection<? extends Topic>> listeners) {
// stop the listener if currently running
if (isRunning()) {
stop();
subscriptionTask.cancel();
}
patternMapping.clear();
channelMapping.clear();
listenerTopics.clear();
if (!CollectionUtils.isEmpty(listeners)) {
for (Map.Entry<? extends MessageListener, Collection<? extends Topic>> entry : listeners.entrySet()) {
@@ -440,11 +490,22 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab
}
private void addListener(MessageListener listener, Collection<? extends Topic> topics) {
Assert.notNull(listener, "a valid listener is required");
Assert.notEmpty(topics, "at least one topic is required");
List<byte[]> channels = new ArrayList<byte[]>(topics.size());
List<byte[]> patterns = new ArrayList<byte[]>(topics.size());
boolean trace = logger.isTraceEnabled();
// add listener mapping
Set<Topic> set = listenerTopics.get(listener);
if (set == null) {
set = new CopyOnWriteArraySet<Topic>();
listenerTopics.put(listener, set);
}
set.addAll(topics);
for (Topic topic : topics) {
ByteArrayWrapper holder = new ByteArrayWrapper(serializer.serialize(topic.getTopic()));
@@ -487,6 +548,85 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab
}
}
private void removeListener(MessageListener listener, Collection<? extends Topic> topics) {
boolean trace = logger.isTraceEnabled();
// check stop listening case
if (listener == null && CollectionUtils.isEmpty(topics)) {
subscriptionTask.cancel();
logger.debug("Stopped listening for Redis messages");
return;
}
List<byte[]> channelsToRemove = new ArrayList<byte[]>();
List<byte[]> patternsToRemove = new ArrayList<byte[]>();
// check unsubscribe all topics case
if (CollectionUtils.isEmpty(topics)) {
Set<Topic> set = listenerTopics.remove(listener);
// listener not found, bail out
if (set == null) {
return;
}
topics = set;
}
for (Topic topic : topics) {
ByteArrayWrapper holder = new ByteArrayWrapper(serializer.serialize(topic.getTopic()));
if (topic instanceof ChannelTopic) {
remove(listener, topic, holder, channelMapping, channelsToRemove);
if (trace) {
String msg = (listener != null ? "listener '" + listener + "'" : "all listeners");
logger.trace("Removing " + msg + " from channel '" + topic.getTopic() + "'");
}
}
else if (topic instanceof PatternTopic) {
remove(listener, topic, holder, patternMapping, patternsToRemove);
if (trace) {
String msg = (listener != null ? "listener '" + listener + "'" : "all listeners");
logger.trace("Removing " + msg + " from pattern '" + topic.getTopic() + "'");
}
}
}
// check the current listening state
if (listening) {
subscriptionTask.unsubscribeChannel(channelsToRemove.toArray(new byte[channelsToRemove.size()][]));
subscriptionTask.unsubscribePattern(patternsToRemove.toArray(new byte[patternsToRemove.size()][]));
}
}
private void remove(MessageListener listener, Topic topic, ByteArrayWrapper holder, Map<ByteArrayWrapper, Collection<MessageListener>> mapping, List<byte[]> topicToRemove) {
Collection<MessageListener> listeners = mapping.get(holder);
if (listeners != null) {
if (listener != null) {
listeners.remove(listener);
}
// remove all listeners for the given topic
else {
for (MessageListener messageListener : listeners) {
Set<Topic> topics = listenerTopics.get(messageListener);
if (topics != null) {
topics.remove(topic);
}
if (topics.isEmpty()) {
listenerTopics.remove(messageListener);
}
}
}
if (listener == null || listeners.isEmpty()) {
mapping.remove(holder);
topicToRemove.add(holder.getArray());
}
}
}
/**
* Runnable used for Redis subscription. Implemented as a dedicated class to provide as many hints
@@ -509,12 +649,12 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab
private long WAIT = 500;
private long ROUNDS = 3;
public boolean isLongLived() {
return false;
}
public void run() {
// wait for subscription to be initialized
boolean done = false;
@@ -542,12 +682,12 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab
private volatile RedisConnection connection;
private final Object localMonitor = new Object();
public boolean isLongLived() {
return true;
}
public void run() {
connection = connectionFactory.getConnection();
try {
@@ -556,7 +696,6 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab
}
// NB: each Xsubscribe call blocks
synchronized (monitor) {
monitor.notify();
}
@@ -613,6 +752,9 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab
}
void cancel() {
if (!listening) {
return;
}
if (connection != null) {
synchronized (localMonitor) {
if (connection != null) {
@@ -694,48 +836,35 @@ public class RedisMessageListenerContainer implements InitializingBean, Disposab
*/
private class DispatchMessageListener implements MessageListener {
public void onMessage(Message message, byte[] pattern) {
// do channel matching first
byte[] channel = message.getChannel();
Collection<MessageListener> listeners = null;
Collection<MessageListener> ch = channelMapping.get(new ByteArrayWrapper(channel));
Collection<MessageListener> pt = null;
// followed by pattern matching
// if it's a pattern, disregard channel
if (pattern != null && pattern.length > 0) {
pt = patternMapping.get(new ByteArrayWrapper(pattern));
listeners = patternMapping.get(new ByteArrayWrapper(pattern));
}
else {
pattern = null;
// do channel matching first
listeners = channelMapping.get(new ByteArrayWrapper(message.getChannel()));
}
if (!CollectionUtils.isEmpty(ch)) {
dispatchChannels(ch, message);
}
if (!CollectionUtils.isEmpty(pt)) {
dispatchPatterns(pt, message, pattern);
}
}
private void dispatchChannels(Collection<MessageListener> ch, final Message message) {
for (final MessageListener messageListener : ch) {
taskExecutor.execute(new Runnable() {
public void run() {
processMessage(messageListener, message, message.getChannel());
}
});
}
}
private void dispatchPatterns(Collection<MessageListener> pt, final Message message, final byte[] pattern) {
for (final MessageListener messageListener : pt) {
taskExecutor.execute(new Runnable() {
public void run() {
processMessage(messageListener, message, pattern.clone());
}
});
if (!CollectionUtils.isEmpty(listeners)) {
dispatchMessage(listeners, message, pattern);
}
}
}
private void dispatchMessage(Collection<MessageListener> listeners, final Message message, final byte[] pattern) {
final byte[] source = (pattern != null ? pattern.clone() : message.getChannel());
for (final MessageListener messageListener : listeners) {
taskExecutor.execute(new Runnable() {
public void run() {
processMessage(messageListener, message, source);
}
});
}
}
}

View File

@@ -0,0 +1,153 @@
/*
* Copyright 2011-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.redis.listener;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.BlockingDeque;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.TimeUnit;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.Test;
import org.springframework.core.task.SimpleAsyncTaskExecutor;
import org.springframework.core.task.SyncTaskExecutor;
import org.springframework.data.redis.ConnectionFactoryTracker;
import org.springframework.data.redis.SettingsUtils;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.connection.jedis.JedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.listener.adapter.MessageListenerAdapter;
/**
* @author Costin Leau
*/
public class PubSubResubscribeTests {
private static final String CHANNEL = "pubsub::test";
private final Long ZERO = Long.valueOf(0);
private final Long ONE = Long.valueOf(1);
private final Long TWO = Long.valueOf(2);
protected RedisMessageListenerContainer container;
protected RedisConnectionFactory factory;
protected RedisTemplate template;
private final BlockingDeque<String> bag = new LinkedBlockingDeque<String>(99);
private final Object handler = new Object() {
public void handleMessage(String message) {
System.out.println(message);
bag.add(message);
}
};
private final MessageListenerAdapter adapter = new MessageListenerAdapter(handler);
@AfterClass
public static void cleanUp() {
ConnectionFactoryTracker.cleanUp();
}
@Before
public void setUp() throws Exception {
JedisConnectionFactory jedisConnFactory = new JedisConnectionFactory();
jedisConnFactory.setUsePool(false);
jedisConnFactory.setPort(SettingsUtils.getPort());
jedisConnFactory.setHostName(SettingsUtils.getHost());
jedisConnFactory.setDatabase(2);
jedisConnFactory.afterPropertiesSet();
factory = jedisConnFactory;
template = new StringRedisTemplate(jedisConnFactory);
ConnectionFactoryTracker.add(template.getConnectionFactory());
bag.clear();
adapter.setSerializer(template.getValueSerializer());
adapter.afterPropertiesSet();
container = new RedisMessageListenerContainer();
container.setConnectionFactory(template.getConnectionFactory());
container.setBeanName("container");
container.addMessageListener(adapter, new ChannelTopic(CHANNEL));
container.setTaskExecutor(new SyncTaskExecutor());
container.setSubscriptionExecutor(new SimpleAsyncTaskExecutor());
container.afterPropertiesSet();
container.start();
Thread.sleep(1000);
}
@After
public void tearDown() throws Exception {
container.destroy();
}
@Test
public void testContainerPatternResubscribe() throws Exception {
String payload1 = "do";
String payload2 = "re mi";
final String PATTERN = "p*";
final String ANOTHER_CHANNEL = "pubsub::test::extra";
MessageListenerAdapter anotherListener = new MessageListenerAdapter(handler);
anotherListener.setSerializer(template.getValueSerializer());
anotherListener.afterPropertiesSet();
// remove adapter from all channels
container.removeMessageListener(adapter);
container.addMessageListener(anotherListener, new PatternTopic(PATTERN));
// test no messages are sent just to patterns
assertEquals(ONE, template.convertAndSend(CHANNEL, payload1));
assertEquals(ONE, template.convertAndSend(ANOTHER_CHANNEL, payload2));
List<String> msgs = new ArrayList<String>();
msgs.add(bag.poll(1, TimeUnit.SECONDS));
msgs.add(bag.poll(1, TimeUnit.SECONDS));
assertEquals(2, msgs.size());
// bind original listener on another channel
container.addMessageListener(adapter, new ChannelTopic(ANOTHER_CHANNEL));
assertEquals(ONE, template.convertAndSend(CHANNEL, payload1));
assertEquals(TWO, template.convertAndSend(ANOTHER_CHANNEL, payload2));
msgs.add(bag.poll(1, TimeUnit.SECONDS));
msgs.add(bag.poll(1, TimeUnit.SECONDS));
msgs.add(bag.poll(1, TimeUnit.SECONDS));
// this message will not arrive on time
assertNull(bag.poll(1, TimeUnit.SECONDS));
// same message received first per channel subscription, second based on the pattern
assertEquals(5, msgs.size());
assertTrue(msgs.contains(payload1));
assertTrue(msgs.contains(payload2));
}
}

View File

@@ -16,6 +16,7 @@
package org.springframework.data.redis.listener;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import java.util.Arrays;
@@ -33,6 +34,8 @@ import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;
import org.springframework.core.task.SimpleAsyncTaskExecutor;
import org.springframework.core.task.SyncTaskExecutor;
import org.springframework.data.redis.ConnectionFactoryTracker;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.listener.adapter.MessageListenerAdapter;
@@ -47,6 +50,9 @@ import org.springframework.data.redis.support.collections.ObjectFactory;
public class PubSubTests<T> {
private static final String CHANNEL = "pubsub::test";
private final Long ZERO = Long.valueOf(0);
private final Long ONE = Long.valueOf(1);
private final Long TWO = Long.valueOf(2);
protected RedisMessageListenerContainer container;
protected ObjectFactory<T> factory;
@@ -64,6 +70,8 @@ public class PubSubTests<T> {
@Before
public void setUp() throws Exception {
bag.clear();
adapter.setSerializer(template.getValueSerializer());
adapter.afterPropertiesSet();
@@ -71,6 +79,8 @@ public class PubSubTests<T> {
container.setConnectionFactory(template.getConnectionFactory());
container.setBeanName("container");
container.addMessageListener(adapter, Arrays.asList(new ChannelTopic(CHANNEL)));
container.setTaskExecutor(new SyncTaskExecutor());
container.setSubscriptionExecutor(new SimpleAsyncTaskExecutor());
container.afterPropertiesSet();
container.start();
@@ -111,15 +121,13 @@ public class PubSubTests<T> {
String payload1 = "do";
String payload2 = "re mi";
template.convertAndSend(CHANNEL, payload1);
template.convertAndSend(CHANNEL, payload2);
assertEquals(ONE, template.convertAndSend(CHANNEL, payload1));
assertEquals(ONE, template.convertAndSend(CHANNEL, payload2));
Set<String> set = new LinkedHashSet<String>();
set.add(bag.poll(1, TimeUnit.SECONDS));
set.add(bag.poll(1, TimeUnit.SECONDS));
System.out.println(set);
assertTrue(set.contains(payload1));
assertTrue(set.contains(payload2));
}
@@ -134,4 +142,55 @@ public class PubSubTests<T> {
Thread.sleep(1000);
assertEquals(COUNT, bag.size());
}
@Test
public void testContainerUnsubscribe() throws Exception {
String payload1 = "do";
String payload2 = "re mi";
container.removeMessageListener(adapter, new ChannelTopic(CHANNEL));
assertEquals(ZERO, template.convertAndSend(CHANNEL, payload1));
assertEquals(ZERO, template.convertAndSend(CHANNEL, payload2));
Set<String> set = new LinkedHashSet<String>();
set.add(bag.poll(1, TimeUnit.SECONDS));
set.add(bag.poll(1, TimeUnit.SECONDS));
assertFalse(set.contains(payload1));
assertFalse(set.contains(payload2));
}
@Test
public void testContainerChannelResubscribe() throws Exception {
String payload1 = "do";
String payload2 = "re mi";
String anotherPayload1 = "od";
String anotherPayload2 = "mi er";
String ANOTHER_CHANNEL = "pubsub::test::extra";
// bind listener on another channel
container.addMessageListener(adapter, new ChannelTopic(ANOTHER_CHANNEL));
container.removeMessageListener(null, new ChannelTopic(CHANNEL));
assertEquals(ZERO, template.convertAndSend(CHANNEL, payload1));
assertEquals(ZERO, template.convertAndSend(CHANNEL, payload2));
assertEquals(ONE, template.convertAndSend(ANOTHER_CHANNEL, anotherPayload1));
assertEquals(ONE, template.convertAndSend(ANOTHER_CHANNEL, anotherPayload2));
Set<String> set = new LinkedHashSet<String>();
set.add(bag.poll(1, TimeUnit.SECONDS));
set.add(bag.poll(1, TimeUnit.SECONDS));
System.out.println(set);
assertFalse(set.contains(payload1));
assertFalse(set.contains(payload2));
assertTrue(set.contains(anotherPayload1));
assertTrue(set.contains(anotherPayload2));
}
}

View File

@@ -4,7 +4,7 @@ log4j.appender.stdout=org.apache.log4j.ConsoleAppender
log4j.appender.stdout.layout=org.apache.log4j.PatternLayout
log4j.appender.stdout.layout.ConversionPattern=%d %p [%c] - <%m>%n
log4j.category.org.springframework.data.keyvalue.redis.listener=TRACE
log4j.category.org.springframework.data.redis.listener=TRACE
# for debugging datasource initialization
# log4j.category.test.jdbc=DEBUG