diff --git a/src/main/java/org/springframework/data/r2dbc/core/NamedParameterUtils.java b/src/main/java/org/springframework/data/r2dbc/core/NamedParameterUtils.java index 88ff313..11a1d7a 100644 --- a/src/main/java/org/springframework/data/r2dbc/core/NamedParameterUtils.java +++ b/src/main/java/org/springframework/data/r2dbc/core/NamedParameterUtils.java @@ -452,6 +452,12 @@ abstract class NamedParameterUtils { List bindMarkers = getBindMarkers(identifier); + if (bindMarkers == null) { + + target.bindNull(identifier, valueType); + return; + } + if (bindMarkers.size() == 1) { bindMarkers.get(0).bindNull(target, valueType); return; diff --git a/src/test/java/org/springframework/data/r2dbc/core/DefaultDatabaseClientUnitTests.java b/src/test/java/org/springframework/data/r2dbc/core/DefaultDatabaseClientUnitTests.java index b4b7bcc..c0dd2a0 100644 --- a/src/test/java/org/springframework/data/r2dbc/core/DefaultDatabaseClientUnitTests.java +++ b/src/test/java/org/springframework/data/r2dbc/core/DefaultDatabaseClientUnitTests.java @@ -19,9 +19,11 @@ import static org.mockito.Mockito.*; import io.r2dbc.spi.Connection; import io.r2dbc.spi.ConnectionFactory; +import io.r2dbc.spi.Statement; import reactor.core.CoreSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; import org.junit.Before; import org.junit.Test; @@ -30,9 +32,8 @@ import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; import org.reactivestreams.Publisher; import org.reactivestreams.Subscription; -import org.springframework.data.r2dbc.core.DatabaseClient; -import org.springframework.data.r2dbc.core.DefaultDatabaseClient; -import org.springframework.data.r2dbc.core.ReactiveDataAccessStrategy; + +import org.springframework.data.r2dbc.dialect.PostgresDialect; import org.springframework.data.r2dbc.support.R2dbcExceptionTranslator; /** @@ -45,7 +46,6 @@ public class DefaultDatabaseClientUnitTests { @Mock ConnectionFactory connectionFactory; @Mock Connection connection; - @Mock ReactiveDataAccessStrategy strategy; @Mock R2dbcExceptionTranslator translator; @Before @@ -58,7 +58,9 @@ public class DefaultDatabaseClientUnitTests { public void shouldCloseConnectionOnlyOnce() { DefaultDatabaseClient databaseClient = (DefaultDatabaseClient) DatabaseClient.builder() - .connectionFactory(connectionFactory).dataAccessStrategy(strategy).exceptionTranslator(translator).build(); + .connectionFactory(connectionFactory) + .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)) + .exceptionTranslator(translator).build(); Flux flux = databaseClient.inConnectionMany(it -> { return Flux.empty(); @@ -87,4 +89,100 @@ public class DefaultDatabaseClientUnitTests { verify(connection, times(1)).close(); } + + @Test // gh-128 + public void executeShouldBindNullValues() { + + Statement statement = mock(Statement.class); + when(connection.createStatement("SELECT * FROM table WHERE key = $1")).thenReturn(statement); + when(statement.execute()).thenReturn(Mono.empty()); + + DefaultDatabaseClient databaseClient = (DefaultDatabaseClient) DatabaseClient.builder() + .connectionFactory(connectionFactory) + .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)).build(); + + databaseClient.execute("SELECT * FROM table WHERE key = $1") // + .bindNull(0, String.class) // + .then() // + .as(StepVerifier::create) // + .verifyComplete(); + + verify(statement).bindNull(0, String.class); + + databaseClient.execute("SELECT * FROM table WHERE key = $1") // + .bindNull("$1", String.class) // + .then() // + .as(StepVerifier::create) // + .verifyComplete(); + + verify(statement).bindNull("$1", String.class); + } + + @Test // gh-128 + public void executeShouldBindNamedNullValues() { + + Statement statement = mock(Statement.class); + when(connection.createStatement("SELECT * FROM table WHERE key = $1")).thenReturn(statement); + when(statement.execute()).thenReturn(Mono.empty()); + + DefaultDatabaseClient databaseClient = (DefaultDatabaseClient) DatabaseClient.builder() + .connectionFactory(connectionFactory) + .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)).build(); + + databaseClient.execute("SELECT * FROM table WHERE key = :key") // + .bindNull("key", String.class) // + .then() // + .as(StepVerifier::create) // + .verifyComplete(); + + verify(statement).bindNull(0, String.class); + } + + @Test // gh-128 + public void executeShouldBindValues() { + + Statement statement = mock(Statement.class); + when(connection.createStatement("SELECT * FROM table WHERE key = $1")).thenReturn(statement); + when(statement.execute()).thenReturn(Mono.empty()); + + DefaultDatabaseClient databaseClient = (DefaultDatabaseClient) DatabaseClient.builder() + .connectionFactory(connectionFactory) + .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)).build(); + + databaseClient.execute("SELECT * FROM table WHERE key = $1") // + .bind(0, "foo") // + .then() // + .as(StepVerifier::create) // + .verifyComplete(); + + verify(statement).bind(0, "foo"); + + databaseClient.execute("SELECT * FROM table WHERE key = $1") // + .bind("$1", "foo") // + .then() // + .as(StepVerifier::create) // + .verifyComplete(); + + verify(statement).bind("$1", "foo"); + } + + @Test // gh-128 + public void executeShouldBindNamedValuesByIndex() { + + Statement statement = mock(Statement.class); + when(connection.createStatement("SELECT * FROM table WHERE key = $1")).thenReturn(statement); + when(statement.execute()).thenReturn(Mono.empty()); + + DefaultDatabaseClient databaseClient = (DefaultDatabaseClient) DatabaseClient.builder() + .connectionFactory(connectionFactory) + .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)).build(); + + databaseClient.execute("SELECT * FROM table WHERE key = :key") // + .bind("key", "foo") // + .then() // + .as(StepVerifier::create) // + .verifyComplete(); + + verify(statement).bind(0, "foo"); + } }