diff --git a/spring-aop/src/test/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessorTests.java b/spring-aop/src/test/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessorTests.java index 634d0a097a..d9897fa588 100644 --- a/spring-aop/src/test/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessorTests.java +++ b/spring-aop/src/test/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessorTests.java @@ -42,6 +42,7 @@ import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.beans.testfixture.beans.factory.aot.MockBeanFactoryInitializationCode; import org.springframework.beans.testfixture.beans.factory.generator.factory.NumberHolder; import org.springframework.core.ResolvableType; +import org.springframework.core.testfixture.aot.generate.TestGenerationContext; import org.springframework.util.ReflectionUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -71,7 +72,7 @@ class ScopedProxyBeanRegistrationAotProcessorTests { this.beanFactory = new DefaultListableBeanFactory(); this.processor = new TestBeanRegistrationsAotProcessor(); this.generatedFiles = new InMemoryGeneratedFiles(); - this.generationContext = new DefaultGenerationContext(this.generatedFiles); + this.generationContext = new TestGenerationContext(this.generatedFiles); this.beanFactoryInitializationCode = new MockBeanFactoryInitializationCode(); } diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanPostProcessor.java b/spring-beans/src/main/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanPostProcessor.java index dc4db39242..c75d94aa32 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanPostProcessor.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanPostProcessor.java @@ -41,6 +41,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.aot.generate.AccessVisibility; +import org.springframework.aot.generate.GeneratedClass; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; import org.springframework.aot.hint.ExecutableHint; @@ -79,11 +80,8 @@ import org.springframework.core.annotation.AnnotationAttributes; import org.springframework.core.annotation.AnnotationUtils; import org.springframework.core.annotation.MergedAnnotation; import org.springframework.core.annotation.MergedAnnotations; -import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; -import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.MethodSpec; -import org.springframework.javapoet.TypeSpec; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; @@ -910,30 +908,28 @@ public class AutowiredAnnotationBeanPostProcessor implements SmartInstantiationA @Override public void applyTo(GenerationContext generationContext, BeanRegistrationCode beanRegistrationCode) { - - ClassName className = generationContext.getClassNameGenerator() - .generateClassName(this.target, "Autowiring"); - TypeSpec.Builder classBuilder = TypeSpec.classBuilder(className); - classBuilder.addJavadoc("Autowiring for {@link $T}.", this.target); - classBuilder.addModifiers(javax.lang.model.element.Modifier.PUBLIC); - classBuilder.addMethod(generateMethod(generationContext.getRuntimeHints())); - JavaFile javaFile = JavaFile - .builder(className.packageName(), classBuilder.build()).build(); - generationContext.getGeneratedFiles().addSourceFile(javaFile); + GeneratedClass generatedClass = generationContext.getGeneratedClasses() + .forFeatureComponent("Autowiring", this.target) + .generate(type -> { + type.addJavadoc("Autowiring for {@link $T}.", this.target); + type.addModifiers(javax.lang.model.element.Modifier.PUBLIC); + }); + generatedClass.getMethodGenerator().generateMethod(APPLY_METHOD) + .using(generateMethod(generationContext.getRuntimeHints())); beanRegistrationCode.addInstancePostProcessor( - MethodReference.ofStatic(className, APPLY_METHOD)); + MethodReference.ofStatic(generatedClass.getName(), APPLY_METHOD)); } - private MethodSpec generateMethod(RuntimeHints hints) { - MethodSpec.Builder builder = MethodSpec.methodBuilder(APPLY_METHOD); - builder.addJavadoc("Apply the autowiring."); - builder.addModifiers(javax.lang.model.element.Modifier.PUBLIC, - javax.lang.model.element.Modifier.STATIC); - builder.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER); - builder.addParameter(this.target, INSTANCE_PARAMETER); - builder.returns(this.target); - builder.addCode(generateMethodCode(hints)); - return builder.build(); + private Consumer generateMethod(RuntimeHints hints) { + return method -> { + method.addJavadoc("Apply the autowiring."); + method.addModifiers(javax.lang.model.element.Modifier.PUBLIC, + javax.lang.model.element.Modifier.STATIC); + method.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER); + method.addParameter(this.target, INSTANCE_PARAMETER); + method.returns(this.target); + method.addCode(generateMethodCode(hints)); + }; } private CodeBlock generateMethodCode(RuntimeHints hints) { diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java index de604409b1..ed7682a882 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java @@ -21,10 +21,8 @@ import java.util.List; import javax.lang.model.element.Modifier; -import org.springframework.aot.generate.ClassGenerator.JavaFileGenerator; import org.springframework.aot.generate.GeneratedClass; import org.springframework.aot.generate.GeneratedMethod; -import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodGenerator; import org.springframework.aot.generate.MethodNameGenerator; @@ -32,8 +30,6 @@ import org.springframework.aot.generate.MethodReference; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.javapoet.ClassName; -import org.springframework.javapoet.JavaFile; -import org.springframework.javapoet.TypeSpec; import org.springframework.lang.Nullable; /** @@ -45,6 +41,8 @@ import org.springframework.lang.Nullable; */ class BeanDefinitionMethodGenerator { + private static final String FEATURE_NAME = "BeanDefinitions"; + private final BeanDefinitionMethodGeneratorFactory methodGeneratorFactory; private final RegisteredBean registeredBean; @@ -81,22 +79,23 @@ class BeanDefinitionMethodGenerator { * Generate the method that returns the {@link BeanDefinition} to be * registered. * @param generationContext the generation context - * @param featureNamePrefix the prefix to use for the feature name * @param beanRegistrationsCode the bean registrations code * @return a reference to the generated method. */ MethodReference generateBeanDefinitionMethod(GenerationContext generationContext, - String featureNamePrefix, BeanRegistrationsCode beanRegistrationsCode) { + BeanRegistrationsCode beanRegistrationsCode) { BeanRegistrationCodeFragments codeFragments = getCodeFragments(generationContext, - beanRegistrationsCode, featureNamePrefix); + beanRegistrationsCode); Class target = codeFragments.getTarget(this.registeredBean, this.constructorOrFactoryMethod); if (!target.getName().startsWith("java.")) { - String featureName = featureNamePrefix + "BeanDefinitions"; - GeneratedClass generatedClass = generationContext.getClassGenerator() - .getOrGenerateClass(new BeanDefinitionsJavaFileGenerator(target), - target, featureName); + GeneratedClass generatedClass = generationContext.getGeneratedClasses() + .forFeatureComponent(FEATURE_NAME, target) + .getOrGenerate(FEATURE_NAME, type -> { + type.addJavadoc("Bean definitions for {@link $T}", target); + type.addModifiers(Modifier.PUBLIC); + }); MethodGenerator methodGenerator = generatedClass.getMethodGenerator() .withName(getName()); GeneratedMethod generatedMethod = generateBeanDefinitionMethod( @@ -115,11 +114,10 @@ class BeanDefinitionMethodGenerator { } private BeanRegistrationCodeFragments getCodeFragments(GenerationContext generationContext, - BeanRegistrationsCode beanRegistrationsCode, String featureNamePrefix) { + BeanRegistrationsCode beanRegistrationsCode) { BeanRegistrationCodeFragments codeFragments = new DefaultBeanRegistrationCodeFragments( - beanRegistrationsCode, this.registeredBean, this.methodGeneratorFactory, - featureNamePrefix); + beanRegistrationsCode, this.registeredBean, this.methodGeneratorFactory); for (BeanRegistrationAotContribution aotContribution : this.aotContributions) { codeFragments = aotContribution.customizeBeanRegistrationCodeFragments(generationContext, codeFragments); } @@ -172,41 +170,4 @@ class BeanDefinitionMethodGenerator { return beanName; } - - /** - * {@link BeanDefinitionsJavaFileGenerator} to create the - * {@code BeanDefinitions} file. - */ - private static class BeanDefinitionsJavaFileGenerator implements JavaFileGenerator { - - private final Class target; - - - BeanDefinitionsJavaFileGenerator(Class target) { - this.target = target; - } - - - @Override - public JavaFile generateJavaFile(ClassName className, GeneratedMethods methods) { - TypeSpec.Builder classBuilder = TypeSpec.classBuilder(className); - classBuilder.addJavadoc("Bean definitions for {@link $T}", this.target); - classBuilder.addModifiers(Modifier.PUBLIC); - methods.doWithMethodSpecs(classBuilder::addMethod); - return JavaFile.builder(className.packageName(), classBuilder.build()) - .build(); - } - - @Override - public int hashCode() { - return getClass().hashCode(); - } - - @Override - public boolean equals(Object obj) { - return getClass() == obj.getClass(); - } - - } - } diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationCode.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationCode.java index ceefb8e2e6..92e250ba7e 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationCode.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanFactoryInitializationCode.java @@ -18,7 +18,6 @@ package org.springframework.beans.factory.aot; import org.springframework.aot.generate.MethodGenerator; import org.springframework.aot.generate.MethodReference; -import org.springframework.lang.Nullable; /** * Interface that can be used to configure the code that will be generated to @@ -35,24 +34,6 @@ public interface BeanFactoryInitializationCode { */ String BEAN_FACTORY_VARIABLE = "beanFactory"; - /** - * Return the target class for this bean factory or {@code null} if there is - * no target. - * @return the target - */ - @Nullable - default Class getTarget() { - return null; - } - - /** - * Return the name of the bean factory or and empty string if no ID is available. - * @return the bean factory name - */ - default String getName() { - return ""; - } - /** * Return a {@link MethodGenerator} that can be used to add more methods to * the Initializing code. diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java index 52e661b2c1..7b29174c23 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContribution.java @@ -20,17 +20,15 @@ import java.util.Map; import javax.lang.model.element.Modifier; +import org.springframework.aot.generate.GeneratedClass; import org.springframework.aot.generate.GeneratedMethod; -import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodGenerator; import org.springframework.aot.generate.MethodReference; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; -import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.MethodSpec; -import org.springframework.javapoet.TypeSpec; /** * AOT contribution from a {@link BeanRegistrationsAotProcessor} used to @@ -61,24 +59,23 @@ class BeanRegistrationsAotContribution public void applyTo(GenerationContext generationContext, BeanFactoryInitializationCode beanFactoryInitializationCode) { - ClassName className = generationContext.getClassNameGenerator().generateClassName( - beanFactoryInitializationCode.getTarget(), - beanFactoryInitializationCode.getName() + "BeanFactoryRegistrations"); + GeneratedClass generatedClass = generationContext.getGeneratedClasses() + .forFeature("BeanFactoryRegistrations").generate(type -> { + type.addJavadoc("Register bean definitions for the bean factory."); + type.addModifiers(Modifier.PUBLIC); + }); BeanRegistrationsCodeGenerator codeGenerator = new BeanRegistrationsCodeGenerator( - className); + generatedClass); GeneratedMethod registerMethod = codeGenerator.getMethodGenerator() .generateMethod("registerBeanDefinitions") .using(builder -> generateRegisterMethod(builder, generationContext, - beanFactoryInitializationCode.getName(), codeGenerator)); - JavaFile javaFile = codeGenerator.generatedJavaFile(className); - generationContext.getGeneratedFiles().addSourceFile(javaFile); beanFactoryInitializationCode - .addInitializer(MethodReference.of(className, registerMethod.getName())); + .addInitializer(MethodReference.of(generatedClass.getName(), registerMethod.getName())); } private void generateRegisterMethod(MethodSpec.Builder builder, - GenerationContext generationContext, String featureNamePrefix, + GenerationContext generationContext, BeanRegistrationsCode beanRegistrationsCode) { builder.addJavadoc("Register the bean definitions."); @@ -88,7 +85,7 @@ class BeanRegistrationsAotContribution CodeBlock.Builder code = CodeBlock.builder(); this.registrations.forEach((beanName, beanDefinitionMethodGenerator) -> { MethodReference beanDefinitionMethod = beanDefinitionMethodGenerator - .generateBeanDefinitionMethod(generationContext, featureNamePrefix, + .generateBeanDefinitionMethod(generationContext, beanRegistrationsCode); code.addStatement("$L.registerBeanDefinition($S, $L)", BEAN_FACTORY_PARAMETER_NAME, beanName, @@ -103,33 +100,21 @@ class BeanRegistrationsAotContribution */ static class BeanRegistrationsCodeGenerator implements BeanRegistrationsCode { - private final ClassName className; + private final GeneratedClass generatedClass; - private final GeneratedMethods generatedMethods = new GeneratedMethods(); - - - public BeanRegistrationsCodeGenerator(ClassName className) { - this.className = className; + public BeanRegistrationsCodeGenerator(GeneratedClass generatedClass) { + this.generatedClass = generatedClass; } @Override public ClassName getClassName() { - return this.className; + return this.generatedClass.getName(); } @Override public MethodGenerator getMethodGenerator() { - return this.generatedMethods; - } - - JavaFile generatedJavaFile(ClassName className) { - TypeSpec.Builder classBuilder = TypeSpec.classBuilder(className); - classBuilder.addJavadoc("Register bean definitions for the bean factory."); - classBuilder.addModifiers(Modifier.PUBLIC); - this.generatedMethods.doWithMethodSpecs(classBuilder::addMethod); - return JavaFile.builder(className.packageName(), classBuilder.build()) - .build(); + return this.generatedClass.getMethodGenerator(); } } diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java index f5f5137f0f..4f00fa68a4 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java @@ -54,18 +54,14 @@ class DefaultBeanRegistrationCodeFragments extends BeanRegistrationCodeFragments private final BeanDefinitionMethodGeneratorFactory beanDefinitionMethodGeneratorFactory; - private final String featureNamePrefix; - DefaultBeanRegistrationCodeFragments(BeanRegistrationsCode beanRegistrationsCode, RegisteredBean registeredBean, - BeanDefinitionMethodGeneratorFactory beanDefinitionMethodGeneratorFactory, - String featureNamePrefix) { + BeanDefinitionMethodGeneratorFactory beanDefinitionMethodGeneratorFactory) { this.beanRegistrationsCode = beanRegistrationsCode; this.registeredBean = registeredBean; this.beanDefinitionMethodGeneratorFactory = beanDefinitionMethodGeneratorFactory; - this.featureNamePrefix = featureNamePrefix; } @@ -124,7 +120,7 @@ class DefaultBeanRegistrationCodeFragments extends BeanRegistrationCodeFragments .getBeanDefinitionMethodGenerator(innerRegisteredBean, name); Assert.state(methodGenerator != null, "Unexpected filtering of inner-bean"); MethodReference generatedMethod = methodGenerator - .generateBeanDefinitionMethod(generationContext, this.featureNamePrefix, + .generateBeanDefinitionMethod(generationContext, this.beanRegistrationsCode); return generatedMethod.toInvokeCodeBlock(); } diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java index 2c88fb2230..4237c126bf 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java @@ -25,7 +25,6 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.springframework.aot.generate.DefaultGenerationContext; -import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.InMemoryGeneratedFiles; import org.springframework.aot.generate.MethodReference; import org.springframework.aot.hint.RuntimeHints; @@ -40,6 +39,7 @@ import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.beans.testfixture.beans.factory.aot.MockBeanRegistrationCode; import org.springframework.core.env.Environment; import org.springframework.core.env.StandardEnvironment; +import org.springframework.core.testfixture.aot.generate.TestGenerationContext; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.MethodSpec; @@ -59,7 +59,7 @@ class AutowiredAnnotationBeanRegistrationAotContributionTests { private InMemoryGeneratedFiles generatedFiles; - private GenerationContext generationContext; + private DefaultGenerationContext generationContext; private RuntimeHints runtimeHints; @@ -70,7 +70,7 @@ class AutowiredAnnotationBeanRegistrationAotContributionTests { @BeforeEach void setup() { this.generatedFiles = new InMemoryGeneratedFiles(); - this.generationContext = new DefaultGenerationContext(this.generatedFiles); + this.generationContext = new TestGenerationContext(this.generatedFiles); this.runtimeHints = this.generationContext.getRuntimeHints(); this.beanRegistrationCode = new MockBeanRegistrationCode(); this.beanFactory = new DefaultListableBeanFactory(); @@ -169,6 +169,7 @@ class AutowiredAnnotationBeanRegistrationAotContributionTests { @SuppressWarnings("unchecked") private void testCompiledResult(RegisteredBean registeredBean, BiConsumer, Compiled> result) { + this.generationContext.writeGeneratedContent(); JavaFile javaFile = createJavaFile(registeredBean.getBeanClass()); TestCompiler.forSystem().withFiles(this.generatedFiles).compile(javaFile::writeTo, compiled -> result.accept(compiled.getInstance(BiFunction.class), diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java index 572b70ac2e..a8183ce4cb 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java @@ -50,6 +50,7 @@ import org.springframework.beans.testfixture.beans.TestBean; import org.springframework.beans.testfixture.beans.factory.aot.MockBeanRegistrationsCode; import org.springframework.core.ResolvableType; import org.springframework.core.mock.MockSpringFactoriesLoader; +import org.springframework.core.testfixture.aot.generate.TestGenerationContext; import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.JavaFile; @@ -80,7 +81,7 @@ class BeanDefinitionMethodGeneratorTests { @BeforeEach void setup() { this.generatedFiles = new InMemoryGeneratedFiles(); - this.generationContext = new DefaultGenerationContext(this.generatedFiles); + this.generationContext = new TestGenerationContext(this.generatedFiles); this.beanFactory = new DefaultListableBeanFactory(); this.methodGeneratorFactory = new BeanDefinitionMethodGeneratorFactory( new AotFactoriesLoader(this.beanFactory, new MockSpringFactoriesLoader())); @@ -96,7 +97,7 @@ class BeanDefinitionMethodGeneratorTests { this.methodGeneratorFactory, registeredBean, null, Collections.emptyList()); MethodReference method = generator.generateBeanDefinitionMethod( - this.generationContext, "", this.beanRegistrationsCode); + this.generationContext, this.beanRegistrationsCode); testCompiledResult(method, (actual, compiled) -> { SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); @@ -114,7 +115,7 @@ class BeanDefinitionMethodGeneratorTests { this.methodGeneratorFactory, registeredBean, null, Collections.emptyList()); MethodReference method = generator.generateBeanDefinitionMethod( - this.generationContext, "", this.beanRegistrationsCode); + this.generationContext, this.beanRegistrationsCode); testCompiledResult(method, (actual, compiled) -> { assertThat(actual.getResolvableType().resolve()).isEqualTo(GenericBean.class); SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); @@ -147,7 +148,7 @@ class BeanDefinitionMethodGeneratorTests { BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( this.methodGeneratorFactory, registeredBean, null, aotContributions); MethodReference method = generator.generateBeanDefinitionMethod( - this.generationContext, "", this.beanRegistrationsCode); + this.generationContext, this.beanRegistrationsCode); testCompiledResult(method, (actual, compiled) -> { assertThat(actual.getBeanClass()).isEqualTo(TestBean.class); InstanceSupplier supplier = (InstanceSupplier) actual @@ -173,7 +174,7 @@ class BeanDefinitionMethodGeneratorTests { BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( this.methodGeneratorFactory, registeredBean, null, aotContributions); MethodReference method = generator.generateBeanDefinitionMethod( - this.generationContext, "", this.beanRegistrationsCode); + this.generationContext, this.beanRegistrationsCode); testCompiledResult(method, (actual, compiled) -> { assertThat(actual.getBeanClass()).isEqualTo(TestBean.class); SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); @@ -213,7 +214,7 @@ class BeanDefinitionMethodGeneratorTests { this.methodGeneratorFactory, registeredBean, null, aotContributions); MethodReference method = generator.generateBeanDefinitionMethod( - this.generationContext, "", this.beanRegistrationsCode); + this.generationContext, this.beanRegistrationsCode); testCompiledResult(method, (actual, compiled) -> { assertThat(actual.getAttribute("a")).isEqualTo("A"); assertThat(actual.getAttribute("b")).isNull(); @@ -246,7 +247,7 @@ class BeanDefinitionMethodGeneratorTests { this.methodGeneratorFactory, innerBean, "testInnerBean", Collections.emptyList()); MethodReference method = generator.generateBeanDefinitionMethod( - this.generationContext, "", this.beanRegistrationsCode); + this.generationContext, this.beanRegistrationsCode); testCompiledResult(method, (actual, compiled) -> { assertThat(compiled.getSourceFile(".*BeanDefinitions")) .contains("Get the inner-bean definition for 'testInnerBean'"); @@ -267,7 +268,7 @@ class BeanDefinitionMethodGeneratorTests { this.methodGeneratorFactory, registeredBean, null, Collections.emptyList()); MethodReference method = generator.generateBeanDefinitionMethod( - this.generationContext, "", this.beanRegistrationsCode); + this.generationContext, this.beanRegistrationsCode); testCompiledResult(method, (actual, compiled) -> { RootBeanDefinition actualInnerBeanDefinition = (RootBeanDefinition) actual .getPropertyValues().get("name"); @@ -301,7 +302,7 @@ class BeanDefinitionMethodGeneratorTests { this.methodGeneratorFactory, registeredBean, null, Collections.emptyList()); MethodReference method = generator.generateBeanDefinitionMethod( - this.generationContext, "", this.beanRegistrationsCode); + this.generationContext, this.beanRegistrationsCode); testCompiledResult(method, (actual, compiled) -> { RootBeanDefinition actualInnerBeanDefinition = (RootBeanDefinition) actual .getConstructorArgumentValues() @@ -334,7 +335,7 @@ class BeanDefinitionMethodGeneratorTests { BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( this.methodGeneratorFactory, registeredBean, null, aotContributions); MethodReference method = generator.generateBeanDefinitionMethod( - this.generationContext, "", this.beanRegistrationsCode); + this.generationContext, this.beanRegistrationsCode); testCompiledResult(method, (actual, compiled) -> { SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); assertThat(sourceFile).contains("AotContributedMethod()"); @@ -351,7 +352,7 @@ class BeanDefinitionMethodGeneratorTests { this.methodGeneratorFactory, registeredBean, null, Collections.emptyList()); MethodReference method = generator.generateBeanDefinitionMethod( - this.generationContext, "", this.beanRegistrationsCode); + this.generationContext, this.beanRegistrationsCode); testCompiledResult(method, (actual, compiled) -> { DefaultListableBeanFactory freshBeanFactory = new DefaultListableBeanFactory(); freshBeanFactory.registerBeanDefinition("test", actual); diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java index ae91b492a7..580855d946 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanRegistrationsAotContributionTests.java @@ -29,6 +29,7 @@ import javax.lang.model.element.Modifier; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.springframework.aot.generate.ClassNameGenerator; import org.springframework.aot.generate.DefaultGenerationContext; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.InMemoryGeneratedFiles; @@ -42,6 +43,8 @@ import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.beans.testfixture.beans.TestBean; import org.springframework.beans.testfixture.beans.factory.aot.MockBeanFactoryInitializationCode; import org.springframework.core.mock.MockSpringFactoriesLoader; +import org.springframework.core.testfixture.aot.generate.TestGenerationContext; +import org.springframework.core.testfixture.aot.generate.TestTarget; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.MethodSpec; @@ -72,7 +75,7 @@ class BeanRegistrationsAotContributionTests { @BeforeEach void setup() { this.generatedFiles = new InMemoryGeneratedFiles(); - this.generationContext = new DefaultGenerationContext(this.generatedFiles); + this.generationContext = new TestGenerationContext(this.generatedFiles); this.beanFactory = new DefaultListableBeanFactory(); this.springFactoriesLoader = new MockSpringFactoriesLoader(); this.methodGeneratorFactory = new BeanDefinitionMethodGeneratorFactory( @@ -100,7 +103,9 @@ class BeanRegistrationsAotContributionTests { @Test void applyToWhenHasNameGeneratesPrefixedFeatureName() { - this.beanFactoryInitializationCode = new MockBeanFactoryInitializationCode("Management"); + this.generationContext = new DefaultGenerationContext( + new ClassNameGenerator(TestTarget.class, "Management"), this.generatedFiles); + this.beanFactoryInitializationCode = new MockBeanFactoryInitializationCode(); Map registrations = new LinkedHashMap<>(); RegisteredBean registeredBean = registerBean( new RootBeanDefinition(TestBean.class)); @@ -129,11 +134,11 @@ class BeanRegistrationsAotContributionTests { @Override MethodReference generateBeanDefinitionMethod( - GenerationContext generationContext, String featureNamePrefix, + GenerationContext generationContext, BeanRegistrationsCode beanRegistrationsCode) { beanRegistrationsCodes.add(beanRegistrationsCode); return super.generateBeanDefinitionMethod(generationContext, - featureNamePrefix, beanRegistrationsCode); + beanRegistrationsCode); } }; diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGeneratorTests.java index d4de4edc3b..beb867620c 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGeneratorTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGeneratorTests.java @@ -52,6 +52,7 @@ import org.springframework.beans.testfixture.beans.factory.generator.factory.Num import org.springframework.beans.testfixture.beans.factory.generator.factory.SampleFactory; import org.springframework.beans.testfixture.beans.factory.generator.injection.InjectionComponent; import org.springframework.core.env.StandardEnvironment; +import org.springframework.core.testfixture.aot.generate.TestGenerationContext; import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.JavaFile; @@ -82,7 +83,7 @@ class InstanceSupplierCodeGeneratorTests { @BeforeEach void setup() { this.generatedFiles = new InMemoryGeneratedFiles(); - this.generationContext = new DefaultGenerationContext(this.generatedFiles); + this.generationContext = new TestGenerationContext(this.generatedFiles); } diff --git a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanFactoryInitializationCode.java b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanFactoryInitializationCode.java index d1aee20458..b931cccdeb 100644 --- a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanFactoryInitializationCode.java +++ b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/MockBeanFactoryInitializationCode.java @@ -35,21 +35,6 @@ public class MockBeanFactoryInitializationCode implements BeanFactoryInitializat private final List initializers = new ArrayList<>(); - private final String name; - - public MockBeanFactoryInitializationCode() { - this(""); - } - - public MockBeanFactoryInitializationCode(String name) { - this.name = name; - } - - @Override - public String getName() { - return this.name; - } - @Override public GeneratedMethods getMethodGenerator() { return this.generatedMethods; diff --git a/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextAotGenerator.java b/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextAotGenerator.java index 0ea7827427..f5a450d1e6 100644 --- a/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextAotGenerator.java +++ b/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextAotGenerator.java @@ -16,14 +16,14 @@ package org.springframework.context.aot; +import org.springframework.aot.generate.GeneratedClass; import org.springframework.aot.generate.GenerationContext; import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationContextInitializer; import org.springframework.context.support.GenericApplicationContext; import org.springframework.javapoet.ClassName; -import org.springframework.javapoet.JavaFile; -import org.springframework.lang.Nullable; /** * Process an {@link ApplicationContext} and its {@link BeanFactory} to generate @@ -42,41 +42,20 @@ public class ApplicationContextAotGenerator { * specified {@link GenerationContext}. * @param applicationContext the application context to handle * @param generationContext the generation context to use - * @param generatedInitializerClassName the class name to use for the - * generated application context initializer + * @return the class name of the {@link ApplicationContextInitializer} entry point */ - public void generateApplicationContext(GenericApplicationContext applicationContext, - GenerationContext generationContext, - ClassName generatedInitializerClassName) { - - generateApplicationContext(applicationContext, null, null, generationContext, - generatedInitializerClassName); - } - - /** - * Refresh the specified {@link GenericApplicationContext} and generate the - * necessary code to restore the state of its {@link BeanFactory}, using the - * specified {@link GenerationContext}. - * @param applicationContext the application context to handle - * @param target the target class for the generated initializer (used when generating class names) - * @param name the name of the application context (used when generating class names) - * @param generationContext the generation context to use - * @param generatedInitializerClassName the class name to use for the - * generated application context initializer - */ - public void generateApplicationContext(GenericApplicationContext applicationContext, - @Nullable Class target, @Nullable String name, GenerationContext generationContext, - ClassName generatedInitializerClassName) { - + public ClassName generateApplicationContext(GenericApplicationContext applicationContext, + GenerationContext generationContext) { applicationContext.refreshForAotProcessing(); DefaultListableBeanFactory beanFactory = applicationContext .getDefaultListableBeanFactory(); - ApplicationContextInitializationCodeGenerator codeGenerator = new ApplicationContextInitializationCodeGenerator( - target, name); + ApplicationContextInitializationCodeGenerator codeGenerator = new ApplicationContextInitializationCodeGenerator(); new BeanFactoryInitializationAotContributions(beanFactory).applyTo(generationContext, codeGenerator); - JavaFile javaFile = codeGenerator.generateJavaFile(generatedInitializerClassName); - generationContext.getGeneratedFiles().addSourceFile(javaFile); + GeneratedClass applicationContextInitializer = generationContext.getGeneratedClasses() + .forFeature("ApplicationContextInitializer") + .generate(codeGenerator.generateJavaFile()); + return applicationContextInitializer.getName(); } } diff --git a/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java b/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java index ab4ae35665..fe532edda2 100644 --- a/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java +++ b/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java @@ -18,6 +18,7 @@ package org.springframework.context.aot; import java.util.ArrayList; import java.util.List; +import java.util.function.Consumer; import javax.lang.model.element.Modifier; @@ -29,14 +30,10 @@ import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.context.ApplicationContextInitializer; import org.springframework.context.annotation.ContextAnnotationAutowireCandidateResolver; import org.springframework.context.support.GenericApplicationContext; -import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; -import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.MethodSpec; import org.springframework.javapoet.ParameterizedTypeName; import org.springframework.javapoet.TypeSpec; -import org.springframework.lang.Nullable; -import org.springframework.util.StringUtils; /** * Internal code generator to create the application context initializer. @@ -50,33 +47,11 @@ class ApplicationContextInitializationCodeGenerator private static final String APPLICATION_CONTEXT_VARIABLE = "applicationContext"; - @Nullable - private final Class target; - - private final String name; - private final GeneratedMethods generatedMethods = new GeneratedMethods(); private final List initializers = new ArrayList<>(); - ApplicationContextInitializationCodeGenerator(@Nullable Class target, @Nullable String name) { - this.target = target; - this.name = (!StringUtils.hasText(name)) ? "" : name; - } - - - @Override - @Nullable - public Class getTarget() { - return this.target; - } - - @Override - public String getName() { - return this.name; - } - @Override public MethodGenerator getMethodGenerator() { return this.generatedMethods; @@ -87,17 +62,17 @@ class ApplicationContextInitializationCodeGenerator this.initializers.add(methodReference); } - JavaFile generateJavaFile(ClassName className) { - TypeSpec.Builder builder = TypeSpec.classBuilder(className); - builder.addJavadoc( - "{@link $T} to restore an application context based on previous AOT processing.", - ApplicationContextInitializer.class); - builder.addModifiers(Modifier.PUBLIC); - builder.addSuperinterface(ParameterizedTypeName.get( - ApplicationContextInitializer.class, GenericApplicationContext.class)); - builder.addMethod(generateInitializeMethod()); - this.generatedMethods.doWithMethodSpecs(builder::addMethod); - return JavaFile.builder(className.packageName(), builder.build()).build(); + Consumer generateJavaFile() { + return builder -> { + builder.addJavadoc( + "{@link $T} to restore an application context based on previous AOT processing.", + ApplicationContextInitializer.class); + builder.addModifiers(Modifier.PUBLIC); + builder.addSuperinterface(ParameterizedTypeName.get( + ApplicationContextInitializer.class, GenericApplicationContext.class)); + builder.addMethod(generateInitializeMethod()); + this.generatedMethods.doWithMethodSpecs(builder::addMethod); + }; } private MethodSpec generateInitializeMethod() { diff --git a/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java b/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java index 46d7080c16..393315fd48 100644 --- a/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java +++ b/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java @@ -37,6 +37,7 @@ import org.springframework.beans.testfixture.beans.factory.aot.MockBeanFactoryIn import org.springframework.beans.testfixture.beans.factory.generator.SimpleConfiguration; import org.springframework.context.testfixture.context.generator.annotation.ImportAwareConfiguration; import org.springframework.context.testfixture.context.generator.annotation.ImportConfiguration; +import org.springframework.core.testfixture.aot.generate.TestGenerationContext; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.MethodSpec; @@ -59,7 +60,7 @@ class ConfigurationClassPostProcessorAotContributionTests { private InMemoryGeneratedFiles generatedFiles = new InMemoryGeneratedFiles(); - private DefaultGenerationContext generationContext = new DefaultGenerationContext( + private DefaultGenerationContext generationContext = new TestGenerationContext( this.generatedFiles); private MockBeanFactoryInitializationCode beanFactoryInitializationCode = new MockBeanFactoryInitializationCode(); diff --git a/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java b/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java index f3a5baf964..30e30f3416 100644 --- a/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java +++ b/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java @@ -44,7 +44,7 @@ import org.springframework.context.support.GenericApplicationContext; import org.springframework.context.testfixture.context.generator.SimpleComponent; import org.springframework.context.testfixture.context.generator.annotation.AutowiredComponent; import org.springframework.context.testfixture.context.generator.annotation.InitDestroyComponent; -import org.springframework.javapoet.ClassName; +import org.springframework.core.testfixture.aot.generate.TestGenerationContext; import static org.assertj.core.api.Assertions.assertThat; @@ -56,9 +56,6 @@ import static org.assertj.core.api.Assertions.assertThat; */ class ApplicationContextAotGeneratorTests { - private static final ClassName MAIN_GENERATED_TYPE = ClassName.get("__", - "TestInitializer"); - @Test void generateApplicationContextWhenHasSimpleBean() { GenericApplicationContext applicationContext = new GenericApplicationContext(); @@ -191,10 +188,9 @@ class ApplicationContextAotGeneratorTests { BiConsumer, Compiled> result) { ApplicationContextAotGenerator generator = new ApplicationContextAotGenerator(); InMemoryGeneratedFiles generatedFiles = new InMemoryGeneratedFiles(); - DefaultGenerationContext generationContext = new DefaultGenerationContext( + DefaultGenerationContext generationContext = new TestGenerationContext( generatedFiles); - generator.generateApplicationContext(applicationContext, generationContext, - MAIN_GENERATED_TYPE); + generator.generateApplicationContext(applicationContext, generationContext); generationContext.writeGeneratedContent(); TestCompiler.forSystem().withFiles(generatedFiles) .compile(compiled -> result.accept( diff --git a/spring-context/src/test/java/org/springframework/context/aot/ReflectiveProcessorBeanRegistrationAotProcessorTests.java b/spring-context/src/test/java/org/springframework/context/aot/ReflectiveProcessorBeanRegistrationAotProcessorTests.java index 6cb85f517f..06a5829850 100644 --- a/spring-context/src/test/java/org/springframework/context/aot/ReflectiveProcessorBeanRegistrationAotProcessorTests.java +++ b/spring-context/src/test/java/org/springframework/context/aot/ReflectiveProcessorBeanRegistrationAotProcessorTests.java @@ -24,9 +24,7 @@ import java.lang.annotation.Target; import org.junit.jupiter.api.Test; -import org.springframework.aot.generate.DefaultGenerationContext; import org.springframework.aot.generate.GenerationContext; -import org.springframework.aot.generate.InMemoryGeneratedFiles; import org.springframework.aot.hint.MemberCategory; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.RuntimeHintsPredicates; @@ -39,6 +37,7 @@ import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.core.annotation.AliasFor; import org.springframework.core.annotation.SynthesizedAnnotation; +import org.springframework.core.testfixture.aot.generate.TestGenerationContext; import org.springframework.lang.Nullable; import static org.assertj.core.api.Assertions.assertThat; @@ -54,8 +53,7 @@ class ReflectiveProcessorBeanRegistrationAotProcessorTests { private final ReflectiveProcessorBeanRegistrationAotProcessor processor = new ReflectiveProcessorBeanRegistrationAotProcessor(); - private final GenerationContext generationContext = new DefaultGenerationContext( - new InMemoryGeneratedFiles()); + private final GenerationContext generationContext = new TestGenerationContext(); @Test void shouldIgnoreNonAnnotatedType() { diff --git a/spring-context/src/test/java/org/springframework/context/aot/RuntimeHintsBeanFactoryInitializationAotProcessorTests.java b/spring-context/src/test/java/org/springframework/context/aot/RuntimeHintsBeanFactoryInitializationAotProcessorTests.java index 8ff69e4bd5..16dd7185fb 100644 --- a/spring-context/src/test/java/org/springframework/context/aot/RuntimeHintsBeanFactoryInitializationAotProcessorTests.java +++ b/spring-context/src/test/java/org/springframework/context/aot/RuntimeHintsBeanFactoryInitializationAotProcessorTests.java @@ -25,9 +25,7 @@ import java.util.stream.Stream; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.springframework.aot.generate.DefaultGenerationContext; import org.springframework.aot.generate.GenerationContext; -import org.springframework.aot.generate.InMemoryGeneratedFiles; import org.springframework.aot.hint.ResourceBundleHint; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.RuntimeHintsRegistrar; @@ -38,7 +36,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.ImportRuntimeHints; import org.springframework.context.support.GenericApplicationContext; -import org.springframework.javapoet.ClassName; +import org.springframework.core.testfixture.aot.generate.TestGenerationContext; import org.springframework.lang.Nullable; import static org.assertj.core.api.Assertions.assertThat; @@ -51,17 +49,13 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; */ class RuntimeHintsBeanFactoryInitializationAotProcessorTests { - private static final ClassName MAIN_GENERATED_TYPE = ClassName.get("__", - "TestInitializer"); - private GenerationContext generationContext; private ApplicationContextAotGenerator generator; @BeforeEach void setup() { - this.generationContext = new DefaultGenerationContext( - new InMemoryGeneratedFiles()); + this.generationContext = new TestGenerationContext(); this.generator = new ApplicationContextAotGenerator(); } @@ -70,7 +64,7 @@ class RuntimeHintsBeanFactoryInitializationAotProcessorTests { GenericApplicationContext applicationContext = createApplicationContext( ConfigurationWithHints.class); this.generator.generateApplicationContext(applicationContext, - this.generationContext, MAIN_GENERATED_TYPE); + this.generationContext); assertThatSampleRegistrarContributed(); } @@ -79,7 +73,7 @@ class RuntimeHintsBeanFactoryInitializationAotProcessorTests { GenericApplicationContext applicationContext = createApplicationContext( ConfigurationWithBeanDeclaringHints.class); this.generator.generateApplicationContext(applicationContext, - this.generationContext, MAIN_GENERATED_TYPE); + this.generationContext); assertThatSampleRegistrarContributed(); } @@ -89,7 +83,7 @@ class RuntimeHintsBeanFactoryInitializationAotProcessorTests { applicationContext.setClassLoader( new TestSpringFactoriesClassLoader("test-runtime-hints-aot.factories")); this.generator.generateApplicationContext(applicationContext, - this.generationContext, MAIN_GENERATED_TYPE); + this.generationContext); assertThatSampleRegistrarContributed(); } @@ -104,7 +98,7 @@ class RuntimeHintsBeanFactoryInitializationAotProcessorTests { new TestSpringFactoriesClassLoader("test-duplicated-runtime-hints-aot.factories")); IncrementalRuntimeHintsRegistrar.counter.set(0); this.generator.generateApplicationContext(applicationContext, - this.generationContext, MAIN_GENERATED_TYPE); + this.generationContext); RuntimeHints runtimeHints = this.generationContext.getRuntimeHints(); assertThat(runtimeHints.resources().resourceBundles().map(ResourceBundleHint::getBaseName)) .containsOnly("com.example.example0", "sample"); @@ -116,7 +110,7 @@ class RuntimeHintsBeanFactoryInitializationAotProcessorTests { GenericApplicationContext applicationContext = createApplicationContext( ConfigurationWithIllegalRegistrar.class); assertThatThrownBy(() -> this.generator.generateApplicationContext( - applicationContext, this.generationContext, MAIN_GENERATED_TYPE)) + applicationContext, this.generationContext)) .isInstanceOf(BeanInstantiationException.class); } diff --git a/spring-core/src/main/java/org/springframework/aot/generate/ClassGenerator.java b/spring-core/src/main/java/org/springframework/aot/generate/ClassGenerator.java deleted file mode 100644 index 2accace2a0..0000000000 --- a/spring-core/src/main/java/org/springframework/aot/generate/ClassGenerator.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright 2002-2022 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.aot.generate; - -import java.util.Collection; -import java.util.Collections; - -import org.springframework.javapoet.ClassName; -import org.springframework.javapoet.JavaFile; - -/** - * Generates new {@link GeneratedClass} instances. - * - * @author Phillip Webb - * @since 6.0 - * @see GeneratedMethods - */ -public interface ClassGenerator { - - /** - * Get or generate a new {@link GeneratedClass} for a given java file - * generator, target and feature name. - * @param javaFileGenerator the java file generator - * @param target the target of the newly generated class - * @param featureName the name of the feature that the generated class - * supports - * @return a {@link GeneratedClass} instance - */ - GeneratedClass getOrGenerateClass(JavaFileGenerator javaFileGenerator, - Class target, String featureName); - - - /** - * Strategy used to generate the java file for the generated class. - * Implementations of this interface are included as part of the key used to - * identify classes that have already been created and as such should be - * static final instances or implement a valid - * {@code equals}/{@code hashCode}. - */ - @FunctionalInterface - interface JavaFileGenerator { - - /** - * Generate the file {@link JavaFile} to be written. - * @param className the class name of the file - * @param methods the generated methods that must be included - * @return the generated files - */ - JavaFile generateJavaFile(ClassName className, GeneratedMethods methods); - - /** - * Return method names that must not be generated. - * @return the reserved method names - */ - default Collection getReservedMethodNames() { - return Collections.emptySet(); - } - - } - -} diff --git a/spring-core/src/main/java/org/springframework/aot/generate/ClassNameGenerator.java b/spring-core/src/main/java/org/springframework/aot/generate/ClassNameGenerator.java index ec6d7cfd99..02d4d6bccb 100644 --- a/spring-core/src/main/java/org/springframework/aot/generate/ClassNameGenerator.java +++ b/spring-core/src/main/java/org/springframework/aot/generate/ClassNameGenerator.java @@ -27,10 +27,9 @@ import org.springframework.util.ClassUtils; import org.springframework.util.StringUtils; /** - * Generate unique class names based on an optional target {@link Class} and - * a feature name. This class is stateful so the same instance should be used - * for all name generation. Most commonly the class name generator is obtained - * via a {@link GenerationContext}. + * Generate unique class names based on target {@link Class} and a feature + * name. This class is stateful so the same instance should be used for all + * name generation. * * @author Phillip Webb * @author Stephane Nicoll @@ -40,38 +39,92 @@ public final class ClassNameGenerator { private static final String SEPARATOR = "__"; - private static final String AOT_PACKAGE = "__."; - private static final String AOT_FEATURE = "Aot"; - private final Map sequenceGenerator = new ConcurrentHashMap<>(); + private final Class defaultTarget; + + private final String featureNamePrefix; + + private final Map sequenceGenerator; + + /** + * Create a new instance using the specified {@code defaultTarget} and no + * feature name prefix. + * @param defaultTarget the default target class to use + */ + public ClassNameGenerator(Class defaultTarget) { + this(defaultTarget, ""); + } + + /** + * Create a new instance using the specified {@code defaultTarget} and + * feature name prefix. + * @param defaultTarget the default target class to use + * @param featureNamePrefix the prefix to use to qualify feature names + */ + public ClassNameGenerator(Class defaultTarget, String featureNamePrefix) { + this(defaultTarget, featureNamePrefix, new ConcurrentHashMap<>()); + } + + private ClassNameGenerator(Class defaultTarget, String featureNamePrefix, + Map sequenceGenerator) { + this.defaultTarget = defaultTarget; + this.featureNamePrefix = (!StringUtils.hasText(featureNamePrefix) ? "" : featureNamePrefix); + this.sequenceGenerator = sequenceGenerator; + } /** - * Generate a unique {@link ClassName} based on the specified {@code target} - * class and {@code featureName}. If a {@code target} is specified, the - * generated class name is a suffixed version of it. - *

For instance, a {@code com.example.Demo} target with an - * {@code Initializer} feature name leads to a - * {@code com.example.Demo__Initializer} generated class name. If such a - * feature was already requested for this target, a counter is used to - * ensure uniqueness. - *

If there is no target, the {@code featureName} is used to generate the - * class name in the {@value #AOT_PACKAGE} package. + * Generate a unique {@link ClassName} based on the specified + * {@code featureName} and {@code target}. If the {@code target} is + * {@code null}, the configured main target of this instance is used. + *

The class name is a suffixed version of the target. For instance, a + * {@code com.example.Demo} target with an {@code Initializer} feature name + * leads to a {@code com.example.Demo__Initializer} generated class name. + * The feature name is qualified by the configured feature name prefix, + * if any. + *

Generated class names are unique. If such a feature was already + * requested for this target, a counter is used to ensure uniqueness. * @param target the class the newly generated class relates to, or - * {@code null} if there is not target + * {@code null} to use the main target * @param featureName the name of the feature that the generated class * supports * @return a unique generated class name */ public ClassName generateClassName(@Nullable Class target, String featureName) { + return generateSequencedClassName(getClassName(target, featureName)); + } + + /** + * Return a class name based on the specified {@code target} and + * {@code featureName}. This uses the same algorithm as + * {@link #generateClassName(Class, String)} but does not register + * the class name, nor add a unique suffix to it if necessary. + * @param target the class the newly generated class relates to, or + * {@code null} to use the main target + * @param featureName the name of the feature that the generated class + * supports + * @return the class name + */ + String getClassName(@Nullable Class target, String featureName) { Assert.hasLength(featureName, "'featureName' must not be empty"); featureName = clean(featureName); - if (target != null) { - return generateSequencedClassName(target.getName().replace("$", "_") - + SEPARATOR + StringUtils.capitalize(featureName)); - } - return generateSequencedClassName(AOT_PACKAGE + featureName); + Class targetToUse = (target != null ? target : this.defaultTarget); + String featureNameToUse = this.featureNamePrefix + featureName; + return targetToUse.getName().replace("$", "_") + + SEPARATOR + StringUtils.capitalize(featureNameToUse); + } + + /** + * Return a new {@link ClassNameGenerator} instance for the specified + * feature name prefix, keeping track of all the class names generated + * by this instance. + * @param featureNamePrefix the feature name prefix to use + * @return a new instance for the specified feature name prefix + */ + ClassNameGenerator usingFeatureNamePrefix(String featureNamePrefix) { + return new ClassNameGenerator(this.defaultTarget, featureNamePrefix, + this.sequenceGenerator); } private String clean(String name) { diff --git a/spring-core/src/main/java/org/springframework/aot/generate/DefaultGenerationContext.java b/spring-core/src/main/java/org/springframework/aot/generate/DefaultGenerationContext.java index 4698d28dbb..dd500d0b84 100644 --- a/spring-core/src/main/java/org/springframework/aot/generate/DefaultGenerationContext.java +++ b/spring-core/src/main/java/org/springframework/aot/generate/DefaultGenerationContext.java @@ -17,12 +17,15 @@ package org.springframework.aot.generate; import java.io.IOException; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; import org.springframework.aot.hint.RuntimeHints; import org.springframework.util.Assert; /** - * Default implementation of {@link GenerationContext}. + * Default {@link GenerationContext} implementation. * * @author Phillip Webb * @author Stephane Nicoll @@ -30,7 +33,7 @@ import org.springframework.util.Assert; */ public class DefaultGenerationContext implements GenerationContext { - private final ClassNameGenerator classNameGenerator; + private final Map sequenceGenerator; private final GeneratedClasses generatedClasses; @@ -41,39 +44,45 @@ public class DefaultGenerationContext implements GenerationContext { /** * Create a new {@link DefaultGenerationContext} instance backed by the - * specified {@code generatedFiles}. + * specified {@link ClassNameGenerator} and {@link GeneratedFiles}. + * @param classNameGenerator the naming convention to use for generated + * class names * @param generatedFiles the generated files */ - public DefaultGenerationContext(GeneratedFiles generatedFiles) { - this(new ClassNameGenerator(), generatedFiles, new RuntimeHints()); + public DefaultGenerationContext(ClassNameGenerator classNameGenerator, GeneratedFiles generatedFiles) { + this(new GeneratedClasses(classNameGenerator), generatedFiles, new RuntimeHints()); } /** * Create a new {@link DefaultGenerationContext} instance backed by the * specified items. - * @param classNameGenerator the class name generator + * @param generatedClasses the generated classes * @param generatedFiles the generated files * @param runtimeHints the runtime hints */ - public DefaultGenerationContext(ClassNameGenerator classNameGenerator, + public DefaultGenerationContext(GeneratedClasses generatedClasses, GeneratedFiles generatedFiles, RuntimeHints runtimeHints) { - Assert.notNull(classNameGenerator, "'classNameGenerator' must not be null"); + Assert.notNull(generatedClasses, "'generatedClasses' must not be null"); Assert.notNull(generatedFiles, "'generatedFiles' must not be null"); Assert.notNull(runtimeHints, "'runtimeHints' must not be null"); - this.classNameGenerator = classNameGenerator; - this.generatedClasses = new GeneratedClasses(classNameGenerator); + this.sequenceGenerator = new ConcurrentHashMap<>(); + this.generatedClasses = generatedClasses; this.generatedFiles = generatedFiles; this.runtimeHints = runtimeHints; } - - @Override - public ClassNameGenerator getClassNameGenerator() { - return this.classNameGenerator; + private DefaultGenerationContext(DefaultGenerationContext existing, String name) { + int sequence = existing.sequenceGenerator + .computeIfAbsent(name, key -> new AtomicInteger()).getAndIncrement(); + String nameToUse = (sequence > 0 ? name + sequence : name); + this.sequenceGenerator = existing.sequenceGenerator; + this.generatedClasses = existing.generatedClasses.withName(nameToUse); + this.generatedFiles = existing.generatedFiles; + this.runtimeHints = existing.runtimeHints; } @Override - public GeneratedClasses getClassGenerator() { + public GeneratedClasses getGeneratedClasses() { return this.generatedClasses; } @@ -87,6 +96,11 @@ public class DefaultGenerationContext implements GenerationContext { return this.runtimeHints; } + @Override + public GenerationContext withName(String name) { + return new DefaultGenerationContext(this, name); + } + /** * Write any generated content out to the generated files. */ diff --git a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClass.java b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClass.java index caef4730e0..12c34c76c9 100644 --- a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClass.java +++ b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClass.java @@ -16,22 +16,24 @@ package org.springframework.aot.generate; -import org.springframework.aot.generate.ClassGenerator.JavaFileGenerator; +import java.util.function.Consumer; + import org.springframework.javapoet.ClassName; import org.springframework.javapoet.JavaFile; -import org.springframework.util.Assert; +import org.springframework.javapoet.TypeSpec; +import org.springframework.javapoet.TypeSpec.Builder; /** - * A generated class. + * A generated class is a container for generated methods. * * @author Phillip Webb + * @author Stephane Nicoll * @since 6.0 * @see GeneratedClasses - * @see ClassGenerator */ public final class GeneratedClass { - private final JavaFileGenerator JavaFileGenerator; + private final Consumer typeSpecCustomizer; private final ClassName name; @@ -44,12 +46,10 @@ public final class GeneratedClass { * {@link GeneratedClasses}. * @param name the generated name */ - GeneratedClass(JavaFileGenerator javaFileGenerator, ClassName name) { - MethodNameGenerator methodNameGenerator = new MethodNameGenerator( - javaFileGenerator.getReservedMethodNames()); - this.JavaFileGenerator = javaFileGenerator; + GeneratedClass(Consumer typeSpecCustomizer, ClassName name) { + this.typeSpecCustomizer = typeSpecCustomizer; this.name = name; - this.methods = new GeneratedMethods(methodNameGenerator); + this.methods = new GeneratedMethods(new MethodNameGenerator()); } @@ -70,15 +70,11 @@ public final class GeneratedClass { } JavaFile generateJavaFile() { - JavaFile javaFile = this.JavaFileGenerator.generateJavaFile(this.name, - this.methods); - Assert.state(this.name.packageName().equals(javaFile.packageName), - () -> "Generated JavaFile should be in package '" - + this.name.packageName() + "'"); - Assert.state(this.name.simpleName().equals(javaFile.typeSpec.name), - () -> "Generated JavaFile should be named '" + this.name.simpleName() - + "'"); - return javaFile; + TypeSpec.Builder typeSpecBuilder = TypeSpec.classBuilder(this.name); + this.typeSpecCustomizer.accept(typeSpecBuilder); + this.methods.doWithMethodSpecs(typeSpecBuilder::addMethod); + return JavaFile.builder(this.name.packageName(), typeSpecBuilder.build()) + .build(); } } diff --git a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClasses.java b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClasses.java index 09e654b3a1..7d86b7faf0 100644 --- a/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClasses.java +++ b/spring-core/src/main/java/org/springframework/aot/generate/GeneratedClasses.java @@ -22,59 +22,143 @@ import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.TypeSpec; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; /** - * A managed collection of generated classes. + * A managed collection of generated classes. This class is stateful so the + * same instance should be used for all class generation. * * @author Phillip Webb + * @author Stephane Nicoll * @since 6.0 * @see GeneratedClass */ -public class GeneratedClasses implements ClassGenerator { +public class GeneratedClasses { private final ClassNameGenerator classNameGenerator; - private final Map classes = new ConcurrentHashMap<>(); + private final List classes; + private final Map classesByOwner; + /** + * Create a new instance using the specified naming conventions. + * @param classNameGenerator the class name generator to use + */ public GeneratedClasses(ClassNameGenerator classNameGenerator) { - Assert.notNull(classNameGenerator, "'classNameGenerator' must not be null"); - this.classNameGenerator = classNameGenerator; + this(classNameGenerator, new ArrayList<>(), new ConcurrentHashMap<>()); } - - @Override - public GeneratedClass getOrGenerateClass(JavaFileGenerator javaFileGenerator, - Class target, String featureName) { - - Assert.notNull(javaFileGenerator, "'javaFileGenerator' must not be null"); - Assert.notNull(target, "'target' must not be null"); - Assert.hasLength(featureName, "'featureName' must not be empty"); - Owner owner = new Owner(javaFileGenerator, target.getName(), featureName); - return this.classes.computeIfAbsent(owner, - key -> new GeneratedClass(javaFileGenerator, - this.classNameGenerator.generateClassName(target, featureName))); + private GeneratedClasses(ClassNameGenerator classNameGenerator, + List classes, Map classesByOwner) { + Assert.notNull(classNameGenerator, "'classNameGenerator' must not be null"); + this.classNameGenerator = classNameGenerator; + this.classes = classes; + this.classesByOwner = classesByOwner; } /** - * Write generated Spring {@code .factories} files to the given + * Prepare a {@link GeneratedClass} for the specified {@code featureName} + * targeting the specified {@code component}. + * @param featureName the name of the feature to associate with the generated class + * @param component the target component + * @return a {@link Builder} for further configuration + */ + public Builder forFeatureComponent(String featureName, Class component) { + Assert.hasLength(featureName, "'featureName' must not be empty"); + Assert.notNull(component, "'component' must not be null"); + return new Builder(featureName, component); + } + + /** + * Prepare a {@link GeneratedClass} for the specified {@code featureName} + * and no particular component. This should be used for high-level code + * generation that are widely applicable and for entry points. + * @param featureName the name of the feature to associate with the generated class + * @return a {@link Builder} for further configuration + */ + public Builder forFeature(String featureName) { + Assert.hasLength(featureName, "'featureName' must not be empty"); + return new Builder(featureName, null); + } + + /** + * Write the {@link GeneratedClass generated classes} using the given * {@link GeneratedFiles} instance. - * @param generatedFiles where to write the generated files + * @param generatedFiles where to write the generated classes * @throws IOException on IO error */ public void writeTo(GeneratedFiles generatedFiles) throws IOException { Assert.notNull(generatedFiles, "'generatedFiles' must not be null"); - List generatedClasses = new ArrayList<>(this.classes.values()); + List generatedClasses = new ArrayList<>(this.classes); generatedClasses.sort(Comparator.comparing(GeneratedClass::getName)); for (GeneratedClass generatedClass : generatedClasses) { generatedFiles.addSourceFile(generatedClass.generateJavaFile()); } } - private record Owner(JavaFileGenerator javaFileGenerator, String target, - String featureName) { + GeneratedClasses withName(String name) { + return new GeneratedClasses(this.classNameGenerator.usingFeatureNamePrefix(name), + this.classes, this.classesByOwner); + } + + private record Owner(String id, String className) { + + } + + public class Builder { + + private final String featureName; + + @Nullable + private final Class target; + + + Builder(String featureName, @Nullable Class target) { + this.target = target; + this.featureName = featureName; + } + + /** + * Generate a new {@link GeneratedClass} using the specified type + * customizer. + * @param typeSpecCustomizer a customizer for the {@link TypeSpec.Builder} + * @return a new {@link GeneratedClass} + */ + public GeneratedClass generate(Consumer typeSpecCustomizer) { + Assert.notNull(typeSpecCustomizer, "'typeSpecCustomizer' must not be null"); + return createGeneratedClass(typeSpecCustomizer); + } + + + /** + * Get or generate a new {@link GeneratedClass} for the specified {@code id}. + * @param id a unique identifier + * @param typeSpecCustomizer a customizer for the {@link TypeSpec.Builder} + * @return a {@link GeneratedClass} instance + */ + public GeneratedClass getOrGenerate(String id, + Consumer typeSpecCustomizer) { + Assert.hasLength(id, "'id' must not be empty"); + Assert.notNull(typeSpecCustomizer, "'typeSpecCustomizer' must not be null"); + Owner owner = new Owner(id, GeneratedClasses.this.classNameGenerator + .getClassName(this.target, this.featureName)); + return GeneratedClasses.this.classesByOwner.computeIfAbsent(owner, + key -> createGeneratedClass(typeSpecCustomizer)); + } + + private GeneratedClass createGeneratedClass(Consumer typeSpecCustomizer) { + ClassName className = GeneratedClasses.this.classNameGenerator + .generateClassName(this.target, this.featureName); + GeneratedClass generatedClass = new GeneratedClass(typeSpecCustomizer, className); + GeneratedClasses.this.classes.add(generatedClass); + return generatedClass; + } } diff --git a/spring-core/src/main/java/org/springframework/aot/generate/GenerationContext.java b/spring-core/src/main/java/org/springframework/aot/generate/GenerationContext.java index 2cd1c770bf..d74d97e98f 100644 --- a/spring-core/src/main/java/org/springframework/aot/generate/GenerationContext.java +++ b/spring-core/src/main/java/org/springframework/aot/generate/GenerationContext.java @@ -24,38 +24,31 @@ import org.springframework.aot.hint.SerializationHints; /** * Central interface used for code generation. - *

- * A generation context provides: + * + *

A generation context provides: *

    - *
  • Support for {@link #getClassNameGenerator() class name generation}.
  • - *
  • Central management of all {@link #getGeneratedFiles() generated - * files}.
  • - *
  • Support for the recording of {@link #getRuntimeHints() runtime - * hints}.
  • + *
  • Management of all {@link #getGeneratedClasses()} generated classes}, + * including naming convention support.
  • + *
  • Central management of all {@link #getGeneratedFiles() generated files}.
  • + *
  • Support for the recording of {@link #getRuntimeHints() runtime hints}.
  • *
* + *

If a dedicated round of code generation is required while processing, it + * is possible to create a specialized context using {@link #withName(String)}. + * * @author Phillip Webb * @author Stephane Nicoll * @since 6.0 */ public interface GenerationContext { - /** - * Return the {@link ClassNameGenerator} being used by the context. Allows - * new class names to be generated before they are added to the - * {@link #getGeneratedFiles() generated files}. - * @return the class name generator - * @see #getGeneratedFiles() - */ - ClassNameGenerator getClassNameGenerator(); - /** * Return the {@link GeneratedClasses} being used by the context. Allows a * single generated class to be shared across multiple AOT processors. All * generated classes are written at the end of AOT processing. * @return the generated classes */ - ClassGenerator getClassGenerator(); + GeneratedClasses getGeneratedClasses(); /** * Return the {@link GeneratedFiles} being used by the context. Used to @@ -73,4 +66,14 @@ public interface GenerationContext { */ RuntimeHints getRuntimeHints(); + /** + * Return a new {@link GenerationContext} instance using the specified + * name to qualify generated assets for a dedicated round of code + * generation. If this name is already in use, a unique sequence is added + * to ensure the name is unique. + * @param name the name to use + * @return a specialized {@link GenerationContext} for the specified name + */ + GenerationContext withName(String name); + } diff --git a/spring-core/src/test/java/org/springframework/aot/generate/ClassNameGeneratorTests.java b/spring-core/src/test/java/org/springframework/aot/generate/ClassNameGeneratorTests.java index f45c3e1618..ae8354341d 100644 --- a/spring-core/src/test/java/org/springframework/aot/generate/ClassNameGeneratorTests.java +++ b/spring-core/src/test/java/org/springframework/aot/generate/ClassNameGeneratorTests.java @@ -32,12 +32,26 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException */ class ClassNameGeneratorTests { - private final ClassNameGenerator generator = new ClassNameGenerator(); + private final ClassNameGenerator generator = new ClassNameGenerator(Object.class); @Test - void generateClassNameWhenTargetClassIsNullUsesAotPackage() { - ClassName generated = this.generator.generateClassName((Class) null, "test"); - assertThat(generated).hasToString("__.Test"); + void generateClassNameWhenTargetClassIsNullUsesMainTarget() { + ClassName generated = this.generator.generateClassName(null, "test"); + assertThat(generated).hasToString("java.lang.Object__Test"); + } + + @Test + void generateClassNameUseFeatureNamePrefix() { + ClassName generated = new ClassNameGenerator(Object.class, "One") + .generateClassName(InputStream.class, "test"); + assertThat(generated).hasToString("java.io.InputStream__OneTest"); + } + + @Test + void generateClassNameWithNoTextFeatureNamePrefix() { + ClassName generated = new ClassNameGenerator(Object.class, " ") + .generateClassName(InputStream.class, "test"); + assertThat(generated).hasToString("java.io.InputStream__Test"); } @Test @@ -59,8 +73,7 @@ class ClassNameGeneratorTests { @Test void generateClassNameWithClassWhenLowercaseFeatureNameGeneratesName() { - ClassName generated = this.generator.generateClassName(InputStream.class, - "bytes"); + ClassName generated = this.generator.generateClassName(InputStream.class, "bytes"); assertThat(generated).hasToString("java.io.InputStream__Bytes"); } @@ -68,7 +81,7 @@ class ClassNameGeneratorTests { void generateClassNameWithClassWhenInnerClassGeneratesName() { ClassName generated = this.generator.generateClassName(TestBean.class, "EventListener"); assertThat(generated) - .hasToString("org.springframework.aot.generate.ClassNameGeneratorTests_TestBean__EventListener"); + .hasToString("org.springframework.aot.generate.ClassNameGeneratorTests_TestBean__EventListener"); } @Test @@ -81,6 +94,15 @@ class ClassNameGeneratorTests { assertThat(generated3).hasToString("java.io.InputStream__Bytes2"); } + @Test + void getClassNameWhenMultipleCallsReturnsSameName() { + String name1 = this.generator.getClassName(InputStream.class, "bytes"); + String name2 = this.generator.getClassName(InputStream.class, "bytes"); + String name3 = this.generator.getClassName(InputStream.class, "bytes"); + assertThat(name1).hasToString("java.io.InputStream__Bytes") + .isEqualTo(name2).isEqualTo(name3); + } + static class TestBean { } diff --git a/spring-core/src/test/java/org/springframework/aot/generate/DefaultGenerationContextTests.java b/spring-core/src/test/java/org/springframework/aot/generate/DefaultGenerationContextTests.java index a6d84e87eb..306a5ae047 100644 --- a/spring-core/src/test/java/org/springframework/aot/generate/DefaultGenerationContextTests.java +++ b/spring-core/src/test/java/org/springframework/aot/generate/DefaultGenerationContextTests.java @@ -16,9 +16,14 @@ package org.springframework.aot.generate; +import java.util.function.Consumer; + import org.junit.jupiter.api.Test; +import org.springframework.aot.generate.GeneratedFiles.Kind; import org.springframework.aot.hint.RuntimeHints; +import org.springframework.core.testfixture.aot.generate.TestTarget; +import org.springframework.javapoet.TypeSpec.Builder; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; @@ -31,9 +36,12 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException */ class DefaultGenerationContextTests { - private final ClassNameGenerator classNameGenerator = new ClassNameGenerator(); + private static final Consumer typeSpecCustomizer = type -> {}; - private final GeneratedFiles generatedFiles = new InMemoryGeneratedFiles(); + private final GeneratedClasses generatedClasses = new GeneratedClasses( + new ClassNameGenerator(TestTarget.class)); + + private final InMemoryGeneratedFiles generatedFiles = new InMemoryGeneratedFiles(); private final RuntimeHints runtimeHints = new RuntimeHints(); @@ -41,9 +49,7 @@ class DefaultGenerationContextTests { @Test void createWithOnlyGeneratedFilesCreatesContext() { DefaultGenerationContext context = new DefaultGenerationContext( - this.generatedFiles); - assertThat(context.getClassNameGenerator()) - .isInstanceOf(ClassNameGenerator.class); + new ClassNameGenerator(TestTarget.class), this.generatedFiles); assertThat(context.getGeneratedFiles()).isSameAs(this.generatedFiles); assertThat(context.getRuntimeHints()).isInstanceOf(RuntimeHints.class); } @@ -51,24 +57,23 @@ class DefaultGenerationContextTests { @Test void createCreatesContext() { DefaultGenerationContext context = new DefaultGenerationContext( - this.classNameGenerator, this.generatedFiles, this.runtimeHints); - assertThat(context.getClassNameGenerator()).isNotNull(); + this.generatedClasses, this.generatedFiles, this.runtimeHints); assertThat(context.getGeneratedFiles()).isNotNull(); assertThat(context.getRuntimeHints()).isNotNull(); } @Test - void createWhenClassNameGeneratorIsNullThrowsException() { + void createWhenGeneratedClassesIsNullThrowsException() { assertThatIllegalArgumentException() .isThrownBy(() -> new DefaultGenerationContext(null, this.generatedFiles, this.runtimeHints)) - .withMessage("'classNameGenerator' must not be null"); + .withMessage("'generatedClasses' must not be null"); } @Test void createWhenGeneratedFilesIsNullThrowsException() { assertThatIllegalArgumentException() - .isThrownBy(() -> new DefaultGenerationContext(this.classNameGenerator, + .isThrownBy(() -> new DefaultGenerationContext(this.generatedClasses, null, this.runtimeHints)) .withMessage("'generatedFiles' must not be null"); } @@ -76,30 +81,71 @@ class DefaultGenerationContextTests { @Test void createWhenRuntimeHintsIsNullThrowsException() { assertThatIllegalArgumentException() - .isThrownBy(() -> new DefaultGenerationContext(this.classNameGenerator, + .isThrownBy(() -> new DefaultGenerationContext(this.generatedClasses, this.generatedFiles, null)) .withMessage("'runtimeHints' must not be null"); } @Test - void getClassNameGeneratorReturnsClassNameGenerator() { + void getGeneratedClassesReturnsClassNameGenerator() { DefaultGenerationContext context = new DefaultGenerationContext( - this.classNameGenerator, this.generatedFiles, this.runtimeHints); - assertThat(context.getClassNameGenerator()).isSameAs(this.classNameGenerator); + this.generatedClasses, this.generatedFiles, this.runtimeHints); + assertThat(context.getGeneratedClasses()).isSameAs(this.generatedClasses); } @Test void getGeneratedFilesReturnsGeneratedFiles() { DefaultGenerationContext context = new DefaultGenerationContext( - this.classNameGenerator, this.generatedFiles, this.runtimeHints); + this.generatedClasses, this.generatedFiles, this.runtimeHints); assertThat(context.getGeneratedFiles()).isSameAs(this.generatedFiles); } @Test void getRuntimeHintsReturnsRuntimeHints() { DefaultGenerationContext context = new DefaultGenerationContext( - this.classNameGenerator, this.generatedFiles, this.runtimeHints); + this.generatedClasses, this.generatedFiles, this.runtimeHints); assertThat(context.getRuntimeHints()).isSameAs(this.runtimeHints); } + @Test + void withNameUpdateNamingConvention() { + DefaultGenerationContext context = new DefaultGenerationContext( + new ClassNameGenerator(TestTarget.class), this.generatedFiles); + GenerationContext anotherContext = context.withName("Another"); + GeneratedClass generatedClass = anotherContext.getGeneratedClasses() + .forFeature("Test").generate(typeSpecCustomizer); + assertThat(generatedClass.getName().simpleName()).endsWith("__AnotherTest"); + } + + @Test + void withNameKeepTrackOfAllGeneratedFiles() { + DefaultGenerationContext context = new DefaultGenerationContext( + new ClassNameGenerator(TestTarget.class), this.generatedFiles); + context.getGeneratedClasses().forFeature("Test").generate(typeSpecCustomizer); + GenerationContext anotherContext = context.withName("Another"); + assertThat(anotherContext.getGeneratedClasses()).isNotSameAs(context.getGeneratedClasses()); + assertThat(anotherContext.getGeneratedFiles()).isSameAs(context.getGeneratedFiles()); + assertThat(anotherContext.getRuntimeHints()).isSameAs(context.getRuntimeHints()); + anotherContext.getGeneratedClasses().forFeature("Test").generate(typeSpecCustomizer); + context.writeGeneratedContent(); + assertThat(this.generatedFiles.getGeneratedFiles(Kind.SOURCE)).hasSize(2); + } + + @Test + void withNameGenerateUniqueName() { + DefaultGenerationContext context = new DefaultGenerationContext( + new ClassNameGenerator(Object.class), this.generatedFiles); + context.withName("Test").getGeneratedClasses() + .forFeature("Feature").generate(typeSpecCustomizer); + context.withName("Test").getGeneratedClasses() + .forFeature("Feature").generate(typeSpecCustomizer); + context.withName("Test").getGeneratedClasses() + .forFeature("Feature").generate(typeSpecCustomizer); + context.writeGeneratedContent(); + assertThat(this.generatedFiles.getGeneratedFiles(Kind.SOURCE)).containsOnlyKeys( + "java/lang/Object__TestFeature.java", + "java/lang/Object__Test1Feature.java", + "java/lang/Object__Test2Feature.java"); + } + } diff --git a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassTests.java b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassTests.java index 35ef736c8c..df0245c652 100644 --- a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassTests.java +++ b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassTests.java @@ -16,77 +16,43 @@ package org.springframework.aot.generate; +import java.util.function.Consumer; + import org.junit.jupiter.api.Test; import org.springframework.javapoet.ClassName; -import org.springframework.javapoet.JavaFile; -import org.springframework.javapoet.TypeSpec; +import org.springframework.javapoet.TypeSpec.Builder; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatIllegalStateException; /** * Tests for {@link GeneratedClass}. * * @author Phillip Webb + * @author Stephane Nicoll */ class GeneratedClassTests { @Test void getNameReturnsName() { ClassName name = ClassName.bestGuess("com.example.Test"); - GeneratedClass generatedClass = new GeneratedClass(this::generateJavaFile, name); + GeneratedClass generatedClass = new GeneratedClass(emptyTypeSpec(), name); assertThat(generatedClass.getName()).isSameAs(name); } @Test - void generateJavaFileSuppliesGeneratedMethods() { + void generateJavaFileIncludesGeneratedMethods() { ClassName name = ClassName.bestGuess("com.example.Test"); - GeneratedClass generatedClass = new GeneratedClass(this::generateJavaFile, name); + GeneratedClass generatedClass = new GeneratedClass(emptyTypeSpec(), name); MethodGenerator methodGenerator = generatedClass.getMethodGenerator(); methodGenerator.generateMethod("test") .using(builder -> builder.addJavadoc("Test Method")); assertThat(generatedClass.generateJavaFile().toString()).contains("Test Method"); } - @Test - void generateJavaFileWhenHasBadPackageThrowsException() { - ClassName name = ClassName.bestGuess("com.example.Test"); - GeneratedClass generatedClass = new GeneratedClass( - this::generateBadPackageJavaFile, name); - assertThatIllegalStateException() - .isThrownBy( - () -> assertThat(generatedClass.generateJavaFile().toString())) - .withMessageContaining("should be in package"); - } - @Test - void generateJavaFileWhenHasBadNameThrowsException() { - ClassName name = ClassName.bestGuess("com.example.Test"); - GeneratedClass generatedClass = new GeneratedClass(this::generateBadNameJavaFile, - name); - assertThatIllegalStateException() - .isThrownBy( - () -> assertThat(generatedClass.generateJavaFile().toString())) - .withMessageContaining("should be named"); - } - - private JavaFile generateJavaFile(ClassName className, GeneratedMethods methods) { - TypeSpec.Builder classBuilder = TypeSpec.classBuilder(className); - methods.doWithMethodSpecs(classBuilder::addMethod); - return JavaFile.builder(className.packageName(), classBuilder.build()).build(); - } - - private JavaFile generateBadPackageJavaFile(ClassName className, - GeneratedMethods methods) { - TypeSpec.Builder classBuilder = TypeSpec.classBuilder(className); - return JavaFile.builder("naughty", classBuilder.build()).build(); - } - - private JavaFile generateBadNameJavaFile(ClassName className, - GeneratedMethods methods) { - TypeSpec.Builder classBuilder = TypeSpec.classBuilder("Naughty"); - return JavaFile.builder(className.packageName(), classBuilder.build()).build(); + private Consumer emptyTypeSpec() { + return type -> {}; } } diff --git a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassesTests.java b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassesTests.java index 420e610f19..7aca293319 100644 --- a/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassesTests.java +++ b/spring-core/src/test/java/org/springframework/aot/generate/GeneratedClassesTests.java @@ -16,27 +16,34 @@ package org.springframework.aot.generate; +import java.io.IOException; +import java.util.function.Consumer; + import org.junit.jupiter.api.Test; -import org.springframework.aot.generate.ClassGenerator.JavaFileGenerator; -import org.springframework.javapoet.ClassName; -import org.springframework.javapoet.JavaFile; +import org.springframework.aot.generate.GeneratedFiles.Kind; import org.springframework.javapoet.TypeSpec; +import org.springframework.javapoet.TypeSpec.Builder; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; /** * Tests for {@link GeneratedClasses}. * * @author Phillip Webb + * @author Stephane Nicoll */ class GeneratedClassesTests { - private GeneratedClasses generatedClasses = new GeneratedClasses( - new ClassNameGenerator()); + private static final Consumer emptyTypeCustomizer = type -> {}; - private static final JavaFileGenerator JAVA_FILE_GENERATOR = GeneratedClassesTests::generateJavaFile; + private final GeneratedClasses generatedClasses = new GeneratedClasses( + new ClassNameGenerator(Object.class)); @Test void createWhenClassNameGeneratorIsNullThrowsException() { @@ -45,61 +52,118 @@ class GeneratedClassesTests { } @Test - void getOrGenerateWithClassTargetWhenJavaFileGeneratorIsNullThrowsException() { + void forFeatureComponentWhenTargetIsNullThrowsException() { assertThatIllegalArgumentException() - .isThrownBy(() -> this.generatedClasses.getOrGenerateClass(null, - TestTarget.class, "test")) - .withMessage("'javaFileGenerator' must not be null"); + .isThrownBy(() -> this.generatedClasses.forFeatureComponent("test", null)) + .withMessage("'component' must not be null"); } @Test - void getOrGenerateWithClassTargetWhenTargetIsNullThrowsException() { + void forFeatureComponentWhenFeatureNameIsEmptyThrowsException() { assertThatIllegalArgumentException() - .isThrownBy(() -> this.generatedClasses - .getOrGenerateClass(JAVA_FILE_GENERATOR, (Class) null, "test")) - .withMessage("'target' must not be null"); - } - - @Test - void getOrGenerateWithClassTargetWhenFeatureIsNullThrowsException() { - assertThatIllegalArgumentException() - .isThrownBy(() -> this.generatedClasses - .getOrGenerateClass(JAVA_FILE_GENERATOR, TestTarget.class, null)) + .isThrownBy(() -> this.generatedClasses.forFeatureComponent("", TestComponent.class)) .withMessage("'featureName' must not be empty"); } @Test - void getOrGenerateWhenNewReturnsGeneratedMethod() { + void forFeatureWhenFeatureNameIsEmptyThrowsException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.generatedClasses.forFeature("")) + .withMessage("'featureName' must not be empty"); + } + + @Test + void generateWhenTypeSpecCustomizerIsNullThrowsException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.generatedClasses + .forFeatureComponent("test", TestComponent.class).generate(null)) + .withMessage("'typeSpecCustomizer' must not be null"); + } + + @Test + void forFeatureUsesDefaultTarget() { + GeneratedClass generatedClass = this.generatedClasses + .forFeature("Test").generate(emptyTypeCustomizer); + assertThat(generatedClass.getName()).hasToString("java.lang.Object__Test"); + } + + @Test + void forFeatureComponentUsesComponent() { + GeneratedClass generatedClass = this.generatedClasses + .forFeatureComponent("Test", TestComponent.class).generate(emptyTypeCustomizer); + assertThat(generatedClass.getName().toString()).endsWith("TestComponent__Test"); + } + + @Test + void generateReturnsDifferentInstances() { + Consumer typeCustomizer = mockTypeCustomizer(); GeneratedClass generatedClass1 = this.generatedClasses - .getOrGenerateClass(JAVA_FILE_GENERATOR, TestTarget.class, "one"); + .forFeatureComponent("one", TestComponent.class).generate(typeCustomizer); GeneratedClass generatedClass2 = this.generatedClasses - .getOrGenerateClass(JAVA_FILE_GENERATOR, TestTarget.class, "two"); + .forFeatureComponent("one", TestComponent.class).generate(typeCustomizer); + assertThat(generatedClass1).isNotSameAs(generatedClass2); + assertThat(generatedClass1.getName().simpleName()).endsWith("__One"); + assertThat(generatedClass2.getName().simpleName()).endsWith("__One1"); + } + + @Test + void getOrGenerateWhenNewReturnsGeneratedMethod() { + Consumer typeCustomizer = mockTypeCustomizer(); + GeneratedClass generatedClass1 = this.generatedClasses + .forFeatureComponent("one", TestComponent.class).getOrGenerate("facet", typeCustomizer); + GeneratedClass generatedClass2 = this.generatedClasses + .forFeatureComponent("two", TestComponent.class).getOrGenerate("facet", typeCustomizer); assertThat(generatedClass1).isNotNull().isNotEqualTo(generatedClass2); assertThat(generatedClass2).isNotNull(); } @Test void getOrGenerateWhenRepeatReturnsSameGeneratedMethod() { - GeneratedClasses generated = this.generatedClasses; - GeneratedClass generatedClass1 = generated.getOrGenerateClass(JAVA_FILE_GENERATOR, - TestTarget.class, "one"); - GeneratedClass generatedClass2 = generated.getOrGenerateClass(JAVA_FILE_GENERATOR, - TestTarget.class, "one"); - GeneratedClass generatedClass3 = generated.getOrGenerateClass(JAVA_FILE_GENERATOR, - TestTarget.class, "one"); - GeneratedClass generatedClass4 = generated.getOrGenerateClass(JAVA_FILE_GENERATOR, - TestTarget.class, "two"); + Consumer typeCustomizer = mockTypeCustomizer(); + GeneratedClass generatedClass1 = this.generatedClasses + .forFeatureComponent("one", TestComponent.class).getOrGenerate("facet", typeCustomizer); + GeneratedClass generatedClass2 = this.generatedClasses + .forFeatureComponent("one", TestComponent.class).getOrGenerate("facet", typeCustomizer); + GeneratedClass generatedClass3 = this.generatedClasses + .forFeatureComponent("one", TestComponent.class).getOrGenerate("facet", typeCustomizer); assertThat(generatedClass1).isNotNull().isSameAs(generatedClass2) - .isSameAs(generatedClass3).isNotSameAs(generatedClass4); + .isSameAs(generatedClass3); + verifyNoInteractions(typeCustomizer); + generatedClass1.generateJavaFile(); + verify(typeCustomizer).accept(any()); } - static JavaFile generateJavaFile(ClassName className, - GeneratedMethods generatedMethods) { - TypeSpec typeSpec = TypeSpec.classBuilder(className).addJavadoc("Test").build(); - return JavaFile.builder(className.packageName(), typeSpec).build(); + @Test + @SuppressWarnings("unchecked") + void writeToInvokeTypeSpecCustomizer() throws IOException { + Consumer typeSpecCustomizer = mock(Consumer.class); + this.generatedClasses.forFeatureComponent("one", TestComponent.class) + .generate(typeSpecCustomizer); + verifyNoInteractions(typeSpecCustomizer); + InMemoryGeneratedFiles generatedFiles = new InMemoryGeneratedFiles(); + this.generatedClasses.writeTo(generatedFiles); + verify(typeSpecCustomizer).accept(any()); + assertThat(generatedFiles.getGeneratedFiles(Kind.SOURCE)).hasSize(1); } - private static class TestTarget { + @Test + void withNameUpdatesNamingConventions() { + GeneratedClass generatedClass1 = this.generatedClasses + .forFeatureComponent("one", TestComponent.class).generate(emptyTypeCustomizer); + GeneratedClass generatedClass2 = this.generatedClasses.withName("Another") + .forFeatureComponent("one", TestComponent.class).generate(emptyTypeCustomizer); + assertThat(generatedClass1.getName().toString()).endsWith("TestComponent__One"); + assertThat(generatedClass2.getName().toString()).endsWith("TestComponent__AnotherOne"); + } + + + @SuppressWarnings("unchecked") + private Consumer mockTypeCustomizer() { + return mock(Consumer.class); + } + + + private static class TestComponent { } diff --git a/spring-core/src/testFixtures/java/org/springframework/core/testfixture/aot/generate/TestGenerationContext.java b/spring-core/src/testFixtures/java/org/springframework/core/testfixture/aot/generate/TestGenerationContext.java new file mode 100644 index 0000000000..ef50d4d4ca --- /dev/null +++ b/spring-core/src/testFixtures/java/org/springframework/core/testfixture/aot/generate/TestGenerationContext.java @@ -0,0 +1,40 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.core.testfixture.aot.generate; + +import org.springframework.aot.generate.ClassNameGenerator; +import org.springframework.aot.generate.DefaultGenerationContext; +import org.springframework.aot.generate.GeneratedFiles; +import org.springframework.aot.generate.GenerationContext; +import org.springframework.aot.generate.InMemoryGeneratedFiles; + +/** + * Test {@link GenerationContext} implementation that uses + * {@link TestTarget} as the main target. + * + * @author Stephane Nicoll + */ +public class TestGenerationContext extends DefaultGenerationContext { + + public TestGenerationContext(GeneratedFiles generatedFiles) { + super(new ClassNameGenerator(TestTarget.class), generatedFiles); + } + + public TestGenerationContext() { + this(new InMemoryGeneratedFiles()); + } +} diff --git a/spring-core/src/testFixtures/java/org/springframework/core/testfixture/aot/generate/TestTarget.java b/spring-core/src/testFixtures/java/org/springframework/core/testfixture/aot/generate/TestTarget.java new file mode 100644 index 0000000000..d1b5568c28 --- /dev/null +++ b/spring-core/src/testFixtures/java/org/springframework/core/testfixture/aot/generate/TestTarget.java @@ -0,0 +1,25 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.core.testfixture.aot.generate; + +/** + * A target used by tests of code generation. + * + * @author Stephane Nicoll + */ +public class TestTarget { +} diff --git a/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java b/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java index 68b4a178ed..3363a6ec01 100644 --- a/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java +++ b/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java @@ -30,6 +30,7 @@ import java.util.Map; import java.util.Properties; import java.util.TreeSet; import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; import jakarta.persistence.EntityManager; import jakarta.persistence.EntityManagerFactory; @@ -39,11 +40,10 @@ import jakarta.persistence.PersistenceProperty; import jakarta.persistence.PersistenceUnit; import jakarta.persistence.SynchronizationType; +import org.springframework.aot.generate.GeneratedClass; import org.springframework.aot.generate.GeneratedMethod; -import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodGenerator; -import org.springframework.aot.generate.MethodNameGenerator; import org.springframework.aot.generate.MethodReference; import org.springframework.aot.hint.RuntimeHints; import org.springframework.beans.BeanUtils; @@ -70,11 +70,9 @@ import org.springframework.core.BridgeMethodResolver; import org.springframework.core.Ordered; import org.springframework.core.PriorityOrdered; import org.springframework.core.annotation.AnnotationUtils; -import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; -import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.MethodSpec; -import org.springframework.javapoet.TypeSpec; +import org.springframework.javapoet.MethodSpec.Builder; import org.springframework.jndi.JndiLocatorDelegate; import org.springframework.jndi.JndiTemplate; import org.springframework.lang.Nullable; @@ -789,34 +787,27 @@ public class PersistenceAnnotationBeanPostProcessor implements InstantiationAwar @Override public void applyTo(GenerationContext generationContext, BeanRegistrationCode beanRegistrationCode) { - ClassName className = generationContext.getClassNameGenerator() - .generateClassName(this.target, "PersistenceInjection"); - TypeSpec.Builder classBuilder = TypeSpec.classBuilder(className); - classBuilder.addJavadoc("Persistence injection for {@link $T}.", this.target); - classBuilder.addModifiers(javax.lang.model.element.Modifier.PUBLIC); - GeneratedMethods methods = new GeneratedMethods( - new MethodNameGenerator(APPLY_METHOD)); - classBuilder.addMethod(generateMethod(generationContext.getRuntimeHints(), - className, methods)); - methods.doWithMethodSpecs(classBuilder::addMethod); - JavaFile javaFile = JavaFile - .builder(className.packageName(), classBuilder.build()).build(); - generationContext.getGeneratedFiles().addSourceFile(javaFile); + GeneratedClass generatedClass = generationContext.getGeneratedClasses() + .forFeatureComponent("PersistenceInjection", this.target).generate(type -> { + type.addJavadoc("Persistence injection for {@link $T}.", this.target); + type.addModifiers(javax.lang.model.element.Modifier.PUBLIC); + }); + generatedClass.getMethodGenerator().generateMethod(APPLY_METHOD) + .using(generateMethod(generationContext.getRuntimeHints(), generatedClass.getMethodGenerator())); beanRegistrationCode.addInstancePostProcessor( - MethodReference.ofStatic(className, APPLY_METHOD)); + MethodReference.ofStatic(generatedClass.getName(), APPLY_METHOD)); } - private MethodSpec generateMethod(RuntimeHints hints, ClassName className, - MethodGenerator methodGenerator) { - MethodSpec.Builder builder = MethodSpec.methodBuilder(APPLY_METHOD); - builder.addJavadoc("Apply the persistence injection."); - builder.addModifiers(javax.lang.model.element.Modifier.PUBLIC, - javax.lang.model.element.Modifier.STATIC); - builder.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER); - builder.addParameter(this.target, INSTANCE_PARAMETER); - builder.returns(this.target); - builder.addCode(generateMethodCode(hints, methodGenerator)); - return builder.build(); + private Consumer generateMethod(RuntimeHints hints, MethodGenerator methodGenerator) { + return method -> { + method.addJavadoc("Apply the persistence injection."); + method.addModifiers(javax.lang.model.element.Modifier.PUBLIC, + javax.lang.model.element.Modifier.STATIC); + method.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER); + method.addParameter(this.target, INSTANCE_PARAMETER); + method.returns(this.target); + method.addCode(generateMethodCode(hints, methodGenerator)); + }; } private CodeBlock generateMethodCode(RuntimeHints hints, diff --git a/spring-orm/src/test/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessorAotContributionTests.java b/spring-orm/src/test/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessorAotContributionTests.java index f38e1da8eb..8430b95e09 100644 --- a/spring-orm/src/test/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessorAotContributionTests.java +++ b/spring-orm/src/test/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessorAotContributionTests.java @@ -43,6 +43,7 @@ import org.springframework.beans.factory.aot.BeanRegistrationCode; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.core.testfixture.aot.generate.TestGenerationContext; import org.springframework.util.ReflectionUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -67,7 +68,7 @@ class PersistenceAnnotationBeanPostProcessorAotContributionTests { void setup() { this.beanFactory = new DefaultListableBeanFactory(); this.generatedFiles = new InMemoryGeneratedFiles(); - this.generationContext = new DefaultGenerationContext(generatedFiles); + this.generationContext = new TestGenerationContext(generatedFiles); } @Test @@ -183,6 +184,7 @@ class PersistenceAnnotationBeanPostProcessorAotContributionTests { .processAheadOfTime(registeredBean); BeanRegistrationCode beanRegistrationCode = mock(BeanRegistrationCode.class); contribution.applyTo(generationContext, beanRegistrationCode); + generationContext.writeGeneratedContent(); TestCompiler.forSystem().withFiles(generatedFiles) .compile(compiled -> result.accept(new Invoker(compiled), compiled)); }