Add support for Message handling in web functions

If the output is a Message we grab the headers from the first one
and send those as response headers.

A function can add headers to messages. The HTTP response will
contain only headers that are in x-* _or_ were added to the message
by user (i.e. they weren't in the request).
We certainly do not want to pass through all HTTP headers (request
headers can conflict with or invalidate response headers).
This commit is contained in:
Dave Syer
2017-06-15 15:37:55 +01:00
parent 1d53cd1234
commit b61562508b
7 changed files with 251 additions and 28 deletions

View File

@@ -31,9 +31,12 @@ import org.apache.commons.logging.LogFactory;
import org.springframework.cloud.function.context.FunctionInspector;
import org.springframework.cloud.function.web.flux.constants.WebRequestConstants;
import org.springframework.cloud.function.web.util.HeaderUtils;
import org.springframework.core.MethodParameter;
import org.springframework.core.Ordered;
import org.springframework.http.MediaType;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.StreamUtils;
import org.springframework.web.bind.support.WebDataBinderFactory;
import org.springframework.web.context.request.NativeWebRequest;
@@ -78,6 +81,7 @@ public class FluxHandlerMethodArgumentResolver
if (type == null) {
type = Object.class;
}
boolean message = inspector.isMessage(inspector.getName(handler));
List<Object> body;
ContentCachingRequestWrapper nativeRequest = new ContentCachingRequestWrapper(
webRequest.getNativeRequest(HttpServletRequest.class));
@@ -100,6 +104,17 @@ public class FluxHandlerMethodArgumentResolver
mapper.readValue(nativeRequest.getContentAsByteArray(), type));
}
}
if (message) {
List<Object> messages = new ArrayList<>();
for (Object payload : body) {
messages.add(MessageBuilder.withPayload(payload).copyHeaders(
HeaderUtils.fromHttp(new ServletServerHttpRequest(
webRequest.getNativeRequest(HttpServletRequest.class))
.getHeaders()))
.build());
}
body = messages;
}
return new FluxRequest<Object>(body);
}

View File

