diff --git a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/FunctionRSocketMessageHandler.java b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/FunctionRSocketMessageHandler.java index cfebb2b19..a7774b544 100644 --- a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/FunctionRSocketMessageHandler.java +++ b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/FunctionRSocketMessageHandler.java @@ -18,6 +18,7 @@ package org.springframework.cloud.function.rsocket; import java.lang.reflect.Field; import java.lang.reflect.Method; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; @@ -48,6 +49,7 @@ import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.MessagingException; import org.springframework.messaging.handler.CompositeMessageCondition; import org.springframework.messaging.handler.DestinationPatternsMessageCondition; +import org.springframework.messaging.handler.MessageCondition; import org.springframework.messaging.handler.invocation.reactive.HandlerMethodArgumentResolver; import org.springframework.messaging.handler.invocation.reactive.HandlerMethodReturnValueHandler; import org.springframework.messaging.handler.invocation.reactive.SyncHandlerMethodArgumentResolver; @@ -59,6 +61,7 @@ import org.springframework.messaging.rsocket.annotation.support.RSocketPayloadRe import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.MimeTypeUtils; import org.springframework.util.ReflectionUtils; +import org.springframework.util.RouteMatcher; import org.springframework.util.RouteMatcher.Route; import org.springframework.util.StringUtils; import org.springframework.web.util.pattern.PathPatternRouteMatcher; @@ -73,6 +76,8 @@ import org.springframework.web.util.pattern.PathPatternRouteMatcher; */ class FunctionRSocketMessageHandler extends RSocketMessageHandler { + public static final String RECONSILED_LOOKUP_DESTINATION_HEADER = "reconsiledLookupDestination"; + private final FunctionCatalog functionCatalog; private final FunctionProperties functionProperties; @@ -130,6 +135,27 @@ class FunctionRSocketMessageHandler extends RSocketMessageHandler { return super.handleMessage(message); } + @Override + protected RouteMatcher.Route getDestination(Message message) { + RouteMatcher.Route reconsiledDestination = (RouteMatcher.Route) message.getHeaders().get(RECONSILED_LOOKUP_DESTINATION_HEADER); + return reconsiledDestination == null ? super.getDestination(message) : reconsiledDestination; + } + + @Override + protected CompositeMessageCondition getMatchingMapping(CompositeMessageCondition mapping, Message message) { + List> result = new ArrayList<>(mapping.getMessageConditions().size()); + for (MessageCondition condition : mapping.getMessageConditions()) { + MessageCondition matchingCondition = condition instanceof DestinationPatternsMessageCondition + ? condition + : (MessageCondition) condition.getMatchingCondition(message); + if (matchingCondition == null) { + return null; + } + result.add(matchingCondition); + } + return new CompositeMessageCondition(result.toArray(new MessageCondition[0])); + } + void registerFunctionHandler(Function function, String route) { CompositeMessageCondition condition = new CompositeMessageCondition(REQUEST_CONDITION, @@ -180,7 +206,7 @@ class FunctionRSocketMessageHandler extends RSocketMessageHandler { Map headersMap = (Map) ReflectionUtils .getField(this.headersField, message.getHeaders()); PathPatternRouteMatcher matcher = new PathPatternRouteMatcher(); - headersMap.put(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, matcher.parseRoute(destination)); + headersMap.put(RECONSILED_LOOKUP_DESTINATION_HEADER, matcher.parseRoute(destination)); } protected static final class MessageHandlerMethodArgumentResolver implements SyncHandlerMethodArgumentResolver { diff --git a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketListenerFunction.java b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketListenerFunction.java index ac022de24..04ca9027e 100644 --- a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketListenerFunction.java +++ b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketListenerFunction.java @@ -96,7 +96,13 @@ class RSocketListenerFunction implements Function>, Publish } else { dataFlux = dataFlux.flatMap((data) -> { - Object result = this.targetFunction.isSupplier() ? this.targetFunction.apply(null) : this.targetFunction.apply(data); + Message incoming = (Message) data; + Message sanitizedMessage = MessageBuilder.withPayload(incoming.getPayload()).copyHeaders(incoming.getHeaders()) + .removeHeader("dataBufferFactory") + .removeHeader("rsocketRequester") + .removeHeader("rsocketResponse") + .build(); + Object result = this.targetFunction.isSupplier() ? this.targetFunction.apply(null) : this.targetFunction.apply(sanitizedMessage); return result instanceof Publisher ? (Publisher>) result : Mono.just((Message) result); diff --git a/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationRoutingTests.java b/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationRoutingTests.java index 8ec7fd459..287f9b68a 100644 --- a/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationRoutingTests.java +++ b/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationRoutingTests.java @@ -16,6 +16,7 @@ package org.springframework.cloud.function.rsocket; + import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; @@ -32,10 +33,14 @@ import org.springframework.cloud.function.context.config.RoutingFunction; import org.springframework.context.ConfigurableApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.messaging.Message; +import org.springframework.messaging.handler.DestinationPatternsMessageCondition; import org.springframework.messaging.rsocket.RSocketRequester; import org.springframework.util.MimeTypeUtils; import org.springframework.util.SocketUtils; +import static org.assertj.core.api.Assertions.assertThat; + /** * * @author Oleg Zhurakousky @@ -139,6 +144,35 @@ public class RSocketAutoConfigurationRoutingTests { } } + @Test + public void testRoutingWithDefinitionMessageFunction() { + int port = SocketUtils.findAvailableTcpPort(); + try ( + ConfigurableApplicationContext applicationContext = + new SpringApplicationBuilder(SampleFunctionConfiguration.class) + .web(WebApplicationType.NONE) + .run("--logging.level.org.springframework.cloud.function=DEBUG", + "--spring.cloud.function.definition=uppercase", + "--spring.cloud.function.routing-expression=headers.func_name", + "--spring.cloud.function.expected-content-type=text/plain", + "--spring.rsocket.server.port=" + port); + ) { + RSocketRequester.Builder rsocketRequesterBuilder = + applicationContext.getBean(RSocketRequester.Builder.class); + + rsocketRequesterBuilder.tcp("localhost", port) + .route("uppercase") + .metadata("{\"func_name\":\"uppercaseMessage\"}", MimeTypeUtils.APPLICATION_JSON) + .data("hello") + .retrieveMono(String.class) + .as(StepVerifier::create) + .expectNext("HELLO") + .expectComplete() + .verify(); + + } + } + @EnableAutoConfiguration @Configuration public static class SampleFunctionConfiguration { @@ -150,6 +184,17 @@ public class RSocketAutoConfigurationRoutingTests { return v -> v.toUpperCase(); } + @Bean + public Function, String> uppercaseMessage() { + return msg -> { + assertThat(msg.getHeaders() + .get(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER)).toString().equals("uppercase"); + assertThat(msg.getHeaders() + .get(FunctionRSocketMessageHandler.RECONSILED_LOOKUP_DESTINATION_HEADER)).toString().equals(RoutingFunction.FUNCTION_NAME); + return msg.getPayload().toUpperCase(); + }; + } + @Bean public Function concat() { return v -> v + v;