diff --git a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/MessageAwareJsonEncoder.java b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/MessageAwareJsonEncoder.java index cf70bedc9..02aa41045 100644 --- a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/MessageAwareJsonEncoder.java +++ b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/MessageAwareJsonEncoder.java @@ -16,6 +16,7 @@ package org.springframework.cloud.function.rsocket; +import java.nio.charset.StandardCharsets; import java.util.Collections; import java.util.List; import java.util.Map; @@ -80,12 +81,18 @@ class MessageAwareJsonEncoder extends AbstractEncoder { return Collections.singletonList(MimeTypeUtils.APPLICATION_JSON); } + @SuppressWarnings({ "unchecked", "rawtypes" }) @Override public DataBuffer encodeValue(Object value, DataBufferFactory bufferFactory, ResolvableType valueType, @Nullable MimeType mimeType, @Nullable Map hints) { if (value instanceof Message) { + Object payload = ((Message) value).getPayload(); value = FunctionRSocketUtils.sanitizeMessageToMap((Message) value); + if (payload instanceof byte[]) { + payload = new String((byte[]) payload, StandardCharsets.UTF_8); // safe for cases when we have JSON + ((Map) value).put(FunctionRSocketUtils.PAYLOAD, payload); + } } else if (!(value instanceof Map)) { if (JsonMapper.isJsonString(value)) { diff --git a/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/MessagingTests.java b/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/MessagingTests.java index 936354d67..a4191e729 100644 --- a/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/MessagingTests.java +++ b/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/MessagingTests.java @@ -227,6 +227,64 @@ public class MessagingTests { } } + @Test + public void testPojoMessageToPojoViaMessageExpectMessageRawPayload() { + int port = SocketUtils.findAvailableTcpPort(); + try ( + ConfigurableApplicationContext applicationContext = + new SpringApplicationBuilder(MessagingConfiguration.class) + .web(WebApplicationType.NONE) + .run("--logging.level.org.springframework.cloud.function=DEBUG", + "--spring.rsocket.server.port=" + port); + ) { + RSocketRequester.Builder rsocketRequesterBuilder = + applicationContext.getBean(RSocketRequester.Builder.class); + + Message message = MessageBuilder.withPayload("{\"name\":\"bob\"}".getBytes()) + .setHeader("someHeader", "foo") + .build(); + + Message result = rsocketRequesterBuilder.tcp("localhost", port) + .route("pojoMessageToPojo") + .data(message) + .retrieveMono(new ParameterizedTypeReference>() { + }) + .block(); + + assertThat(result.getPayload()).isEqualTo("{\"name\":\"BOB\"}".getBytes()); + assertThat(result.getHeaders().get("someHeader")).isEqualTo("foo"); + } + } + + @Test + public void testPojoMessageToPojoViaMessageExpectMessageStringPayload() { + int port = SocketUtils.findAvailableTcpPort(); + try ( + ConfigurableApplicationContext applicationContext = + new SpringApplicationBuilder(MessagingConfiguration.class) + .web(WebApplicationType.NONE) + .run("--logging.level.org.springframework.cloud.function=DEBUG", + "--spring.rsocket.server.port=" + port); + ) { + RSocketRequester.Builder rsocketRequesterBuilder = + applicationContext.getBean(RSocketRequester.Builder.class); + + Message message = MessageBuilder.withPayload("{\"name\":\"bob\"}") + .setHeader("someHeader", "foo") + .build(); + + Message result = rsocketRequesterBuilder.tcp("localhost", port) + .route("pojoMessageToPojo") + .data(message) + .retrieveMono(new ParameterizedTypeReference>() { + }) + .block(); + + assertThat(result.getPayload()).isEqualTo("{\"name\":\"BOB\"}"); + assertThat(result.getHeaders().get("someHeader")).isEqualTo("foo"); + } + } + @Test public void testPojoToMessageMap() { int port = SocketUtils.findAvailableTcpPort(); diff --git a/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationTests.java b/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationTests.java index 0bfb6f8e0..d83227f0c 100644 --- a/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationTests.java +++ b/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationTests.java @@ -541,7 +541,7 @@ public class RSocketAutoConfigurationTests { RSocketRequester requester = rsocketRequesterBuilder.tcp("localhost", server.address().getPort()); requester.route("reverse") - .data("\"hello\"") + .data("hello") .retrieveMono(String.class) .as(StepVerifier::create) .expectNext("olleh") @@ -575,7 +575,7 @@ public class RSocketAutoConfigurationTests { rsocketRequesterBuilder.tcp("localhost", port) .route(RoutingFunction.FUNCTION_NAME) .metadata("{\"function_definition\":\"uppercase|concat\"}", MimeTypeUtils.APPLICATION_JSON) - .data("\"hello\"") + .data("hello") .retrieveMono(String.class) .as(StepVerifier::create) .expectNext("HELLOHELLO") @@ -584,6 +584,36 @@ public class RSocketAutoConfigurationTests { } } + @Test + public void testByteArrayInOut() { + int port = SocketUtils.findAvailableTcpPort(); + try ( + ConfigurableApplicationContext applicationContext = + new SpringApplicationBuilder(SampleFunctionConfiguration.class) + .web(WebApplicationType.NONE) + .run("--logging.level.org.springframework.cloud.function=DEBUG", + "--spring.rsocket.server.port=" + port); + ) { + RSocketRequester.Builder rsocketRequesterBuilder = + applicationContext.getBean(RSocketRequester.Builder.class); + + String result = rsocketRequesterBuilder.tcp("localhost", port) + .route("uppercase") + .data("hello".getBytes()) + .retrieveMono(String.class) + .block(); + + assertThat(result).isEqualTo("HELLO"); + + byte[] resultBytes = rsocketRequesterBuilder.tcp("localhost", port) + .route("uppercase") + .data("hello".getBytes()) + .retrieveMono(byte[].class) + .block(); + + assertThat(resultBytes).isEqualTo("HELLO".getBytes()); + } + } @EnableAutoConfiguration @Configuration