From 9b07457d06d0bfc5f87d157a79a227b66b9aee2b Mon Sep 17 00:00:00 2001 From: Stephane Nicoll Date: Sun, 6 Mar 2022 18:40:27 +0100 Subject: [PATCH] Introduce ApplicationContextAotGenerator This commit introduces a way to process a GenericApplicationContext ahead of time. Components that can contribute in that phase are invoked, and their contributions are recorded in the GeneratedTypeContext. This commit also expands BeanFactoryContribution so that it can exclude bean definitions that are no longer required. Closes gh-28150 --- .../BeanDefinitionsContribution.java | 13 + .../generator/BeanFactoryContribution.java | 14 + .../BeanDefinitionsContributionTests.java | 39 +++ spring-context/spring-context.gradle | 1 + .../ApplicationContextAotGenerator.java | 172 ++++++++++ .../ApplicationContextInitialization.java | 35 ++ .../generator/InfrastructureContribution.java | 43 +++ .../context/generator/package-info.java | 10 + .../ApplicationContextAotGeneratorTests.java | 318 ++++++++++++++++++ .../InfrastructureContributionTests.java | 61 ++++ .../context/generator/SimpleComponent.java | 20 ++ .../annotation/AutowiredComponent.java | 46 +++ .../annotation/SimpleConfiguration.java | 30 ++ 13 files changed, 802 insertions(+) create mode 100644 spring-context/src/main/java/org/springframework/context/generator/ApplicationContextAotGenerator.java create mode 100644 spring-context/src/main/java/org/springframework/context/generator/ApplicationContextInitialization.java create mode 100644 spring-context/src/main/java/org/springframework/context/generator/InfrastructureContribution.java create mode 100644 spring-context/src/main/java/org/springframework/context/generator/package-info.java create mode 100644 spring-context/src/test/java/org/springframework/context/generator/ApplicationContextAotGeneratorTests.java create mode 100644 spring-context/src/test/java/org/springframework/context/generator/InfrastructureContributionTests.java create mode 100644 spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/generator/SimpleComponent.java create mode 100644 spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/generator/annotation/AutowiredComponent.java create mode 100644 spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/generator/annotation/SimpleConfiguration.java diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanDefinitionsContribution.java b/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanDefinitionsContribution.java index 83e6d9a012..cb07da8a17 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanDefinitionsContribution.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanDefinitionsContribution.java @@ -20,8 +20,11 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.function.BiPredicate; import java.util.function.Consumer; +import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.core.io.support.SpringFactoriesLoader; @@ -69,6 +72,16 @@ public class BeanDefinitionsContribution implements BeanFactoryContribution { writeBeanDefinitions(initialization); } + @Override + public BiPredicate getBeanDefinitionExcludeFilter() { + List> predicates = new ArrayList<>(); + for (String beanName : this.beanFactory.getBeanDefinitionNames()) { + handleMergedBeanDefinition(beanName, beanDefinition -> predicates.add( + getBeanRegistrationContribution(beanName, beanDefinition).getBeanDefinitionExcludeFilter())); + } + return predicates.stream().filter(Objects::nonNull).reduce((n, d) -> false, BiPredicate::or); + } + private void writeBeanDefinitions(BeanFactoryInitialization initialization) { for (String beanName : this.beanFactory.getBeanDefinitionNames()) { handleMergedBeanDefinition(beanName, beanDefinition -> { diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanFactoryContribution.java b/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanFactoryContribution.java index 2eeeac9a98..87402bc468 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanFactoryContribution.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanFactoryContribution.java @@ -16,6 +16,10 @@ package org.springframework.beans.factory.generator; +import java.util.function.BiPredicate; + +import org.springframework.beans.factory.config.BeanDefinition; + /** * Contribute optimizations ahead of time to initialize a bean factory. * @@ -31,4 +35,14 @@ public interface BeanFactoryContribution { */ void applyTo(BeanFactoryInitialization initialization); + /** + * Return a predicate that determines if a particular bean definition + * should be excluded from processing. Can be used to exclude infrastructure + * that has been optimized using generated code. + * @return the predicate to use + */ + default BiPredicate getBeanDefinitionExcludeFilter() { + return (beanName, beanDefinition) -> false; + } + } diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/generator/BeanDefinitionsContributionTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/generator/BeanDefinitionsContributionTests.java index 17b1eb0b15..68fd627099 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/generator/BeanDefinitionsContributionTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/generator/BeanDefinitionsContributionTests.java @@ -17,6 +17,7 @@ package org.springframework.beans.factory.generator; import java.util.List; +import java.util.function.BiPredicate; import org.junit.jupiter.api.Test; import org.mockito.ArgumentMatchers; @@ -26,6 +27,7 @@ import org.mockito.Mockito; import org.springframework.aot.generator.DefaultGeneratedTypeContext; import org.springframework.aot.generator.GeneratedType; import org.springframework.aot.generator.GeneratedTypeContext; +import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.beans.factory.support.RootBeanDefinition; @@ -90,6 +92,37 @@ class BeanDefinitionsContributionTests { """); } + @Test + void getBeanDefinitionWithNoUnderlyingContributorReturnFalseByDefault() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + BiPredicate excludeFilter = new BeanDefinitionsContribution(beanFactory) + .getBeanDefinitionExcludeFilter(); + assertThat(excludeFilter.test("foo", new RootBeanDefinition())).isFalse(); + } + + @Test + @SuppressWarnings("unchecked") + void getBeanDefinitionExcludeFilterWrapsUnderlyingFilter() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("bean1", new RootBeanDefinition()); + beanFactory.registerBeanDefinition("bean2", new RootBeanDefinition()); + BiPredicate excludeFilter1 = Mockito.mock(BiPredicate.class); + BDDMockito.given(excludeFilter1.test(ArgumentMatchers.eq("bean1"), ArgumentMatchers.any(BeanDefinition.class))).willReturn(Boolean.TRUE); + BDDMockito.given(excludeFilter1.test(ArgumentMatchers.eq("bean2"), ArgumentMatchers.any(BeanDefinition.class))).willReturn(Boolean.FALSE); + BiPredicate excludeFilter2 = Mockito.mock(BiPredicate.class); + BDDMockito.given(excludeFilter2.test(ArgumentMatchers.eq("bean2"), ArgumentMatchers.any(BeanDefinition.class))).willReturn(Boolean.TRUE); + BiPredicate excludeFilter = new BeanDefinitionsContribution(beanFactory, List.of( + new TestBeanRegistrationContributionProvider("bean1", mockExcludeFilter(excludeFilter1)), + new TestBeanRegistrationContributionProvider("bean2", mockExcludeFilter(excludeFilter2))) + ).getBeanDefinitionExcludeFilter(); + assertThat(excludeFilter.test("bean2", new RootBeanDefinition())).isTrue(); + Mockito.verify(excludeFilter1).test(ArgumentMatchers.eq("bean2"), ArgumentMatchers.any(BeanDefinition.class)); + Mockito.verify(excludeFilter2).test(ArgumentMatchers.eq("bean2"), ArgumentMatchers.any(BeanDefinition.class)); + assertThat(excludeFilter.test("bean1", new RootBeanDefinition())).isTrue(); + Mockito.verify(excludeFilter1).test(ArgumentMatchers.eq("bean1"), ArgumentMatchers.any(BeanDefinition.class)); + Mockito.verifyNoMoreInteractions(excludeFilter2); + } + private CodeSnippet contribute(DefaultListableBeanFactory beanFactory, GeneratedTypeContext generationContext) { BeanDefinitionsContribution contribution = new BeanDefinitionsContribution(beanFactory); BeanFactoryInitialization initialization = new BeanFactoryInitialization(generationContext); @@ -102,6 +135,12 @@ class BeanDefinitionsContributionTests { GeneratedType.of(ClassName.get(packageName, "Test"))); } + private BeanFactoryContribution mockExcludeFilter(BiPredicate excludeFilter) { + BeanFactoryContribution contribution = Mockito.mock(BeanFactoryContribution.class); + BDDMockito.given(contribution.getBeanDefinitionExcludeFilter()).willReturn(excludeFilter); + return contribution; + } + static class TestBeanRegistrationContributionProvider implements BeanRegistrationContributionProvider { private final String beanName; diff --git a/spring-context/spring-context.gradle b/spring-context/spring-context.gradle index ffacfd8eb5..b038d42892 100644 --- a/spring-context/spring-context.gradle +++ b/spring-context/spring-context.gradle @@ -25,6 +25,7 @@ dependencies { testImplementation(testFixtures(project(":spring-aop"))) testImplementation(testFixtures(project(":spring-beans"))) testImplementation(testFixtures(project(":spring-core"))) + testImplementation(project(":spring-core-test")) testImplementation("io.projectreactor:reactor-core") testImplementation("org.apache.groovy:groovy-jsr223") testImplementation("org.apache.groovy:groovy-xml") diff --git a/spring-context/src/main/java/org/springframework/context/generator/ApplicationContextAotGenerator.java b/spring-context/src/main/java/org/springframework/context/generator/ApplicationContextAotGenerator.java new file mode 100644 index 0000000000..5e800e0a8e --- /dev/null +++ b/spring-context/src/main/java/org/springframework/context/generator/ApplicationContextAotGenerator.java @@ -0,0 +1,172 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.context.generator; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; +import java.util.function.BiPredicate; +import java.util.stream.Stream; + +import javax.lang.model.element.Modifier; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.aot.generator.GeneratedType; +import org.springframework.aot.generator.GeneratedTypeContext; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.beans.factory.generator.AotContributingBeanFactoryPostProcessor; +import org.springframework.beans.factory.generator.AotContributingBeanPostProcessor; +import org.springframework.beans.factory.generator.BeanDefinitionsContribution; +import org.springframework.beans.factory.generator.BeanFactoryContribution; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationContextInitializer; +import org.springframework.context.support.GenericApplicationContext; +import org.springframework.core.OrderComparator; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.MethodSpec; +import org.springframework.javapoet.ParameterizedTypeName; + +/** + * Process an {@link ApplicationContext} and its {@link BeanFactory} to generate + * code that represents the state of the bean factory, as well as the necessary + * hints that can be used at runtime in a constrained environment. + * + * @author Stephane Nicoll + * @since 6.0 + */ +public class ApplicationContextAotGenerator { + + private static final Log logger = LogFactory.getLog(ApplicationContextAotGenerator.class); + + /** + * Refresh the specified {@link GenericApplicationContext} and generate the + * necessary code to restore the state of its {@link BeanFactory}, using the + * specified {@link GeneratedTypeContext}. + * @param applicationContext the application context to handle + * @param generationContext the generation context to use + */ + public void generateApplicationContext(GenericApplicationContext applicationContext, + GeneratedTypeContext generationContext) { + applicationContext.refreshForAotProcessing(); + + DefaultListableBeanFactory beanFactory = applicationContext.getDefaultListableBeanFactory(); + List contributions = resolveBeanFactoryContributions(beanFactory); + + filterBeanFactory(contributions, beanFactory); + ApplicationContextInitialization applicationContextInitialization = new ApplicationContextInitialization(generationContext); + applyContributions(contributions, applicationContextInitialization); + + GeneratedType mainGeneratedType = generationContext.getMainGeneratedType(); + mainGeneratedType.customizeType(type -> type.addSuperinterface(ParameterizedTypeName.get( + ApplicationContextInitializer.class, GenericApplicationContext.class))); + mainGeneratedType.addMethod(initializeMethod(applicationContextInitialization.toCodeBlock())); + } + + private MethodSpec.Builder initializeMethod(CodeBlock methodBody) { + MethodSpec.Builder method = MethodSpec.methodBuilder("initialize").addModifiers(Modifier.PUBLIC) + .addParameter(GenericApplicationContext.class, "context").addAnnotation(Override.class); + method.addCode(methodBody); + return method; + } + + private void filterBeanFactory(List contributions, DefaultListableBeanFactory beanFactory) { + BiPredicate filter = Stream.concat(Stream.of(aotContributingExcludeFilter()), + contributions.stream().map(BeanFactoryContribution::getBeanDefinitionExcludeFilter)) + .filter(Objects::nonNull).reduce((n, d) -> false, BiPredicate::or); + for (String beanName : beanFactory.getBeanDefinitionNames()) { + BeanDefinition bd = beanFactory.getMergedBeanDefinition(beanName); + if (filter.test(beanName, bd)) { + if (logger.isDebugEnabled()) { + logger.debug("Filtering out bean with name" + beanName + ": " + bd); + } + beanFactory.removeBeanDefinition(beanName); + } + } + } + + // TODO: is this right? + private BiPredicate aotContributingExcludeFilter() { + return (beanName, beanDefinition) -> { + Class type = beanDefinition.getResolvableType().toClass(); + return AotContributingBeanFactoryPostProcessor.class.isAssignableFrom(type) || + AotContributingBeanPostProcessor.class.isAssignableFrom(type); + }; + } + + + private void applyContributions(List contributions, + ApplicationContextInitialization initialization) { + for (BeanFactoryContribution contribution : contributions) { + contribution.applyTo(initialization); + } + } + + /** + * Resolve the {@link BeanFactoryContribution} available in the specified + * bean factory. Infrastructure is contributed first, and bean definitions + * registration last. + * @param beanFactory the bean factory to process + * @return the contribution to apply + * @see InfrastructureContribution + * @see BeanDefinitionsContribution + */ + private List resolveBeanFactoryContributions(DefaultListableBeanFactory beanFactory) { + List contributions = new ArrayList<>(); + contributions.add(new InfrastructureContribution()); + List postProcessors = getAotContributingBeanFactoryPostProcessors(beanFactory); + for (AotContributingBeanFactoryPostProcessor postProcessor : postProcessors) { + BeanFactoryContribution contribution = postProcessor.contribute(beanFactory); + if (contribution != null) { + contributions.add(contribution); + } + } + contributions.add(new BeanDefinitionsContribution(beanFactory)); + return contributions; + } + + private static List getAotContributingBeanFactoryPostProcessors(DefaultListableBeanFactory beanFactory) { + String[] postProcessorNames = beanFactory.getBeanNamesForType(AotContributingBeanFactoryPostProcessor.class, true, false); + List postProcessors = new ArrayList<>(); + for (String ppName : postProcessorNames) { + postProcessors.add(beanFactory.getBean(ppName, AotContributingBeanFactoryPostProcessor.class)); + } + sortPostProcessors(postProcessors, beanFactory); + return postProcessors; + } + + private static void sortPostProcessors(List postProcessors, ConfigurableListableBeanFactory beanFactory) { + // Nothing to sort? + if (postProcessors.size() <= 1) { + return; + } + Comparator comparatorToUse = null; + if (beanFactory instanceof DefaultListableBeanFactory) { + comparatorToUse = ((DefaultListableBeanFactory) beanFactory).getDependencyComparator(); + } + if (comparatorToUse == null) { + comparatorToUse = OrderComparator.INSTANCE; + } + postProcessors.sort(comparatorToUse); + } + +} diff --git a/spring-context/src/main/java/org/springframework/context/generator/ApplicationContextInitialization.java b/spring-context/src/main/java/org/springframework/context/generator/ApplicationContextInitialization.java new file mode 100644 index 0000000000..9f2704f2c0 --- /dev/null +++ b/spring-context/src/main/java/org/springframework/context/generator/ApplicationContextInitialization.java @@ -0,0 +1,35 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.context.generator; + +import org.springframework.aot.generator.GeneratedTypeContext; +import org.springframework.beans.factory.generator.BeanFactoryInitialization; +import org.springframework.context.ApplicationContext; + +/** + * The initialization of an {@link ApplicationContext}. + * + * @author Andy Wilkinson + * @since 6.0 + */ +public class ApplicationContextInitialization extends BeanFactoryInitialization { + + public ApplicationContextInitialization(GeneratedTypeContext generatedTypeContext) { + super(generatedTypeContext); + } + +} diff --git a/spring-context/src/main/java/org/springframework/context/generator/InfrastructureContribution.java b/spring-context/src/main/java/org/springframework/context/generator/InfrastructureContribution.java new file mode 100644 index 0000000000..6a0ce89e3a --- /dev/null +++ b/spring-context/src/main/java/org/springframework/context/generator/InfrastructureContribution.java @@ -0,0 +1,43 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.context.generator; + +import org.springframework.beans.factory.generator.BeanFactoryContribution; +import org.springframework.beans.factory.generator.BeanFactoryInitialization; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.context.annotation.ContextAnnotationAutowireCandidateResolver; + +/** + * A {@link BeanFactoryContribution} that configures the low-level + * infrastructure necessary to process an AOT context. + * + * @author Stephane Nicoll + */ +class InfrastructureContribution implements BeanFactoryContribution { + + @Override + public void applyTo(BeanFactoryInitialization initialization) { + initialization.contribute(code -> { + code.add("// infrastructure\n"); + code.addStatement("$T beanFactory = context.getDefaultListableBeanFactory()", + DefaultListableBeanFactory.class); + code.addStatement("beanFactory.setAutowireCandidateResolver(new $T())", + ContextAnnotationAutowireCandidateResolver.class); + }); + } + +} diff --git a/spring-context/src/main/java/org/springframework/context/generator/package-info.java b/spring-context/src/main/java/org/springframework/context/generator/package-info.java new file mode 100644 index 0000000000..af0aaf7294 --- /dev/null +++ b/spring-context/src/main/java/org/springframework/context/generator/package-info.java @@ -0,0 +1,10 @@ +/** + * Support for generating code that represents the state of an application + * context. + */ +@NonNullApi +@NonNullFields +package org.springframework.context.generator; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-context/src/test/java/org/springframework/context/generator/ApplicationContextAotGeneratorTests.java b/spring-context/src/test/java/org/springframework/context/generator/ApplicationContextAotGeneratorTests.java new file mode 100644 index 0000000000..4935de2965 --- /dev/null +++ b/spring-context/src/test/java/org/springframework/context/generator/ApplicationContextAotGeneratorTests.java @@ -0,0 +1,318 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.context.generator; + +import java.io.IOException; +import java.io.StringWriter; +import java.util.function.BiPredicate; +import java.util.function.Consumer; +import java.util.function.Supplier; + +import org.junit.jupiter.api.Test; + +import org.springframework.aot.generator.DefaultGeneratedTypeContext; +import org.springframework.aot.generator.GeneratedType; +import org.springframework.aot.generator.GeneratedTypeContext; +import org.springframework.aot.test.generator.compile.TestCompiler; +import org.springframework.aot.test.generator.file.SourceFile; +import org.springframework.aot.test.generator.file.SourceFiles; +import org.springframework.beans.factory.annotation.AutowiredAnnotationBeanPostProcessor; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.beans.factory.generator.AotContributingBeanFactoryPostProcessor; +import org.springframework.beans.factory.generator.AotContributingBeanPostProcessor; +import org.springframework.beans.factory.generator.BeanFactoryContribution; +import org.springframework.beans.factory.generator.BeanFactoryInitialization; +import org.springframework.beans.factory.generator.BeanInstantiationContribution; +import org.springframework.beans.factory.support.BeanDefinitionBuilder; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.context.ApplicationContextInitializer; +import org.springframework.context.annotation.AnnotationConfigUtils; +import org.springframework.context.support.GenericApplicationContext; +import org.springframework.context.testfixture.context.generator.SimpleComponent; +import org.springframework.context.testfixture.context.generator.annotation.AutowiredComponent; +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.JavaFile; +import org.springframework.lang.Nullable; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link ApplicationContextAotGenerator}. + * + * @author Stephane Nicoll + */ +class ApplicationContextAotGeneratorTests { + + private static final ClassName MAIN_GENERATED_TYPE = ClassName.get("com.example", "Test"); + + @Test + void generateApplicationContextWithSimpleBean() { + GenericApplicationContext context = new GenericApplicationContext(); + context.registerBeanDefinition("test", new RootBeanDefinition(SimpleComponent.class)); + compile(context, toFreshApplicationContext(GenericApplicationContext::new, aotContext -> { + assertThat(aotContext.getBeanDefinitionNames()).containsOnly("test"); + assertThat(aotContext.getBean("test")).isInstanceOf(SimpleComponent.class); + })); + } + + @Test + void generateApplicationContextWithAutowiring() { + GenericApplicationContext context = new GenericApplicationContext(); + context.registerBeanDefinition(AnnotationConfigUtils.AUTOWIRED_ANNOTATION_PROCESSOR_BEAN_NAME, + BeanDefinitionBuilder.rootBeanDefinition(AutowiredAnnotationBeanPostProcessor.class) + .setRole(BeanDefinition.ROLE_INFRASTRUCTURE).getBeanDefinition()); + context.registerBeanDefinition("autowiredComponent", new RootBeanDefinition(AutowiredComponent.class)); + context.registerBeanDefinition("number", BeanDefinitionBuilder.rootBeanDefinition(Integer.class, "valueOf") + .addConstructorArgValue("42").getBeanDefinition()); + compile(context, toFreshApplicationContext(GenericApplicationContext::new, aotContext -> { + assertThat(aotContext.getBeanDefinitionNames()).containsOnly("autowiredComponent", "number"); + AutowiredComponent bean = aotContext.getBean(AutowiredComponent.class); + assertThat(bean.getEnvironment()).isSameAs(aotContext.getEnvironment()); + assertThat(bean.getCounter()).isEqualTo(42L); + })); + } + + @Test + void generateApplicationContextWitNoContributors() { + GeneratedTypeContext generationContext = createGenerationContext(); + ApplicationContextAotGenerator generator = new ApplicationContextAotGenerator(); + generator.generateApplicationContext(new GenericApplicationContext(), generationContext); + assertThat(write(generationContext.getMainGeneratedType())).contains(""" + public class Test implements ApplicationContextInitializer { + @Override + public void initialize(GenericApplicationContext context) { + // infrastructure + DefaultListableBeanFactory beanFactory = context.getDefaultListableBeanFactory(); + beanFactory.setAutowireCandidateResolver(new ContextAnnotationAutowireCandidateResolver()); + } + } + """); + } + + @Test + void generateApplicationContextApplyContributionAsIsWithNewLineAtTheEnd() { + GenericApplicationContext applicationContext = new GenericApplicationContext(); + registerAotContributingBeanDefinition(applicationContext, "bpp", code -> code.add("// Hello")); + GeneratedTypeContext generationContext = createGenerationContext(); + ApplicationContextAotGenerator generator = new ApplicationContextAotGenerator(); + generator.generateApplicationContext(applicationContext, generationContext); + assertThat(write(generationContext.getMainGeneratedType())).contains(""" + public class Test implements ApplicationContextInitializer { + @Override + public void initialize(GenericApplicationContext context) { + // infrastructure + DefaultListableBeanFactory beanFactory = context.getDefaultListableBeanFactory(); + beanFactory.setAutowireCandidateResolver(new ContextAnnotationAutowireCandidateResolver()); + // Hello + } + } + """); + } + + @Test + void generateApplicationContextApplyMultipleContributionAsIsWithNewLineAtTheEnd() { + GeneratedTypeContext generationContext = createGenerationContext(); + GenericApplicationContext applicationContext = new GenericApplicationContext(); + registerAotContributingBeanDefinition(applicationContext, "bpp", code -> code.add("// Hello")); + registerAotContributingBeanDefinition(applicationContext, "bpp2", code -> code.add("// World")); + ApplicationContextAotGenerator generator = new ApplicationContextAotGenerator(); + generator.generateApplicationContext(applicationContext, generationContext); + assertThat(write(generationContext.getMainGeneratedType())).contains(""" + public class Test implements ApplicationContextInitializer { + @Override + public void initialize(GenericApplicationContext context) { + // infrastructure + DefaultListableBeanFactory beanFactory = context.getDefaultListableBeanFactory(); + beanFactory.setAutowireCandidateResolver(new ContextAnnotationAutowireCandidateResolver()); + // Hello + // World + } + } + """); + } + + @Test + void generateApplicationContextExcludeAotContributingBeanFactoryPostProcessorByDefault() { + GenericApplicationContext context = new GenericApplicationContext(); + context.registerBeanDefinition("test", new RootBeanDefinition(NoOpAotContributingBeanFactoryPostProcessor.class)); + compile(context, toFreshApplicationContext(GenericApplicationContext::new, aotContext -> + assertThat(aotContext.getBeanDefinitionNames()).isEmpty())); + } + + @Test + void generateApplicationContextExcludeAotContributingBeanPostProcessorByDefault() { + GenericApplicationContext context = new GenericApplicationContext(); + context.registerBeanDefinition("test", new RootBeanDefinition(NoOpAotContributingBeanPostProcessor.class)); + compile(context, toFreshApplicationContext(GenericApplicationContext::new, aotContext -> + assertThat(aotContext.getBeanDefinitionNames()).isEmpty())); + } + + @Test + void generateApplicationContextInvokeExcludePredicateInOrder() { + GeneratedTypeContext generationContext = createGenerationContext(); + GenericApplicationContext applicationContext = new GenericApplicationContext(); + DefaultListableBeanFactory beanFactory = applicationContext.getDefaultListableBeanFactory(); + BiPredicate excludeFilter = mock(BiPredicate.class); + given(excludeFilter.test(eq("bean1"), any(BeanDefinition.class))).willReturn(Boolean.FALSE); + given(excludeFilter.test(eq("bean2"), any(BeanDefinition.class))).willReturn(Boolean.TRUE); + applicationContext.registerBeanDefinition("bean2", new RootBeanDefinition(SimpleComponent.class)); + applicationContext.registerBeanDefinition("bean1", new RootBeanDefinition(SimpleComponent.class)); + registerAotContributingBeanDefinition(applicationContext, "bpp", code -> {}, excludeFilter); + ApplicationContextAotGenerator generator = new ApplicationContextAotGenerator(); + generator.generateApplicationContext(applicationContext, generationContext); + assertThat(write(generationContext.getMainGeneratedType())) + .doesNotContain("bean2").doesNotContain("bpp") + .contains("BeanDefinitionRegistrar.of(\"bean1\", SimpleComponent.class)"); + verify(excludeFilter).test(eq("bean2"), any(BeanDefinition.class)); + verify(excludeFilter).test("bean1", beanFactory.getMergedBeanDefinition("bean1")); + } + + + @SuppressWarnings("rawtypes") + private void compile(GenericApplicationContext applicationContext, Consumer initializer) { + DefaultGeneratedTypeContext generationContext = createGenerationContext(); + ApplicationContextAotGenerator generator = new ApplicationContextAotGenerator(); + generator.generateApplicationContext(applicationContext, generationContext); + SourceFiles sourceFiles = SourceFiles.none(); + for (JavaFile javaFile : generationContext.toJavaFiles()) { + sourceFiles = sourceFiles.and(SourceFile.of((javaFile::writeTo))); + } + TestCompiler.forSystem().withSources(sourceFiles).compile(compiled -> { + ApplicationContextInitializer instance = compiled.getInstance(ApplicationContextInitializer.class, MAIN_GENERATED_TYPE.canonicalName()); + initializer.accept(instance); + }); + } + + @SuppressWarnings({ "rawtypes", "unchecked" }) + private Consumer toFreshApplicationContext( + Supplier applicationContextFactory, Consumer context) { + return applicationContextInitializer -> { + T applicationContext = applicationContextFactory.get(); + applicationContextInitializer.initialize(applicationContext); + applicationContext.refresh(); + context.accept(applicationContext); + }; + } + + private DefaultGeneratedTypeContext createGenerationContext() { + return new DefaultGeneratedTypeContext(MAIN_GENERATED_TYPE.packageName(), packageName -> + GeneratedType.of(ClassName.get(packageName, MAIN_GENERATED_TYPE.simpleName()))); + } + + private String write(GeneratedType generatedType) { + try { + StringWriter out = new StringWriter(); + generatedType.toJavaFile().writeTo(out); + return out.toString(); + } + catch (IOException ex) { + throw new IllegalStateException("Failed to write " + generatedType, ex); + } + } + + private void registerAotContributingBeanDefinition(GenericApplicationContext context, String name, + Consumer code) { + registerAotContributingBeanDefinition(context, name, code, + (beanName, beanDefinition) -> name.equals(beanName)); + } + + private void registerAotContributingBeanDefinition(GenericApplicationContext context, String name, + Consumer code, BiPredicate excludeFilter) { + BeanFactoryContribution contribution = new TestBeanFactoryContribution( + initialization -> initialization.contribute(code), excludeFilter); + context.registerBeanDefinition(name, BeanDefinitionBuilder.rootBeanDefinition( + TestAotContributingBeanFactoryPostProcessor.class, () -> + new TestAotContributingBeanFactoryPostProcessor(contribution)).getBeanDefinition()); + } + + + static class TestAotContributingBeanFactoryPostProcessor implements AotContributingBeanFactoryPostProcessor { + + @Nullable + private final BeanFactoryContribution beanFactoryContribution; + + TestAotContributingBeanFactoryPostProcessor(@Nullable BeanFactoryContribution beanFactoryContribution) { + this.beanFactoryContribution = beanFactoryContribution; + } + + TestAotContributingBeanFactoryPostProcessor() { + this(null); + } + + @Override + public BeanFactoryContribution contribute(ConfigurableListableBeanFactory beanFactory) { + return this.beanFactoryContribution; + } + + } + + static class NoOpAotContributingBeanFactoryPostProcessor implements AotContributingBeanFactoryPostProcessor { + + @Override + public BeanFactoryContribution contribute(ConfigurableListableBeanFactory beanFactory) { + return null; + } + } + + static class NoOpAotContributingBeanPostProcessor implements AotContributingBeanPostProcessor { + + @Override + public BeanInstantiationContribution contribute(RootBeanDefinition beanDefinition, Class beanType, String beanName) { + return null; + } + + @Override + public int getOrder() { + return 0; + } + + } + + static class TestBeanFactoryContribution implements BeanFactoryContribution { + private final Consumer contribution; + + private final BiPredicate excludeFilter; + + private int order; + + public TestBeanFactoryContribution(Consumer contribution, + BiPredicate excludeFilter) { + this.contribution = contribution; + this.excludeFilter = excludeFilter; + } + + @Override + public void applyTo(BeanFactoryInitialization initialization) { + this.contribution.accept(initialization); + } + + @Override + public BiPredicate getBeanDefinitionExcludeFilter() { + return this.excludeFilter; + } + + } + +} diff --git a/spring-context/src/test/java/org/springframework/context/generator/InfrastructureContributionTests.java b/spring-context/src/test/java/org/springframework/context/generator/InfrastructureContributionTests.java new file mode 100644 index 0000000000..0616826eb1 --- /dev/null +++ b/spring-context/src/test/java/org/springframework/context/generator/InfrastructureContributionTests.java @@ -0,0 +1,61 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.context.generator; + +import org.junit.jupiter.api.Test; + +import org.springframework.aot.generator.DefaultGeneratedTypeContext; +import org.springframework.aot.generator.GeneratedType; +import org.springframework.aot.generator.GeneratedTypeContext; +import org.springframework.beans.factory.generator.BeanFactoryInitialization; +import org.springframework.context.annotation.ContextAnnotationAutowireCandidateResolver; +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.support.CodeSnippet; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link InfrastructureContribution}. + * + * @author Stephane Nicoll + */ +class InfrastructureContributionTests { + + @Test + void contributeInfrastructure() { + CodeSnippet codeSnippet = contribute(createGenerationContext()); + assertThat(codeSnippet.getSnippet()).isEqualTo(""" + // infrastructure + DefaultListableBeanFactory beanFactory = context.getDefaultListableBeanFactory(); + beanFactory.setAutowireCandidateResolver(new ContextAnnotationAutowireCandidateResolver()); + """); + assertThat(codeSnippet.hasImport(ContextAnnotationAutowireCandidateResolver.class)).isTrue(); + } + + private CodeSnippet contribute(GeneratedTypeContext generationContext) { + InfrastructureContribution contribution = new InfrastructureContribution(); + BeanFactoryInitialization initialization = new BeanFactoryInitialization(generationContext); + contribution.applyTo(initialization); + return CodeSnippet.of(initialization.toCodeBlock()); + } + + private GeneratedTypeContext createGenerationContext() { + return new DefaultGeneratedTypeContext("com.example", packageName -> + GeneratedType.of(ClassName.get(packageName, "Test"))); + } + +} diff --git a/spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/generator/SimpleComponent.java b/spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/generator/SimpleComponent.java new file mode 100644 index 0000000000..7004b8e0c6 --- /dev/null +++ b/spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/generator/SimpleComponent.java @@ -0,0 +1,20 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.context.testfixture.context.generator; + +public class SimpleComponent { +} diff --git a/spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/generator/annotation/AutowiredComponent.java b/spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/generator/annotation/AutowiredComponent.java new file mode 100644 index 0000000000..8bc7dae191 --- /dev/null +++ b/spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/generator/annotation/AutowiredComponent.java @@ -0,0 +1,46 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.context.testfixture.context.generator.annotation; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.core.env.Environment; + +public class AutowiredComponent { + + private Environment environment; + + private Integer counter; + + public Environment getEnvironment() { + return this.environment; + } + + @Autowired + public void setEnvironment(Environment environment) { + this.environment = environment; + } + + public Integer getCounter() { + return this.counter; + } + + @Autowired + public void setCounter(Integer counter) { + this.counter = counter; + } + +} diff --git a/spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/generator/annotation/SimpleConfiguration.java b/spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/generator/annotation/SimpleConfiguration.java new file mode 100644 index 0000000000..34cd687999 --- /dev/null +++ b/spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/generator/annotation/SimpleConfiguration.java @@ -0,0 +1,30 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.context.testfixture.context.generator.annotation; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +@Configuration(proxyBeanMethods = false) +public class SimpleConfiguration { + + @Bean + public String stringBean() { + return "Hello"; + } + +}