diff --git a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/FunctionRSocketMessageHandler.java b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/FunctionRSocketMessageHandler.java index 0f56d80f2..9205b7aed 100644 --- a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/FunctionRSocketMessageHandler.java +++ b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/FunctionRSocketMessageHandler.java @@ -16,7 +16,6 @@ package org.springframework.cloud.function.rsocket; -import java.lang.reflect.Field; import java.lang.reflect.Method; import java.util.Collections; import java.util.List; @@ -42,7 +41,6 @@ import org.springframework.core.codec.Encoder; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.lang.Nullable; import org.springframework.messaging.Message; -import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.MessagingException; import org.springframework.messaging.handler.CompositeMessageCondition; import org.springframework.messaging.handler.DestinationPatternsMessageCondition; @@ -70,8 +68,6 @@ class FunctionRSocketMessageHandler extends RSocketMessageHandler { private final FunctionCatalog functionCatalog; - private final Field headersField; - private static final Method FUNCTION_APPLY_METHOD = ReflectionUtils.findMethod(Function.class, "apply", (Class[]) null); @@ -85,8 +81,6 @@ class FunctionRSocketMessageHandler extends RSocketMessageHandler { FunctionRSocketMessageHandler(FunctionCatalog functionCatalog) { setHandlerPredicate((clazz) -> false); this.functionCatalog = functionCatalog; - this.headersField = ReflectionUtils.findField(MessageHeaders.class, "headers"); - this.headersField.setAccessible(true); } @@ -185,14 +179,17 @@ class FunctionRSocketMessageHandler extends RSocketMessageHandler { } } + /** + * This metadata extractor will ensure that any JSON data passed + * via metadata will be copied into Message headers. + */ private static class HeadersAwareMetadataExtractor extends DefaultMetadataExtractor { HeadersAwareMetadataExtractor(List> decoders) { super(decoders); super.metadataToExtract(MimeTypeUtils.APPLICATION_JSON, new ParameterizedTypeReference>() { - }, (jsonMap, outputMap) -> { - outputMap.putAll(jsonMap); - }); + }, (jsonMap, outputMap) -> outputMap.putAll(jsonMap) + ); } } diff --git a/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationTests.java b/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationTests.java index 7f288379c..79be8c982 100644 --- a/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationTests.java +++ b/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationTests.java @@ -48,7 +48,7 @@ import org.springframework.util.SocketUtils; */ public class RSocketAutoConfigurationTests { @Test - public void testImperativeFunctionAsRequestReply() { + public void testImperativeFunctionAsRequestReplyWithDefinition() { int port = SocketUtils.findAvailableTcpPort(); try ( ConfigurableApplicationContext applicationContext = @@ -61,6 +61,30 @@ public class RSocketAutoConfigurationTests { RSocketRequester.Builder rsocketRequesterBuilder = applicationContext.getBean(RSocketRequester.Builder.class); + rsocketRequesterBuilder.tcp("localhost", port) + .route("") + .data("\"hello\"") + .retrieveMono(String.class) + .as(StepVerifier::create) + .expectNext("\"HELLO\"") + .expectComplete() + .verify(); + } + } + + @Test + public void testImperativeFunctionAsRequestReply() { + int port = SocketUtils.findAvailableTcpPort(); + try ( + ConfigurableApplicationContext applicationContext = + new SpringApplicationBuilder(SampleFunctionConfiguration.class) + .web(WebApplicationType.NONE) + .run("--logging.level.org.springframework.cloud.function=DEBUG", + "--spring.rsocket.server.port=" + port); + ) { + RSocketRequester.Builder rsocketRequesterBuilder = + applicationContext.getBean(RSocketRequester.Builder.class); + rsocketRequesterBuilder.tcp("localhost", port) .route("uppercase") .data("\"hello\"") @@ -72,6 +96,30 @@ public class RSocketAutoConfigurationTests { } } + @Test + public void testImperativeFunctionAsRequestReplyWithComposition() { + int port = SocketUtils.findAvailableTcpPort(); + try ( + ConfigurableApplicationContext applicationContext = + new SpringApplicationBuilder(SampleFunctionConfiguration.class) + .web(WebApplicationType.NONE) + .run("--logging.level.org.springframework.cloud.function=DEBUG", + "--spring.rsocket.server.port=" + port); + ) { + RSocketRequester.Builder rsocketRequesterBuilder = + applicationContext.getBean(RSocketRequester.Builder.class); + + rsocketRequesterBuilder.tcp("localhost", port) + .route("uppercase|concat") + .data("\"hello\"") + .retrieveMono(String.class) + .as(StepVerifier::create) + .expectNext("\"HELLOHELLO\"") + .expectComplete() + .verify(); + } + } + @Test public void testSupplierAsRequestReply() { int port = SocketUtils.findAvailableTcpPort(); @@ -80,7 +128,6 @@ public class RSocketAutoConfigurationTests { new SpringApplicationBuilder(SampleFunctionConfiguration.class) .web(WebApplicationType.NONE) .run("--logging.level.org.springframework.cloud.function=DEBUG", - "--spring.cloud.function.definition=source", "--spring.rsocket.server.port=" + port); ) { RSocketRequester.Builder rsocketRequesterBuilder = @@ -105,7 +152,6 @@ public class RSocketAutoConfigurationTests { new SpringApplicationBuilder(SampleFunctionConfiguration.class) .web(WebApplicationType.NONE) .run("--logging.level.org.springframework.cloud.function=DEBUG", - "--spring.cloud.function.definition=uppercase", "--spring.rsocket.server.port=" + port); ) { RSocketRequester.Builder rsocketRequesterBuilder = @@ -130,7 +176,6 @@ public class RSocketAutoConfigurationTests { new SpringApplicationBuilder(SampleFunctionConfiguration.class) .web(WebApplicationType.NONE) .run("--logging.level.org.springframework.cloud.function=DEBUG", - "--spring.cloud.function.definition=uppercase", "--spring.rsocket.server.port=" + port); ) { RSocketRequester.Builder rsocketRequesterBuilder = @@ -155,7 +200,6 @@ public class RSocketAutoConfigurationTests { new SpringApplicationBuilder(SampleFunctionConfiguration.class) .web(WebApplicationType.NONE) .run("--logging.level.org.springframework.cloud.function=DEBUG", - "--spring.cloud.function.definition=uppercaseReactive", "--spring.rsocket.server.port=" + port); ) { RSocketRequester.Builder rsocketRequesterBuilder = @@ -180,7 +224,6 @@ public class RSocketAutoConfigurationTests { new SpringApplicationBuilder(SampleFunctionConfiguration.class) .web(WebApplicationType.NONE) .run("--logging.level.org.springframework.cloud.function=DEBUG", - "--spring.cloud.function.definition=uppercaseReactive", "--spring.rsocket.server.port=" + port); ) { RSocketRequester.Builder rsocketRequesterBuilder = @@ -205,7 +248,6 @@ public class RSocketAutoConfigurationTests { new SpringApplicationBuilder(SampleFunctionConfiguration.class) .web(WebApplicationType.NONE) .run("--logging.level.org.springframework.cloud.function=DEBUG", - "--spring.cloud.function.definition=uppercaseReactive", "--spring.rsocket.server.port=" + port); ) { RSocketRequester.Builder rsocketRequesterBuilder = @@ -223,7 +265,7 @@ public class RSocketAutoConfigurationTests { } @Test - public void testRequestReplyFunctionWithComposition() { + public void testRequestReplyFunctionWithDistributedComposition() { int portA = SocketUtils.findAvailableTcpPort(); int portB = SocketUtils.findAvailableTcpPort(); try ( @@ -303,20 +345,18 @@ public class RSocketAutoConfigurationTests { @Test public void testFireAndForgetConsumer() { + int port = SocketUtils.findAvailableTcpPort(); try ( ConfigurableApplicationContext applicationContext = new SpringApplicationBuilder(SampleFunctionConfiguration.class) .web(WebApplicationType.NONE) .run("--logging.level.org.springframework.cloud.function=DEBUG", - "--spring.cloud.function.definition=log", - "--spring.rsocket.server.port=0"); + "--spring.rsocket.server.port=" + port); ) { RSocketRequester.Builder rsocketRequesterBuilder = applicationContext.getBean(RSocketRequester.Builder.class); - RSocketServerBootstrap serverBootstrap = applicationContext.getBean(RSocketServerBootstrap.class); - RSocketServer server = (RSocketServer) ReflectionTestUtils.getField(serverBootstrap, "server"); - rsocketRequesterBuilder.tcp("localhost", server.address().getPort()) + rsocketRequesterBuilder.tcp("localhost", port) .route("log") .data("\"hello\"") .send()