@@ -30,25 +30,27 @@ import reactor.core.publisher.Flux;
*
* @author Dave Syer
*/
class FluxResponseBodyEmitter<T> extends ResponseBodyEmitter {
class FluxResponseBodyEmitter extends ResponseBodyEmitter {
private final MediaType mediaType;
private ResponseBodyEmitterSubscriber subscriber;
public FluxResponseBodyEmitter(Publisher<T> observable) {
this(null, observable);
public FluxResponseBodyEmitter(Publisher<?> observable) {
this(new HttpHeaders(), null, observable);
}
public FluxResponseBodyEmitter(MediaType mediaType, Publisher<T> observable) {
public FluxResponseBodyEmitter(HttpHeaders request, MediaType mediaType,
Publisher<?> observable) {
super();
this.mediaType = mediaType;
new ResponseBodyEmitterSubscriber<>(mediaType, observable, this,
MediaType.APPLICATION_JSON.isCompatibleWith(mediaType));
this.subscriber = new ResponseBodyEmitterSubscriber(request, mediaType,
observable, this, MediaType.APPLICATION_JSON.isCompatibleWith(mediaType));
}
@Override
protected void extendResponse(ServerHttpResponse outputMessage) {
super.extendResponse(outputMessage);
this.subscriber.extendResponse(outputMessage);
HttpHeaders headers = outputMessage.getHeaders();
if (headers.getContentType() == null && this.mediaType != null
&& !MediaType.ALL.equals(this.mediaType)) {

View File

@@ -18,7 +18,9 @@ package org.springframework.cloud.function.web.flux.response;
import org.reactivestreams.Publisher;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitter;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
@@ -30,15 +32,24 @@ import reactor.core.publisher.Flux;
*
* @author Dave Syer
*/
class FluxResponseSseEmitter<T> extends SseEmitter {
class FluxResponseSseEmitter extends SseEmitter {
public FluxResponseSseEmitter(Publisher<T> observable) {
this(MediaType.valueOf("text/plain"), observable);
private ResponseBodyEmitterSubscriber subscriber;
public FluxResponseSseEmitter(Publisher<?> observable) {
this(new HttpHeaders(), MediaType.valueOf("text/plain"), observable);
}
public FluxResponseSseEmitter(MediaType mediaType, Publisher<T> observable) {
public FluxResponseSseEmitter(HttpHeaders request, MediaType mediaType,
Publisher<?> observable) {
super();
new ResponseBodyEmitterSubscriber<>(mediaType, observable, this, false);
this.subscriber = new ResponseBodyEmitterSubscriber(request, mediaType,
observable, this, false);
}
@Override
protected void extendResponse(ServerHttpResponse outputMessage) {
super.extendResponse(outputMessage);
this.subscriber.extendResponse(outputMessage);
}
}

View File

@@ -23,6 +23,7 @@ import java.util.List;
import java.util.function.Supplier;
import java.util.stream.Stream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.apache.commons.logging.Log;
@@ -32,11 +33,15 @@ import org.reactivestreams.Publisher;
import org.springframework.cloud.function.context.FunctionInspector;
import org.springframework.cloud.function.web.flux.constants.WebRequestConstants;
import org.springframework.cloud.function.web.util.HeaderUtils;
import org.springframework.core.MethodParameter;
import org.springframework.core.ResolvableType;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.messaging.Message;
import org.springframework.util.ReflectionUtils;
import org.springframework.web.context.request.NativeWebRequest;
import org.springframework.web.method.support.AsyncHandlerMethodReturnValueHandler;
@@ -130,8 +135,16 @@ public class FluxReturnValueHandler implements AsyncHandlerMethodReturnValueHand
if (returnValue instanceof ResponseEntity) {
ResponseEntity<?> value = (ResponseEntity<?>) returnValue;
adaptFrom = value.getBody();
webRequest.getNativeResponse(HttpServletResponse.class)
.setStatus(value.getStatusCodeValue());
HttpServletResponse response = webRequest
.getNativeResponse(HttpServletResponse.class);
response.setStatus(value.getStatusCodeValue());
HttpHeaders headers = value.getHeaders();
for (String name : headers.keySet()) {
List<String> list = headers.get(name);
for (String header : list) {
response.addHeader(name, header);
}
}
}
Publisher<?> flux = (Publisher<?>) adaptFrom;
@@ -141,8 +154,13 @@ public class FluxReturnValueHandler implements AsyncHandlerMethodReturnValueHand
boolean inputSingle = isInputSingle(webRequest, handler);
if (inputSingle && isOutputSingle(handler)) {
single.handleReturnValue(Flux.from(flux).blockFirst(), singleReturnType,
mavContainer, webRequest);
Object result = Flux.from(flux).blockFirst();
if (result instanceof Message) {
Message<?> message = (Message<?>) result;
result = message.getPayload();
addHeaders(webRequest, message);
}
single.handleReturnValue(result, singleReturnType, mavContainer, webRequest);
return;
}
@@ -157,10 +175,27 @@ public class FluxReturnValueHandler implements AsyncHandlerMethodReturnValueHand
logger.debug(
"Handling return value " + type + " with media type: " + mediaType);
}
delegate.handleReturnValue(getEmitter(timeout, flux, mediaType), returnType,
ServletServerHttpRequest request = new ServletServerHttpRequest(
webRequest.getNativeRequest(HttpServletRequest.class));
delegate.handleReturnValue(
getEmitter(timeout, flux, mediaType, request.getHeaders()), returnType,
mavContainer, webRequest);
}
private void addHeaders(NativeWebRequest webRequest, Message<?> message) {
HttpServletResponse response = webRequest
.getNativeResponse(HttpServletResponse.class);
ServletServerHttpRequest request = new ServletServerHttpRequest(
webRequest.getNativeRequest(HttpServletRequest.class));
HttpHeaders headers = HeaderUtils.fromMessage(message.getHeaders(),
request.getHeaders());
for (String name : headers.keySet()) {
for (Object object : headers.get(name)) {
response.addHeader(name, object.toString());
}
}
}
private boolean isInputSingle(NativeWebRequest webRequest, Object handler) {
Boolean single = (Boolean) webRequest.getAttribute(
WebRequestConstants.INPUT_SINGLE, NativeWebRequest.SCOPE_REQUEST);
@@ -218,15 +253,16 @@ public class FluxReturnValueHandler implements AsyncHandlerMethodReturnValueHand
}
private ResponseBodyEmitter getEmitter(Long timeout, Publisher<?> flux,
MediaType mediaType) {
MediaType mediaType, HttpHeaders request) {
Publisher<?> exported = flux instanceof Mono ? Mono.from(flux)
: Flux.from(flux).timeout(Duration.ofMillis(timeout), Flux.empty());
if (!MediaType.ALL.equals(mediaType)
&& EVENT_STREAM.isCompatibleWith(mediaType)) {
// TODO: more subtle content negotiation
return new FluxResponseSseEmitter<>(MediaType.APPLICATION_JSON, exported);
return new FluxResponseSseEmitter(request, MediaType.APPLICATION_JSON,
exported);
}
return new FluxResponseBodyEmitter<>(mediaType, exported);
return new FluxResponseBodyEmitter(request, mediaType, exported);
}
}

View File

@@ -23,7 +23,11 @@ import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import org.springframework.cloud.function.web.util.HeaderUtils;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.messaging.Message;
import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitter;
import reactor.core.publisher.Flux;
@@ -35,7 +39,7 @@ import reactor.core.publisher.Mono;
*
* @author Dave Syer
*/
class ResponseBodyEmitterSubscriber<T> implements Subscriber<T> {
class ResponseBodyEmitterSubscriber implements Subscriber<Object> {
private final MediaType mediaType;
@@ -49,11 +53,17 @@ class ResponseBodyEmitterSubscriber<T> implements Subscriber<T> {
private boolean single;
private boolean json;
private final boolean json;
public ResponseBodyEmitterSubscriber(MediaType mediaType, Publisher<T> observable,
ResponseBodyEmitter responseBodyEmitter, boolean json) {
private Message<?> first;
private final HttpHeaders request;
public ResponseBodyEmitterSubscriber(HttpHeaders request, MediaType mediaType,
Publisher<?> observable, ResponseBodyEmitter responseBodyEmitter,
boolean json) {
this.request = request;
this.mediaType = mediaType;
this.responseBodyEmitter = responseBodyEmitter;
this.json = json;
@@ -63,6 +73,10 @@ class ResponseBodyEmitterSubscriber<T> implements Subscriber<T> {
observable.subscribe(this);
}
public void extendResponse(ServerHttpResponse response) {
headers(response);
}
@Override
public void onSubscribe(Subscription subscription) {
this.subscription = subscription;
@@ -70,10 +84,16 @@ class ResponseBodyEmitterSubscriber<T> implements Subscriber<T> {
}
@Override
public void onNext(T value) {
public void onNext(Object value) {
Object object = value;
if (object instanceof Message) {
Message<?> message = (Message<?>) object;
object = message.getPayload();
this.first = message;
}
try {
if (isJson()) {
if (!this.firstElementWritten) {
@@ -85,9 +105,9 @@ class ResponseBodyEmitterSubscriber<T> implements Subscriber<T> {
else {
responseBodyEmitter.send(",");
}
if (!single && value.getClass() == String.class
&& !((String) value).contains("\"")) {
object = "\"" + value + "\"";
if (!single && object.getClass() == String.class
&& !((String) object).contains("\"")) {
object = "\"" + object + "\"";
}
}
if (!completed) {
@@ -101,6 +121,24 @@ class ResponseBodyEmitterSubscriber<T> implements Subscriber<T> {
}
}
private void headers(ServerHttpResponse response) {
if (this.first != null) {
Message<?> message = first;
try {
HttpHeaders headers = HeaderUtils.fromMessage(message.getHeaders(),
request);
for (String name : headers.keySet()) {
for (String value : headers.get(name)) {
response.getHeaders().add(name, value);
}
}
}
catch (Exception e) {
// Headers could not be set
}
}
}
@Override
public void onError(Throwable e) {
if (!completed) {
@@ -170,4 +208,5 @@ class ResponseBodyEmitterSubscriber<T> implements Subscriber<T> {
ResponseBodyEmitterSubscriber.this.subscription.cancel();
}
}
}

View File

@@ -0,0 +1,84 @@
/*
* Copyright 2016-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.function.web.util;
import java.util.Arrays;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.Map;
import org.springframework.http.HttpHeaders;
import org.springframework.messaging.MessageHeaders;
import org.springframework.util.ObjectUtils;
/**
* @author Dave Syer
*
*/
public class HeaderUtils {
public static HttpHeaders fromMessage(MessageHeaders headers, HttpHeaders request) {
HttpHeaders result = new HttpHeaders();
for (String name : headers.keySet()) {
Object value = headers.get(name);
name = name.toLowerCase();
if (MessageHeaders.ID.equals(name)) {
continue;
}
if (request.containsKey(name)) {
if (name.startsWith("x-")) {
if (!name.startsWith("x-forwarded")) {
Collection<?> values = multi(value);
for (Object object : values) {
result.set(name, object.toString());
}
}
}
}
else {
Collection<?> values = multi(value);
for (Object object : values) {
result.set(name, object.toString());
}
}
}
return result;
}
private static Collection<?> multi(Object value) {
if (value instanceof Collection) {
Collection<?> collection = (Collection<?>) value;
return collection;
}
else if (ObjectUtils.isArray(value)) {
Object[] values = ObjectUtils.toObjectArray(value);
return Arrays.asList(values);
}
return Arrays.asList(value);
}
public static MessageHeaders fromHttp(HttpHeaders headers) {
Map<String, Object> map = new LinkedHashMap<>();
for (String name : headers.keySet()) {
Collection<?> values = multi(headers.get(name));
name = name.toLowerCase();
Object value = values == null ? null
: (values.size() == 1 ? values.iterator().next() : values);
map.put(name, value);
}
return new MessageHeaders(map);
}
}

View File

@@ -45,6 +45,8 @@ import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.messaging.Message;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.test.context.junit4.SpringRunner;
import org.springframework.util.StringUtils;
@@ -256,6 +258,26 @@ public class RestApplicationTests {
assertThat(result.getBody()).isEqualTo("[\"(FOO)\",\"(BAR)\"]");
}
@Test
public void messages() throws Exception {
ResponseEntity<String> result = rest.exchange(RequestEntity
.post(new URI("/messages")).contentType(MediaType.APPLICATION_JSON)
.header("x-foo", "bar").body("[\"foo\",\"bar\"]"), String.class);
assertThat(result.getBody()).isEqualTo("[\"(FOO)\",\"(BAR)\"]");
assertThat(result.getHeaders().getFirst("x-foo")).isEqualTo("bar");
assertThat(result.getHeaders()).doesNotContainKey("id");
}
@Test
public void headers() throws Exception {
ResponseEntity<String> result = rest.exchange(RequestEntity
.post(new URI("/headers")).contentType(MediaType.APPLICATION_JSON)
.body("[\"foo\",\"bar\"]"), String.class);
assertThat(result.getBody()).isEqualTo("[\"(FOO)\",\"(BAR)\"]");
assertThat(result.getHeaders().getFirst("foo")).isEqualTo("bar");
assertThat(result.getHeaders()).doesNotContainKey("id");
}
@Test
public void uppercaseSingleValue() throws Exception {
ResponseEntity<String> result = rest
@@ -406,6 +428,20 @@ public class RestApplicationTests {
return value -> "(" + value.trim().toUpperCase() + ")";
}
@Bean
public Function<Message<String>, Message<String>> messages() {
return value -> MessageBuilder
.withPayload("(" + value.getPayload().trim().toUpperCase() + ")")
.copyHeaders(value.getHeaders()).build();
}
@Bean
public Function<Flux<Message<String>>, Flux<Message<String>>> headers() {
return flux -> flux.map(value -> MessageBuilder
.withPayload("(" + value.getPayload().trim().toUpperCase() + ")")
.setHeader("foo", "bar").build());
}
@Bean
public Function<Flux<Foo>, Flux<Foo>> upFoos() {
return flux -> flux.log()