Track subscriptions and unsubscriptions in LettuceReactiveRedisConnection.

We now track subscriptions and unsubscriptions in the reactive API to ensure that we do not prematurely unsubscribe from a channel or pattern if the topic was subscribed multiple times.

Original Pull Request: #2467
This commit is contained in:
Mark Paluch
2022-11-30 15:55:11 +01:00
committed by Christoph Strobl
parent 7b6a697265
commit 230c764c69
7 changed files with 484 additions and 43 deletions

View File

@@ -20,14 +20,21 @@ import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLongFieldUpdater;
import java.util.function.BiFunction;
import java.util.function.Function;
import org.reactivestreams.Publisher;
import org.springframework.data.redis.connection.ReactivePubSubCommands;
import org.springframework.data.redis.connection.ReactiveSubscription;
import org.springframework.data.redis.connection.ReactiveSubscription.ChannelMessage;
import org.springframework.data.redis.connection.SubscriptionListener;
import org.springframework.data.redis.connection.util.ByteArrayWrapper;
import org.springframework.data.redis.util.ByteUtils;
import org.springframework.util.Assert;
/**
@@ -39,16 +46,27 @@ class LettuceReactivePubSubCommands implements ReactivePubSubCommands {
private final LettuceReactiveRedisConnection connection;
private final Map<ByteArrayWrapper, Target> channels = new ConcurrentHashMap<>();
private final Map<ByteArrayWrapper, Target> patterns = new ConcurrentHashMap<>();
LettuceReactivePubSubCommands(LettuceReactiveRedisConnection connection) {
this.connection = connection;
}
public Map<ByteArrayWrapper, Target> getChannels() {
return channels;
}
public Map<ByteArrayWrapper, Target> getPatterns() {
return patterns;
}
@Override
public Mono<ReactiveSubscription> createSubscription(SubscriptionListener listener) {
return connection.getPubSubConnection()
.map(pubSubConnection -> new LettuceReactiveSubscription(listener, pubSubConnection,
connection.translateException()));
return connection.getPubSubConnection().map(pubSubConnection -> new LettuceReactiveSubscription(listener,
pubSubConnection, this, connection.translateException()));
}
@Override
@@ -65,20 +83,157 @@ class LettuceReactivePubSubCommands implements ReactivePubSubCommands {
Assert.notNull(channels, "Channels must not be null");
Target.trackSubscriptions(channels, this.channels); // track usage but do not limit what to subscribe to
return doWithPubSub(commands -> commands.subscribe(channels));
}
public Mono<Void> unsubscribe(ByteBuffer... channels) {
Assert.notNull(patterns, "Patterns must not be null");
ByteBuffer[] actualUnsubscribe = Target.trackUnsubscriptions(channels, this.channels);
if (actualUnsubscribe.length == 0 && channels.length != 0) {
return Mono.empty();
}
return doWithPubSub(commands -> commands.unsubscribe(actualUnsubscribe));
}
@Override
public Mono<Void> pSubscribe(ByteBuffer... patterns) {
Assert.notNull(patterns, "Patterns must not be null");
Target.trackSubscriptions(patterns, this.patterns); // track usage but do not limit what to subscribe to
return doWithPubSub(commands -> commands.psubscribe(patterns));
}
public Mono<Void> pUnsubscribe(ByteBuffer... patterns) {
Assert.notNull(patterns, "Patterns must not be null");
ByteBuffer[] actualUnsubscribe = Target.trackUnsubscriptions(patterns, this.patterns);
if (actualUnsubscribe.length == 0 && patterns.length != 0) {
return Mono.empty();
}
return doWithPubSub(commands -> commands.punsubscribe(actualUnsubscribe));
}
private <T> Mono<T> doWithPubSub(Function<RedisPubSubReactiveCommands<ByteBuffer, ByteBuffer>, Mono<T>> function) {
return connection.getPubSubConnection().flatMap(pubSubConnection -> function.apply(pubSubConnection.reactive()))
.onErrorMap(connection.translateException());
}
static class Target {
private static final AtomicLongFieldUpdater<Target> SUBSCRIBERS = AtomicLongFieldUpdater.newUpdater(Target.class,
"subscribers");
private final byte[] raw;
private volatile long subscribers;
Target(byte[] raw) {
this.raw = raw;
}
/**
* Record the subscriptions to {@code targets} and store these in {@code targetMap}.
*
* @param targets
* @param targetMap
*/
public static void trackSubscriptions(ByteBuffer[] targets, Map<ByteArrayWrapper, Target> targetMap) {
doWithTargets(targets, targetMap, Target::allocate);
}
/**
* Record the un-subscriptions to {@code targets} and store these in {@code targetMap}. Returns the targets to
* actually unsubscribe from if there are no subscribers to a particular target.
*
* @param targets
* @param targetMap
*/
public static ByteBuffer[] trackUnsubscriptions(ByteBuffer[] targets, Map<ByteArrayWrapper, Target> targetMap) {
return doWithTargets(targets, targetMap, Target::deallocate);
}
static ByteBuffer[] doWithTargets(ByteBuffer[] targets, Map<ByteArrayWrapper, Target> targetMap,
BiFunction<ByteBuffer, Map<ByteArrayWrapper, Target>, Boolean> f) {
List<ByteBuffer> toSubscribe = new ArrayList<>(targets.length);
synchronized (targetMap) {
for (ByteBuffer target : targets) {
if (f.apply(target, targetMap)) {
toSubscribe.add(target);
}
}
}
return toSubscribe.toArray(new ByteBuffer[0]);
}
boolean increment() {
return SUBSCRIBERS.incrementAndGet(this) == 1;
}
boolean decrement() {
long l = SUBSCRIBERS.get(this);
if (l > 0) {
if (SUBSCRIBERS.compareAndSet(this, l, l - 1)) {
return l == 1; // return true if this was the last subscriber
}
}
return false;
}
static boolean allocate(ByteBuffer buffer, Map<ByteArrayWrapper, Target> targets) {
byte[] raw = ByteUtils.getBytes(buffer);
ByteArrayWrapper wrapper = new ByteArrayWrapper(raw);
Target targetToUse = targets.get(wrapper);
if (targetToUse == null) {
targetToUse = new Target(raw);
targets.put(wrapper, targetToUse);
}
return targetToUse.increment();
}
static boolean deallocate(ByteBuffer buffer, Map<ByteArrayWrapper, Target> targets) {
byte[] raw = ByteUtils.getBytes(buffer);
ByteArrayWrapper wrapper = new ByteArrayWrapper(raw);
Target targetToUse = targets.get(wrapper);
if (targetToUse == null) {
return false;
}
if (targetToUse.decrement()) {
targets.remove(wrapper);
return true;
}
return false;
}
@Override
public String toString() {
return String.format("%s: Subscribers: %s", new String(raw), SUBSCRIBERS.get(this));
}
}
}

