diff --git a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketAutoConfiguration.java b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketAutoConfiguration.java index 3e1d738fb..9f50a8318 100644 --- a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketAutoConfiguration.java +++ b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketAutoConfiguration.java @@ -93,7 +93,7 @@ public class RSocketAutoConfiguration { .createUnresolved(this.rSocketFunctionProperties.getBindAddress(), this.rSocketFunctionProperties.getBindPort()); if (this.invocableFunction == null) { - this.invocableFunction = new RSocketFunction(function, bindAddress, null); + this.invocableFunction = new RSocketFunction(function, bindAddress); this.invocableFunction.start(); } } @@ -102,8 +102,8 @@ public class RSocketAutoConfiguration { private void registerRsocketProxiesIfNecessary(String definition) { String[] names = StringUtils.delimitedListToStringArray(definition.replaceAll(",", "|").trim(), "|"); - InetSocketAddress listenAddress = InetSocketAddress - .createUnresolved(this.rSocketFunctionProperties.getBindAddress(), this.rSocketFunctionProperties.getBindPort()); +// InetSocketAddress listenAddress = InetSocketAddress +// .createUnresolved(this.rSocketFunctionProperties.getBindAddress(), this.rSocketFunctionProperties.getBindPort()); for (String name : names) { @@ -116,14 +116,14 @@ public class RSocketAutoConfiguration { InetSocketAddress outputAddress = InetSocketAddress .createUnresolved(hostPort[0], Integer.valueOf(hostPort[1])); - RSocketFunction rsocketFunction = new RSocketFunction(function, listenAddress, outputAddress); + RSocketForwardingFunction rsocketFunction = new RSocketForwardingFunction(function, outputAddress); FunctionRegistration functionRegistration = new FunctionRegistration(rsocketFunction, name); functionRegistration.type(FunctionTypeUtils.discoverFunctionTypeFromClass(RSocketFunction.class)); ((FunctionRegistry) this.functionCatalog).register(functionRegistration); - - this.invocableFunction = rsocketFunction; - this.invocableFunction.start(); +// +// this.invocableFunction = rsocketFunction; +// this.invocableFunction.start(); } } } diff --git a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketForwardingFunction.java b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketForwardingFunction.java new file mode 100644 index 000000000..3762d6e54 --- /dev/null +++ b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketForwardingFunction.java @@ -0,0 +1,58 @@ +package org.springframework.cloud.function.rsocket; + +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.time.Duration; +import java.util.function.Function; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.reactivestreams.Publisher; +import org.springframework.cloud.function.context.catalog.SimpleFunctionRegistry.FunctionInvocationWrapper; +import org.springframework.lang.Nullable; +import org.springframework.messaging.Message; +import org.springframework.messaging.support.MessageBuilder; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.util.DefaultPayload; +import reactor.util.retry.Retry; + +class RSocketForwardingFunction implements Function, Publisher>> { + private static Log logger = LogFactory.getLog(RSocketForwardingFunction.class); + + private final RSocket rSocket; + + private final FunctionInvocationWrapper targetFunction; + + RSocketForwardingFunction(FunctionInvocationWrapper targetFunction, InetSocketAddress outputAddress) { + this.targetFunction = targetFunction; + this.rSocket = outputAddress == null ? null + : RSocketConnector.connectWith(TcpClientTransport.create(outputAddress)) + .log() + .retryWhen(Retry.backoff(5, Duration.ofSeconds(1))) + .block(); + } + + @Override + public Publisher> apply(Message input) { + if (logger.isDebugEnabled()) { + logger.debug("Executiing: " + this.targetFunction); + } + + Object rawResult = this.targetFunction.apply(input); + Publisher> resultMessage = this.rSocket + .requestStream(DefaultPayload.create(((Message) rawResult).getPayload())) + .map(this::buildResultMessage); + return resultMessage; + } + + private Message buildResultMessage(Payload payload) { + ByteBuffer payloadBuffer = payload.getData(); + byte[] payloadData = new byte[payloadBuffer.remaining()]; + payloadBuffer.get(payloadData); + return MessageBuilder.withPayload(payloadData).build(); + } +} diff --git a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketFunction.java b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketFunction.java index 5f3aacacd..927e90caa 100644 --- a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketFunction.java +++ b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/RSocketFunction.java @@ -19,13 +19,10 @@ package org.springframework.cloud.function.rsocket; import java.lang.reflect.Type; import java.net.InetSocketAddress; import java.nio.ByteBuffer; -import java.time.Duration; import java.util.function.Function; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.core.RSocketConnector; -import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.util.DefaultPayload; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -33,11 +30,9 @@ import org.reactivestreams.Publisher; import reactor.core.Disposable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.util.retry.Retry; import org.springframework.cloud.function.context.catalog.FunctionTypeUtils; import org.springframework.cloud.function.context.catalog.SimpleFunctionRegistry.FunctionInvocationWrapper; -import org.springframework.lang.Nullable; import org.springframework.messaging.Message; import org.springframework.messaging.support.MessageBuilder; @@ -62,23 +57,13 @@ class RSocketFunction implements Function, Publisher, Publisher> resultMessage = null; - if (this.outputAddress != null) { - resultMessage = this.rSocket - .requestStream(DefaultPayload.create(((Message) rawResult).getPayload())) - .map(this::buildResultMessage); - } - resultMessage = rawResult instanceof Publisher ? (Publisher>) rawResult : Mono.just((Message) rawResult); - return resultMessage; - } - else { - return (Publisher>) rawResult; - } - + return rawResult instanceof Publisher ? (Publisher>) rawResult : Mono.just((Message) rawResult); } void start() { @@ -136,15 +108,7 @@ class RSocketFunction implements Function, Publisher inputMessage = deserealizePayload(payload); Mono> result = Mono.from(function.apply(inputMessage)); - if (rSocket != null) { - return result.flatMap(message -> { - Mono requestResponse = rSocket.requestResponse(DefaultPayload.create(message.getPayload())); - return requestResponse; - }); - } - else { - return result.map(message -> DefaultPayload.create(message.getPayload())); - } + return result.map(message -> DefaultPayload.create(message.getPayload())); } } @@ -213,18 +177,11 @@ class RSocketFunction implements Function, Publisher buildResultMessage(Payload payload) { - ByteBuffer payloadBuffer = payload.getData(); - byte[] payloadData = new byte[payloadBuffer.remaining()]; - payloadBuffer.get(payloadData); - return MessageBuilder.withPayload(payloadData).build(); - } - private void printSplashScreen(String definition, Type type) { System.out.println(splash); System.out.println("Function Definition: " + definition + ":[" + type + "]"); System.out.println("RSocket Listen Address: " + this.listenAddress); - System.out.println("RSocket Target Address: " + this.outputAddress); +// System.out.println("RSocket Target Address: " + this.outputAddress); System.out.println("======================================================\n"); } diff --git a/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationTests.java b/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationTests.java index c0e19108e..38ca7dcf0 100644 --- a/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationTests.java +++ b/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationTests.java @@ -184,7 +184,7 @@ public class RSocketAutoConfigurationTests { new SpringApplicationBuilder(AdditionalFunctionConfiguration.class).web(WebApplicationType.NONE).run( "--logging.level.org.springframework.cloud.function=DEBUG", - "--spring.cloud.function.definition=reverse>localhost:" + portA, + "--spring.cloud.function.definition=reverse>localhost:" + portA + "|wrap", "--spring.cloud.function.rsocket.bind-address=localhost", "--spring.cloud.function.rsocket.bind-port=" + portB); @@ -192,7 +192,7 @@ public class RSocketAutoConfigurationTests { Mono result = socket.requestResponse(DefaultPayload.create("\"hello\"")).map(Payload::getDataUtf8); StepVerifier .create(result) - .expectNext("\"OLLEHOLLEH\"") + .expectNext("\"(OLLEHOLLEH)\"") .expectComplete() .verify(); }