Commit 905346d0 authored by Andy Wilkinson's avatar Andy Wilkinson

Consider @Bean methods with args to determine type created by factory

Previously, BeanTypeRegistry would only look for a @Bean method
with no arguments when trying to determine the type that will be
created by a factory bean. This meant that the type produced by a
factory bean declared via a @Bean that has one or more arguments would
be unknown and any on missing bean conditions look for a bean of the
type produced by the factory bean would match in error.

This commit updates BeanTypeRegistry to, where possible, use the
factory method metadata for the bean definition when determining the
type that will be created. This allows it to determine the type for
factory bean created by @Bean methods that take arguments and also
avoids the use reflection to find the factory method. Where factory
method metadata is not available, the existing reflection-based
approach is used as a fallback.

Closes gh-3657
parent 2c0ec1b4
...@@ -33,11 +33,14 @@ import org.springframework.beans.factory.CannotLoadBeanClassException; ...@@ -33,11 +33,14 @@ import org.springframework.beans.factory.CannotLoadBeanClassException;
import org.springframework.beans.factory.FactoryBean; import org.springframework.beans.factory.FactoryBean;
import org.springframework.beans.factory.ListableBeanFactory; import org.springframework.beans.factory.ListableBeanFactory;
import org.springframework.beans.factory.SmartInitializingSingleton; import org.springframework.beans.factory.SmartInitializingSingleton;
import org.springframework.beans.factory.annotation.AnnotatedBeanDefinition;
import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.core.ResolvableType; import org.springframework.core.ResolvableType;
import org.springframework.core.type.MethodMetadata;
import org.springframework.core.type.StandardMethodMetadata;
import org.springframework.util.ClassUtils; import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils; import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
...@@ -108,12 +111,7 @@ abstract class BeanTypeRegistry { ...@@ -108,12 +111,7 @@ abstract class BeanTypeRegistry {
private Class<?> getConfigurationClassFactoryBeanGeneric( private Class<?> getConfigurationClassFactoryBeanGeneric(
ConfigurableListableBeanFactory beanFactory, BeanDefinition definition, ConfigurableListableBeanFactory beanFactory, BeanDefinition definition,
String name) throws Exception { String name) throws Exception {
BeanDefinition factoryDefinition = beanFactory.getBeanDefinition(definition Method method = getFactoryMethod(beanFactory, definition);
.getFactoryBeanName());
Class<?> factoryClass = ClassUtils.forName(factoryDefinition.getBeanClassName(),
beanFactory.getBeanClassLoader());
Method method = ReflectionUtils.findMethod(factoryClass,
definition.getFactoryMethodName());
Class<?> generic = ResolvableType.forMethodReturnType(method) Class<?> generic = ResolvableType.forMethodReturnType(method)
.as(FactoryBean.class).resolveGeneric(); .as(FactoryBean.class).resolveGeneric();
if ((generic == null || generic.equals(Object.class)) if ((generic == null || generic.equals(Object.class))
...@@ -124,6 +122,24 @@ abstract class BeanTypeRegistry { ...@@ -124,6 +122,24 @@ abstract class BeanTypeRegistry {
return generic; return generic;
} }
private Method getFactoryMethod(ConfigurableListableBeanFactory beanFactory,
BeanDefinition definition) throws Exception {
if (definition instanceof AnnotatedBeanDefinition) {
MethodMetadata factoryMethodMetadata = ((AnnotatedBeanDefinition) definition)
.getFactoryMethodMetadata();
if (factoryMethodMetadata instanceof StandardMethodMetadata) {
return ((StandardMethodMetadata) factoryMethodMetadata)
.getIntrospectedMethod();
}
}
BeanDefinition factoryDefinition = beanFactory.getBeanDefinition(definition
.getFactoryBeanName());
Class<?> factoryClass = ClassUtils.forName(factoryDefinition.getBeanClassName(),
beanFactory.getBeanClassLoader());
return ReflectionUtils
.findMethod(factoryClass, definition.getFactoryMethodName());
}
private Class<?> getDirectFactoryBeanGeneric( private Class<?> getDirectFactoryBeanGeneric(
ConfigurableListableBeanFactory beanFactory, BeanDefinition definition, ConfigurableListableBeanFactory beanFactory, BeanDefinition definition,
String name) throws ClassNotFoundException, LinkageError { String name) throws ClassNotFoundException, LinkageError {
......
...@@ -18,9 +18,11 @@ package org.springframework.boot.autoconfigure.condition; ...@@ -18,9 +18,11 @@ package org.springframework.boot.autoconfigure.condition;
import org.junit.Test; import org.junit.Test;
import org.springframework.beans.factory.FactoryBean; import org.springframework.beans.factory.FactoryBean;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.BeanDefinitionRegistry; import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.boot.autoconfigure.PropertyPlaceholderAutoConfiguration; import org.springframework.boot.autoconfigure.PropertyPlaceholderAutoConfiguration;
import org.springframework.boot.test.EnvironmentTestUtils;
import org.springframework.context.annotation.AnnotationConfigApplicationContext; import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
...@@ -127,6 +129,17 @@ public class ConditionalOnMissingBeanTests { ...@@ -127,6 +129,17 @@ public class ConditionalOnMissingBeanTests {
equalTo("fromFactory")); equalTo("fromFactory"));
} }
@Test
public void testOnMissingBeanConditionWithFactoryBeanWithBeanMethodArguments() {
this.context.register(FactoryBeanWithBeanMethodArgumentsConfiguration.class,
ConditionalOnFactoryBean.class,
PropertyPlaceholderAutoConfiguration.class);
EnvironmentTestUtils.addEnvironment(this.context, "theValue:foo");
this.context.refresh();
assertThat(this.context.getBean(ExampleBean.class).toString(),
equalTo("fromFactory"));
}
@Test @Test
public void testOnMissingBeanConditionWithConcreteFactoryBean() { public void testOnMissingBeanConditionWithConcreteFactoryBean() {
this.context.register(ConcreteFactoryBeanConfiguration.class, this.context.register(ConcreteFactoryBeanConfiguration.class,
...@@ -227,6 +240,15 @@ public class ConditionalOnMissingBeanTests { ...@@ -227,6 +240,15 @@ public class ConditionalOnMissingBeanTests {
} }
} }
@Configuration
protected static class FactoryBeanWithBeanMethodArgumentsConfiguration {
@Bean
public FactoryBean<ExampleBean> exampleBeanFactoryBean(
@Value("${theValue}") String value) {
return new ExampleFactoryBean(value);
}
}
@Configuration @Configuration
protected static class ConcreteFactoryBeanConfiguration { protected static class ConcreteFactoryBeanConfiguration {
@Bean @Bean
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment