diff --git a/src/main/java/org/springframework/guice/module/SpringModule.java b/src/main/java/org/springframework/guice/module/SpringModule.java index f36c329..4d40812 100644 --- a/src/main/java/org/springframework/guice/module/SpringModule.java +++ b/src/main/java/org/springframework/guice/module/SpringModule.java @@ -44,7 +44,9 @@ import com.google.inject.name.Named; import com.google.inject.name.Names; import com.google.inject.spi.ProvisionListener; +import com.google.inject.util.Types; import org.springframework.beans.factory.BeanFactoryUtils; +import org.springframework.beans.factory.FactoryBean; import org.springframework.beans.factory.annotation.AnnotatedBeanDefinition; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; @@ -136,30 +138,42 @@ public class SpringModule extends AbstractModule { if (definition.isAutowireCandidate() && definition.getRole() == AbstractBeanDefinition.ROLE_APPLICATION) { Type type; - RootBeanDefinition rootBeanDefinition = (RootBeanDefinition) beanFactory - .getMergedBeanDefinition(name); - if (rootBeanDefinition.getFactoryBeanName() != null - && rootBeanDefinition.getResolvedFactoryMethod() != null) { - type = rootBeanDefinition.getResolvedFactoryMethod() - .getGenericReturnType(); - } - else { - type = rootBeanDefinition.getResolvableType().getType(); - } - if (type == null) { - continue; - } - final String beanName = name; - Provider typeProvider = BeanFactoryProvider.typed(beanFactory, type, - bindingAnnotation); - Provider namedProvider = BeanFactoryProvider.named(beanFactory, - beanName, type, bindingAnnotation); - - Class clazz = (type instanceof Class) ? (Class) type - : beanFactory.getType(beanName); + Class clazz = beanFactory.getType(name); if (clazz == null) { continue; } + if (clazz.getTypeParameters().length > 0) { + RootBeanDefinition rootBeanDefinition = (RootBeanDefinition) beanFactory + .getMergedBeanDefinition(name); + if (rootBeanDefinition.getFactoryBeanName() != null + && rootBeanDefinition.getResolvedFactoryMethod() != null) { + type = rootBeanDefinition.getResolvedFactoryMethod() + .getGenericReturnType(); + } + else { + type = rootBeanDefinition.getResolvableType().getType(); + } + if (type instanceof ParameterizedType) { + ParameterizedType parameterizedType = (ParameterizedType) type; + if (parameterizedType.getRawType() instanceof Class && + FactoryBean.class.isAssignableFrom((Class) parameterizedType.getRawType())) { + type = Types.newParameterizedTypeWithOwner(parameterizedType.getOwnerType(), + clazz, parameterizedType.getActualTypeArguments()); + } + } + + } else { + type = clazz; + } + + if (type == null) { + continue; + } + Provider typeProvider = BeanFactoryProvider.typed(beanFactory, type, + bindingAnnotation); + Provider namedProvider = BeanFactoryProvider.named(beanFactory, + name, type, bindingAnnotation); + if (!clazz.isInterface() && !ClassUtils.isCglibProxyClass(clazz)) { bindConditionally(binder(), name, clazz, typeProvider, namedProvider, bindingAnnotation); diff --git a/src/test/java/org/springframework/guice/SuperClassTests.java b/src/test/java/org/springframework/guice/SuperClassTests.java index 762431c..8aee26f 100644 --- a/src/test/java/org/springframework/guice/SuperClassTests.java +++ b/src/test/java/org/springframework/guice/SuperClassTests.java @@ -6,6 +6,7 @@ import com.google.inject.Key; import com.google.inject.TypeLiteral; import org.junit.Test; +import org.springframework.beans.factory.FactoryBean; import org.springframework.context.annotation.AnnotationConfigApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.ComponentScan; @@ -308,6 +309,52 @@ public class SuperClassTests { assertTrue(integerFoo instanceof SubIntegerFoo); } + @Test + public void testSpringFactoryBean() { + baseTestSpringFactoryBean(ModulesConfig.class); + } + + @Test + public void testImportSpringFactoryBean() { + baseTestSpringFactoryBean(ImportConfig.class); + } + + @Test + public void testComponentScanSpringFactoryBean() { + baseTestSpringFactoryBean(ComponentScanConfig.class); + } + + private void baseTestSpringFactoryBean(Class configClass) { + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext( + configClass); + + Bar bar = context.getBean(Bar.class); + assertTrue(bar instanceof Bar); + } + + @Test + public void testGuiceFactoryBean() { + baseTestGuiceFactoryBean(ModulesConfig.class); + } + + @Test + public void testImportGuiceFactoryBean() { + baseTestGuiceFactoryBean(ImportConfig.class); + } + + @Test + public void testComponentScanGuiceFactoryBean() { + baseTestGuiceFactoryBean(ComponentScanConfig.class); + } + + private void baseTestGuiceFactoryBean(Class configClass) { + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext( + configClass); + Injector injector = context.getBean(Injector.class); + Bar bar = injector.getInstance(Bar.class); + assertTrue(bar instanceof Bar); + } + static class DisableJITConfig { @Bean public AbstractModule disableJITModule() { @@ -354,12 +401,17 @@ public class SuperClassTests { return new SubIntegerFoo(); } + @Bean + BarFactory barFactory() { + return new BarFactory(); + } + } @Configuration @EnableGuiceModules @Import({ IGrandChildImpl.class, IGrandChildString.class, IGrandChildInteger.class, - SubFoo.class, SubStringFoo.class, SubIntegerFoo.class }) + SubFoo.class, SubStringFoo.class, SubIntegerFoo.class, BarFactory.class }) static class ImportConfig extends DisableJITConfig { } @@ -443,4 +495,20 @@ public class SuperClassTests { } + public static class Bar {} + + + @Component + public static class BarFactory implements FactoryBean { + + @Override + public Bar getObject() { + return new Bar(); + } + + @Override + public Class getObjectType() { + return Bar.class; + } + } }