diff --git a/spring-cloud-stream/src/main/java/org/springframework/cloud/stream/binder/EmbeddedHeadersMessageConverter.java b/spring-cloud-stream/src/main/java/org/springframework/cloud/stream/binder/EmbeddedHeadersMessageConverter.java index b71f0c076..8ef2a3b9b 100644 --- a/spring-cloud-stream/src/main/java/org/springframework/cloud/stream/binder/EmbeddedHeadersMessageConverter.java +++ b/spring-cloud-stream/src/main/java/org/springframework/cloud/stream/binder/EmbeddedHeadersMessageConverter.java @@ -1,5 +1,5 @@ /* - * Copyright 2014-2015 the original author or authors. + * Copyright 2014-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ import javax.xml.bind.DatatypeConverter; import org.springframework.integration.IntegrationMessageHeaderAccessor; import org.springframework.integration.support.json.Jackson2JsonObjectMapper; import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHeaders; /** * Encodes requested headers into payload with format @@ -39,6 +40,7 @@ import org.springframework.messaging.Message; * * @author Eric Bottard * @author Gary Russell + * @author Ilayaperumal Gopinathan */ public class EmbeddedHeadersMessageConverter { @@ -93,38 +95,54 @@ public class EmbeddedHeadersMessageConverter { * have been promoted back to actual headers. The new payload is now the original * payload. * - * @param message the message to extract headers - * @param copyRequestHeaders boolean value to specify if original headers should be - * copied + * @param message the message to extract headers + * @param copyRequestHeaders boolean value to specify if the request headers should be + * copied */ public MessageValues extractHeaders(Message message, boolean copyRequestHeaders) throws Exception { - byte[] bytes = message.getPayload(); - ByteBuffer byteBuffer = ByteBuffer.wrap(bytes); + return extractHeaders(message.getPayload(), copyRequestHeaders, message.getHeaders()); + } + + private MessageValues extractHeaders(byte[] payload, boolean copyRequestHeaders, MessageHeaders requestHeaders) throws Exception { + ByteBuffer byteBuffer = ByteBuffer.wrap(payload); int headerCount = byteBuffer.get() & 0xff; if (headerCount < 255) { - return oldExtractHeaders(byteBuffer, bytes, headerCount, message, copyRequestHeaders); + return oldExtractHeaders(byteBuffer, payload, headerCount, copyRequestHeaders, requestHeaders); } else { headerCount = byteBuffer.get() & 0xff; Map headers = new HashMap(); for (int i = 0; i < headerCount; i++) { int len = byteBuffer.get() & 0xff; - String headerName = new String(bytes, byteBuffer.position(), len, "UTF-8"); + String headerName = new String(payload, byteBuffer.position(), len, "UTF-8"); byteBuffer.position(byteBuffer.position() + len); len = byteBuffer.getInt(); - String headerValue = new String(bytes, byteBuffer.position(), len, "UTF-8"); + String headerValue = new String(payload, byteBuffer.position(), len, "UTF-8"); Object headerContent = this.objectMapper.fromJson(headerValue, Object.class); headers.put(headerName, headerContent); byteBuffer.position(byteBuffer.position() + len); } byte[] newPayload = new byte[byteBuffer.remaining()]; byteBuffer.get(newPayload); - return buildMessageValues(message, newPayload, headers, copyRequestHeaders); + return buildMessageValues(newPayload, headers, copyRequestHeaders, requestHeaders); } } + /** + * Return a message where headers, that were originally embedded into the payload, + * have been promoted back to actual headers. The new payload is now the original + * payload. + * + * @param payload the message payload + * @return the message with extracted headers + * @throws Exception + */ + public MessageValues extractHeaders(byte[] payload) throws Exception { + return extractHeaders(payload, false, null); + } + private MessageValues oldExtractHeaders(ByteBuffer byteBuffer, byte[] bytes, int headerCount, - Message message, boolean copyRequestHeaders) + boolean copyRequestHeaders, MessageHeaders requestHeaders) throws UnsupportedEncodingException { Map headers = new HashMap(); for (int i = 0; i < headerCount; i++) { @@ -144,14 +162,14 @@ public class EmbeddedHeadersMessageConverter { } byte[] newPayload = new byte[byteBuffer.remaining()]; byteBuffer.get(newPayload); - return buildMessageValues(message, newPayload, headers, copyRequestHeaders); + return buildMessageValues(newPayload, headers, copyRequestHeaders, requestHeaders); } - private MessageValues buildMessageValues(Message message, byte[] payload, Map headers, - boolean copyRequestHeaders) { + private MessageValues buildMessageValues(byte[] payload, Map headers, + boolean copyRequestHeaders, MessageHeaders requestHeaders) { MessageValues messageValues = new MessageValues(payload, headers); - if (copyRequestHeaders) { - messageValues.copyHeadersIfAbsent(message.getHeaders()); + if (copyRequestHeaders && requestHeaders != null) { + messageValues.copyHeadersIfAbsent(requestHeaders); } return messageValues; } diff --git a/spring-cloud-stream/src/test/java/org/springframework/cloud/stream/binder/MessageConverterTests.java b/spring-cloud-stream/src/test/java/org/springframework/cloud/stream/binder/MessageConverterTests.java index 43cac1489..66306b175 100644 --- a/spring-cloud-stream/src/test/java/org/springframework/cloud/stream/binder/MessageConverterTests.java +++ b/spring-cloud-stream/src/test/java/org/springframework/cloud/stream/binder/MessageConverterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2015 the original author or authors. + * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ import static org.assertj.core.api.Assertions.assertThat; /** * @author Gary Russell + * @author Ilayaperumal Gopinathan * @since 1.0 * */ @@ -48,6 +49,23 @@ public class MessageConverterTests { assertThat(extracted.get("baz")).isEqualTo("quxx"); } + @Test + public void testHeaderExtractionWithDirectPayload() throws Exception { + EmbeddedHeadersMessageConverter converter = new EmbeddedHeadersMessageConverter(); + Message message = MessageBuilder.withPayload("Hello".getBytes()).setHeader("foo", "bar") + .setHeader("baz", "quxx").build(); + byte[] embedded = converter.embedHeaders(new MessageValues(message), "foo", "baz"); + assertThat(embedded[0] & 0xff).isEqualTo(0xff); + assertThat(new String(embedded).substring(1)).isEqualTo( + "\u0002\u0003foo\u0000\u0000\u0000\u0005\"bar\"\u0003baz\u0000\u0000\u0000\u0006\"quxx\"Hello"); + + MessageValues extracted = converter.extractHeaders(embedded); + assertThat(new String((byte[]) extracted.getPayload())).isEqualTo("Hello"); + assertThat(extracted.get("foo")).isEqualTo("bar"); + assertThat(extracted.get("baz")).isEqualTo("quxx"); + } + + @Test public void testUnicodeHeader() throws Exception { EmbeddedHeadersMessageConverter converter = new EmbeddedHeadersMessageConverter();