Add support for refreshing a GenericApplicationContext for AOT

This commit adds a way to refresh a GenericApplicationContext for ahead
of time processing: refreshForAotProcessing() processes the bean factory
up to a point where it is about to create bean instances.

MergedBeanDefinitionPostProcessor implementations are the only bean
post processors that are invoked during this phase.

Closes gh-28065
This commit is contained in:
Stephane Nicoll
2022-03-06 18:10:31 +01:00
parent 9ba927215e
commit b5695b9248
4 changed files with 341 additions and 4 deletions

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -396,6 +396,15 @@ class AnnotationConfigApplicationContextTests {
assertThat(context.getBeanNamesForType(TypedFactoryBean.class)).hasSize(1);
}
@Test
void refreshForAotProcessingWithConfiguration() {
AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext();
context.register(Config.class);
context.refreshForAotProcessing();
assertThat(context.getBeanFactory().getBeanDefinitionNames()).contains(
"annotationConfigApplicationContextTests.Config", "testBean");
}
@Configuration
static class Config {

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -17,22 +17,35 @@
package org.springframework.context.support;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.beans.factory.NoUniqueBeanDefinitionException;
import org.springframework.beans.factory.config.AbstractFactoryBean;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.GenericBeanDefinition;
import org.springframework.beans.factory.support.MergedBeanDefinitionPostProcessor;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.core.env.ConfigurableEnvironment;
import org.springframework.core.env.Environment;
import org.springframework.core.metrics.jfr.FlightRecorderApplicationStartup;
import org.springframework.util.ObjectUtils;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
/**
* @author Juergen Hoeller
* @author Chris Beams
* @author Stephane Nicoll
*/
public class GenericApplicationContextTests {
@@ -210,6 +223,153 @@ public class GenericApplicationContextTests {
assertThat(context.getBeanFactory().getApplicationStartup()).isEqualTo(applicationStartup);
}
@Test
void refreshForAotSetsContextActive() {
GenericApplicationContext context = new GenericApplicationContext();
assertThat(context.isActive()).isFalse();
context.refreshForAotProcessing();
assertThat(context.isActive()).isTrue();
}
@Test
void refreshForAotRegistersEnvironment() {
ConfigurableEnvironment environment = mock(ConfigurableEnvironment.class);
GenericApplicationContext context = new GenericApplicationContext();
context.setEnvironment(environment);
context.refreshForAotProcessing();
assertThat(context.getBean(Environment.class)).isEqualTo(environment);
}
@Test
void refreshForAotLoadsBeanClassName() {
GenericApplicationContext context = new GenericApplicationContext();
context.registerBeanDefinition("number", new RootBeanDefinition("java.lang.Integer"));
context.refreshForAotProcessing();
assertThat(getBeanDefinition(context, "number").getBeanClass()).isEqualTo(Integer.class);
}
@Test
void refreshForAotLoadsBeanClassNameOfConstructorArgumentInnerBeanDefinition() {
GenericApplicationContext context = new GenericApplicationContext();
RootBeanDefinition beanDefinition = new RootBeanDefinition(String.class);
GenericBeanDefinition innerBeanDefinition = new GenericBeanDefinition();
innerBeanDefinition.setBeanClassName("java.lang.Integer");
beanDefinition.getConstructorArgumentValues().addIndexedArgumentValue(0, innerBeanDefinition);
context.registerBeanDefinition("test",beanDefinition);
context.refreshForAotProcessing();
RootBeanDefinition bd = getBeanDefinition(context, "test");
GenericBeanDefinition value = (GenericBeanDefinition) bd.getConstructorArgumentValues()
.getIndexedArgumentValue(0, GenericBeanDefinition.class).getValue();
assertThat(value.hasBeanClass()).isTrue();
assertThat(value.getBeanClass()).isEqualTo(Integer.class);
}
@Test
void refreshForAotLoadsBeanClassNameOfPropertyValueInnerBeanDefinition() {
GenericApplicationContext context = new GenericApplicationContext();
RootBeanDefinition beanDefinition = new RootBeanDefinition(String.class);
GenericBeanDefinition innerBeanDefinition = new GenericBeanDefinition();
innerBeanDefinition.setBeanClassName("java.lang.Integer");
beanDefinition.getPropertyValues().add("inner", innerBeanDefinition);
context.registerBeanDefinition("test",beanDefinition);
context.refreshForAotProcessing();
RootBeanDefinition bd = getBeanDefinition(context, "test");
GenericBeanDefinition value = (GenericBeanDefinition) bd.getPropertyValues().get("inner");
assertThat(value.hasBeanClass()).isTrue();
assertThat(value.getBeanClass()).isEqualTo(Integer.class);
}
@Test
void refreshForAotInvokesBeanFactoryPostProcessors() {
GenericApplicationContext context = new GenericApplicationContext();
BeanFactoryPostProcessor bfpp = mock(BeanFactoryPostProcessor.class);
context.addBeanFactoryPostProcessor(bfpp);
context.refreshForAotProcessing();
verify(bfpp).postProcessBeanFactory(context.getBeanFactory());
}
@Test
void refreshForAotInvokesMergedBeanDefinitionPostProcessors() {
GenericApplicationContext context = new GenericApplicationContext();
context.registerBeanDefinition("test", new RootBeanDefinition(String.class));
context.registerBeanDefinition("number", new RootBeanDefinition("java.lang.Integer"));
MergedBeanDefinitionPostProcessor bpp = registerMockMergedBeanDefinitionPostProcessor(context);
context.refreshForAotProcessing();
verify(bpp).postProcessMergedBeanDefinition(getBeanDefinition(context, "test"), String.class, "test");
verify(bpp).postProcessMergedBeanDefinition(getBeanDefinition(context, "number"), Integer.class, "number");
}
@Test
void refreshForAotInvokesMergedBeanDefinitionPostProcessorsOnConstructorArgument() {
GenericApplicationContext context = new GenericApplicationContext();
RootBeanDefinition beanDefinition = new RootBeanDefinition(BeanD.class);
GenericBeanDefinition innerBeanDefinition = new GenericBeanDefinition();
innerBeanDefinition.setBeanClassName("java.lang.Integer");
beanDefinition.getConstructorArgumentValues().addIndexedArgumentValue(0, innerBeanDefinition);
context.registerBeanDefinition("test", beanDefinition);
MergedBeanDefinitionPostProcessor bpp = registerMockMergedBeanDefinitionPostProcessor(context);
context.refreshForAotProcessing();
ArgumentCaptor<String> captor = ArgumentCaptor.forClass(String.class);
verify(bpp).postProcessMergedBeanDefinition(getBeanDefinition(context, "test"), BeanD.class, "test");
verify(bpp).postProcessMergedBeanDefinition(any(RootBeanDefinition.class), eq(Integer.class), captor.capture());
assertThat(captor.getValue()).startsWith("(inner bean)");
}
@Test
void refreshForAotInvokesMergedBeanDefinitionPostProcessorsOnPropertyValue() {
GenericApplicationContext context = new GenericApplicationContext();
RootBeanDefinition beanDefinition = new RootBeanDefinition(BeanD.class);
GenericBeanDefinition innerBeanDefinition = new GenericBeanDefinition();
innerBeanDefinition.setBeanClassName("java.lang.Integer");
beanDefinition.getPropertyValues().add("counter", innerBeanDefinition);
context.registerBeanDefinition("test", beanDefinition);
MergedBeanDefinitionPostProcessor bpp = registerMockMergedBeanDefinitionPostProcessor(context);
context.refreshForAotProcessing();
ArgumentCaptor<String> captor = ArgumentCaptor.forClass(String.class);
verify(bpp).postProcessMergedBeanDefinition(getBeanDefinition(context, "test"), BeanD.class, "test");
verify(bpp).postProcessMergedBeanDefinition(any(RootBeanDefinition.class), eq(Integer.class), captor.capture());
assertThat(captor.getValue()).startsWith("(inner bean)");
}
@Test
void refreshForAotFailsOnAnActiveContext() {
GenericApplicationContext context = new GenericApplicationContext();
context.refresh();
assertThatIllegalStateException().isThrownBy(context::refreshForAotProcessing)
.withMessageContaining("does not support multiple refresh attempts");
}
@Test
void refreshForAotDoesNotInitializeFactoryBeansEarly() {
GenericApplicationContext context = new GenericApplicationContext();
context.registerBeanDefinition("genericFactoryBean",
new RootBeanDefinition(TestAotFactoryBean.class));
context.refreshForAotProcessing();
}
@Test
void refreshForAotDoesNotInstantiateBean() {
GenericApplicationContext context = new GenericApplicationContext();
context.registerBeanDefinition("test", BeanDefinitionBuilder.rootBeanDefinition(String.class, () -> {
throw new IllegalStateException("Should not be invoked");
}).getBeanDefinition());
context.refreshForAotProcessing();
}
private MergedBeanDefinitionPostProcessor registerMockMergedBeanDefinitionPostProcessor(GenericApplicationContext context) {
MergedBeanDefinitionPostProcessor bpp = mock(MergedBeanDefinitionPostProcessor.class);
context.registerBeanDefinition("bpp", BeanDefinitionBuilder.rootBeanDefinition(
MergedBeanDefinitionPostProcessor.class, () -> bpp)
.setRole(BeanDefinition.ROLE_INFRASTRUCTURE).getBeanDefinition());
return bpp;
}
private RootBeanDefinition getBeanDefinition(GenericApplicationContext context, String name) {
return (RootBeanDefinition) context.getBeanFactory().getMergedBeanDefinition(name);
}
static class BeanA {
@@ -237,4 +397,39 @@ public class GenericApplicationContextTests {
static class BeanC {}
static class BeanD {
private Integer counter;
BeanD(Integer counter) {
this.counter = counter;
}
public BeanD() {
}
public void setCounter(Integer counter) {
this.counter = counter;
}
}
static class TestAotFactoryBean<T> extends AbstractFactoryBean<T> {
TestAotFactoryBean() {
throw new IllegalStateException("FactoryBean should not be instantied early");
}
@Override
public Class<?> getObjectType() {
return Object.class;
}
@SuppressWarnings("unchecked")
@Override
protected T createInstance() {
return (T) new Object();
}
}
}