diff --git a/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/catalog/BeanFactoryAwareFunctionRegistry.java b/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/catalog/BeanFactoryAwareFunctionRegistry.java index 940a632b2..48f1b1e17 100644 --- a/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/catalog/BeanFactoryAwareFunctionRegistry.java +++ b/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/catalog/BeanFactoryAwareFunctionRegistry.java @@ -90,11 +90,10 @@ import org.springframework.util.StringUtils; * * @author Oleg Zhurakousky * @author Eric Botard - * * @since 3.0 */ public class BeanFactoryAwareFunctionRegistry - implements FunctionRegistry, FunctionInspector, ApplicationContextAware { + implements FunctionRegistry, FunctionInspector, ApplicationContextAware { private static Log logger = LogFactory.getLog(BeanFactoryAwareFunctionRegistry.class); @@ -119,7 +118,7 @@ public class BeanFactoryAwareFunctionRegistry private final CompositeMessageConverter messageConverter; public BeanFactoryAwareFunctionRegistry(ConversionService conversionService, - @Nullable CompositeMessageConverter messageConverter) { + @Nullable CompositeMessageConverter messageConverter) { this.conversionService = conversionService; this.messageConverter = messageConverter; } @@ -132,8 +131,8 @@ public class BeanFactoryAwareFunctionRegistry @Override public int size() { return this.applicationContext.getBeanNamesForType(Supplier.class).length + - this.applicationContext.getBeanNamesForType(Function.class).length + - this.applicationContext.getBeanNamesForType(Consumer.class).length; + this.applicationContext.getBeanNamesForType(Function.class).length + + this.applicationContext.getBeanNamesForType(Consumer.class).length; } @Override @@ -142,7 +141,8 @@ public class BeanFactoryAwareFunctionRegistry if (!StringUtils.hasText(definition)) { definition = this.applicationContext.getEnvironment().getProperty("spring.cloud.function.definition"); } - Object function = this.proxyInvokerIfNecessary((FunctionInvocationWrapper) this.compose(null, definition, acceptedOutputTypes)); + Object function = this + .proxyInvokerIfNecessary((FunctionInvocationWrapper) this.compose(null, definition, acceptedOutputTypes)); return (T) function; } @@ -150,11 +150,14 @@ public class BeanFactoryAwareFunctionRegistry @Override public Set getNames(Class type) { Set registeredNames = registrationsByFunction.values().stream().flatMap(reg -> reg.getNames().stream()) - .collect(Collectors.toSet()); + .collect(Collectors.toSet()); if (type == null) { - registeredNames.addAll(CollectionUtils.arrayToList(this.applicationContext.getBeanNamesForType(Function.class))); - registeredNames.addAll(CollectionUtils.arrayToList(this.applicationContext.getBeanNamesForType(Supplier.class))); - registeredNames.addAll(CollectionUtils.arrayToList(this.applicationContext.getBeanNamesForType(Consumer.class))); + registeredNames + .addAll(CollectionUtils.arrayToList(this.applicationContext.getBeanNamesForType(Function.class))); + registeredNames + .addAll(CollectionUtils.arrayToList(this.applicationContext.getBeanNamesForType(Supplier.class))); + registeredNames + .addAll(CollectionUtils.arrayToList(this.applicationContext.getBeanNamesForType(Consumer.class))); } else { registeredNames.addAll(CollectionUtils.arrayToList(this.applicationContext.getBeanNamesForType(type))); @@ -194,16 +197,18 @@ public class BeanFactoryAwareFunctionRegistry } if (function != null && this.notFunction(function.getClass()) - && this.applicationContext.containsBean(name + FunctionRegistration.REGISTRATION_NAME_SUFFIX)) { // e.g., Kotlin lambdas - function = this.applicationContext.getBean(name + FunctionRegistration.REGISTRATION_NAME_SUFFIX, FunctionRegistration.class); + && this.applicationContext + .containsBean(name + FunctionRegistration.REGISTRATION_NAME_SUFFIX)) { // e.g., Kotlin lambdas + function = this.applicationContext + .getBean(name + FunctionRegistration.REGISTRATION_NAME_SUFFIX, FunctionRegistration.class); } return function; } private boolean notFunction(Class functionClass) { return !Function.class.isAssignableFrom(functionClass) - && !Supplier.class.isAssignableFrom(functionClass) - && !Consumer.class.isAssignableFrom(functionClass); + && !Supplier.class.isAssignableFrom(functionClass) + && !Consumer.class.isAssignableFrom(functionClass); } private Type discoverFunctionType(Object function, String... names) { @@ -211,34 +216,39 @@ public class BeanFactoryAwareFunctionRegistry for (int i = 0; i < names.length && !beanDefinitionExists; i++) { beanDefinitionExists = this.applicationContext.getBeanFactory().containsBeanDefinition(names[i]); if (this.applicationContext.containsBean("&" + names[i])) { - Class objectType = this.applicationContext.getBean("&" + names[i], FactoryBean.class).getObjectType(); + Class objectType = this.applicationContext.getBean("&" + names[i], FactoryBean.class) + .getObjectType(); return FunctionTypeUtils.discoverFunctionTypeFromClass(objectType); } } if (!beanDefinitionExists) { logger.info("BeanDefinition for function name(s) '" + Arrays.asList(names) + - "' can not be located. FunctionType will be based on " + function.getClass()); + "' can not be located. FunctionType will be based on " + function.getClass()); } return beanDefinitionExists - ? FunctionType.of(FunctionContextUtils.findType(applicationContext.getBeanFactory(), names)).getType() - : new FunctionType(function.getClass()).getType(); + ? FunctionType.of(FunctionContextUtils.findType(applicationContext.getBeanFactory(), names)).getType() + : new FunctionType(function.getClass()).getType(); } private String discoverDefaultDefinitionIfNecessary(String definition) { if (StringUtils.isEmpty(definition)) { // the underscores are for Kotlin function registrations (see KotlinLambdaToFunctionAutoConfiguration) - String[] functionNames = Stream.of(this.applicationContext.getBeanNamesForType(Function.class)) - .filter(n -> !n.endsWith(FunctionRegistration.REGISTRATION_NAME_SUFFIX) && !n.equals(RoutingFunction.FUNCTION_NAME)).toArray(String[]::new); - String[] consumerNames = Stream.of(this.applicationContext.getBeanNamesForType(Consumer.class)) - .filter(n -> !n.endsWith(FunctionRegistration.REGISTRATION_NAME_SUFFIX) && !n.equals(RoutingFunction.FUNCTION_NAME)).toArray(String[]::new); - String[] supplierNames = Stream.of(this.applicationContext.getBeanNamesForType(Supplier.class)) - .filter(n -> !n.endsWith(FunctionRegistration.REGISTRATION_NAME_SUFFIX) && !n.equals(RoutingFunction.FUNCTION_NAME)).toArray(String[]::new); + String[] functionNames = Stream.of(this.applicationContext.getBeanNamesForType(Function.class)) + .filter(n -> !n.endsWith(FunctionRegistration.REGISTRATION_NAME_SUFFIX) && !n + .equals(RoutingFunction.FUNCTION_NAME)).toArray(String[]::new); + String[] consumerNames = Stream.of(this.applicationContext.getBeanNamesForType(Consumer.class)) + .filter(n -> !n.endsWith(FunctionRegistration.REGISTRATION_NAME_SUFFIX) && !n + .equals(RoutingFunction.FUNCTION_NAME)).toArray(String[]::new); + String[] supplierNames = Stream.of(this.applicationContext.getBeanNamesForType(Supplier.class)) + .filter(n -> !n.endsWith(FunctionRegistration.REGISTRATION_NAME_SUFFIX) && !n + .equals(RoutingFunction.FUNCTION_NAME)).toArray(String[]::new); /* * we may need to add BiFunction and BiConsumer at some point */ List names = Stream - .concat(Stream.of(functionNames), Stream.concat(Stream.of(consumerNames), Stream.of(supplierNames))).collect(Collectors.toList()); + .concat(Stream.of(functionNames), Stream.concat(Stream.of(consumerNames), Stream.of(supplierNames))) + .collect(Collectors.toList()); if (!ObjectUtils.isEmpty(names)) { if (names.size() > 1) { @@ -253,15 +263,18 @@ public class BeanFactoryAwareFunctionRegistry } else { if (this.registrationsByName.size() > 0) { - Assert.isTrue(this.registrationsByName.size() == 1, "Found more then one function in local registry"); + Assert + .isTrue(this.registrationsByName.size() == 1, "Found more then one function in local registry"); definition = this.registrationsByName.keySet().iterator().next(); } } if (StringUtils.hasText(definition) && this.applicationContext.containsBean(definition)) { Type functionType = discoverFunctionType(this.applicationContext.getBean(definition), definition); - if (!FunctionTypeUtils.isSupplier(functionType) && !FunctionTypeUtils.isFunction(functionType) && !FunctionTypeUtils.isConsumer(functionType)) { - logger.info("Discovered functional instance of bean '" + definition + "' as a default function, however its " + if (!FunctionTypeUtils.isSupplier(functionType) && !FunctionTypeUtils + .isFunction(functionType) && !FunctionTypeUtils.isConsumer(functionType)) { + logger + .info("Discovered functional instance of bean '" + definition + "' as a default function, however its " + "function argument types can not be determined. Discarding."); definition = null; } @@ -270,10 +283,11 @@ public class BeanFactoryAwareFunctionRegistry return definition; } - @SuppressWarnings({ "unchecked", "rawtypes" }) + @SuppressWarnings({"unchecked", "rawtypes"}) private Function compose(Class type, String definition, String... acceptedOutputTypes) { if (logger.isInfoEnabled()) { - logger.info("Looking up function '" + definition + "' with acceptedOutputTypes: " + Arrays.asList(acceptedOutputTypes)); + logger.info("Looking up function '" + definition + "' with acceptedOutputTypes: " + Arrays + .asList(acceptedOutputTypes)); } definition = discoverDefaultDefinitionIfNecessary(definition); if (StringUtils.isEmpty(definition)) { @@ -281,7 +295,7 @@ public class BeanFactoryAwareFunctionRegistry } Function resultFunction = null; if (this.registrationsByName.containsKey(definition)) { - Object targetFunction = this.registrationsByName.get(definition).getTarget(); + Object targetFunction = this.registrationsByName.get(definition).getTarget(); Type functionType = this.registrationsByName.get(definition).getType().getType(); resultFunction = new FunctionInvocationWrapper(targetFunction, functionType, definition, acceptedOutputTypes); } @@ -295,13 +309,14 @@ public class BeanFactoryAwareFunctionRegistry Object function = this.locateFunction(name); if (function == null) { logger.warn("!!! Failed to discover function '" + definition + "' in function catalog. " - + "Function available in catalog are: " + this.getNames(null)); + + "Function available in catalog are: " + this.getNames(null)); return null; } else { Type functionType = FunctionContextUtils.findType(applicationContext.getBeanFactory(), name); if (functionType != null && functionType.toString().contains("org.apache.kafka.streams.")) { - logger.debug("Kafka Streams function '" + definition + "' is not supported by spring-cloud-function."); + logger + .debug("Kafka Streams function '" + definition + "' is not supported by spring-cloud-function."); return null; } } @@ -314,7 +329,8 @@ public class BeanFactoryAwareFunctionRegistry if (function instanceof FunctionRegistration) { registration = (FunctionRegistration) function; - currentFunctionType = currentFunctionType == null ? registration.getType().getType() : currentFunctionType; + currentFunctionType = currentFunctionType == null ? registration.getType() + .getType() : currentFunctionType; function = registration.getTarget(); } else { @@ -324,7 +340,8 @@ public class BeanFactoryAwareFunctionRegistry function = this.proxyTarget(function, functionalMethod); } String[] aliasNames = this.getAliases(name).toArray(new String[] {}); - currentFunctionType = currentFunctionType == null ? this.discoverFunctionType(function, aliasNames) : currentFunctionType; + currentFunctionType = currentFunctionType == null ? this + .discoverFunctionType(function, aliasNames) : currentFunctionType; registration = new FunctionRegistration<>(function, name).type(currentFunctionType); } @@ -343,12 +360,12 @@ public class BeanFactoryAwareFunctionRegistry else { originFunctionType = FunctionTypeUtils.compose(originFunctionType, currentFunctionType); resultFunction = new FunctionInvocationWrapper(resultFunction.andThen((Function) function), - originFunctionType, composedNameBuilder.toString(), acceptedOutputTypes); + originFunctionType, composedNameBuilder.toString(), acceptedOutputTypes); } prefix = "|"; } FunctionRegistration registration = new FunctionRegistration(resultFunction, definition) - .type(originFunctionType); + .type(originFunctionType); registrationsByFunction.putIfAbsent(resultFunction, registration); registrationsByName.putIfAbsent(definition, registration); } @@ -357,8 +374,8 @@ public class BeanFactoryAwareFunctionRegistry private boolean isFunctionPojo(Object function) { return !function.getClass().isSynthetic() - && !(function instanceof Supplier) && !(function instanceof Function) && !(function instanceof Consumer) - && !function.getClass().getPackage().getName().startsWith("org.springframework.cloud.function.compiler"); + && !(function instanceof Supplier) && !(function instanceof Function) && !(function instanceof Consumer) + && !function.getClass().getPackage().getName().startsWith("org.springframework.cloud.function.compiler"); } /* @@ -375,7 +392,9 @@ public class BeanFactoryAwareFunctionRegistry private Object proxyInvokerIfNecessary(FunctionInvocationWrapper functionInvoker) { if (functionInvoker != null && AopUtils.isCglibProxy(functionInvoker.getTarget())) { if (logger.isInfoEnabled()) { - logger.info("Proxying POJO function: " + functionInvoker.functionDefinition + ". . ." + functionInvoker.target.getClass()); + logger + .info("Proxying POJO function: " + functionInvoker.functionDefinition + ". . ." + functionInvoker.target + .getClass()); } ProxyFactory pf = new ProxyFactory(functionInvoker.getTarget()); pf.setProxyTargetClass(true); @@ -385,7 +404,7 @@ public class BeanFactoryAwareFunctionRegistry public Object invoke(MethodInvocation invocation) throws Throwable { // this will trigger the INNER PROXY if (ObjectUtils.isEmpty(invocation.getArguments())) { - Object o = functionInvoker.get(); + Object o = functionInvoker.get(); return o; } else { @@ -439,7 +458,7 @@ public class BeanFactoryAwareFunctionRegistry if (source instanceof StandardMethodMetadata) { StandardMethodMetadata metadata = (StandardMethodMetadata) source; Qualifier qualifier = AnnotatedElementUtils.findMergedAnnotation(metadata.getIntrospectedMethod(), - Qualifier.class); + Qualifier.class); if (qualifier != null && qualifier.value().length() > 0) { return qualifier.value(); } @@ -453,7 +472,6 @@ public class BeanFactoryAwareFunctionRegistry * catalog. * * @author Oleg Zhurakousky - * */ public class FunctionInvocationWrapper implements Function, Consumer, Supplier { @@ -491,7 +509,8 @@ public class BeanFactoryAwareFunctionRegistry /** * !! Experimental, may change. Is not yet intended as public API !! - * @param input input value + * + * @param input input value * @param enricher enricher function instance * @return the result */ @@ -507,14 +526,15 @@ public class BeanFactoryAwareFunctionRegistry /** * !! Experimental, may change. Is not yet intended as public API !! + * * @param enricher enricher function instance * @return the result */ @SuppressWarnings("rawtypes") public Object get(Function enricher) { Object input = FunctionTypeUtils.isMono(this.functionType) - ? Mono.empty() - : (FunctionTypeUtils.isMono(this.functionType) ? Flux.empty() : null); + ? Mono.empty() + : (FunctionTypeUtils.isMono(this.functionType) ? Flux.empty() : null); return this.doApply(input, false, enricher); } @@ -535,7 +555,7 @@ public class BeanFactoryAwareFunctionRegistry return target; } - @SuppressWarnings({ "rawtypes", "unchecked" }) + @SuppressWarnings({"rawtypes", "unchecked"}) private Object invokeFunction(Object input) { Object invocationResult = null; if (this.target instanceof Function) { @@ -563,12 +583,13 @@ public class BeanFactoryAwareFunctionRegistry } if (!(this.target instanceof Consumer) && logger.isDebugEnabled()) { - logger.debug("Result of invocation of \"" + this.functionDefinition + "\" function is '" + invocationResult + "'"); + logger + .debug("Result of invocation of \"" + this.functionDefinition + "\" function is '" + invocationResult + "'"); } return invocationResult; } - @SuppressWarnings({ "unchecked", "rawtypes" }) + @SuppressWarnings({"unchecked", "rawtypes"}) private Object doApply(Object input, boolean consumer, Function enricher) { if (logger.isDebugEnabled()) { logger.debug("Applying function: " + this.functionDefinition); @@ -577,60 +598,65 @@ public class BeanFactoryAwareFunctionRegistry Object result; if (input instanceof Publisher) { input = this.composed ? input : - this.convertInputPublisherIfNecessary((Publisher) input, FunctionTypeUtils.getInputType(this.functionType, 0)); + this.convertInputPublisherIfNecessary((Publisher) input, FunctionTypeUtils + .getInputType(this.functionType, 0)); if (FunctionTypeUtils.isReactive(FunctionTypeUtils.getInputType(this.functionType, 0))) { result = this.invokeFunction(input); } else { if (this.composed) { return input instanceof Mono - ? Mono.from((Publisher) input).transform((Function) this.target) - : Flux.from((Publisher) input).transform((Function) this.target); + ? Mono.from((Publisher) input).transform((Function) this.target) + : Flux.from((Publisher) input).transform((Function) this.target); } else { if (FunctionTypeUtils.isConsumer(functionType)) { result = input instanceof Mono - ? Mono.from((Publisher) input).doOnNext((Consumer) this.target).then() - : Flux.from((Publisher) input).doOnNext((Consumer) this.target).then(); + ? Mono.from((Publisher) input).doOnNext((Consumer) this.target).then() + : Flux.from((Publisher) input).doOnNext((Consumer) this.target).then(); } else { result = input instanceof Mono - ? Mono.from((Publisher) input).map(value -> this.invokeFunction(value)) - : Flux.from((Publisher) input).map(value -> this.invokeFunction(value)); + ? Mono.from((Publisher) input).map(value -> this.invokeFunction(value)) + : Flux.from((Publisher) input).map(value -> this.invokeFunction(value)); } } } } else { Type type = FunctionTypeUtils.getInputType(this.functionType, 0); - if (!this.composed && !FunctionTypeUtils.isMultipleInputArguments(this.functionType) && FunctionTypeUtils.isReactive(type)) { + if (!this.composed && !FunctionTypeUtils + .isMultipleInputArguments(this.functionType) && FunctionTypeUtils.isReactive(type)) { Publisher publisher = FunctionTypeUtils.isFlux(type) - ? input == null ? Flux.empty() : Flux.just(input) - : input == null ? Mono.empty() : Mono.just(input); + ? input == null ? Flux.empty() : Flux.just(input) + : input == null ? Mono.empty() : Mono.just(input); if (logger.isDebugEnabled()) { logger.debug("Invoking reactive function '" + this.functionType + "' with non-reactive input " - + "should at least assume reactive output (e.g., Function> f3 = catalog.lookup(\"echoFlux\");), " - + "otherwise invocation will result in ClassCastException."); + + "should at least assume reactive output (e.g., Function> f3 = catalog.lookup(\"echoFlux\");), " + + "otherwise invocation will result in ClassCastException."); } - result = this.invokeFunction(this.convertInputPublisherIfNecessary(publisher, FunctionTypeUtils.getInputType(this.functionType, 0))); + result = this.invokeFunction(this.convertInputPublisherIfNecessary(publisher, FunctionTypeUtils + .getInputType(this.functionType, 0))); } else { result = this.invokeFunction(this.composed ? input - : (input == null ? input : this.convertInputValueIfNecessary(input, FunctionTypeUtils.getInputType(this.functionType, 0)))); + : (input == null ? input : this + .convertInputValueIfNecessary(input, FunctionTypeUtils.getInputType(this.functionType, 0)))); } } // Outputs will be converted only if we're told how (via acceptedOutputMimeTypes), otherwise output returned as is. if (result != null && !ObjectUtils.isEmpty(this.acceptedOutputMimeTypes)) { result = result instanceof Publisher - ? this.convertOutputPublisherIfNecessary((Publisher) result, enricher, this.acceptedOutputMimeTypes) - : this.convertOutputValueIfNecessary(result, enricher, this.acceptedOutputMimeTypes); + ? this + .convertOutputPublisherIfNecessary((Publisher) result, enricher, this.acceptedOutputMimeTypes) + : this.convertOutputValueIfNecessary(result, enricher, this.acceptedOutputMimeTypes); } return result; } - @SuppressWarnings({ "rawtypes", "unchecked" }) + @SuppressWarnings({"rawtypes", "unchecked"}) private Object convertOutputValueIfNecessary(Object value, Function enricher, String... acceptedOutputMimeTypes) { logger.debug("Applying type conversion on output value"); Object convertedValue = null; @@ -642,20 +668,22 @@ public class BeanFactoryAwareFunctionRegistry Object outputArgument = parsed.getValue(value); try { convertedInputArray[i] = outputArgument instanceof Publisher - ? this.convertOutputPublisherIfNecessary((Publisher) outputArgument, enricher, acceptedOutputMimeTypes[i]) - : this.convertOutputValueIfNecessary(outputArgument, enricher, acceptedOutputMimeTypes[i]); + ? this + .convertOutputPublisherIfNecessary((Publisher) outputArgument, enricher, acceptedOutputMimeTypes[i]) + : this.convertOutputValueIfNecessary(outputArgument, enricher, acceptedOutputMimeTypes[i]); } catch (ArrayIndexOutOfBoundsException e) { throw new IllegalStateException("The number of 'acceptedOutputMimeTypes' for function '" + this.functionDefinition - + "' is (" + acceptedOutputMimeTypes.length - + "), which does not match the number of actual outputs of this function which is (" + outputCount + ").", e); + + "' is (" + acceptedOutputMimeTypes.length + + "), which does not match the number of actual outputs of this function which is (" + outputCount + ").", e); } } convertedValue = Tuples.fromArray(convertedInputArray); } else { - List acceptedContentTypes = MimeTypeUtils.parseMimeTypes(acceptedOutputMimeTypes[0].toString()); + List acceptedContentTypes = MimeTypeUtils + .parseMimeTypes(acceptedOutputMimeTypes[0].toString()); if (CollectionUtils.isEmpty(acceptedContentTypes)) { convertedValue = value; } @@ -672,7 +700,8 @@ public class BeanFactoryAwareFunctionRegistry } } else if (value instanceof byte[]) { - convertedValue = MessageBuilder.withPayload(value).setHeader(MessageHeaders.CONTENT_TYPE, acceptedContentType).build(); + convertedValue = MessageBuilder.withPayload(value) + .setHeader(MessageHeaders.CONTENT_TYPE, acceptedContentType).build(); } else if (value instanceof Iterable || ObjectUtils.isArray(value)) { boolean isArray = ObjectUtils.isArray(value); @@ -681,7 +710,9 @@ public class BeanFactoryAwareFunctionRegistry } AtomicReference> messages = new AtomicReference>(new ArrayList<>()); ((Iterable) value).forEach(element -> - messages.get().add((Message) convertOutputValueIfNecessary(element, enricher, acceptedContentType.toString()))); + messages.get() + .add((Message) convertOutputValueIfNecessary(element, enricher, acceptedContentType + .toString()))); convertedValue = messages.get(); } else { @@ -698,19 +729,30 @@ public class BeanFactoryAwareFunctionRegistry return convertedValue; } - @SuppressWarnings({ "rawtypes", "unchecked" }) - private Message convertValueToMessage(Object value, Function enricher, MimeType acceptedContentType) { + @SuppressWarnings({"rawtypes", "unchecked"}) + private Message convertValueToMessage(Object value, Function enricher, MimeType acceptedContentType) { Message outputMessage = null; if (value instanceof Message) { MessageHeaders headers = ((Message) value).getHeaders(); - if (!headers.containsKey(MessageHeaders.CONTENT_TYPE)) { + if (!headers.containsKey(NegotiatingMessageConverterWrapper.ACCEPT)) { Map headersMap = (Map) ReflectionUtils - .getField(this.headersField, headers); - headersMap.put(MessageHeaders.CONTENT_TYPE, acceptedContentType); + .getField(this.headersField, headers); + headersMap.put(NegotiatingMessageConverterWrapper.ACCEPT, acceptedContentType); + // Set the contentType header to the value of accept for "legacy" reasons. But, do not set the + // contentType header to the value of accept if it is a wildcard type, as this doesn't make sense. + // This also applies to the else branch below. + if (acceptedContentType.isConcrete()) { + headersMap.put(MessageHeaders.CONTENT_TYPE, acceptedContentType); + } } } else { - value = MessageBuilder.withPayload(value).setHeader(MessageHeaders.CONTENT_TYPE, acceptedContentType).build(); + MessageBuilder builder = MessageBuilder.withPayload(value) + .setHeader(NegotiatingMessageConverterWrapper.ACCEPT, acceptedContentType); + if (acceptedContentType.isConcrete()) { + builder.setHeader(MessageHeaders.CONTENT_TYPE, acceptedContentType); + } + value = builder.build(); } if (enricher != null) { value = enricher.apply((Message) value); @@ -726,8 +768,10 @@ public class BeanFactoryAwareFunctionRegistry } Publisher result = publisher instanceof Mono - ? Mono.from(publisher) .map(value -> this.convertOutputValueIfNecessary(value, enricher, acceptedOutputMimeTypes)) - : Flux.from(publisher).map(value -> this.convertOutputValueIfNecessary(value, enricher, acceptedOutputMimeTypes)); + ? Mono.from(publisher) + .map(value -> this.convertOutputValueIfNecessary(value, enricher, acceptedOutputMimeTypes)) + : Flux.from(publisher) + .map(value -> this.convertOutputValueIfNecessary(value, enricher, acceptedOutputMimeTypes)); return result; } @@ -737,8 +781,8 @@ public class BeanFactoryAwareFunctionRegistry } Publisher result = publisher instanceof Mono - ? Mono.from(publisher).map(value -> this.convertInputValueIfNecessary(value, type)) - : Flux.from(publisher).map(value -> this.convertInputValueIfNecessary(value, type)); + ? Mono.from(publisher).map(value -> this.convertInputValueIfNecessary(value, type)) + : Flux.from(publisher).map(value -> this.convertInputValueIfNecessary(value, type)); return result; } @@ -756,15 +800,17 @@ public class BeanFactoryAwareFunctionRegistry Expression parsed = new SpelExpressionParser().parseExpression("getT" + (i + 1) + "()"); Object inptArgument = parsed.getValue(value); inptArgument = inptArgument instanceof Publisher - ? this.convertInputPublisherIfNecessary((Publisher) inptArgument, FunctionTypeUtils.getInputType(functionType, i)) - : this.convertInputValueIfNecessary(inptArgument, FunctionTypeUtils.getInputType(functionType, i)); + ? this.convertInputPublisherIfNecessary((Publisher) inptArgument, FunctionTypeUtils + .getInputType(functionType, i)) + : this + .convertInputValueIfNecessary(inptArgument, FunctionTypeUtils.getInputType(functionType, i)); convertedInputArray[i] = inptArgument; } convertedValue = Tuples.fromArray(convertedInputArray); } else { // this needs revisiting as the type is not always Class (think really complex types) - Type rawType = FunctionTypeUtils.unwrapActualTypeByIndex(type, 0); + Type rawType = FunctionTypeUtils.unwrapActualTypeByIndex(type, 0); if (logger.isDebugEnabled()) { logger.debug("Raw type of value: " + value + "is " + rawType); } @@ -775,13 +821,14 @@ public class BeanFactoryAwareFunctionRegistry if (value instanceof Message) { // see AWS adapter with Optional payload if (messageNeedsConversion(rawType, (Message) value)) { convertedValue = FunctionTypeUtils.isTypeCollection(type) - ? messageConverter.fromMessage((Message) value, (Class) rawType, type) - : messageConverter.fromMessage((Message) value, (Class) rawType); + ? messageConverter.fromMessage((Message) value, (Class) rawType, type) + : messageConverter.fromMessage((Message) value, (Class) rawType); if (logger.isDebugEnabled()) { logger.debug("Converted from Message: " + convertedValue); } if (FunctionTypeUtils.isMessage(type)) { - convertedValue = MessageBuilder.withPayload(convertedValue).copyHeaders(((Message) value).getHeaders()).build(); + convertedValue = MessageBuilder.withPayload(convertedValue) + .copyHeaders(((Message) value).getHeaders()).build(); } } else if (!FunctionTypeUtils.isMessage(type)) { @@ -794,7 +841,8 @@ public class BeanFactoryAwareFunctionRegistry } catch (Exception e) { if (value instanceof String || value instanceof byte[]) { - convertedValue = messageConverter.fromMessage(new GenericMessage(value), (Class) rawType); + convertedValue = messageConverter + .fromMessage(new GenericMessage(value), (Class) rawType); } } } @@ -810,9 +858,9 @@ public class BeanFactoryAwareFunctionRegistry private boolean messageNeedsConversion(Type rawType, Message message) { Boolean skipConversion = message.getHeaders().containsKey(FunctionProperties.SKIP_CONVERSION_HEADER) - ? message.getHeaders().get(FunctionProperties.SKIP_CONVERSION_HEADER, Boolean.class) - : false; - if (skipConversion) { + ? message.getHeaders().get(FunctionProperties.SKIP_CONVERSION_HEADER, Boolean.class) + : false; + if (skipConversion) { return false; } return rawType instanceof Class diff --git a/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/catalog/NegotiatingMessageConverterWrapper.java b/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/catalog/NegotiatingMessageConverterWrapper.java new file mode 100644 index 000000000..7a1f9c06f --- /dev/null +++ b/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/catalog/NegotiatingMessageConverterWrapper.java @@ -0,0 +1,90 @@ +/* + * Copyright 2019-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.cloud.function.context.catalog; + +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHeaders; +import org.springframework.messaging.converter.AbstractMessageConverter; +import org.springframework.messaging.converter.SmartMessageConverter; +import org.springframework.messaging.support.MessageHeaderAccessor; +import org.springframework.util.MimeType; + +/** + * A {@link org.springframework.messaging.converter.AbstractMessageConverter} wrapper that supports the concept of wildcard + * negotiation when producing messages. To that effect, messages should contain an "accept" header, that may + * contain a wildcard type (such as {@code text/*}, which may be tested against every + * {@link AbstractMessageConverter#getSupportedMimeTypes() supported mime type} of the delegate MessageConverter. + */ +public final class NegotiatingMessageConverterWrapper implements SmartMessageConverter { + + /** + * The Message Header key that may contain the list of (possibly wildcard) MimeTypes to convert to. + */ + public static final String ACCEPT = "accept"; + + private final AbstractMessageConverter delegate; + + private NegotiatingMessageConverterWrapper(AbstractMessageConverter delegate) { + this.delegate = delegate; + } + + public static NegotiatingMessageConverterWrapper wrap(AbstractMessageConverter delegate) { + return new NegotiatingMessageConverterWrapper(delegate); + } + + @Override + public Object fromMessage(Message message, Class targetClass, Object conversionHint) { + return delegate.fromMessage(message, targetClass, conversionHint); + } + + @Override + public Message toMessage(Object payload, MessageHeaders headers, Object conversionHint) { + MimeType accepted = headers.get(ACCEPT, MimeType.class); + MessageHeaderAccessor accessor = new MessageHeaderAccessor(); + accessor.copyHeaders(headers); + accessor.removeHeader(ACCEPT); + // Fall back to (concrete) 'contentType' header if 'accept' is not present. + // MimeType.includes() below should then amount to equality. + if (accepted == null) { + accepted = headers.get(MessageHeaders.CONTENT_TYPE, MimeType.class); + } + + if (accepted != null) { + for (MimeType supportedConcreteType : delegate.getSupportedMimeTypes()) { + if (accepted.includes(supportedConcreteType)) { + // Note the use of setHeader() which will set the value even if already present. + accessor.setHeader(MessageHeaders.CONTENT_TYPE, supportedConcreteType); + Message result = delegate.toMessage(payload, accessor.toMessageHeaders(), conversionHint); + if (result != null) { + return result; + } + } + } + } + return null; + } + + @Override + public Object fromMessage(Message message, Class targetClass) { + return fromMessage(message, targetClass, null); + } + + @Override + public Message toMessage(Object payload, MessageHeaders headers) { + return toMessage(payload, headers, null); + } +} diff --git a/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/config/ContextFunctionCatalogAutoConfiguration.java b/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/config/ContextFunctionCatalogAutoConfiguration.java index c9241a335..e70e4be36 100644 --- a/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/config/ContextFunctionCatalogAutoConfiguration.java +++ b/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/config/ContextFunctionCatalogAutoConfiguration.java @@ -38,6 +38,7 @@ import org.springframework.cloud.function.context.FunctionProperties; import org.springframework.cloud.function.context.FunctionRegistry; import org.springframework.cloud.function.context.catalog.BeanFactoryAwareFunctionRegistry; import org.springframework.cloud.function.context.catalog.FunctionInspector; +import org.springframework.cloud.function.context.catalog.NegotiatingMessageConverterWrapper; import org.springframework.cloud.function.json.GsonMapper; import org.springframework.cloud.function.json.JacksonMapper; import org.springframework.context.ConfigurableApplicationContext; @@ -105,9 +106,9 @@ public class ContextFunctionCatalogAutoConfiguration { } MappingJackson2MessageConverter jsonConverter = new MappingJackson2MessageConverter(); jsonConverter.setObjectMapper(objectMapper); - mcList.add(jsonConverter); - mcList.add(new ByteArrayMessageConverter()); - mcList.add(new StringMessageConverter()); + mcList.add(NegotiatingMessageConverterWrapper.wrap(jsonConverter)); + mcList.add(NegotiatingMessageConverterWrapper.wrap(new ByteArrayMessageConverter())); + mcList.add(NegotiatingMessageConverterWrapper.wrap(new StringMessageConverter())); } if (!CollectionUtils.isEmpty(mcList)) { messageConverter = new CompositeMessageConverter(mcList);