Introduce first-class support for programmatic bean registration

This commit introduces a new BeanRegistrar interface that can be
implemented to register beans programmatically in a concise and
flexible way.

Those bean registrar implementations are typically imported with
an `@Import` annotation on `@Configuration` classes.

See BeanRegistrarConfigurationTests for a concrete example.

See gh-18353
This commit is contained in:
Sébastien Deleuze
2025-03-06 18:51:09 +01:00
parent aeaf52ee96
commit 496be9ca98
14 changed files with 1356 additions and 8 deletions

View File

@@ -24,6 +24,7 @@ import java.util.Set;
import org.jspecify.annotations.Nullable;
import org.springframework.beans.factory.BeanRegistrar;
import org.springframework.beans.factory.parsing.Location;
import org.springframework.beans.factory.parsing.Problem;
import org.springframework.beans.factory.parsing.ProblemReporter;
@@ -65,6 +66,8 @@ final class ConfigurationClass {
private final Map<String, Class<? extends BeanDefinitionReader>> importedResources =
new LinkedHashMap<>();
private final Set<BeanRegistrar> beanRegistrars = new LinkedHashSet<>();
private final Map<ImportBeanDefinitionRegistrar, AnnotationMetadata> importBeanDefinitionRegistrars =
new LinkedHashMap<>();
@@ -219,6 +222,14 @@ final class ConfigurationClass {
return this.importedResources;
}
void addBeanRegistrar(BeanRegistrar beanRegistrar) {
this.beanRegistrars.add(beanRegistrar);
}
public Set<BeanRegistrar> getBeanRegistrars() {
return this.beanRegistrars;
}
void addImportBeanDefinitionRegistrar(ImportBeanDefinitionRegistrar registrar, AnnotationMetadata importingClassMetadata) {
this.importBeanDefinitionRegistrars.put(registrar, importingClassMetadata);
}

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2002-2024 the original author or authors.
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -29,6 +29,8 @@ import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.beans.factory.BeanDefinitionStoreException;
import org.springframework.beans.factory.BeanRegistrar;
import org.springframework.beans.factory.ListableBeanFactory;
import org.springframework.beans.factory.annotation.AnnotatedBeanDefinition;
import org.springframework.beans.factory.annotation.AnnotatedGenericBeanDefinition;
import org.springframework.beans.factory.config.BeanDefinition;
@@ -41,6 +43,7 @@ import org.springframework.beans.factory.support.BeanDefinitionOverrideException
import org.springframework.beans.factory.support.BeanDefinitionReader;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.BeanNameGenerator;
import org.springframework.beans.factory.support.BeanRegistryAdapter;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.beans.factory.xml.XmlBeanDefinitionReader;
import org.springframework.context.annotation.ConfigurationCondition.ConfigurationPhase;
@@ -146,7 +149,8 @@ class ConfigurationClassBeanDefinitionReader {
}
loadBeanDefinitionsFromImportedResources(configClass.getImportedResources());
loadBeanDefinitionsFromRegistrars(configClass.getImportBeanDefinitionRegistrars());
loadBeanDefinitionsFromImportBeanDefinitionRegistrars(configClass.getImportBeanDefinitionRegistrars());
loadBeanDefinitionsFromBeanRegistrars(configClass.getBeanRegistrars());
}
/**
@@ -395,11 +399,19 @@ class ConfigurationClassBeanDefinitionReader {
});
}
private void loadBeanDefinitionsFromRegistrars(Map<ImportBeanDefinitionRegistrar, AnnotationMetadata> registrars) {
private void loadBeanDefinitionsFromImportBeanDefinitionRegistrars(Map<ImportBeanDefinitionRegistrar, AnnotationMetadata> registrars) {
registrars.forEach((registrar, metadata) ->
registrar.registerBeanDefinitions(metadata, this.registry, this.importBeanNameGenerator));
}
private void loadBeanDefinitionsFromBeanRegistrars(Set<BeanRegistrar> registrars) {
Assert.isInstanceOf(ListableBeanFactory.class, this.registry,
"Cannot support bean registrars since " + this.registry.getClass().getName() +
" does not implement BeanDefinitionRegistry");
registrars.forEach(registrar -> registrar.register(new BeanRegistryAdapter(this.registry,
(ListableBeanFactory) this.registry, registrar.getClass()), this.environment));
}
/**
* {@link RootBeanDefinition} marker subclass used to signify that a bean definition

View File

@@ -40,6 +40,7 @@ import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.beans.factory.BeanDefinitionStoreException;
import org.springframework.beans.factory.BeanRegistrar;
import org.springframework.beans.factory.annotation.AnnotatedBeanDefinition;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanDefinitionHolder;
@@ -597,6 +598,13 @@ class ConfigurationClassParser {
processImports(configClass, currentSourceClass, importSourceClasses, filter, false);
}
}
else if (candidate.isAssignable(BeanRegistrar.class)) {
Class<?> candidateClass = candidate.loadClass();
BeanRegistrar registrar =
ParserStrategyUtils.instantiateClass(candidateClass, BeanRegistrar.class,
this.environment, this.resourceLoader, this.registry);
configClass.addBeanRegistrar(registrar);
}
else if (candidate.isAssignable(ImportBeanDefinitionRegistrar.class)) {
// Candidate class is an ImportBeanDefinitionRegistrar ->
// delegate to it to register additional bean definitions

View File

@@ -20,7 +20,9 @@ import java.io.IOException;
import java.io.UncheckedIOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.Executable;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
@@ -40,8 +42,12 @@ import org.jspecify.annotations.Nullable;
import org.springframework.aop.framework.autoproxy.AutoProxyUtils;
import org.springframework.aot.generate.GeneratedMethod;
import org.springframework.aot.generate.GeneratedMethods;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.MethodReference;
import org.springframework.aot.hint.ExecutableMode;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.ReflectionHints;
import org.springframework.aot.hint.ResourceHints;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.TypeReference;
@@ -49,7 +55,10 @@ import org.springframework.beans.PropertyValues;
import org.springframework.beans.factory.BeanClassLoaderAware;
import org.springframework.beans.factory.BeanDefinitionStoreException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanRegistrar;
import org.springframework.beans.factory.ListableBeanFactory;
import org.springframework.beans.factory.annotation.AnnotatedBeanDefinition;
import org.springframework.beans.factory.aot.AotServices;
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution;
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotProcessor;
import org.springframework.beans.factory.aot.BeanFactoryInitializationCode;
@@ -60,6 +69,7 @@ import org.springframework.beans.factory.aot.BeanRegistrationCodeFragments;
import org.springframework.beans.factory.aot.BeanRegistrationCodeFragmentsDecorator;
import org.springframework.beans.factory.aot.InstanceSupplierCodeGenerator;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanDefinitionCustomizer;
import org.springframework.beans.factory.config.BeanDefinitionHolder;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
@@ -73,6 +83,7 @@ import org.springframework.beans.factory.support.AbstractBeanDefinition;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
import org.springframework.beans.factory.support.BeanNameGenerator;
import org.springframework.beans.factory.support.BeanRegistryAdapter;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.beans.factory.support.RegisteredBean.InstantiationDescriptor;
@@ -99,6 +110,7 @@ import org.springframework.core.type.AnnotationMetadata;
import org.springframework.core.type.MethodMetadata;
import org.springframework.core.type.classreading.CachingMetadataReaderFactory;
import org.springframework.core.type.classreading.MetadataReaderFactory;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.CodeBlock.Builder;
import org.springframework.javapoet.MethodSpec;
@@ -106,6 +118,10 @@ import org.springframework.javapoet.ParameterizedTypeName;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.ObjectUtils;
import org.springframework.util.ReflectionUtils;
/**
* {@link BeanFactoryPostProcessor} used for bootstrapping processing of
@@ -181,6 +197,8 @@ public class ConfigurationClassPostProcessor implements BeanDefinitionRegistryPo
@SuppressWarnings("NullAway.Init")
private List<PropertySourceDescriptor> propertySourceDescriptors;
private Set<BeanRegistrar> beanRegistrars = new LinkedHashSet<>();
@Override
public int getOrder() {
@@ -323,7 +341,8 @@ public class ConfigurationClassPostProcessor implements BeanDefinitionRegistryPo
public @Nullable BeanFactoryInitializationAotContribution processAheadOfTime(ConfigurableListableBeanFactory beanFactory) {
boolean hasPropertySourceDescriptors = !CollectionUtils.isEmpty(this.propertySourceDescriptors);
boolean hasImportRegistry = beanFactory.containsBean(IMPORT_REGISTRY_BEAN_NAME);
if (hasPropertySourceDescriptors || hasImportRegistry) {
boolean hasBeanRegistrars = !this.beanRegistrars.isEmpty();
if (hasPropertySourceDescriptors || hasImportRegistry || hasBeanRegistrars) {
return (generationContext, code) -> {
if (hasPropertySourceDescriptors) {
new PropertySourcesAotContribution(this.propertySourceDescriptors, this::resolvePropertySourceLocation)
@@ -332,6 +351,9 @@ public class ConfigurationClassPostProcessor implements BeanDefinitionRegistryPo
if (hasImportRegistry) {
new ImportAwareAotContribution(beanFactory).applyTo(generationContext, code);
}
if (hasBeanRegistrars) {
new BeanRegistrarAotContribution(this.beanRegistrars, beanFactory).applyTo(generationContext, code);
}
};
}
return null;
@@ -420,6 +442,9 @@ public class ConfigurationClassPostProcessor implements BeanDefinitionRegistryPo
this.importBeanNameGenerator, parser.getImportRegistry());
}
this.reader.loadBeanDefinitions(configClasses);
for (ConfigurationClass configClass : configClasses) {
this.beanRegistrars.addAll(configClass.getBeanRegistrars());
}
alreadyParsed.addAll(configClasses);
processConfig.tag("classCount", () -> String.valueOf(configClasses.size())).end();
@@ -815,4 +840,182 @@ public class ConfigurationClassPostProcessor implements BeanDefinitionRegistryPo
}
}
private static class BeanRegistrarAotContribution implements BeanFactoryInitializationAotContribution {
private static final String CUSTOMIZER_MAP_VARIABLE = "customizers";
private static final String ENVIRONMENT_VARIABLE = "environment";
private final Set<BeanRegistrar> beanRegistrars;
private final ConfigurableListableBeanFactory beanFactory;
private final AotServices<BeanRegistrationAotProcessor> aotProcessors;
public BeanRegistrarAotContribution(Set<BeanRegistrar> beanRegistrars, ConfigurableListableBeanFactory beanFactory) {
this.beanRegistrars = beanRegistrars;
this.beanFactory = beanFactory;
this.aotProcessors = AotServices.factoriesAndBeans(this.beanFactory).load(BeanRegistrationAotProcessor.class);
}
@Override
public void applyTo(GenerationContext generationContext, BeanFactoryInitializationCode beanFactoryInitializationCode) {
GeneratedMethod generatedMethod = beanFactoryInitializationCode.getMethods().add(
"applyBeanRegistrars", builder -> this.generateApplyBeanRegistrarsMethod(builder, generationContext));
beanFactoryInitializationCode.addInitializer(generatedMethod.toMethodReference());
}
private void generateApplyBeanRegistrarsMethod(MethodSpec.Builder method, GenerationContext generationContext) {
ReflectionHints reflectionHints = generationContext.getRuntimeHints().reflection();
method.addJavadoc("Apply bean registrars.");
method.addModifiers(Modifier.PRIVATE);
method.addParameter(ListableBeanFactory.class, BeanFactoryInitializationCode.BEAN_FACTORY_VARIABLE);
method.addParameter(Environment.class, ENVIRONMENT_VARIABLE);
method.addCode(generateCustomizerMap());
for (String name : this.beanFactory.getBeanDefinitionNames()) {
BeanDefinition beanDefinition = this.beanFactory.getMergedBeanDefinition(name);
if (beanDefinition.getSource() instanceof Class<?> sourceClass
&& BeanRegistrar.class.isAssignableFrom(sourceClass)) {
for (BeanRegistrationAotProcessor aotProcessor : this.aotProcessors) {
BeanRegistrationAotContribution contribution =
aotProcessor.processAheadOfTime(RegisteredBean.of(this.beanFactory, name));
if (contribution != null) {
contribution.applyTo(generationContext,
new UnsupportedBeanRegistrationCode(name, aotProcessor.getClass()));
}
}
if (beanDefinition instanceof RootBeanDefinition rootBeanDefinition) {
if (rootBeanDefinition.getPreferredConstructors() != null) {
for (Constructor<?> constructor : rootBeanDefinition.getPreferredConstructors()) {
reflectionHints.registerConstructor(constructor, ExecutableMode.INVOKE);
}
}
if (!ObjectUtils.isEmpty(rootBeanDefinition.getInitMethodNames())) {
method.addCode(generateInitDestroyMethods(name, rootBeanDefinition,
rootBeanDefinition.getInitMethodNames(), "setInitMethodNames", reflectionHints));
}
if (!ObjectUtils.isEmpty(rootBeanDefinition.getDestroyMethodNames())) {
method.addCode(generateInitDestroyMethods(name, rootBeanDefinition,
rootBeanDefinition.getDestroyMethodNames(), "setDestroyMethodNames", reflectionHints));
}
checkUnsupportedFeatures(rootBeanDefinition);
}
}
}
method.addCode(generateRegisterCode());
}
private void checkUnsupportedFeatures(AbstractBeanDefinition beanDefinition) {
if (!ObjectUtils.isEmpty(beanDefinition.getFactoryBeanName())) {
throw new UnsupportedOperationException("AOT post processing of the factory bean name is not supported yet with BeanRegistrar");
}
if (beanDefinition.hasConstructorArgumentValues()) {
throw new UnsupportedOperationException("AOT post processing of argument values is not supported yet with BeanRegistrar");
}
if (!beanDefinition.getQualifiers().isEmpty()) {
throw new UnsupportedOperationException("AOT post processing of qualifiers is not supported yet with BeanRegistrar");
}
for (String attributeName : beanDefinition.attributeNames()) {
if (!attributeName.equals(AbstractBeanDefinition.ORDER_ATTRIBUTE)
&& !attributeName.equals("aotProcessingIgnoreRegistration")) {
throw new UnsupportedOperationException("AOT post processing of attribute " + attributeName +
" is not supported yet with BeanRegistrar");
}
}
}
private CodeBlock generateCustomizerMap() {
Builder code = CodeBlock.builder();
code.addStatement("$T<$T, $T> $L = new $T<>()", MultiValueMap.class, String.class, BeanDefinitionCustomizer.class,
CUSTOMIZER_MAP_VARIABLE, LinkedMultiValueMap.class);
return code.build();
}
private CodeBlock generateRegisterCode() {
Builder code = CodeBlock.builder();
for (BeanRegistrar beanRegistrar : this.beanRegistrars) {
code.addStatement("new $T().register(new $T(($T)$L, $L, $T.class, $L), $L)", beanRegistrar.getClass(),
BeanRegistryAdapter.class, BeanDefinitionRegistry.class, BeanFactoryInitializationCode.BEAN_FACTORY_VARIABLE,
BeanFactoryInitializationCode.BEAN_FACTORY_VARIABLE, beanRegistrar.getClass(), CUSTOMIZER_MAP_VARIABLE,
ENVIRONMENT_VARIABLE);
}
return code.build();
}
private CodeBlock generateInitDestroyMethods(String beanName, AbstractBeanDefinition beanDefinition,
String[] methodNames, String method, ReflectionHints reflectionHints) {
Builder code = CodeBlock.builder();
// For Publisher-based destroy methods
reflectionHints.registerType(TypeReference.of("org.reactivestreams.Publisher"));
Class<?> beanType = ClassUtils.getUserClass(beanDefinition.getResolvableType().toClass());
Arrays.stream(methodNames).forEach(methodName -> addInitDestroyHint(beanType, methodName, reflectionHints));
CodeBlock arguments = Arrays.stream(methodNames)
.map(name -> CodeBlock.of("$S", name))
.collect(CodeBlock.joining(", "));
code.addStatement("$L.add($S, $L -> (($T)$L).$L($L))", CUSTOMIZER_MAP_VARIABLE, beanName, "bd",
AbstractBeanDefinition.class, "bd", method, arguments);
return code.build();
}
// Inspired from BeanDefinitionPropertiesCodeGenerator#addInitDestroyHint
private static void addInitDestroyHint(Class<?> beanUserClass, String methodName, ReflectionHints reflectionHints) {
Class<?> methodDeclaringClass = beanUserClass;
// Parse fully-qualified method name if necessary.
int indexOfDot = methodName.lastIndexOf('.');
if (indexOfDot > 0) {
String className = methodName.substring(0, indexOfDot);
methodName = methodName.substring(indexOfDot + 1);
if (!beanUserClass.getName().equals(className)) {
try {
methodDeclaringClass = ClassUtils.forName(className, beanUserClass.getClassLoader());
}
catch (Throwable ex) {
throw new IllegalStateException("Failed to load Class [" + className +
"] from ClassLoader [" + beanUserClass.getClassLoader() + "]", ex);
}
}
}
Method method = ReflectionUtils.findMethod(methodDeclaringClass, methodName);
if (method != null) {
reflectionHints.registerMethod(method, ExecutableMode.INVOKE);
Method publiclyAccessibleMethod = ClassUtils.getPubliclyAccessibleMethodIfPossible(method, beanUserClass);
if (!publiclyAccessibleMethod.equals(method)) {
reflectionHints.registerMethod(publiclyAccessibleMethod, ExecutableMode.INVOKE);
}
}
}
static class UnsupportedBeanRegistrationCode implements BeanRegistrationCode {
private final String message;
public UnsupportedBeanRegistrationCode(String beanName, Class<?> aotProcessorClass) {
this.message = "Code generation attempted for bean " + beanName + " by the AOT Processor " +
aotProcessorClass + " is not supported with BeanRegistrar yet";
}
@Override
public ClassName getClassName() {
throw new UnsupportedOperationException(this.message);
}
@Override
public GeneratedMethods getMethods() {
throw new UnsupportedOperationException(this.message);
}
@Override
public void addInstancePostProcessor(MethodReference methodReference) {
throw new UnsupportedOperationException(this.message);
}
}
}
}

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -22,6 +22,8 @@ import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import org.springframework.beans.factory.BeanRegistrar;
/**
* Indicates one or more <em>component classes</em> to import &mdash; typically
* {@link Configuration @Configuration} classes.
@@ -57,7 +59,8 @@ public @interface Import {
/**
* {@link Configuration @Configuration}, {@link ImportSelector},
* {@link ImportBeanDefinitionRegistrar}, or regular component classes to import.
* {@link ImportBeanDefinitionRegistrar}, {@link BeanRegistrar} or regular
* component classes to import.
*/
Class<?>[] value();

View File

@@ -30,6 +30,7 @@ import org.springframework.aot.generate.GeneratedMethods;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.MethodReference;
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
import org.springframework.beans.factory.ListableBeanFactory;
import org.springframework.beans.factory.aot.BeanFactoryInitializationCode;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
@@ -150,6 +151,7 @@ class ApplicationContextInitializationCodeGenerator implements BeanFactoryInitia
private @Nullable CodeBlock apply(ClassName className) {
String name = className.canonicalName();
if (name.equals(DefaultListableBeanFactory.class.getName())
|| name.equals(ListableBeanFactory.class.getName())
|| name.equals(ConfigurableListableBeanFactory.class.getName())) {
return CodeBlock.of(BEAN_FACTORY_VARIABLE);
}