Introduce AOT run-time support in the TestContext framework

This commit introduces initial AOT run-time support in the Spring
TestContext Framework.

- DefaultCacheAwareContextLoaderDelegate: when running in AOT mode, now
  loads a test's ApplicationContext via the AotContextLoader SPI
  instead of via the standard SmartContextLoader and ContextLoader SPIs.

- DependencyInjectionTestExecutionListener: when running in AOT mode,
  now injects dependencies into a test instance using a local instance
  of AutowiredAnnotationBeanPostProcessor instead of relying on
  AutowireCapableBeanFactory support.

Closes gh-28205
This commit is contained in:
Sam Brannen
2022-08-20 15:01:53 +02:00
parent ada0880f3c
commit 8a6c1ba198
4 changed files with 221 additions and 15 deletions

View File

@@ -19,13 +19,21 @@ package org.springframework.test.context.cache;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.aot.AotDetector;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextInitializer;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.core.log.LogMessage;
import org.springframework.lang.Nullable;
import org.springframework.test.annotation.DirtiesContext.HierarchyMode;
import org.springframework.test.context.CacheAwareContextLoaderDelegate;
import org.springframework.test.context.ContextLoader;
import org.springframework.test.context.MergedContextConfiguration;
import org.springframework.test.context.SmartContextLoader;
import org.springframework.test.context.aot.AotContextLoader;
import org.springframework.test.context.aot.AotTestMappings;
import org.springframework.test.context.aot.TestContextAotException;
import org.springframework.util.Assert;
/**
@@ -48,6 +56,8 @@ public class DefaultCacheAwareContextLoaderDelegate implements CacheAwareContext
*/
static final ContextCache defaultContextCache = new DefaultContextCache();
private final AotTestMappings aotTestMappings = getAotTestMappings();
private final ContextCache contextCache;
@@ -87,7 +97,12 @@ public class DefaultCacheAwareContextLoaderDelegate implements CacheAwareContext
ApplicationContext context = this.contextCache.get(mergedContextConfiguration);
if (context == null) {
try {
context = loadContextInternal(mergedContextConfiguration);
if (runningInAotMode(mergedContextConfiguration.getTestClass())) {
context = loadContextInAotMode(mergedContextConfiguration);
}
else {
context = loadContextInternal(mergedContextConfiguration);
}
if (logger.isDebugEnabled()) {
logger.debug(String.format("Storing ApplicationContext [%s] in cache under key [%s]",
System.identityHashCode(context), mergedContextConfiguration));
@@ -149,4 +164,45 @@ public class DefaultCacheAwareContextLoaderDelegate implements CacheAwareContext
}
}
protected ApplicationContext loadContextInAotMode(MergedContextConfiguration mergedConfig) throws Exception {
Class<?> testClass = mergedConfig.getTestClass();
ApplicationContextInitializer<ConfigurableApplicationContext> contextInitializer =
this.aotTestMappings.getContextInitializer(testClass);
Assert.state(contextInitializer != null,
() -> "Failed to load AOT ApplicationContextInitializer for test class [%s]"
.formatted(testClass.getName()));
logger.info(LogMessage.format("Loading ApplicationContext in AOT mode for %s", mergedConfig));
ContextLoader contextLoader = mergedConfig.getContextLoader();
if (!((contextLoader instanceof AotContextLoader aotContextLoader) &&
(aotContextLoader.loadContextForAotRuntime(mergedConfig, contextInitializer)
instanceof GenericApplicationContext gac))) {
throw new TestContextAotException("""
Cannot load ApplicationContext for AOT runtime for %s. The configured \
ContextLoader [%s] must be an AotContextLoader and must create a \
GenericApplicationContext."""
.formatted(mergedConfig, contextLoader.getClass().getName()));
}
gac.registerShutdownHook();
return gac;
}
/**
* Determine if we are running in AOT mode for the supplied test class.
*/
private boolean runningInAotMode(Class<?> testClass) {
return (this.aotTestMappings != null && this.aotTestMappings.isSupportedTestClass(testClass));
}
private static AotTestMappings getAotTestMappings() {
if (AotDetector.useGeneratedArtifacts()) {
try {
return new AotTestMappings();
}
catch (Exception ex) {
throw new IllegalStateException("Failed to instantiate AotTestMappings", ex);
}
}
return null;
}
}

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -19,9 +19,15 @@ package org.springframework.test.context.support;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.aot.AotDetector;
import org.springframework.beans.factory.annotation.AutowiredAnnotationBeanPostProcessor;
import org.springframework.beans.factory.config.AutowireCapableBeanFactory;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.context.ApplicationContext;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.core.Conventions;
import org.springframework.test.context.TestContext;
import org.springframework.test.context.aot.AotTestMappings;
/**
* {@code TestExecutionListener} which provides support for dependency
@@ -53,6 +59,8 @@ public class DependencyInjectionTestExecutionListener extends AbstractTestExecut
private static final Log logger = LogFactory.getLog(DependencyInjectionTestExecutionListener.class);
private final AotTestMappings aotTestMappings = getAotTestMappings();
/**
* Returns {@code 2000}.
@@ -78,9 +86,14 @@ public class DependencyInjectionTestExecutionListener extends AbstractTestExecut
@Override
public void prepareTestInstance(TestContext testContext) throws Exception {
if (logger.isDebugEnabled()) {
logger.debug("Performing dependency injection for test context [" + testContext + "].");
logger.debug("Performing dependency injection for test context " + testContext);
}
if (runningInAotMode(testContext.getTestClass())) {
injectDependenciesInAotMode(testContext);
}
else {
injectDependencies(testContext);
}
injectDependencies(testContext);
}
/**
@@ -96,7 +109,12 @@ public class DependencyInjectionTestExecutionListener extends AbstractTestExecut
if (logger.isDebugEnabled()) {
logger.debug("Reinjecting dependencies for test context [" + testContext + "].");
}
injectDependencies(testContext);
if (runningInAotMode(testContext.getTestClass())) {
injectDependenciesInAotMode(testContext);
}
else {
injectDependencies(testContext);
}
}
}
@@ -121,4 +139,40 @@ public class DependencyInjectionTestExecutionListener extends AbstractTestExecut
testContext.removeAttribute(REINJECT_DEPENDENCIES_ATTRIBUTE);
}
private void injectDependenciesInAotMode(TestContext testContext) throws Exception {
ApplicationContext applicationContext = testContext.getApplicationContext();
if (!(applicationContext instanceof GenericApplicationContext gac)) {
throw new IllegalStateException("AOT ApplicationContext must be a GenericApplicationContext instead of " +
applicationContext.getClass().getName());
}
Object bean = testContext.getTestInstance();
Class<?> clazz = testContext.getTestClass();
ConfigurableListableBeanFactory beanFactory = gac.getBeanFactory();
AutowiredAnnotationBeanPostProcessor beanPostProcessor = new AutowiredAnnotationBeanPostProcessor();
beanPostProcessor.setBeanFactory(beanFactory);
beanPostProcessor.processInjection(bean);
beanFactory.initializeBean(bean, clazz.getName() + AutowireCapableBeanFactory.ORIGINAL_INSTANCE_SUFFIX);
testContext.removeAttribute(REINJECT_DEPENDENCIES_ATTRIBUTE);
}
/**
* Determine if we are running in AOT mode for the supplied test class.
*/
private boolean runningInAotMode(Class<?> testClass) {
return (this.aotTestMappings != null && this.aotTestMappings.isSupportedTestClass(testClass));
}
private static AotTestMappings getAotTestMappings() {
if (AotDetector.useGeneratedArtifacts()) {
try {
return new AotTestMappings();
}
catch (Exception ex) {
throw new IllegalStateException("Failed to instantiate AotTestMappings", ex);
}
}
return null;
}
}