Use PreparedStatementCreator for query/update with indexed params

Closes gh-31122
This commit is contained in:
Juergen Hoeller
2023-09-03 00:44:11 +02:00
parent 7595465c21
commit 855fe39b7f
5 changed files with 94 additions and 16 deletions

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -46,8 +46,9 @@ public class PreparedStatementCreatorFactory {
/** The SQL, which won't change when the parameters change. */
private final String sql;
/** List of SqlParameter objects (may not be {@code null}). */
private final List<SqlParameter> declaredParameters;
/** List of SqlParameter objects (may be {@code null}). */
@Nullable
private List<SqlParameter> declaredParameters;
private int resultSetType = ResultSet.TYPE_FORWARD_ONLY;
@@ -66,7 +67,6 @@ public class PreparedStatementCreatorFactory {
*/
public PreparedStatementCreatorFactory(String sql) {
this.sql = sql;
this.declaredParameters = new ArrayList<>();
}
/**
@@ -104,6 +104,9 @@ public class PreparedStatementCreatorFactory {
* @param param the parameter to add to the list of declared parameters
*/
public void addParameter(SqlParameter param) {
if (this.declaredParameters == null) {
this.declaredParameters = new ArrayList<>();
}
this.declaredParameters.add(param);
}
@@ -180,7 +183,7 @@ public class PreparedStatementCreatorFactory {
*/
public PreparedStatementCreator newPreparedStatementCreator(String sqlToUse, @Nullable Object[] params) {
return new PreparedStatementCreatorImpl(
sqlToUse, params != null ? Arrays.asList(params) : Collections.emptyList());
sqlToUse, (params != null ? Arrays.asList(params) : Collections.emptyList()));
}
@@ -201,7 +204,7 @@ public class PreparedStatementCreatorFactory {
public PreparedStatementCreatorImpl(String actualSql, List<?> parameters) {
this.actualSql = actualSql;
this.parameters = parameters;
if (parameters.size() != declaredParameters.size()) {
if (declaredParameters != null && parameters.size() != declaredParameters.size()) {
// Account for named parameters being used multiple times
Set<String> names = new HashSet<>();
for (int i = 0; i < parameters.size(); i++) {
@@ -249,14 +252,14 @@ public class PreparedStatementCreatorFactory {
int sqlColIndx = 1;
for (int i = 0; i < this.parameters.size(); i++) {
Object in = this.parameters.get(i);
SqlParameter declaredParameter;
SqlParameter declaredParameter = null;
// SqlParameterValue overrides declared parameter meta-data, in particular for
// independence from the declared parameter position in case of named parameters.
if (in instanceof SqlParameterValue sqlParameterValue) {
in = sqlParameterValue.getValue();
declaredParameter = sqlParameterValue;
}
else {
else if (declaredParameters != null) {
if (declaredParameters.size() <= i) {
throw new InvalidDataAccessApiUsageException(
"SQL [" + sql + "]: unable to access parameter number " + (i + 1) +
@@ -265,7 +268,10 @@ public class PreparedStatementCreatorFactory {
}
declaredParameter = declaredParameters.get(i);
}
if (in instanceof Iterable<?> entries && declaredParameter.getSqlType() != Types.ARRAY) {
if (declaredParameter == null) {
StatementCreatorUtils.setParameterValue(ps, sqlColIndx++, SqlTypeValue.TYPE_UNKNOWN, in);
}
else if (in instanceof Iterable<?> entries && declaredParameter.getSqlType() != Types.ARRAY) {
for (Object entry : entries) {
if (entry instanceof Object[] valueArray) {
for (Object argValue : valueArray) {

View File

@@ -28,6 +28,8 @@ import javax.sql.DataSource;
import org.springframework.beans.BeanUtils;
import org.springframework.jdbc.core.JdbcOperations;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.PreparedStatementCreator;
import org.springframework.jdbc.core.PreparedStatementCreatorFactory;
import org.springframework.jdbc.core.ResultSetExtractor;
import org.springframework.jdbc.core.RowCallbackHandler;
import org.springframework.jdbc.core.RowMapper;
@@ -206,7 +208,7 @@ final class DefaultJdbcClient implements JdbcClient {
namedParamOps.query(this.sql, this.namedParams, rch);
}
else {
classicOps.query(this.sql, rch, this.indexedParams.toArray());
classicOps.query(getPreparedStatementCreatorForIndexedParams(), rch);
}
}
@@ -214,7 +216,7 @@ final class DefaultJdbcClient implements JdbcClient {
public <T> T query(ResultSetExtractor<T> rse) {
T result = (useNamedParams() ?
namedParamOps.query(this.sql, this.namedParams, rse) :
classicOps.query(this.sql, rse, this.indexedParams.toArray()));
classicOps.query(getPreparedStatementCreatorForIndexedParams(), rse));
Assert.state(result != null, "No result from ResultSetExtractor");
return result;
}
@@ -223,14 +225,14 @@ final class DefaultJdbcClient implements JdbcClient {
public int update() {
return (useNamedParams() ?
namedParamOps.update(this.sql, this.namedParamSource) :
classicOps.update(this.sql, this.indexedParams.toArray()));
classicOps.update(getPreparedStatementCreatorForIndexedParams()));
}
@Override
public int update(KeyHolder generatedKeyHolder) {
return (useNamedParams() ?
namedParamOps.update(this.sql, this.namedParamSource, generatedKeyHolder) :
classicOps.update(this.sql, this.indexedParams.toArray(), generatedKeyHolder));
classicOps.update(getPreparedStatementCreatorForIndexedParams(), generatedKeyHolder));
}
private boolean useNamedParams() {
@@ -245,6 +247,10 @@ final class DefaultJdbcClient implements JdbcClient {
return hasNamedParams;
}
private PreparedStatementCreator getPreparedStatementCreatorForIndexedParams() {
return new PreparedStatementCreatorFactory(this.sql).newPreparedStatementCreator(this.indexedParams);
}
private class IndexedParamResultQuerySpec implements ResultQuerySpec {

View File

@@ -20,6 +20,7 @@ import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Types;
import java.util.ArrayList;
@@ -35,6 +36,8 @@ import org.junit.jupiter.api.Test;
import org.springframework.jdbc.Customer;
import org.springframework.jdbc.core.SqlParameterValue;
import org.springframework.jdbc.support.GeneratedKeyHolder;
import org.springframework.jdbc.support.KeyHolder;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.anyString;
@@ -56,6 +59,9 @@ public class JdbcClientIndexedParameterTests {
private static final String UPDATE_NAMED_PARAMETERS =
"update seat_status set booking_id = null where performance_id = ? and price_band_id = ?";
private static final String INSERT_GENERATE_KEYS =
"insert into show (name) values(?)";
private static final String[] COLUMN_NAMES = new String[] {"id", "forename"};
@@ -67,6 +73,8 @@ public class JdbcClientIndexedParameterTests {
private ResultSet resultSet = mock();
private ResultSetMetaData resultSetMetaData = mock();
private DatabaseMetaData databaseMetaData = mock();
private JdbcClient client = JdbcClient.create(dataSource);
@@ -329,4 +337,28 @@ public class JdbcClientIndexedParameterTests {
verify(connection).close();
}
@Test
public void testUpdateAndGeneratedKeys() throws SQLException {
given(resultSetMetaData.getColumnCount()).willReturn(1);
given(resultSetMetaData.getColumnLabel(1)).willReturn("1");
given(resultSet.getMetaData()).willReturn(resultSetMetaData);
given(resultSet.next()).willReturn(true, false);
given(resultSet.getObject(1)).willReturn(11);
given(preparedStatement.executeUpdate()).willReturn(1);
given(preparedStatement.getGeneratedKeys()).willReturn(resultSet);
given(connection.prepareStatement(INSERT_GENERATE_KEYS, PreparedStatement.RETURN_GENERATED_KEYS))
.willReturn(preparedStatement);
KeyHolder generatedKeyHolder = new GeneratedKeyHolder();
int rowsAffected = client.sql(INSERT_GENERATE_KEYS).param("rod").update(generatedKeyHolder);
assertThat(rowsAffected).isEqualTo(1);
assertThat(generatedKeyHolder.getKeyList()).hasSize(1);
assertThat(generatedKeyHolder.getKey()).isEqualTo(11);
verify(preparedStatement).setString(1, "rod");
verify(resultSet).close();
verify(preparedStatement).close();
verify(connection).close();
}
}

View File

@@ -20,6 +20,7 @@ import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Types;
import java.util.ArrayList;
@@ -37,6 +38,8 @@ import org.junit.jupiter.api.Test;
import org.springframework.jdbc.Customer;
import org.springframework.jdbc.core.SqlParameterValue;
import org.springframework.jdbc.support.GeneratedKeyHolder;
import org.springframework.jdbc.support.KeyHolder;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.anyString;
@@ -62,6 +65,11 @@ public class JdbcClientNamedParameterTests {
private static final String UPDATE_NAMED_PARAMETERS_PARSED =
"update seat_status set booking_id = null where performance_id = ? and price_band_id = ?";
private static final String INSERT_GENERATE_KEYS =
"insert into show (name) values(:name)";
private static final String INSERT_GENERATE_KEYS_PARSED =
"insert into show (name) values(?)";
private static final String[] COLUMN_NAMES = new String[] {"id", "forename"};
@@ -73,6 +81,8 @@ public class JdbcClientNamedParameterTests {
private ResultSet resultSet = mock();
private ResultSetMetaData resultSetMetaData = mock();
private DatabaseMetaData databaseMetaData = mock();
private JdbcClient client = JdbcClient.create(dataSource);
@@ -335,4 +345,28 @@ public class JdbcClientNamedParameterTests {
verify(connection).close();
}
@Test
public void testUpdateAndGeneratedKeys() throws SQLException {
given(resultSetMetaData.getColumnCount()).willReturn(1);
given(resultSetMetaData.getColumnLabel(1)).willReturn("1");
given(resultSet.getMetaData()).willReturn(resultSetMetaData);
given(resultSet.next()).willReturn(true, false);
given(resultSet.getObject(1)).willReturn(11);
given(preparedStatement.executeUpdate()).willReturn(1);
given(preparedStatement.getGeneratedKeys()).willReturn(resultSet);
given(connection.prepareStatement(INSERT_GENERATE_KEYS_PARSED, PreparedStatement.RETURN_GENERATED_KEYS))
.willReturn(preparedStatement);
KeyHolder generatedKeyHolder = new GeneratedKeyHolder();
int rowsAffected = client.sql(INSERT_GENERATE_KEYS).param("name", "rod").update(generatedKeyHolder);
assertThat(rowsAffected).isEqualTo(1);
assertThat(generatedKeyHolder.getKeyList()).hasSize(1);
assertThat(generatedKeyHolder.getKey()).isEqualTo(11);
verify(preparedStatement).setString(1, "rod");
verify(resultSet).close();
verify(preparedStatement).close();
verify(connection).close();
}
}

View File

@@ -211,9 +211,8 @@ public class SqlUpdateTests {
given(resultSet.getObject(1)).willReturn(11);
given(preparedStatement.executeUpdate()).willReturn(1);
given(preparedStatement.getGeneratedKeys()).willReturn(resultSet);
given(connection.prepareStatement(INSERT_GENERATE_KEYS,
PreparedStatement.RETURN_GENERATED_KEYS)
).willReturn(preparedStatement);
given(connection.prepareStatement(INSERT_GENERATE_KEYS, PreparedStatement.RETURN_GENERATED_KEYS))
.willReturn(preparedStatement);
GeneratedKeysUpdater pc = new GeneratedKeysUpdater();
KeyHolder generatedKeyHolder = new GeneratedKeyHolder();
@@ -294,6 +293,7 @@ public class SqlUpdateTests {
pc::run);
}
private class Updater extends SqlUpdate {
public Updater() {