Add HTTP headers to outgoing messages by default
Some care is required to prevent request-specific headers being reflected and interfering with content negotiation. Fixes gh-207
This commit is contained in:
@@ -78,14 +78,17 @@ public class RequestProcessor {
|
||||
|
||||
public Mono<ResponseEntity<?>> get(FunctionWrapper wrapper) {
|
||||
if (wrapper.function() != null) {
|
||||
return response(wrapper.function(), value(wrapper.function(), wrapper.argument()), true, true);
|
||||
return response(wrapper, wrapper.function(),
|
||||
value(wrapper.function(), wrapper.argument()), true, true);
|
||||
}
|
||||
else {
|
||||
return response(wrapper.supplier(), wrapper.supplier().get(), null, true);
|
||||
return response(wrapper, wrapper.supplier(), wrapper.supplier().get(), null,
|
||||
true);
|
||||
}
|
||||
}
|
||||
|
||||
public Mono<ResponseEntity<?>> post(FunctionWrapper wrapper, String body, boolean stream) {
|
||||
public Mono<ResponseEntity<?>> post(FunctionWrapper wrapper, String body,
|
||||
boolean stream) {
|
||||
Mono<ResponseEntity<?>> responseEntityMono;
|
||||
Object function = wrapper.handler();
|
||||
|
||||
@@ -117,14 +120,15 @@ public class RequestProcessor {
|
||||
public Mono<ResponseEntity<?>> stream(FunctionWrapper request) {
|
||||
Publisher<?> result = request.function() != null
|
||||
? value(request.function(), request.argument())
|
||||
: request.supplier().get();
|
||||
: request.supplier().get();
|
||||
return stream(request, result);
|
||||
}
|
||||
|
||||
private Mono<ResponseEntity<?>> post(FunctionWrapper wrapper, Object body,
|
||||
MultiValueMap<String, String> params, boolean stream) {
|
||||
|
||||
Iterable<?> iterable = body instanceof Collection ? (List<?>) body : Collections.singletonList(body);
|
||||
Iterable<?> iterable = body instanceof Collection ? (List<?>) body
|
||||
: Collections.singletonList(body);
|
||||
|
||||
Function<Publisher<?>, Publisher<?>> function = wrapper.function();
|
||||
Consumer<Publisher<?>> consumer = wrapper.consumer();
|
||||
@@ -146,13 +150,15 @@ public class RequestProcessor {
|
||||
responseEntityMono = stream(wrapper, result);
|
||||
}
|
||||
else {
|
||||
responseEntityMono = response(function, result, body == null ? null : !(body instanceof Collection), false);
|
||||
responseEntityMono = response(wrapper, function, result,
|
||||
body == null ? null : !(body instanceof Collection), false);
|
||||
}
|
||||
}
|
||||
else if (consumer != null) {
|
||||
consumer.accept(flux);
|
||||
logger.debug("Handled POST with consumer");
|
||||
responseEntityMono = Mono.just(ResponseEntity.status(HttpStatus.ACCEPTED).build());
|
||||
responseEntityMono = Mono
|
||||
.just(ResponseEntity.status(HttpStatus.ACCEPTED).build());
|
||||
}
|
||||
return responseEntityMono;
|
||||
}
|
||||
@@ -175,13 +181,16 @@ public class RequestProcessor {
|
||||
.map(message -> MessageUtils.unpack(request.handler(), message)
|
||||
.getPayload());
|
||||
}
|
||||
else {
|
||||
builder.headers(HeaderUtils.sanitize(request.headers()));
|
||||
}
|
||||
|
||||
Publisher<?> output = result;
|
||||
return Flux.from(output).then(Mono.fromSupplier(() -> builder.body(output)));
|
||||
}
|
||||
|
||||
private Mono<ResponseEntity<?>> response(Object handler, Publisher<?> result,
|
||||
Boolean single, boolean getter) {
|
||||
private Mono<ResponseEntity<?>> response(FunctionWrapper request, Object handler,
|
||||
Publisher<?> result, Boolean single, boolean getter) {
|
||||
|
||||
BodyBuilder builder = ResponseEntity.ok();
|
||||
if (inspector.isMessage(handler)) {
|
||||
@@ -189,8 +198,12 @@ public class RequestProcessor {
|
||||
.doOnNext(value -> addHeaders(builder, (Message<?>) value))
|
||||
.map(message -> MessageUtils.unpack(handler, message).getPayload());
|
||||
}
|
||||
else {
|
||||
builder.headers(HeaderUtils.sanitize(request.headers()));
|
||||
}
|
||||
|
||||
if (isOutputSingle(handler) && (single != null && single || getter || isInputMultiple(handler))) {
|
||||
if (isOutputSingle(handler)
|
||||
&& (single != null && single || getter || isInputMultiple(handler))) {
|
||||
result = Mono.from(result);
|
||||
}
|
||||
|
||||
@@ -213,7 +226,8 @@ public class RequestProcessor {
|
||||
return false;
|
||||
}
|
||||
else {
|
||||
return wrapper == type || Mono.class.equals(wrapper) || Optional.class.equals(wrapper);
|
||||
return wrapper == type || Mono.class.equals(wrapper)
|
||||
|| Optional.class.equals(wrapper);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ package org.springframework.cloud.function.web.util;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import org.springframework.http.HttpHeaders;
|
||||
@@ -30,10 +31,17 @@ import org.springframework.messaging.MessageHeaders;
|
||||
public class HeaderUtils {
|
||||
|
||||
private static HttpHeaders IGNORED = new HttpHeaders();
|
||||
|
||||
|
||||
private static HttpHeaders REQUEST_ONLY = new HttpHeaders();
|
||||
|
||||
static {
|
||||
IGNORED.add(MessageHeaders.ID, "");
|
||||
IGNORED.add(HttpHeaders.CONTENT_LENGTH, "0");
|
||||
// Headers that would typically be added by a downstream client
|
||||
REQUEST_ONLY.add(HttpHeaders.ACCEPT, "");
|
||||
REQUEST_ONLY.add(HttpHeaders.CONTENT_LENGTH, "");
|
||||
REQUEST_ONLY.add(HttpHeaders.CONTENT_TYPE, "");
|
||||
REQUEST_ONLY.add(HttpHeaders.HOST, "");
|
||||
}
|
||||
|
||||
public static HttpHeaders fromMessage(MessageHeaders headers, HttpHeaders request) {
|
||||
@@ -51,6 +59,18 @@ public class HeaderUtils {
|
||||
return result;
|
||||
}
|
||||
|
||||
public static HttpHeaders sanitize(HttpHeaders request) {
|
||||
HttpHeaders result = new HttpHeaders();
|
||||
for (String name : request.keySet()) {
|
||||
List<String> value = request.get(name);
|
||||
name = name.toLowerCase();
|
||||
if (!IGNORED.containsKey(name) && !REQUEST_ONLY.containsKey(name)) {
|
||||
result.put(name, value);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private static Collection<?> multi(Object value) {
|
||||
return value instanceof Collection ? (Collection<?>) value : Arrays.asList(value);
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
package org.springframework.cloud.function.web.mvc;
|
||||
|
||||
import java.net.URI;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
|
||||
@@ -30,6 +31,8 @@ import org.springframework.boot.test.web.client.TestRestTemplate;
|
||||
import org.springframework.cloud.function.web.RestApplication;
|
||||
import org.springframework.cloud.function.web.mvc.HeadersToMessageTests.TestConfiguration;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.RequestEntity;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.messaging.Message;
|
||||
import org.springframework.messaging.support.MessageBuilder;
|
||||
@@ -46,8 +49,9 @@ import static org.junit.Assert.assertTrue;
|
||||
*/
|
||||
@RunWith(SpringRunner.class)
|
||||
@SpringBootTest(webEnvironment = WebEnvironment.RANDOM_PORT, properties = {
|
||||
"spring.main.web-application-type=servlet", "spring.cloud.function.web.path=/functions" })
|
||||
@ContextConfiguration(classes= {RestApplication.class, TestConfiguration.class})
|
||||
"spring.main.web-application-type=servlet",
|
||||
"spring.cloud.function.web.path=/functions" })
|
||||
@ContextConfiguration(classes = { RestApplication.class, TestConfiguration.class })
|
||||
public class HeadersToMessageTests {
|
||||
|
||||
@Autowired
|
||||
@@ -65,6 +69,17 @@ public class HeadersToMessageTests {
|
||||
assertEquals("bar", postForEntity.getHeaders().get("foo").get(0));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testHeadersPropagatedByDefault() throws Exception {
|
||||
HttpEntity<String> postForEntity = rest.exchange(RequestEntity
|
||||
.post(new URI("/functions/vanilla")).header("x-context-type", "rubbish")
|
||||
.body("{\"name\":\"Bob\",\"age\":25}"), String.class);
|
||||
assertEquals("{\"name\":\"Bob\",\"age\":25,\"foo\":\"bar\"}",
|
||||
postForEntity.getBody());
|
||||
assertTrue(postForEntity.getHeaders().containsKey("x-context-type"));
|
||||
assertEquals("rubbish", postForEntity.getHeaders().get("x-context-type").get(0));
|
||||
}
|
||||
|
||||
@EnableAutoConfiguration
|
||||
@org.springframework.boot.test.context.TestConfiguration
|
||||
protected static class TestConfiguration {
|
||||
@@ -78,5 +93,14 @@ public class HeadersToMessageTests {
|
||||
return message;
|
||||
};
|
||||
}
|
||||
|
||||
@Bean
|
||||
public Function<Map<String, Object>, Map<String, Object>> vanilla() {
|
||||
return request -> {
|
||||
Map<String, Object> message = new LinkedHashMap<>(request);
|
||||
message.put("foo", "bar");
|
||||
return message;
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user