Add support for Message handling in web functions
If the output is a Message we grab the headers from the first one and send those as response headers. A function can add headers to messages. The HTTP response will contain only headers that are in x-* _or_ were added to the message by user (i.e. they weren't in the request). We certainly do not want to pass through all HTTP headers (request headers can conflict with or invalidate response headers).
This commit is contained in:
@@ -31,9 +31,12 @@ import org.apache.commons.logging.LogFactory;
|
||||
|
||||
import org.springframework.cloud.function.context.FunctionInspector;
|
||||
import org.springframework.cloud.function.web.flux.constants.WebRequestConstants;
|
||||
import org.springframework.cloud.function.web.util.HeaderUtils;
|
||||
import org.springframework.core.MethodParameter;
|
||||
import org.springframework.core.Ordered;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.server.ServletServerHttpRequest;
|
||||
import org.springframework.messaging.support.MessageBuilder;
|
||||
import org.springframework.util.StreamUtils;
|
||||
import org.springframework.web.bind.support.WebDataBinderFactory;
|
||||
import org.springframework.web.context.request.NativeWebRequest;
|
||||
@@ -78,6 +81,7 @@ public class FluxHandlerMethodArgumentResolver
|
||||
if (type == null) {
|
||||
type = Object.class;
|
||||
}
|
||||
boolean message = inspector.isMessage(inspector.getName(handler));
|
||||
List<Object> body;
|
||||
ContentCachingRequestWrapper nativeRequest = new ContentCachingRequestWrapper(
|
||||
webRequest.getNativeRequest(HttpServletRequest.class));
|
||||
@@ -100,6 +104,17 @@ public class FluxHandlerMethodArgumentResolver
|
||||
mapper.readValue(nativeRequest.getContentAsByteArray(), type));
|
||||
}
|
||||
}
|
||||
if (message) {
|
||||
List<Object> messages = new ArrayList<>();
|
||||
for (Object payload : body) {
|
||||
messages.add(MessageBuilder.withPayload(payload).copyHeaders(
|
||||
HeaderUtils.fromHttp(new ServletServerHttpRequest(
|
||||
webRequest.getNativeRequest(HttpServletRequest.class))
|
||||
.getHeaders()))
|
||||
.build());
|
||||
}
|
||||
body = messages;
|
||||
}
|
||||
return new FluxRequest<Object>(body);
|
||||
}
|
||||
|
||||
|
||||
@@ -30,25 +30,27 @@ import reactor.core.publisher.Flux;
|
||||
*
|
||||
* @author Dave Syer
|
||||
*/
|
||||
class FluxResponseBodyEmitter<T> extends ResponseBodyEmitter {
|
||||
class FluxResponseBodyEmitter extends ResponseBodyEmitter {
|
||||
|
||||
private final MediaType mediaType;
|
||||
private ResponseBodyEmitterSubscriber subscriber;
|
||||
|
||||
public FluxResponseBodyEmitter(Publisher<T> observable) {
|
||||
this(null, observable);
|
||||
public FluxResponseBodyEmitter(Publisher<?> observable) {
|
||||
this(new HttpHeaders(), null, observable);
|
||||
}
|
||||
|
||||
public FluxResponseBodyEmitter(MediaType mediaType, Publisher<T> observable) {
|
||||
public FluxResponseBodyEmitter(HttpHeaders request, MediaType mediaType,
|
||||
Publisher<?> observable) {
|
||||
super();
|
||||
this.mediaType = mediaType;
|
||||
new ResponseBodyEmitterSubscriber<>(mediaType, observable, this,
|
||||
MediaType.APPLICATION_JSON.isCompatibleWith(mediaType));
|
||||
this.subscriber = new ResponseBodyEmitterSubscriber(request, mediaType,
|
||||
observable, this, MediaType.APPLICATION_JSON.isCompatibleWith(mediaType));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void extendResponse(ServerHttpResponse outputMessage) {
|
||||
super.extendResponse(outputMessage);
|
||||
|
||||
this.subscriber.extendResponse(outputMessage);
|
||||
HttpHeaders headers = outputMessage.getHeaders();
|
||||
if (headers.getContentType() == null && this.mediaType != null
|
||||
&& !MediaType.ALL.equals(this.mediaType)) {
|
||||
|
||||
@@ -18,7 +18,9 @@ package org.springframework.cloud.function.web.flux.response;
|
||||
|
||||
import org.reactivestreams.Publisher;
|
||||
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.server.ServerHttpResponse;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitter;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
@@ -30,15 +32,24 @@ import reactor.core.publisher.Flux;
|
||||
*
|
||||
* @author Dave Syer
|
||||
*/
|
||||
class FluxResponseSseEmitter<T> extends SseEmitter {
|
||||
class FluxResponseSseEmitter extends SseEmitter {
|
||||
|
||||
public FluxResponseSseEmitter(Publisher<T> observable) {
|
||||
this(MediaType.valueOf("text/plain"), observable);
|
||||
private ResponseBodyEmitterSubscriber subscriber;
|
||||
|
||||
public FluxResponseSseEmitter(Publisher<?> observable) {
|
||||
this(new HttpHeaders(), MediaType.valueOf("text/plain"), observable);
|
||||
}
|
||||
|
||||
public FluxResponseSseEmitter(MediaType mediaType, Publisher<T> observable) {
|
||||
public FluxResponseSseEmitter(HttpHeaders request, MediaType mediaType,
|
||||
Publisher<?> observable) {
|
||||
super();
|
||||
new ResponseBodyEmitterSubscriber<>(mediaType, observable, this, false);
|
||||
this.subscriber = new ResponseBodyEmitterSubscriber(request, mediaType,
|
||||
observable, this, false);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void extendResponse(ServerHttpResponse outputMessage) {
|
||||
super.extendResponse(outputMessage);
|
||||
this.subscriber.extendResponse(outputMessage);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ import java.util.List;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
@@ -32,11 +33,15 @@ import org.reactivestreams.Publisher;
|
||||
|
||||
import org.springframework.cloud.function.context.FunctionInspector;
|
||||
import org.springframework.cloud.function.web.flux.constants.WebRequestConstants;
|
||||
import org.springframework.cloud.function.web.util.HeaderUtils;
|
||||
import org.springframework.core.MethodParameter;
|
||||
import org.springframework.core.ResolvableType;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.http.converter.HttpMessageConverter;
|
||||
import org.springframework.http.server.ServletServerHttpRequest;
|
||||
import org.springframework.messaging.Message;
|
||||
import org.springframework.util.ReflectionUtils;
|
||||
import org.springframework.web.context.request.NativeWebRequest;
|
||||
import org.springframework.web.method.support.AsyncHandlerMethodReturnValueHandler;
|
||||
@@ -130,8 +135,16 @@ public class FluxReturnValueHandler implements AsyncHandlerMethodReturnValueHand
|
||||
if (returnValue instanceof ResponseEntity) {
|
||||
ResponseEntity<?> value = (ResponseEntity<?>) returnValue;
|
||||
adaptFrom = value.getBody();
|
||||
webRequest.getNativeResponse(HttpServletResponse.class)
|
||||
.setStatus(value.getStatusCodeValue());
|
||||
HttpServletResponse response = webRequest
|
||||
.getNativeResponse(HttpServletResponse.class);
|
||||
response.setStatus(value.getStatusCodeValue());
|
||||
HttpHeaders headers = value.getHeaders();
|
||||
for (String name : headers.keySet()) {
|
||||
List<String> list = headers.get(name);
|
||||
for (String header : list) {
|
||||
response.addHeader(name, header);
|
||||
}
|
||||
}
|
||||
}
|
||||
Publisher<?> flux = (Publisher<?>) adaptFrom;
|
||||
|
||||
@@ -141,8 +154,13 @@ public class FluxReturnValueHandler implements AsyncHandlerMethodReturnValueHand
|
||||
|
||||
boolean inputSingle = isInputSingle(webRequest, handler);
|
||||
if (inputSingle && isOutputSingle(handler)) {
|
||||
single.handleReturnValue(Flux.from(flux).blockFirst(), singleReturnType,
|
||||
mavContainer, webRequest);
|
||||
Object result = Flux.from(flux).blockFirst();
|
||||
if (result instanceof Message) {
|
||||
Message<?> message = (Message<?>) result;
|
||||
result = message.getPayload();
|
||||
addHeaders(webRequest, message);
|
||||
}
|
||||
single.handleReturnValue(result, singleReturnType, mavContainer, webRequest);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -157,10 +175,27 @@ public class FluxReturnValueHandler implements AsyncHandlerMethodReturnValueHand
|
||||
logger.debug(
|
||||
"Handling return value " + type + " with media type: " + mediaType);
|
||||
}
|
||||
delegate.handleReturnValue(getEmitter(timeout, flux, mediaType), returnType,
|
||||
ServletServerHttpRequest request = new ServletServerHttpRequest(
|
||||
webRequest.getNativeRequest(HttpServletRequest.class));
|
||||
delegate.handleReturnValue(
|
||||
getEmitter(timeout, flux, mediaType, request.getHeaders()), returnType,
|
||||
mavContainer, webRequest);
|
||||
}
|
||||
|
||||
private void addHeaders(NativeWebRequest webRequest, Message<?> message) {
|
||||
HttpServletResponse response = webRequest
|
||||
.getNativeResponse(HttpServletResponse.class);
|
||||
ServletServerHttpRequest request = new ServletServerHttpRequest(
|
||||
webRequest.getNativeRequest(HttpServletRequest.class));
|
||||
HttpHeaders headers = HeaderUtils.fromMessage(message.getHeaders(),
|
||||
request.getHeaders());
|
||||
for (String name : headers.keySet()) {
|
||||
for (Object object : headers.get(name)) {
|
||||
response.addHeader(name, object.toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private boolean isInputSingle(NativeWebRequest webRequest, Object handler) {
|
||||
Boolean single = (Boolean) webRequest.getAttribute(
|
||||
WebRequestConstants.INPUT_SINGLE, NativeWebRequest.SCOPE_REQUEST);
|
||||
@@ -218,15 +253,16 @@ public class FluxReturnValueHandler implements AsyncHandlerMethodReturnValueHand
|
||||
}
|
||||
|
||||
private ResponseBodyEmitter getEmitter(Long timeout, Publisher<?> flux,
|
||||
MediaType mediaType) {
|
||||
MediaType mediaType, HttpHeaders request) {
|
||||
Publisher<?> exported = flux instanceof Mono ? Mono.from(flux)
|
||||
: Flux.from(flux).timeout(Duration.ofMillis(timeout), Flux.empty());
|
||||
if (!MediaType.ALL.equals(mediaType)
|
||||
&& EVENT_STREAM.isCompatibleWith(mediaType)) {
|
||||
// TODO: more subtle content negotiation
|
||||
return new FluxResponseSseEmitter<>(MediaType.APPLICATION_JSON, exported);
|
||||
return new FluxResponseSseEmitter(request, MediaType.APPLICATION_JSON,
|
||||
exported);
|
||||
}
|
||||
return new FluxResponseBodyEmitter<>(mediaType, exported);
|
||||
return new FluxResponseBodyEmitter(request, mediaType, exported);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -23,7 +23,11 @@ import org.reactivestreams.Publisher;
|
||||
import org.reactivestreams.Subscriber;
|
||||
import org.reactivestreams.Subscription;
|
||||
|
||||
import org.springframework.cloud.function.web.util.HeaderUtils;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.server.ServerHttpResponse;
|
||||
import org.springframework.messaging.Message;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitter;
|
||||
|
||||
import reactor.core.publisher.Flux;
|
||||
@@ -35,7 +39,7 @@ import reactor.core.publisher.Mono;
|
||||
*
|
||||
* @author Dave Syer
|
||||
*/
|
||||
class ResponseBodyEmitterSubscriber<T> implements Subscriber<T> {
|
||||
class ResponseBodyEmitterSubscriber implements Subscriber<Object> {
|
||||
|
||||
private final MediaType mediaType;
|
||||
|
||||
@@ -49,11 +53,17 @@ class ResponseBodyEmitterSubscriber<T> implements Subscriber<T> {
|
||||
|
||||
private boolean single;
|
||||
|
||||
private boolean json;
|
||||
private final boolean json;
|
||||
|
||||
public ResponseBodyEmitterSubscriber(MediaType mediaType, Publisher<T> observable,
|
||||
ResponseBodyEmitter responseBodyEmitter, boolean json) {
|
||||
private Message<?> first;
|
||||
|
||||
private final HttpHeaders request;
|
||||
|
||||
public ResponseBodyEmitterSubscriber(HttpHeaders request, MediaType mediaType,
|
||||
Publisher<?> observable, ResponseBodyEmitter responseBodyEmitter,
|
||||
boolean json) {
|
||||
|
||||
this.request = request;
|
||||
this.mediaType = mediaType;
|
||||
this.responseBodyEmitter = responseBodyEmitter;
|
||||
this.json = json;
|
||||
@@ -63,6 +73,10 @@ class ResponseBodyEmitterSubscriber<T> implements Subscriber<T> {
|
||||
observable.subscribe(this);
|
||||
}
|
||||
|
||||
public void extendResponse(ServerHttpResponse response) {
|
||||
headers(response);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onSubscribe(Subscription subscription) {
|
||||
this.subscription = subscription;
|
||||
@@ -70,10 +84,16 @@ class ResponseBodyEmitterSubscriber<T> implements Subscriber<T> {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onNext(T value) {
|
||||
public void onNext(Object value) {
|
||||
|
||||
Object object = value;
|
||||
|
||||
if (object instanceof Message) {
|
||||
Message<?> message = (Message<?>) object;
|
||||
object = message.getPayload();
|
||||
this.first = message;
|
||||
}
|
||||
|
||||
try {
|
||||
if (isJson()) {
|
||||
if (!this.firstElementWritten) {
|
||||
@@ -85,9 +105,9 @@ class ResponseBodyEmitterSubscriber<T> implements Subscriber<T> {
|
||||
else {
|
||||
responseBodyEmitter.send(",");
|
||||
}
|
||||
if (!single && value.getClass() == String.class
|
||||
&& !((String) value).contains("\"")) {
|
||||
object = "\"" + value + "\"";
|
||||
if (!single && object.getClass() == String.class
|
||||
&& !((String) object).contains("\"")) {
|
||||
object = "\"" + object + "\"";
|
||||
}
|
||||
}
|
||||
if (!completed) {
|
||||
@@ -101,6 +121,24 @@ class ResponseBodyEmitterSubscriber<T> implements Subscriber<T> {
|
||||
}
|
||||
}
|
||||
|
||||
private void headers(ServerHttpResponse response) {
|
||||
if (this.first != null) {
|
||||
Message<?> message = first;
|
||||
try {
|
||||
HttpHeaders headers = HeaderUtils.fromMessage(message.getHeaders(),
|
||||
request);
|
||||
for (String name : headers.keySet()) {
|
||||
for (String value : headers.get(name)) {
|
||||
response.getHeaders().add(name, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (Exception e) {
|
||||
// Headers could not be set
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable e) {
|
||||
if (!completed) {
|
||||
@@ -170,4 +208,5 @@ class ResponseBodyEmitterSubscriber<T> implements Subscriber<T> {
|
||||
ResponseBodyEmitterSubscriber.this.subscription.cancel();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
/*
|
||||
* Copyright 2016-2017 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.springframework.cloud.function.web.util;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.messaging.MessageHeaders;
|
||||
import org.springframework.util.ObjectUtils;
|
||||
|
||||
/**
|
||||
* @author Dave Syer
|
||||
*
|
||||
*/
|
||||
public class HeaderUtils {
|
||||
|
||||
public static HttpHeaders fromMessage(MessageHeaders headers, HttpHeaders request) {
|
||||
HttpHeaders result = new HttpHeaders();
|
||||
for (String name : headers.keySet()) {
|
||||
Object value = headers.get(name);
|
||||
name = name.toLowerCase();
|
||||
if (MessageHeaders.ID.equals(name)) {
|
||||
continue;
|
||||
}
|
||||
if (request.containsKey(name)) {
|
||||
if (name.startsWith("x-")) {
|
||||
if (!name.startsWith("x-forwarded")) {
|
||||
Collection<?> values = multi(value);
|
||||
for (Object object : values) {
|
||||
result.set(name, object.toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
Collection<?> values = multi(value);
|
||||
for (Object object : values) {
|
||||
result.set(name, object.toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private static Collection<?> multi(Object value) {
|
||||
if (value instanceof Collection) {
|
||||
Collection<?> collection = (Collection<?>) value;
|
||||
return collection;
|
||||
}
|
||||
else if (ObjectUtils.isArray(value)) {
|
||||
Object[] values = ObjectUtils.toObjectArray(value);
|
||||
return Arrays.asList(values);
|
||||
}
|
||||
return Arrays.asList(value);
|
||||
}
|
||||
|
||||
public static MessageHeaders fromHttp(HttpHeaders headers) {
|
||||
Map<String, Object> map = new LinkedHashMap<>();
|
||||
for (String name : headers.keySet()) {
|
||||
Collection<?> values = multi(headers.get(name));
|
||||
name = name.toLowerCase();
|
||||
Object value = values == null ? null
|
||||
: (values.size() == 1 ? values.iterator().next() : values);
|
||||
map.put(name, value);
|
||||
}
|
||||
return new MessageHeaders(map);
|
||||
}
|
||||
}
|
||||
@@ -45,6 +45,8 @@ import org.springframework.http.HttpStatus;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.RequestEntity;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.messaging.Message;
|
||||
import org.springframework.messaging.support.MessageBuilder;
|
||||
import org.springframework.test.context.junit4.SpringRunner;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
@@ -256,6 +258,26 @@ public class RestApplicationTests {
|
||||
assertThat(result.getBody()).isEqualTo("[\"(FOO)\",\"(BAR)\"]");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void messages() throws Exception {
|
||||
ResponseEntity<String> result = rest.exchange(RequestEntity
|
||||
.post(new URI("/messages")).contentType(MediaType.APPLICATION_JSON)
|
||||
.header("x-foo", "bar").body("[\"foo\",\"bar\"]"), String.class);
|
||||
assertThat(result.getBody()).isEqualTo("[\"(FOO)\",\"(BAR)\"]");
|
||||
assertThat(result.getHeaders().getFirst("x-foo")).isEqualTo("bar");
|
||||
assertThat(result.getHeaders()).doesNotContainKey("id");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void headers() throws Exception {
|
||||
ResponseEntity<String> result = rest.exchange(RequestEntity
|
||||
.post(new URI("/headers")).contentType(MediaType.APPLICATION_JSON)
|
||||
.body("[\"foo\",\"bar\"]"), String.class);
|
||||
assertThat(result.getBody()).isEqualTo("[\"(FOO)\",\"(BAR)\"]");
|
||||
assertThat(result.getHeaders().getFirst("foo")).isEqualTo("bar");
|
||||
assertThat(result.getHeaders()).doesNotContainKey("id");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void uppercaseSingleValue() throws Exception {
|
||||
ResponseEntity<String> result = rest
|
||||
@@ -406,6 +428,20 @@ public class RestApplicationTests {
|
||||
return value -> "(" + value.trim().toUpperCase() + ")";
|
||||
}
|
||||
|
||||
@Bean
|
||||
public Function<Message<String>, Message<String>> messages() {
|
||||
return value -> MessageBuilder
|
||||
.withPayload("(" + value.getPayload().trim().toUpperCase() + ")")
|
||||
.copyHeaders(value.getHeaders()).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public Function<Flux<Message<String>>, Flux<Message<String>>> headers() {
|
||||
return flux -> flux.map(value -> MessageBuilder
|
||||
.withPayload("(" + value.getPayload().trim().toUpperCase() + ")")
|
||||
.setHeader("foo", "bar").build());
|
||||
}
|
||||
|
||||
@Bean
|
||||
public Function<Flux<Foo>, Flux<Foo>> upFoos() {
|
||||
return flux -> flux.log()
|
||||
|
||||
Reference in New Issue
Block a user