Commit eb120041 authored by Andy Wilkinson's avatar Andy Wilkinson

Derive a ConnectionFactoryBuilder from an existing ConnectionFactory

Closes gh-25788
parent 870d9b19
......@@ -29,6 +29,10 @@ import io.r2dbc.pool.ConnectionPool;
import io.r2dbc.pool.PoolMetrics;
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.Option;
import io.r2dbc.spi.Wrapped;
import org.assertj.core.api.InstanceOfAssertFactories;
import org.assertj.core.api.InstanceOfAssertFactory;
import org.assertj.core.api.ObjectAssert;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.BeanCreationException;
......@@ -36,6 +40,7 @@ import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
import org.springframework.boot.autoconfigure.r2dbc.SimpleConnectionFactoryProvider.SimpleTestConnectionFactory;
import org.springframework.boot.r2dbc.EmbeddedDatabaseConnection;
import org.springframework.boot.r2dbc.OptionsCapableConnectionFactory;
import org.springframework.boot.test.context.FilteredClassLoader;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.context.annotation.Bean;
......@@ -58,8 +63,14 @@ class R2dbcAutoConfigurationTests {
@Test
void configureWithUrlCreateConnectionPoolByDefault() {
this.contextRunner.withPropertyValues("spring.r2dbc.url:r2dbc:h2:mem:///" + randomDatabaseName())
.run((context) -> assertThat(context).hasSingleBean(ConnectionFactory.class)
.hasSingleBean(ConnectionPool.class));
.run((context) -> {
assertThat(context).hasSingleBean(ConnectionFactory.class).hasSingleBean(ConnectionPool.class);
assertThat(context.getBean(ConnectionPool.class)).extracting(ConnectionPool::unwrap)
.satisfies((connectionFactory) -> assertThat(connectionFactory)
.asInstanceOf(type(OptionsCapableConnectionFactory.class))
.extracting(Wrapped<ConnectionFactory>::unwrap)
.isExactlyInstanceOf(H2ConnectionFactory.class));
});
}
@Test
......@@ -113,7 +124,10 @@ class R2dbcAutoConfigurationTests {
this.contextRunner.withPropertyValues("spring.r2dbc.pool.enabled=false", "spring.r2dbc.url:r2dbc:h2:mem:///"
+ randomDatabaseName() + "?options=DB_CLOSE_DELAY=-1;DB_CLOSE_ON_EXIT=FALSE").run((context) -> {
assertThat(context).hasSingleBean(ConnectionFactory.class).doesNotHaveBean(ConnectionPool.class);
assertThat(context.getBean(ConnectionFactory.class)).isExactlyInstanceOf(H2ConnectionFactory.class);
assertThat(context.getBean(ConnectionFactory.class))
.asInstanceOf(type(OptionsCapableConnectionFactory.class))
.extracting(Wrapped<ConnectionFactory>::unwrap)
.isExactlyInstanceOf(H2ConnectionFactory.class);
});
}
......@@ -122,8 +136,10 @@ class R2dbcAutoConfigurationTests {
this.contextRunner.with(hideConnectionPool()).withPropertyValues("spring.r2dbc.url:r2dbc:h2:mem:///"
+ randomDatabaseName() + "?options=DB_CLOSE_DELAY=-1;DB_CLOSE_ON_EXIT=FALSE").run((context) -> {
assertThat(context).hasSingleBean(ConnectionFactory.class);
ConnectionFactory bean = context.getBean(ConnectionFactory.class);
assertThat(bean).isExactlyInstanceOf(H2ConnectionFactory.class);
assertThat(context.getBean(ConnectionFactory.class))
.asInstanceOf(type(OptionsCapableConnectionFactory.class))
.extracting(Wrapped<ConnectionFactory>::unwrap)
.isExactlyInstanceOf(H2ConnectionFactory.class);
});
}
......@@ -142,11 +158,10 @@ class R2dbcAutoConfigurationTests {
.withPropertyValues("spring.r2dbc.pool.enabled=false", "spring.r2dbc.url:r2dbc:simple://host/database")
.withUserConfiguration(CustomizerConfiguration.class).run((context) -> {
assertThat(context).hasSingleBean(ConnectionFactory.class).doesNotHaveBean(ConnectionPool.class);
ConnectionFactory bean = context.getBean(ConnectionFactory.class);
assertThat(bean).isExactlyInstanceOf(SimpleTestConnectionFactory.class);
SimpleTestConnectionFactory connectionFactory = (SimpleTestConnectionFactory) bean;
assertThat(connectionFactory.getOptions().getRequiredValue(Option.<Boolean>valueOf("customized")))
.isTrue();
ConnectionFactory connectionFactory = context.getBean(ConnectionFactory.class);
assertThat(connectionFactory).asInstanceOf(type(OptionsCapableConnectionFactory.class))
.extracting(OptionsCapableConnectionFactory::getOptions).satisfies((options) -> assertThat(
options.getRequiredValue(Option.<Boolean>valueOf("customized"))).isTrue());
});
}
......@@ -155,11 +170,11 @@ class R2dbcAutoConfigurationTests {
this.contextRunner.withPropertyValues("spring.r2dbc.url:r2dbc:simple://host/database")
.withUserConfiguration(CustomizerConfiguration.class).run((context) -> {
assertThat(context).hasSingleBean(ConnectionFactory.class).hasSingleBean(ConnectionPool.class);
ConnectionFactory bean = context.getBean(ConnectionFactory.class);
SimpleTestConnectionFactory connectionFactory = (SimpleTestConnectionFactory) ((ConnectionPool) bean)
.unwrap();
assertThat(connectionFactory.getOptions().getRequiredValue(Option.<Boolean>valueOf("customized")))
.isTrue();
ConnectionFactory pool = context.getBean(ConnectionFactory.class);
ConnectionFactory connectionFactory = ((ConnectionPool) pool).unwrap();
assertThat(connectionFactory).asInstanceOf(type(OptionsCapableConnectionFactory.class))
.extracting(OptionsCapableConnectionFactory::getOptions).satisfies((options) -> assertThat(
options.getRequiredValue(Option.<Boolean>valueOf("customized"))).isTrue());
});
}
......@@ -174,8 +189,10 @@ class R2dbcAutoConfigurationTests {
this.contextRunner.withPropertyValues("spring.r2dbc.pool.enabled=false", "spring.r2dbc.url:r2dbc:simple://foo")
.withClassLoader(new FilteredClassLoader("org.springframework.jdbc")).run((context) -> {
assertThat(context).hasSingleBean(ConnectionFactory.class);
ConnectionFactory connectionFactory = context.getBean(ConnectionFactory.class);
assertThat(connectionFactory).isInstanceOf(SimpleTestConnectionFactory.class);
assertThat(context.getBean(ConnectionFactory.class))
.asInstanceOf(type(OptionsCapableConnectionFactory.class))
.extracting(Wrapped<ConnectionFactory>::unwrap)
.isExactlyInstanceOf(SimpleTestConnectionFactory.class);
});
}
......@@ -183,9 +200,12 @@ class R2dbcAutoConfigurationTests {
void configureWithoutPoolShouldApplyAdditionalProperties() {
this.contextRunner.withPropertyValues("spring.r2dbc.pool.enabled=false", "spring.r2dbc.url:r2dbc:simple://foo",
"spring.r2dbc.properties.test=value", "spring.r2dbc.properties.another=2").run((context) -> {
SimpleTestConnectionFactory connectionFactory = context.getBean(SimpleTestConnectionFactory.class);
assertThat(getRequiredOptionsValue(connectionFactory, "test")).isEqualTo("value");
assertThat(getRequiredOptionsValue(connectionFactory, "another")).isEqualTo("2");
ConnectionFactory connectionFactory = context.getBean(ConnectionFactory.class);
assertThat(connectionFactory).asInstanceOf(type(OptionsCapableConnectionFactory.class))
.extracting(OptionsCapableConnectionFactory::getOptions).satisfies((options) -> {
assertThat(options.getRequiredValue(Option.<String>valueOf("test"))).isEqualTo("value");
assertThat(options.getRequiredValue(Option.<String>valueOf("another"))).isEqualTo("2");
});
});
}
......@@ -194,17 +214,15 @@ class R2dbcAutoConfigurationTests {
this.contextRunner.withPropertyValues("spring.r2dbc.url:r2dbc:simple://foo",
"spring.r2dbc.properties.test=value", "spring.r2dbc.properties.another=2").run((context) -> {
assertThat(context).hasSingleBean(ConnectionFactory.class).hasSingleBean(ConnectionPool.class);
SimpleTestConnectionFactory connectionFactory = (SimpleTestConnectionFactory) context
.getBean(ConnectionPool.class).unwrap();
assertThat(getRequiredOptionsValue(connectionFactory, "test")).isEqualTo("value");
assertThat(getRequiredOptionsValue(connectionFactory, "another")).isEqualTo("2");
ConnectionFactory connectionFactory = context.getBean(ConnectionPool.class).unwrap();
assertThat(connectionFactory).asInstanceOf(type(OptionsCapableConnectionFactory.class))
.extracting(OptionsCapableConnectionFactory::getOptions).satisfies((options) -> {
assertThat(options.getRequiredValue(Option.<String>valueOf("test"))).isEqualTo("value");
assertThat(options.getRequiredValue(Option.<String>valueOf("another"))).isEqualTo("2");
});
});
}
private Object getRequiredOptionsValue(SimpleTestConnectionFactory connectionFactory, String name) {
return connectionFactory.options.getRequiredValue(Option.valueOf(name));
}
@Test
void configureWithoutUrlShouldCreateEmbeddedConnectionPoolByDefault() {
this.contextRunner.run((context) -> assertThat(context).hasSingleBean(ConnectionFactory.class)
......@@ -215,7 +233,9 @@ class R2dbcAutoConfigurationTests {
void configureWithoutUrlAndPollPoolDisabledCreateGenericConnectionFactory() {
this.contextRunner.withPropertyValues("spring.r2dbc.pool.enabled=false").run((context) -> {
assertThat(context).hasSingleBean(ConnectionFactory.class).doesNotHaveBean(ConnectionPool.class);
assertThat(context.getBean(ConnectionFactory.class)).isExactlyInstanceOf(H2ConnectionFactory.class);
assertThat(context.getBean(ConnectionFactory.class))
.asInstanceOf(type(OptionsCapableConnectionFactory.class))
.extracting(Wrapped<ConnectionFactory>::unwrap).isExactlyInstanceOf(H2ConnectionFactory.class);
});
}
......@@ -260,6 +280,10 @@ class R2dbcAutoConfigurationTests {
.doesNotHaveBean(DatabaseClient.class));
}
private <T> InstanceOfAssertFactory<T, ObjectAssert<T>> type(Class<T> type) {
return InstanceOfAssertFactories.type(type);
}
private String randomDatabaseName() {
return "testdb-" + UUID.randomUUID();
}
......
......@@ -16,8 +16,8 @@
package org.springframework.boot.autoconfigure.r2dbc;
import io.r2dbc.pool.ConnectionPool;
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.Wrapped;
import org.springframework.boot.autoconfigure.r2dbc.SimpleConnectionFactoryProvider.SimpleTestConnectionFactory;
import org.springframework.r2dbc.core.binding.BindMarkersFactory;
......@@ -38,9 +38,10 @@ public class SimpleBindMarkerFactoryProvider implements BindMarkerFactoryProvide
return null;
}
@SuppressWarnings("unchecked")
private ConnectionFactory unwrapIfNecessary(ConnectionFactory connectionFactory) {
if (connectionFactory instanceof ConnectionPool) {
return ((ConnectionPool) connectionFactory).unwrap();
if (connectionFactory instanceof Wrapped) {
return unwrapIfNecessary(((Wrapped<ConnectionFactory>) connectionFactory).unwrap());
}
return connectionFactory;
}
......
......@@ -25,15 +25,16 @@ import org.reactivestreams.Publisher;
import reactor.core.publisher.Mono;
/**
* Simple driver to capture {@link ConnectionFactoryOptions}.
* Simple driver for testing.
*
* @author Mark Paluch
* @author Andy Wilkinson
*/
public class SimpleConnectionFactoryProvider implements ConnectionFactoryProvider {
@Override
public ConnectionFactory create(ConnectionFactoryOptions connectionFactoryOptions) {
return new SimpleTestConnectionFactory(connectionFactoryOptions);
return new SimpleTestConnectionFactory();
}
@Override
......@@ -48,12 +49,6 @@ public class SimpleConnectionFactoryProvider implements ConnectionFactoryProvide
public static class SimpleTestConnectionFactory implements ConnectionFactory {
final ConnectionFactoryOptions options;
SimpleTestConnectionFactory(ConnectionFactoryOptions options) {
this.options = options;
}
@Override
public Publisher<? extends Connection> create() {
return Mono.error(new UnsupportedOperationException());
......@@ -64,10 +59,6 @@ public class SimpleConnectionFactoryProvider implements ConnectionFactoryProvide
return SimpleConnectionFactoryProvider.class::getName;
}
public ConnectionFactoryOptions getOptions() {
return this.options;
}
}
}
......@@ -35,6 +35,7 @@ dependencies {
optional("io.netty:netty-tcnative-boringssl-static")
optional("io.projectreactor:reactor-tools")
optional("io.projectreactor.netty:reactor-netty-http")
optional("io.r2dbc:r2dbc-pool")
optional("io.rsocket:rsocket-core")
optional("io.rsocket:rsocket-transport-netty")
optional("io.undertow:undertow-servlet") {
......
......@@ -16,14 +16,24 @@
package org.springframework.boot.r2dbc;
import java.time.Duration;
import java.util.Locale;
import java.util.function.Consumer;
import java.util.function.Function;
import io.r2dbc.pool.ConnectionPool;
import io.r2dbc.pool.ConnectionPoolConfiguration;
import io.r2dbc.pool.PoolingConnectionFactoryProvider;
import io.r2dbc.spi.ConnectionFactories;
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.ConnectionFactoryOptions;
import io.r2dbc.spi.ConnectionFactoryOptions.Builder;
import io.r2dbc.spi.ValidationDepth;
import io.r2dbc.spi.Wrapped;
import org.springframework.boot.context.properties.PropertyMapper;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
/**
* Builder for {@link ConnectionFactory}.
......@@ -31,10 +41,24 @@ import org.springframework.util.Assert;
* @author Mark Paluch
* @author Tadaya Tsuyukubo
* @author Stephane Nicoll
* @author Andy Wilkinson
* @since 2.5.0
*/
public final class ConnectionFactoryBuilder {
private static final OptionsCapableWrapper optionsCapableWrapper;
static {
if (ClassUtils.isPresent("io.r2dbc.pool.ConnectionPool", ConnectionFactoryBuilder.class.getClassLoader())) {
optionsCapableWrapper = new PoolingAwareOptionsCapableWrapper();
}
else {
optionsCapableWrapper = new OptionsCapableWrapper();
}
}
private static final String COLON = ":";
private final Builder optionsBuilder;
private ConnectionFactoryBuilder(Builder optionsBuilder) {
......@@ -63,6 +87,35 @@ public final class ConnectionFactoryBuilder {
return new ConnectionFactoryBuilder(options);
}
/**
* Initialize a new {@link ConnectionFactoryBuilder} derived from the options of the
* specified {@code connectionFactory}.
* @param connectionFactory the connection factory whose options are to be used to
* initialize the builder
* @return a new builder initialized with the options from the connection factory
*/
public static ConnectionFactoryBuilder derivefrom(ConnectionFactory connectionFactory) {
ConnectionFactoryOptions options = extractOptionsIfPossible(connectionFactory);
if (options == null) {
throw new IllegalArgumentException(
"ConnectionFactoryOptions could not be extracted from " + connectionFactory);
}
return withOptions(options.mutate());
}
private static ConnectionFactoryOptions extractOptionsIfPossible(ConnectionFactory connectionFactory) {
if (connectionFactory instanceof OptionsCapableConnectionFactory) {
return ((OptionsCapableConnectionFactory) connectionFactory).getOptions();
}
if (connectionFactory instanceof Wrapped) {
Object unwrapped = ((Wrapped<?>) connectionFactory).unwrap();
if (unwrapped instanceof ConnectionFactory) {
return extractOptionsIfPossible((ConnectionFactory) unwrapped);
}
}
return null;
}
/**
* Configure additional options.
* @param options a {@link Consumer} to customize the options
......@@ -123,7 +176,8 @@ public final class ConnectionFactoryBuilder {
* @return a connection factory
*/
public ConnectionFactory build() {
return ConnectionFactories.get(buildOptions());
ConnectionFactoryOptions options = buildOptions();
return optionsCapableWrapper.buildAndWrap(options);
}
/**
......@@ -134,4 +188,100 @@ public final class ConnectionFactoryBuilder {
return this.optionsBuilder.build();
}
private static class OptionsCapableWrapper {
ConnectionFactory buildAndWrap(ConnectionFactoryOptions options) {
ConnectionFactory connectionFactory = ConnectionFactories.get(options);
return new OptionsCapableConnectionFactory(options, connectionFactory);
}
}
static final class PoolingAwareOptionsCapableWrapper extends OptionsCapableWrapper {
private final PoolingConnectionFactoryProvider poolingProvider = new PoolingConnectionFactoryProvider();
@Override
ConnectionFactory buildAndWrap(ConnectionFactoryOptions options) {
if (!this.poolingProvider.supports(options)) {
return super.buildAndWrap(options);
}
ConnectionFactoryOptions delegateOptions = delegateFactoryOptions(options);
ConnectionFactory connectionFactory = super.buildAndWrap(delegateOptions);
ConnectionPoolConfiguration poolConfiguration = connectionPoolConfiguration(delegateOptions,
connectionFactory);
return new ConnectionPool(poolConfiguration);
}
private ConnectionFactoryOptions delegateFactoryOptions(ConnectionFactoryOptions options) {
String protocol = options.getRequiredValue(ConnectionFactoryOptions.PROTOCOL);
if (protocol.trim().length() == 0) {
throw new IllegalArgumentException(String.format("Protocol %s is not valid.", protocol));
}
String[] protocols = protocol.split(COLON, 2);
String driverDelegate = protocols[0];
String protocolDelegate = (protocols.length != 2) ? "" : protocols[1];
ConnectionFactoryOptions newOptions = ConnectionFactoryOptions.builder().from(options)
.option(ConnectionFactoryOptions.DRIVER, driverDelegate)
.option(ConnectionFactoryOptions.PROTOCOL, protocolDelegate).build();
return newOptions;
}
ConnectionPoolConfiguration connectionPoolConfiguration(ConnectionFactoryOptions options,
ConnectionFactory connectionFactory) {
ConnectionPoolConfiguration.Builder builder = ConnectionPoolConfiguration.builder(connectionFactory);
PropertyMapper map = PropertyMapper.get().alwaysApplyingWhenNonNull();
map.from((Object) options.getValue(PoolingConnectionFactoryProvider.INITIAL_SIZE)).as(this::toInteger)
.to(builder::initialSize);
map.from((Object) options.getValue(PoolingConnectionFactoryProvider.MAX_SIZE)).as(this::toInteger)
.to(builder::maxSize);
map.from((Object) options.getValue(PoolingConnectionFactoryProvider.ACQUIRE_RETRY)).as(this::toInteger)
.to(builder::acquireRetry);
map.from((Object) options.getValue(PoolingConnectionFactoryProvider.MAX_LIFE_TIME)).as(this::toDuration)
.to(builder::maxLifeTime);
map.from((Object) options.getValue(PoolingConnectionFactoryProvider.MAX_ACQUIRE_TIME)).as(this::toDuration)
.to(builder::maxAcquireTime);
map.from((Object) options.getValue(PoolingConnectionFactoryProvider.MAX_IDLE_TIME)).as(this::toDuration)
.to(builder::maxIdleTime);
map.from((Object) options.getValue(PoolingConnectionFactoryProvider.MAX_CREATE_CONNECTION_TIME))
.as(this::toDuration).to(builder::maxCreateConnectionTime);
map.from(options.getValue(PoolingConnectionFactoryProvider.POOL_NAME)).to(builder::name);
map.from((Object) options.getValue(PoolingConnectionFactoryProvider.REGISTER_JMX)).as(this::toBoolean)
.to(builder::registerJmx);
map.from(options.getValue(PoolingConnectionFactoryProvider.VALIDATION_QUERY)).to(builder::validationQuery);
map.from((Object) options.getValue(PoolingConnectionFactoryProvider.VALIDATION_DEPTH))
.as(this::toValidationDepth).to(builder::validationDepth);
ConnectionPoolConfiguration build = builder.build();
return build;
}
private Integer toInteger(Object object) {
return toType(Integer.class, object, Integer::valueOf);
}
private Duration toDuration(Object object) {
return toType(Duration.class, object, Duration::parse);
}
private Boolean toBoolean(Object object) {
return toType(Boolean.class, object, Boolean::valueOf);
}
private ValidationDepth toValidationDepth(Object object) {
return toType(ValidationDepth.class, object,
(string) -> ValidationDepth.valueOf(string.toUpperCase(Locale.ENGLISH)));
}
private <T> T toType(Class<T> type, Object object, Function<String, T> converter) {
if (type.isInstance(object)) {
return type.cast(object);
}
if (object instanceof String) {
return converter.apply((String) object);
}
throw new IllegalArgumentException("Cannot convert '" + object + "' to " + type.getName());
}
}
}
/*
* Copyright 2012-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.r2dbc;
import io.r2dbc.spi.Connection;
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.ConnectionFactoryMetadata;
import io.r2dbc.spi.ConnectionFactoryOptions;
import io.r2dbc.spi.Wrapped;
import org.reactivestreams.Publisher;
/**
* {@link ConnectionFactory} capable of providing access to the
* {@link ConnectionFactoryOptions} from which it was built.
*
* @author Andy Wilkinson
* @since 2.5.0
*/
public class OptionsCapableConnectionFactory implements Wrapped<ConnectionFactory>, ConnectionFactory {
private final ConnectionFactoryOptions options;
private final ConnectionFactory delegate;
/**
* Create a new {@code OptionsCapableConnectionFactory} that will provide access to
* the given {@code options} that were used to build the given {@code delegate}
* {@link ConnectionFactory}.
* @param options the options from which the connection factory was built
* @param delegate the delegate connection factory that was built with options
*/
public OptionsCapableConnectionFactory(ConnectionFactoryOptions options, ConnectionFactory delegate) {
this.options = options;
this.delegate = delegate;
}
public ConnectionFactoryOptions getOptions() {
return this.options;
}
@Override
public Publisher<? extends Connection> create() {
return this.delegate.create();
}
@Override
public ConnectionFactoryMetadata getMetadata() {
return this.delegate.getMetadata();
}
@Override
public ConnectionFactory unwrap() {
return this.delegate;
}
}
......@@ -16,16 +16,30 @@
package org.springframework.boot.r2dbc;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import io.r2dbc.h2.H2ConnectionFactoryMetadata;
import io.r2dbc.pool.ConnectionPool;
import io.r2dbc.pool.ConnectionPoolConfiguration;
import io.r2dbc.pool.PoolingConnectionFactoryProvider;
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.ConnectionFactoryOptions;
import io.r2dbc.spi.Option;
import io.r2dbc.spi.ValidationDepth;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.springframework.boot.r2dbc.ConnectionFactoryBuilder.PoolingAwareOptionsCapableWrapper;
import org.springframework.util.ReflectionUtils;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.Mockito.mock;
/**
* Tests for {@link ConnectionFactoryBuilder}.
......@@ -120,4 +134,146 @@ class ConnectionFactoryBuilderTests {
assertThat(connectionFactory.getMetadata().getName()).isEqualTo(H2ConnectionFactoryMetadata.NAME);
}
@Test
void buildWhenDerivedWithNewDatabaseReturnsNewConnectionFactory() {
String intialDatabaseName = UUID.randomUUID().toString();
ConnectionFactory connectionFactory = ConnectionFactoryBuilder
.withUrl(EmbeddedDatabaseConnection.H2.getUrl(intialDatabaseName)).build();
ConnectionFactoryOptions initialOptions = ((OptionsCapableConnectionFactory) connectionFactory).getOptions();
String derivedDatabaseName = UUID.randomUUID().toString();
ConnectionFactory derived = ConnectionFactoryBuilder.derivefrom(connectionFactory).database(derivedDatabaseName)
.build();
ConnectionFactoryOptions derivedOptions = ((OptionsCapableConnectionFactory) derived).getOptions();
assertThat(derivedOptions.getRequiredValue(ConnectionFactoryOptions.DATABASE)).isEqualTo(derivedDatabaseName);
assertMatchingOptions(derivedOptions, initialOptions, ConnectionFactoryOptions.CONNECT_TIMEOUT,
ConnectionFactoryOptions.DRIVER, ConnectionFactoryOptions.HOST, ConnectionFactoryOptions.PASSWORD,
ConnectionFactoryOptions.PORT, ConnectionFactoryOptions.PROTOCOL, ConnectionFactoryOptions.SSL,
ConnectionFactoryOptions.USER);
}
@Test
void buildWhenDerivedWithNewCredentialsReturnsNewConnectionFactory() {
ConnectionFactory connectionFactory = ConnectionFactoryBuilder
.withUrl(EmbeddedDatabaseConnection.H2.getUrl(UUID.randomUUID().toString())).build();
ConnectionFactoryOptions initialOptions = ((OptionsCapableConnectionFactory) connectionFactory).getOptions();
ConnectionFactory derived = ConnectionFactoryBuilder.derivefrom(connectionFactory).username("admin")
.password("secret").build();
ConnectionFactoryOptions derivedOptions = ((OptionsCapableConnectionFactory) derived).getOptions();
assertThat(derivedOptions.getRequiredValue(ConnectionFactoryOptions.USER)).isEqualTo("admin");
assertThat(derivedOptions.getRequiredValue(ConnectionFactoryOptions.PASSWORD)).isEqualTo("secret");
assertMatchingOptions(derivedOptions, initialOptions, ConnectionFactoryOptions.CONNECT_TIMEOUT,
ConnectionFactoryOptions.DATABASE, ConnectionFactoryOptions.DRIVER, ConnectionFactoryOptions.HOST,
ConnectionFactoryOptions.PORT, ConnectionFactoryOptions.PROTOCOL, ConnectionFactoryOptions.SSL);
}
@Test
void buildWhenDerivedFromPoolReturnsNewNonPooledConnectionFactory() {
ConnectionFactory connectionFactory = ConnectionFactoryBuilder
.withUrl(EmbeddedDatabaseConnection.H2.getUrl(UUID.randomUUID().toString())).build();
ConnectionFactoryOptions initialOptions = ((OptionsCapableConnectionFactory) connectionFactory).getOptions();
ConnectionPoolConfiguration poolConfiguration = ConnectionPoolConfiguration.builder(connectionFactory).build();
ConnectionPool pool = new ConnectionPool(poolConfiguration);
ConnectionFactory derived = ConnectionFactoryBuilder.derivefrom(pool).username("admin").password("secret")
.build();
assertThat(derived).isNotInstanceOf(ConnectionPool.class).isInstanceOf(OptionsCapableConnectionFactory.class);
ConnectionFactoryOptions derivedOptions = ((OptionsCapableConnectionFactory) derived).getOptions();
assertThat(derivedOptions.getRequiredValue(ConnectionFactoryOptions.USER)).isEqualTo("admin");
assertThat(derivedOptions.getRequiredValue(ConnectionFactoryOptions.PASSWORD)).isEqualTo("secret");
assertMatchingOptions(derivedOptions, initialOptions, ConnectionFactoryOptions.CONNECT_TIMEOUT,
ConnectionFactoryOptions.DATABASE, ConnectionFactoryOptions.DRIVER, ConnectionFactoryOptions.HOST,
ConnectionFactoryOptions.PORT, ConnectionFactoryOptions.PROTOCOL, ConnectionFactoryOptions.SSL);
}
@ParameterizedTest
@SuppressWarnings({ "rawtypes", "unchecked" })
@MethodSource("poolingConnectionProviderOptions")
void optionIsMappedWhenCreatingPoolConfiguration(Option option) {
String url = "r2dbc:pool:h2:mem:///" + UUID.randomUUID().toString();
ExpectedOption expectedOption = ExpectedOption.get(option);
ConnectionFactoryOptions options = ConnectionFactoryBuilder.withUrl(url).configure((builder) -> builder
.option(PoolingConnectionFactoryProvider.POOL_NAME, "defaultName").option(option, expectedOption.value))
.buildOptions();
ConnectionPoolConfiguration configuration = new PoolingAwareOptionsCapableWrapper()
.connectionPoolConfiguration(options, mock(ConnectionFactory.class));
assertThat(configuration).extracting(expectedOption.property).isEqualTo(expectedOption.value);
}
@ParameterizedTest
@SuppressWarnings({ "rawtypes", "unchecked" })
@MethodSource("poolingConnectionProviderOptions")
void stringlyTypedOptionIsMappedWhenCreatingPoolConfiguration(Option option) {
String url = "r2dbc:pool:h2:mem:///" + UUID.randomUUID().toString();
ExpectedOption expectedOption = ExpectedOption.get(option);
ConnectionFactoryOptions options = ConnectionFactoryBuilder.withUrl(url)
.configure((builder) -> builder.option(PoolingConnectionFactoryProvider.POOL_NAME, "defaultName")
.option(option, expectedOption.value.toString()))
.buildOptions();
ConnectionPoolConfiguration configuration = new PoolingAwareOptionsCapableWrapper()
.connectionPoolConfiguration(options, mock(ConnectionFactory.class));
assertThat(configuration).extracting(expectedOption.property).isEqualTo(expectedOption.value);
}
private void assertMatchingOptions(ConnectionFactoryOptions actualOptions, ConnectionFactoryOptions expectedOptions,
Option<?>... optionsToCheck) {
for (Option<?> option : optionsToCheck) {
assertThat(actualOptions.getValue(option)).as(option.name()).isEqualTo(expectedOptions.getValue(option));
}
}
private static Iterable<Arguments> poolingConnectionProviderOptions() {
List<Arguments> arguments = new ArrayList<>();
ReflectionUtils.doWithFields(PoolingConnectionFactoryProvider.class,
(field) -> arguments.add(Arguments.of((Option<?>) ReflectionUtils.getField(field, null))),
(field) -> Option.class.equals(field.getType()));
return arguments;
}
private enum ExpectedOption {
ACQUIRE_RETRY(PoolingConnectionFactoryProvider.ACQUIRE_RETRY, 4, "acquireRetry"),
INITIAL_SIZE(PoolingConnectionFactoryProvider.INITIAL_SIZE, 2, "initialSize"),
MAX_SIZE(PoolingConnectionFactoryProvider.MAX_SIZE, 8, "maxSize"),
MAX_LIFE_TIME(PoolingConnectionFactoryProvider.MAX_LIFE_TIME, Duration.ofMinutes(2), "maxLifeTime"),
MAX_ACQUIRE_TIME(PoolingConnectionFactoryProvider.MAX_ACQUIRE_TIME, Duration.ofSeconds(30), "maxAcquireTime"),
MAX_IDLE_TIME(PoolingConnectionFactoryProvider.MAX_IDLE_TIME, Duration.ofMinutes(1), "maxIdleTime"),
MAX_CREATE_CONNECTION_TIME(PoolingConnectionFactoryProvider.MAX_CREATE_CONNECTION_TIME, Duration.ofSeconds(10),
"maxCreateConnectionTime"),
POOL_NAME(PoolingConnectionFactoryProvider.POOL_NAME, "testPool", "name"),
REGISTER_JMX(PoolingConnectionFactoryProvider.REGISTER_JMX, true, "registerJmx"),
VALIDATION_QUERY(PoolingConnectionFactoryProvider.VALIDATION_QUERY, "SELECT 1", "validationQuery"),
VALIDATION_DEPTH(PoolingConnectionFactoryProvider.VALIDATION_DEPTH, ValidationDepth.REMOTE, "validationDepth");
private final Option<?> option;
private final Object value;
private final String property;
ExpectedOption(Option<?> option, Object value, String property) {
this.option = option;
this.value = value;
this.property = property;
}
static ExpectedOption get(Option<?> option) {
for (ExpectedOption expectedOption : values()) {
if (expectedOption.option == option) {
return expectedOption;
}
}
throw new IllegalArgumentException("Unexpected option: '" + option + "'");
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment