Use PreparedStatementCreator for query/update with indexed params
Closes gh-31122
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2002-2022 the original author or authors.
|
||||
* Copyright 2002-2023 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -46,8 +46,9 @@ public class PreparedStatementCreatorFactory {
|
||||
/** The SQL, which won't change when the parameters change. */
|
||||
private final String sql;
|
||||
|
||||
/** List of SqlParameter objects (may not be {@code null}). */
|
||||
private final List<SqlParameter> declaredParameters;
|
||||
/** List of SqlParameter objects (may be {@code null}). */
|
||||
@Nullable
|
||||
private List<SqlParameter> declaredParameters;
|
||||
|
||||
private int resultSetType = ResultSet.TYPE_FORWARD_ONLY;
|
||||
|
||||
@@ -66,7 +67,6 @@ public class PreparedStatementCreatorFactory {
|
||||
*/
|
||||
public PreparedStatementCreatorFactory(String sql) {
|
||||
this.sql = sql;
|
||||
this.declaredParameters = new ArrayList<>();
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -104,6 +104,9 @@ public class PreparedStatementCreatorFactory {
|
||||
* @param param the parameter to add to the list of declared parameters
|
||||
*/
|
||||
public void addParameter(SqlParameter param) {
|
||||
if (this.declaredParameters == null) {
|
||||
this.declaredParameters = new ArrayList<>();
|
||||
}
|
||||
this.declaredParameters.add(param);
|
||||
}
|
||||
|
||||
@@ -180,7 +183,7 @@ public class PreparedStatementCreatorFactory {
|
||||
*/
|
||||
public PreparedStatementCreator newPreparedStatementCreator(String sqlToUse, @Nullable Object[] params) {
|
||||
return new PreparedStatementCreatorImpl(
|
||||
sqlToUse, params != null ? Arrays.asList(params) : Collections.emptyList());
|
||||
sqlToUse, (params != null ? Arrays.asList(params) : Collections.emptyList()));
|
||||
}
|
||||
|
||||
|
||||
@@ -201,7 +204,7 @@ public class PreparedStatementCreatorFactory {
|
||||
public PreparedStatementCreatorImpl(String actualSql, List<?> parameters) {
|
||||
this.actualSql = actualSql;
|
||||
this.parameters = parameters;
|
||||
if (parameters.size() != declaredParameters.size()) {
|
||||
if (declaredParameters != null && parameters.size() != declaredParameters.size()) {
|
||||
// Account for named parameters being used multiple times
|
||||
Set<String> names = new HashSet<>();
|
||||
for (int i = 0; i < parameters.size(); i++) {
|
||||
@@ -249,14 +252,14 @@ public class PreparedStatementCreatorFactory {
|
||||
int sqlColIndx = 1;
|
||||
for (int i = 0; i < this.parameters.size(); i++) {
|
||||
Object in = this.parameters.get(i);
|
||||
SqlParameter declaredParameter;
|
||||
SqlParameter declaredParameter = null;
|
||||
// SqlParameterValue overrides declared parameter meta-data, in particular for
|
||||
// independence from the declared parameter position in case of named parameters.
|
||||
if (in instanceof SqlParameterValue sqlParameterValue) {
|
||||
in = sqlParameterValue.getValue();
|
||||
declaredParameter = sqlParameterValue;
|
||||
}
|
||||
else {
|
||||
else if (declaredParameters != null) {
|
||||
if (declaredParameters.size() <= i) {
|
||||
throw new InvalidDataAccessApiUsageException(
|
||||
"SQL [" + sql + "]: unable to access parameter number " + (i + 1) +
|
||||
@@ -265,7 +268,10 @@ public class PreparedStatementCreatorFactory {
|
||||
}
|
||||
declaredParameter = declaredParameters.get(i);
|
||||
}
|
||||
if (in instanceof Iterable<?> entries && declaredParameter.getSqlType() != Types.ARRAY) {
|
||||
if (declaredParameter == null) {
|
||||
StatementCreatorUtils.setParameterValue(ps, sqlColIndx++, SqlTypeValue.TYPE_UNKNOWN, in);
|
||||
}
|
||||
else if (in instanceof Iterable<?> entries && declaredParameter.getSqlType() != Types.ARRAY) {
|
||||
for (Object entry : entries) {
|
||||
if (entry instanceof Object[] valueArray) {
|
||||
for (Object argValue : valueArray) {
|
||||
|
||||
@@ -28,6 +28,8 @@ import javax.sql.DataSource;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.jdbc.core.JdbcOperations;
|
||||
import org.springframework.jdbc.core.JdbcTemplate;
|
||||
import org.springframework.jdbc.core.PreparedStatementCreator;
|
||||
import org.springframework.jdbc.core.PreparedStatementCreatorFactory;
|
||||
import org.springframework.jdbc.core.ResultSetExtractor;
|
||||
import org.springframework.jdbc.core.RowCallbackHandler;
|
||||
import org.springframework.jdbc.core.RowMapper;
|
||||
@@ -206,7 +208,7 @@ final class DefaultJdbcClient implements JdbcClient {
|
||||
namedParamOps.query(this.sql, this.namedParams, rch);
|
||||
}
|
||||
else {
|
||||
classicOps.query(this.sql, rch, this.indexedParams.toArray());
|
||||
classicOps.query(getPreparedStatementCreatorForIndexedParams(), rch);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -214,7 +216,7 @@ final class DefaultJdbcClient implements JdbcClient {
|
||||
public <T> T query(ResultSetExtractor<T> rse) {
|
||||
T result = (useNamedParams() ?
|
||||
namedParamOps.query(this.sql, this.namedParams, rse) :
|
||||
classicOps.query(this.sql, rse, this.indexedParams.toArray()));
|
||||
classicOps.query(getPreparedStatementCreatorForIndexedParams(), rse));
|
||||
Assert.state(result != null, "No result from ResultSetExtractor");
|
||||
return result;
|
||||
}
|
||||
@@ -223,14 +225,14 @@ final class DefaultJdbcClient implements JdbcClient {
|
||||
public int update() {
|
||||
return (useNamedParams() ?
|
||||
namedParamOps.update(this.sql, this.namedParamSource) :
|
||||
classicOps.update(this.sql, this.indexedParams.toArray()));
|
||||
classicOps.update(getPreparedStatementCreatorForIndexedParams()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public int update(KeyHolder generatedKeyHolder) {
|
||||
return (useNamedParams() ?
|
||||
namedParamOps.update(this.sql, this.namedParamSource, generatedKeyHolder) :
|
||||
classicOps.update(this.sql, this.indexedParams.toArray(), generatedKeyHolder));
|
||||
classicOps.update(getPreparedStatementCreatorForIndexedParams(), generatedKeyHolder));
|
||||
}
|
||||
|
||||
private boolean useNamedParams() {
|
||||
@@ -245,6 +247,10 @@ final class DefaultJdbcClient implements JdbcClient {
|
||||
return hasNamedParams;
|
||||
}
|
||||
|
||||
private PreparedStatementCreator getPreparedStatementCreatorForIndexedParams() {
|
||||
return new PreparedStatementCreatorFactory(this.sql).newPreparedStatementCreator(this.indexedParams);
|
||||
}
|
||||
|
||||
|
||||
private class IndexedParamResultQuerySpec implements ResultQuerySpec {
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ import java.sql.Connection;
|
||||
import java.sql.DatabaseMetaData;
|
||||
import java.sql.PreparedStatement;
|
||||
import java.sql.ResultSet;
|
||||
import java.sql.ResultSetMetaData;
|
||||
import java.sql.SQLException;
|
||||
import java.sql.Types;
|
||||
import java.util.ArrayList;
|
||||
@@ -35,6 +36,8 @@ import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.jdbc.Customer;
|
||||
import org.springframework.jdbc.core.SqlParameterValue;
|
||||
import org.springframework.jdbc.support.GeneratedKeyHolder;
|
||||
import org.springframework.jdbc.support.KeyHolder;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
@@ -56,6 +59,9 @@ public class JdbcClientIndexedParameterTests {
|
||||
private static final String UPDATE_NAMED_PARAMETERS =
|
||||
"update seat_status set booking_id = null where performance_id = ? and price_band_id = ?";
|
||||
|
||||
private static final String INSERT_GENERATE_KEYS =
|
||||
"insert into show (name) values(?)";
|
||||
|
||||
private static final String[] COLUMN_NAMES = new String[] {"id", "forename"};
|
||||
|
||||
|
||||
@@ -67,6 +73,8 @@ public class JdbcClientIndexedParameterTests {
|
||||
|
||||
private ResultSet resultSet = mock();
|
||||
|
||||
private ResultSetMetaData resultSetMetaData = mock();
|
||||
|
||||
private DatabaseMetaData databaseMetaData = mock();
|
||||
|
||||
private JdbcClient client = JdbcClient.create(dataSource);
|
||||
@@ -329,4 +337,28 @@ public class JdbcClientIndexedParameterTests {
|
||||
verify(connection).close();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUpdateAndGeneratedKeys() throws SQLException {
|
||||
given(resultSetMetaData.getColumnCount()).willReturn(1);
|
||||
given(resultSetMetaData.getColumnLabel(1)).willReturn("1");
|
||||
given(resultSet.getMetaData()).willReturn(resultSetMetaData);
|
||||
given(resultSet.next()).willReturn(true, false);
|
||||
given(resultSet.getObject(1)).willReturn(11);
|
||||
given(preparedStatement.executeUpdate()).willReturn(1);
|
||||
given(preparedStatement.getGeneratedKeys()).willReturn(resultSet);
|
||||
given(connection.prepareStatement(INSERT_GENERATE_KEYS, PreparedStatement.RETURN_GENERATED_KEYS))
|
||||
.willReturn(preparedStatement);
|
||||
|
||||
KeyHolder generatedKeyHolder = new GeneratedKeyHolder();
|
||||
int rowsAffected = client.sql(INSERT_GENERATE_KEYS).param("rod").update(generatedKeyHolder);
|
||||
|
||||
assertThat(rowsAffected).isEqualTo(1);
|
||||
assertThat(generatedKeyHolder.getKeyList()).hasSize(1);
|
||||
assertThat(generatedKeyHolder.getKey()).isEqualTo(11);
|
||||
verify(preparedStatement).setString(1, "rod");
|
||||
verify(resultSet).close();
|
||||
verify(preparedStatement).close();
|
||||
verify(connection).close();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import java.sql.Connection;
|
||||
import java.sql.DatabaseMetaData;
|
||||
import java.sql.PreparedStatement;
|
||||
import java.sql.ResultSet;
|
||||
import java.sql.ResultSetMetaData;
|
||||
import java.sql.SQLException;
|
||||
import java.sql.Types;
|
||||
import java.util.ArrayList;
|
||||
@@ -37,6 +38,8 @@ import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.jdbc.Customer;
|
||||
import org.springframework.jdbc.core.SqlParameterValue;
|
||||
import org.springframework.jdbc.support.GeneratedKeyHolder;
|
||||
import org.springframework.jdbc.support.KeyHolder;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
@@ -62,6 +65,11 @@ public class JdbcClientNamedParameterTests {
|
||||
private static final String UPDATE_NAMED_PARAMETERS_PARSED =
|
||||
"update seat_status set booking_id = null where performance_id = ? and price_band_id = ?";
|
||||
|
||||
private static final String INSERT_GENERATE_KEYS =
|
||||
"insert into show (name) values(:name)";
|
||||
private static final String INSERT_GENERATE_KEYS_PARSED =
|
||||
"insert into show (name) values(?)";
|
||||
|
||||
private static final String[] COLUMN_NAMES = new String[] {"id", "forename"};
|
||||
|
||||
|
||||
@@ -73,6 +81,8 @@ public class JdbcClientNamedParameterTests {
|
||||
|
||||
private ResultSet resultSet = mock();
|
||||
|
||||
private ResultSetMetaData resultSetMetaData = mock();
|
||||
|
||||
private DatabaseMetaData databaseMetaData = mock();
|
||||
|
||||
private JdbcClient client = JdbcClient.create(dataSource);
|
||||
@@ -335,4 +345,28 @@ public class JdbcClientNamedParameterTests {
|
||||
verify(connection).close();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUpdateAndGeneratedKeys() throws SQLException {
|
||||
given(resultSetMetaData.getColumnCount()).willReturn(1);
|
||||
given(resultSetMetaData.getColumnLabel(1)).willReturn("1");
|
||||
given(resultSet.getMetaData()).willReturn(resultSetMetaData);
|
||||
given(resultSet.next()).willReturn(true, false);
|
||||
given(resultSet.getObject(1)).willReturn(11);
|
||||
given(preparedStatement.executeUpdate()).willReturn(1);
|
||||
given(preparedStatement.getGeneratedKeys()).willReturn(resultSet);
|
||||
given(connection.prepareStatement(INSERT_GENERATE_KEYS_PARSED, PreparedStatement.RETURN_GENERATED_KEYS))
|
||||
.willReturn(preparedStatement);
|
||||
|
||||
KeyHolder generatedKeyHolder = new GeneratedKeyHolder();
|
||||
int rowsAffected = client.sql(INSERT_GENERATE_KEYS).param("name", "rod").update(generatedKeyHolder);
|
||||
|
||||
assertThat(rowsAffected).isEqualTo(1);
|
||||
assertThat(generatedKeyHolder.getKeyList()).hasSize(1);
|
||||
assertThat(generatedKeyHolder.getKey()).isEqualTo(11);
|
||||
verify(preparedStatement).setString(1, "rod");
|
||||
verify(resultSet).close();
|
||||
verify(preparedStatement).close();
|
||||
verify(connection).close();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -211,9 +211,8 @@ public class SqlUpdateTests {
|
||||
given(resultSet.getObject(1)).willReturn(11);
|
||||
given(preparedStatement.executeUpdate()).willReturn(1);
|
||||
given(preparedStatement.getGeneratedKeys()).willReturn(resultSet);
|
||||
given(connection.prepareStatement(INSERT_GENERATE_KEYS,
|
||||
PreparedStatement.RETURN_GENERATED_KEYS)
|
||||
).willReturn(preparedStatement);
|
||||
given(connection.prepareStatement(INSERT_GENERATE_KEYS, PreparedStatement.RETURN_GENERATED_KEYS))
|
||||
.willReturn(preparedStatement);
|
||||
|
||||
GeneratedKeysUpdater pc = new GeneratedKeysUpdater();
|
||||
KeyHolder generatedKeyHolder = new GeneratedKeyHolder();
|
||||
@@ -294,6 +293,7 @@ public class SqlUpdateTests {
|
||||
pc::run);
|
||||
}
|
||||
|
||||
|
||||
private class Updater extends SqlUpdate {
|
||||
|
||||
public Updater() {
|
||||
|
||||
Reference in New Issue
Block a user