View File

@@ -51,6 +51,8 @@ class LettuceReactiveRedisConnection implements ReactiveRedisConnection {
private final AsyncConnect<StatefulConnection<ByteBuffer, ByteBuffer>> dedicatedConnection;
private final AsyncConnect<StatefulRedisPubSubConnection<ByteBuffer, ByteBuffer>> pubSubConnection;
private final LettuceReactivePubSubCommands pubSub = new LettuceReactivePubSubCommands(this);
private @Nullable Mono<StatefulConnection<ByteBuffer, ByteBuffer>> sharedConnection;
/**
@@ -137,7 +139,7 @@ class LettuceReactiveRedisConnection implements ReactiveRedisConnection {
@Override
public ReactivePubSubCommands pubSubCommands() {
return new LettuceReactivePubSubCommands(this);
return pubSub;
}
@Override

View File

@@ -23,18 +23,17 @@ import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.springframework.data.redis.connection.ReactiveSubscription;
import org.springframework.data.redis.connection.SubscriptionListener;
import org.springframework.data.redis.connection.util.ByteArrayWrapper;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;
@@ -50,19 +49,22 @@ class LettuceReactiveSubscription implements ReactiveSubscription {
private final LettuceByteBufferPubSubListenerWrapper listener;
private final StatefulRedisPubSubConnection<ByteBuffer, ByteBuffer> connection;
private final RedisPubSubReactiveCommands<ByteBuffer, ByteBuffer> commands;
private final RedisPubSubReactiveCommands<ByteBuffer, ByteBuffer> reactive;
private final LettuceReactivePubSubCommands commands;
private final State patternState;
private final State channelState;
LettuceReactiveSubscription(SubscriptionListener subscriptionListener,
StatefulRedisPubSubConnection<ByteBuffer, ByteBuffer> connection,
StatefulRedisPubSubConnection<ByteBuffer, ByteBuffer> connection, LettuceReactivePubSubCommands commands,
Function<Throwable, Throwable> exceptionTranslator) {
this.listener = new LettuceByteBufferPubSubListenerWrapper(
new LettuceMessageListener((messages, pattern) -> {}, subscriptionListener));
this.connection = connection;
this.commands = connection.reactive();
this.reactive = connection.reactive();
this.commands = commands;
connection.addListener(listener);
this.patternState = new State(exceptionTranslator);
@@ -84,7 +86,7 @@ class LettuceReactiveSubscription implements ReactiveSubscription {
Assert.notNull(patterns, "Patterns must not be null");
Assert.noNullElements(patterns, "Patterns must not contain null elements");
return patternState.subscribe(patterns, commands::psubscribe);
return patternState.subscribe(patterns, commands::pSubscribe);
}
@Override
@@ -112,7 +114,7 @@ class LettuceReactiveSubscription implements ReactiveSubscription {
Assert.notNull(patterns, "Patterns must not be null");
Assert.noNullElements(patterns, "Patterns must not contain null elements");
return ObjectUtils.isEmpty(patterns) ? Mono.empty() : patternState.unsubscribe(patterns, commands::punsubscribe);
return ObjectUtils.isEmpty(patterns) ? Mono.empty() : patternState.unsubscribe(patterns, commands::pUnsubscribe);
}
@Override
@@ -128,12 +130,12 @@ class LettuceReactiveSubscription implements ReactiveSubscription {
@Override
public Flux<Message<ByteBuffer, ByteBuffer>> receive() {
Flux<Message<ByteBuffer, ByteBuffer>> channelMessages = channelState.receive(() -> commands.observeChannels() //
.filter(message -> channelState.getTargets().contains(message.getChannel())) //
Flux<Message<ByteBuffer, ByteBuffer>> channelMessages = channelState.receive(() -> reactive.observeChannels() //
.filter(message -> channelState.contains(message.getChannel())) //
.map(message -> new ChannelMessage<>(message.getChannel(), message.getMessage())));
Flux<Message<ByteBuffer, ByteBuffer>> patternMessages = patternState.receive(() -> commands.observePatterns() //
.filter(message -> patternState.getTargets().contains(message.getPattern())) //
Flux<Message<ByteBuffer, ByteBuffer>> patternMessages = patternState.receive(() -> reactive.observePatterns() //
.filter(message -> patternState.contains(message.getPattern())) //
.map(message -> new PatternMessage<>(message.getPattern(), message.getChannel(), message.getMessage())));
return channelMessages.mergeWith(patternMessages);
@@ -149,7 +151,7 @@ class LettuceReactiveSubscription implements ReactiveSubscription {
// this is to ensure completion of the futures and result processing. Since we're unsubscribing first, we expect
// that we receive pub/sub confirmations before the PING response.
return commands.ping().then(Mono.fromRunnable(() -> {
return reactive.ping().then(Mono.fromRunnable(() -> {
connection.removeListener(listener);
}));
}));
@@ -162,7 +164,7 @@ class LettuceReactiveSubscription implements ReactiveSubscription {
*/
static class State {
private final Set<ByteBuffer> targets = new ConcurrentSkipListSet<>();
private final Set<ByteArrayWrapper> targets = new ConcurrentSkipListSet<>();
private final AtomicLong subscribers = new AtomicLong();
private final AtomicReference<Flux<?>> flux = new AtomicReference<>();
private final Function<Throwable, Throwable> exceptionTranslator;
@@ -182,8 +184,12 @@ class LettuceReactiveSubscription implements ReactiveSubscription {
*/
Mono<Void> subscribe(ByteBuffer[] targets, Function<ByteBuffer[], Mono<Void>> subscribeFunction) {
return subscribeFunction.apply(targets).doOnSuccess((discard) -> this.targets.addAll(Arrays.asList(targets)))
.onErrorMap(exceptionTranslator);
return subscribeFunction.apply(targets).doOnSuccess((discard) -> {
for (ByteBuffer target : targets) {
this.targets.add(getWrapper(target));
}
}).onErrorMap(exceptionTranslator);
}
/**
@@ -198,16 +204,18 @@ class LettuceReactiveSubscription implements ReactiveSubscription {
return Mono.defer(() -> {
List<ByteBuffer> targetCollection = Arrays.asList(targets);
return unsubscribeFunction.apply(targets).doOnSuccess((discard) -> {
this.targets.removeAll(targetCollection);
for (ByteBuffer byteBuffer : targets) {
this.targets.remove(getWrapper(byteBuffer));
}
}).onErrorMap(exceptionTranslator);
});
}
Set<ByteBuffer> getTargets() {
return Collections.unmodifiableSet(targets);
return targets.stream().map(ByteArrayWrapper::getArray).map(ByteBuffer::wrap)
.collect(Collectors.toUnmodifiableSet());
}
/**
@@ -263,5 +271,13 @@ class LettuceReactiveSubscription implements ReactiveSubscription {
disposable.dispose();
}
}
public boolean contains(ByteBuffer target) {
return this.targets.contains(getWrapper(target));
}
private static ByteArrayWrapper getWrapper(ByteBuffer byteBuffer) {
return new ByteArrayWrapper(byteBuffer);
}
}
}

View File

@@ -15,8 +15,10 @@
*/
package org.springframework.data.redis.connection.util;
import java.nio.ByteBuffer;
import java.util.Arrays;
import org.springframework.data.redis.util.ByteUtils;
import org.springframework.lang.Nullable;
/**
@@ -24,11 +26,15 @@ import org.springframework.lang.Nullable;
*
* @author Costin Leau
*/
public class ByteArrayWrapper {
public class ByteArrayWrapper implements Comparable<ByteArrayWrapper> {
private final byte[] array;
private final int hashCode;
public ByteArrayWrapper(ByteBuffer buffer) {
this(ByteUtils.getBytes(buffer.asReadOnlyBuffer()));
}
public ByteArrayWrapper(byte[] array) {
this.array = array;
this.hashCode = Arrays.hashCode(array);
@@ -54,4 +60,14 @@ public class ByteArrayWrapper {
public byte[] getArray() {
return array;
}
@Override
public String toString() {
return new String(array);
}
@Override
public int compareTo(ByteArrayWrapper o) {
return Arrays.compare(this.array, o.array);
}
}