Add HTTP headers to outgoing messages by default

Some care is required to prevent request-specific headers being
reflected and interfering with content negotiation.

Fixes gh-207
This commit is contained in:
Dave Syer
2018-09-12 15:33:24 +01:00
parent ed3a532f96
commit d1c423e161
5 changed files with 77 additions and 20 deletions

View File

@@ -188,7 +188,7 @@ We recommend the http://eclipse.org/m2e/[m2eclipse] eclipse plugin when working
eclipse. If you don't already have m2eclipse installed it is available from the "eclipse
marketplace".
Also, you will need to install Kotlin pug-in from the "eclipse marketplace" to ensure projects
Also, you will need to install Kotlin plug-in from the "eclipse marketplace" to ensure projects
that depend on Kotlin complile under `m2eclipse` plugin (mentioned above).
NOTE: If you still see compile errors after installing Kotlin plugin, simply right-click on the project with error and _remove_ and then _add_ Kotlin Nature via ***Configure Kotlin*** feature.

View File

@@ -31,7 +31,6 @@ import org.springframework.cloud.function.core.FluxConsumer;
import org.springframework.cloud.function.core.FluxFunction;
import org.springframework.cloud.function.core.FluxSupplier;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
@@ -51,7 +50,8 @@ public class FunctionRegistration<T> {
private FunctionType type;
/**
* @deprecated as of v1.0.0 in favor of {@link #FunctionRegistration(Object, String...)}
* @deprecated as of v1.0.0 in favor of
* {@link #FunctionRegistration(Object, String...)}
*/
@Deprecated
public FunctionRegistration(T target) {
@@ -59,13 +59,12 @@ public class FunctionRegistration<T> {
this.target = target;
}
/**
* Creates instance of FunctionRegistration.
*
* @param target instance of {@link Supplier}, {@link Function} or {@link Consumer}
* @param names additional set of names for this registration. Additional names
* can be provided {@link #name(String)} or {@link #names(String...)} operations.
* @param names additional set of names for this registration. Additional names can be
* provided {@link #name(String)} or {@link #names(String...)} operations.
*/
public FunctionRegistration(T target, String... names) {
Assert.notNull(target, "'target' must not be null");

View File

@@ -78,14 +78,17 @@ public class RequestProcessor {
public Mono<ResponseEntity<?>> get(FunctionWrapper wrapper) {
if (wrapper.function() != null) {
return response(wrapper.function(), value(wrapper.function(), wrapper.argument()), true, true);
return response(wrapper, wrapper.function(),
value(wrapper.function(), wrapper.argument()), true, true);
}
else {
return response(wrapper.supplier(), wrapper.supplier().get(), null, true);
return response(wrapper, wrapper.supplier(), wrapper.supplier().get(), null,
true);
}
}
public Mono<ResponseEntity<?>> post(FunctionWrapper wrapper, String body, boolean stream) {
public Mono<ResponseEntity<?>> post(FunctionWrapper wrapper, String body,
boolean stream) {
Mono<ResponseEntity<?>> responseEntityMono;
Object function = wrapper.handler();
@@ -117,14 +120,15 @@ public class RequestProcessor {
public Mono<ResponseEntity<?>> stream(FunctionWrapper request) {
Publisher<?> result = request.function() != null
? value(request.function(), request.argument())
: request.supplier().get();
: request.supplier().get();
return stream(request, result);
}
private Mono<ResponseEntity<?>> post(FunctionWrapper wrapper, Object body,
MultiValueMap<String, String> params, boolean stream) {
Iterable<?> iterable = body instanceof Collection ? (List<?>) body : Collections.singletonList(body);
Iterable<?> iterable = body instanceof Collection ? (List<?>) body
: Collections.singletonList(body);
Function<Publisher<?>, Publisher<?>> function = wrapper.function();
Consumer<Publisher<?>> consumer = wrapper.consumer();
@@ -146,13 +150,15 @@ public class RequestProcessor {
responseEntityMono = stream(wrapper, result);
}
else {
responseEntityMono = response(function, result, body == null ? null : !(body instanceof Collection), false);
responseEntityMono = response(wrapper, function, result,
body == null ? null : !(body instanceof Collection), false);
}
}
else if (consumer != null) {
consumer.accept(flux);
logger.debug("Handled POST with consumer");
responseEntityMono = Mono.just(ResponseEntity.status(HttpStatus.ACCEPTED).build());
responseEntityMono = Mono
.just(ResponseEntity.status(HttpStatus.ACCEPTED).build());
}
return responseEntityMono;
}
@@ -175,13 +181,16 @@ public class RequestProcessor {
.map(message -> MessageUtils.unpack(request.handler(), message)
.getPayload());
}
else {
builder.headers(HeaderUtils.sanitize(request.headers()));
}
Publisher<?> output = result;
return Flux.from(output).then(Mono.fromSupplier(() -> builder.body(output)));
}
private Mono<ResponseEntity<?>> response(Object handler, Publisher<?> result,
Boolean single, boolean getter) {
private Mono<ResponseEntity<?>> response(FunctionWrapper request, Object handler,
Publisher<?> result, Boolean single, boolean getter) {
BodyBuilder builder = ResponseEntity.ok();
if (inspector.isMessage(handler)) {
@@ -189,8 +198,12 @@ public class RequestProcessor {
.doOnNext(value -> addHeaders(builder, (Message<?>) value))
.map(message -> MessageUtils.unpack(handler, message).getPayload());
}
else {
builder.headers(HeaderUtils.sanitize(request.headers()));
}
if (isOutputSingle(handler) && (single != null && single || getter || isInputMultiple(handler))) {
if (isOutputSingle(handler)
&& (single != null && single || getter || isInputMultiple(handler))) {
result = Mono.from(result);
}
@@ -213,7 +226,8 @@ public class RequestProcessor {
return false;
}
else {
return wrapper == type || Mono.class.equals(wrapper) || Optional.class.equals(wrapper);
return wrapper == type || Mono.class.equals(wrapper)
|| Optional.class.equals(wrapper);
}
}

View File

@@ -18,6 +18,7 @@ package org.springframework.cloud.function.web.util;
import java.util.Arrays;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.springframework.http.HttpHeaders;
@@ -30,10 +31,17 @@ import org.springframework.messaging.MessageHeaders;
public class HeaderUtils {
private static HttpHeaders IGNORED = new HttpHeaders();
private static HttpHeaders REQUEST_ONLY = new HttpHeaders();
static {
IGNORED.add(MessageHeaders.ID, "");
IGNORED.add(HttpHeaders.CONTENT_LENGTH, "0");
// Headers that would typically be added by a downstream client
REQUEST_ONLY.add(HttpHeaders.ACCEPT, "");
REQUEST_ONLY.add(HttpHeaders.CONTENT_LENGTH, "");
REQUEST_ONLY.add(HttpHeaders.CONTENT_TYPE, "");
REQUEST_ONLY.add(HttpHeaders.HOST, "");
}
public static HttpHeaders fromMessage(MessageHeaders headers, HttpHeaders request) {
@@ -51,6 +59,18 @@ public class HeaderUtils {
return result;
}
public static HttpHeaders sanitize(HttpHeaders request) {
HttpHeaders result = new HttpHeaders();
for (String name : request.keySet()) {
List<String> value = request.get(name);
name = name.toLowerCase();
if (!IGNORED.containsKey(name) && !REQUEST_ONLY.containsKey(name)) {
result.put(name, value);
}
}
return result;
}
private static Collection<?> multi(Object value) {
return value instanceof Collection ? (Collection<?>) value : Arrays.asList(value);
}

View File

@@ -16,6 +16,7 @@
package org.springframework.cloud.function.web.mvc;
import java.net.URI;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.function.Function;
@@ -30,6 +31,8 @@ import org.springframework.boot.test.web.client.TestRestTemplate;
import org.springframework.cloud.function.web.RestApplication;
import org.springframework.cloud.function.web.mvc.HeadersToMessageTests.TestConfiguration;
import org.springframework.context.annotation.Bean;
import org.springframework.http.HttpEntity;
import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.messaging.Message;
import org.springframework.messaging.support.MessageBuilder;
@@ -46,8 +49,9 @@ import static org.junit.Assert.assertTrue;
*/
@RunWith(SpringRunner.class)
@SpringBootTest(webEnvironment = WebEnvironment.RANDOM_PORT, properties = {
"spring.main.web-application-type=servlet", "spring.cloud.function.web.path=/functions" })
@ContextConfiguration(classes= {RestApplication.class, TestConfiguration.class})
"spring.main.web-application-type=servlet",
"spring.cloud.function.web.path=/functions" })
@ContextConfiguration(classes = { RestApplication.class, TestConfiguration.class })
public class HeadersToMessageTests {
@Autowired
@@ -65,6 +69,17 @@ public class HeadersToMessageTests {
assertEquals("bar", postForEntity.getHeaders().get("foo").get(0));
}
@Test
public void testHeadersPropagatedByDefault() throws Exception {
HttpEntity<String> postForEntity = rest.exchange(RequestEntity
.post(new URI("/functions/vanilla")).header("x-context-type", "rubbish")
.body("{\"name\":\"Bob\",\"age\":25}"), String.class);
assertEquals("{\"name\":\"Bob\",\"age\":25,\"foo\":\"bar\"}",
postForEntity.getBody());
assertTrue(postForEntity.getHeaders().containsKey("x-context-type"));
assertEquals("rubbish", postForEntity.getHeaders().get("x-context-type").get(0));
}
@EnableAutoConfiguration
@org.springframework.boot.test.context.TestConfiguration
protected static class TestConfiguration {
@@ -78,5 +93,14 @@ public class HeadersToMessageTests {
return message;
};
}
@Bean
public Function<Map<String, Object>, Map<String, Object>> vanilla() {
return request -> {
Map<String, Object> message = new LinkedHashMap<>(request);
message.put("foo", "bar");
return message;
};
}
}
}