diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanDefinitionsContribution.java b/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanDefinitionsContribution.java index 83e6d9a012..cb07da8a17 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanDefinitionsContribution.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanDefinitionsContribution.java @@ -20,8 +20,11 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.function.BiPredicate; import java.util.function.Consumer; +import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.core.io.support.SpringFactoriesLoader; @@ -69,6 +72,16 @@ public class BeanDefinitionsContribution implements BeanFactoryContribution { writeBeanDefinitions(initialization); } + @Override + public BiPredicate getBeanDefinitionExcludeFilter() { + List> predicates = new ArrayList<>(); + for (String beanName : this.beanFactory.getBeanDefinitionNames()) { + handleMergedBeanDefinition(beanName, beanDefinition -> predicates.add( + getBeanRegistrationContribution(beanName, beanDefinition).getBeanDefinitionExcludeFilter())); + } + return predicates.stream().filter(Objects::nonNull).reduce((n, d) -> false, BiPredicate::or); + } + private void writeBeanDefinitions(BeanFactoryInitialization initialization) { for (String beanName : this.beanFactory.getBeanDefinitionNames()) { handleMergedBeanDefinition(beanName, beanDefinition -> { diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanFactoryContribution.java b/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanFactoryContribution.java index 2eeeac9a98..87402bc468 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanFactoryContribution.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanFactoryContribution.java @@ -16,6 +16,10 @@ package org.springframework.beans.factory.generator; +import java.util.function.BiPredicate; + +import org.springframework.beans.factory.config.BeanDefinition; + /** * Contribute optimizations ahead of time to initialize a bean factory. * @@ -31,4 +35,14 @@ public interface BeanFactoryContribution { */ void applyTo(BeanFactoryInitialization initialization); + /** + * Return a predicate that determines if a particular bean definition + * should be excluded from processing. Can be used to exclude infrastructure + * that has been optimized using generated code. + * @return the predicate to use + */ + default BiPredicate getBeanDefinitionExcludeFilter() { + return (beanName, beanDefinition) -> false; + } + } diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/generator/BeanDefinitionsContributionTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/generator/BeanDefinitionsContributionTests.java index 17b1eb0b15..68fd627099 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/generator/BeanDefinitionsContributionTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/generator/BeanDefinitionsContributionTests.java @@ -17,6 +17,7 @@ package org.springframework.beans.factory.generator; import java.util.List; +import java.util.function.BiPredicate; import org.junit.jupiter.api.Test; import org.mockito.ArgumentMatchers; @@ -26,6 +27,7 @@ import org.mockito.Mockito; import org.springframework.aot.generator.DefaultGeneratedTypeContext; import org.springframework.aot.generator.GeneratedType; import org.springframework.aot.generator.GeneratedTypeContext; +import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.beans.factory.support.RootBeanDefinition; @@ -90,6 +92,37 @@ class BeanDefinitionsContributionTests { """); } + @Test + void getBeanDefinitionWithNoUnderlyingContributorReturnFalseByDefault() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + BiPredicate excludeFilter = new BeanDefinitionsContribution(beanFactory) + .getBeanDefinitionExcludeFilter(); + assertThat(excludeFilter.test("foo", new RootBeanDefinition())).isFalse(); + } + + @Test + @SuppressWarnings("unchecked") + void getBeanDefinitionExcludeFilterWrapsUnderlyingFilter() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("bean1", new RootBeanDefinition()); + beanFactory.registerBeanDefinition("bean2", new RootBeanDefinition()); + BiPredicate excludeFilter1 = Mockito.mock(BiPredicate.class); + BDDMockito.given(excludeFilter1.test(ArgumentMatchers.eq("bean1"), ArgumentMatchers.any(BeanDefinition.class))).willReturn(Boolean.TRUE); + BDDMockito.given(excludeFilter1.test(ArgumentMatchers.eq("bean2"), ArgumentMatchers.any(BeanDefinition.class))).willReturn(Boolean.FALSE); + BiPredicate excludeFilter2 = Mockito.mock(BiPredicate.class); + BDDMockito.given(excludeFilter2.test(ArgumentMatchers.eq("bean2"), ArgumentMatchers.any(BeanDefinition.class))).willReturn(Boolean.TRUE); + BiPredicate excludeFilter = new BeanDefinitionsContribution(beanFactory, List.of( + new TestBeanRegistrationContributionProvider("bean1", mockExcludeFilter(excludeFilter1)), + new TestBeanRegistrationContributionProvider("bean2", mockExcludeFilter(excludeFilter2))) + ).getBeanDefinitionExcludeFilter(); + assertThat(excludeFilter.test("bean2", new RootBeanDefinition())).isTrue(); + Mockito.verify(excludeFilter1).test(ArgumentMatchers.eq("bean2"), ArgumentMatchers.any(BeanDefinition.class)); + Mockito.verify(excludeFilter2).test(ArgumentMatchers.eq("bean2"), ArgumentMatchers.any(BeanDefinition.class)); + assertThat(excludeFilter.test("bean1", new RootBeanDefinition())).isTrue(); + Mockito.verify(excludeFilter1).test(ArgumentMatchers.eq("bean1"), ArgumentMatchers.any(BeanDefinition.class)); + Mockito.verifyNoMoreInteractions(excludeFilter2); + } + private CodeSnippet contribute(DefaultListableBeanFactory beanFactory, GeneratedTypeContext generationContext) { BeanDefinitionsContribution contribution = new BeanDefinitionsContribution(beanFactory); BeanFactoryInitialization initialization = new BeanFactoryInitialization(generationContext); @@ -102,6 +135,12 @@ class BeanDefinitionsContributionTests { GeneratedType.of(ClassName.get(packageName, "Test"))); } + private BeanFactoryContribution mockExcludeFilter(BiPredicate excludeFilter) { + BeanFactoryContribution contribution = Mockito.mock(BeanFactoryContribution.class); + BDDMockito.given(contribution.getBeanDefinitionExcludeFilter()).willReturn(excludeFilter); + return contribution; + } + static class TestBeanRegistrationContributionProvider implements BeanRegistrationContributionProvider { private final String beanName; diff --git a/spring-context/spring-context.gradle b/spring-context/spring-context.gradle index ffacfd8eb5..b038d42892 100644 --- a/spring-context/spring-context.gradle +++ b/spring-context/spring-context.gradle @@ -25,6 +25,7 @@ dependencies { testImplementation(testFixtures(project(":spring-aop"))) testImplementation(testFixtures(project(":spring-beans"))) testImplementation(testFixtures(project(":spring-core"))) + testImplementation(project(":spring-core-test")) testImplementation("io.projectreactor:reactor-core") testImplementation("org.apache.groovy:groovy-jsr223") testImplementation("org.apache.groovy:groovy-xml") diff --git a/spring-context/src/main/java/org/springframework/context/generator/ApplicationContextAotGenerator.java b/spring-context/src/main/java/org/springframework/context/generator/ApplicationContextAotGenerator.java new file mode 100644 index 0000000000..5e800e0a8e --- /dev/null +++ b/spring-context/src/main/java/org/springframework/context/generator/ApplicationContextAotGenerator.java @@ -0,0 +1,172 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.context.generator; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; +import java.util.function.BiPredicate; +import java.util.stream.Stream; + +import javax.lang.model.element.Modifier; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.aot.generator.GeneratedType; +import org.springframework.aot.generator.GeneratedTypeContext; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.beans.factory.generator.AotContributingBeanFactoryPostProcessor; +import org.springframework.beans.factory.generator.AotContributingBeanPostProcessor; +import org.springframework.beans.factory.generator.BeanDefinitionsContribution; +import org.springframework.beans.factory.generator.BeanFactoryContribution; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationContextInitializer; +import org.springframework.context.support.GenericApplicationContext; +import org.springframework.core.OrderComparator; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.MethodSpec; +import org.springframework.javapoet.ParameterizedTypeName; + +/** + * Process an {@link ApplicationContext} and its {@link BeanFactory} to generate + * code that represents the state of the bean factory, as well as the necessary + * hints that can be used at runtime in a constrained environment. + * + * @author Stephane Nicoll + * @since 6.0 + */ +public class ApplicationContextAotGenerator { + + private static final Log logger = LogFactory.getLog(ApplicationContextAotGenerator.class); + + /** + * Refresh the specified {@link GenericApplicationContext} and generate the + * necessary code to restore the state of its {@link BeanFactory}, using the + * specified {@link GeneratedTypeContext}. + * @param applicationContext the application context to handle + * @param generationContext the generation context to use + */ + public void generateApplicationContext(GenericApplicationContext applicationContext, + GeneratedTypeContext generationContext) { + applicationContext.refreshForAotProcessing(); + + DefaultListableBeanFactory beanFactory = applicationContext.getDefaultListableBeanFactory(); + List contributions = resolveBeanFactoryContributions(beanFactory); + + filterBeanFactory(contributions, beanFactory); + ApplicationContextInitialization applicationContextInitialization = new ApplicationContextInitialization(generationContext); + applyContributions(contributions, applicationContextInitialization); + + GeneratedType mainGeneratedType = generationContext.getMainGeneratedType(); + mainGeneratedType.customizeType(type -> type.addSuperinterface(ParameterizedTypeName.get( + ApplicationContextInitializer.class, GenericApplicationContext.class))); + mainGeneratedType.addMethod(initializeMethod(applicationContextInitialization.toCodeBlock())); + } + + private MethodSpec.Builder initializeMethod(CodeBlock methodBody) { + MethodSpec.Builder method = MethodSpec.methodBuilder("initialize").addModifiers(Modifier.PUBLIC) + .addParameter(GenericApplicationContext.class, "context").addAnnotation(Override.class); + method.addCode(methodBody); + return method; + } + + private void filterBeanFactory(List contributions, DefaultListableBeanFactory beanFactory) { + BiPredicate filter = Stream.concat(Stream.of(aotContributingExcludeFilter()), + contributions.stream().map(BeanFactoryContribution::getBeanDefinitionExcludeFilter)) + .filter(Objects::nonNull).reduce((n, d) -> false, BiPredicate::or); + for (String beanName : beanFactory.getBeanDefinitionNames()) { + BeanDefinition bd = beanFactory.getMergedBeanDefinition(beanName); + if (filter.test(beanName, bd)) { + if (logger.isDebugEnabled()) { + logger.debug("Filtering out bean with name" + beanName + ": " + bd); + } + beanFactory.removeBeanDefinition(beanName); + } + } + } + + // TODO: is this right? + private BiPredicate aotContributingExcludeFilter() { + return (beanName, beanDefinition) -> { + Class type = beanDefinition.getResolvableType().toClass(); + return AotContributingBeanFactoryPostProcessor.class.isAssignableFrom(type) || + AotContributingBeanPostProcessor.class.isAssignableFrom(type); + }; + } + + + private void applyContributions(List contributions, + ApplicationContextInitialization initialization) { + for (BeanFactoryContribution contribution : contributions) { + contribution.applyTo(initialization); + } + } + + /** + * Resolve the {@link BeanFactoryContribution} available in the specified + * bean factory. Infrastructure is contributed first, and bean definitions + * registration last. + * @param beanFactory the bean factory to process + * @return the contribution to apply + * @see InfrastructureContribution + * @see BeanDefinitionsContribution + */ + private List resolveBeanFactoryContributions(DefaultListableBeanFactory beanFactory) { + List contributions = new ArrayList<>(); + contributions.add(new InfrastructureContribution()); + List postProcessors = getAotContributingBeanFactoryPostProcessors(beanFactory); + for (AotContributingBeanFactoryPostProcessor postProcessor : postProcessors) { + BeanFactoryContribution contribution = postProcessor.contribute(beanFactory); + if (contribution != null) { + contributions.add(contribution); + } + } + contributions.add(new BeanDefinitionsContribution(beanFactory)); + return contributions; + } + + private static List getAotContributingBeanFactoryPostProcessors(DefaultListableBeanFactory beanFactory) { + String[] postProcessorNames = beanFactory.getBeanNamesForType(AotContributingBeanFactoryPostProcessor.class, true, false); + List postProcessors = new ArrayList<>(); + for (String ppName : postProcessorNames) { + postProcessors.add(beanFactory.getBean(ppName, AotContributingBeanFactoryPostProcessor.class)); + } + sortPostProcessors(postProcessors, beanFactory); + return postProcessors; + } + + private static void sortPostProcessors(List postProcessors, ConfigurableListableBeanFactory beanFactory) { + // Nothing to sort? + if (postProcessors.size() <= 1) { + return; + } + Comparator comparatorToUse = null; + if (beanFactory instanceof DefaultListableBeanFactory) { + comparatorToUse = ((DefaultListableBeanFactory) beanFactory).getDependencyComparator(); + } + if (comparatorToUse == null) { + comparatorToUse = OrderComparator.INSTANCE; + } + postProcessors.sort(comparatorToUse); + } + +} diff --git a/spring-context/src/main/java/org/springframework/context/generator/ApplicationContextInitialization.java b/spring-context/src/main/java/org/springframework/context/generator/ApplicationContextInitialization.java new file mode 100644 index 0000000000..9f2704f2c0 --- /dev/null +++ b/spring-context/src/main/java/org/springframework/context/generator/ApplicationContextInitialization.java @@ -0,0 +1,35 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.context.generator; + +import org.springframework.aot.generator.GeneratedTypeContext; +import org.springframework.beans.factory.generator.BeanFactoryInitialization; +import org.springframework.context.ApplicationContext; + +/** + * The initialization of an {@link ApplicationContext}. + * + * @author Andy Wilkinson + * @since 6.0 + */ +public class ApplicationContextInitialization extends BeanFactoryInitialization { + + public ApplicationContextInitialization(GeneratedTypeContext generatedTypeContext) { + super(generatedTypeContext); + } + +} diff --git a/spring-context/src/main/java/org/springframework/context/generator/InfrastructureContribution.java b/spring-context/src/main/java/org/springframework/context/generator/InfrastructureContribution.java new file mode 100644 index 0000000000..6a0ce89e3a --- /dev/null +++ b/spring-context/src/main/java/org/springframework/context/generator/InfrastructureContribution.java @@ -0,0 +1,43 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.context.generator; + +import org.springframework.beans.factory.generator.BeanFactoryContribution; +import org.springframework.beans.factory.generator.BeanFactoryInitialization; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.context.annotation.ContextAnnotationAutowireCandidateResolver; + +/** + * A {@link BeanFactoryContribution} that configures the low-level + * infrastructure necessary to process an AOT context. + * + * @author Stephane Nicoll + */ +class InfrastructureContribution implements BeanFactoryContribution { + + @Override + public void applyTo(BeanFactoryInitialization initialization) { + initialization.contribute(code -> { + code.add("// infrastructure\n"); + code.addStatement("$T beanFactory = context.getDefaultListableBeanFactory()", + DefaultListableBeanFactory.class); + code.addStatement("beanFactory.setAutowireCandidateResolver(new $T())", + ContextAnnotationAutowireCandidateResolver.class); + }); + } + +} diff --git a/spring-context/src/main/java/org/springframework/context/generator/package-info.java b/spring-context/src/main/java/org/springframework/context/generator/package-info.java new file mode 100644 index 0000000000..af0aaf7294 --- /dev/null +++ b/spring-context/src/main/java/org/springframework/context/generator/package-info.java @@ -0,0 +1,10 @@ +/** + * Support for generating code that represents the state of an application + * context. + */ +@NonNullApi +@NonNullFields +package org.springframework.context.generator; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-context/src/test/java/org/springframework/context/generator/ApplicationContextAotGeneratorTests.java b/spring-context/src/test/java/org/springframework/context/generator/ApplicationContextAotGeneratorTests.java new file mode 100644 index 0000000000..4935de2965 --- /dev/null +++ b/spring-context/src/test/java/org/springframework/context/generator/ApplicationContextAotGeneratorTests.java @@ -0,0 +1,318 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.context.generator; + +import java.io.IOException; +import java.io.StringWriter; +import java.util.function.BiPredicate; +import java.util.function.Consumer; +import java.util.function.Supplier; + +import org.junit.jupiter.api.Test; + +import org.springframework.aot.generator.DefaultGeneratedTypeContext; +import org.springframework.aot.generator.GeneratedType; +import org.springframework.aot.generator.GeneratedTypeContext; +import org.springframework.aot.test.generator.compile.TestCompiler; +import org.springframework.aot.test.generator.file.SourceFile; +import org.springframework.aot.test.generator.file.SourceFiles; +import org.springframework.beans.factory.annotation.AutowiredAnnotationBeanPostProcessor; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.beans.factory.generator.AotContributingBeanFactoryPostProcessor; +import org.springframework.beans.factory.generator.AotContributingBeanPostProcessor; +import org.springframework.beans.factory.generator.BeanFactoryContribution; +import org.springframework.beans.factory.generator.BeanFactoryInitialization; +import org.springframework.beans.factory.generator.BeanInstantiationContribution; +import org.springframework.beans.factory.support.BeanDefinitionBuilder; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.context.ApplicationContextInitializer; +import org.springframework.context.annotation.AnnotationConfigUtils; +import org.springframework.context.support.GenericApplicationContext; +import org.springframework.context.testfixture.context.generator.SimpleComponent; +import org.springframework.context.testfixture.context.generator.annotation.AutowiredComponent; +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.JavaFile; +import org.springframework.lang.Nullable; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link ApplicationContextAotGenerator}. + * + * @author Stephane Nicoll + */ +class ApplicationContextAotGeneratorTests { + + private static final ClassName MAIN_GENERATED_TYPE = ClassName.get("com.example", "Test"); + + @Test + void generateApplicationContextWithSimpleBean() { + GenericApplicationContext context = new GenericApplicationContext(); + context.registerBeanDefinition("test", new RootBeanDefinition(SimpleComponent.class)); + compile(context, toFreshApplicationContext(GenericApplicationContext::new, aotContext -> { + assertThat(aotContext.getBeanDefinitionNames()).containsOnly("test"); + assertThat(aotContext.getBean("test")).isInstanceOf(SimpleComponent.class); + })); + } + + @Test + void generateApplicationContextWithAutowiring() { + GenericApplicationContext context = new GenericApplicationContext(); + context.registerBeanDefinition(AnnotationConfigUtils.AUTOWIRED_ANNOTATION_PROCESSOR_BEAN_NAME, + BeanDefinitionBuilder.rootBeanDefinition(AutowiredAnnotationBeanPostProcessor.class) + .setRole(BeanDefinition.ROLE_INFRASTRUCTURE).getBeanDefinition()); + context.registerBeanDefinition("autowiredComponent", new RootBeanDefinition(AutowiredComponent.class)); + context.registerBeanDefinition("number", BeanDefinitionBuilder.rootBeanDefinition(Integer.class, "valueOf") + .addConstructorArgValue("42").getBeanDefinition()); + compile(context, toFreshApplicationContext(GenericApplicationContext::new, aotContext -> { + assertThat(aotContext.getBeanDefinitionNames()).containsOnly("autowiredComponent", "number"); + AutowiredComponent bean = aotContext.getBean(AutowiredComponent.class); + assertThat(bean.getEnvironment()).isSameAs(aotContext.getEnvironment()); + assertThat(bean.getCounter()).isEqualTo(42L); + })); + } + + @Test + void generateApplicationContextWitNoContributors() { + GeneratedTypeContext generationContext = createGenerationContext(); + ApplicationContextAotGenerator generator = new ApplicationContextAotGenerator(); + generator.generateApplicationContext(new GenericApplicationContext(), generationContext); + assertThat(write(generationContext.getMainGeneratedType())).contains(""" + public class Test implements ApplicationContextInitializer { + @Override + public void initialize(GenericApplicationContext context) { + // infrastructure + DefaultListableBeanFactory beanFactory = context.getDefaultListableBeanFactory(); + beanFactory.setAutowireCandidateResolver(new ContextAnnotationAutowireCandidateResolver()); + } + } + """); + } + + @Test + void generateApplicationContextApplyContributionAsIsWithNewLineAtTheEnd() { + GenericApplicationContext applicationContext = new GenericApplicationContext(); + registerAotContributingBeanDefinition(applicationContext, "bpp", code -> code.add("// Hello")); + GeneratedTypeContext generationContext = createGenerationContext(); + ApplicationContextAotGenerator generator = new ApplicationContextAotGenerator(); + generator.generateApplicationContext(applicationContext, generationContext); + assertThat(write(generationContext.getMainGeneratedType())).contains(""" + public class Test implements ApplicationContextInitializer { + @Override + public void initialize(GenericApplicationContext context) { + // infrastructure + DefaultListableBeanFactory beanFactory = context.getDefaultListableBeanFactory(); + beanFactory.setAutowireCandidateResolver(new ContextAnnotationAutowireCandidateResolver()); + // Hello + } + } + """); + } + + @Test + void generateApplicationContextApplyMultipleContributionAsIsWithNewLineAtTheEnd() { + GeneratedTypeContext generationContext = createGenerationContext(); + GenericApplicationContext applicationContext = new GenericApplicationContext(); + registerAotContributingBeanDefinition(applicationContext, "bpp", code -> code.add("// Hello")); + registerAotContributingBeanDefinition(applicationContext, "bpp2", code -> code.add("// World")); + ApplicationContextAotGenerator generator = new ApplicationContextAotGenerator(); + generator.generateApplicationContext(applicationContext, generationContext); + assertThat(write(generationContext.getMainGeneratedType())).contains(""" + public class Test implements ApplicationContextInitializer { + @Override + public void initialize(GenericApplicationContext context) { + // infrastructure + DefaultListableBeanFactory beanFactory = context.getDefaultListableBeanFactory(); + beanFactory.setAutowireCandidateResolver(new ContextAnnotationAutowireCandidateResolver()); + // Hello + // World + } + } + """); + } + + @Test + void generateApplicationContextExcludeAotContributingBeanFactoryPostProcessorByDefault() { + GenericApplicationContext context = new GenericApplicationContext(); + context.registerBeanDefinition("test", new RootBeanDefinition(NoOpAotContributingBeanFactoryPostProcessor.class)); + compile(context, toFreshApplicationContext(GenericApplicationContext::new, aotContext -> + assertThat(aotContext.getBeanDefinitionNames()).isEmpty())); + } + + @Test + void generateApplicationContextExcludeAotContributingBeanPostProcessorByDefault() { + GenericApplicationContext context = new GenericApplicationContext(); + context.registerBeanDefinition("test", new RootBeanDefinition(NoOpAotContributingBeanPostProcessor.class)); + compile(context, toFreshApplicationContext(GenericApplicationContext::new, aotContext -> + assertThat(aotContext.getBeanDefinitionNames()).isEmpty())); + } + + @Test + void generateApplicationContextInvokeExcludePredicateInOrder() { + GeneratedTypeContext generationContext = createGenerationContext(); + GenericApplicationContext applicationContext = new GenericApplicationContext(); + DefaultListableBeanFactory beanFactory = applicationContext.getDefaultListableBeanFactory(); + BiPredicate excludeFilter = mock(BiPredicate.class); + given(excludeFilter.test(eq("bean1"), any(BeanDefinition.class))).willReturn(Boolean.FALSE); + given(excludeFilter.test(eq("bean2"), any(BeanDefinition.class))).willReturn(Boolean.TRUE); + applicationContext.registerBeanDefinition("bean2", new RootBeanDefinition(SimpleComponent.class)); + applicationContext.registerBeanDefinition("bean1", new RootBeanDefinition(SimpleComponent.class)); + registerAotContributingBeanDefinition(applicationContext, "bpp", code -> {}, excludeFilter); + ApplicationContextAotGenerator generator = new ApplicationContextAotGenerator(); + generator.generateApplicationContext(applicationContext, generationContext); + assertThat(write(generationContext.getMainGeneratedType())) + .doesNotContain("bean2").doesNotContain("bpp") + .contains("BeanDefinitionRegistrar.of(\"bean1\", SimpleComponent.class)"); + verify(excludeFilter).test(eq("bean2"), any(BeanDefinition.class)); + verify(excludeFilter).test("bean1", beanFactory.getMergedBeanDefinition("bean1")); + } + + + @SuppressWarnings("rawtypes") + private void compile(GenericApplicationContext applicationContext, Consumer initializer) { + DefaultGeneratedTypeContext generationContext = createGenerationContext(); + ApplicationContextAotGenerator generator = new ApplicationContextAotGenerator(); + generator.generateApplicationContext(applicationContext, generationContext); + SourceFiles sourceFiles = SourceFiles.none(); + for (JavaFile javaFile : generationContext.toJavaFiles()) { + sourceFiles = sourceFiles.and(SourceFile.of((javaFile::writeTo))); + } + TestCompiler.forSystem().withSources(sourceFiles).compile(compiled -> { + ApplicationContextInitializer instance = compiled.getInstance(ApplicationContextInitializer.class, MAIN_GENERATED_TYPE.canonicalName()); + initializer.accept(instance); + }); + } + + @SuppressWarnings({ "rawtypes", "unchecked" }) + private Consumer toFreshApplicationContext( + Supplier applicationContextFactory, Consumer context) { + return applicationContextInitializer -> { + T applicationContext = applicationContextFactory.get(); + applicationContextInitializer.initialize(applicationContext); + applicationContext.refresh(); + context.accept(applicationContext); + }; + } + + private DefaultGeneratedTypeContext createGenerationContext() { + return new DefaultGeneratedTypeContext(MAIN_GENERATED_TYPE.packageName(), packageName -> + GeneratedType.of(ClassName.get(packageName, MAIN_GENERATED_TYPE.simpleName()))); + } + + private String write(GeneratedType generatedType) { + try { + StringWriter out = new StringWriter(); + generatedType.toJavaFile().writeTo(out); + return out.toString(); + } + catch (IOException ex) { + throw new IllegalStateException("Failed to write " + generatedType, ex); + } + } + + private void registerAotContributingBeanDefinition(GenericApplicationContext context, String name, + Consumer code) { + registerAotContributingBeanDefinition(context, name, code, + (beanName, beanDefinition) -> name.equals(beanName)); + } + + private void registerAotContributingBeanDefinition(GenericApplicationContext context, String name, + Consumer code, BiPredicate excludeFilter) { + BeanFactoryContribution contribution = new TestBeanFactoryContribution( + initialization -> initialization.contribute(code), excludeFilter); + context.registerBeanDefinition(name, BeanDefinitionBuilder.rootBeanDefinition( + TestAotContributingBeanFactoryPostProcessor.class, () -> + new TestAotContributingBeanFactoryPostProcessor(contribution)).getBeanDefinition()); + } + + + static class TestAotContributingBeanFactoryPostProcessor implements AotContributingBeanFactoryPostProcessor { + + @Nullable + private final BeanFactoryContribution beanFactoryContribution; + + TestAotContributingBeanFactoryPostProcessor(@Nullable BeanFactoryContribution beanFactoryContribution) { + this.beanFactoryContribution = beanFactoryContribution; + } + + TestAotContributingBeanFactoryPostProcessor() { + this(null); + } + + @Override + public BeanFactoryContribution contribute(ConfigurableListableBeanFactory beanFactory) { + return this.beanFactoryContribution; + } + + } + + static class NoOpAotContributingBeanFactoryPostProcessor implements AotContributingBeanFactoryPostProcessor { + + @Override + public BeanFactoryContribution contribute(ConfigurableListableBeanFactory beanFactory) { + return null; + } + } + + static class NoOpAotContributingBeanPostProcessor implements AotContributingBeanPostProcessor { + + @Override + public BeanInstantiationContribution contribute(RootBeanDefinition beanDefinition, Class beanType, String beanName) { + return null; + } + + @Override + public int getOrder() { + return 0; + } + + } + + static class TestBeanFactoryContribution implements BeanFactoryContribution { + private final Consumer contribution; + + private final BiPredicate excludeFilter; + + private int order; + + public TestBeanFactoryContribution(Consumer contribution, + BiPredicate excludeFilter) { + this.contribution = contribution; + this.excludeFilter = excludeFilter; + } + + @Override + public void applyTo(BeanFactoryInitialization initialization) { + this.contribution.accept(initialization); + } + + @Override + public BiPredicate getBeanDefinitionExcludeFilter() { + return this.excludeFilter; + } + + } + +} diff --git a/spring-context/src/test/java/org/springframework/context/generator/InfrastructureContributionTests.java b/spring-context/src/test/java/org/springframework/context/generator/InfrastructureContributionTests.java new file mode 100644 index 0000000000..0616826eb1 --- /dev/null +++ b/spring-context/src/test/java/org/springframework/context/generator/InfrastructureContributionTests.java @@ -0,0 +1,61 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.context.generator; + +import org.junit.jupiter.api.Test; + +import org.springframework.aot.generator.DefaultGeneratedTypeContext; +import org.springframework.aot.generator.GeneratedType; +import org.springframework.aot.generator.GeneratedTypeContext; +import org.springframework.beans.factory.generator.BeanFactoryInitialization; +import org.springframework.context.annotation.ContextAnnotationAutowireCandidateResolver; +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.support.CodeSnippet; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link InfrastructureContribution}. + * + * @author Stephane Nicoll + */ +class InfrastructureContributionTests { + + @Test + void contributeInfrastructure() { + CodeSnippet codeSnippet = contribute(createGenerationContext()); + assertThat(codeSnippet.getSnippet()).isEqualTo(""" + // infrastructure + DefaultListableBeanFactory beanFactory = context.getDefaultListableBeanFactory(); + beanFactory.setAutowireCandidateResolver(new ContextAnnotationAutowireCandidateResolver()); + """); + assertThat(codeSnippet.hasImport(ContextAnnotationAutowireCandidateResolver.class)).isTrue(); + } + + private CodeSnippet contribute(GeneratedTypeContext generationContext) { + InfrastructureContribution contribution = new InfrastructureContribution(); + BeanFactoryInitialization initialization = new BeanFactoryInitialization(generationContext); + contribution.applyTo(initialization); + return CodeSnippet.of(initialization.toCodeBlock()); + } + + private GeneratedTypeContext createGenerationContext() { + return new DefaultGeneratedTypeContext("com.example", packageName -> + GeneratedType.of(ClassName.get(packageName, "Test"))); + } + +} diff --git a/spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/generator/SimpleComponent.java b/spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/generator/SimpleComponent.java new file mode 100644 index 0000000000..7004b8e0c6 --- /dev/null +++ b/spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/generator/SimpleComponent.java @@ -0,0 +1,20 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.context.testfixture.context.generator; + +public class SimpleComponent { +} diff --git a/spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/generator/annotation/AutowiredComponent.java b/spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/generator/annotation/AutowiredComponent.java new file mode 100644 index 0000000000..8bc7dae191 --- /dev/null +++ b/spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/generator/annotation/AutowiredComponent.java @@ -0,0 +1,46 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.context.testfixture.context.generator.annotation; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.core.env.Environment; + +public class AutowiredComponent { + + private Environment environment; + + private Integer counter; + + public Environment getEnvironment() { + return this.environment; + } + + @Autowired + public void setEnvironment(Environment environment) { + this.environment = environment; + } + + public Integer getCounter() { + return this.counter; + } + + @Autowired + public void setCounter(Integer counter) { + this.counter = counter; + } + +} diff --git a/spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/generator/annotation/SimpleConfiguration.java b/spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/generator/annotation/SimpleConfiguration.java new file mode 100644 index 0000000000..34cd687999 --- /dev/null +++ b/spring-context/src/testFixtures/java/org/springframework/context/testfixture/context/generator/annotation/SimpleConfiguration.java @@ -0,0 +1,30 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.context.testfixture.context.generator.annotation; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +@Configuration(proxyBeanMethods = false) +public class SimpleConfiguration { + + @Bean + public String stringBean() { + return "Hello"; + } + +}