diff --git a/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java b/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java index aa6c5b3..7893dc0 100644 --- a/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java +++ b/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java @@ -327,7 +327,7 @@ public class R2dbcEntityTemplate implements R2dbcEntityOperations, BeanFactoryAw Expression countExpression = entity.hasIdProperty() ? table.column(entity.getRequiredIdProperty().getColumnName()) - : Expressions.asterisk(); + : Expressions.just("1"); return spec.withProjection(Functions.count(countExpression)); }); @@ -362,13 +362,14 @@ public class R2dbcEntityTemplate implements R2dbcEntityOperations, BeanFactoryAw RelationalPersistentEntity entity = getRequiredEntity(entityClass); StatementMapper statementMapper = dataAccessStrategy.getStatementMapper().forType(entityClass); - SqlIdentifier columnName = entity.hasIdProperty() ? entity.getRequiredIdProperty().getColumnName() - : SqlIdentifier.unquoted("*"); + StatementMapper.SelectSpec selectSpec = statementMapper.createSelect(tableName).limit(1); + if (entity.hasIdProperty()) { + selectSpec = selectSpec // + .withProjection(entity.getRequiredIdProperty().getColumnName()); - StatementMapper.SelectSpec selectSpec = statementMapper // - .createSelect(tableName) // - .withProjection(columnName) // - .limit(1); + } else { + selectSpec = selectSpec.withProjection(Expressions.just("1")); + } Optional criteria = query.getCriteria(); if (criteria.isPresent()) { diff --git a/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java b/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java index 1cbd8c4..2bde0d7 100644 --- a/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java +++ b/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java @@ -153,7 +153,8 @@ public class QueryMapper { */ public Expression getMappedObject(Expression expression, @Nullable RelationalPersistentEntity entity) { - if (entity == null || expression instanceof AsteriskFromTable) { + if (entity == null || expression instanceof AsteriskFromTable + || expression instanceof Expressions.SimpleExpression) { return expression; } diff --git a/src/main/java/org/springframework/data/r2dbc/repository/query/R2dbcQueryCreator.java b/src/main/java/org/springframework/data/r2dbc/repository/query/R2dbcQueryCreator.java index b147e89..a49f1b9 100644 --- a/src/main/java/org/springframework/data/r2dbc/repository/query/R2dbcQueryCreator.java +++ b/src/main/java/org/springframework/data/r2dbc/repository/query/R2dbcQueryCreator.java @@ -154,26 +154,22 @@ class R2dbcQueryCreator extends RelationalQueryCreator> { for (String projectedProperty : projectedProperties) { RelationalPersistentProperty property = entity.getPersistentProperty(projectedProperty); - Column column = table.column(property != null ? property.getColumnName() : SqlIdentifier.unquoted(projectedProperty)); + Column column = table + .column(property != null ? property.getColumnName() : SqlIdentifier.unquoted(projectedProperty)); expressions.add(column); } - } else if (tree.isExistsProjection()) { - - expressions = dataAccessStrategy.getIdentifierColumns(entityToRead).stream() - .map(table::column) - .collect(Collectors.toList()); - } else if (tree.isCountProjection()) { + } else if (tree.isExistsProjection() || tree.isCountProjection()) { Expression countExpression = entityMetadata.getTableEntity().hasIdProperty() ? table.column(entityMetadata.getTableEntity().getRequiredIdProperty().getColumnName()) - : Expressions.asterisk(); + : Expressions.just("1"); - expressions = Collections.singletonList(Functions.count(countExpression)); + expressions = Collections + .singletonList(tree.isCountProjection() ? Functions.count(countExpression) : countExpression); } else { - expressions = dataAccessStrategy.getAllColumns(entityToRead).stream() - .map(table::column) - .collect(Collectors.toList()); + expressions = dataAccessStrategy.getAllColumns(entityToRead).stream().map(table::column) + .collect(Collectors.toList()); } return expressions.toArray(new Expression[0]); diff --git a/src/test/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplateUnitTests.java b/src/test/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplateUnitTests.java index 849d912..486798b 100644 --- a/src/test/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplateUnitTests.java +++ b/src/test/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplateUnitTests.java @@ -90,8 +90,7 @@ public class R2dbcEntityTemplateUnitTests { MockRowMetadata metadata = MockRowMetadata.builder() .columnMetadata(MockColumnMetadata.builder().name("name").type(R2dbcType.VARCHAR).build()).build(); - MockResult result = MockResult.builder() - .row(MockRow.builder().identified(0, Long.class, 1L).build()).build(); + MockResult result = MockResult.builder().row(MockRow.builder().identified(0, Long.class, 1L).build()).build(); recorder.addStubbing(s -> s.startsWith("SELECT"), result); @@ -109,10 +108,7 @@ public class R2dbcEntityTemplateUnitTests { @Test // gh-469 void shouldProjectExistsResult() { - MockRowMetadata metadata = MockRowMetadata.builder() - .columnMetadata(MockColumnMetadata.builder().name("name").type(R2dbcType.VARCHAR).build()).build(); - MockResult result = MockResult.builder() - .row(MockRow.builder().identified(0, Object.class, null).build()).build(); + MockResult result = MockResult.builder().row(MockRow.builder().identified(0, Object.class, null).build()).build(); recorder.addStubbing(s -> s.startsWith("SELECT"), result); @@ -124,13 +120,36 @@ public class R2dbcEntityTemplateUnitTests { .verifyComplete(); } + @Test // gh-773 + void shouldProjectExistsResultWithoutId() { + + MockResult result = MockResult.builder().row(MockRow.builder().identified(0, Object.class, null).build()).build(); + + recorder.addStubbing(s -> s.startsWith("SELECT 1"), result); + + entityTemplate.select(WithoutId.class).exists() // + .as(StepVerifier::create) // + .expectNext(true).verifyComplete(); + } + + @Test // gh-773 + void shouldProjectCountResultWithoutId() { + + MockResult result = MockResult.builder().row(MockRow.builder().identified(0, Long.class, 1L).build()).build(); + + recorder.addStubbing(s -> s.startsWith("SELECT COUNT(1)"), result); + + entityTemplate.select(WithoutId.class).count() // + .as(StepVerifier::create) // + .expectNext(1L).verifyComplete(); + } + @Test // gh-469 void shouldExistsByCriteria() { MockRowMetadata metadata = MockRowMetadata.builder() .columnMetadata(MockColumnMetadata.builder().name("name").type(R2dbcType.VARCHAR).build()).build(); - MockResult result = MockResult.builder() - .row(MockRow.builder().identified(0, Long.class, 1L).build()).build(); + MockResult result = MockResult.builder().row(MockRow.builder().identified(0, Long.class, 1L).build()).build(); recorder.addStubbing(s -> s.startsWith("SELECT"), result); @@ -480,6 +499,12 @@ public class R2dbcEntityTemplateUnitTests { Parameter.from("before-save")); } + @Value + static class WithoutId { + + String name; + } + @Value @With static class Person { diff --git a/src/test/java/org/springframework/data/r2dbc/repository/query/PartTreeR2dbcQueryUnitTests.java b/src/test/java/org/springframework/data/r2dbc/repository/query/PartTreeR2dbcQueryUnitTests.java index 602f4a6..ea57a4a 100644 --- a/src/test/java/org/springframework/data/r2dbc/repository/query/PartTreeR2dbcQueryUnitTests.java +++ b/src/test/java/org/springframework/data/r2dbc/repository/query/PartTreeR2dbcQueryUnitTests.java @@ -77,8 +77,6 @@ class PartTreeR2dbcQueryUnitTests { ".age", ".active" }; private static final String[] ALL_FIELDS_ARRAY_PREFIXED = Arrays.stream(ALL_FIELDS_ARRAY).map(f -> TABLE + f) .toArray(String[]::new); - private static final String ALL_FIELDS = String.join(", ", ALL_FIELDS_ARRAY_PREFIXED); - private static final String DISTINCT = "DISTINCT"; @Mock ConnectionFactory connectionFactory; @Mock R2dbcConverter r2dbcConverter; @@ -698,6 +696,32 @@ class PartTreeR2dbcQueryUnitTests { .where(TABLE + ".first_name = $1"); } + @Test // GH-773 + void createsQueryWithoutIdForCountProjection() throws Exception { + + R2dbcQueryMethod queryMethod = getQueryMethod(WithoutIdRepository.class, "countByFirstName", String.class); + PartTreeR2dbcQuery r2dbcQuery = new PartTreeR2dbcQuery(queryMethod, operations, r2dbcConverter, dataAccessStrategy); + PreparedOperation query = createQuery(queryMethod, r2dbcQuery, "John"); + + PreparedOperationAssert.assertThat(query) // + .selects("COUNT(1)") // + .from(TABLE) // + .where(TABLE + ".first_name = $1"); + } + + @Test // GH-773 + void createsQueryWithoutIdForExistsProjection() throws Exception { + + R2dbcQueryMethod queryMethod = getQueryMethod(WithoutIdRepository.class, "existsByFirstName", String.class); + PartTreeR2dbcQuery r2dbcQuery = new PartTreeR2dbcQuery(queryMethod, operations, r2dbcConverter, dataAccessStrategy); + PreparedOperation query = createQuery(queryMethod, r2dbcQuery, "John"); + + PreparedOperationAssert.assertThat(query) // + .selects("1") // + .from(TABLE) // + .where(TABLE + ".first_name = $1 LIMIT 1"); + } + private PreparedOperation createQuery(R2dbcQueryMethod queryMethod, PartTreeR2dbcQuery r2dbcQuery, Object... parameters) { return createQuery(r2dbcQuery, getAccessor(queryMethod, parameters)); @@ -709,8 +733,13 @@ class PartTreeR2dbcQueryUnitTests { } private R2dbcQueryMethod getQueryMethod(String methodName, Class... parameterTypes) throws Exception { - Method method = UserRepository.class.getMethod(methodName, parameterTypes); - return new R2dbcQueryMethod(method, new DefaultRepositoryMetadata(UserRepository.class), + return getQueryMethod(UserRepository.class, methodName, parameterTypes); + } + + private R2dbcQueryMethod getQueryMethod(Class repository, String methodName, Class... parameterTypes) + throws Exception { + Method method = repository.getMethod(methodName, parameterTypes); + return new R2dbcQueryMethod(method, new DefaultRepositoryMetadata(repository), new SpelAwareProxyProjectionFactory(), mappingContext); } @@ -887,6 +916,13 @@ class PartTreeR2dbcQueryUnitTests { Mono countByFirstName(String firstName); } + interface WithoutIdRepository extends Repository { + + Mono existsByFirstName(String firstName); + + Mono countByFirstName(String firstName); + } + @Table("users") @Data private static class User { @@ -899,6 +935,13 @@ class PartTreeR2dbcQueryUnitTests { private Boolean active; } + @Table("users") + @Data + private static class WithoutId { + + private String firstName; + } + interface UserProjection { String getFirstName();