diff --git a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/FunctionRSocketMessageHandler.java b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/FunctionRSocketMessageHandler.java index 6e8c2d612..0f56d80f2 100644 --- a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/FunctionRSocketMessageHandler.java +++ b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/FunctionRSocketMessageHandler.java @@ -16,9 +16,11 @@ package org.springframework.cloud.function.rsocket; +import java.lang.reflect.Field; import java.lang.reflect.Method; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.function.Function; @@ -30,6 +32,7 @@ import reactor.core.publisher.Mono; import org.springframework.cloud.function.context.FunctionCatalog; import org.springframework.cloud.function.context.catalog.SimpleFunctionRegistry.FunctionInvocationWrapper; import org.springframework.core.MethodParameter; +import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.ReactiveAdapterRegistry; import org.springframework.core.ResolvableType; import org.springframework.core.codec.ByteArrayDecoder; @@ -39,16 +42,20 @@ import org.springframework.core.codec.Encoder; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.lang.Nullable; import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.MessagingException; import org.springframework.messaging.handler.CompositeMessageCondition; import org.springframework.messaging.handler.DestinationPatternsMessageCondition; import org.springframework.messaging.handler.invocation.reactive.HandlerMethodArgumentResolver; import org.springframework.messaging.handler.invocation.reactive.HandlerMethodReturnValueHandler; import org.springframework.messaging.handler.invocation.reactive.SyncHandlerMethodArgumentResolver; +import org.springframework.messaging.rsocket.DefaultMetadataExtractor; +import org.springframework.messaging.rsocket.MetadataExtractor; import org.springframework.messaging.rsocket.annotation.support.RSocketFrameTypeMessageCondition; import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler; import org.springframework.messaging.rsocket.annotation.support.RSocketPayloadReturnValueHandler; import org.springframework.messaging.support.MessageBuilder; +import org.springframework.util.MimeTypeUtils; import org.springframework.util.ReflectionUtils; /** @@ -63,6 +70,8 @@ class FunctionRSocketMessageHandler extends RSocketMessageHandler { private final FunctionCatalog functionCatalog; + private final Field headersField; + private static final Method FUNCTION_APPLY_METHOD = ReflectionUtils.findMethod(Function.class, "apply", (Class[]) null); @@ -76,6 +85,8 @@ class FunctionRSocketMessageHandler extends RSocketMessageHandler { FunctionRSocketMessageHandler(FunctionCatalog functionCatalog) { setHandlerPredicate((clazz) -> false); this.functionCatalog = functionCatalog; + this.headersField = ReflectionUtils.findField(MessageHeaders.class, "headers"); + this.headersField.setAccessible(true); } @@ -85,12 +96,19 @@ class FunctionRSocketMessageHandler extends RSocketMessageHandler { super.afterPropertiesSet(); } + @SuppressWarnings("unchecked") + @Override + public MetadataExtractor getMetadataExtractor() { + return new HeadersAwareMetadataExtractor((List>) this.getDecoders()); + } + /** * Will check if there is a function handler registered for destination before proceeding. * This typically happens when user avoids using 'spring.cloud.function.definition' property. */ @Override public Mono handleMessage(Message message) throws MessagingException { + if (!FrameType.SETUP.equals(message.getHeaders().get("rsocketFrameType"))) { String destination = this.getDestination(message).value(); Set mappings = this.getDestinationLookup().keySet(); @@ -165,7 +183,17 @@ class FunctionRSocketMessageHandler extends RSocketMessageHandler { } return super.handleReturnValue(returnValue, returnType, message); } + } + private static class HeadersAwareMetadataExtractor extends DefaultMetadataExtractor { + HeadersAwareMetadataExtractor(List> decoders) { + super(decoders); + super.metadataToExtract(MimeTypeUtils.APPLICATION_JSON, + new ParameterizedTypeReference>() { + }, (jsonMap, outputMap) -> { + outputMap.putAll(jsonMap); + }); + } } } diff --git a/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationTests.java b/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationTests.java index 8c33b3ec9..7f288379c 100644 --- a/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationTests.java +++ b/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationTests.java @@ -31,12 +31,14 @@ import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.boot.builder.SpringApplicationBuilder; import org.springframework.boot.rsocket.context.RSocketServerBootstrap; import org.springframework.boot.rsocket.server.RSocketServer; +import org.springframework.cloud.function.context.config.RoutingFunction; import org.springframework.context.ConfigurableApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.env.ConfigurableEnvironment; import org.springframework.messaging.rsocket.RSocketRequester; import org.springframework.test.util.ReflectionTestUtils; +import org.springframework.util.MimeTypeUtils; import org.springframework.util.SocketUtils; /** @@ -220,7 +222,6 @@ public class RSocketAutoConfigurationTests { } } -// @Disabled @Test public void testRequestReplyFunctionWithComposition() { int portA = SocketUtils.findAvailableTcpPort(); @@ -367,6 +368,32 @@ public class RSocketAutoConfigurationTests { } } + @Test + public void testRoutingWithRoutingFunction() { + int port = SocketUtils.findAvailableTcpPort(); + try ( + ConfigurableApplicationContext applicationContext = + new SpringApplicationBuilder(SampleFunctionConfiguration.class) + .web(WebApplicationType.NONE) + .run("--logging.level.org.springframework.cloud.function=DEBUG", + "--spring.cloud.function.routing-expression=headers.function_definition", + "--spring.rsocket.server.port=" + port); + ) { + RSocketRequester.Builder rsocketRequesterBuilder = + applicationContext.getBean(RSocketRequester.Builder.class); + + rsocketRequesterBuilder.tcp("localhost", port) + .route(RoutingFunction.FUNCTION_NAME) + .metadata("{\"function_definition\":\"uppercase|concat\"}", MimeTypeUtils.APPLICATION_JSON) + .data("\"hello\"") + .retrieveMono(String.class) + .as(StepVerifier::create) + .expectNext("\"HELLOHELLO\"") + .expectComplete() + .verify(); + } + } + @EnableAutoConfiguration @Configuration