diff --git a/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/catalog/BeanFactoryAwareFunctionRegistry.java b/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/catalog/BeanFactoryAwareFunctionRegistry.java index dbfca43c4..4f6f284ae 100644 --- a/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/catalog/BeanFactoryAwareFunctionRegistry.java +++ b/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/catalog/BeanFactoryAwareFunctionRegistry.java @@ -16,13 +16,13 @@ package org.springframework.cloud.function.context.catalog; +import java.lang.reflect.Field; import java.lang.reflect.Method; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; -import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashSet; import java.util.List; @@ -79,6 +79,7 @@ import org.springframework.util.CollectionUtils; import org.springframework.util.MimeType; import org.springframework.util.MimeTypeUtils; import org.springframework.util.ObjectUtils; +import org.springframework.util.ReflectionUtils; import org.springframework.util.StringUtils; @@ -466,12 +467,16 @@ public class BeanFactoryAwareFunctionRegistry private final String functionDefinition; + private final Field headersField; + FunctionInvocationWrapper(Object target, Type functionType, String functionDefinition, String... acceptedOutputMimeTypes) { this.target = target; this.composed = functionDefinition.contains("|") || target instanceof RoutingFunction; this.functionType = functionType; this.acceptedOutputMimeTypes = acceptedOutputMimeTypes; this.functionDefinition = functionDefinition; + this.headersField = ReflectionUtils.findField(MessageHeaders.class, "headers"); + this.headersField.setAccessible(true); } @Override @@ -680,25 +685,24 @@ public class BeanFactoryAwareFunctionRegistry return convertedValue; } - @SuppressWarnings("rawtypes") + @SuppressWarnings({ "rawtypes", "unchecked" }) private Message convertValueToMessage(Object value, Function enricher, MimeType acceptedContentType) { Message outputMessage = null; - if (enricher != null) { - if (!(value instanceof Message)) { - value = MessageBuilder.withPayload(value).setHeader(MessageHeaders.CONTENT_TYPE, acceptedContentType).build(); + if (value instanceof Message) { + MessageHeaders headers = ((Message) value).getHeaders(); + if (!headers.containsKey(MessageHeaders.CONTENT_TYPE)) { + Map headersMap = (Map) ReflectionUtils + .getField(this.headersField, headers); + headersMap.put(MessageHeaders.CONTENT_TYPE, acceptedContentType); } - value = enricher.apply((Message) value); - outputMessage = messageConverter.toMessage(((Message) value).getPayload(), ((Message) value).getHeaders()); } else { - if (value instanceof Message) { - outputMessage = messageConverter.toMessage(((Message) value).getPayload(), ((Message) value).getHeaders()); - } - else { - outputMessage = messageConverter.toMessage(value, - new MessageHeaders(Collections.singletonMap(MessageHeaders.CONTENT_TYPE, acceptedContentType))); - } + value = MessageBuilder.withPayload(value).setHeader(MessageHeaders.CONTENT_TYPE, acceptedContentType).build(); } + if (enricher != null) { + value = enricher.apply((Message) value); + } + outputMessage = messageConverter.toMessage(((Message) value).getPayload(), ((Message) value).getHeaders()); return outputMessage; } diff --git a/spring-cloud-function-context/src/test/java/org/springframework/cloud/function/context/catalog/BeanFactoryAwareFunctionRegistryTests.java b/spring-cloud-function-context/src/test/java/org/springframework/cloud/function/context/catalog/BeanFactoryAwareFunctionRegistryTests.java index 8e897fff2..10e471472 100644 --- a/spring-cloud-function-context/src/test/java/org/springframework/cloud/function/context/catalog/BeanFactoryAwareFunctionRegistryTests.java +++ b/spring-cloud-function-context/src/test/java/org/springframework/cloud/function/context/catalog/BeanFactoryAwareFunctionRegistryTests.java @@ -83,7 +83,10 @@ public class BeanFactoryAwareFunctionRegistryTests { assertThat(((boolean) field.get(function))).isFalse(); //== System.setProperty("spring.cloud.function.definition", "uppercase|uppercaseFlux"); - function = catalog.lookup(""); + function = catalog.lookup("", "application/json"); + Function, Flux>> typedFunction = (Function, Flux>>) function; + Object blockFirst = typedFunction.apply(Flux.just("hello")).blockFirst(); + System.out.println(blockFirst); assertThat(function).isNotNull(); field = ReflectionUtils.findField(FunctionInvocationWrapper.class, "composed"); field.setAccessible(true);