diff --git a/README.adoc b/README.adoc index e90c4b3f9..4bf63fd4c 100644 --- a/README.adoc +++ b/README.adoc @@ -188,7 +188,7 @@ We recommend the http://eclipse.org/m2e/[m2eclipse] eclipse plugin when working eclipse. If you don't already have m2eclipse installed it is available from the "eclipse marketplace". -Also, you will need to install Kotlin pug-in from the "eclipse marketplace" to ensure projects +Also, you will need to install Kotlin plug-in from the "eclipse marketplace" to ensure projects that depend on Kotlin complile under `m2eclipse` plugin (mentioned above). NOTE: If you still see compile errors after installing Kotlin plugin, simply right-click on the project with error and _remove_ and then _add_ Kotlin Nature via ***Configure Kotlin*** feature. diff --git a/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/FunctionRegistration.java b/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/FunctionRegistration.java index f342498be..def633e8b 100644 --- a/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/FunctionRegistration.java +++ b/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/FunctionRegistration.java @@ -31,7 +31,6 @@ import org.springframework.cloud.function.core.FluxConsumer; import org.springframework.cloud.function.core.FluxFunction; import org.springframework.cloud.function.core.FluxSupplier; import org.springframework.util.Assert; -import org.springframework.util.ObjectUtils; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -51,7 +50,8 @@ public class FunctionRegistration { private FunctionType type; /** - * @deprecated as of v1.0.0 in favor of {@link #FunctionRegistration(Object, String...)} + * @deprecated as of v1.0.0 in favor of + * {@link #FunctionRegistration(Object, String...)} */ @Deprecated public FunctionRegistration(T target) { @@ -59,13 +59,12 @@ public class FunctionRegistration { this.target = target; } - /** * Creates instance of FunctionRegistration. * * @param target instance of {@link Supplier}, {@link Function} or {@link Consumer} - * @param names additional set of names for this registration. Additional names - * can be provided {@link #name(String)} or {@link #names(String...)} operations. + * @param names additional set of names for this registration. Additional names can be + * provided {@link #name(String)} or {@link #names(String...)} operations. */ public FunctionRegistration(T target, String... names) { Assert.notNull(target, "'target' must not be null"); diff --git a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/RequestProcessor.java b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/RequestProcessor.java index c613c4d45..36998076a 100644 --- a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/RequestProcessor.java +++ b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/RequestProcessor.java @@ -78,14 +78,17 @@ public class RequestProcessor { public Mono> get(FunctionWrapper wrapper) { if (wrapper.function() != null) { - return response(wrapper.function(), value(wrapper.function(), wrapper.argument()), true, true); + return response(wrapper, wrapper.function(), + value(wrapper.function(), wrapper.argument()), true, true); } else { - return response(wrapper.supplier(), wrapper.supplier().get(), null, true); + return response(wrapper, wrapper.supplier(), wrapper.supplier().get(), null, + true); } } - public Mono> post(FunctionWrapper wrapper, String body, boolean stream) { + public Mono> post(FunctionWrapper wrapper, String body, + boolean stream) { Mono> responseEntityMono; Object function = wrapper.handler(); @@ -117,14 +120,15 @@ public class RequestProcessor { public Mono> stream(FunctionWrapper request) { Publisher result = request.function() != null ? value(request.function(), request.argument()) - : request.supplier().get(); + : request.supplier().get(); return stream(request, result); } private Mono> post(FunctionWrapper wrapper, Object body, MultiValueMap params, boolean stream) { - Iterable iterable = body instanceof Collection ? (List) body : Collections.singletonList(body); + Iterable iterable = body instanceof Collection ? (List) body + : Collections.singletonList(body); Function, Publisher> function = wrapper.function(); Consumer> consumer = wrapper.consumer(); @@ -146,13 +150,15 @@ public class RequestProcessor { responseEntityMono = stream(wrapper, result); } else { - responseEntityMono = response(function, result, body == null ? null : !(body instanceof Collection), false); + responseEntityMono = response(wrapper, function, result, + body == null ? null : !(body instanceof Collection), false); } } else if (consumer != null) { consumer.accept(flux); logger.debug("Handled POST with consumer"); - responseEntityMono = Mono.just(ResponseEntity.status(HttpStatus.ACCEPTED).build()); + responseEntityMono = Mono + .just(ResponseEntity.status(HttpStatus.ACCEPTED).build()); } return responseEntityMono; } @@ -175,13 +181,16 @@ public class RequestProcessor { .map(message -> MessageUtils.unpack(request.handler(), message) .getPayload()); } + else { + builder.headers(HeaderUtils.sanitize(request.headers())); + } Publisher output = result; return Flux.from(output).then(Mono.fromSupplier(() -> builder.body(output))); } - private Mono> response(Object handler, Publisher result, - Boolean single, boolean getter) { + private Mono> response(FunctionWrapper request, Object handler, + Publisher result, Boolean single, boolean getter) { BodyBuilder builder = ResponseEntity.ok(); if (inspector.isMessage(handler)) { @@ -189,8 +198,12 @@ public class RequestProcessor { .doOnNext(value -> addHeaders(builder, (Message) value)) .map(message -> MessageUtils.unpack(handler, message).getPayload()); } + else { + builder.headers(HeaderUtils.sanitize(request.headers())); + } - if (isOutputSingle(handler) && (single != null && single || getter || isInputMultiple(handler))) { + if (isOutputSingle(handler) + && (single != null && single || getter || isInputMultiple(handler))) { result = Mono.from(result); } @@ -213,7 +226,8 @@ public class RequestProcessor { return false; } else { - return wrapper == type || Mono.class.equals(wrapper) || Optional.class.equals(wrapper); + return wrapper == type || Mono.class.equals(wrapper) + || Optional.class.equals(wrapper); } } diff --git a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/util/HeaderUtils.java b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/util/HeaderUtils.java index 4e27fa274..d5cf9d0ec 100644 --- a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/util/HeaderUtils.java +++ b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/util/HeaderUtils.java @@ -18,6 +18,7 @@ package org.springframework.cloud.function.web.util; import java.util.Arrays; import java.util.Collection; import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; import org.springframework.http.HttpHeaders; @@ -30,10 +31,17 @@ import org.springframework.messaging.MessageHeaders; public class HeaderUtils { private static HttpHeaders IGNORED = new HttpHeaders(); - + + private static HttpHeaders REQUEST_ONLY = new HttpHeaders(); + static { IGNORED.add(MessageHeaders.ID, ""); IGNORED.add(HttpHeaders.CONTENT_LENGTH, "0"); + // Headers that would typically be added by a downstream client + REQUEST_ONLY.add(HttpHeaders.ACCEPT, ""); + REQUEST_ONLY.add(HttpHeaders.CONTENT_LENGTH, ""); + REQUEST_ONLY.add(HttpHeaders.CONTENT_TYPE, ""); + REQUEST_ONLY.add(HttpHeaders.HOST, ""); } public static HttpHeaders fromMessage(MessageHeaders headers, HttpHeaders request) { @@ -51,6 +59,18 @@ public class HeaderUtils { return result; } + public static HttpHeaders sanitize(HttpHeaders request) { + HttpHeaders result = new HttpHeaders(); + for (String name : request.keySet()) { + List value = request.get(name); + name = name.toLowerCase(); + if (!IGNORED.containsKey(name) && !REQUEST_ONLY.containsKey(name)) { + result.put(name, value); + } + } + return result; + } + private static Collection multi(Object value) { return value instanceof Collection ? (Collection) value : Arrays.asList(value); } diff --git a/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/mvc/HeadersToMessageTests.java b/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/mvc/HeadersToMessageTests.java index 967eb3cfb..ae78ae9b5 100644 --- a/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/mvc/HeadersToMessageTests.java +++ b/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/mvc/HeadersToMessageTests.java @@ -16,6 +16,7 @@ package org.springframework.cloud.function.web.mvc; import java.net.URI; +import java.util.LinkedHashMap; import java.util.Map; import java.util.function.Function; @@ -30,6 +31,8 @@ import org.springframework.boot.test.web.client.TestRestTemplate; import org.springframework.cloud.function.web.RestApplication; import org.springframework.cloud.function.web.mvc.HeadersToMessageTests.TestConfiguration; import org.springframework.context.annotation.Bean; +import org.springframework.http.HttpEntity; +import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; import org.springframework.messaging.Message; import org.springframework.messaging.support.MessageBuilder; @@ -46,8 +49,9 @@ import static org.junit.Assert.assertTrue; */ @RunWith(SpringRunner.class) @SpringBootTest(webEnvironment = WebEnvironment.RANDOM_PORT, properties = { - "spring.main.web-application-type=servlet", "spring.cloud.function.web.path=/functions" }) -@ContextConfiguration(classes= {RestApplication.class, TestConfiguration.class}) + "spring.main.web-application-type=servlet", + "spring.cloud.function.web.path=/functions" }) +@ContextConfiguration(classes = { RestApplication.class, TestConfiguration.class }) public class HeadersToMessageTests { @Autowired @@ -65,6 +69,17 @@ public class HeadersToMessageTests { assertEquals("bar", postForEntity.getHeaders().get("foo").get(0)); } + @Test + public void testHeadersPropagatedByDefault() throws Exception { + HttpEntity postForEntity = rest.exchange(RequestEntity + .post(new URI("/functions/vanilla")).header("x-context-type", "rubbish") + .body("{\"name\":\"Bob\",\"age\":25}"), String.class); + assertEquals("{\"name\":\"Bob\",\"age\":25,\"foo\":\"bar\"}", + postForEntity.getBody()); + assertTrue(postForEntity.getHeaders().containsKey("x-context-type")); + assertEquals("rubbish", postForEntity.getHeaders().get("x-context-type").get(0)); + } + @EnableAutoConfiguration @org.springframework.boot.test.context.TestConfiguration protected static class TestConfiguration { @@ -78,5 +93,14 @@ public class HeadersToMessageTests { return message; }; } + + @Bean + public Function, Map> vanilla() { + return request -> { + Map message = new LinkedHashMap<>(request); + message.put("foo", "bar"); + return message; + }; + } } }