Use LocalVariableNameFactory in repository contributor.

Closes #3875
This commit is contained in:
Mark Paluch
2025-05-08 16:27:07 +02:00
parent 65214fb0b1
commit e1cec313f3
4 changed files with 96 additions and 77 deletions

View File

@@ -81,7 +81,7 @@ class JpaCodeBlocks {
private final AotQueryMethodGenerationContext context;
private final JpaQueryMethod queryMethod;
private String queryVariableName = "query";
private String queryVariableName;
private @Nullable AotQueries queries;
private MergedAnnotation<QueryHints> queryHints = MergedAnnotation.missing();
private @Nullable AotEntityGraph entityGraph;
@@ -92,11 +92,12 @@ class JpaCodeBlocks {
private QueryBlockBuilder(AotQueryMethodGenerationContext context, JpaQueryMethod queryMethod) {
this.context = context;
this.queryMethod = queryMethod;
this.queryVariableName = context.localVariable("query");
}
public QueryBlockBuilder usingQueryVariableName(String queryVariableName) {
this.queryVariableName = queryVariableName;
this.queryVariableName = context.localVariable(queryVariableName);
return this;
}
@@ -153,14 +154,13 @@ class JpaCodeBlocks {
}
CodeBlock.Builder builder = CodeBlock.builder();
builder.add("\n");
String queryStringVariableName = null;
String queryRewriterName = null;
if (queries.result() instanceof StringAotQuery && queryRewriter != QueryRewriter.IdentityQueryRewriter.class) {
queryRewriterName = "queryRewriter";
queryRewriterName = context.localVariable("queryRewriter");
builder.addStatement("$T $L = new $T()", queryRewriter, queryRewriterName, queryRewriter);
}
@@ -171,11 +171,13 @@ class JpaCodeBlocks {
}
String countQueryStringNameVariableName = null;
String countQueryVariableName = "count%s".formatted(StringUtils.capitalize(queryVariableName));
String countQueryVariableName = context
.localVariable("count%s".formatted(StringUtils.capitalize(queryVariableName)));
if (queryMethod.isPageQuery() && queries.count() instanceof StringAotQuery sq) {
countQueryStringNameVariableName = "count%sString".formatted(StringUtils.capitalize(queryVariableName));
countQueryStringNameVariableName = context
.localVariable("count%sString".formatted(StringUtils.capitalize(queryVariableName)));
builder.add(buildQueryString(sq, countQueryStringNameVariableName));
}
@@ -201,7 +203,7 @@ class JpaCodeBlocks {
if (queryMethod.isPageQuery()) {
builder.beginControlFlow("$T $L = () ->", LongSupplier.class, "countAll");
builder.beginControlFlow("$T $L = () ->", LongSupplier.class, context.localVariable("countAll"));
boolean queryHints = this.queryHints.isPresent() && this.queryHints.getBoolean("forCounting");
@@ -235,17 +237,21 @@ class JpaCodeBlocks {
builder.beginControlFlow("if ($L.isSorted())", sort);
}
builder.addStatement("$T declaredQuery = $T.$L($L)", DeclaredQuery.class, DeclaredQuery.class,
builder.addStatement("$T $L = $T.$L($L)", DeclaredQuery.class, context.localVariable("declaredQuery"),
DeclaredQuery.class,
queries != null && queries.isNative() ? "nativeQuery" : "jpqlQuery", queryString);
boolean hasDynamicReturnType = StringUtils.hasText(dynamicReturnType);
if (hasSort && hasDynamicReturnType) {
builder.addStatement("$L = rewriteQuery(declaredQuery, $L, $L)", queryString, sort, dynamicReturnType);
builder.addStatement("$L = rewriteQuery($L, $L, $L)", queryString, context.localVariable("declaredQuery"), sort,
dynamicReturnType);
} else if (hasSort) {
builder.addStatement("$L = rewriteQuery(declaredQuery, $L, $T.class)", queryString, sort, actualReturnType);
builder.addStatement("$L = rewriteQuery($L, $L, $T.class)", queryString, context.localVariable("declaredQuery"),
sort, actualReturnType);
} else if (hasDynamicReturnType) {
builder.addStatement("$L = rewriteQuery(declaredQuery, $T.unsorted(), $L)", queryString, Sort.class,
builder.addStatement("$L = rewriteQuery($L, $T.unsorted(), $L)", context.localVariable("declaredQuery"),
queryString, Sort.class,
dynamicReturnType);
}
@@ -470,19 +476,21 @@ class JpaCodeBlocks {
if (StringUtils.hasText(entityGraph.name())) {
builder.addStatement("$T<?> entityGraph = $L.getEntityGraph($S)", jakarta.persistence.EntityGraph.class,
builder.addStatement("$T<?> $L = $L.getEntityGraph($S)", jakarta.persistence.EntityGraph.class,
context.localVariable("entityGraph"),
context.fieldNameOf(EntityManager.class), entityGraph.name());
} else {
builder.addStatement("$T<$T> entityGraph = $L.createEntityGraph($T.class)",
builder.addStatement("$T<$T> $L = $L.createEntityGraph($T.class)",
jakarta.persistence.EntityGraph.class, context.getActualReturnType().getType(),
context.localVariable("entityGraph"),
context.fieldNameOf(EntityManager.class), context.getActualReturnType().getType());
for (String attributePath : entityGraph.attributePaths()) {
String[] pathComponents = StringUtils.delimitedListToStringArray(attributePath, ".");
StringBuilder chain = new StringBuilder("entityGraph");
StringBuilder chain = new StringBuilder(context.localVariable("entityGraph"));
for (int i = 0; i < pathComponents.length; i++) {
if (i < pathComponents.length - 1) {
@@ -495,7 +503,8 @@ class JpaCodeBlocks {
builder.addStatement(chain.toString(), (Object[]) pathComponents);
}
builder.addStatement("$L.setHint($S, entityGraph)", queryVariableName, entityGraph.type().getKey());
builder.addStatement("$L.setHint($S, $L)", queryVariableName, entityGraph.type().getKey(),
context.localVariable("entityGraph"));
}
return builder.build();
@@ -521,17 +530,19 @@ class JpaCodeBlocks {
private final AotQueryMethodGenerationContext context;
private final JpaQueryMethod queryMethod;
private @Nullable AotQuery aotQuery;
private String queryVariableName = "query";
private String queryVariableName;
private MergedAnnotation<Modifying> modifying = MergedAnnotation.missing();
private QueryExecutionBlockBuilder(AotQueryMethodGenerationContext context, JpaQueryMethod queryMethod) {
this.context = context;
this.queryMethod = queryMethod;
this.queryVariableName = context.localVariable("query");
}
public QueryExecutionBlockBuilder referencing(String queryVariableName) {
this.queryVariableName = queryVariableName;
this.queryVariableName = context.localVariable(queryVariableName);
return this;
}
@@ -567,7 +578,7 @@ class JpaCodeBlocks {
Class<?> returnType = context.getMethod().getReturnType();
if (returnsModifying(returnType)) {
builder.addStatement("int result = $L.executeUpdate()", queryVariableName);
builder.addStatement("int $L = $L.executeUpdate()", context.localVariable("result"), queryVariableName);
} else {
builder.addStatement("$L.executeUpdate()", queryVariableName);
}
@@ -577,11 +588,11 @@ class JpaCodeBlocks {
}
if (returnType == int.class || returnType == long.class || returnType == Integer.class) {
builder.addStatement("return result");
builder.addStatement("return $L", context.localVariable("result"));
}
if (returnType == Long.class) {
builder.addStatement("return (long) result");
builder.addStatement("return (long) $L", context.localVariable("result"));
}
return builder.build();
@@ -589,16 +600,20 @@ class JpaCodeBlocks {
if (aotQuery != null && aotQuery.isDelete()) {
builder.addStatement("$T<$T> resultList = $L.getResultList()", List.class, actualReturnType, queryVariableName);
builder.addStatement("resultList.forEach($L::remove)", context.fieldNameOf(EntityManager.class));
builder.addStatement("$T<$T> $L = $L.getResultList()", List.class, actualReturnType,
context.localVariable("resultList"), queryVariableName);
builder.addStatement("$L.forEach($L::remove)", context.localVariable("resultList"),
context.fieldNameOf(EntityManager.class));
if (!context.getReturnType().isAssignableFrom(List.class)) {
if (ClassUtils.isAssignable(Number.class, context.getMethod().getReturnType())) {
builder.addStatement("return $T.valueOf(resultList.size())", context.getMethod().getReturnType());
builder.addStatement("return $T.valueOf($L.size())", context.getMethod().getReturnType(),
context.localVariable("resultList"));
} else {
builder.addStatement("return resultList.isEmpty() ? null : resultList.iterator().next()");
builder.addStatement("return $L.isEmpty() ? null : $L.iterator().next()",
context.localVariable("resultList"), context.localVariable("resultList"));
}
} else {
builder.addStatement("return resultList");
builder.addStatement("return $L", context.localVariable("resultList"));
}
} else if (aotQuery != null && aotQuery.isExists()) {
builder.addStatement("return !$L.getResultList().isEmpty()", queryVariableName);
@@ -609,25 +624,29 @@ class JpaCodeBlocks {
TypeName queryResultType = TypeName.get(context.getActualReturnType().toClass());
if (queryMethod.isCollectionQuery()) {
builder.addStatement("return ($T) convertMany(query.getResultList(), $L, $T.class)",
context.getReturnTypeName(), aotQuery.isNative(), queryResultType);
builder.addStatement("return ($T) convertMany($L.getResultList(), $L, $T.class)",
context.getReturnTypeName(), queryVariableName, aotQuery.isNative(), queryResultType);
} else if (queryMethod.isStreamQuery()) {
builder.addStatement("return ($T) convertMany(query.getResultStream(), $L, $T.class)",
context.getReturnTypeName(), aotQuery.isNative(), queryResultType);
builder.addStatement("return ($T) convertMany($L.getResultStream(), $L, $T.class)",
context.getReturnTypeName(), queryVariableName, aotQuery.isNative(), queryResultType);
} else if (queryMethod.isPageQuery()) {
builder.addStatement(
"return $T.getPage(($T<$T>) convertMany($L.getResultList(), $L, $T.class), $L, countAll)",
"return $T.getPage(($T<$T>) convertMany($L.getResultList(), $L, $T.class), $L, $L)",
PageableExecutionUtils.class, List.class, actualReturnType, queryVariableName, aotQuery.isNative(),
queryResultType, context.getPageableParameterName());
queryResultType, context.getPageableParameterName(), context.localVariable("countAll"));
} else if (queryMethod.isSliceQuery()) {
builder.addStatement("$T<$T> resultList = ($T<$T>) convertMany($L.getResultList(), $L, $T.class)",
List.class, actualReturnType, List.class, actualReturnType, queryVariableName, aotQuery.isNative(),
builder.addStatement("$T<$T> $L = ($T<$T>) convertMany($L.getResultList(), $L, $T.class)", List.class,
actualReturnType, context.localVariable("resultList"), List.class, actualReturnType, queryVariableName,
aotQuery.isNative(),
queryResultType);
builder.addStatement("boolean hasNext = $L.isPaged() && resultList.size() > $L.getPageSize()",
context.getPageableParameterName(), context.getPageableParameterName());
builder.addStatement("boolean $L = $L.isPaged() && $L.size() > $L.getPageSize()",
context.localVariable("hasNext"), context.getPageableParameterName(),
context.localVariable("resultList"), context.getPageableParameterName());
builder.addStatement(
"return new $T<>(hasNext ? resultList.subList(0, $L.getPageSize()) : resultList, $L, hasNext)",
SliceImpl.class, context.getPageableParameterName(), context.getPageableParameterName());
"return new $T<>($L ? $L.subList(0, $L.getPageSize()) : $L, $L, $L)", SliceImpl.class,
context.localVariable("hasNext"), context.localVariable("resultList"),
context.getPageableParameterName(), context.localVariable("resultList"),
context.getPageableParameterName(), context.localVariable("hasNext"));
} else {
if (Optional.class.isAssignableFrom(context.getReturnType().toClass())) {
@@ -642,21 +661,24 @@ class JpaCodeBlocks {
} else {
if (queryMethod.isCollectionQuery()) {
builder.addStatement("return ($T) query.getResultList()", context.getReturnTypeName());
builder.addStatement("return ($T) $L.getResultList()", context.getReturnTypeName(), queryVariableName);
} else if (queryMethod.isStreamQuery()) {
builder.addStatement("return ($T) query.getResultStream()", context.getReturnTypeName());
builder.addStatement("return ($T) $L.getResultStream()", context.getReturnTypeName(), queryVariableName);
} else if (queryMethod.isPageQuery()) {
builder.addStatement("return $T.getPage(($T<$T>) $L.getResultList(), $L, countAll)",
builder.addStatement("return $T.getPage(($T<$T>) $L.getResultList(), $L, $L)",
PageableExecutionUtils.class, List.class, actualReturnType, queryVariableName,
context.getPageableParameterName());
context.getPageableParameterName(), context.localVariable("countAll"));
} else if (queryMethod.isSliceQuery()) {
builder.addStatement("$T<$T> resultList = $L.getResultList()", List.class, actualReturnType,
queryVariableName);
builder.addStatement("boolean hasNext = $L.isPaged() && resultList.size() > $L.getPageSize()",
context.getPageableParameterName(), context.getPageableParameterName());
builder.addStatement("$T<$T> $L = $L.getResultList()", List.class, actualReturnType,
context.localVariable("resultList"), queryVariableName);
builder.addStatement("boolean $L = $L.isPaged() && $L.size() > $L.getPageSize()",
context.localVariable("hasNext"), context.getPageableParameterName(),
context.localVariable("resultList"), context.getPageableParameterName());
builder.addStatement(
"return new $T<>(hasNext ? resultList.subList(0, $L.getPageSize()) : resultList, $L, hasNext)",
SliceImpl.class, context.getPageableParameterName(), context.getPageableParameterName());
"return new $T<>($L ? $L.subList(0, $L.getPageSize()) : $L, $L, $L)", SliceImpl.class,
context.localVariable("hasNext"), context.localVariable("resultList"),
context.getPageableParameterName(), context.localVariable("resultList"),
context.getPageableParameterName(), context.localVariable("hasNext"));
} else {
if (Optional.class.isAssignableFrom(context.getReturnType().toClass())) {

View File

@@ -36,20 +36,18 @@ import org.springframework.data.jpa.repository.query.JpaParameters;
import org.springframework.data.jpa.repository.query.JpaQueryMethod;
import org.springframework.data.jpa.repository.query.Procedure;
import org.springframework.data.jpa.repository.query.QueryEnhancerSelector;
import org.springframework.data.repository.aot.generate.AotRepositoryClassBuilder;
import org.springframework.data.repository.aot.generate.AotRepositoryConstructorBuilder;
import org.springframework.data.repository.aot.generate.AotRepositoryFragmentMetadata;
import org.springframework.data.repository.aot.generate.MethodContributor;
import org.springframework.data.repository.aot.generate.QueryMetadata;
import org.springframework.data.repository.aot.generate.RepositoryContributor;
import org.springframework.data.repository.config.AotRepositoryContext;
import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport;
import org.springframework.data.repository.query.QueryMethod;
import org.springframework.data.repository.query.ReturnedType;
import org.springframework.data.util.TypeInformation;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.TypeName;
import org.springframework.javapoet.TypeSpec;
import org.springframework.util.ClassUtils;
import org.springframework.util.StringUtils;
@@ -88,9 +86,8 @@ public class JpaRepositoryContributor extends RepositoryContributor {
}
@Override
protected void customizeClass(RepositoryInformation information, AotRepositoryFragmentMetadata metadata,
TypeSpec.Builder builder) {
builder.superclass(TypeName.get(AotRepositoryFragmentSupport.class));
protected void customizeClass(AotRepositoryClassBuilder classBuilder) {
classBuilder.customize(builder -> builder.superclass(TypeName.get(AotRepositoryFragmentSupport.class)));
}
@Override
@@ -102,16 +99,15 @@ public class JpaRepositoryContributor extends RepositoryContributor {
constructorBuilder.addParameter("context", RepositoryFactoryBeanSupport.FragmentCreationContext.class);
// TODO: Pick up the configured QueryEnhancerSelector
constructorBuilder.customize((repositoryInformation, builder) -> {
constructorBuilder.customize(builder -> {
builder.addStatement("super($T.DEFAULT_SELECTOR, context)", QueryEnhancerSelector.class);
});
}
@Override
protected @Nullable MethodContributor<? extends QueryMethod> contributeQueryMethod(Method method,
RepositoryInformation repositoryInformation) {
protected @Nullable MethodContributor<? extends QueryMethod> contributeQueryMethod(Method method) {
JpaQueryMethod queryMethod = new JpaQueryMethod(method, repositoryInformation, getProjectionFactory(),
JpaQueryMethod queryMethod = new JpaQueryMethod(method, getRepositoryInformation(), getProjectionFactory(),
persistenceProvider);
// meh!
@@ -125,7 +121,6 @@ public class JpaRepositoryContributor extends RepositoryContributor {
MethodContributor.QueryMethodMetadataContributorBuilder<JpaQueryMethod> builder = MethodContributor
.forQueryMethod(queryMethod);
if (procedure != null) {
if (StringUtils.hasText(procedure.name())) {
@@ -150,7 +145,7 @@ public class JpaRepositoryContributor extends RepositoryContributor {
MergedAnnotation<Query> query = MergedAnnotations.from(method).get(Query.class);
AotQueries aotQueries = queriesFactory.createQueries(repositoryInformation, query, selector, queryMethod,
AotQueries aotQueries = queriesFactory.createQueries(getRepositoryInformation(), query, selector, queryMethod,
returnedType);
// no KeysetScrolling for now.
@@ -167,7 +162,7 @@ public class JpaRepositoryContributor extends RepositoryContributor {
if (queryMethod.isModifyingQuery()) {
TypeInformation<?> returnType = repositoryInformation.getReturnType(method);
TypeInformation<?> returnType = getRepositoryInformation().getReturnType(method);
boolean returnsCount = JpaCodeBlocks.QueryExecutionBlockBuilder.returnsModifying(returnType.getType());
@@ -182,26 +177,26 @@ public class JpaRepositoryContributor extends RepositoryContributor {
return MethodContributor.forQueryMethod(queryMethod).withMetadata(aotQueries.toMetadata(queryMethod.isPageQuery()))
.contribute(context -> {
CodeBlock.Builder body = CodeBlock.builder();
CodeBlock.Builder body = CodeBlock.builder();
MergedAnnotation<NativeQuery> nativeQuery = context.getAnnotation(NativeQuery.class);
MergedAnnotation<QueryHints> queryHints = context.getAnnotation(QueryHints.class);
MergedAnnotation<EntityGraph> entityGraph = context.getAnnotation(EntityGraph.class);
MergedAnnotation<Modifying> modifying = context.getAnnotation(Modifying.class);
MergedAnnotation<NativeQuery> nativeQuery = context.getAnnotation(NativeQuery.class);
MergedAnnotation<QueryHints> queryHints = context.getAnnotation(QueryHints.class);
MergedAnnotation<EntityGraph> entityGraph = context.getAnnotation(EntityGraph.class);
MergedAnnotation<Modifying> modifying = context.getAnnotation(Modifying.class);
AotEntityGraph aotEntityGraph = entityGraphLookup.findEntityGraph(entityGraph, repositoryInformation,
returnedType, queryMethod);
AotEntityGraph aotEntityGraph = entityGraphLookup.findEntityGraph(entityGraph, getRepositoryInformation(),
returnedType, queryMethod);
body.add(JpaCodeBlocks.queryBuilder(context, queryMethod).filter(aotQueries)
.queryReturnType(QueriesFactory.getQueryReturnType(aotQueries.result(), returnedType, context))
.nativeQuery(nativeQuery).queryHints(queryHints).entityGraph(aotEntityGraph)
.queryRewriter(query.isPresent() ? query.getClass("queryRewriter") : null).build());
body.add(JpaCodeBlocks.queryBuilder(context, queryMethod).filter(aotQueries)
.queryReturnType(QueriesFactory.getQueryReturnType(aotQueries.result(), returnedType, context))
.nativeQuery(nativeQuery).queryHints(queryHints).entityGraph(aotEntityGraph)
.queryRewriter(query.isPresent() ? query.getClass("queryRewriter") : null).build());
body.add(
JpaCodeBlocks.executionBuilder(context, queryMethod).modifying(modifying).query(aotQueries.result()).build());
body.add(JpaCodeBlocks.executionBuilder(context, queryMethod).modifying(modifying).query(aotQueries.result())
.build());
return body.build();
});
return body.build();
});
}
record StoredProcedureMetadata(String procedure) implements QueryMetadata {

View File

@@ -104,8 +104,9 @@ interface UserRepository extends CrudRepository<User, Integer> {
@Query("select u from User u where u.lastname like ?1%")
List<User> findAnnotatedQueryByLastname(String lastname, Limit limit, Sort sort);
// nasty parameter names
@Query("select u from User u where u.lastname like ?1%")
List<User> findAnnotatedQueryByLastname(String lastname, Pageable pageable);
List<User> findAnnotatedQueryByLastname(String query, Pageable queryString);
@Query("select u from User u where u.lastname like ?1%")
Page<User> findAnnotatedQueryPageOfUsersByLastname(String lastname, Pageable pageable);

View File

@@ -19,7 +19,8 @@
<!-- <logger name="org.testcontainers" level="debug" />-->
<!-- <logger name="org.springframework.data.repository.aot.generate.RepositoryContributor" level="trace" /> -->
<logger name="org.springframework.data.repository.aot.generate.RepositoryContributor"
level="warn"/>
<root level="error">
<appender-ref ref="console"/>