From a611c5fc335c7fec82830123719120f03469bcc3 Mon Sep 17 00:00:00 2001 From: Jens Schauder Date: Mon, 23 Dec 2024 14:09:53 +0100 Subject: [PATCH] `JdbcAggregateTemplate` honors columns specified in query. If no columns are given, all columns are selected by default. If columns are specified, only these are selected. Joins normally triggered by columns from 1:1 relationships are not implemented, and the corresponding columns don't get loaded and can't be specified in a query. Limiting columns is not supported for single query loading. Closes #1803 Original pull request: #1967 --- .../data/jdbc/core/convert/SqlGenerator.java | 112 +++++++++++++----- ...JdbcAggregateTemplateIntegrationTests.java | 26 +++- .../core/convert/SqlGeneratorUnitTests.java | 45 ++++++- .../data/relational/core/query/Query.java | 3 +- 4 files changed, 154 insertions(+), 32 deletions(-) diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SqlGenerator.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SqlGenerator.java index 7ac637e8..064ab548 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SqlGenerator.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SqlGenerator.java @@ -39,6 +39,7 @@ import org.springframework.data.util.Lazy; import org.springframework.jdbc.core.namedparam.MapSqlParameterSource; import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; /** * Generates SQL statements to be used by {@link SimpleJdbcRepository} @@ -507,47 +508,87 @@ class SqlGenerator { } private SelectBuilder.SelectWhere selectBuilder() { - return selectBuilder(Collections.emptyList()); + return selectBuilder(Collections.emptyList(), Query.empty()); + } + + private SelectBuilder.SelectWhere selectBuilder(Query query) { + return selectBuilder(Collections.emptyList(), query); } private SelectBuilder.SelectWhere selectBuilder(Collection keyColumns) { + return selectBuilder(keyColumns, Query.empty()); + } + + private SelectBuilder.SelectWhere selectBuilder(Collection keyColumns, Query query) { Table table = getTable(); - Set columnExpressions = new LinkedHashSet<>(); - - List joinTables = new ArrayList<>(); - for (PersistentPropertyPath path : mappingContext - .findPersistentPropertyPaths(entity.getType(), p -> true)) { - - AggregatePath extPath = mappingContext.getAggregatePath(path); - - // add a join if necessary - Join join = getJoin(extPath); - if (join != null) { - joinTables.add(join); - } - - Column column = getColumn(extPath); - if (column != null) { - columnExpressions.add(column); - } - } - - for (SqlIdentifier keyColumn : keyColumns) { - columnExpressions.add(table.column(keyColumn).as(keyColumn)); - } - - SelectBuilder.SelectAndFrom selectBuilder = StatementBuilder.select(columnExpressions); + Projection projection = getProjection(keyColumns, query, table); + SelectBuilder.SelectAndFrom selectBuilder = StatementBuilder.select(projection.columns()); SelectBuilder.SelectJoin baseSelect = selectBuilder.from(table); - for (Join join : joinTables) { + for (Join join : projection.joins()) { baseSelect = baseSelect.leftOuterJoin(join.joinTable).on(join.joinColumn).equals(join.parentId); } return (SelectBuilder.SelectWhere) baseSelect; } + private Projection getProjection(Collection keyColumns, Query query, Table table) { + + Set columns = new LinkedHashSet<>(); + Set joins = new LinkedHashSet<>(); + + if (!CollectionUtils.isEmpty(query.getColumns())) { + for (SqlIdentifier columnName : query.getColumns()) { + + String columnNameString = columnName.getReference(); + RelationalPersistentProperty property = entity.getPersistentProperty(columnNameString); + if (property != null) { + + AggregatePath aggregatePath = mappingContext.getAggregatePath( + mappingContext.getPersistentPropertyPath(columnNameString, entity.getTypeInformation())); + gatherColumn(aggregatePath, joins, columns); + } else { + columns.add(Column.create(columnName, table)); + } + } + } else { + for (PersistentPropertyPath path : mappingContext + .findPersistentPropertyPaths(entity.getType(), p -> true)) { + + AggregatePath aggregatePath = mappingContext.getAggregatePath(path); + + gatherColumn(aggregatePath, joins, columns); + } + } + + for (SqlIdentifier keyColumn : keyColumns) { + columns.add(table.column(keyColumn).as(keyColumn)); + } + + return new Projection(columns, joins); + } + + private void gatherColumn(AggregatePath aggregatePath, Set joins, Set columns) { + + joins.addAll(getJoins(aggregatePath)); + + Column column = getColumn(aggregatePath); + if (column != null) { + columns.add(column); + } + } + + /** + * Projection including its source joins. + * + * @param columns + * @param joins + */ + record Projection(Set columns, Set joins) { + } + private SelectBuilder.SelectOrdered selectBuilder(Collection keyColumns, Sort sort, Pageable pageable) { @@ -611,9 +652,24 @@ class SqlGenerator { return sqlContext.getColumn(path); } + List getJoins(AggregatePath path) { + + List joins = new ArrayList<>(); + while (!path.isRoot()) { + Join join = getJoin(path); + if (join != null) { + joins.add(join); + } + + path = path.getParentPath(); + } + return joins; + } + @Nullable Join getJoin(AggregatePath path) { + // TODO: This doesn't handle paths with length > 1 correctly if (!path.isEntity() || path.isEmbedded() || path.isMultiValued()) { return null; } @@ -876,7 +932,7 @@ class SqlGenerator { Assert.notNull(parameterSource, "parameterSource must not be null"); - SelectBuilder.SelectWhere selectBuilder = selectBuilder(); + SelectBuilder.SelectWhere selectBuilder = selectBuilder(query); Select select = applyQueryOnSelect(query, parameterSource, selectBuilder) // .build(); diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/AbstractJdbcAggregateTemplateIntegrationTests.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/AbstractJdbcAggregateTemplateIntegrationTests.java index 1cd62f19..28f7f1ed 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/AbstractJdbcAggregateTemplateIntegrationTests.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/AbstractJdbcAggregateTemplateIntegrationTests.java @@ -29,6 +29,7 @@ import java.util.function.Function; import java.util.stream.IntStream; import java.util.stream.Stream; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationContext; @@ -236,7 +237,25 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests { Query query = Query.query(criteria); Iterable reloadedById = template.findAll(query, SimpleListParent.class); - assertThat(reloadedById).extracting(e -> e.id, e -> e.content.size()).containsExactly(tuple(two.id, 2)); + assertThat(reloadedById) // + .extracting(e -> e.id, e-> e.name, e -> e.content.size()) // + .containsExactly(tuple(two.id, two.name, 2)); + } + + @Test // GH-1803 + void findAllByQueryWithColumns() { + + template.save(SimpleListParent.of("one", "one_1")); + SimpleListParent two = template.save(SimpleListParent.of("two", "two_1", "two_2")); + template.save(SimpleListParent.of("three", "three_1", "three_2", "three_3")); + + CriteriaDefinition criteria = CriteriaDefinition.from(Criteria.where("id").is(two.id)); + Query query = Query.query(criteria).columns("id"); + Iterable reloadedById = template.findAll(query, SimpleListParent.class); + + assertThat(reloadedById) // + .extracting(e -> e.id, e-> e.name, e -> e.content.size()) // + .containsExactly(tuple(two.id, null, 2)); } @Test // GH-1601 @@ -2335,5 +2354,10 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests { static class JdbcAggregateTemplateSingleQueryLoadingIntegrationTests extends AbstractJdbcAggregateTemplateIntegrationTests { + @Disabled + @Override + void findAllByQueryWithColumns() { + super.findAllByQueryWithColumns(); + } } } diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/convert/SqlGeneratorUnitTests.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/convert/SqlGeneratorUnitTests.java index cc264cbe..51637b72 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/convert/SqlGeneratorUnitTests.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/convert/SqlGeneratorUnitTests.java @@ -351,12 +351,13 @@ class SqlGeneratorUnitTests { @Test // GH-1919 void selectByQuery() { - Query query = Query.query(Criteria.where("id").is(23L)); + Query query = Query.query(Criteria.where("id").is(23L)).columns(new String[0]); String sql = sqlGenerator.selectByQuery(query, new MapSqlParameterSource()); assertThat(sql).contains( // "SELECT", // + "dummy_entity.id1 AS id1, dummy_entity.x_name AS x_name", // "FROM dummy_entity", // "LEFT OUTER JOIN referenced_entity ref ON ref.dummy_entity = dummy_entity.id1", // "LEFT OUTER JOIN second_level_referenced_entity ref_further ON ref_further.referenced_entity = ref.x_l1id", // @@ -364,6 +365,45 @@ class SqlGeneratorUnitTests { ); } + @Test // GH-1803 + void selectByQueryWithColumnLimit() { + + Query query = Query.empty().columns("id", "alpha", "beta", "gamma"); + + String sql = sqlGenerator.selectByQuery(query, new MapSqlParameterSource()); + + assertThat(sql).contains( // + "SELECT dummy_entity.id1 AS id1, dummy_entity.alpha, dummy_entity.beta, dummy_entity.gamma", // + "FROM dummy_entity" // + ); + } + + @Test // GH-1803 + void selectingSetContentSelectsAllColumns() { + + Query query = Query.empty().columns("elements.content"); + + String sql = sqlGenerator.selectByQuery(query, new MapSqlParameterSource()); + + assertThat(sql).contains( // + "SELECT dummy_entity.id1 AS id1, dummy_entity.x_name AS x_name"// + ); + } + + @Test // GH-1803 + void selectByQueryWithMappedColumnPathsRendersCorrectSelection() { + + Query query = Query.empty().columns("ref.content"); + + String sql = sqlGenerator.selectByQuery(query, new MapSqlParameterSource()); + + assertThat(sql).contains( // + "SELECT", // + "ref.id1 AS id1, ref.content AS x_content", // + "FROM dummy_entity", // + "LEFT OUTER JOIN referenced_entity ref ON ref.dummy_entity = dummy_entity.id1"); + } + @Test // GH-1919 void selectBySortedQuery() { @@ -381,7 +421,8 @@ class SqlGeneratorUnitTests { "ORDER BY dummy_entity.id1 ASC" // ); assertThat(sql).containsOnlyOnce("LEFT OUTER JOIN referenced_entity ref ON ref.dummy_entity = dummy_entity.id1"); - assertThat(sql).containsOnlyOnce("LEFT OUTER JOIN second_level_referenced_entity ref_further ON ref_further.referenced_entity = ref.x_l1id"); + assertThat(sql).containsOnlyOnce( + "LEFT OUTER JOIN second_level_referenced_entity ref_further ON ref_further.referenced_entity = ref.x_l1id"); } @Test // DATAJDBC-131, DATAJDBC-111 diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/query/Query.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/query/Query.java index 3a8e9d72..6d1ed69d 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/query/Query.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/query/Query.java @@ -41,6 +41,7 @@ import org.springframework.util.Assert; */ public class Query { + private static final Query EMPTY = new Query(null); private static final int NO_LIMIT = -1; private final @Nullable CriteriaDefinition criteria; @@ -84,7 +85,7 @@ public class Query { * @return */ public static Query empty() { - return new Query(null); + return EMPTY; } /**