diff --git a/src/main/java/org/springframework/data/redis/core/RedisAccessor.java b/src/main/java/org/springframework/data/redis/core/RedisAccessor.java index 4a5adbdfb..93b4fef34 100644 --- a/src/main/java/org/springframework/data/redis/core/RedisAccessor.java +++ b/src/main/java/org/springframework/data/redis/core/RedisAccessor.java @@ -19,9 +19,8 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.beans.factory.InitializingBean; import org.springframework.data.redis.connection.RedisConnectionFactory; -import org.springframework.lang.NonNull; +import org.springframework.data.redis.util.RedisAssertions; import org.springframework.lang.Nullable; -import org.springframework.util.Assert; /** * Base class for {@link RedisTemplate} defining common properties. Not intended to be used directly. @@ -62,14 +61,8 @@ public class RedisAccessor implements InitializingBean { * @see #getConnectionFactory() * @since 2.0 */ - @NonNull public RedisConnectionFactory getRequiredConnectionFactory() { - - RedisConnectionFactory connectionFactory = getConnectionFactory(); - - Assert.state(connectionFactory != null, "RedisConnectionFactory is required"); - - return connectionFactory; + return RedisAssertions.requireState(getConnectionFactory(), "RedisConnectionFactory is required"); } /** diff --git a/src/main/java/org/springframework/data/redis/util/RedisAssertions.java b/src/main/java/org/springframework/data/redis/util/RedisAssertions.java index 93a5adf6b..1e26b5936 100644 --- a/src/main/java/org/springframework/data/redis/util/RedisAssertions.java +++ b/src/main/java/org/springframework/data/redis/util/RedisAssertions.java @@ -56,4 +56,33 @@ public abstract class RedisAssertions { Assert.notNull(target, message); return target; } + + /** + * Asserts the given {@link Object} is not {@literal null}. + * + * @param {@link Class type} of {@link Object} being asserted. + * @param target {@link Object} to evaluate. + * @param message {@link String} containing the message for the thrown exception. + * @param arguments array of {@link Object} arguments used to format the {@link String message}. + * @return the given {@link Object}. + * @throws IllegalArgumentException if the {@link Object target} is {@literal null}. + * @see #requireObject(Object, Supplier) + */ + public static T requireState(@Nullable T target, String message, Object... arguments) { + return requireState(target, () -> String.format(message, arguments)); + } + + /** + * Asserts the given {@link Object} is not {@literal null}. + * + * @param {@link Class type} of {@link Object} being asserted. + * @param target {@link Object} to evaluate. + * @param message {@link Supplier} supplying the message for the thrown exception. + * @return the given {@link Object}. + * @throws IllegalArgumentException if the {@link Object target} is {@literal null}. + */ + public static T requireState(@Nullable T target, Supplier message) { + Assert.state(target != null, message); + return target; + } } diff --git a/src/test/java/org/springframework/data/redis/core/RedisAccessorUnitTests.java b/src/test/java/org/springframework/data/redis/core/RedisAccessorUnitTests.java new file mode 100644 index 000000000..a87b8055f --- /dev/null +++ b/src/test/java/org/springframework/data/redis/core/RedisAccessorUnitTests.java @@ -0,0 +1,87 @@ +/* + * Copyright 2017-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.redis.core; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +import org.junit.jupiter.api.Test; + +import org.springframework.data.redis.connection.RedisConnectionFactory; + +/** + * Unit Tests for {@link RedisAccessor}. + * + * @author John Blum + * @see org.junit.jupiter.api.Test + * @see org.springframework.data.redis.core.RedisAccessor + * @since 3.2.0 + */ +public class RedisAccessorUnitTests { + + @Test + public void setAndGetConnectionFactory() { + + RedisConnectionFactory mockConnectionFactory = mock(RedisConnectionFactory.class); + + RedisAccessor redisAccessor = new RedisAccessor(); + + assertThat(redisAccessor.getConnectionFactory()).isNull(); + + redisAccessor.setConnectionFactory(mockConnectionFactory); + + assertThat(redisAccessor.getConnectionFactory()).isSameAs(mockConnectionFactory); + assertThat(redisAccessor.getRequiredConnectionFactory()).isSameAs(mockConnectionFactory); + + redisAccessor.setConnectionFactory(null); + + assertThat(redisAccessor.getConnectionFactory()).isNull(); + + verifyNoInteractions(mockConnectionFactory); + } + + @Test + public void getRequiredConnectionFactoryWhenNull() { + + assertThatIllegalStateException() + .isThrownBy(() -> new RedisAccessor().getRequiredConnectionFactory()) + .withMessage("RedisConnectionFactory is required") + .withNoCause(); + } + + @Test + public void afterPropertiesSetCallsGetRequiredConnectionFactory() { + + RedisConnectionFactory mockConnectionFactory = mock(RedisConnectionFactory.class); + + RedisAccessor redisAccessor = spy(new RedisAccessor()); + + doReturn(mockConnectionFactory).when(redisAccessor).getRequiredConnectionFactory(); + + redisAccessor.afterPropertiesSet(); + + verify(redisAccessor, times(1)).afterPropertiesSet(); + verify(redisAccessor, times(1)).getRequiredConnectionFactory(); + verifyNoMoreInteractions(redisAccessor); + } +} diff --git a/src/test/java/org/springframework/data/redis/util/RedisAssertionsUnitTests.java b/src/test/java/org/springframework/data/redis/util/RedisAssertionsUnitTests.java index 80cd71346..6605fb13d 100644 --- a/src/test/java/org/springframework/data/redis/util/RedisAssertionsUnitTests.java +++ b/src/test/java/org/springframework/data/redis/util/RedisAssertionsUnitTests.java @@ -17,6 +17,7 @@ package org.springframework.data.redis.util; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -79,4 +80,40 @@ public class RedisAssertionsUnitTests { verify(this.mockSupplier, times(1)).get(); verifyNoMoreInteractions(this.mockSupplier); } + + @Test + public void requireStateWithMessageAndArgumentsIsSuccessful() { + assertThat(RedisAssertions.requireState("test", "Mock message")).isEqualTo("test"); + } + + @Test + public void requireStateWithMessageAndArgumentsThrowsIllegalStateException() { + + assertThatIllegalStateException() + .isThrownBy(() -> RedisAssertions.requireState(null, "This is a %s", "test")) + .withMessage("This is a test") + .withNoCause(); + } + + @Test + public void requireStateWithSupplierIsSuccessful() { + + assertThat(RedisAssertions.requireState("test", this.mockSupplier)).isEqualTo("test"); + + verifyNoInteractions(this.mockSupplier); + } + + @Test + public void requiredStateWithSupplierThrowsIllegalStateException() { + + doReturn("Mock message").when(this.mockSupplier).get(); + + assertThatIllegalStateException() + .isThrownBy(() -> RedisAssertions.requireState(null, this.mockSupplier)) + .withMessage("Mock message") + .withNoCause(); + + verify(this.mockSupplier, times(1)).get(); + verifyNoMoreInteractions(this.mockSupplier); + } }