Commit 5d591edb authored by Jakub Kubrynski's avatar Jakub Kubrynski Committed by Dave Syer

Consider FactoryBean classes in OnBeanCondition

Update OnBeanCondition to attempt to consider FactoryBean classes
for bean type matches. To ensure early instantiation does not occur, the
object type from the FactoryBean is deduced by resolving generics on the
declaration.

Fixes gh-355
parent fa9a506e
...@@ -20,15 +20,24 @@ import java.lang.annotation.Annotation; ...@@ -20,15 +20,24 @@ import java.lang.annotation.Annotation;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Set;
import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryUtils; import org.springframework.beans.factory.BeanFactoryUtils;
import org.springframework.beans.factory.FactoryBean;
import org.springframework.beans.factory.HierarchicalBeanFactory;
import org.springframework.beans.factory.ListableBeanFactory;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Condition; import org.springframework.context.annotation.Condition;
import org.springframework.context.annotation.ConditionContext; import org.springframework.context.annotation.ConditionContext;
import org.springframework.context.annotation.ConfigurationCondition; import org.springframework.context.annotation.ConfigurationCondition;
import org.springframework.core.ResolvableType;
import org.springframework.core.type.AnnotatedTypeMetadata; import org.springframework.core.type.AnnotatedTypeMetadata;
import org.springframework.core.type.MethodMetadata; import org.springframework.core.type.MethodMetadata;
import org.springframework.util.Assert; import org.springframework.util.Assert;
...@@ -43,6 +52,7 @@ import org.springframework.util.StringUtils; ...@@ -43,6 +52,7 @@ import org.springframework.util.StringUtils;
* *
* @author Phillip Webb * @author Phillip Webb
* @author Dave Syer * @author Dave Syer
* @author Jakub Kubrynski
*/ */
class OnBeanCondition extends SpringBootCondition implements ConfigurationCondition { class OnBeanCondition extends SpringBootCondition implements ConfigurationCondition {
...@@ -100,8 +110,8 @@ class OnBeanCondition extends SpringBootCondition implements ConfigurationCondit ...@@ -100,8 +110,8 @@ class OnBeanCondition extends SpringBootCondition implements ConfigurationCondit
boolean considerHierarchy = beans.getStrategy() == SearchStrategy.ALL; boolean considerHierarchy = beans.getStrategy() == SearchStrategy.ALL;
for (String type : beans.getTypes()) { for (String type : beans.getTypes()) {
beanNames.addAll(Arrays.asList(getBeanNamesForType(beanFactory, type, beanNames.addAll(getBeanNamesForType(beanFactory, type,
context.getClassLoader(), considerHierarchy))); context.getClassLoader(), considerHierarchy));
} }
for (String annotation : beans.getAnnotations()) { for (String annotation : beans.getAnnotations()) {
...@@ -126,23 +136,92 @@ class OnBeanCondition extends SpringBootCondition implements ConfigurationCondit ...@@ -126,23 +136,92 @@ class OnBeanCondition extends SpringBootCondition implements ConfigurationCondit
return beanFactory.containsLocalBean(beanName); return beanFactory.containsLocalBean(beanName);
} }
private String[] getBeanNamesForType(ConfigurableListableBeanFactory beanFactory, private Collection<String> getBeanNamesForType(
String type, ClassLoader classLoader, boolean considerHierarchy) ConfigurableListableBeanFactory beanFactory, String type,
throws LinkageError { ClassLoader classLoader, boolean considerHierarchy) throws LinkageError {
// eagerInit set to false to prevent early instantiation (some
// factory beans will not be able to determine their object type at this
// stage, so those are not eligible for matching this condition)
try { try {
Class<?> typeClass = ClassUtils.forName(type, classLoader); Set<String> result = new LinkedHashSet<String>();
if (considerHierarchy) { collectBeanNamesForType(result, beanFactory,
return BeanFactoryUtils.beanNamesForTypeIncludingAncestors(beanFactory, ClassUtils.forName(type, classLoader), considerHierarchy);
typeClass, false, false); return result;
}
return beanFactory.getBeanNamesForType(typeClass, false, false);
} }
catch (ClassNotFoundException ex) { catch (ClassNotFoundException ex) {
return NO_BEANS; return Collections.emptySet();
}
}
private void collectBeanNamesForType(Set<String> result,
ListableBeanFactory beanFactory, Class<?> type, boolean considerHierarchy) {
// eagerInit set to false to prevent early instantiation
result.addAll(Arrays.asList(beanFactory.getBeanNamesForType(type, true, false)));
if (beanFactory instanceof ConfigurableListableBeanFactory) {
collectBeanNamesForTypeFromFactoryBeans(result,
(ConfigurableListableBeanFactory) beanFactory, type);
}
if (considerHierarchy && beanFactory instanceof HierarchicalBeanFactory) {
BeanFactory parent = ((HierarchicalBeanFactory) beanFactory)
.getParentBeanFactory();
if (parent instanceof ListableBeanFactory) {
collectBeanNamesForType(result, (ListableBeanFactory) parent, type,
considerHierarchy);
}
}
}
/**
* Attempt to collect bean names for type by considering FactoryBean generics. Some
* factory beans will not be able to determine their object type at this stage, so
* those are not eligible for matching this condition.
*/
private void collectBeanNamesForTypeFromFactoryBeans(Set<String> result,
ConfigurableListableBeanFactory beanFactory, Class<?> type) {
String[] names = beanFactory.getBeanNamesForType(FactoryBean.class, true, false);
for (String name : names) {
name = BeanFactoryUtils.transformedBeanName(name);
BeanDefinition beanDefinition = beanFactory.getBeanDefinition(name);
Class<?> generic = getFactoryBeanGeneric(beanFactory, beanDefinition);
if (generic != null && ClassUtils.isAssignable(type, generic)) {
result.add(name);
}
}
}
private Class<?> getFactoryBeanGeneric(ConfigurableListableBeanFactory beanFactory,
BeanDefinition definition) {
try {
if (StringUtils.hasLength(definition.getFactoryBeanName())
&& StringUtils.hasLength(definition.getFactoryMethodName())) {
return getConfigurationClassFactoryBeanGeneric(beanFactory, definition);
}
if (StringUtils.hasLength(definition.getBeanClassName())) {
return getDirectFactoryBeanGeneric(beanFactory, definition);
}
}
catch (Exception ex) {
} }
return null;
}
private Class<?> getConfigurationClassFactoryBeanGeneric(
ConfigurableListableBeanFactory beanFactory, BeanDefinition definition)
throws Exception {
BeanDefinition factoryDefinition = beanFactory.getBeanDefinition(definition
.getFactoryBeanName());
Class<?> factoryClass = ClassUtils.forName(factoryDefinition.getBeanClassName(),
beanFactory.getBeanClassLoader());
Method method = ReflectionUtils.findMethod(factoryClass,
definition.getFactoryMethodName());
return ResolvableType.forMethodReturnType(method).as(FactoryBean.class)
.resolveGeneric();
}
private Class<?> getDirectFactoryBeanGeneric(
ConfigurableListableBeanFactory beanFactory, BeanDefinition definition)
throws ClassNotFoundException, LinkageError {
Class<?> factoryBeanClass = ClassUtils.forName(definition.getBeanClassName(),
beanFactory.getBeanClassLoader());
return ResolvableType.forClass(factoryBeanClass).as(FactoryBean.class)
.resolveGeneric();
} }
private String[] getBeanNamesForAnnotation( private String[] getBeanNamesForAnnotation(
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
package org.springframework.boot.autoconfigure.condition; package org.springframework.boot.autoconfigure.condition;
import org.junit.Ignore;
import org.junit.Test; import org.junit.Test;
import org.springframework.beans.factory.FactoryBean; import org.springframework.beans.factory.FactoryBean;
import org.springframework.boot.autoconfigure.PropertyPlaceholderAutoConfiguration; import org.springframework.boot.autoconfigure.PropertyPlaceholderAutoConfiguration;
...@@ -38,6 +37,7 @@ import static org.junit.Assert.assertTrue; ...@@ -38,6 +37,7 @@ import static org.junit.Assert.assertTrue;
* *
* @author Dave Syer * @author Dave Syer
* @author Phillip Webb * @author Phillip Webb
* @author Jakub Kubrynski
*/ */
@SuppressWarnings("resource") @SuppressWarnings("resource")
public class ConditionalOnMissingBeanTests { public class ConditionalOnMissingBeanTests {
...@@ -102,7 +102,7 @@ public class ConditionalOnMissingBeanTests { ...@@ -102,7 +102,7 @@ public class ConditionalOnMissingBeanTests {
@Test @Test
public void testAnnotationOnMissingBeanConditionWithEagerFactoryBean() { public void testAnnotationOnMissingBeanConditionWithEagerFactoryBean() {
this.context.register(FooConfiguration.class, OnAnnotationConfiguration.class, this.context.register(FooConfiguration.class, OnAnnotationConfiguration.class,
ConfigurationWithFactoryBean.class, FactoryBeanXmlConfiguration.class,
PropertyPlaceholderAutoConfiguration.class); PropertyPlaceholderAutoConfiguration.class);
this.context.refresh(); this.context.refresh();
assertFalse(this.context.containsBean("bar")); assertFalse(this.context.containsBean("bar"));
...@@ -111,22 +111,44 @@ public class ConditionalOnMissingBeanTests { ...@@ -111,22 +111,44 @@ public class ConditionalOnMissingBeanTests {
} }
@Test @Test
@Ignore("This will never work - you need to use XML for FactoryBeans, or else call getObject() inside the @Bean method")
public void testOnMissingBeanConditionWithFactoryBean() { public void testOnMissingBeanConditionWithFactoryBean() {
this.context.register(ExampleBeanAndFactoryBeanConfiguration.class, this.context.register(FactoryBeanConfiguration.class,
ConditionalOnFactoryBean.class,
PropertyPlaceholderAutoConfiguration.class); PropertyPlaceholderAutoConfiguration.class);
this.context.refresh(); this.context.refresh();
// There should be only one assertThat(this.context.getBean(ExampleBean.class).toString(),
this.context.getBean(ExampleBean.class); equalTo("fromFactory"));
}
@Test
public void testOnMissingBeanConditionWithConcreteFactoryBean() {
this.context.register(ConcreteFactoryBeanConfiguration.class,
ConditionalOnFactoryBean.class,
PropertyPlaceholderAutoConfiguration.class);
this.context.refresh();
assertThat(this.context.getBean(ExampleBean.class).toString(),
equalTo("fromFactory"));
}
@Test
public void testOnMissingBeanConditionWithUnhelpfulFactoryBean() {
this.context.register(UnhelpfulFactoryBeanConfiguration.class,
ConditionalOnFactoryBean.class,
PropertyPlaceholderAutoConfiguration.class);
this.context.refresh();
// We could not tell that the FactoryBean would ultimately create an ExampleBean
assertThat(this.context.getBeansOfType(ExampleBean.class).values().size(),
equalTo(2));
} }
@Test @Test
public void testOnMissingBeanConditionWithFactoryBeanInXml() { public void testOnMissingBeanConditionWithFactoryBeanInXml() {
this.context.register(ConfigurationWithFactoryBean.class, this.context.register(FactoryBeanXmlConfiguration.class,
ConditionalOnFactoryBean.class,
PropertyPlaceholderAutoConfiguration.class); PropertyPlaceholderAutoConfiguration.class);
this.context.refresh(); this.context.refresh();
// There should be only one assertThat(this.context.getBean(ExampleBean.class).toString(),
this.context.getBean(ExampleBean.class); equalTo("fromFactory"));
} }
@Configuration @Configuration
...@@ -139,17 +161,41 @@ public class ConditionalOnMissingBeanTests { ...@@ -139,17 +161,41 @@ public class ConditionalOnMissingBeanTests {
} }
@Configuration @Configuration
protected static class ExampleBeanAndFactoryBeanConfiguration { protected static class FactoryBeanConfiguration {
@Bean @Bean
public FactoryBean<ExampleBean> exampleBeanFactoryBean() { public FactoryBean<ExampleBean> exampleBeanFactoryBean() {
return new ExampleFactoryBean("foo"); return new ExampleFactoryBean("foo");
} }
}
@Configuration
protected static class ConcreteFactoryBeanConfiguration {
@Bean
public ExampleFactoryBean exampleBeanFactoryBean() {
return new ExampleFactoryBean("foo");
}
}
@Configuration
protected static class UnhelpfulFactoryBeanConfiguration {
@Bean
@SuppressWarnings("rawtypes")
public FactoryBean exampleBeanFactoryBean() {
return new ExampleFactoryBean("foo");
}
}
@Configuration
@ImportResource("org/springframework/boot/autoconfigure/condition/factorybean.xml")
protected static class FactoryBeanXmlConfiguration {
}
@Configuration
protected static class ConditionalOnFactoryBean {
@Bean @Bean
@ConditionalOnMissingBean(ExampleBean.class) @ConditionalOnMissingBean(ExampleBean.class)
public ExampleBean createExampleBean() { public ExampleBean createExampleBean() {
return new ExampleBean(); return new ExampleBean("direct");
} }
} }
...@@ -162,11 +208,6 @@ public class ConditionalOnMissingBeanTests { ...@@ -162,11 +208,6 @@ public class ConditionalOnMissingBeanTests {
} }
} }
@Configuration
@ImportResource("org/springframework/boot/autoconfigure/condition/factorybean.xml")
protected static class ConfigurationWithFactoryBean {
}
@Configuration @Configuration
@EnableScheduling @EnableScheduling
protected static class FooConfiguration { protected static class FooConfiguration {
...@@ -198,7 +239,7 @@ public class ConditionalOnMissingBeanTests { ...@@ -198,7 +239,7 @@ public class ConditionalOnMissingBeanTests {
protected static class ExampleBeanConfiguration { protected static class ExampleBeanConfiguration {
@Bean @Bean
public ExampleBean exampleBean() { public ExampleBean exampleBean() {
return new ExampleBean(); return new ExampleBean("test");
} }
} }
...@@ -208,12 +249,24 @@ public class ConditionalOnMissingBeanTests { ...@@ -208,12 +249,24 @@ public class ConditionalOnMissingBeanTests {
@Bean @Bean
@ConditionalOnMissingBean @ConditionalOnMissingBean
public ExampleBean exampleBean2() { public ExampleBean exampleBean2() {
return new ExampleBean(); return new ExampleBean("test");
} }
} }
public static class ExampleBean { public static class ExampleBean {
private String value;
public ExampleBean(String value) {
this.value = value;
}
@Override
public String toString() {
return this.value;
}
} }
public static class ExampleFactoryBean implements FactoryBean<ExampleBean> { public static class ExampleFactoryBean implements FactoryBean<ExampleBean> {
...@@ -224,7 +277,7 @@ public class ConditionalOnMissingBeanTests { ...@@ -224,7 +277,7 @@ public class ConditionalOnMissingBeanTests {
@Override @Override
public ExampleBean getObject() throws Exception { public ExampleBean getObject() throws Exception {
return new ExampleBean(); return new ExampleBean("fromFactory");
} }
@Override @Override
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment