Add HTTP headers to outgoing messages by default

Some care is required to prevent request-specific headers being
reflected and interfering with content negotiation.

Fixes gh-207
This commit is contained in:
Dave Syer
2018-09-12 15:33:24 +01:00
parent ed3a532f96
commit d1c423e161
5 changed files with 77 additions and 20 deletions

View File

@@ -78,14 +78,17 @@ public class RequestProcessor {
public Mono<ResponseEntity<?>> get(FunctionWrapper wrapper) {
if (wrapper.function() != null) {
return response(wrapper.function(), value(wrapper.function(), wrapper.argument()), true, true);
return response(wrapper, wrapper.function(),
value(wrapper.function(), wrapper.argument()), true, true);
}
else {
return response(wrapper.supplier(), wrapper.supplier().get(), null, true);
return response(wrapper, wrapper.supplier(), wrapper.supplier().get(), null,
true);
}
}
public Mono<ResponseEntity<?>> post(FunctionWrapper wrapper, String body, boolean stream) {
public Mono<ResponseEntity<?>> post(FunctionWrapper wrapper, String body,
boolean stream) {
Mono<ResponseEntity<?>> responseEntityMono;
Object function = wrapper.handler();
@@ -117,14 +120,15 @@ public class RequestProcessor {
public Mono<ResponseEntity<?>> stream(FunctionWrapper request) {
Publisher<?> result = request.function() != null
? value(request.function(), request.argument())
: request.supplier().get();
: request.supplier().get();
return stream(request, result);
}
private Mono<ResponseEntity<?>> post(FunctionWrapper wrapper, Object body,
MultiValueMap<String, String> params, boolean stream) {
Iterable<?> iterable = body instanceof Collection ? (List<?>) body : Collections.singletonList(body);
Iterable<?> iterable = body instanceof Collection ? (List<?>) body
: Collections.singletonList(body);
Function<Publisher<?>, Publisher<?>> function = wrapper.function();
Consumer<Publisher<?>> consumer = wrapper.consumer();
@@ -146,13 +150,15 @@ public class RequestProcessor {
responseEntityMono = stream(wrapper, result);
}
else {
responseEntityMono = response(function, result, body == null ? null : !(body instanceof Collection), false);
responseEntityMono = response(wrapper, function, result,
body == null ? null : !(body instanceof Collection), false);
}
}
else if (consumer != null) {
consumer.accept(flux);
logger.debug("Handled POST with consumer");
responseEntityMono = Mono.just(ResponseEntity.status(HttpStatus.ACCEPTED).build());
responseEntityMono = Mono
.just(ResponseEntity.status(HttpStatus.ACCEPTED).build());
}
return responseEntityMono;
}
@@ -175,13 +181,16 @@ public class RequestProcessor {
.map(message -> MessageUtils.unpack(request.handler(), message)
.getPayload());
}
else {
builder.headers(HeaderUtils.sanitize(request.headers()));
}
Publisher<?> output = result;
return Flux.from(output).then(Mono.fromSupplier(() -> builder.body(output)));
}
private Mono<ResponseEntity<?>> response(Object handler, Publisher<?> result,
Boolean single, boolean getter) {
private Mono<ResponseEntity<?>> response(FunctionWrapper request, Object handler,
Publisher<?> result, Boolean single, boolean getter) {
BodyBuilder builder = ResponseEntity.ok();
if (inspector.isMessage(handler)) {
@@ -189,8 +198,12 @@ public class RequestProcessor {
.doOnNext(value -> addHeaders(builder, (Message<?>) value))
.map(message -> MessageUtils.unpack(handler, message).getPayload());
}
else {
builder.headers(HeaderUtils.sanitize(request.headers()));
}
if (isOutputSingle(handler) && (single != null && single || getter || isInputMultiple(handler))) {
if (isOutputSingle(handler)
&& (single != null && single || getter || isInputMultiple(handler))) {
result = Mono.from(result);
}
@@ -213,7 +226,8 @@ public class RequestProcessor {
return false;
}
else {
return wrapper == type || Mono.class.equals(wrapper) || Optional.class.equals(wrapper);
return wrapper == type || Mono.class.equals(wrapper)
|| Optional.class.equals(wrapper);
}
}

View File

@@ -18,6 +18,7 @@ package org.springframework.cloud.function.web.util;
import java.util.Arrays;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.springframework.http.HttpHeaders;
@@ -30,10 +31,17 @@ import org.springframework.messaging.MessageHeaders;
public class HeaderUtils {
private static HttpHeaders IGNORED = new HttpHeaders();
private static HttpHeaders REQUEST_ONLY = new HttpHeaders();
static {
IGNORED.add(MessageHeaders.ID, "");
IGNORED.add(HttpHeaders.CONTENT_LENGTH, "0");
// Headers that would typically be added by a downstream client
REQUEST_ONLY.add(HttpHeaders.ACCEPT, "");
REQUEST_ONLY.add(HttpHeaders.CONTENT_LENGTH, "");
REQUEST_ONLY.add(HttpHeaders.CONTENT_TYPE, "");
REQUEST_ONLY.add(HttpHeaders.HOST, "");
}
public static HttpHeaders fromMessage(MessageHeaders headers, HttpHeaders request) {
@@ -51,6 +59,18 @@ public class HeaderUtils {
return result;
}
public static HttpHeaders sanitize(HttpHeaders request) {
HttpHeaders result = new HttpHeaders();
for (String name : request.keySet()) {
List<String> value = request.get(name);
name = name.toLowerCase();
if (!IGNORED.containsKey(name) && !REQUEST_ONLY.containsKey(name)) {
result.put(name, value);
}
}
return result;
}
private static Collection<?> multi(Object value) {
return value instanceof Collection ? (Collection<?>) value : Arrays.asList(value);
}

View File

@@ -16,6 +16,7 @@
package org.springframework.cloud.function.web.mvc;
import java.net.URI;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.function.Function;
@@ -30,6 +31,8 @@ import org.springframework.boot.test.web.client.TestRestTemplate;
import org.springframework.cloud.function.web.RestApplication;
import org.springframework.cloud.function.web.mvc.HeadersToMessageTests.TestConfiguration;
import org.springframework.context.annotation.Bean;
import org.springframework.http.HttpEntity;
import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.messaging.Message;
import org.springframework.messaging.support.MessageBuilder;
@@ -46,8 +49,9 @@ import static org.junit.Assert.assertTrue;
*/
@RunWith(SpringRunner.class)
@SpringBootTest(webEnvironment = WebEnvironment.RANDOM_PORT, properties = {
"spring.main.web-application-type=servlet", "spring.cloud.function.web.path=/functions" })
@ContextConfiguration(classes= {RestApplication.class, TestConfiguration.class})
"spring.main.web-application-type=servlet",
"spring.cloud.function.web.path=/functions" })
@ContextConfiguration(classes = { RestApplication.class, TestConfiguration.class })
public class HeadersToMessageTests {
@Autowired
@@ -65,6 +69,17 @@ public class HeadersToMessageTests {
assertEquals("bar", postForEntity.getHeaders().get("foo").get(0));
}
@Test
public void testHeadersPropagatedByDefault() throws Exception {
HttpEntity<String> postForEntity = rest.exchange(RequestEntity
.post(new URI("/functions/vanilla")).header("x-context-type", "rubbish")
.body("{\"name\":\"Bob\",\"age\":25}"), String.class);
assertEquals("{\"name\":\"Bob\",\"age\":25,\"foo\":\"bar\"}",
postForEntity.getBody());
assertTrue(postForEntity.getHeaders().containsKey("x-context-type"));
assertEquals("rubbish", postForEntity.getHeaders().get("x-context-type").get(0));
}
@EnableAutoConfiguration
@org.springframework.boot.test.context.TestConfiguration
protected static class TestConfiguration {
@@ -78,5 +93,14 @@ public class HeadersToMessageTests {
return message;
};
}
@Bean
public Function<Map<String, Object>, Map<String, Object>> vanilla() {
return request -> {
Map<String, Object> message = new LinkedHashMap<>(request);
message.put("foo", "bar");
return message;
};
}
}
}