diff --git a/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/ContextFunctionCatalogAutoConfiguration.java b/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/ContextFunctionCatalogAutoConfiguration.java index b0139f201..2bdfd9f03 100644 --- a/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/ContextFunctionCatalogAutoConfiguration.java +++ b/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/ContextFunctionCatalogAutoConfiguration.java @@ -159,6 +159,7 @@ public class ContextFunctionCatalogAutoConfiguration { private BeanDefinitionRegistry registry; private ConversionService conversionService; private Map registrations = new HashMap<>(); + private Map>> types = new HashMap<>(); @Override public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) { @@ -166,7 +167,8 @@ public class ContextFunctionCatalogAutoConfiguration { } @Override - public void postProcessBeanFactory(ConfigurableListableBeanFactory factory) throws BeansException { + public void postProcessBeanFactory(ConfigurableListableBeanFactory factory) + throws BeansException { for (String name : factory.getBeanDefinitionNames()) { if (isGeneric(factory, name, Supplier.class)) { this.suppliers.add(name); @@ -198,7 +200,8 @@ public class ContextFunctionCatalogAutoConfiguration { // Add consumers that were not already registered for (String key : consumers.keySet()) { if (!targets.containsKey(consumers.get(key))) { - FunctionRegistration target = new FunctionRegistration(consumers.get(key)).names(getAliases(key)); + FunctionRegistration target = new FunctionRegistration( + consumers.get(key)).names(getAliases(key)); targets.put(target.getTarget(), key); registrations.add(target); } @@ -206,7 +209,8 @@ public class ContextFunctionCatalogAutoConfiguration { // Add suppliers that were not already registered for (String key : suppliers.keySet()) { if (!targets.containsKey(suppliers.get(key))) { - FunctionRegistration target = new FunctionRegistration(suppliers.get(key)).names(getAliases(key)); + FunctionRegistration target = new FunctionRegistration( + suppliers.get(key)).names(getAliases(key)); targets.put(target.getTarget(), key); registrations.add(target); } @@ -214,7 +218,8 @@ public class ContextFunctionCatalogAutoConfiguration { // Add functions that were not already registered for (String key : functions.keySet()) { if (!targets.containsKey(functions.get(key))) { - FunctionRegistration target = new FunctionRegistration(functions.get(key)).names(getAliases(key)); + FunctionRegistration target = new FunctionRegistration( + functions.get(key)).names(getAliases(key)); targets.put(target.getTarget(), key); registrations.add(target); } @@ -230,9 +235,12 @@ public class ContextFunctionCatalogAutoConfiguration { } private Object convert(Object function, String value) { - if (conversionService == null && registry instanceof ConfigurableListableBeanFactory) { - ConversionService conversionService = ((ConfigurableBeanFactory) this.registry).getConversionService(); - this.conversionService = conversionService != null ? conversionService : new DefaultConversionService(); + if (conversionService == null + && registry instanceof ConfigurableListableBeanFactory) { + ConversionService conversionService = ((ConfigurableBeanFactory) this.registry) + .getConversionService(); + this.conversionService = conversionService != null ? conversionService + : new DefaultConversionService(); } Class type = findType(function, ParamType.INPUT); return conversionService.canConvert(String.class, type) @@ -254,12 +262,16 @@ public class ContextFunctionCatalogAutoConfiguration { Object target = registration.getTarget(); this.registrations.put(target, key); if (target instanceof Supplier) { + findType(target, ParamType.OUTPUT); registration.target(target((Supplier) target, key)); } else if (target instanceof Consumer) { + findType(target, ParamType.INPUT); registration.target(target((Consumer) target, key)); } else if (target instanceof Function) { + findType(target, ParamType.INPUT); + findType(target, ParamType.OUTPUT); registration.target(target((Function) target, key)); } registrations.remove(target); @@ -284,14 +296,17 @@ public class ContextFunctionCatalogAutoConfiguration { @SuppressWarnings({ "unchecked", "rawtypes" }) private T target(T target, String key) { - if (target instanceof Supplier && !isFluxSupplier(key, (Supplier)target)) { - target = (T) new FluxSupplier((Supplier)target); + if (target instanceof Supplier + && !isFluxSupplier(key, (Supplier) target)) { + target = (T) new FluxSupplier((Supplier) target); } - else if (target instanceof Function && !isFluxFunction(key, (Function)target)) { - target = (T) new FluxFunction((Function)target); + else if (target instanceof Function + && !isFluxFunction(key, (Function) target)) { + target = (T) new FluxFunction((Function) target); } - else if (target instanceof Consumer && !isFluxConsumer(key, (Consumer)target)) { - target = (T) new FluxConsumer((Consumer)target); + else if (target instanceof Consumer + && !isFluxConsumer(key, (Consumer) target)) { + target = (T) new FluxConsumer((Consumer) target); } return target; } @@ -312,25 +327,36 @@ public class ContextFunctionCatalogAutoConfiguration { } private boolean hasFluxTypes(Object function) { - return FunctionInspector.isWrapper(findType(function, ParamType.INPUT_WRAPPER)) - || FunctionInspector.isWrapper(findType(function, ParamType.OUTPUT_WRAPPER)); + return FunctionInspector + .isWrapper(findType(function, ParamType.INPUT_WRAPPER)) + || FunctionInspector + .isWrapper(findType(function, ParamType.OUTPUT_WRAPPER)); } - private boolean isGeneric(ConfigurableListableBeanFactory factory, String name, Class functionalInterface) { + private boolean isGeneric(ConfigurableListableBeanFactory factory, String name, + Class functionalInterface) { ResolvableType matchingType = null; ResolvableType[] nonMatchingTypes = null; if (functionalInterface.isAssignableFrom(Function.class)) { - matchingType = ResolvableType.forClassWithGenerics(Function.class, Flux.class, Flux.class); - nonMatchingTypes = new ResolvableType[]{ResolvableType.forClassWithGenerics(Flux.class, String.class), ResolvableType.forClassWithGenerics(Flux.class, String.class)}; + matchingType = ResolvableType.forClassWithGenerics(Function.class, + Flux.class, Flux.class); + nonMatchingTypes = new ResolvableType[] { + ResolvableType.forClassWithGenerics(Flux.class, String.class), + ResolvableType.forClassWithGenerics(Flux.class, String.class) }; } else { - nonMatchingTypes = new ResolvableType[]{ResolvableType.forClassWithGenerics(Flux.class, String.class)}; + nonMatchingTypes = new ResolvableType[] { + ResolvableType.forClassWithGenerics(Flux.class, String.class) }; if (functionalInterface.isAssignableFrom(Consumer.class)) { - matchingType = ResolvableType.forClassWithGenerics(Consumer.class, Flux.class); + matchingType = ResolvableType.forClassWithGenerics(Consumer.class, + Flux.class); } - matchingType = ResolvableType.forClassWithGenerics(Supplier.class, Flux.class); + matchingType = ResolvableType.forClassWithGenerics(Supplier.class, + Flux.class); } - return factory.isTypeMatch(name, matchingType) && !factory.isTypeMatch(name, ResolvableType.forClassWithGenerics(functionalInterface, nonMatchingTypes)); + return factory.isTypeMatch(name, matchingType) + && !factory.isTypeMatch(name, ResolvableType + .forClassWithGenerics(functionalInterface, nonMatchingTypes)); } private Class findType(String name, AbstractBeanDefinition definition, @@ -353,14 +379,8 @@ public class ContextFunctionCatalogAutoConfiguration { } else if (source instanceof Resource) { try { - Class beanType = ClassUtils.forName(definition.getBeanClassName(), - null); - for (Type type : beanType.getGenericInterfaces()) { - if (type.getTypeName().startsWith("java.util.function")) { - param = extractType(type, paramType, index); - break; - } - } + Class beanType = resolveBeanClass(definition); + param = findTypeFromBeanClass(beanType, paramType); if (param == null) { // Last chance param = beanType; @@ -386,6 +406,11 @@ public class ContextFunctionCatalogAutoConfiguration { } } } + Class result = extractClass(name, param, paramType); + return result; + } + + private Class extractClass(String name, Type param, ParamType paramType) { if (param instanceof ParameterizedType) { ParameterizedType concrete = (ParameterizedType) param; param = concrete.getRawType(); @@ -393,11 +418,40 @@ public class ContextFunctionCatalogAutoConfiguration { if (param == null) { // Last ditch attempt to guess: Flux if (paramType.isWrapper()) { - return Flux.class; + param = Flux.class; + } + else { + param = String.class; } - return String.class; } - return (Class) param; + Class result = param instanceof Class ? (Class) param : null; + if (result != null) { + Map> values = types.computeIfAbsent(name, + key -> new HashMap<>()); + values.put(paramType, result); + } + return result; + } + + private Type findTypeFromBeanClass(Class beanType, ParamType paramType) { + int index = paramType.isOutput() ? 1 : 0; + for (Type type : beanType.getGenericInterfaces()) { + if (type.getTypeName().startsWith("java.util.function")) { + return extractType(type, paramType, index); + } + } + return null; + } + + private Class resolveBeanClass(AbstractBeanDefinition definition) + throws ClassNotFoundException, LinkageError { + try { + return ClassUtils.forName(definition.getBeanClassName(), null); + } + catch (ClassNotFoundException e) { + return ClassUtils.forName(definition.getBeanClassName(), + getClass().getClassLoader()); + } } private Type findBeanType(AbstractBeanDefinition definition, @@ -481,10 +535,26 @@ public class ContextFunctionCatalogAutoConfiguration { private Class findType(Object function, ParamType type) { String name = registrations.get(function); + if (types.containsKey(name)) { + Map> values = types.get(name); + if (values.containsKey(type)) { + return values.get(type); + } + } if (name == null || !registry.containsBeanDefinition(name)) { + if (function != null) { + Type param = findTypeFromBeanClass(function.getClass(), type); + if (param != null) { + Class result = extractClass(name, param, type); + if (result != null) { + return result; + } + } + } return Object.class; } - return findType(name, (AbstractBeanDefinition) registry.getBeanDefinition(name), type); + return findType(name, + (AbstractBeanDefinition) registry.getBeanDefinition(name), type); } } @@ -497,8 +567,7 @@ public class ContextFunctionCatalogAutoConfiguration { } public boolean isInput() { - return this == INPUT || this == INPUT_WRAPPER - || this == INPUT_INNER_WRAPPER; + return this == INPUT || this == INPUT_WRAPPER || this == INPUT_INNER_WRAPPER; } public boolean isWrapper() { diff --git a/spring-cloud-function-context/src/test/java/org/springframework/cloud/function/context/ContextFunctionCatalogAutoConfigurationTests.java b/spring-cloud-function-context/src/test/java/org/springframework/cloud/function/context/ContextFunctionCatalogAutoConfigurationTests.java index 8bcb63c84..b54a2b6ed 100644 --- a/spring-cloud-function-context/src/test/java/org/springframework/cloud/function/context/ContextFunctionCatalogAutoConfigurationTests.java +++ b/spring-cloud-function-context/src/test/java/org/springframework/cloud/function/context/ContextFunctionCatalogAutoConfigurationTests.java @@ -29,7 +29,10 @@ import java.util.stream.Collectors; import org.junit.After; import org.junit.Test; +import org.springframework.beans.BeansException; import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.beans.factory.config.BeanFactoryPostProcessor; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.boot.builder.SpringApplicationBuilder; import org.springframework.cloud.function.compiler.CompiledFunctionFactory; @@ -158,6 +161,17 @@ public class ContextFunctionCatalogAutoConfigurationTests { .isAssignableFrom(Map.class); } + @Test + public void singletonFunction() { + create(SingletonConfiguration.class); + assertThat(context.getBean("function")).isInstanceOf(Function.class); + assertThat(catalog.lookupFunction("function")).isInstanceOf(Function.class); + assertThat(inspector.getInputType(catalog.lookupFunction("function"))) + .isAssignableFrom(Integer.class); + assertThat(inspector.getInputWrapper(catalog.lookupFunction("function"))) + .isAssignableFrom(Integer.class); + } + @Test public void componentScanBeanFunction() { create(ComponentScanBeanConfiguration.class); @@ -399,6 +413,26 @@ public class ContextFunctionCatalogAutoConfigurationTests { protected static class ExternalConfiguration { } + @EnableAutoConfiguration + @Configuration + protected static class SingletonConfiguration implements BeanFactoryPostProcessor { + + @Override + public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) + throws BeansException { + beanFactory.registerSingleton("function", new SingletonFunction()); + } + } + + protected static class SingletonFunction implements Function { + + @Override + public String apply(Integer input) { + return "value=" + input; + } + + } + @EnableAutoConfiguration @Configuration @ComponentScan(basePackageClasses = GenericFunction.class) diff --git a/spring-cloud-function-stream/src/main/java/org/springframework/cloud/function/stream/StreamListeningFunctionInvoker.java b/spring-cloud-function-stream/src/main/java/org/springframework/cloud/function/stream/StreamListeningFunctionInvoker.java index d0b91abb8..46df7e6ee 100644 --- a/spring-cloud-function-stream/src/main/java/org/springframework/cloud/function/stream/StreamListeningFunctionInvoker.java +++ b/spring-cloud-function-stream/src/main/java/org/springframework/cloud/function/stream/StreamListeningFunctionInvoker.java @@ -152,8 +152,14 @@ public class StreamListeningFunctionInvoker implements SmartInitializingSingleto } else { for (String candidate : names) { - Class inputType = functionInspector - .getInputType(functionCatalog.lookupFunction(candidate)); + Object function = functionCatalog.lookupFunction(candidate); + if (function==null) { + function = functionCatalog.lookupConsumer(candidate); + } + if (function==null) { + continue; + } + Class inputType = functionInspector.getInputType(function); Object value = this.converter.fromMessage(input, inputType); if (value != null && inputType.isInstance(value)) { matches.add(candidate); @@ -210,7 +216,7 @@ public class StreamListeningFunctionInvoker implements SmartInitializingSingleto else { result = this.converter.fromMessage(m, inputType); } - if (result==null) { + if (result == null) { result = UNCONVERTED; } return result; diff --git a/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/SingletonTests.java b/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/SingletonTests.java index 3f6d113ab..f678390b4 100644 --- a/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/SingletonTests.java +++ b/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/SingletonTests.java @@ -80,7 +80,7 @@ public class SingletonTests { @Override public void postProcessBeanDefinitionRegistry( BeanDefinitionRegistry registry) throws BeansException { - // Simulates what happens whem you add a compiled function + // Simulates what happens when you add a compiled function RootBeanDefinition beanDefinition = new RootBeanDefinition(MySupplier.class); registry.registerBeanDefinition("words", beanDefinition); }