This commit is contained in:
Oleg Zhurakousky
2020-07-16 17:08:55 +02:00
parent 13aa4700b1
commit 638c98cfb7
4 changed files with 71 additions and 56 deletions

View File

@@ -93,7 +93,7 @@ public class RSocketAutoConfiguration {
.createUnresolved(this.rSocketFunctionProperties.getBindAddress(), this.rSocketFunctionProperties.getBindPort());
if (this.invocableFunction == null) {
this.invocableFunction = new RSocketFunction(function, bindAddress, null);
this.invocableFunction = new RSocketFunction(function, bindAddress);
this.invocableFunction.start();
}
}
@@ -102,8 +102,8 @@ public class RSocketAutoConfiguration {
private void registerRsocketProxiesIfNecessary(String definition) {
String[] names = StringUtils.delimitedListToStringArray(definition.replaceAll(",", "|").trim(), "|");
InetSocketAddress listenAddress = InetSocketAddress
.createUnresolved(this.rSocketFunctionProperties.getBindAddress(), this.rSocketFunctionProperties.getBindPort());
// InetSocketAddress listenAddress = InetSocketAddress
// .createUnresolved(this.rSocketFunctionProperties.getBindAddress(), this.rSocketFunctionProperties.getBindPort());
for (String name : names) {
@@ -116,14 +116,14 @@ public class RSocketAutoConfiguration {
InetSocketAddress outputAddress = InetSocketAddress
.createUnresolved(hostPort[0], Integer.valueOf(hostPort[1]));
RSocketFunction rsocketFunction = new RSocketFunction(function, listenAddress, outputAddress);
RSocketForwardingFunction rsocketFunction = new RSocketForwardingFunction(function, outputAddress);
FunctionRegistration functionRegistration = new FunctionRegistration(rsocketFunction, name);
functionRegistration.type(FunctionTypeUtils.discoverFunctionTypeFromClass(RSocketFunction.class));
((FunctionRegistry) this.functionCatalog).register(functionRegistration);
this.invocableFunction = rsocketFunction;
this.invocableFunction.start();
//
// this.invocableFunction = rsocketFunction;
// this.invocableFunction.start();
}
}
}

View File

@@ -0,0 +1,58 @@
package org.springframework.cloud.function.rsocket;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.time.Duration;
import java.util.function.Function;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.reactivestreams.Publisher;
import org.springframework.cloud.function.context.catalog.SimpleFunctionRegistry.FunctionInvocationWrapper;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.messaging.support.MessageBuilder;
import io.rsocket.Payload;
import io.rsocket.RSocket;
import io.rsocket.core.RSocketConnector;
import io.rsocket.transport.netty.client.TcpClientTransport;
import io.rsocket.util.DefaultPayload;
import reactor.util.retry.Retry;
class RSocketForwardingFunction implements Function<Message<byte[]>, Publisher<Message<byte[]>>> {
private static Log logger = LogFactory.getLog(RSocketForwardingFunction.class);
private final RSocket rSocket;
private final FunctionInvocationWrapper targetFunction;
RSocketForwardingFunction(FunctionInvocationWrapper targetFunction, InetSocketAddress outputAddress) {
this.targetFunction = targetFunction;
this.rSocket = outputAddress == null ? null
: RSocketConnector.connectWith(TcpClientTransport.create(outputAddress))
.log()
.retryWhen(Retry.backoff(5, Duration.ofSeconds(1)))
.block();
}
@Override
public Publisher<Message<byte[]>> apply(Message<byte[]> input) {
if (logger.isDebugEnabled()) {
logger.debug("Executiing: " + this.targetFunction);
}
Object rawResult = this.targetFunction.apply(input);
Publisher<Message<byte[]>> resultMessage = this.rSocket
.requestStream(DefaultPayload.create(((Message<byte[]>) rawResult).getPayload()))
.map(this::buildResultMessage);
return resultMessage;
}
private Message<byte[]> buildResultMessage(Payload payload) {
ByteBuffer payloadBuffer = payload.getData();
byte[] payloadData = new byte[payloadBuffer.remaining()];
payloadBuffer.get(payloadData);
return MessageBuilder.withPayload(payloadData).build();
}
}

View File

@@ -19,13 +19,10 @@ package org.springframework.cloud.function.rsocket;
import java.lang.reflect.Type;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.time.Duration;
import java.util.function.Function;
import io.rsocket.Payload;
import io.rsocket.RSocket;
import io.rsocket.core.RSocketConnector;
import io.rsocket.transport.netty.client.TcpClientTransport;
import io.rsocket.util.DefaultPayload;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@@ -33,11 +30,9 @@ import org.reactivestreams.Publisher;
import reactor.core.Disposable;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.util.retry.Retry;
import org.springframework.cloud.function.context.catalog.FunctionTypeUtils;
import org.springframework.cloud.function.context.catalog.SimpleFunctionRegistry.FunctionInvocationWrapper;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.messaging.support.MessageBuilder;
@@ -62,23 +57,13 @@ class RSocketFunction implements Function<Message<byte[]>, Publisher<Message<byt
private final InetSocketAddress listenAddress;
private final InetSocketAddress outputAddress;
private final FunctionInvocationWrapper targetFunction;
private final RSocket rSocket;
private Disposable rsocketConnection;
RSocketFunction(FunctionInvocationWrapper targetFunction, InetSocketAddress listenAddress, @Nullable InetSocketAddress outputAddress) {
RSocketFunction(FunctionInvocationWrapper targetFunction, InetSocketAddress listenAddress) {
this.listenAddress = listenAddress;
this.outputAddress = outputAddress;
this.targetFunction = targetFunction;
this.rSocket = outputAddress == null ? null
: RSocketConnector.connectWith(TcpClientTransport.create(this.outputAddress))
.log()
.retryWhen(Retry.backoff(5, Duration.ofSeconds(1)))
.block();
}
@SuppressWarnings("unchecked")
@@ -89,20 +74,7 @@ class RSocketFunction implements Function<Message<byte[]>, Publisher<Message<byt
}
Object rawResult = this.targetFunction.apply(input);
if (rawResult instanceof Message) {
Publisher<Message<byte[]>> resultMessage = null;
if (this.outputAddress != null) {
resultMessage = this.rSocket
.requestStream(DefaultPayload.create(((Message<byte[]>) rawResult).getPayload()))
.map(this::buildResultMessage);
}
resultMessage = rawResult instanceof Publisher ? (Publisher<Message<byte[]>>) rawResult : Mono.just((Message<byte[]>) rawResult);
return resultMessage;
}
else {
return (Publisher<Message<byte[]>>) rawResult;
}
return rawResult instanceof Publisher ? (Publisher<Message<byte[]>>) rawResult : Mono.just((Message<byte[]>) rawResult);
}
void start() {
@@ -136,15 +108,7 @@ class RSocketFunction implements Function<Message<byte[]>, Publisher<Message<byt
else {
Message<byte[]> inputMessage = deserealizePayload(payload);
Mono<Message<byte[]>> result = Mono.from(function.apply(inputMessage));
if (rSocket != null) {
return result.flatMap(message -> {
Mono<Payload> requestResponse = rSocket.requestResponse(DefaultPayload.create(message.getPayload()));
return requestResponse;
});
}
else {
return result.map(message -> DefaultPayload.create(message.getPayload()));
}
return result.map(message -> DefaultPayload.create(message.getPayload()));
}
}
@@ -213,18 +177,11 @@ class RSocketFunction implements Function<Message<byte[]>, Publisher<Message<byt
}
private Message<byte[]> buildResultMessage(Payload payload) {
ByteBuffer payloadBuffer = payload.getData();
byte[] payloadData = new byte[payloadBuffer.remaining()];
payloadBuffer.get(payloadData);
return MessageBuilder.withPayload(payloadData).build();
}
private void printSplashScreen(String definition, Type type) {
System.out.println(splash);
System.out.println("Function Definition: " + definition + ":[" + type + "]");
System.out.println("RSocket Listen Address: " + this.listenAddress);
System.out.println("RSocket Target Address: " + this.outputAddress);
// System.out.println("RSocket Target Address: " + this.outputAddress);
System.out.println("======================================================\n");
}

View File

@@ -184,7 +184,7 @@ public class RSocketAutoConfigurationTests {
new SpringApplicationBuilder(AdditionalFunctionConfiguration.class).web(WebApplicationType.NONE).run(
"--logging.level.org.springframework.cloud.function=DEBUG",
"--spring.cloud.function.definition=reverse>localhost:" + portA,
"--spring.cloud.function.definition=reverse>localhost:" + portA + "|wrap",
"--spring.cloud.function.rsocket.bind-address=localhost",
"--spring.cloud.function.rsocket.bind-port=" + portB);
@@ -192,7 +192,7 @@ public class RSocketAutoConfigurationTests {
Mono<String> result = socket.requestResponse(DefaultPayload.create("\"hello\"")).map(Payload::getDataUtf8);
StepVerifier
.create(result)
.expectNext("\"OLLEHOLLEH\"")
.expectNext("\"(OLLEHOLLEH)\"")
.expectComplete()
.verify();
}