diff --git a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/FunctionRSocketMessageHandler.java b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/FunctionRSocketMessageHandler.java index 14aedff59..2507c9921 100644 --- a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/FunctionRSocketMessageHandler.java +++ b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/FunctionRSocketMessageHandler.java @@ -264,23 +264,24 @@ class FunctionRSocketMessageHandler extends RSocketMessageHandler { return MessageBuilder.withPayload(structure).build(); } } + else { + return MessageBuilder.withPayload(structure).build(); + } } return MessageBuilder.withPayload(bytePayload).copyHeadersIfAbsent(message.getHeaders()).build(); }); return MessageBuilder.createMessage(argument, message.getHeaders()); } - else { + else { // delegate to the existing argument resolvers for (HandlerMethodArgumentResolver handlerMethodArgumentResolver : this.resolvers) { if (handlerMethodArgumentResolver.supportsParameter(parameter)) { Publisher arg = handlerMethodArgumentResolver.resolveArgument(parameter, message); return MessageBuilder.withPayload(arg).copyHeadersIfAbsent(message.getHeaders()).build(); } - } return message; } } - } protected static final class FunctionRSocketPayloadReturnValueHandler extends RSocketPayloadReturnValueHandler { diff --git a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/MessageAwareJsonDecoder.java b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/MessageAwareJsonDecoder.java index 4df89dac9..771de71f9 100644 --- a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/MessageAwareJsonDecoder.java +++ b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/MessageAwareJsonDecoder.java @@ -55,7 +55,7 @@ class MessageAwareJsonDecoder extends AbstractDecoder { @Override public boolean canDecode(ResolvableType elementType, @Nullable MimeType mimeType) { - return mimeType.isCompatibleWith(MimeTypeUtils.APPLICATION_JSON); + return mimeType != null && mimeType.isCompatibleWith(MimeTypeUtils.APPLICATION_JSON); } @SuppressWarnings("unchecked") diff --git a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/MessageAwareJsonEncoder.java b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/MessageAwareJsonEncoder.java index ab1ded943..cf70bedc9 100644 --- a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/MessageAwareJsonEncoder.java +++ b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/MessageAwareJsonEncoder.java @@ -66,7 +66,7 @@ class MessageAwareJsonEncoder extends AbstractEncoder { @Override public boolean canEncode(ResolvableType elementType, MimeType mimeType) { - boolean canEncode = mimeType.isCompatibleWith(MimeTypeUtils.APPLICATION_JSON); + boolean canEncode = mimeType != null && mimeType.isCompatibleWith(MimeTypeUtils.APPLICATION_JSON); if (canEncode && this.isClient) { canEncode = (FunctionTypeUtils.isMessage(elementType.getType()) || Map.class.isAssignableFrom(FunctionTypeUtils.getRawType(elementType.getType()))); diff --git a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketListenerFunction.java b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketListenerFunction.java index 1165fd1b9..53f81e688 100644 --- a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketListenerFunction.java +++ b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketListenerFunction.java @@ -46,56 +46,51 @@ class RSocketListenerFunction implements Function> { private final FunctionInvocationWrapper targetFunction; RSocketListenerFunction(FunctionInvocationWrapper targetFunction) { - this.targetFunction = targetFunction; - } - - /* - * We need to maintain the input typeless to ensure that no encoder/decoders will attempt any conversion. - * That said it will always be Message> - */ - @SuppressWarnings("unchecked") - @Override - public Publisher apply(Object input) { - Assert.isTrue(this.targetFunction != null, "Failed to discover target function. \n" + Assert.isTrue(targetFunction != null, "Failed to discover target function. \n" + "To fix it you should either provide 'spring.cloud.function.definition' property " + "or if you are using RSocketRequester provide valid function definition via 'route' " + "operator (e.g., requester.route(\"echo\"))"); -// if (input instanceof Message) { - Message> inputMessage = (Message>) input; - FrameType frameType = RSocketFrameTypeMessageCondition.getFrameType(inputMessage); - switch (frameType) { - case REQUEST_FNF: - return handle(inputMessage); - case REQUEST_RESPONSE: - case REQUEST_STREAM: - case REQUEST_CHANNEL: - return handleAndReply(inputMessage); - default: - throw new UnsupportedOperationException(); - } -// } -// throw new UnsupportedOperationException("Expecting input to be of type Message>"); + this.targetFunction = targetFunction; + } + + + @SuppressWarnings("unchecked") + @Override + public Publisher apply(Object input) { + /* + * We need to maintain the input typeless to ensure that no encoder/decoders will attempt any conversion. + * That said it will always be Message> + */ + Message> inputMessage = (Message>) input; + + FrameType frameType = RSocketFrameTypeMessageCondition.getFrameType(inputMessage); + switch (frameType) { + case REQUEST_FNF: + return handle(inputMessage); + case REQUEST_RESPONSE: + case REQUEST_STREAM: + case REQUEST_CHANNEL: + return handleAndReply(inputMessage); + default: + throw new UnsupportedOperationException(); + } } @SuppressWarnings({ "unchecked", "rawtypes" }) private Mono handle(Message> messageToProcess) { if (this.targetFunction.isRoutingFunction()) { Flux dataFlux = Flux.from(messageToProcess.getPayload()) - .map((payload) -> { - return MessageBuilder.createMessage(payload, messageToProcess.getHeaders()); - }); + .map(payload -> MessageBuilder.createMessage(payload, messageToProcess.getHeaders())); return dataFlux.doOnNext(this.targetFunction).then(); } else if (this.targetFunction.isConsumer()) { - Flux dataFlux = - Flux.from(messageToProcess.getPayload()) - .map((payload) -> MessageBuilder.createMessage(payload, messageToProcess.getHeaders())); - if (FunctionTypeUtils.isPublisher(this.targetFunction.getInputType())) { - dataFlux = dataFlux.transform((Function) this.targetFunction); - } - else { - dataFlux = dataFlux.doOnNext(this.targetFunction); - } + Flux dataFlux = Flux.from(messageToProcess.getPayload()) + .map(payload -> this.buildReceivedMessage(payload, messageToProcess.getHeaders())); + + dataFlux = FunctionTypeUtils.isPublisher(this.targetFunction.getInputType()) + ? dataFlux.transform((Function) this.targetFunction) + : dataFlux.doOnNext(this.targetFunction); + return dataFlux.then(); } else { @@ -105,13 +100,9 @@ class RSocketListenerFunction implements Function> { @SuppressWarnings({ "unchecked", "rawtypes" }) private Flux handleAndReply(Message> messageToProcess) { - Flux dataFlux = - Flux.from(messageToProcess.getPayload()) - .map((payload) -> { - return payload instanceof Message - ? MessageBuilder.fromMessage((Message) payload).copyHeadersIfAbsent(messageToProcess.getHeaders()).build() - : MessageBuilder.withPayload(payload).copyHeadersIfAbsent(messageToProcess.getHeaders()).build(); - }); + Flux dataFlux = Flux.from(messageToProcess.getPayload()) + .map(payload -> this.buildReceivedMessage(payload, messageToProcess.getHeaders())); + if (this.targetFunction.getInputType() != null && FunctionTypeUtils.isPublisher(this.targetFunction.getInputType())) { dataFlux = dataFlux.transform((Function) this.targetFunction); } @@ -132,6 +123,12 @@ class RSocketListenerFunction implements Function> { return dataFlux; } + private Message buildReceivedMessage(Object mayBeMessage, MessageHeaders messageHeaders) { + return mayBeMessage instanceof Message + ? MessageBuilder.fromMessage((Message) mayBeMessage).copyHeadersIfAbsent(messageHeaders).build() + : MessageBuilder.withPayload(mayBeMessage).copyHeadersIfAbsent(messageHeaders).build(); + } + /* * This will ensure that unless CT is application/json for which we provide Message aware encoder/decoder * the payload is extracted since no other available encoders/decoders understand Message. diff --git a/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationTests.java b/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationTests.java index 61bddb018..e33cce346 100644 --- a/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationTests.java +++ b/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationTests.java @@ -16,6 +16,7 @@ package org.springframework.cloud.function.rsocket; +import java.util.Map; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; @@ -38,9 +39,12 @@ import org.springframework.context.annotation.Configuration; import org.springframework.core.env.ConfigurableEnvironment; import org.springframework.messaging.rsocket.RSocketRequester; import org.springframework.test.util.ReflectionTestUtils; +import org.springframework.util.MimeType; import org.springframework.util.MimeTypeUtils; import org.springframework.util.SocketUtils; +import static org.assertj.core.api.Assertions.assertThat; + /** * * @author Oleg Zhurakousky @@ -122,6 +126,35 @@ public class RSocketAutoConfigurationTests { } } + @SuppressWarnings("unchecked") + @Test + public void testWithCborContentType() { + int port = SocketUtils.findAvailableTcpPort(); + try ( + ConfigurableApplicationContext applicationContext = + new SpringApplicationBuilder(SampleFunctionConfiguration.class) + .web(WebApplicationType.NONE) + .run("--logging.level.org.springframework.cloud.function=DEBUG", + "--spring.cloud.function.definition=uppercase", + "--spring.rsocket.server.port=" + port); + ) { + RSocketRequester.Builder rsocketRequesterBuilder = + applicationContext.getBean(RSocketRequester.Builder.class); + + Person p = new Person(); + p.setAge(23); + p.setName("Bob"); + Map m = rsocketRequesterBuilder + .dataMimeType(MimeType.valueOf("application/cbor")) + .tcp("localhost", port) + .route("echoMap") + .data(p) + .retrieveMono(Map.class).block(); + assertThat(m.get("name")).isEqualTo("Bob"); + assertThat(m.get("age")).isEqualTo(23); + } + } + @Test @Disabled public void testImperativeFunctionAsRequestReplyWithDefinitionExplicitExpectedOutputCt() { @@ -472,6 +505,10 @@ public class RSocketAutoConfigurationTests { .run("--logging.level.org.springframework.cloud.function=DEBUG", "--spring.rsocket.server.port=" + port); ) { + + SampleFunctionConfiguration config = applicationContext.getBean(SampleFunctionConfiguration.class); + + RSocketRequester.Builder rsocketRequesterBuilder = applicationContext.getBean(RSocketRequester.Builder.class); @@ -482,6 +519,8 @@ public class RSocketAutoConfigurationTests { .as(StepVerifier::create) .expectComplete() .verify(); + String result = config.consumerData.asMono().block(); + assertThat(result).isEqualTo("hello"); } } @@ -550,7 +589,7 @@ public class RSocketAutoConfigurationTests { @Configuration public static class SampleFunctionConfiguration { - final Sinks.One consumerData = Sinks.one(); + final Sinks.One consumerData = Sinks.one(); @Bean public Function uppercase() { @@ -567,6 +606,11 @@ public class RSocketAutoConfigurationTests { return v -> v; } + @Bean + public Function, Map> echoMap() { + return v -> v; + } + @Bean public Function, Flux> uppercaseReactive() { return flux -> flux.map(v -> { @@ -576,7 +620,7 @@ public class RSocketAutoConfigurationTests { } @Bean - public Consumer log() { + public Consumer log() { return this.consumerData::tryEmitValue; } @@ -612,4 +656,21 @@ public class RSocketAutoConfigurationTests { } } + public static class Person { + private String name; + private int age; + public String getName() { + return name; + } + public void setName(String name) { + this.name = name; + } + public int getAge() { + return age; + } + public void setAge(int age) { + this.age = age; + } + } + }