From 8d6ebb4e9904b2e2fc0f142f582f50864e75d062 Mon Sep 17 00:00:00 2001 From: John Blum Date: Tue, 8 Aug 2023 17:05:42 -0700 Subject: [PATCH] Add configuration for `TaskExecutor` used by `ClusterCommandsExecutor`. This change allows users to leverage the VirtualThread facilities and AsyncTaskExecutor implementations provided in and by the core Spring Framework as part of our Loom support theme. Closes #2594 Original pull request: #2669 --- .../connection/ClusterCommandExecutor.java | 46 ++++---- .../connection/RedisClusterConfiguration.java | 55 ++++++---- .../redis/connection/RedisConfiguration.java | 96 ++++++++++------- .../jedis/JedisConnectionFactory.java | 24 ++++- .../lettuce/LettuceConnectionFactory.java | 100 +++++++++++++----- .../JedisConnectionFactoryUnitTests.java | 89 ++++++++++++---- .../LettuceConnectionFactoryUnitTests.java | 52 ++++++--- 7 files changed, 319 insertions(+), 143 deletions(-) diff --git a/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java b/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java index 26d9cfcba..7ef1e522f 100644 --- a/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java +++ b/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java @@ -233,6 +233,7 @@ public class ClusterCommandExecutor implements DisposableBean { while (!done) { done = true; + for (Map.Entry>> entry : futures.entrySet()) { if (!entry.getValue().isDone() && !entry.getValue().isCancelled()) { @@ -240,9 +241,11 @@ public class ClusterCommandExecutor implements DisposableBean { } else { NodeExecution execution = entry.getKey(); + try { String futureId = ObjectUtils.getIdentityHexString(entry.getValue()); + if (!saveGuard.contains(futureId)) { if (execution.isPositional()) { @@ -250,19 +253,22 @@ public class ClusterCommandExecutor implements DisposableBean { } else { result.add(entry.getValue().get()); } + saveGuard.add(futureId); } - } catch (ExecutionException e) { + } catch (ExecutionException cause) { - RuntimeException ex = convertToDataAccessException((Exception) e.getCause()); + RuntimeException exception = convertToDataAccessException((Exception) cause.getCause()); - exceptions.put(execution.getNode(), ex != null ? ex : e.getCause()); - } catch (InterruptedException e) { + exceptions.put(execution.getNode(), exception != null ? exception : cause.getCause()); + } catch (InterruptedException cause) { Thread.currentThread().interrupt(); - RuntimeException ex = convertToDataAccessException((Exception) e.getCause()); - exceptions.put(execution.getNode(), ex != null ? ex : e.getCause()); + RuntimeException exception = convertToDataAccessException((Exception) cause.getCause()); + + exceptions.put(execution.getNode(), exception != null ? exception : cause.getCause()); + break; } } @@ -271,7 +277,6 @@ public class ClusterCommandExecutor implements DisposableBean { try { Thread.sleep(10); } catch (InterruptedException e) { - done = true; Thread.currentThread().interrupt(); } @@ -280,18 +285,19 @@ public class ClusterCommandExecutor implements DisposableBean { if (!exceptions.isEmpty()) { throw new ClusterCommandExecutionFailureException(new ArrayList<>(exceptions.values())); } + return result; } /** * Run {@link MultiKeyClusterCommandCallback} with on a curated set of nodes serving one or more keys. * - * @param cmd must not be {@literal null}. + * @param commandCallback must not be {@literal null}. * @return never {@literal null}. * @throws ClusterCommandExecutionFailureException if a failure occurs while executing the given * {@link MultiKeyClusterCommandCallback command}. */ - public MultiNodeResult executeMultiKeyCommand(MultiKeyClusterCommandCallback cmd, + public MultiNodeResult executeMultiKeyCommand(MultiKeyClusterCommandCallback commandCallback, Iterable keys) { Map nodeKeyMap = new HashMap<>(); @@ -309,8 +315,8 @@ public class ClusterCommandExecutor implements DisposableBean { if (entry.getKey().isMaster()) { for (PositionalKey key : entry.getValue()) { - futures.put(new NodeExecution(entry.getKey(), key), - executor.submit(() -> executeMultiKeyCommandOnSingleNode(cmd, entry.getKey(), key.getBytes()))); + futures.put(new NodeExecution(entry.getKey(), key), this.executor.submit(() -> + executeMultiKeyCommandOnSingleNode(commandCallback, entry.getKey(), key.getBytes()))); } } } @@ -318,10 +324,10 @@ public class ClusterCommandExecutor implements DisposableBean { return collectResults(futures); } - private NodeResult executeMultiKeyCommandOnSingleNode(MultiKeyClusterCommandCallback cmd, + private NodeResult executeMultiKeyCommandOnSingleNode(MultiKeyClusterCommandCallback commandCallback, RedisClusterNode node, byte[] key) { - Assert.notNull(cmd, "MultiKeyCommandCallback must not be null"); + Assert.notNull(commandCallback, "MultiKeyCommandCallback must not be null"); Assert.notNull(node, "RedisClusterNode must not be null"); Assert.notNull(key, "Keys for execution must not be null"); @@ -330,7 +336,7 @@ public class ClusterCommandExecutor implements DisposableBean { Assert.notNull(client, "Could not acquire resource for node; Is your cluster info up to date"); try { - return new NodeResult<>(node, cmd.doInCluster(client, key), key); + return new NodeResult<>(node, commandCallback.doInCluster(client, key), key); } catch (RuntimeException ex) { RuntimeException translatedException = convertToDataAccessException(ex); @@ -345,8 +351,8 @@ public class ClusterCommandExecutor implements DisposableBean { } @Nullable - private DataAccessException convertToDataAccessException(Exception e) { - return exceptionTranslationStrategy.translate(e); + private DataAccessException convertToDataAccessException(Exception cause) { + return exceptionTranslationStrategy.translate(cause); } /** @@ -361,12 +367,12 @@ public class ClusterCommandExecutor implements DisposableBean { @Override public void destroy() throws Exception { - if (executor instanceof DisposableBean) { - ((DisposableBean) executor).destroy(); + if (this.executor instanceof DisposableBean disposableBean) { + disposableBean.destroy(); } - if (resourceProvider instanceof DisposableBean) { - ((DisposableBean) resourceProvider).destroy(); + if (this.resourceProvider instanceof DisposableBean disposableBean) { + disposableBean.destroy(); } } diff --git a/src/main/java/org/springframework/data/redis/connection/RedisClusterConfiguration.java b/src/main/java/org/springframework/data/redis/connection/RedisClusterConfiguration.java index 2068f48fd..d1fb52c59 100644 --- a/src/main/java/org/springframework/data/redis/connection/RedisClusterConfiguration.java +++ b/src/main/java/org/springframework/data/redis/connection/RedisClusterConfiguration.java @@ -26,6 +26,7 @@ import java.util.Set; import org.springframework.core.env.MapPropertySource; import org.springframework.core.env.PropertySource; +import org.springframework.core.task.AsyncTaskExecutor; import org.springframework.data.redis.connection.RedisConfiguration.ClusterConfiguration; import org.springframework.data.redis.util.RedisAssertions; import org.springframework.lang.Nullable; @@ -51,6 +52,8 @@ public class RedisClusterConfiguration implements RedisConfiguration, ClusterCon private @Nullable Integer maxRedirects; + private @Nullable AsyncTaskExecutor executor; + private RedisPassword password = RedisPassword.none(); private final Set clusterNodes; @@ -109,6 +112,13 @@ public class RedisClusterConfiguration implements RedisConfiguration, ClusterCon } } + private void appendClusterNodes(Set hostAndPorts) { + + for (String hostAndPort : hostAndPorts) { + addClusterNode(RedisNode.fromString(hostAndPort)); + } + } + /** * Set {@literal cluster nodes} to connect to. * @@ -139,6 +149,15 @@ public class RedisClusterConfiguration implements RedisConfiguration, ClusterCon this.clusterNodes.add(RedisAssertions.requireNonNull(node, "ClusterNode must not be null")); } + /** + * @param host Redis cluster node host name or ip address. + * @param port Redis cluster node port. + * @return this. + */ + public RedisClusterConfiguration clusterNode(String host, Integer port) { + return clusterNode(new RedisNode(host, port)); + } + /** * @return this. */ @@ -149,11 +168,6 @@ public class RedisClusterConfiguration implements RedisConfiguration, ClusterCon return this; } - @Override - public Integer getMaxRedirects() { - return maxRedirects != null && maxRedirects > Integer.MIN_VALUE ? maxRedirects : null; - } - /** * @param maxRedirects the max number of redirects to follow. */ @@ -164,20 +178,9 @@ public class RedisClusterConfiguration implements RedisConfiguration, ClusterCon this.maxRedirects = maxRedirects; } - /** - * @param host Redis cluster node host name or ip address. - * @param port Redis cluster node port. - * @return this. - */ - public RedisClusterConfiguration clusterNode(String host, Integer port) { - return clusterNode(new RedisNode(host, port)); - } - - private void appendClusterNodes(Set hostAndPorts) { - - for (String hostAndPort : hostAndPorts) { - addClusterNode(RedisNode.fromString(hostAndPort)); - } + @Override + public Integer getMaxRedirects() { + return maxRedirects != null && maxRedirects > Integer.MIN_VALUE ? maxRedirects : null; } @Override @@ -191,14 +194,24 @@ public class RedisClusterConfiguration implements RedisConfiguration, ClusterCon return this.username; } + @Override + public void setPassword(RedisPassword password) { + this.password = RedisAssertions.requireNonNull(password, "RedisPassword must not be null"); + } + @Override public RedisPassword getPassword() { return password; } @Override - public void setPassword(RedisPassword password) { - this.password = RedisAssertions.requireNonNull(password, "RedisPassword must not be null"); + public void setAsyncTaskExecutor(@Nullable AsyncTaskExecutor executor) { + this.executor = executor; + } + + @Nullable @Override + public AsyncTaskExecutor getAsyncTaskExecutor() { + return this.executor; } @Override diff --git a/src/main/java/org/springframework/data/redis/connection/RedisConfiguration.java b/src/main/java/org/springframework/data/redis/connection/RedisConfiguration.java index d2d451b54..c4c200614 100644 --- a/src/main/java/org/springframework/data/redis/connection/RedisConfiguration.java +++ b/src/main/java/org/springframework/data/redis/connection/RedisConfiguration.java @@ -21,6 +21,7 @@ import java.util.Set; import java.util.function.IntSupplier; import java.util.function.Supplier; +import org.springframework.core.task.AsyncTaskExecutor; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -205,6 +206,14 @@ public interface RedisConfiguration { */ void setUsername(@Nullable String username); + /** + * Get the username to use when connecting. + * + * @return {@literal null} if none set. + */ + @Nullable + String getUsername(); + /** * Create and set a {@link RedisPassword} for given {@link String}. * @@ -230,14 +239,6 @@ public interface RedisConfiguration { */ void setPassword(RedisPassword password); - /** - * Get the username to use when connecting. - * - * @return {@literal null} if none set. - */ - @Nullable - String getUsername(); - /** * Get the RedisPassword to use when connecting. * @@ -337,6 +338,53 @@ public interface RedisConfiguration { String getSocket(); } + /** + * Configuration interface suitable for Redis cluster environments. + * + * @author Christoph Strobl + * @since 2.1 + */ + interface ClusterConfiguration extends WithPassword { + + /** + * Configures the {@link AsyncTaskExecutor} used to execute commands asynchronously across the cluster. + * + * @param executor {@link AsyncTaskExecutor} used to execute commands asynchronously across the cluster. + */ + void setAsyncTaskExecutor(AsyncTaskExecutor executor); + + /** + * Returns the configured {@link AsyncTaskExecutor} used to execute commands asynchronously across the cluster. + * + * @return the configured {@link AsyncTaskExecutor} used to execute commands asynchronously across the cluster. + */ + AsyncTaskExecutor getAsyncTaskExecutor(); + + /** + * Returns an {@link Collections#unmodifiableSet(Set) Set} of {@link RedisNode cluster nodes}. + * + * @return {@link Set} of {@link RedisNode cluster nodes}. Never {@literal null}. + */ + Set getClusterNodes(); + + /** + * @return max number of redirects to follow or {@literal null} if not set. + */ + @Nullable + Integer getMaxRedirects(); + + } + + /** + * Configuration interface suitable for single node redis connections using local unix domain socket. + * + * @author Christoph Strobl + * @since 2.1 + */ + interface DomainSocketConfiguration extends WithDomainSocket, WithDatabaseIndex, WithPassword { + + } + /** * Configuration interface suitable for Redis Sentinel environments. * @@ -459,28 +507,6 @@ public interface RedisConfiguration { } - /** - * Configuration interface suitable for Redis cluster environments. - * - * @author Christoph Strobl - * @since 2.1 - */ - interface ClusterConfiguration extends WithPassword { - - /** - * Returns an {@link Collections#unmodifiableSet(Set)} of {@literal cluster nodes}. - * - * @return {@link Set} of nodes. Never {@literal null}. - */ - Set getClusterNodes(); - - /** - * @return max number of redirects to follow or {@literal null} if not set. - */ - @Nullable - Integer getMaxRedirects(); - } - /** * Configuration interface suitable for Redis master/replica environments with fixed hosts. * @@ -495,14 +521,4 @@ public interface RedisConfiguration { */ List getNodes(); } - - /** - * Configuration interface suitable for single node redis connections using local unix domain socket. - * - * @author Christoph Strobl - * @since 2.1 - */ - interface DomainSocketConfiguration extends WithDomainSocket, WithDatabaseIndex, WithPassword { - - } } diff --git a/src/main/java/org/springframework/data/redis/connection/jedis/JedisConnectionFactory.java b/src/main/java/org/springframework/data/redis/connection/jedis/JedisConnectionFactory.java index e620b6671..c29345f58 100644 --- a/src/main/java/org/springframework/data/redis/connection/jedis/JedisConnectionFactory.java +++ b/src/main/java/org/springframework/data/redis/connection/jedis/JedisConnectionFactory.java @@ -46,6 +46,7 @@ import org.apache.commons.pool2.impl.GenericObjectPoolConfig; import org.springframework.beans.factory.DisposableBean; import org.springframework.beans.factory.InitializingBean; import org.springframework.context.SmartLifecycle; +import org.springframework.core.task.AsyncTaskExecutor; import org.springframework.dao.DataAccessException; import org.springframework.dao.InvalidDataAccessApiUsageException; import org.springframework.dao.InvalidDataAccessResourceUsageException; @@ -337,12 +338,9 @@ public class JedisConnectionFactory } if (isRedisClusterAware()) { - this.cluster = createCluster(); this.topologyProvider = createTopologyProvider(this.cluster); - this.clusterCommandExecutor = new ClusterCommandExecutor(this.topologyProvider, - new JedisClusterConnection.JedisClusterNodeResourceProvider(this.cluster, this.topologyProvider), - EXCEPTION_TRANSLATION); + this.clusterCommandExecutor = newClusterCommandExecutor(); } this.state.set(State.STARTED); @@ -353,6 +351,24 @@ public class JedisConnectionFactory return State.CREATED.equals(state) || State.STOPPED.equals(state); } + private ClusterCommandExecutor newClusterCommandExecutor() { + + return new ClusterCommandExecutor(this.topologyProvider, newClusterNodeResourceProvider(), + EXCEPTION_TRANSLATION, resolveTaskExecutor(this.configuration)); + } + + private ClusterNodeResourceProvider newClusterNodeResourceProvider() { + return new JedisClusterConnection.JedisClusterNodeResourceProvider(this.cluster, this.topologyProvider); + } + + @Nullable + private AsyncTaskExecutor resolveTaskExecutor(@Nullable RedisConfiguration redisConfiguration) { + + return redisConfiguration instanceof RedisConfiguration.ClusterConfiguration clusterConfiguration + ? clusterConfiguration.getAsyncTaskExecutor() + : null; + } + @Override public void stop() { diff --git a/src/main/java/org/springframework/data/redis/connection/lettuce/LettuceConnectionFactory.java b/src/main/java/org/springframework/data/redis/connection/lettuce/LettuceConnectionFactory.java index 5883fee84..c74ed1c89 100644 --- a/src/main/java/org/springframework/data/redis/connection/lettuce/LettuceConnectionFactory.java +++ b/src/main/java/org/springframework/data/redis/connection/lettuce/LettuceConnectionFactory.java @@ -15,7 +15,7 @@ */ package org.springframework.data.redis.connection.lettuce; -import static org.springframework.data.redis.connection.lettuce.LettuceConnection.*; +import static org.springframework.data.redis.connection.lettuce.LettuceConnection.PipeliningFlushPolicy; import io.lettuce.core.AbstractRedisClient; import io.lettuce.core.ClientOptions; @@ -49,16 +49,29 @@ import org.apache.commons.logging.LogFactory; import org.springframework.beans.factory.DisposableBean; import org.springframework.beans.factory.InitializingBean; import org.springframework.context.SmartLifecycle; +import org.springframework.core.task.AsyncTaskExecutor; import org.springframework.dao.DataAccessException; import org.springframework.dao.InvalidDataAccessApiUsageException; import org.springframework.data.redis.ExceptionTranslationStrategy; import org.springframework.data.redis.PassThroughExceptionTranslationStrategy; import org.springframework.data.redis.RedisConnectionFailureException; -import org.springframework.data.redis.connection.*; +import org.springframework.data.redis.connection.ClusterCommandExecutor; +import org.springframework.data.redis.connection.ClusterTopologyProvider; +import org.springframework.data.redis.connection.ReactiveRedisConnectionFactory; +import org.springframework.data.redis.connection.RedisClusterConfiguration; +import org.springframework.data.redis.connection.RedisClusterConnection; +import org.springframework.data.redis.connection.RedisConfiguration; import org.springframework.data.redis.connection.RedisConfiguration.ClusterConfiguration; -import org.springframework.data.redis.connection.RedisConfiguration.DomainSocketConfiguration; import org.springframework.data.redis.connection.RedisConfiguration.WithDatabaseIndex; import org.springframework.data.redis.connection.RedisConfiguration.WithPassword; +import org.springframework.data.redis.connection.RedisConnection; +import org.springframework.data.redis.connection.RedisConnectionFactory; +import org.springframework.data.redis.connection.RedisPassword; +import org.springframework.data.redis.connection.RedisSentinelConfiguration; +import org.springframework.data.redis.connection.RedisSentinelConnection; +import org.springframework.data.redis.connection.RedisSocketConfiguration; +import org.springframework.data.redis.connection.RedisStandaloneConfiguration; +import org.springframework.data.redis.connection.RedisStaticMasterReplicaConfiguration; import org.springframework.data.util.Optionals; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -342,28 +355,28 @@ public class LettuceConnectionFactory implements RedisConnectionFactory, Reactiv this.configuration = this.standaloneConfig; } + @Nullable + protected ClusterCommandExecutor getClusterCommandExecutor() { + return this.clusterCommandExecutor; + } + @Override public void start() { - State current = state.getAndUpdate(state -> - State.CREATED.equals(state) || State.STOPPED.equals(state) ? State.STARTING : state); + State current = this.state.getAndUpdate(state -> isCreatedOrStopped(state) ? State.STARTING : state); - if (State.CREATED.equals(current) || State.STOPPED.equals(current)) { + if (isCreatedOrStopped(current)) { this.client = createClient(); - - this.connectionProvider = new ExceptionTranslatingConnectionProvider(createConnectionProvider(client, CODEC)); - this.reactiveConnectionProvider = new ExceptionTranslatingConnectionProvider( - createConnectionProvider(client, LettuceReactiveRedisConnection.CODEC)); + this.connectionProvider = newExceptionTranslatingConnectionProvider(this.client, LettuceConnection.CODEC); + this.reactiveConnectionProvider = newExceptionTranslatingConnectionProvider(this.client, + LettuceReactiveRedisConnection.CODEC); if (isClusterAware()) { - this.clusterCommandExecutor = new ClusterCommandExecutor( - new LettuceClusterTopologyProvider((RedisClusterClient) client), - new LettuceClusterConnection.LettuceClusterNodeResourceProvider(this.connectionProvider), - EXCEPTION_TRANSLATION); + this.clusterCommandExecutor = newClusterCommandExecutor(); } - state.set(State.STARTED); + this.state.set(State.STARTED); if (getEagerInitialization() && getShareNativeConnection()) { initConnection(); @@ -371,6 +384,38 @@ public class LettuceConnectionFactory implements RedisConnectionFactory, Reactiv } } + private boolean isCreatedOrStopped(@Nullable State state) { + return State.CREATED.equals(state) || State.STOPPED.equals(state); + } + + private ClusterCommandExecutor newClusterCommandExecutor() { + + return new ClusterCommandExecutor(newClusterTopologyProvider(), newClusterNodeResourceProvider(), + EXCEPTION_TRANSLATION, resolveTaskExecutor(this.configuration)); + } + + private LettuceClusterConnection.LettuceClusterNodeResourceProvider newClusterNodeResourceProvider() { + return new LettuceClusterConnection.LettuceClusterNodeResourceProvider(this.connectionProvider); + } + + private LettuceClusterTopologyProvider newClusterTopologyProvider() { + return new LettuceClusterTopologyProvider((RedisClusterClient) this.client); + } + + @Nullable + private AsyncTaskExecutor resolveTaskExecutor(RedisConfiguration redisConfiguration) { + + return redisConfiguration instanceof ClusterConfiguration clusterConfiguration + ? clusterConfiguration.getAsyncTaskExecutor() + : null; + } + + private ExceptionTranslatingConnectionProvider newExceptionTranslatingConnectionProvider(AbstractRedisClient client, + RedisCodec codec) { + + return new ExceptionTranslatingConnectionProvider(createConnectionProvider(client, codec)); + } + @Override public void stop() { @@ -420,7 +465,7 @@ public class LettuceConnectionFactory implements RedisConnectionFactory, Reactiv @Override public boolean isRunning() { - return State.STARTED.equals(state.get()); + return State.STARTED.equals(this.state.get()); } @Override @@ -434,17 +479,20 @@ public class LettuceConnectionFactory implements RedisConnectionFactory, Reactiv public void destroy() { stop(); - client = null; + this.client = null; + + ClusterCommandExecutor clusterCommandExecutor = getClusterCommandExecutor(); if (clusterCommandExecutor != null) { try { clusterCommandExecutor.destroy(); - } catch (Exception ex) { - log.warn("Cannot properly close cluster command executor", ex); + this.clusterCommandExecutor = null; + } catch (Exception cause) { + log.warn("Cannot properly close cluster command executor", cause); } } - state.set(State.DESTROYED); + this.state.set(State.DESTROYED); } private void dispose(@Nullable LettuceConnectionProvider connectionProvider) { @@ -472,7 +520,7 @@ public class LettuceConnectionFactory implements RedisConnectionFactory, Reactiv LettuceConnection connection = doCreateLettuceConnection(getSharedConnection(), connectionProvider, getTimeout(), getDatabase()); - connection.setConvertPipelineAndTxResults(convertPipelineAndTxResults); + connection.setConvertPipelineAndTxResults(this.convertPipelineAndTxResults); return connection; } @@ -492,8 +540,8 @@ public class LettuceConnectionFactory implements RedisConnectionFactory, Reactiv LettuceClusterTopologyProvider topologyProvider = new LettuceClusterTopologyProvider(clusterClient); - return doCreateLettuceClusterConnection(sharedConnection, connectionProvider, topologyProvider, - clusterCommandExecutor, clientConfiguration.getCommandTimeout()); + return doCreateLettuceClusterConnection(sharedConnection, this.connectionProvider, topologyProvider, + getClusterCommandExecutor(), this.clientConfiguration.getCommandTimeout()); } /** @@ -819,7 +867,7 @@ public class LettuceConnectionFactory implements RedisConnectionFactory, Reactiv * @return native connection shared. */ public boolean getShareNativeConnection() { - return shareNativeConnection; + return this.shareNativeConnection; } /** @@ -842,7 +890,7 @@ public class LettuceConnectionFactory implements RedisConnectionFactory, Reactiv * @since 2.2 */ public boolean getEagerInitialization() { - return eagerInitialization; + return this.eagerInitialization; } /** @@ -1164,7 +1212,7 @@ public class LettuceConnectionFactory implements RedisConnectionFactory, Reactiv return shareNativeConnection ? getOrCreateSharedReactiveConnection().getConnection() : null; } - private LettuceConnectionProvider createConnectionProvider(AbstractRedisClient client, RedisCodec codec) { + LettuceConnectionProvider createConnectionProvider(AbstractRedisClient client, RedisCodec codec) { LettuceConnectionProvider connectionProvider = doCreateConnectionProvider(client, codec); diff --git a/src/test/java/org/springframework/data/redis/connection/jedis/JedisConnectionFactoryUnitTests.java b/src/test/java/org/springframework/data/redis/connection/jedis/JedisConnectionFactoryUnitTests.java index a18b2ecb2..de6bd7f4d 100644 --- a/src/test/java/org/springframework/data/redis/connection/jedis/JedisConnectionFactoryUnitTests.java +++ b/src/test/java/org/springframework/data/redis/connection/jedis/JedisConnectionFactoryUnitTests.java @@ -15,12 +15,16 @@ */ package org.springframework.data.redis.connection.jedis; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; - -import redis.clients.jedis.JedisClientConfig; -import redis.clients.jedis.JedisCluster; -import redis.clients.jedis.JedisPoolConfig; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import java.io.IOException; import java.security.NoSuchAlgorithmException; @@ -33,15 +37,25 @@ import javax.net.ssl.SSLContext; import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSocketFactory; -import org.apache.commons.pool2.impl.GenericObjectPoolConfig; import org.junit.jupiter.api.Test; + +import org.springframework.core.task.AsyncTaskExecutor; +import org.springframework.data.redis.connection.ClusterCommandExecutor; +import org.springframework.data.redis.connection.ClusterTopologyProvider; import org.springframework.data.redis.connection.RedisClusterConfiguration; import org.springframework.data.redis.connection.RedisPassword; import org.springframework.data.redis.connection.RedisSentinelConfiguration; import org.springframework.data.redis.connection.RedisStandaloneConfiguration; import org.springframework.data.redis.connection.jedis.JedisConnectionFactory.State; +import org.springframework.lang.Nullable; import org.springframework.test.util.ReflectionTestUtils; +import org.apache.commons.pool2.impl.GenericObjectPoolConfig; + +import redis.clients.jedis.JedisClientConfig; +import redis.clients.jedis.JedisCluster; +import redis.clients.jedis.JedisPoolConfig; + /** * Unit tests for {@link JedisConnectionFactory}. * @@ -335,24 +349,59 @@ class JedisConnectionFactoryUnitTests { assertThat(connectionFactory.isRunning()).isTrue(); } - private JedisConnectionFactory initSpyedConnectionFactory(RedisSentinelConfiguration sentinelConfig, - JedisPoolConfig poolConfig) { + @Test // GH-2594 + void configuresCustomTaskExecutorCorrectly() { + + AsyncTaskExecutor mockTaskExecutor = mock(AsyncTaskExecutor.class); + ClusterTopologyProvider mockClusterTopologyProvider = mock(ClusterTopologyProvider.class); + JedisCluster mockJedisCluster = mock(JedisCluster.class); + + RedisClusterConfiguration clusterConfiguration = new RedisClusterConfiguration(); + + clusterConfiguration.setAsyncTaskExecutor(mockTaskExecutor); + + JedisConnectionFactory connectionFactory = initSpyedConnectionFactory(clusterConfiguration, null); + + doReturn(false).when(connectionFactory).getUsePool(); + doReturn(mockJedisCluster).when(connectionFactory).createCluster(); + doReturn(mockClusterTopologyProvider).when(connectionFactory).createTopologyProvider(eq(mockJedisCluster)); + + connectionFactory.start(); + + assertThat(connectionFactory.isRunning()).isTrue(); + + ClusterCommandExecutor clusterCommandExecutor = connectionFactory.getClusterCommandExecutor(); + + assertThat(clusterCommandExecutor).isNotNull(); + assertThat(ReflectionTestUtils.getField(clusterCommandExecutor, "executor")).isEqualTo(mockTaskExecutor); + } + + private JedisConnectionFactory initSpyedConnectionFactory(RedisSentinelConfiguration sentinelConfiguration, + @Nullable JedisPoolConfig poolConfig) { // we have to use a spy here as jedis would start connecting to redis sentinels when the pool is created. - JedisConnectionFactory factorySpy = spy(new JedisConnectionFactory(sentinelConfig, poolConfig)); - doReturn(null).when(factorySpy).createRedisSentinelPool(any(RedisSentinelConfiguration.class)); - doReturn(null).when(factorySpy).createRedisPool(); - return factorySpy; + JedisConnectionFactory connectionFactorySpy = spy(new JedisConnectionFactory(sentinelConfiguration, poolConfig)); + + doReturn(null).when(connectionFactorySpy) + .createRedisSentinelPool(any(RedisSentinelConfiguration.class)); + + doReturn(null).when(connectionFactorySpy).createRedisPool(); + + return connectionFactorySpy; } - private JedisConnectionFactory initSpyedConnectionFactory(RedisClusterConfiguration clusterConfig, - JedisPoolConfig poolConfig) { + private JedisConnectionFactory initSpyedConnectionFactory(RedisClusterConfiguration clusterConfiguration, + @Nullable JedisPoolConfig poolConfig) { JedisCluster clusterMock = mock(JedisCluster.class); - JedisConnectionFactory factorySpy = spy(new JedisConnectionFactory(clusterConfig)); - doReturn(clusterMock).when(factorySpy).createCluster(any(RedisClusterConfiguration.class), - any(GenericObjectPoolConfig.class)); - doReturn(null).when(factorySpy).createRedisPool(); - return factorySpy; + + JedisConnectionFactory connectionFactorySpy = spy(new JedisConnectionFactory(clusterConfiguration, poolConfig)); + + doReturn(clusterMock).when(connectionFactorySpy) + .createCluster(any(RedisClusterConfiguration.class), any(GenericObjectPoolConfig.class)); + + doReturn(null).when(connectionFactorySpy).createRedisPool(); + + return connectionFactorySpy; } } diff --git a/src/test/java/org/springframework/data/redis/connection/lettuce/LettuceConnectionFactoryUnitTests.java b/src/test/java/org/springframework/data/redis/connection/lettuce/LettuceConnectionFactoryUnitTests.java index ae8a4e456..84200c8e3 100644 --- a/src/test/java/org/springframework/data/redis/connection/lettuce/LettuceConnectionFactoryUnitTests.java +++ b/src/test/java/org/springframework/data/redis/connection/lettuce/LettuceConnectionFactoryUnitTests.java @@ -15,12 +15,15 @@ */ package org.springframework.data.redis.connection.lettuce; -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; import static org.mockito.Mockito.*; -import static org.springframework.data.redis.connection.ClusterTestVariables.*; -import static org.springframework.data.redis.connection.RedisConfiguration.*; -import static org.springframework.data.redis.test.extension.LettuceTestClientResources.*; -import static org.springframework.test.util.ReflectionTestUtils.*; +import static org.springframework.data.redis.connection.ClusterTestVariables.CLUSTER_NODE_1; +import static org.springframework.data.redis.connection.RedisConfiguration.WithHostAndPort; +import static org.springframework.data.redis.test.extension.LettuceTestClientResources.getSharedClientResources; +import static org.springframework.test.util.ReflectionTestUtils.getField; import io.lettuce.core.AbstractRedisClient; import io.lettuce.core.ClientOptions; @@ -43,6 +46,7 @@ import java.util.Collections; import java.util.Objects; import java.util.concurrent.CompletableFuture; +import org.assertj.core.api.InstanceOfAssertFactories; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; @@ -50,8 +54,10 @@ import org.junit.jupiter.api.Test; import org.mockito.ArgumentMatchers; import org.springframework.beans.DirectFieldAccessor; import org.springframework.beans.factory.DisposableBean; +import org.springframework.core.task.AsyncTaskExecutor; import org.springframework.data.redis.ConnectionFactoryTracker; import org.springframework.data.redis.RedisConnectionFailureException; +import org.springframework.data.redis.connection.ClusterCommandExecutor; import org.springframework.data.redis.connection.PoolException; import org.springframework.data.redis.connection.RedisClusterConfiguration; import org.springframework.data.redis.connection.RedisClusterConnection; @@ -63,9 +69,6 @@ import org.springframework.data.redis.connection.RedisSentinelConfiguration; import org.springframework.data.redis.connection.RedisSocketConfiguration; import org.springframework.data.redis.connection.RedisStandaloneConfiguration; import org.springframework.data.redis.test.extension.LettuceTestClientResources; -import org.springframework.test.util.ReflectionTestUtils; - -import org.assertj.core.api.InstanceOfAssertFactories; /** * Unit tests for {@link LettuceConnectionFactory}. @@ -823,7 +826,7 @@ class LettuceConnectionFactoryUnitTests { ConnectionFactoryTracker.add(connectionFactory); RedisClusterConnection clusterConnection = connectionFactory.getClusterConnection(); - assertThat(ReflectionTestUtils.getField(clusterConnection, "timeout")).isEqualTo(2000L); + assertThat(getField(clusterConnection, "timeout")).isEqualTo(2000L); clusterConnection.close(); } @@ -839,7 +842,7 @@ class LettuceConnectionFactoryUnitTests { ConnectionFactoryTracker.add(connectionFactory); RedisClusterConnection clusterConnection = connectionFactory.getClusterConnection(); - assertThat(ReflectionTestUtils.getField(clusterConnection, "timeout")).isEqualTo(2000L); + assertThat(getField(clusterConnection, "timeout")).isEqualTo(2000L); clusterConnection.close(); } @@ -1250,8 +1253,8 @@ class LettuceConnectionFactoryUnitTests { .withNoCause()); } - @Test - public void createRedisConfigurationWithValidRedisUriString() { + @Test // GH-2594 + void createRedisConfigurationWithValidRedisUriString() { RedisConfiguration redisConfiguration = LettuceConnectionFactory.createRedisConfiguration("redis://skullbox:6789"); @@ -1269,6 +1272,31 @@ class LettuceConnectionFactoryUnitTests { .isEqualTo(6789); } + @Test // GH-2594 + void configuresCustomTaskExecutorCorrectly() { + + AsyncTaskExecutor mockTaskExecutor = mock(AsyncTaskExecutor.class); + LettuceConnectionProvider mockConnectionProvider = mock(LettuceConnectionProvider.class); + RedisClusterClient mockRedisClient = mock(RedisClusterClient.class); + + RedisClusterConfiguration clusterConfiguration = new RedisClusterConfiguration(); + + clusterConfiguration.setAsyncTaskExecutor(mockTaskExecutor); + + LettuceConnectionFactory connectionFactory = spy(new LettuceConnectionFactory(clusterConfiguration)); + + doReturn(mockRedisClient).when(connectionFactory).createClient(); + doReturn(mockConnectionProvider).when(connectionFactory).createConnectionProvider(eq(mockRedisClient), any()); + + connectionFactory.start(); + + assertThat(connectionFactory.isRunning()).isTrue(); + + ClusterCommandExecutor clusterCommandExecutor = connectionFactory.getClusterCommandExecutor(); + + assertThat(getField(clusterCommandExecutor, "executor")).isEqualTo(mockTaskExecutor); + } + static class CustomRedisConfiguration implements RedisConfiguration, WithHostAndPort { private String hostName;