From b61562508bee1f4bda52fbe5f3868b5184386018 Mon Sep 17 00:00:00 2001 From: Dave Syer Date: Thu, 15 Jun 2017 15:37:55 +0100 Subject: [PATCH] Add support for Message handling in web functions If the output is a Message we grab the headers from the first one and send those as response headers. A function can add headers to messages. The HTTP response will contain only headers that are in x-* _or_ were added to the message by user (i.e. they weren't in the request). We certainly do not want to pass through all HTTP headers (request headers can conflict with or invalidate response headers). --- .../FluxHandlerMethodArgumentResolver.java | 15 ++++ .../response/FluxResponseBodyEmitter.java | 16 ++-- .../flux/response/FluxResponseSseEmitter.java | 21 +++-- .../flux/response/FluxReturnValueHandler.java | 52 ++++++++++-- .../ResponseBodyEmitterSubscriber.java | 55 ++++++++++-- .../cloud/function/web/util/HeaderUtils.java | 84 +++++++++++++++++++ .../function/web/RestApplicationTests.java | 36 ++++++++ 7 files changed, 251 insertions(+), 28 deletions(-) create mode 100644 spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/util/HeaderUtils.java diff --git a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/request/FluxHandlerMethodArgumentResolver.java b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/request/FluxHandlerMethodArgumentResolver.java index 05a9d8ffa..6db85765b 100644 --- a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/request/FluxHandlerMethodArgumentResolver.java +++ b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/request/FluxHandlerMethodArgumentResolver.java @@ -31,9 +31,12 @@ import org.apache.commons.logging.LogFactory; import org.springframework.cloud.function.context.FunctionInspector; import org.springframework.cloud.function.web.flux.constants.WebRequestConstants; +import org.springframework.cloud.function.web.util.HeaderUtils; import org.springframework.core.MethodParameter; import org.springframework.core.Ordered; import org.springframework.http.MediaType; +import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.StreamUtils; import org.springframework.web.bind.support.WebDataBinderFactory; import org.springframework.web.context.request.NativeWebRequest; @@ -78,6 +81,7 @@ public class FluxHandlerMethodArgumentResolver if (type == null) { type = Object.class; } + boolean message = inspector.isMessage(inspector.getName(handler)); List body; ContentCachingRequestWrapper nativeRequest = new ContentCachingRequestWrapper( webRequest.getNativeRequest(HttpServletRequest.class)); @@ -100,6 +104,17 @@ public class FluxHandlerMethodArgumentResolver mapper.readValue(nativeRequest.getContentAsByteArray(), type)); } } + if (message) { + List messages = new ArrayList<>(); + for (Object payload : body) { + messages.add(MessageBuilder.withPayload(payload).copyHeaders( + HeaderUtils.fromHttp(new ServletServerHttpRequest( + webRequest.getNativeRequest(HttpServletRequest.class)) + .getHeaders())) + .build()); + } + body = messages; + } return new FluxRequest(body); } diff --git a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxResponseBodyEmitter.java b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxResponseBodyEmitter.java index 6c99697c0..d8814cca9 100644 --- a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxResponseBodyEmitter.java +++ b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxResponseBodyEmitter.java @@ -30,25 +30,27 @@ import reactor.core.publisher.Flux; * * @author Dave Syer */ -class FluxResponseBodyEmitter extends ResponseBodyEmitter { +class FluxResponseBodyEmitter extends ResponseBodyEmitter { private final MediaType mediaType; + private ResponseBodyEmitterSubscriber subscriber; - public FluxResponseBodyEmitter(Publisher observable) { - this(null, observable); + public FluxResponseBodyEmitter(Publisher observable) { + this(new HttpHeaders(), null, observable); } - public FluxResponseBodyEmitter(MediaType mediaType, Publisher observable) { + public FluxResponseBodyEmitter(HttpHeaders request, MediaType mediaType, + Publisher observable) { super(); this.mediaType = mediaType; - new ResponseBodyEmitterSubscriber<>(mediaType, observable, this, - MediaType.APPLICATION_JSON.isCompatibleWith(mediaType)); + this.subscriber = new ResponseBodyEmitterSubscriber(request, mediaType, + observable, this, MediaType.APPLICATION_JSON.isCompatibleWith(mediaType)); } @Override protected void extendResponse(ServerHttpResponse outputMessage) { super.extendResponse(outputMessage); - + this.subscriber.extendResponse(outputMessage); HttpHeaders headers = outputMessage.getHeaders(); if (headers.getContentType() == null && this.mediaType != null && !MediaType.ALL.equals(this.mediaType)) { diff --git a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxResponseSseEmitter.java b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxResponseSseEmitter.java index 2c5725dc9..07d9e0c5b 100644 --- a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxResponseSseEmitter.java +++ b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxResponseSseEmitter.java @@ -18,7 +18,9 @@ package org.springframework.cloud.function.web.flux.response; import org.reactivestreams.Publisher; +import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; +import org.springframework.http.server.ServerHttpResponse; import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitter; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; @@ -30,15 +32,24 @@ import reactor.core.publisher.Flux; * * @author Dave Syer */ -class FluxResponseSseEmitter extends SseEmitter { +class FluxResponseSseEmitter extends SseEmitter { - public FluxResponseSseEmitter(Publisher observable) { - this(MediaType.valueOf("text/plain"), observable); + private ResponseBodyEmitterSubscriber subscriber; + + public FluxResponseSseEmitter(Publisher observable) { + this(new HttpHeaders(), MediaType.valueOf("text/plain"), observable); } - public FluxResponseSseEmitter(MediaType mediaType, Publisher observable) { + public FluxResponseSseEmitter(HttpHeaders request, MediaType mediaType, + Publisher observable) { super(); - new ResponseBodyEmitterSubscriber<>(mediaType, observable, this, false); + this.subscriber = new ResponseBodyEmitterSubscriber(request, mediaType, + observable, this, false); } + @Override + protected void extendResponse(ServerHttpResponse outputMessage) { + super.extendResponse(outputMessage); + this.subscriber.extendResponse(outputMessage); + } } diff --git a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxReturnValueHandler.java b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxReturnValueHandler.java index af4e991cb..3bd3572df 100644 --- a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxReturnValueHandler.java +++ b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxReturnValueHandler.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.function.Supplier; import java.util.stream.Stream; +import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.apache.commons.logging.Log; @@ -32,11 +33,15 @@ import org.reactivestreams.Publisher; import org.springframework.cloud.function.context.FunctionInspector; import org.springframework.cloud.function.web.flux.constants.WebRequestConstants; +import org.springframework.cloud.function.web.util.HeaderUtils; import org.springframework.core.MethodParameter; import org.springframework.core.ResolvableType; +import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.messaging.Message; import org.springframework.util.ReflectionUtils; import org.springframework.web.context.request.NativeWebRequest; import org.springframework.web.method.support.AsyncHandlerMethodReturnValueHandler; @@ -130,8 +135,16 @@ public class FluxReturnValueHandler implements AsyncHandlerMethodReturnValueHand if (returnValue instanceof ResponseEntity) { ResponseEntity value = (ResponseEntity) returnValue; adaptFrom = value.getBody(); - webRequest.getNativeResponse(HttpServletResponse.class) - .setStatus(value.getStatusCodeValue()); + HttpServletResponse response = webRequest + .getNativeResponse(HttpServletResponse.class); + response.setStatus(value.getStatusCodeValue()); + HttpHeaders headers = value.getHeaders(); + for (String name : headers.keySet()) { + List list = headers.get(name); + for (String header : list) { + response.addHeader(name, header); + } + } } Publisher flux = (Publisher) adaptFrom; @@ -141,8 +154,13 @@ public class FluxReturnValueHandler implements AsyncHandlerMethodReturnValueHand boolean inputSingle = isInputSingle(webRequest, handler); if (inputSingle && isOutputSingle(handler)) { - single.handleReturnValue(Flux.from(flux).blockFirst(), singleReturnType, - mavContainer, webRequest); + Object result = Flux.from(flux).blockFirst(); + if (result instanceof Message) { + Message message = (Message) result; + result = message.getPayload(); + addHeaders(webRequest, message); + } + single.handleReturnValue(result, singleReturnType, mavContainer, webRequest); return; } @@ -157,10 +175,27 @@ public class FluxReturnValueHandler implements AsyncHandlerMethodReturnValueHand logger.debug( "Handling return value " + type + " with media type: " + mediaType); } - delegate.handleReturnValue(getEmitter(timeout, flux, mediaType), returnType, + ServletServerHttpRequest request = new ServletServerHttpRequest( + webRequest.getNativeRequest(HttpServletRequest.class)); + delegate.handleReturnValue( + getEmitter(timeout, flux, mediaType, request.getHeaders()), returnType, mavContainer, webRequest); } + private void addHeaders(NativeWebRequest webRequest, Message message) { + HttpServletResponse response = webRequest + .getNativeResponse(HttpServletResponse.class); + ServletServerHttpRequest request = new ServletServerHttpRequest( + webRequest.getNativeRequest(HttpServletRequest.class)); + HttpHeaders headers = HeaderUtils.fromMessage(message.getHeaders(), + request.getHeaders()); + for (String name : headers.keySet()) { + for (Object object : headers.get(name)) { + response.addHeader(name, object.toString()); + } + } + } + private boolean isInputSingle(NativeWebRequest webRequest, Object handler) { Boolean single = (Boolean) webRequest.getAttribute( WebRequestConstants.INPUT_SINGLE, NativeWebRequest.SCOPE_REQUEST); @@ -218,15 +253,16 @@ public class FluxReturnValueHandler implements AsyncHandlerMethodReturnValueHand } private ResponseBodyEmitter getEmitter(Long timeout, Publisher flux, - MediaType mediaType) { + MediaType mediaType, HttpHeaders request) { Publisher exported = flux instanceof Mono ? Mono.from(flux) : Flux.from(flux).timeout(Duration.ofMillis(timeout), Flux.empty()); if (!MediaType.ALL.equals(mediaType) && EVENT_STREAM.isCompatibleWith(mediaType)) { // TODO: more subtle content negotiation - return new FluxResponseSseEmitter<>(MediaType.APPLICATION_JSON, exported); + return new FluxResponseSseEmitter(request, MediaType.APPLICATION_JSON, + exported); } - return new FluxResponseBodyEmitter<>(mediaType, exported); + return new FluxResponseBodyEmitter(request, mediaType, exported); } } diff --git a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/ResponseBodyEmitterSubscriber.java b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/ResponseBodyEmitterSubscriber.java index b55b96cfe..3c5961561 100644 --- a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/ResponseBodyEmitterSubscriber.java +++ b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/ResponseBodyEmitterSubscriber.java @@ -23,7 +23,11 @@ import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; +import org.springframework.cloud.function.web.util.HeaderUtils; +import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.messaging.Message; import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitter; import reactor.core.publisher.Flux; @@ -35,7 +39,7 @@ import reactor.core.publisher.Mono; * * @author Dave Syer */ -class ResponseBodyEmitterSubscriber implements Subscriber { +class ResponseBodyEmitterSubscriber implements Subscriber { private final MediaType mediaType; @@ -49,11 +53,17 @@ class ResponseBodyEmitterSubscriber implements Subscriber { private boolean single; - private boolean json; + private final boolean json; - public ResponseBodyEmitterSubscriber(MediaType mediaType, Publisher observable, - ResponseBodyEmitter responseBodyEmitter, boolean json) { + private Message first; + private final HttpHeaders request; + + public ResponseBodyEmitterSubscriber(HttpHeaders request, MediaType mediaType, + Publisher observable, ResponseBodyEmitter responseBodyEmitter, + boolean json) { + + this.request = request; this.mediaType = mediaType; this.responseBodyEmitter = responseBodyEmitter; this.json = json; @@ -63,6 +73,10 @@ class ResponseBodyEmitterSubscriber implements Subscriber { observable.subscribe(this); } + public void extendResponse(ServerHttpResponse response) { + headers(response); + } + @Override public void onSubscribe(Subscription subscription) { this.subscription = subscription; @@ -70,10 +84,16 @@ class ResponseBodyEmitterSubscriber implements Subscriber { } @Override - public void onNext(T value) { + public void onNext(Object value) { Object object = value; + if (object instanceof Message) { + Message message = (Message) object; + object = message.getPayload(); + this.first = message; + } + try { if (isJson()) { if (!this.firstElementWritten) { @@ -85,9 +105,9 @@ class ResponseBodyEmitterSubscriber implements Subscriber { else { responseBodyEmitter.send(","); } - if (!single && value.getClass() == String.class - && !((String) value).contains("\"")) { - object = "\"" + value + "\""; + if (!single && object.getClass() == String.class + && !((String) object).contains("\"")) { + object = "\"" + object + "\""; } } if (!completed) { @@ -101,6 +121,24 @@ class ResponseBodyEmitterSubscriber implements Subscriber { } } + private void headers(ServerHttpResponse response) { + if (this.first != null) { + Message message = first; + try { + HttpHeaders headers = HeaderUtils.fromMessage(message.getHeaders(), + request); + for (String name : headers.keySet()) { + for (String value : headers.get(name)) { + response.getHeaders().add(name, value); + } + } + } + catch (Exception e) { + // Headers could not be set + } + } + } + @Override public void onError(Throwable e) { if (!completed) { @@ -170,4 +208,5 @@ class ResponseBodyEmitterSubscriber implements Subscriber { ResponseBodyEmitterSubscriber.this.subscription.cancel(); } } + } diff --git a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/util/HeaderUtils.java b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/util/HeaderUtils.java new file mode 100644 index 000000000..932859e33 --- /dev/null +++ b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/util/HeaderUtils.java @@ -0,0 +1,84 @@ +/* + * Copyright 2016-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.cloud.function.web.util; + +import java.util.Arrays; +import java.util.Collection; +import java.util.LinkedHashMap; +import java.util.Map; + +import org.springframework.http.HttpHeaders; +import org.springframework.messaging.MessageHeaders; +import org.springframework.util.ObjectUtils; + +/** + * @author Dave Syer + * + */ +public class HeaderUtils { + + public static HttpHeaders fromMessage(MessageHeaders headers, HttpHeaders request) { + HttpHeaders result = new HttpHeaders(); + for (String name : headers.keySet()) { + Object value = headers.get(name); + name = name.toLowerCase(); + if (MessageHeaders.ID.equals(name)) { + continue; + } + if (request.containsKey(name)) { + if (name.startsWith("x-")) { + if (!name.startsWith("x-forwarded")) { + Collection values = multi(value); + for (Object object : values) { + result.set(name, object.toString()); + } + } + } + } + else { + Collection values = multi(value); + for (Object object : values) { + result.set(name, object.toString()); + } + } + } + return result; + } + + private static Collection multi(Object value) { + if (value instanceof Collection) { + Collection collection = (Collection) value; + return collection; + } + else if (ObjectUtils.isArray(value)) { + Object[] values = ObjectUtils.toObjectArray(value); + return Arrays.asList(values); + } + return Arrays.asList(value); + } + + public static MessageHeaders fromHttp(HttpHeaders headers) { + Map map = new LinkedHashMap<>(); + for (String name : headers.keySet()) { + Collection values = multi(headers.get(name)); + name = name.toLowerCase(); + Object value = values == null ? null + : (values.size() == 1 ? values.iterator().next() : values); + map.put(name, value); + } + return new MessageHeaders(map); + } +} diff --git a/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/RestApplicationTests.java b/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/RestApplicationTests.java index 44e5a420a..b87f90175 100644 --- a/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/RestApplicationTests.java +++ b/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/RestApplicationTests.java @@ -45,6 +45,8 @@ import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; +import org.springframework.messaging.Message; +import org.springframework.messaging.support.MessageBuilder; import org.springframework.test.context.junit4.SpringRunner; import org.springframework.util.StringUtils; @@ -256,6 +258,26 @@ public class RestApplicationTests { assertThat(result.getBody()).isEqualTo("[\"(FOO)\",\"(BAR)\"]"); } + @Test + public void messages() throws Exception { + ResponseEntity result = rest.exchange(RequestEntity + .post(new URI("/messages")).contentType(MediaType.APPLICATION_JSON) + .header("x-foo", "bar").body("[\"foo\",\"bar\"]"), String.class); + assertThat(result.getBody()).isEqualTo("[\"(FOO)\",\"(BAR)\"]"); + assertThat(result.getHeaders().getFirst("x-foo")).isEqualTo("bar"); + assertThat(result.getHeaders()).doesNotContainKey("id"); + } + + @Test + public void headers() throws Exception { + ResponseEntity result = rest.exchange(RequestEntity + .post(new URI("/headers")).contentType(MediaType.APPLICATION_JSON) + .body("[\"foo\",\"bar\"]"), String.class); + assertThat(result.getBody()).isEqualTo("[\"(FOO)\",\"(BAR)\"]"); + assertThat(result.getHeaders().getFirst("foo")).isEqualTo("bar"); + assertThat(result.getHeaders()).doesNotContainKey("id"); + } + @Test public void uppercaseSingleValue() throws Exception { ResponseEntity result = rest @@ -406,6 +428,20 @@ public class RestApplicationTests { return value -> "(" + value.trim().toUpperCase() + ")"; } + @Bean + public Function, Message> messages() { + return value -> MessageBuilder + .withPayload("(" + value.getPayload().trim().toUpperCase() + ")") + .copyHeaders(value.getHeaders()).build(); + } + + @Bean + public Function>, Flux>> headers() { + return flux -> flux.map(value -> MessageBuilder + .withPayload("(" + value.getPayload().trim().toUpperCase() + ")") + .setHeader("foo", "bar").build()); + } + @Bean public Function, Flux> upFoos() { return flux -> flux.log()