diff --git a/spring-context/src/main/java/org/springframework/context/support/GenericApplicationContext.java b/spring-context/src/main/java/org/springframework/context/support/GenericApplicationContext.java index accdb15031..50a7582ff2 100644 --- a/spring-context/src/main/java/org/springframework/context/support/GenericApplicationContext.java +++ b/spring-context/src/main/java/org/springframework/context/support/GenericApplicationContext.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,7 +30,9 @@ import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanDefinitionCustomizer; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.beans.factory.support.BeanDefinitionRegistry; +import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor; import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.support.MergedBeanDefinitionPostProcessor; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.context.ApplicationContext; import org.springframework.core.io.Resource; @@ -60,6 +62,10 @@ import org.springframework.util.Assert; * this context is available right from the start, to be able to register bean * definitions on it. {@link #refresh()} may only be called once. * + *

This ApplicationContext implementation is suitable for Ahead of Time + * processing, using {@link #refreshForAotProcessing()} as an alternative to the + * regular {@link #refresh()}. + * *

Usage example: * *

@@ -86,6 +92,7 @@ import org.springframework.util.Assert;
  *
  * @author Juergen Hoeller
  * @author Chris Beams
+ * @author Stephane Nicoll
  * @since 1.1.2
  * @see #registerBeanDefinition
  * @see #refresh()
@@ -361,6 +368,34 @@ public class GenericApplicationContext extends AbstractApplicationContext implem
 	}
 
 
+	//---------------------------------------------------------------------
+	// AOT processing
+	//---------------------------------------------------------------------
+
+	/**
+	 * Load or refresh the persistent representation of the configuration up to
+	 * a point where the underlying bean factory is ready to create bean
+	 * instances.
+	 * 

This variant of {@link #refresh()} is used by Ahead of Time processing + * that optimizes the application context, typically at build-time. + *

In this mode, only {@link BeanDefinitionRegistryPostProcessor} and + * {@link MergedBeanDefinitionPostProcessor} are invoked. + * @throws BeansException if the bean factory could not be initialized + * @throws IllegalStateException if already initialized and multiple refresh + * attempts are not supported + */ + public void refreshForAotProcessing() { + if (logger.isDebugEnabled()) { + logger.debug("Preparing bean factory for AOT processing"); + } + prepareRefresh(); + obtainFreshBeanFactory(); + prepareBeanFactory(this.beanFactory); + postProcessBeanFactory(this.beanFactory); + invokeBeanFactoryPostProcessors(this.beanFactory); + PostProcessorRegistrationDelegate.invokeMergedBeanDefinitionPostProcessors(this.beanFactory); + } + //--------------------------------------------------------------------- // Convenient methods for registering individual beans //--------------------------------------------------------------------- diff --git a/spring-context/src/main/java/org/springframework/context/support/PostProcessorRegistrationDelegate.java b/spring-context/src/main/java/org/springframework/context/support/PostProcessorRegistrationDelegate.java index ebf935ab51..849c6a9e9c 100644 --- a/spring-context/src/main/java/org/springframework/context/support/PostProcessorRegistrationDelegate.java +++ b/spring-context/src/main/java/org/springframework/context/support/PostProcessorRegistrationDelegate.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,19 +22,25 @@ import java.util.Comparator; import java.util.HashSet; import java.util.List; import java.util.Set; +import java.util.function.BiConsumer; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.beans.PropertyValue; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanFactoryPostProcessor; import org.springframework.beans.factory.config.BeanPostProcessor; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.beans.factory.config.ConstructorArgumentValues.ValueHolder; +import org.springframework.beans.factory.support.AbstractBeanDefinition; import org.springframework.beans.factory.support.AbstractBeanFactory; import org.springframework.beans.factory.support.BeanDefinitionRegistry; import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor; +import org.springframework.beans.factory.support.BeanDefinitionValueResolver; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.beans.factory.support.MergedBeanDefinitionPostProcessor; +import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.core.OrderComparator; import org.springframework.core.Ordered; import org.springframework.core.PriorityOrdered; @@ -47,6 +53,7 @@ import org.springframework.lang.Nullable; * * @author Juergen Hoeller * @author Sam Brannen + * @author Stephane Nicoll * @since 4.0 */ final class PostProcessorRegistrationDelegate { @@ -280,6 +287,36 @@ final class PostProcessorRegistrationDelegate { beanFactory.addBeanPostProcessor(new ApplicationListenerDetector(applicationContext)); } + /** + * Load and sort the post-processors of the specified type. + * @param beanFactory the bean factory to use + * @param beanPostProcessorType the post-processor type + * @param the post-processor type + * @return a list of sorted post-processors for the specified type + */ + static List loadBeanPostProcessors( + ConfigurableListableBeanFactory beanFactory, Class beanPostProcessorType) { + + String[] postProcessorNames = beanFactory.getBeanNamesForType(beanPostProcessorType, true, false); + List postProcessors = new ArrayList<>(); + for (String ppName : postProcessorNames) { + postProcessors.add(beanFactory.getBean(ppName, beanPostProcessorType)); + } + sortPostProcessors(postProcessors, beanFactory); + return postProcessors; + + } + + /** + * Selectively invoke {@link MergedBeanDefinitionPostProcessor} instances + * registered in the specified bean factory, resolving bean definitions as + * well as any inner bean definitions that they may contain. + * @param beanFactory the bean factory to use + */ + static void invokeMergedBeanDefinitionPostProcessors(DefaultListableBeanFactory beanFactory) { + new MergedBeanDefinitionPostProcessorInvoker(beanFactory).invokeMergedBeanDefinitionPostProcessors(); + } + private static void sortPostProcessors(List postProcessors, ConfigurableListableBeanFactory beanFactory) { // Nothing to sort? if (postProcessors.size() <= 1) { @@ -386,4 +423,65 @@ final class PostProcessorRegistrationDelegate { } } + private static final class MergedBeanDefinitionPostProcessorInvoker { + + private final DefaultListableBeanFactory beanFactory; + + private MergedBeanDefinitionPostProcessorInvoker(DefaultListableBeanFactory beanFactory) { + this.beanFactory = beanFactory; + } + + private void invokeMergedBeanDefinitionPostProcessors() { + List postProcessors = PostProcessorRegistrationDelegate.loadBeanPostProcessors( + this.beanFactory, MergedBeanDefinitionPostProcessor.class); + for (String beanName : this.beanFactory.getBeanDefinitionNames()) { + RootBeanDefinition bd = (RootBeanDefinition) this.beanFactory.getMergedBeanDefinition(beanName); + Class beanType = resolveBeanType(bd); + postProcessRootBeanDefinition(postProcessors, beanName, beanType, bd); + } + } + + private void postProcessRootBeanDefinition(List postProcessors, + String beanName, Class beanType, RootBeanDefinition bd) { + BeanDefinitionValueResolver valueResolver = new BeanDefinitionValueResolver(this.beanFactory, beanName, bd); + postProcessors.forEach(postProcessor -> postProcessor.postProcessMergedBeanDefinition(bd, beanType, beanName)); + for (PropertyValue propertyValue : bd.getPropertyValues().getPropertyValueList()) { + Object value = propertyValue.getValue(); + if (value instanceof AbstractBeanDefinition innerBd) { + Class innerBeanType = resolveBeanType(innerBd); + resolveInnerBeanDefinition(valueResolver, innerBd, (innerBeanName, innerBeanDefinition) + -> postProcessRootBeanDefinition(postProcessors, innerBeanName, innerBeanType, innerBeanDefinition)); + } + } + for (ValueHolder valueHolder : bd.getConstructorArgumentValues().getIndexedArgumentValues().values()) { + Object value = valueHolder.getValue(); + if (value instanceof AbstractBeanDefinition innerBd) { + Class innerBeanType = resolveBeanType(innerBd); + resolveInnerBeanDefinition(valueResolver, innerBd, (innerBeanName, innerBeanDefinition) + -> postProcessRootBeanDefinition(postProcessors, innerBeanName, innerBeanType, innerBeanDefinition)); + } + } + } + + private void resolveInnerBeanDefinition(BeanDefinitionValueResolver valueResolver, BeanDefinition innerBeanDefinition, + BiConsumer resolver) { + valueResolver.resolveInnerBean(null, innerBeanDefinition, (name, rbd) -> { + resolver.accept(name, rbd); + return Void.class; + }); + } + + private Class resolveBeanType(AbstractBeanDefinition bd) { + if (!bd.hasBeanClass()) { + try { + bd.resolveBeanClass(this.beanFactory.getBeanClassLoader()); + } + catch (ClassNotFoundException ex) { + // ignore + } + } + return bd.getResolvableType().toClass(); + } + } + } diff --git a/spring-context/src/test/java/org/springframework/context/annotation/AnnotationConfigApplicationContextTests.java b/spring-context/src/test/java/org/springframework/context/annotation/AnnotationConfigApplicationContextTests.java index 566a445288..7ea2b74278 100644 --- a/spring-context/src/test/java/org/springframework/context/annotation/AnnotationConfigApplicationContextTests.java +++ b/spring-context/src/test/java/org/springframework/context/annotation/AnnotationConfigApplicationContextTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -396,6 +396,15 @@ class AnnotationConfigApplicationContextTests { assertThat(context.getBeanNamesForType(TypedFactoryBean.class)).hasSize(1); } + @Test + void refreshForAotProcessingWithConfiguration() { + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + context.register(Config.class); + context.refreshForAotProcessing(); + assertThat(context.getBeanFactory().getBeanDefinitionNames()).contains( + "annotationConfigApplicationContextTests.Config", "testBean"); + } + @Configuration static class Config { diff --git a/spring-context/src/test/java/org/springframework/context/support/GenericApplicationContextTests.java b/spring-context/src/test/java/org/springframework/context/support/GenericApplicationContextTests.java index 36604c6c31..bba6ddf9cc 100644 --- a/spring-context/src/test/java/org/springframework/context/support/GenericApplicationContextTests.java +++ b/spring-context/src/test/java/org/springframework/context/support/GenericApplicationContextTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,22 +17,35 @@ package org.springframework.context.support; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; import org.springframework.beans.factory.NoUniqueBeanDefinitionException; +import org.springframework.beans.factory.config.AbstractFactoryBean; import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.BeanFactoryPostProcessor; +import org.springframework.beans.factory.support.BeanDefinitionBuilder; +import org.springframework.beans.factory.support.GenericBeanDefinition; +import org.springframework.beans.factory.support.MergedBeanDefinitionPostProcessor; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; +import org.springframework.core.env.ConfigurableEnvironment; +import org.springframework.core.env.Environment; import org.springframework.core.metrics.jfr.FlightRecorderApplicationStartup; import org.springframework.util.ObjectUtils; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalStateException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; /** * @author Juergen Hoeller * @author Chris Beams + * @author Stephane Nicoll */ public class GenericApplicationContextTests { @@ -210,6 +223,153 @@ public class GenericApplicationContextTests { assertThat(context.getBeanFactory().getApplicationStartup()).isEqualTo(applicationStartup); } + @Test + void refreshForAotSetsContextActive() { + GenericApplicationContext context = new GenericApplicationContext(); + assertThat(context.isActive()).isFalse(); + context.refreshForAotProcessing(); + assertThat(context.isActive()).isTrue(); + } + + @Test + void refreshForAotRegistersEnvironment() { + ConfigurableEnvironment environment = mock(ConfigurableEnvironment.class); + GenericApplicationContext context = new GenericApplicationContext(); + context.setEnvironment(environment); + context.refreshForAotProcessing(); + assertThat(context.getBean(Environment.class)).isEqualTo(environment); + } + + @Test + void refreshForAotLoadsBeanClassName() { + GenericApplicationContext context = new GenericApplicationContext(); + context.registerBeanDefinition("number", new RootBeanDefinition("java.lang.Integer")); + context.refreshForAotProcessing(); + assertThat(getBeanDefinition(context, "number").getBeanClass()).isEqualTo(Integer.class); + } + + @Test + void refreshForAotLoadsBeanClassNameOfConstructorArgumentInnerBeanDefinition() { + GenericApplicationContext context = new GenericApplicationContext(); + RootBeanDefinition beanDefinition = new RootBeanDefinition(String.class); + GenericBeanDefinition innerBeanDefinition = new GenericBeanDefinition(); + innerBeanDefinition.setBeanClassName("java.lang.Integer"); + beanDefinition.getConstructorArgumentValues().addIndexedArgumentValue(0, innerBeanDefinition); + context.registerBeanDefinition("test",beanDefinition); + context.refreshForAotProcessing(); + RootBeanDefinition bd = getBeanDefinition(context, "test"); + GenericBeanDefinition value = (GenericBeanDefinition) bd.getConstructorArgumentValues() + .getIndexedArgumentValue(0, GenericBeanDefinition.class).getValue(); + assertThat(value.hasBeanClass()).isTrue(); + assertThat(value.getBeanClass()).isEqualTo(Integer.class); + + } + + @Test + void refreshForAotLoadsBeanClassNameOfPropertyValueInnerBeanDefinition() { + GenericApplicationContext context = new GenericApplicationContext(); + RootBeanDefinition beanDefinition = new RootBeanDefinition(String.class); + GenericBeanDefinition innerBeanDefinition = new GenericBeanDefinition(); + innerBeanDefinition.setBeanClassName("java.lang.Integer"); + beanDefinition.getPropertyValues().add("inner", innerBeanDefinition); + context.registerBeanDefinition("test",beanDefinition); + context.refreshForAotProcessing(); + RootBeanDefinition bd = getBeanDefinition(context, "test"); + GenericBeanDefinition value = (GenericBeanDefinition) bd.getPropertyValues().get("inner"); + assertThat(value.hasBeanClass()).isTrue(); + assertThat(value.getBeanClass()).isEqualTo(Integer.class); + } + + @Test + void refreshForAotInvokesBeanFactoryPostProcessors() { + GenericApplicationContext context = new GenericApplicationContext(); + BeanFactoryPostProcessor bfpp = mock(BeanFactoryPostProcessor.class); + context.addBeanFactoryPostProcessor(bfpp); + context.refreshForAotProcessing(); + verify(bfpp).postProcessBeanFactory(context.getBeanFactory()); + } + + @Test + void refreshForAotInvokesMergedBeanDefinitionPostProcessors() { + GenericApplicationContext context = new GenericApplicationContext(); + context.registerBeanDefinition("test", new RootBeanDefinition(String.class)); + context.registerBeanDefinition("number", new RootBeanDefinition("java.lang.Integer")); + MergedBeanDefinitionPostProcessor bpp = registerMockMergedBeanDefinitionPostProcessor(context); + context.refreshForAotProcessing(); + verify(bpp).postProcessMergedBeanDefinition(getBeanDefinition(context, "test"), String.class, "test"); + verify(bpp).postProcessMergedBeanDefinition(getBeanDefinition(context, "number"), Integer.class, "number"); + } + + @Test + void refreshForAotInvokesMergedBeanDefinitionPostProcessorsOnConstructorArgument() { + GenericApplicationContext context = new GenericApplicationContext(); + RootBeanDefinition beanDefinition = new RootBeanDefinition(BeanD.class); + GenericBeanDefinition innerBeanDefinition = new GenericBeanDefinition(); + innerBeanDefinition.setBeanClassName("java.lang.Integer"); + beanDefinition.getConstructorArgumentValues().addIndexedArgumentValue(0, innerBeanDefinition); + context.registerBeanDefinition("test", beanDefinition); + MergedBeanDefinitionPostProcessor bpp = registerMockMergedBeanDefinitionPostProcessor(context); + context.refreshForAotProcessing(); + ArgumentCaptor captor = ArgumentCaptor.forClass(String.class); + verify(bpp).postProcessMergedBeanDefinition(getBeanDefinition(context, "test"), BeanD.class, "test"); + verify(bpp).postProcessMergedBeanDefinition(any(RootBeanDefinition.class), eq(Integer.class), captor.capture()); + assertThat(captor.getValue()).startsWith("(inner bean)"); + } + + @Test + void refreshForAotInvokesMergedBeanDefinitionPostProcessorsOnPropertyValue() { + GenericApplicationContext context = new GenericApplicationContext(); + RootBeanDefinition beanDefinition = new RootBeanDefinition(BeanD.class); + GenericBeanDefinition innerBeanDefinition = new GenericBeanDefinition(); + innerBeanDefinition.setBeanClassName("java.lang.Integer"); + beanDefinition.getPropertyValues().add("counter", innerBeanDefinition); + context.registerBeanDefinition("test", beanDefinition); + MergedBeanDefinitionPostProcessor bpp = registerMockMergedBeanDefinitionPostProcessor(context); + context.refreshForAotProcessing(); + ArgumentCaptor captor = ArgumentCaptor.forClass(String.class); + verify(bpp).postProcessMergedBeanDefinition(getBeanDefinition(context, "test"), BeanD.class, "test"); + verify(bpp).postProcessMergedBeanDefinition(any(RootBeanDefinition.class), eq(Integer.class), captor.capture()); + assertThat(captor.getValue()).startsWith("(inner bean)"); + } + + @Test + void refreshForAotFailsOnAnActiveContext() { + GenericApplicationContext context = new GenericApplicationContext(); + context.refresh(); + assertThatIllegalStateException().isThrownBy(context::refreshForAotProcessing) + .withMessageContaining("does not support multiple refresh attempts"); + } + + @Test + void refreshForAotDoesNotInitializeFactoryBeansEarly() { + GenericApplicationContext context = new GenericApplicationContext(); + context.registerBeanDefinition("genericFactoryBean", + new RootBeanDefinition(TestAotFactoryBean.class)); + context.refreshForAotProcessing(); + } + + @Test + void refreshForAotDoesNotInstantiateBean() { + GenericApplicationContext context = new GenericApplicationContext(); + context.registerBeanDefinition("test", BeanDefinitionBuilder.rootBeanDefinition(String.class, () -> { + throw new IllegalStateException("Should not be invoked"); + }).getBeanDefinition()); + context.refreshForAotProcessing(); + } + + private MergedBeanDefinitionPostProcessor registerMockMergedBeanDefinitionPostProcessor(GenericApplicationContext context) { + MergedBeanDefinitionPostProcessor bpp = mock(MergedBeanDefinitionPostProcessor.class); + context.registerBeanDefinition("bpp", BeanDefinitionBuilder.rootBeanDefinition( + MergedBeanDefinitionPostProcessor.class, () -> bpp) + .setRole(BeanDefinition.ROLE_INFRASTRUCTURE).getBeanDefinition()); + return bpp; + } + + + private RootBeanDefinition getBeanDefinition(GenericApplicationContext context, String name) { + return (RootBeanDefinition) context.getBeanFactory().getMergedBeanDefinition(name); + } + static class BeanA { @@ -237,4 +397,39 @@ public class GenericApplicationContextTests { static class BeanC {} + static class BeanD { + + private Integer counter; + + BeanD(Integer counter) { + this.counter = counter; + } + + public BeanD() { + } + + public void setCounter(Integer counter) { + this.counter = counter; + } + + } + + static class TestAotFactoryBean extends AbstractFactoryBean { + + TestAotFactoryBean() { + throw new IllegalStateException("FactoryBean should not be instantied early"); + } + + @Override + public Class getObjectType() { + return Object.class; + } + + @SuppressWarnings("unchecked") + @Override + protected T createInstance() { + return (T) new Object(); + } + } + }