diff --git a/spring-cloud-function-rsocket/pom.xml b/spring-cloud-function-rsocket/pom.xml index 2e4c0c323..c8906b6ed 100644 --- a/spring-cloud-function-rsocket/pom.xml +++ b/spring-cloud-function-rsocket/pom.xml @@ -20,7 +20,16 @@ - + + org.springframework.boot + spring-boot-starter-rsocket + + + com.fasterxml.jackson.dataformat + jackson-dataformat-cbor + + + io.rsocket rsocket-core diff --git a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/FunctionRSocketMessageHandler.java b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/FunctionRSocketMessageHandler.java new file mode 100644 index 000000000..c7492b5d7 --- /dev/null +++ b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/FunctionRSocketMessageHandler.java @@ -0,0 +1,144 @@ +/* + * Copyright 2020-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.cloud.function.rsocket; + +import java.lang.reflect.Method; +import java.util.Collections; +import java.util.List; +import java.util.function.Function; + +import io.rsocket.frame.FrameType; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.MethodParameter; +import org.springframework.core.ReactiveAdapterRegistry; +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.ByteArrayDecoder; +import org.springframework.core.codec.ByteArrayEncoder; +import org.springframework.core.codec.Decoder; +import org.springframework.core.codec.Encoder; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.lang.Nullable; +import org.springframework.messaging.Message; +import org.springframework.messaging.handler.CompositeMessageCondition; +import org.springframework.messaging.handler.DestinationPatternsMessageCondition; +import org.springframework.messaging.handler.invocation.reactive.HandlerMethodArgumentResolver; +import org.springframework.messaging.handler.invocation.reactive.HandlerMethodReturnValueHandler; +import org.springframework.messaging.handler.invocation.reactive.SyncHandlerMethodArgumentResolver; +import org.springframework.messaging.rsocket.annotation.support.RSocketFrameTypeMessageCondition; +import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler; +import org.springframework.messaging.rsocket.annotation.support.RSocketPayloadReturnValueHandler; +import org.springframework.messaging.support.MessageBuilder; +import org.springframework.util.ReflectionUtils; + +/** + * An {@link RSocketMessageHandler} extension for Spring Cloud Function specifics. + * + * @author Artem Bilan + * + * @since 3.1 + */ +public class FunctionRSocketMessageHandler extends RSocketMessageHandler { + + private static final Method FUNCTION_APPLY_METHOD = + ReflectionUtils.findMethod(Function.class, "apply", (Class[]) null); + + private static final RSocketFrameTypeMessageCondition REQUEST_CONDITION = + new RSocketFrameTypeMessageCondition( + FrameType.REQUEST_FNF, + FrameType.REQUEST_RESPONSE, + FrameType.REQUEST_STREAM, + FrameType.REQUEST_CHANNEL); + + public FunctionRSocketMessageHandler() { + setHandlerPredicate((clazz) -> false); + } + + + @Override + public void afterPropertiesSet() { + setEncoders(Collections.singletonList(new ByteArrayEncoder())); + super.afterPropertiesSet(); + } + + public void registerFunctionHandler(Function function, String route) { + CompositeMessageCondition condition = + new CompositeMessageCondition(REQUEST_CONDITION, + new DestinationPatternsMessageCondition(new String[]{ route }, + obtainRouteMatcher())); + registerHandlerMethod(function, FUNCTION_APPLY_METHOD, condition); + } + + @Override + protected List initArgumentResolvers() { + return Collections.singletonList(new MessageHandlerMethodArgumentResolver()); + } + + @SuppressWarnings("unchecked") + @Override + protected List initReturnValueHandlers() { + return Collections.singletonList(new FunctionRSocketPayloadReturnValueHandler((List>) getEncoders(), + getReactiveAdapterRegistry())); + } + + protected static final class MessageHandlerMethodArgumentResolver implements SyncHandlerMethodArgumentResolver { + + private final Decoder decoder = new ByteArrayDecoder(); + + @Override + public boolean supportsParameter(MethodParameter parameter) { + return true; + } + + @SuppressWarnings("unchecked") + @Override + public Object resolveArgumentValue(MethodParameter parameter, Message message) { + Flux data; + Object payload = message.getPayload(); + if (payload instanceof DataBuffer) { + data = Flux.just((DataBuffer) payload); + } + else { + data = Flux.from((Publisher) payload); + } + Flux decoded = this.decoder.decode(data, ResolvableType.forType(byte[].class), null, null); + return MessageBuilder.createMessage(decoded, message.getHeaders()); + } + + } + + protected static final class FunctionRSocketPayloadReturnValueHandler extends RSocketPayloadReturnValueHandler { + + public FunctionRSocketPayloadReturnValueHandler(List> encoders, ReactiveAdapterRegistry registry) { + super(encoders, registry); + } + + @Override + public Mono handleReturnValue(@Nullable Object returnValue, MethodParameter returnType, + Message message) { + + if (returnValue instanceof Publisher && !message.getHeaders().containsKey(RESPONSE_HEADER)) { + return Mono.from((Publisher) returnValue).then(); + } + return super.handleReturnValue(returnValue, returnType, message); + } + + } + +} diff --git a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketAutoConfiguration.java b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketAutoConfiguration.java index c295d7dc1..f32902874 100644 --- a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketAutoConfiguration.java +++ b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketAutoConfiguration.java @@ -16,32 +16,32 @@ package org.springframework.cloud.function.rsocket; -import java.net.InetSocketAddress; +import java.net.URI; +import java.util.regex.Pattern; -import io.rsocket.RSocket; -import io.rsocket.SocketAcceptor; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.beans.BeansException; -import org.springframework.beans.factory.InitializingBean; +import org.springframework.beans.factory.ObjectProvider; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.autoconfigure.rsocket.RSocketMessageHandlerCustomizer; import org.springframework.boot.context.properties.EnableConfigurationProperties; -import org.springframework.boot.rsocket.context.RSocketServerBootstrap; -import org.springframework.boot.rsocket.server.RSocketServerFactory; import org.springframework.cloud.function.context.FunctionCatalog; import org.springframework.cloud.function.context.FunctionProperties; import org.springframework.cloud.function.context.FunctionRegistration; import org.springframework.cloud.function.context.FunctionRegistry; import org.springframework.cloud.function.context.catalog.FunctionTypeUtils; import org.springframework.cloud.function.context.catalog.SimpleFunctionRegistry.FunctionInvocationWrapper; -import org.springframework.cloud.function.json.JsonMapper; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.springframework.context.support.GenericApplicationContext; +import org.springframework.context.annotation.Primary; +import org.springframework.messaging.rsocket.RSocketRequester; +import org.springframework.messaging.rsocket.RSocketRequester.Builder; +import org.springframework.messaging.rsocket.RSocketStrategies; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -55,104 +55,89 @@ import org.springframework.util.StringUtils; @Configuration(proxyBeanMethods = false) @EnableConfigurationProperties({ FunctionProperties.class, RSocketFunctionProperties.class }) @ConditionalOnProperty(name = FunctionProperties.PREFIX + ".rsocket.enabled", matchIfMissing = true) -class RSocketAutoConfiguration { +class RSocketAutoConfiguration implements ApplicationContextAware { - private static Log logger = LogFactory.getLog(RSocketAutoConfiguration.class); + private static final Log LOGGER = LogFactory.getLog(RSocketAutoConfiguration.class); - @Bean - public FunctionToRSocketBinder functionToDestinationBinder(FunctionCatalog functionCatalog, - FunctionProperties functionProperties, JsonMapper jsonMapper) { - return new FunctionToRSocketBinder(functionCatalog, functionProperties, jsonMapper); + private static final Pattern WS_URI_PATTERN = Pattern.compile("^(https?|wss?)://.+"); + + private ApplicationContext applicationContext; + + @Override + public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { + this.applicationContext = applicationContext; } @Bean @ConditionalOnMissingBean - @ConditionalOnProperty("spring.rsocket.server.port") - RSocketServerBootstrap rSocketServerBootstrap(RSocketServerFactory rSocketServerFactory, - FunctionToRSocketBinder binder) { - return new RSocketServerBootstrap(rSocketServerFactory, SocketAcceptor.with(binder.getRSocket())); + @Primary + public FunctionRSocketMessageHandler functionRSocketMessageHandler(RSocketStrategies rSocketStrategies, + ObjectProvider customizers, FunctionCatalog functionCatalog, + FunctionProperties functionProperties) { + + FunctionRSocketMessageHandler rsocketMessageHandler = new FunctionRSocketMessageHandler(); + rsocketMessageHandler.setRSocketStrategies(rSocketStrategies); + customizers.orderedStream().forEach((customizer) -> customizer.customize(rsocketMessageHandler)); + registerFunctionsWithRSocketHandler(rsocketMessageHandler, functionCatalog, functionProperties); + return rsocketMessageHandler; } - /** - * - */ - static class FunctionToRSocketBinder implements InitializingBean, ApplicationContextAware { - - private final FunctionCatalog functionCatalog; - - private final FunctionProperties functionProperties; - - private final JsonMapper jsonMapper; - - private RSocketListenerFunction invocableFunction; - - private GenericApplicationContext context; - - FunctionToRSocketBinder(FunctionCatalog functionCatalog, FunctionProperties functionProperties, JsonMapper jsonMapper) { - this.functionCatalog = functionCatalog; - this.functionProperties = functionProperties; - this.jsonMapper = jsonMapper; - } - - @Override - public void afterPropertiesSet() throws Exception { - String definition = this.functionProperties.getDefinition(); - if (!StringUtils.hasText(definition)) { - FunctionInvocationWrapper f = this.functionCatalog.lookup(""); - if (f != null) { - definition = f.getFunctionDefinition(); - } - } - Assert.isTrue(StringUtils.hasText(definition), "Failed to determine target function for RSocket."); - this.registerRsocketForwardingFunctionIfNecessary(definition); - // TODO externalize content-type + private void registerFunctionsWithRSocketHandler(FunctionRSocketMessageHandler rsocketMessageHandler, + FunctionCatalog functionCatalog, FunctionProperties functionProperties) { + String definition = functionProperties.getDefinition(); + if (StringUtils.hasText(definition)) { + String rootFunctionName = registerRSocketForwardingFunctionIfNecessary(definition, functionCatalog); + //TODO externalize content-type FunctionInvocationWrapper function = functionCatalog.lookup(definition, "application/json"); - if (function.isSupplier()) { - throw new UnsupportedOperationException("Supplier is not currently supported for RSocket interaction"); - } - - this.invocableFunction = new RSocketListenerFunction(function, this.jsonMapper); + rsocketMessageHandler.registerFunctionHandler(new RSocketListenerFunction(function), rootFunctionName); } - - RSocket getRSocket() { - if (this.invocableFunction == null) { - return null; - } - return this.invocableFunction.getRsocket(); + else { + functionCatalog.getNames(null) + .forEach((name) -> { + FunctionInvocationWrapper function = functionCatalog.lookup(name, "application/json"); + rsocketMessageHandler.registerFunctionHandler(new RSocketListenerFunction(function), name); + }); } + } - @SuppressWarnings({ "rawtypes", "unchecked" }) - private void registerRsocketForwardingFunctionIfNecessary(String definition) { - String[] names = StringUtils.delimitedListToStringArray(definition.replaceAll(",", "|").trim(), "|"); - - for (String name : names) { - if (!this.context.containsBean(name)) { // this means RSocket - if (logger.isDebugEnabled()) { - logger.debug("Registering rsocket forwarder for '" + name + "' function."); - } - String[] functionToRSocketDefinition = StringUtils.delimitedListToStringArray(name, ">"); - Assert.isTrue(functionToRSocketDefinition.length == 2, "Must only contain one output redirect"); - FunctionInvocationWrapper function = functionCatalog.lookup(functionToRSocketDefinition[0], - "application/json"); - - String[] hostPort = StringUtils.delimitedListToStringArray(functionToRSocketDefinition[1], ":"); - InetSocketAddress outputAddress = InetSocketAddress.createUnresolved(hostPort[0], - Integer.valueOf(hostPort[1])); - - RSocketForwardingFunction rsocketFunction = new RSocketForwardingFunction(function, outputAddress); - FunctionRegistration functionRegistration = new FunctionRegistration(rsocketFunction, name); - - functionRegistration - .type(FunctionTypeUtils.discoverFunctionTypeFromClass(RSocketListenerFunction.class)); - ((FunctionRegistry) this.functionCatalog).register(functionRegistration); + private String registerRSocketForwardingFunctionIfNecessary(String definition, FunctionCatalog functionCatalog) { + String[] names = StringUtils.delimitedListToStringArray(definition.replaceAll(",", "|").trim(), "|"); + String rootFunctionName = names[0]; + for (String name : names) { + if (!this.applicationContext.containsBean(name)) { // this means RSocket + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Registering RSocket forwarder for '" + name + "' function."); } + String[] functionToRSocketDefinition = StringUtils.delimitedListToStringArray(name, ">"); + Assert.isTrue(functionToRSocketDefinition.length == 2, "Must only contain one output redirect"); + FunctionInvocationWrapper function = functionCatalog.lookup(functionToRSocketDefinition[0], "application/json"); + + String[] hostPort = StringUtils.delimitedListToStringArray(functionToRSocketDefinition[1], ":"); + + rootFunctionName = function.getFunctionDefinition(); + String forwardingUrl = functionToRSocketDefinition[1]; + RSocketRequester rsocketRequester; + + Builder rsocketRequesterBuilder = RSocketRequester.builder(); + + if (WS_URI_PATTERN.matcher(forwardingUrl).matches()) { + rsocketRequester = rsocketRequesterBuilder.websocket(URI.create(forwardingUrl)); + } + else { + rsocketRequester = rsocketRequesterBuilder.tcp(hostPort[0], Integer.parseInt(hostPort[1])); + } + + RSocketForwardingFunction rsocketFunction = + new RSocketForwardingFunction(function, rsocketRequester, null); + FunctionRegistration functionRegistration = + new FunctionRegistration<>(rsocketFunction, name); + functionRegistration.type( + FunctionTypeUtils.discoverFunctionTypeFromClass(RSocketForwardingFunction.class)); + ((FunctionRegistry) functionCatalog).register(functionRegistration); } } - @Override - public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { - this.context = (GenericApplicationContext) applicationContext; - } + return rootFunctionName; } } diff --git a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketForwardingFunction.java b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketForwardingFunction.java index 99dbb759c..96dabd2a8 100644 --- a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketForwardingFunction.java +++ b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketForwardingFunction.java @@ -16,38 +16,25 @@ package org.springframework.cloud.function.rsocket; -import java.net.InetSocketAddress; -import java.nio.ByteBuffer; -import java.time.Duration; import java.util.function.Function; -import io.rsocket.Payload; -import io.rsocket.RSocket; -import io.rsocket.core.RSocketConnector; -import io.rsocket.transport.netty.client.TcpClientTransport; -import io.rsocket.util.DefaultPayload; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.reactivestreams.Publisher; import reactor.core.publisher.Mono; -import reactor.util.retry.Retry; import org.springframework.cloud.function.context.catalog.SimpleFunctionRegistry.FunctionInvocationWrapper; import org.springframework.messaging.Message; -import org.springframework.messaging.support.MessageBuilder; +import org.springframework.messaging.rsocket.RSocketRequester; +import org.springframework.messaging.support.GenericMessage; /** - * - * An implementation of {@link Function} to support distributed function composition. - *
- * This function wraps target function and forwards the result of - * the invocation of the target function to another RSocket returning the result of such forwarding as {@link Publisher}. - *

- * A typical example is `spring.cloud.function.definition=uppercase>localhost:8888'. - *
- * In this case 'uppercase' is targetFunction which will be invoked during the call to 'apply' and the result of - * this invocation sent to RSocket reachable at localhost:8888. + * Wrapper over an instance of target Function (represented by {@link FunctionInvocationWrapper}) + * which will use the result of the invocation of such function as an input to another RSocket + * effectively composing two functions over RSocket. + *

+ * Note: the remote RSocket route is not necessary to be as a Spring Cloud Function binding. * * @author Oleg Zhurakousky * @author Artem Bilan @@ -59,39 +46,37 @@ class RSocketForwardingFunction implements Function, Publisher rsocketMono; - private final FunctionInvocationWrapper targetFunction; - RSocketForwardingFunction(FunctionInvocationWrapper targetFunction, InetSocketAddress outputAddress) { + private final RSocketRequester rSocketRequester; + +// private final String remoteFunctionName; + + RSocketForwardingFunction(FunctionInvocationWrapper targetFunction, RSocketRequester rsocketRequester, + String remoteFunctionName) { + this.targetFunction = targetFunction; - this.rsocketMono = - outputAddress == null - ? null - : RSocketConnector.create() - .reconnect(Retry.backoff(5, Duration.ofSeconds(1))) - .connect(TcpClientTransport.create(outputAddress)); + this.rSocketRequester = rsocketRequester; +// this.remoteFunctionName = remoteFunctionName; } - @SuppressWarnings("unchecked") @Override public Publisher> apply(Message input) { if (LOGGER.isDebugEnabled()) { LOGGER.debug("Executing: " + this.targetFunction); } - Object rawResult = this.targetFunction.apply(input); - return this.rsocketMono - .flatMapMany((rsocket) -> - rsocket.requestStream(DefaultPayload.create(((Message) rawResult).getPayload()))) - .map(this::buildResultMessage); - } + Mono targetFunctionCall = Mono.just(input) + .map(this.targetFunction) + .cast(Message.class) + .map(Message::getPayload); - private Message buildResultMessage(Payload payload) { - ByteBuffer payloadBuffer = payload.getData(); - byte[] payloadData = new byte[payloadBuffer.remaining()]; - payloadBuffer.get(payloadData); - return MessageBuilder.withPayload(payloadData).build(); + return this.rSocketRequester +// .route(this.remoteFunctionName) + .route("uppercase") + .data(targetFunctionCall, byte[].class) + .retrieveFlux(byte[].class) + .map(GenericMessage::new); } } diff --git a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketListenerFunction.java b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketListenerFunction.java index 1e0ab0770..8f2865378 100644 --- a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketListenerFunction.java +++ b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketListenerFunction.java @@ -17,181 +17,93 @@ package org.springframework.cloud.function.rsocket; import java.lang.reflect.Type; -import java.nio.ByteBuffer; -import java.util.Map; import java.util.function.Function; -import io.rsocket.Payload; -import io.rsocket.RSocket; -import io.rsocket.util.DefaultPayload; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; +import io.rsocket.frame.FrameType; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import org.springframework.cloud.function.context.catalog.FunctionTypeUtils; import org.springframework.cloud.function.context.catalog.SimpleFunctionRegistry.FunctionInvocationWrapper; -import org.springframework.cloud.function.json.JsonMapper; import org.springframework.messaging.Message; +import org.springframework.messaging.rsocket.annotation.support.RSocketFrameTypeMessageCondition; import org.springframework.messaging.support.MessageBuilder; -import org.springframework.util.CollectionUtils; + + /** - * Wrapper over an instance of target Function (represented by {@link FunctionInvocationWrapper}) - * which will use the result of the invocation of such function as an input to another RSocket - * effectively composing two functions over RSocket. + * A function wrapper which is bound onto an RSocket route. * * @author Oleg Zhurakousky + * @author Artem Bilan + * * @since 3.1 */ -class RSocketListenerFunction implements Function, Publisher>> { - - private static String splash = " ____ _ _______ __ ____ __ _ ___ ____ __ __ \n" + - " / __/__ ____(_)__ ___ _ / ___/ /__ __ _____/ / / __/_ _____ ____/ /_(_)__ ___ / _ \\/ __/__ ____/ /_____ / /_\n" + - " _\\ \\/ _ \\/ __/ / _ \\/ _ `/ / /__/ / _ \\/ // / _ / / _// // / _ \\/ __/ __/ / _ \\/ _ \\ / , _/\\ \\/ _ \\/ __/ '_/ -_) __/\n" + - "/___/ .__/_/ /_/_//_/\\_, / \\___/_/\\___/\\_,_/\\_,_/ /_/ \\_,_/_//_/\\__/\\__/_/\\___/_//_/ /_/|_/___/\\___/\\__/_/\\_\\\\__/\\__/ \n" + - " /_/ /___/ \n" + - ""; - - private static Log logger = LogFactory.getLog(RSocketListenerFunction.class); +public class RSocketListenerFunction implements Function>, Publisher> { private final FunctionInvocationWrapper targetFunction; - private RSocket rsocket; - - private final JsonMapper jsonMapper; - - RSocketListenerFunction(FunctionInvocationWrapper targetFunction, JsonMapper jsonMapper) { + RSocketListenerFunction(FunctionInvocationWrapper targetFunction) { this.targetFunction = targetFunction; - this.jsonMapper = jsonMapper; } - @SuppressWarnings("unchecked") @Override - public Publisher> apply(Message input) { - if (logger.isDebugEnabled()) { - logger.debug("Executiing: " + this.targetFunction); + public Publisher apply(Message> input) { + FrameType frameType = RSocketFrameTypeMessageCondition.getFrameType(input); + switch (frameType) { + case REQUEST_FNF: + return handle(input); + case REQUEST_RESPONSE: + case REQUEST_STREAM: + case REQUEST_CHANNEL: + return handleAndReply(input); + default: + throw new UnsupportedOperationException(); } - - Object rawResult = this.targetFunction.apply(input); - return rawResult instanceof Publisher ? (Publisher>) rawResult : Mono.just((Message) rawResult); } - public RSocket getRsocket() { - if (this.rsocket == null) { - Type functionType = this.targetFunction.getFunctionType(); - - if (this.rsocket == null) { - this.rsocket = this.buildRSocket(this.targetFunction, functionType, this); + @SuppressWarnings({ "unchecked", "rawtypes" }) + private Mono handle(Message> messageToProcess) { + if (this.targetFunction.isConsumer()) { + Flux dataFlux = + messageToProcess.getPayload() + .map((payload) -> MessageBuilder.createMessage(payload, messageToProcess.getHeaders())); + if (isFunctionInputReactive(this.targetFunction.getFunctionType())) { + dataFlux = dataFlux.transform((Function) this.targetFunction); } - this.printSplashScreen(this.targetFunction.getFunctionDefinition(), functionType); + else { + dataFlux = dataFlux.doOnNext(this.targetFunction); + } + return dataFlux.then(); + } + else { + return Mono.error(new IllegalStateException("Only 'Consumer' can handle 'fire-and-forget' RSocket frame.")); } - return this.rsocket; } - private RSocket buildRSocket(FunctionInvocationWrapper targetFunction, Type functionType, Function, Publisher>> function) { - String definition = targetFunction.getFunctionDefinition(); - RSocket clientRSocket = new RSocket() { // imperative function or Function = requestResponse - @Override - public Mono requestResponse(Payload payload) { - if (logger.isDebugEnabled()) { - logger.debug("Invoking function '" + definition + "' as RSocket `requestResponse`."); - } - - if (isFunctionReactive(functionType)) { - Flux result = this.requestChannel(Flux.just(payload)); - return Mono.from(result); - } - else { - Message inputMessage = deserealizePayload(payload); - Mono> result = Mono.from(function.apply(inputMessage)); - return result.map(message -> DefaultPayload.create(message.getPayload(), jsonMapper.toJson(message.getHeaders()))); - } - } - - @Override - public Flux requestStream(Payload payload) { - if (logger.isDebugEnabled()) { - logger.debug("Invoking function '" + definition + "' as RSocket `requestStream`."); - } - if (isFunctionReactive(functionType)) { - return this.requestChannel(Flux.just(payload)); - } - else { - Message inputMessage = deserealizePayload(payload); - Flux> result = Flux.from(function.apply(inputMessage)); - return result.map(message -> DefaultPayload.create(message.getPayload())); - } - } - - @SuppressWarnings({ "unchecked", "rawtypes" }) - @Override - public Flux requestChannel(Publisher payloads) { - if (logger.isDebugEnabled()) { - logger.debug("Invoking function '" + definition + "' as RSocket `requestChannel`."); - } - if (isFunctionReactive(functionType)) { - return Flux.from(payloads) - .transform(inputFlux -> inputFlux.map(payload -> deserealizePayload(payload))) - .transform((Function) targetFunction) - .transform(outputFlux -> ((Flux>) outputFlux).map(message -> DefaultPayload.create(message.getPayload()))); - } - else { - return Flux.from(payloads) - .transform(flux -> { - return flux.flatMap(payload -> { - Message inputMessage = deserealizePayload(payload); - Flux> result = Flux.from(function.apply(inputMessage)); - return result; - }); - }) - .doOnNext(System.out::println) - .transform(outputFlux -> outputFlux.map(message -> DefaultPayload.create(message.getPayload()))); - } - - } - }; - return clientRSocket; + @SuppressWarnings({ "unchecked", "rawtypes" }) + private Flux handleAndReply(Message> messageToProcess) { + Flux dataFlux = + messageToProcess.getPayload() + .map((payload) -> MessageBuilder.createMessage(payload, messageToProcess.getHeaders())); + if (isFunctionInputReactive(this.targetFunction.getFunctionType())) { + dataFlux = dataFlux.transform((Function) this.targetFunction); + } + else { + dataFlux = dataFlux.flatMap((data) -> { + Object result = this.targetFunction.apply(data); + return result instanceof Publisher + ? (Publisher>) result + : Mono.just((Message) result); + }); + } + return dataFlux.cast(Message.class).map(Message::getPayload); } - private static boolean isFunctionReactive(Type functionType) { + private static boolean isFunctionInputReactive(Type functionType) { Type inputType = FunctionTypeUtils.getInputType(functionType, 0); - Type outputType = FunctionTypeUtils.getOutputType(functionType, 0); - return FunctionTypeUtils.isPublisher(inputType) && FunctionTypeUtils.isFlux(outputType); - } - - @SuppressWarnings({ "rawtypes", "unchecked" }) - private Message deserealizePayload(Payload payload) { - ByteBuffer buffer = payload.getData(); - byte[] rawData = new byte[buffer.remaining()]; - buffer.get(rawData); - Map headers = null; - if (payload.hasMetadata()) { - try { - ByteBuffer metadata = payload.getMetadata(); - byte[] metadataBytes = new byte[metadata.remaining()]; - metadata.get(metadataBytes); - headers = this.jsonMapper.fromJson(metadataBytes, Map.class); - } - catch (Exception e) { - //throw new IllegalStateException(e); - logger.warn("Failed to extract headers from metadata", e); - } - } - MessageBuilder builder = MessageBuilder.withPayload(rawData); - if (!CollectionUtils.isEmpty(headers)) { - builder.copyHeaders(headers); - } - Message inputMessage = builder.build(); - return inputMessage; - - } - - private void printSplashScreen(String definition, Type type) { - System.out.println(splash); - System.out.println("Function Definition: " + definition + "; T[" + type + "]"); - System.out.println("======================================================\n"); + return FunctionTypeUtils.isPublisher(inputType); } } diff --git a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketRoutingAutoConfiguration.java b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketRoutingAutoConfiguration.java index db0131cd0..d5e5feaed 100644 --- a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketRoutingAutoConfiguration.java +++ b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketRoutingAutoConfiguration.java @@ -16,7 +16,6 @@ package org.springframework.cloud.function.rsocket; -import io.rsocket.SocketAcceptor; import io.rsocket.routing.client.spring.RoutingClientAutoConfiguration; import org.springframework.boot.autoconfigure.AutoConfigureAfter; @@ -24,7 +23,6 @@ import org.springframework.boot.autoconfigure.AutoConfigureBefore; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.cloud.function.context.FunctionProperties; -import org.springframework.cloud.function.rsocket.RSocketAutoConfiguration.FunctionToRSocketBinder; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.messaging.rsocket.RSocketConnectorConfigurer; @@ -45,8 +43,8 @@ class RSocketRoutingAutoConfiguration { @Bean public RSocketConnectorConfigurer functionRSocketConnectorConfigurer( - FunctionToRSocketBinder binder) { - return connector -> connector.acceptor(SocketAcceptor.with(binder.getRSocket())); + FunctionRSocketMessageHandler handler) { + return connector -> connector.acceptor(handler.responder()); } } diff --git a/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationTests.java b/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationTests.java index be09b54ae..41857db03 100644 --- a/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationTests.java +++ b/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationTests.java @@ -18,23 +18,25 @@ package org.springframework.cloud.function.rsocket; import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.Supplier; -import io.rsocket.Payload; -import io.rsocket.util.DefaultPayload; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; import org.springframework.boot.WebApplicationType; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.boot.builder.SpringApplicationBuilder; -import org.springframework.context.ApplicationContext; +import org.springframework.boot.rsocket.context.RSocketServerBootstrap; +import org.springframework.boot.rsocket.server.RSocketServer; +import org.springframework.context.ConfigurableApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.springframework.context.annotation.Import; +import org.springframework.core.env.ConfigurableEnvironment; import org.springframework.messaging.rsocket.RSocketRequester; -import org.springframework.util.Assert; +import org.springframework.test.util.ReflectionTestUtils; import org.springframework.util.SocketUtils; /** @@ -44,235 +46,342 @@ import org.springframework.util.SocketUtils; */ public class RSocketAutoConfigurationTests { @Test - public void testImperativeFunctionAsRequestReply() throws Exception { + public void testImperativeFunctionAsRequestReply() { int port = SocketUtils.findAvailableTcpPort(); - ApplicationContext context = new SpringApplicationBuilder(SampleFunctionConfiguration.class).web(WebApplicationType.NONE).run( - "--logging.level.org.springframework.cloud.function=DEBUG", - "--spring.cloud.function.definition=uppercase", - "--spring.rsocket.server.port=" + port); + try ( + ConfigurableApplicationContext applicationContext = + new SpringApplicationBuilder(SampleFunctionConfiguration.class) + .web(WebApplicationType.NONE) + .run("--logging.level.org.springframework.cloud.function=DEBUG", + "--spring.cloud.function.definition=uppercase", + "--spring.rsocket.server.port=" + port); + ) { + RSocketRequester.Builder rsocketRequesterBuilder = + applicationContext.getBean(RSocketRequester.Builder.class); - RSocketRequester requester = context.getBean(RSocketRequester.class); - Mono result = requester.rsocket().requestResponse(DefaultPayload.create("\"hello\"")).map(Payload::getDataUtf8); - - StepVerifier - .create(result) - .expectNext("\"HELLO\"") - .expectComplete() - .verify(); + rsocketRequesterBuilder.tcp("localhost", port) + .route("uppercase") + .data("\"hello\"") + .retrieveMono(String.class) + .as(StepVerifier::create) + .expectNext("\"HELLO\"") + .expectComplete() + .verify(); + } } @Test - public void testImperativeFunctionAsRequestReplyWithMetadata() throws Exception { + public void testSupplierAsRequestReply() { int port = SocketUtils.findAvailableTcpPort(); - ApplicationContext context = new SpringApplicationBuilder(SampleFunctionConfiguration.class).web(WebApplicationType.NONE).run( - "--logging.level.org.springframework.cloud.function=DEBUG", - "--spring.cloud.function.definition=uppercase", - "--spring.rsocket.server.port=" + port); + try ( + ConfigurableApplicationContext applicationContext = + new SpringApplicationBuilder(SampleFunctionConfiguration.class) + .web(WebApplicationType.NONE) + .run("--logging.level.org.springframework.cloud.function=DEBUG", + "--spring.cloud.function.definition=source", + "--spring.rsocket.server.port=" + port); + ) { + RSocketRequester.Builder rsocketRequesterBuilder = + applicationContext.getBean(RSocketRequester.Builder.class); - RSocketRequester requester = context.getBean(RSocketRequester.class); - Mono result = requester.rsocket().requestResponse(DefaultPayload.create("\"hello\"", "{\"name\":\"bob\", \"age\":23}")) - .map(payload -> { - Assert.hasText(payload.getMetadataUtf8(), "Metadata must not be null"); - return payload.getDataUtf8(); - }); - - StepVerifier - .create(result) - .expectNext("\"HELLO\"") - .expectComplete() - .verify(); + rsocketRequesterBuilder.tcp("localhost", port) + .route("source") + .data("\"hello\"") + .retrieveMono(String.class) + .as(StepVerifier::create) + .expectNext("\"test data\"") + .expectComplete() + .verify(); + } } @Test - public void testImperativeFunctionAsRequestStream() throws Exception { + public void testImperativeFunctionAsRequestStream() { int port = SocketUtils.findAvailableTcpPort(); - ApplicationContext context = new SpringApplicationBuilder(SampleFunctionConfiguration.class).web(WebApplicationType.NONE).run( - "--logging.level.org.springframework.cloud.function=DEBUG", - "--spring.cloud.function.definition=uppercase", - "--spring.rsocket.server.port=" + port); + try ( + ConfigurableApplicationContext applicationContext = + new SpringApplicationBuilder(SampleFunctionConfiguration.class) + .web(WebApplicationType.NONE) + .run("--logging.level.org.springframework.cloud.function=DEBUG", + "--spring.cloud.function.definition=uppercase", + "--spring.rsocket.server.port=" + port); + ) { + RSocketRequester.Builder rsocketRequesterBuilder = + applicationContext.getBean(RSocketRequester.Builder.class); - RSocketRequester requester = context.getBean(RSocketRequester.class); - Flux result = requester.rsocket().requestStream(DefaultPayload.create("\"hello\"")).map(Payload::getDataUtf8); - - StepVerifier - .create(result) - .expectNext("\"HELLO\"") - .expectComplete() - .verify(); + rsocketRequesterBuilder.tcp("localhost", port) + .route("uppercase") + .data("\"hello\"") + .retrieveFlux(String.class) + .as(StepVerifier::create) + .expectNext("\"HELLO\"") + .expectComplete() + .verify(); + } } @Test - public void testImperativeFunctionAsRequestChannel() throws Exception { + public void testImperativeFunctionAsRequestChannel() { int port = SocketUtils.findAvailableTcpPort(); - ApplicationContext context = new SpringApplicationBuilder(SampleFunctionConfiguration.class).web(WebApplicationType.NONE).run( - "--logging.level.org.springframework.cloud.function=DEBUG", - "--spring.cloud.function.definition=uppercase", - "--spring.rsocket.server.port=" + port); + try ( + ConfigurableApplicationContext applicationContext = + new SpringApplicationBuilder(SampleFunctionConfiguration.class) + .web(WebApplicationType.NONE) + .run("--logging.level.org.springframework.cloud.function=DEBUG", + "--spring.cloud.function.definition=uppercase", + "--spring.rsocket.server.port=" + port); + ) { + RSocketRequester.Builder rsocketRequesterBuilder = + applicationContext.getBean(RSocketRequester.Builder.class); - RSocketRequester requester = context.getBean(RSocketRequester.class); - Flux result = requester.rsocket().requestChannel(Flux.just( - DefaultPayload.create("\"Ricky\""), - DefaultPayload.create("\"Julien\""), - DefaultPayload.create("\"Bubbles\"")) - ) - .map(Payload::getDataUtf8); - - StepVerifier.create(result) - .expectNext("\"RICKY\"") - .expectNext("\"JULIEN\"") - .expectNext("\"BUBBLES\"") - .expectComplete() - .verify(); + rsocketRequesterBuilder.tcp("localhost", port) + .route("uppercase") + .data(Flux.just("\"Ricky\"", "\"Julien\"", "\"Bubbles\"")) + .retrieveFlux(String.class) + .as(StepVerifier::create) + .expectNext("\"RICKY\"", "\"JULIEN\"", "\"BUBBLES\"") + .expectComplete() + .verify(); + } } @Test - public void testReactiveFunctionAsRequestReply() throws Exception { + public void testReactiveFunctionAsRequestReply() { int port = SocketUtils.findAvailableTcpPort(); - ApplicationContext context = new SpringApplicationBuilder(SampleFunctionConfiguration.class).web(WebApplicationType.NONE).run( - "--logging.level.org.springframework.cloud.function=DEBUG", - "--spring.cloud.function.definition=uppercaseReactive", - "--spring.rsocket.server.port=" + port); + try ( + ConfigurableApplicationContext applicationContext = + new SpringApplicationBuilder(SampleFunctionConfiguration.class) + .web(WebApplicationType.NONE) + .run("--logging.level.org.springframework.cloud.function=DEBUG", + "--spring.cloud.function.definition=uppercaseReactive", + "--spring.rsocket.server.port=" + port); + ) { + RSocketRequester.Builder rsocketRequesterBuilder = + applicationContext.getBean(RSocketRequester.Builder.class); - RSocketRequester requester = context.getBean(RSocketRequester.class); - Mono result = requester.rsocket().requestResponse(DefaultPayload.create("\"hello\"")).map(Payload::getDataUtf8); - - StepVerifier - .create(result) - .expectNext("\"HELLO\"") - .expectComplete() - .verify(); + rsocketRequesterBuilder.tcp("localhost", port) + .route("uppercaseReactive") + .data("\"hello\"") + .retrieveMono(String.class) + .as(StepVerifier::create) + .expectNext("\"HELLO\"") + .expectComplete() + .verify(); + } } @Test - public void testReactiveFunctionAsRequestStream() throws Exception { + public void testReactiveFunctionAsRequestStream() { int port = SocketUtils.findAvailableTcpPort(); - ApplicationContext context = new SpringApplicationBuilder(SampleFunctionConfiguration.class).web(WebApplicationType.NONE).run( - "--logging.level.org.springframework.cloud.function=DEBUG", - "--spring.cloud.function.definition=uppercaseReactive", - "--spring.rsocket.server.port=" + port); + try ( + ConfigurableApplicationContext applicationContext = + new SpringApplicationBuilder(SampleFunctionConfiguration.class) + .web(WebApplicationType.NONE) + .run("--logging.level.org.springframework.cloud.function=DEBUG", + "--spring.cloud.function.definition=uppercaseReactive", + "--spring.rsocket.server.port=" + port); + ) { + RSocketRequester.Builder rsocketRequesterBuilder = + applicationContext.getBean(RSocketRequester.Builder.class); - RSocketRequester requester = context.getBean(RSocketRequester.class); - Flux result = requester.rsocket().requestStream(DefaultPayload.create("\"hello\"")).map(Payload::getDataUtf8); - - StepVerifier - .create(result) - .expectNext("\"HELLO\"") - .expectComplete() - .verify(); + rsocketRequesterBuilder.tcp("localhost", port) + .route("uppercaseReactive") + .data("\"hello\"") + .retrieveFlux(String.class) + .as(StepVerifier::create) + .expectNext("\"HELLO\"") + .expectComplete() + .verify(); + } } @Test - public void testReactiveFunctionAsRequestChannel() throws Exception { + public void testReactiveFunctionAsRequestChannel() { int port = SocketUtils.findAvailableTcpPort(); - ApplicationContext context = new SpringApplicationBuilder(SampleFunctionConfiguration.class).web(WebApplicationType.NONE).run( - "--logging.level.org.springframework.cloud.function=DEBUG", - "--spring.cloud.function.definition=uppercaseReactive", - "--spring.rsocket.server.port=" + port); + try ( + ConfigurableApplicationContext applicationContext = + new SpringApplicationBuilder(SampleFunctionConfiguration.class) + .web(WebApplicationType.NONE) + .run("--logging.level.org.springframework.cloud.function=DEBUG", + "--spring.cloud.function.definition=uppercaseReactive", + "--spring.rsocket.server.port=" + port); + ) { + RSocketRequester.Builder rsocketRequesterBuilder = + applicationContext.getBean(RSocketRequester.Builder.class); - RSocketRequester requester = context.getBean(RSocketRequester.class); - Flux result = requester.rsocket().requestChannel(Flux.just( - DefaultPayload.create("\"Ricky\""), - DefaultPayload.create("\"Julien\""), - DefaultPayload.create("\"Bubbles\"")) - ) - .map(Payload::getDataUtf8); - - StepVerifier - .create(result) - .expectNext("\"RICKY\"") - .expectNext("\"JULIEN\"") - .expectNext("\"BUBBLES\"") - .expectComplete() - .verify(); + rsocketRequesterBuilder.tcp("localhost", port) + .route("uppercaseReactive") + .data(Flux.just("\"Ricky\"", "\"Julien\"", "\"Bubbles\"")) + .retrieveFlux(String.class) + .as(StepVerifier::create) + .expectNext("\"RICKY\"", "\"JULIEN\"", "\"BUBBLES\"") + .expectComplete() + .verify(); + } } + @Disabled @Test - public void testRequestReplyFunctionWithComposition() throws Exception { + public void testRequestReplyFunctionWithComposition() { int portA = SocketUtils.findAvailableTcpPort(); int portB = SocketUtils.findAvailableTcpPort(); - new SpringApplicationBuilder(SampleFunctionConfiguration.class).web(WebApplicationType.NONE).run( - "--logging.level.org.springframework.cloud.function=DEBUG", - "--spring.cloud.function.definition=uppercase|concat", - "--spring.rsocket.server.port=" + portA); + try ( + ConfigurableApplicationContext applicationContext = + new SpringApplicationBuilder(SampleFunctionConfiguration.class) + .web(WebApplicationType.NONE) + .run("--logging.level.org.springframework.cloud.function=DEBUG", + "--spring.cloud.function.definition=uppercase|concat", + "--spring.rsocket.server.port=" + portA); + ) { - ApplicationContext bContext = new SpringApplicationBuilder(AdditionalFunctionConfiguration.class).web(WebApplicationType.NONE).run( - "--logging.level.org.springframework.cloud.function=DEBUG", - "--spring.cloud.function.definition=reverse>localhost:" + portA + "|wrap", - "--spring.rsocket.server.port=" + portB); + try ( + ConfigurableApplicationContext applicationContext2 = + new SpringApplicationBuilder(AdditionalFunctionConfiguration.class) + .web(WebApplicationType.NONE) + .run("--logging.level.org.springframework.cloud.function=DEBUG", + "--spring.cloud.function.definition=reverse>localhost:" + portA + "|wrap", + "--spring.rsocket.server.port=" + portB); + ) { - RSocketRequester requester = bContext.getBean(RSocketRequester.class); + RSocketRequester.Builder rsocketRequesterBuilder = + applicationContext2.getBean(RSocketRequester.Builder.class); - Mono result = requester.rsocket().requestResponse(DefaultPayload.create("\"hello\"")).map(Payload::getDataUtf8); - StepVerifier - .create(result) - .expectNext("\"(OLLEHOLLEH)\"") - .expectComplete() - .verify(); + rsocketRequesterBuilder.tcp("localhost", portB) + .route("reverse") + .data("\"hello\"") + .retrieveMono(String.class) + .as(StepVerifier::create) + .expectNext("\"(OLLEHOLLEH)\"") + .expectComplete() + .verify(); + } + } + } + + @Disabled("TODO") + @Test + public void testCompositionOverWebSocket() { + try ( + ConfigurableApplicationContext applicationContext = + new SpringApplicationBuilder(SampleFunctionConfiguration.class) + .web(WebApplicationType.REACTIVE) + .run("--logging.level.org.springframework.cloud.function=DEBUG", + "--spring.cloud.function.definition=uppercase|concat", + "--spring.rsocket.server.transport=websocket", + "--spring.rsocket.server.mapping-path=rsockets", + "--server.port=0"); + ) { + ConfigurableEnvironment environment = applicationContext.getEnvironment(); + String httpServerPort = environment.getProperty("local.server.port"); + + try ( + ConfigurableApplicationContext applicationContext2 = + new SpringApplicationBuilder(AdditionalFunctionConfiguration.class) + .web(WebApplicationType.NONE) + .run("--logging.level.org.springframework.cloud.function=DEBUG", + "--spring.cloud.function.definition=reverse>http://localhost:" + httpServerPort + "/rsockets/uppercase|wrap", + "--spring.rsocket.server.port=0"); + ) { + RSocketServerBootstrap serverBootstrap = applicationContext2.getBean(RSocketServerBootstrap.class); + RSocketServer server = (RSocketServer) ReflectionTestUtils.getField(serverBootstrap, "server"); + + RSocketRequester.Builder rsocketRequesterBuilder = + applicationContext2.getBean(RSocketRequester.Builder.class); + + rsocketRequesterBuilder.tcp("localhost", server.address().getPort()) + .route("reverse") + .data("\"hello\"") + .retrieveMono(String.class) + .as(StepVerifier::create) + .expectNext("\"(OLLEHOLLEH)\"") + .expectComplete() + .verify(); + } + } } @Test - public void testRequestChannelFunction() throws Exception { - int port = SocketUtils.findAvailableTcpPort(); - ApplicationContext context = new SpringApplicationBuilder(SampleFunctionConfiguration.class).web(WebApplicationType.NONE).run( - "--logging.level.org.springframework.cloud.function=DEBUG", - "--spring.cloud.function.definition=uppercaseReactive", - "--spring.rsocket.server.port=" + port); + public void testFireAndForgetConsumer() { + try ( + ConfigurableApplicationContext applicationContext = + new SpringApplicationBuilder(SampleFunctionConfiguration.class) + .web(WebApplicationType.NONE) + .run("--logging.level.org.springframework.cloud.function=DEBUG", + "--spring.cloud.function.definition=log", + "--spring.rsocket.server.port=0"); + ) { + RSocketRequester.Builder rsocketRequesterBuilder = + applicationContext.getBean(RSocketRequester.Builder.class); + RSocketServerBootstrap serverBootstrap = applicationContext.getBean(RSocketServerBootstrap.class); + RSocketServer server = (RSocketServer) ReflectionTestUtils.getField(serverBootstrap, "server"); - RSocketRequester requester = context.getBean(RSocketRequester.class); - Flux result = requester.rsocket().requestChannel(Flux.just( - DefaultPayload.create("\"Ricky\""), - DefaultPayload.create("\"Julien\""), - DefaultPayload.create("\"Bubbles\"")) - ) - .map(Payload::getDataUtf8); + rsocketRequesterBuilder.tcp("localhost", server.address().getPort()) + .route("log") + .data("\"hello\"") + .send() + .as(StepVerifier::create) + .expectComplete() + .verify(); - StepVerifier - .create(result) - .expectNext("\"RICKY\"") - .expectNext("\"JULIEN\"") - .expectNext("\"BUBBLES\"") - .expectComplete() - .verify(); + applicationContext.getBean(SampleFunctionConfiguration.class).consumerData + .asMono() + .map(String::new) + .as(StepVerifier::create) + .expectNext("\"hello\"") + .expectComplete() + .verify(); + } } + @Test + public void testRsocketRoutesForAllFunctions() { + try ( + ConfigurableApplicationContext applicationContext = + new SpringApplicationBuilder(AdditionalFunctionConfiguration.class) + .web(WebApplicationType.NONE) + .run("--logging.level.org.springframework.cloud.function=DEBUG", + "--spring.rsocket.server.port=0"); + ) { + RSocketRequester.Builder rsocketRequesterBuilder = + applicationContext.getBean(RSocketRequester.Builder.class); + RSocketServerBootstrap serverBootstrap = applicationContext.getBean(RSocketServerBootstrap.class); + RSocketServer server = (RSocketServer) ReflectionTestUtils.getField(serverBootstrap, "server"); + RSocketRequester requester = rsocketRequesterBuilder.tcp("localhost", server.address().getPort()); + + requester.route("reverse") + .data("\"hello\"") + .retrieveMono(String.class) + .as(StepVerifier::create) + .expectNext("\"olleh\"") + .expectComplete() + .verify(); + + requester.route("wrap") + .data("\"hello\"") + .retrieveMono(String.class) + .as(StepVerifier::create) + .expectNext("\"(hello)\"") + .expectComplete() + .verify(); + } + } -// @Test -// public void testFireAndForgetConsumer() throws Exception { -// new SpringApplicationBuilder(SampleFunctionConfiguration.class) -// .run("--logging.level.org.springframework.cloud.function=DEBUG", -// "--spring.cloud.function.definition=log"); -// -// RSocket socket = RSocketConnector.connectWith(TcpClientTransport.create("localhost", 7000)) -// .log() -// .retryWhen(Retry.backoff(5, Duration.ofSeconds(1))) -// .block(); -// socket.fireAndForget(DefaultPayload.create("Hello")) -// .log() -// .onErrorContinue((e, x) -> { -// System.out.println(e); -// }) -// .block(); -// Thread.sleep(2000); -// System.out.println(); -// } @EnableAutoConfiguration @Configuration - @Import(RSocketTestConfiguration.class) public static class SampleFunctionConfiguration { + + final Sinks.One consumerData = Sinks.one(); + @Bean public Function uppercase() { - return v -> { - return v.toUpperCase(); - }; + return String::toUpperCase; } @Bean public Function concat() { - return v -> { - return v + v; - }; + return v -> v + v; } @Bean @@ -290,28 +399,30 @@ public class RSocketAutoConfigurationTests { @Bean public Consumer log() { - return v -> { - System.out.println("==> In Consumer: " + new String(v)); - }; + return this.consumerData::emitValue; } + + @Bean + public Supplier source() { + return () -> "test data"; + } + } @EnableAutoConfiguration @Configuration - @Import(RSocketTestConfiguration.class) public static class AdditionalFunctionConfiguration { + @Bean public Function reverse() { - return v -> { - return new StringBuilder(v).reverse().toString(); - }; + return v -> new StringBuilder(v).reverse().toString(); } @Bean public Function wrap() { - return v -> { - return "(" + v + ")"; - }; + return v -> "(" + v + ")"; } + } + } diff --git a/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketTestConfiguration.java b/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketTestConfiguration.java deleted file mode 100644 index 4706bb024..000000000 --- a/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketTestConfiguration.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright 2020-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.cloud.function.rsocket; - -import java.net.InetSocketAddress; -import java.time.Duration; - -import io.rsocket.RSocket; -import io.rsocket.core.RSocketConnector; -import io.rsocket.transport.netty.client.TcpClientTransport; -import reactor.util.retry.Retry; -import reactor.util.retry.RetrySpec; - -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; -import org.springframework.context.annotation.Scope; -import org.springframework.core.env.Environment; -import org.springframework.lang.Nullable; -import org.springframework.messaging.rsocket.RSocketRequester; -import org.springframework.messaging.rsocket.RSocketStrategies; -import org.springframework.util.Assert; -import org.springframework.util.MimeTypeUtils; - -/** - * - * @author Oleg Zhurakousky - * - */ -@Configuration -public class RSocketTestConfiguration { - - @Bean - @Scope("prototype") - RSocketRequester rSocketRequester(RSocketStrategies rSocketStrategies, Environment environment, - @Nullable RetrySpec retrySpec) { - String port = environment.getProperty("spring.rsocket.server.port"); - Assert.hasText(port, "'spring.rsocket.server.port' must be specified"); - String host = environment.getProperty("spring.rsocket.server.address", "localhost"); - RSocket socket = RSocketConnector - .connectWith( - TcpClientTransport.create(InetSocketAddress.createUnresolved(host, Integer.parseInt(port)))) - .log() - .retryWhen(retrySpec == null ? Retry.backoff(5, Duration.ofSeconds(1)) : retrySpec).block(); - return RSocketRequester.wrap(socket, MimeTypeUtils.APPLICATION_JSON, MimeTypeUtils.APPLICATION_JSON, - rSocketStrategies); - } -} diff --git a/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RoutingBrokerTests.java b/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RoutingBrokerTests.java index 03cb908f5..74d23130c 100644 --- a/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RoutingBrokerTests.java +++ b/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RoutingBrokerTests.java @@ -62,7 +62,7 @@ public class RoutingBrokerTests { public void testRoutingWithProperty() throws Exception { this.setup(true); RSocketRequester requester = clientContext.getBean(RSocketRequester.class); - Mono result = requester.route("toupper") // used to find a messagemapping, so unused here + Mono result = requester.route("uppercase") // used to find a messagemapping, so unused here // auto creates metadata .data("\"hello\"") .retrieveMono(String.class); @@ -79,7 +79,7 @@ public class RoutingBrokerTests { this.setup(false); RSocketRequester requester = clientContext.getBean(RSocketRequester.class); RoutingMetadata metadata = clientContext.getBean(RoutingMetadata.class); - Mono result = requester.route("toupper") // used to find a messagemapping, so unused here + Mono result = requester.route("uppercase") // used to find a messagemapping, so unused here .metadata(metadata.address("samplefn")) .data("\"hello\"") .retrieveMono(String.class);