diff --git a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/request/FluxHandlerMethodArgumentResolver.java b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/request/FluxHandlerMethodArgumentResolver.java index 05a9d8ffa..6db85765b 100644 --- a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/request/FluxHandlerMethodArgumentResolver.java +++ b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/request/FluxHandlerMethodArgumentResolver.java @@ -31,9 +31,12 @@ import org.apache.commons.logging.LogFactory; import org.springframework.cloud.function.context.FunctionInspector; import org.springframework.cloud.function.web.flux.constants.WebRequestConstants; +import org.springframework.cloud.function.web.util.HeaderUtils; import org.springframework.core.MethodParameter; import org.springframework.core.Ordered; import org.springframework.http.MediaType; +import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.StreamUtils; import org.springframework.web.bind.support.WebDataBinderFactory; import org.springframework.web.context.request.NativeWebRequest; @@ -78,6 +81,7 @@ public class FluxHandlerMethodArgumentResolver if (type == null) { type = Object.class; } + boolean message = inspector.isMessage(inspector.getName(handler)); List body; ContentCachingRequestWrapper nativeRequest = new ContentCachingRequestWrapper( webRequest.getNativeRequest(HttpServletRequest.class)); @@ -100,6 +104,17 @@ public class FluxHandlerMethodArgumentResolver mapper.readValue(nativeRequest.getContentAsByteArray(), type)); } } + if (message) { + List messages = new ArrayList<>(); + for (Object payload : body) { + messages.add(MessageBuilder.withPayload(payload).copyHeaders( + HeaderUtils.fromHttp(new ServletServerHttpRequest( + webRequest.getNativeRequest(HttpServletRequest.class)) + .getHeaders())) + .build()); + } + body = messages; + } return new FluxRequest(body); } diff --git a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxResponseBodyEmitter.java b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxResponseBodyEmitter.java index 6c99697c0..d8814cca9 100644 --- a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxResponseBodyEmitter.java +++ b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxResponseBodyEmitter.java @@ -30,25 +30,27 @@ import reactor.core.publisher.Flux; * * @author Dave Syer */ -class FluxResponseBodyEmitter extends ResponseBodyEmitter { +class FluxResponseBodyEmitter extends ResponseBodyEmitter { private final MediaType mediaType; + private ResponseBodyEmitterSubscriber subscriber; - public FluxResponseBodyEmitter(Publisher observable) { - this(null, observable); + public FluxResponseBodyEmitter(Publisher observable) { + this(new HttpHeaders(), null, observable); } - public FluxResponseBodyEmitter(MediaType mediaType, Publisher observable) { + public FluxResponseBodyEmitter(HttpHeaders request, MediaType mediaType, + Publisher observable) { super(); this.mediaType = mediaType; - new ResponseBodyEmitterSubscriber<>(mediaType, observable, this, - MediaType.APPLICATION_JSON.isCompatibleWith(mediaType)); + this.subscriber = new ResponseBodyEmitterSubscriber(request, mediaType, + observable, this, MediaType.APPLICATION_JSON.isCompatibleWith(mediaType)); } @Override protected void extendResponse(ServerHttpResponse outputMessage) { super.extendResponse(outputMessage); - + this.subscriber.extendResponse(outputMessage); HttpHeaders headers = outputMessage.getHeaders(); if (headers.getContentType() == null && this.mediaType != null && !MediaType.ALL.equals(this.mediaType)) { diff --git a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxResponseSseEmitter.java b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxResponseSseEmitter.java index 2c5725dc9..07d9e0c5b 100644 --- a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxResponseSseEmitter.java +++ b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxResponseSseEmitter.java @@ -18,7 +18,9 @@ package org.springframework.cloud.function.web.flux.response; import org.reactivestreams.Publisher; +import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; +import org.springframework.http.server.ServerHttpResponse; import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitter; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; @@ -30,15 +32,24 @@ import reactor.core.publisher.Flux; * * @author Dave Syer */ -class FluxResponseSseEmitter extends SseEmitter { +class FluxResponseSseEmitter extends SseEmitter { - public FluxResponseSseEmitter(Publisher observable) { - this(MediaType.valueOf("text/plain"), observable); + private ResponseBodyEmitterSubscriber subscriber; + + public FluxResponseSseEmitter(Publisher observable) { + this(new HttpHeaders(), MediaType.valueOf("text/plain"), observable); } - public FluxResponseSseEmitter(MediaType mediaType, Publisher observable) { + public FluxResponseSseEmitter(HttpHeaders request, MediaType mediaType, + Publisher observable) { super(); - new ResponseBodyEmitterSubscriber<>(mediaType, observable, this, false); + this.subscriber = new ResponseBodyEmitterSubscriber(request, mediaType, + observable, this, false); } + @Override + protected void extendResponse(ServerHttpResponse outputMessage) { + super.extendResponse(outputMessage); + this.subscriber.extendResponse(outputMessage); + } } diff --git a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxReturnValueHandler.java b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxReturnValueHandler.java index af4e991cb..3bd3572df 100644 --- a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxReturnValueHandler.java +++ b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxReturnValueHandler.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.function.Supplier; import java.util.stream.Stream; +import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.apache.commons.logging.Log; @@ -32,11 +33,15 @@ import org.reactivestreams.Publisher; import org.springframework.cloud.function.context.FunctionInspector; import org.springframework.cloud.function.web.flux.constants.WebRequestConstants; +import org.springframework.cloud.function.web.util.HeaderUtils; import org.springframework.core.MethodParameter; import org.springframework.core.ResolvableType; +import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.messaging.Message; import org.springframework.util.ReflectionUtils; import org.springframework.web.context.request.NativeWebRequest; import org.springframework.web.method.support.AsyncHandlerMethodReturnValueHandler; @@ -130,8 +135,16 @@ public class FluxReturnValueHandler implements AsyncHandlerMethodReturnValueHand if (returnValue instanceof ResponseEntity) { ResponseEntity value = (ResponseEntity) returnValue; adaptFrom = value.getBody(); - webRequest.getNativeResponse(HttpServletResponse.class) - .setStatus(value.getStatusCodeValue()); + HttpServletResponse response = webRequest + .getNativeResponse(HttpServletResponse.class); + response.setStatus(value.getStatusCodeValue()); + HttpHeaders headers = value.getHeaders(); + for (String name : headers.keySet()) { + List list = headers.get(name); + for (String header : list) { + response.addHeader(name, header); + } + } } Publisher flux = (Publisher) adaptFrom; @@ -141,8 +154,13 @@ public class FluxReturnValueHandler implements AsyncHandlerMethodReturnValueHand boolean inputSingle = isInputSingle(webRequest, handler); if (inputSingle && isOutputSingle(handler)) { - single.handleReturnValue(Flux.from(flux).blockFirst(), singleReturnType, - mavContainer, webRequest); + Object result = Flux.from(flux).blockFirst(); + if (result instanceof Message) { + Message message = (Message) result; + result = message.getPayload(); + addHeaders(webRequest, message); + } + single.handleReturnValue(result, singleReturnType, mavContainer, webRequest); return; } @@ -157,10 +175,27 @@ public class FluxReturnValueHandler implements AsyncHandlerMethodReturnValueHand logger.debug( "Handling return value " + type + " with media type: " + mediaType); } - delegate.handleReturnValue(getEmitter(timeout, flux, mediaType), returnType, + ServletServerHttpRequest request = new ServletServerHttpRequest( + webRequest.getNativeRequest(HttpServletRequest.class)); + delegate.handleReturnValue( + getEmitter(timeout, flux, mediaType, request.getHeaders()), returnType, mavContainer, webRequest); } + private void addHeaders(NativeWebRequest webRequest, Message message) { + HttpServletResponse response = webRequest + .getNativeResponse(HttpServletResponse.class); + ServletServerHttpRequest request = new ServletServerHttpRequest( + webRequest.getNativeRequest(HttpServletRequest.class)); + HttpHeaders headers = HeaderUtils.fromMessage(message.getHeaders(), + request.getHeaders()); + for (String name : headers.keySet()) { + for (Object object : headers.get(name)) { + response.addHeader(name, object.toString()); + } + } + } + private boolean isInputSingle(NativeWebRequest webRequest, Object handler) { Boolean single = (Boolean) webRequest.getAttribute( WebRequestConstants.INPUT_SINGLE, NativeWebRequest.SCOPE_REQUEST); @@ -218,15 +253,16 @@ public class FluxReturnValueHandler implements AsyncHandlerMethodReturnValueHand } private ResponseBodyEmitter getEmitter(Long timeout, Publisher flux, - MediaType mediaType) { + MediaType mediaType, HttpHeaders request) { Publisher exported = flux instanceof Mono ? Mono.from(flux) : Flux.from(flux).timeout(Duration.ofMillis(timeout), Flux.empty()); if (!MediaType.ALL.equals(mediaType) && EVENT_STREAM.isCompatibleWith(mediaType)) { // TODO: more subtle content negotiation - return new FluxResponseSseEmitter<>(MediaType.APPLICATION_JSON, exported); + return new FluxResponseSseEmitter(request, MediaType.APPLICATION_JSON, + exported); } - return new FluxResponseBodyEmitter<>(mediaType, exported); + return new FluxResponseBodyEmitter(request, mediaType, exported); } } diff --git a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/ResponseBodyEmitterSubscriber.java b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/ResponseBodyEmitterSubscriber.java index b55b96cfe..3c5961561 100644 --- a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/ResponseBodyEmitterSubscriber.java +++ b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/ResponseBodyEmitterSubscriber.java @@ -23,7 +23,11 @@ import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; +import org.springframework.cloud.function.web.util.HeaderUtils; +import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.messaging.Message; import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitter; import reactor.core.publisher.Flux; @@ -35,7 +39,7 @@ import reactor.core.publisher.Mono; * * @author Dave Syer */ -class ResponseBodyEmitterSubscriber implements Subscriber { +class ResponseBodyEmitterSubscriber implements Subscriber { private final MediaType mediaType; @@ -49,11 +53,17 @@ class ResponseBodyEmitterSubscriber implements Subscriber { private boolean single; - private boolean json; + private final boolean json; - public ResponseBodyEmitterSubscriber(MediaType mediaType, Publisher observable, - ResponseBodyEmitter responseBodyEmitter, boolean json) { + private Message first; + private final HttpHeaders request; + + public ResponseBodyEmitterSubscriber(HttpHeaders request, MediaType mediaType, + Publisher observable, ResponseBodyEmitter responseBodyEmitter, + boolean json) { + + this.request = request; this.mediaType = mediaType; this.responseBodyEmitter = responseBodyEmitter; this.json = json; @@ -63,6 +73,10 @@ class ResponseBodyEmitterSubscriber implements Subscriber { observable.subscribe(this); } + public void extendResponse(ServerHttpResponse response) { + headers(response); + } + @Override public void onSubscribe(Subscription subscription) { this.subscription = subscription; @@ -70,10 +84,16 @@ class ResponseBodyEmitterSubscriber implements Subscriber { } @Override - public void onNext(T value) { + public void onNext(Object value) { Object object = value; + if (object instanceof Message) { + Message message = (Message) object; + object = message.getPayload(); + this.first = message; + } + try { if (isJson()) { if (!this.firstElementWritten) { @@ -85,9 +105,9 @@ class ResponseBodyEmitterSubscriber implements Subscriber { else { responseBodyEmitter.send(","); } - if (!single && value.getClass() == String.class - && !((String) value).contains("\"")) { - object = "\"" + value + "\""; + if (!single && object.getClass() == String.class + && !((String) object).contains("\"")) { + object = "\"" + object + "\""; } } if (!completed) { @@ -101,6 +121,24 @@ class ResponseBodyEmitterSubscriber implements Subscriber { } } + private void headers(ServerHttpResponse response) { + if (this.first != null) { + Message message = first; + try { + HttpHeaders headers = HeaderUtils.fromMessage(message.getHeaders(), + request); + for (String name : headers.keySet()) { + for (String value : headers.get(name)) { + response.getHeaders().add(name, value); + } + } + } + catch (Exception e) { + // Headers could not be set + } + } + } + @Override public void onError(Throwable e) { if (!completed) { @@ -170,4 +208,5 @@ class ResponseBodyEmitterSubscriber implements Subscriber { ResponseBodyEmitterSubscriber.this.subscription.cancel(); } } + } diff --git a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/util/HeaderUtils.java b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/util/HeaderUtils.java new file mode 100644 index 000000000..932859e33 --- /dev/null +++ b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/util/HeaderUtils.java @@ -0,0 +1,84 @@ +/* + * Copyright 2016-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.cloud.function.web.util; + +import java.util.Arrays; +import java.util.Collection; +import java.util.LinkedHashMap; +import java.util.Map; + +import org.springframework.http.HttpHeaders; +import org.springframework.messaging.MessageHeaders; +import org.springframework.util.ObjectUtils; + +/** + * @author Dave Syer + * + */ +public class HeaderUtils { + + public static HttpHeaders fromMessage(MessageHeaders headers, HttpHeaders request) { + HttpHeaders result = new HttpHeaders(); + for (String name : headers.keySet()) { + Object value = headers.get(name); + name = name.toLowerCase(); + if (MessageHeaders.ID.equals(name)) { + continue; + } + if (request.containsKey(name)) { + if (name.startsWith("x-")) { + if (!name.startsWith("x-forwarded")) { + Collection values = multi(value); + for (Object object : values) { + result.set(name, object.toString()); + } + } + } + } + else { + Collection values = multi(value); + for (Object object : values) { + result.set(name, object.toString()); + } + } + } + return result; + } + + private static Collection multi(Object value) { + if (value instanceof Collection) { + Collection collection = (Collection) value; + return collection; + } + else if (ObjectUtils.isArray(value)) { + Object[] values = ObjectUtils.toObjectArray(value); + return Arrays.asList(values); + } + return Arrays.asList(value); + } + + public static MessageHeaders fromHttp(HttpHeaders headers) { + Map map = new LinkedHashMap<>(); + for (String name : headers.keySet()) { + Collection values = multi(headers.get(name)); + name = name.toLowerCase(); + Object value = values == null ? null + : (values.size() == 1 ? values.iterator().next() : values); + map.put(name, value); + } + return new MessageHeaders(map); + } +} diff --git a/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/RestApplicationTests.java b/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/RestApplicationTests.java index 44e5a420a..b87f90175 100644 --- a/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/RestApplicationTests.java +++ b/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/RestApplicationTests.java @@ -45,6 +45,8 @@ import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; +import org.springframework.messaging.Message; +import org.springframework.messaging.support.MessageBuilder; import org.springframework.test.context.junit4.SpringRunner; import org.springframework.util.StringUtils; @@ -256,6 +258,26 @@ public class RestApplicationTests { assertThat(result.getBody()).isEqualTo("[\"(FOO)\",\"(BAR)\"]"); } + @Test + public void messages() throws Exception { + ResponseEntity result = rest.exchange(RequestEntity + .post(new URI("/messages")).contentType(MediaType.APPLICATION_JSON) + .header("x-foo", "bar").body("[\"foo\",\"bar\"]"), String.class); + assertThat(result.getBody()).isEqualTo("[\"(FOO)\",\"(BAR)\"]"); + assertThat(result.getHeaders().getFirst("x-foo")).isEqualTo("bar"); + assertThat(result.getHeaders()).doesNotContainKey("id"); + } + + @Test + public void headers() throws Exception { + ResponseEntity result = rest.exchange(RequestEntity + .post(new URI("/headers")).contentType(MediaType.APPLICATION_JSON) + .body("[\"foo\",\"bar\"]"), String.class); + assertThat(result.getBody()).isEqualTo("[\"(FOO)\",\"(BAR)\"]"); + assertThat(result.getHeaders().getFirst("foo")).isEqualTo("bar"); + assertThat(result.getHeaders()).doesNotContainKey("id"); + } + @Test public void uppercaseSingleValue() throws Exception { ResponseEntity result = rest @@ -406,6 +428,20 @@ public class RestApplicationTests { return value -> "(" + value.trim().toUpperCase() + ")"; } + @Bean + public Function, Message> messages() { + return value -> MessageBuilder + .withPayload("(" + value.getPayload().trim().toUpperCase() + ")") + .copyHeaders(value.getHeaders()).build(); + } + + @Bean + public Function>, Flux>> headers() { + return flux -> flux.map(value -> MessageBuilder + .withPayload("(" + value.getPayload().trim().toUpperCase() + ")") + .setHeader("foo", "bar").build()); + } + @Bean public Function, Flux> upFoos() { return flux -> flux.log()