@@ -0,0 +1,144 @@
|
||||
/*
|
||||
* Copyright 2020-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.cloud.function.rsocket;
|
||||
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.function.Function;
|
||||
|
||||
import io.rsocket.frame.FrameType;
|
||||
import org.reactivestreams.Publisher;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
import org.springframework.core.MethodParameter;
|
||||
import org.springframework.core.ReactiveAdapterRegistry;
|
||||
import org.springframework.core.ResolvableType;
|
||||
import org.springframework.core.codec.ByteArrayDecoder;
|
||||
import org.springframework.core.codec.ByteArrayEncoder;
|
||||
import org.springframework.core.codec.Decoder;
|
||||
import org.springframework.core.codec.Encoder;
|
||||
import org.springframework.core.io.buffer.DataBuffer;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.messaging.Message;
|
||||
import org.springframework.messaging.handler.CompositeMessageCondition;
|
||||
import org.springframework.messaging.handler.DestinationPatternsMessageCondition;
|
||||
import org.springframework.messaging.handler.invocation.reactive.HandlerMethodArgumentResolver;
|
||||
import org.springframework.messaging.handler.invocation.reactive.HandlerMethodReturnValueHandler;
|
||||
import org.springframework.messaging.handler.invocation.reactive.SyncHandlerMethodArgumentResolver;
|
||||
import org.springframework.messaging.rsocket.annotation.support.RSocketFrameTypeMessageCondition;
|
||||
import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler;
|
||||
import org.springframework.messaging.rsocket.annotation.support.RSocketPayloadReturnValueHandler;
|
||||
import org.springframework.messaging.support.MessageBuilder;
|
||||
import org.springframework.util.ReflectionUtils;
|
||||
|
||||
/**
|
||||
* An {@link RSocketMessageHandler} extension for Spring Cloud Function specifics.
|
||||
*
|
||||
* @author Artem Bilan
|
||||
*
|
||||
* @since 3.1
|
||||
*/
|
||||
public class FunctionRSocketMessageHandler extends RSocketMessageHandler {
|
||||
|
||||
private static final Method FUNCTION_APPLY_METHOD =
|
||||
ReflectionUtils.findMethod(Function.class, "apply", (Class<?>[]) null);
|
||||
|
||||
private static final RSocketFrameTypeMessageCondition REQUEST_CONDITION =
|
||||
new RSocketFrameTypeMessageCondition(
|
||||
FrameType.REQUEST_FNF,
|
||||
FrameType.REQUEST_RESPONSE,
|
||||
FrameType.REQUEST_STREAM,
|
||||
FrameType.REQUEST_CHANNEL);
|
||||
|
||||
public FunctionRSocketMessageHandler() {
|
||||
setHandlerPredicate((clazz) -> false);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
setEncoders(Collections.singletonList(new ByteArrayEncoder()));
|
||||
super.afterPropertiesSet();
|
||||
}
|
||||
|
||||
public void registerFunctionHandler(Function<?, ?> function, String route) {
|
||||
CompositeMessageCondition condition =
|
||||
new CompositeMessageCondition(REQUEST_CONDITION,
|
||||
new DestinationPatternsMessageCondition(new String[]{ route },
|
||||
obtainRouteMatcher()));
|
||||
registerHandlerMethod(function, FUNCTION_APPLY_METHOD, condition);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected List<? extends HandlerMethodArgumentResolver> initArgumentResolvers() {
|
||||
return Collections.singletonList(new MessageHandlerMethodArgumentResolver());
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@Override
|
||||
protected List<? extends HandlerMethodReturnValueHandler> initReturnValueHandlers() {
|
||||
return Collections.singletonList(new FunctionRSocketPayloadReturnValueHandler((List<Encoder<?>>) getEncoders(),
|
||||
getReactiveAdapterRegistry()));
|
||||
}
|
||||
|
||||
protected static final class MessageHandlerMethodArgumentResolver implements SyncHandlerMethodArgumentResolver {
|
||||
|
||||
private final Decoder<byte[]> decoder = new ByteArrayDecoder();
|
||||
|
||||
@Override
|
||||
public boolean supportsParameter(MethodParameter parameter) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@Override
|
||||
public Object resolveArgumentValue(MethodParameter parameter, Message<?> message) {
|
||||
Flux<DataBuffer> data;
|
||||
Object payload = message.getPayload();
|
||||
if (payload instanceof DataBuffer) {
|
||||
data = Flux.just((DataBuffer) payload);
|
||||
}
|
||||
else {
|
||||
data = Flux.from((Publisher<DataBuffer>) payload);
|
||||
}
|
||||
Flux<byte[]> decoded = this.decoder.decode(data, ResolvableType.forType(byte[].class), null, null);
|
||||
return MessageBuilder.createMessage(decoded, message.getHeaders());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
protected static final class FunctionRSocketPayloadReturnValueHandler extends RSocketPayloadReturnValueHandler {
|
||||
|
||||
public FunctionRSocketPayloadReturnValueHandler(List<Encoder<?>> encoders, ReactiveAdapterRegistry registry) {
|
||||
super(encoders, registry);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Mono<Void> handleReturnValue(@Nullable Object returnValue, MethodParameter returnType,
|
||||
Message<?> message) {
|
||||
|
||||
if (returnValue instanceof Publisher<?> && !message.getHeaders().containsKey(RESPONSE_HEADER)) {
|
||||
return Mono.from((Publisher<?>) returnValue).then();
|
||||
}
|
||||
return super.handleReturnValue(returnValue, returnType, message);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -16,32 +16,32 @@
|
||||
|
||||
package org.springframework.cloud.function.rsocket;
|
||||
|
||||
import java.net.InetSocketAddress;
|
||||
import java.net.URI;
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
import io.rsocket.RSocket;
|
||||
import io.rsocket.SocketAcceptor;
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
import org.springframework.beans.BeansException;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.beans.factory.ObjectProvider;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
||||
import org.springframework.boot.autoconfigure.rsocket.RSocketMessageHandlerCustomizer;
|
||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||
import org.springframework.boot.rsocket.context.RSocketServerBootstrap;
|
||||
import org.springframework.boot.rsocket.server.RSocketServerFactory;
|
||||
import org.springframework.cloud.function.context.FunctionCatalog;
|
||||
import org.springframework.cloud.function.context.FunctionProperties;
|
||||
import org.springframework.cloud.function.context.FunctionRegistration;
|
||||
import org.springframework.cloud.function.context.FunctionRegistry;
|
||||
import org.springframework.cloud.function.context.catalog.FunctionTypeUtils;
|
||||
import org.springframework.cloud.function.context.catalog.SimpleFunctionRegistry.FunctionInvocationWrapper;
|
||||
import org.springframework.cloud.function.json.JsonMapper;
|
||||
import org.springframework.context.ApplicationContext;
|
||||
import org.springframework.context.ApplicationContextAware;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.context.support.GenericApplicationContext;
|
||||
import org.springframework.context.annotation.Primary;
|
||||
import org.springframework.messaging.rsocket.RSocketRequester;
|
||||
import org.springframework.messaging.rsocket.RSocketRequester.Builder;
|
||||
import org.springframework.messaging.rsocket.RSocketStrategies;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
@@ -55,104 +55,89 @@ import org.springframework.util.StringUtils;
|
||||
@Configuration(proxyBeanMethods = false)
|
||||
@EnableConfigurationProperties({ FunctionProperties.class, RSocketFunctionProperties.class })
|
||||
@ConditionalOnProperty(name = FunctionProperties.PREFIX + ".rsocket.enabled", matchIfMissing = true)
|
||||
class RSocketAutoConfiguration {
|
||||
class RSocketAutoConfiguration implements ApplicationContextAware {
|
||||
|
||||
private static Log logger = LogFactory.getLog(RSocketAutoConfiguration.class);
|
||||
private static final Log LOGGER = LogFactory.getLog(RSocketAutoConfiguration.class);
|
||||
|
||||
@Bean
|
||||
public FunctionToRSocketBinder functionToDestinationBinder(FunctionCatalog functionCatalog,
|
||||
FunctionProperties functionProperties, JsonMapper jsonMapper) {
|
||||
return new FunctionToRSocketBinder(functionCatalog, functionProperties, jsonMapper);
|
||||
private static final Pattern WS_URI_PATTERN = Pattern.compile("^(https?|wss?)://.+");
|
||||
|
||||
private ApplicationContext applicationContext;
|
||||
|
||||
@Override
|
||||
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
|
||||
this.applicationContext = applicationContext;
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnMissingBean
|
||||
@ConditionalOnProperty("spring.rsocket.server.port")
|
||||
RSocketServerBootstrap rSocketServerBootstrap(RSocketServerFactory rSocketServerFactory,
|
||||
FunctionToRSocketBinder binder) {
|
||||
return new RSocketServerBootstrap(rSocketServerFactory, SocketAcceptor.with(binder.getRSocket()));
|
||||
@Primary
|
||||
public FunctionRSocketMessageHandler functionRSocketMessageHandler(RSocketStrategies rSocketStrategies,
|
||||
ObjectProvider<RSocketMessageHandlerCustomizer> customizers, FunctionCatalog functionCatalog,
|
||||
FunctionProperties functionProperties) {
|
||||
|
||||
FunctionRSocketMessageHandler rsocketMessageHandler = new FunctionRSocketMessageHandler();
|
||||
rsocketMessageHandler.setRSocketStrategies(rSocketStrategies);
|
||||
customizers.orderedStream().forEach((customizer) -> customizer.customize(rsocketMessageHandler));
|
||||
registerFunctionsWithRSocketHandler(rsocketMessageHandler, functionCatalog, functionProperties);
|
||||
return rsocketMessageHandler;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
static class FunctionToRSocketBinder implements InitializingBean, ApplicationContextAware {
|
||||
|
||||
private final FunctionCatalog functionCatalog;
|
||||
|
||||
private final FunctionProperties functionProperties;
|
||||
|
||||
private final JsonMapper jsonMapper;
|
||||
|
||||
private RSocketListenerFunction invocableFunction;
|
||||
|
||||
private GenericApplicationContext context;
|
||||
|
||||
FunctionToRSocketBinder(FunctionCatalog functionCatalog, FunctionProperties functionProperties, JsonMapper jsonMapper) {
|
||||
this.functionCatalog = functionCatalog;
|
||||
this.functionProperties = functionProperties;
|
||||
this.jsonMapper = jsonMapper;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() throws Exception {
|
||||
String definition = this.functionProperties.getDefinition();
|
||||
if (!StringUtils.hasText(definition)) {
|
||||
FunctionInvocationWrapper f = this.functionCatalog.lookup("");
|
||||
if (f != null) {
|
||||
definition = f.getFunctionDefinition();
|
||||
}
|
||||
}
|
||||
Assert.isTrue(StringUtils.hasText(definition), "Failed to determine target function for RSocket.");
|
||||
this.registerRsocketForwardingFunctionIfNecessary(definition);
|
||||
// TODO externalize content-type
|
||||
private void registerFunctionsWithRSocketHandler(FunctionRSocketMessageHandler rsocketMessageHandler,
|
||||
FunctionCatalog functionCatalog, FunctionProperties functionProperties) {
|
||||
String definition = functionProperties.getDefinition();
|
||||
if (StringUtils.hasText(definition)) {
|
||||
String rootFunctionName = registerRSocketForwardingFunctionIfNecessary(definition, functionCatalog);
|
||||
//TODO externalize content-type
|
||||
FunctionInvocationWrapper function = functionCatalog.lookup(definition, "application/json");
|
||||
if (function.isSupplier()) {
|
||||
throw new UnsupportedOperationException("Supplier is not currently supported for RSocket interaction");
|
||||
}
|
||||
|
||||
this.invocableFunction = new RSocketListenerFunction(function, this.jsonMapper);
|
||||
rsocketMessageHandler.registerFunctionHandler(new RSocketListenerFunction(function), rootFunctionName);
|
||||
}
|
||||
|
||||
RSocket getRSocket() {
|
||||
if (this.invocableFunction == null) {
|
||||
return null;
|
||||
}
|
||||
return this.invocableFunction.getRsocket();
|
||||
else {
|
||||
functionCatalog.getNames(null)
|
||||
.forEach((name) -> {
|
||||
FunctionInvocationWrapper function = functionCatalog.lookup(name, "application/json");
|
||||
rsocketMessageHandler.registerFunctionHandler(new RSocketListenerFunction(function), name);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@SuppressWarnings({ "rawtypes", "unchecked" })
|
||||
private void registerRsocketForwardingFunctionIfNecessary(String definition) {
|
||||
String[] names = StringUtils.delimitedListToStringArray(definition.replaceAll(",", "|").trim(), "|");
|
||||
|
||||
for (String name : names) {
|
||||
if (!this.context.containsBean(name)) { // this means RSocket
|
||||
if (logger.isDebugEnabled()) {
|
||||
logger.debug("Registering rsocket forwarder for '" + name + "' function.");
|
||||
}
|
||||
String[] functionToRSocketDefinition = StringUtils.delimitedListToStringArray(name, ">");
|
||||
Assert.isTrue(functionToRSocketDefinition.length == 2, "Must only contain one output redirect");
|
||||
FunctionInvocationWrapper function = functionCatalog.lookup(functionToRSocketDefinition[0],
|
||||
"application/json");
|
||||
|
||||
String[] hostPort = StringUtils.delimitedListToStringArray(functionToRSocketDefinition[1], ":");
|
||||
InetSocketAddress outputAddress = InetSocketAddress.createUnresolved(hostPort[0],
|
||||
Integer.valueOf(hostPort[1]));
|
||||
|
||||
RSocketForwardingFunction rsocketFunction = new RSocketForwardingFunction(function, outputAddress);
|
||||
FunctionRegistration functionRegistration = new FunctionRegistration(rsocketFunction, name);
|
||||
|
||||
functionRegistration
|
||||
.type(FunctionTypeUtils.discoverFunctionTypeFromClass(RSocketListenerFunction.class));
|
||||
((FunctionRegistry) this.functionCatalog).register(functionRegistration);
|
||||
private String registerRSocketForwardingFunctionIfNecessary(String definition, FunctionCatalog functionCatalog) {
|
||||
String[] names = StringUtils.delimitedListToStringArray(definition.replaceAll(",", "|").trim(), "|");
|
||||
String rootFunctionName = names[0];
|
||||
for (String name : names) {
|
||||
if (!this.applicationContext.containsBean(name)) { // this means RSocket
|
||||
if (LOGGER.isDebugEnabled()) {
|
||||
LOGGER.debug("Registering RSocket forwarder for '" + name + "' function.");
|
||||
}
|
||||
String[] functionToRSocketDefinition = StringUtils.delimitedListToStringArray(name, ">");
|
||||
Assert.isTrue(functionToRSocketDefinition.length == 2, "Must only contain one output redirect");
|
||||
FunctionInvocationWrapper function = functionCatalog.lookup(functionToRSocketDefinition[0], "application/json");
|
||||
|
||||
String[] hostPort = StringUtils.delimitedListToStringArray(functionToRSocketDefinition[1], ":");
|
||||
|
||||
rootFunctionName = function.getFunctionDefinition();
|
||||
String forwardingUrl = functionToRSocketDefinition[1];
|
||||
RSocketRequester rsocketRequester;
|
||||
|
||||
Builder rsocketRequesterBuilder = RSocketRequester.builder();
|
||||
|
||||
if (WS_URI_PATTERN.matcher(forwardingUrl).matches()) {
|
||||
rsocketRequester = rsocketRequesterBuilder.websocket(URI.create(forwardingUrl));
|
||||
}
|
||||
else {
|
||||
rsocketRequester = rsocketRequesterBuilder.tcp(hostPort[0], Integer.parseInt(hostPort[1]));
|
||||
}
|
||||
|
||||
RSocketForwardingFunction rsocketFunction =
|
||||
new RSocketForwardingFunction(function, rsocketRequester, null);
|
||||
FunctionRegistration<RSocketForwardingFunction> functionRegistration =
|
||||
new FunctionRegistration<>(rsocketFunction, name);
|
||||
functionRegistration.type(
|
||||
FunctionTypeUtils.discoverFunctionTypeFromClass(RSocketForwardingFunction.class));
|
||||
((FunctionRegistry) functionCatalog).register(functionRegistration);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
|
||||
this.context = (GenericApplicationContext) applicationContext;
|
||||
}
|
||||
return rootFunctionName;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -16,38 +16,25 @@
|
||||
|
||||
package org.springframework.cloud.function.rsocket;
|
||||
|
||||
import java.net.InetSocketAddress;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.time.Duration;
|
||||
import java.util.function.Function;
|
||||
|
||||
import io.rsocket.Payload;
|
||||
import io.rsocket.RSocket;
|
||||
import io.rsocket.core.RSocketConnector;
|
||||
import io.rsocket.transport.netty.client.TcpClientTransport;
|
||||
import io.rsocket.util.DefaultPayload;
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.reactivestreams.Publisher;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.util.retry.Retry;
|
||||
|
||||
import org.springframework.cloud.function.context.catalog.SimpleFunctionRegistry.FunctionInvocationWrapper;
|
||||
import org.springframework.messaging.Message;
|
||||
import org.springframework.messaging.support.MessageBuilder;
|
||||
import org.springframework.messaging.rsocket.RSocketRequester;
|
||||
import org.springframework.messaging.support.GenericMessage;
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* An implementation of {@link Function} to support distributed function composition.
|
||||
* <br>
|
||||
* This function wraps target function and forwards the result of
|
||||
* the invocation of the target function to another RSocket returning the result of such forwarding as {@link Publisher}.
|
||||
* <br><br>
|
||||
* A typical example is `spring.cloud.function.definition=uppercase>localhost:8888'.
|
||||
* <br>
|
||||
* In this case 'uppercase' is targetFunction which will be invoked during the call to 'apply' and the result of
|
||||
* this invocation sent to RSocket reachable at localhost:8888.
|
||||
* Wrapper over an instance of target Function (represented by {@link FunctionInvocationWrapper})
|
||||
* which will use the result of the invocation of such function as an input to another RSocket
|
||||
* effectively composing two functions over RSocket.
|
||||
* <p>
|
||||
* Note: the remote RSocket route is not necessary to be as a Spring Cloud Function binding.
|
||||
*
|
||||
* @author Oleg Zhurakousky
|
||||
* @author Artem Bilan
|
||||
@@ -59,39 +46,37 @@ class RSocketForwardingFunction implements Function<Message<byte[]>, Publisher<M
|
||||
|
||||
private static final Log LOGGER = LogFactory.getLog(RSocketForwardingFunction.class);
|
||||
|
||||
private final Mono<RSocket> rsocketMono;
|
||||
|
||||
private final FunctionInvocationWrapper targetFunction;
|
||||
|
||||
RSocketForwardingFunction(FunctionInvocationWrapper targetFunction, InetSocketAddress outputAddress) {
|
||||
private final RSocketRequester rSocketRequester;
|
||||
|
||||
// private final String remoteFunctionName;
|
||||
|
||||
RSocketForwardingFunction(FunctionInvocationWrapper targetFunction, RSocketRequester rsocketRequester,
|
||||
String remoteFunctionName) {
|
||||
|
||||
this.targetFunction = targetFunction;
|
||||
this.rsocketMono =
|
||||
outputAddress == null
|
||||
? null
|
||||
: RSocketConnector.create()
|
||||
.reconnect(Retry.backoff(5, Duration.ofSeconds(1)))
|
||||
.connect(TcpClientTransport.create(outputAddress));
|
||||
this.rSocketRequester = rsocketRequester;
|
||||
// this.remoteFunctionName = remoteFunctionName;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@Override
|
||||
public Publisher<Message<byte[]>> apply(Message<byte[]> input) {
|
||||
if (LOGGER.isDebugEnabled()) {
|
||||
LOGGER.debug("Executing: " + this.targetFunction);
|
||||
}
|
||||
|
||||
Object rawResult = this.targetFunction.apply(input);
|
||||
return this.rsocketMono
|
||||
.flatMapMany((rsocket) ->
|
||||
rsocket.requestStream(DefaultPayload.create(((Message<byte[]>) rawResult).getPayload())))
|
||||
.map(this::buildResultMessage);
|
||||
}
|
||||
Mono<Object> targetFunctionCall = Mono.just(input)
|
||||
.map(this.targetFunction)
|
||||
.cast(Message.class)
|
||||
.map(Message::getPayload);
|
||||
|
||||
private Message<byte[]> buildResultMessage(Payload payload) {
|
||||
ByteBuffer payloadBuffer = payload.getData();
|
||||
byte[] payloadData = new byte[payloadBuffer.remaining()];
|
||||
payloadBuffer.get(payloadData);
|
||||
return MessageBuilder.withPayload(payloadData).build();
|
||||
return this.rSocketRequester
|
||||
// .route(this.remoteFunctionName)
|
||||
.route("uppercase")
|
||||
.data(targetFunctionCall, byte[].class)
|
||||
.retrieveFlux(byte[].class)
|
||||
.map(GenericMessage::new);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -17,181 +17,93 @@
|
||||
package org.springframework.cloud.function.rsocket;
|
||||
|
||||
import java.lang.reflect.Type;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
|
||||
import io.rsocket.Payload;
|
||||
import io.rsocket.RSocket;
|
||||
import io.rsocket.util.DefaultPayload;
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import io.rsocket.frame.FrameType;
|
||||
import org.reactivestreams.Publisher;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
import org.springframework.cloud.function.context.catalog.FunctionTypeUtils;
|
||||
import org.springframework.cloud.function.context.catalog.SimpleFunctionRegistry.FunctionInvocationWrapper;
|
||||
import org.springframework.cloud.function.json.JsonMapper;
|
||||
import org.springframework.messaging.Message;
|
||||
import org.springframework.messaging.rsocket.annotation.support.RSocketFrameTypeMessageCondition;
|
||||
import org.springframework.messaging.support.MessageBuilder;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Wrapper over an instance of target Function (represented by {@link FunctionInvocationWrapper})
|
||||
* which will use the result of the invocation of such function as an input to another RSocket
|
||||
* effectively composing two functions over RSocket.
|
||||
* A function wrapper which is bound onto an RSocket route.
|
||||
*
|
||||
* @author Oleg Zhurakousky
|
||||
* @author Artem Bilan
|
||||
*
|
||||
* @since 3.1
|
||||
*/
|
||||
class RSocketListenerFunction implements Function<Message<byte[]>, Publisher<Message<byte[]>>> {
|
||||
|
||||
private static String splash = " ____ _ _______ __ ____ __ _ ___ ____ __ __ \n" +
|
||||
" / __/__ ____(_)__ ___ _ / ___/ /__ __ _____/ / / __/_ _____ ____/ /_(_)__ ___ / _ \\/ __/__ ____/ /_____ / /_\n" +
|
||||
" _\\ \\/ _ \\/ __/ / _ \\/ _ `/ / /__/ / _ \\/ // / _ / / _// // / _ \\/ __/ __/ / _ \\/ _ \\ / , _/\\ \\/ _ \\/ __/ '_/ -_) __/\n" +
|
||||
"/___/ .__/_/ /_/_//_/\\_, / \\___/_/\\___/\\_,_/\\_,_/ /_/ \\_,_/_//_/\\__/\\__/_/\\___/_//_/ /_/|_/___/\\___/\\__/_/\\_\\\\__/\\__/ \n" +
|
||||
" /_/ /___/ \n" +
|
||||
"";
|
||||
|
||||
private static Log logger = LogFactory.getLog(RSocketListenerFunction.class);
|
||||
public class RSocketListenerFunction implements Function<Message<Flux<byte[]>>, Publisher<?>> {
|
||||
|
||||
private final FunctionInvocationWrapper targetFunction;
|
||||
|
||||
private RSocket rsocket;
|
||||
|
||||
private final JsonMapper jsonMapper;
|
||||
|
||||
RSocketListenerFunction(FunctionInvocationWrapper targetFunction, JsonMapper jsonMapper) {
|
||||
RSocketListenerFunction(FunctionInvocationWrapper targetFunction) {
|
||||
this.targetFunction = targetFunction;
|
||||
this.jsonMapper = jsonMapper;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@Override
|
||||
public Publisher<Message<byte[]>> apply(Message<byte[]> input) {
|
||||
if (logger.isDebugEnabled()) {
|
||||
logger.debug("Executiing: " + this.targetFunction);
|
||||
public Publisher<?> apply(Message<Flux<byte[]>> input) {
|
||||
FrameType frameType = RSocketFrameTypeMessageCondition.getFrameType(input);
|
||||
switch (frameType) {
|
||||
case REQUEST_FNF:
|
||||
return handle(input);
|
||||
case REQUEST_RESPONSE:
|
||||
case REQUEST_STREAM:
|
||||
case REQUEST_CHANNEL:
|
||||
return handleAndReply(input);
|
||||
default:
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
Object rawResult = this.targetFunction.apply(input);
|
||||
return rawResult instanceof Publisher ? (Publisher<Message<byte[]>>) rawResult : Mono.just((Message<byte[]>) rawResult);
|
||||
}
|
||||
|
||||
public RSocket getRsocket() {
|
||||
if (this.rsocket == null) {
|
||||
Type functionType = this.targetFunction.getFunctionType();
|
||||
|
||||
if (this.rsocket == null) {
|
||||
this.rsocket = this.buildRSocket(this.targetFunction, functionType, this);
|
||||
@SuppressWarnings({ "unchecked", "rawtypes" })
|
||||
private Mono<Void> handle(Message<Flux<byte[]>> messageToProcess) {
|
||||
if (this.targetFunction.isConsumer()) {
|
||||
Flux<?> dataFlux =
|
||||
messageToProcess.getPayload()
|
||||
.map((payload) -> MessageBuilder.createMessage(payload, messageToProcess.getHeaders()));
|
||||
if (isFunctionInputReactive(this.targetFunction.getFunctionType())) {
|
||||
dataFlux = dataFlux.transform((Function) this.targetFunction);
|
||||
}
|
||||
this.printSplashScreen(this.targetFunction.getFunctionDefinition(), functionType);
|
||||
else {
|
||||
dataFlux = dataFlux.doOnNext(this.targetFunction);
|
||||
}
|
||||
return dataFlux.then();
|
||||
}
|
||||
else {
|
||||
return Mono.error(new IllegalStateException("Only 'Consumer' can handle 'fire-and-forget' RSocket frame."));
|
||||
}
|
||||
return this.rsocket;
|
||||
}
|
||||
|
||||
private RSocket buildRSocket(FunctionInvocationWrapper targetFunction, Type functionType, Function<Message<byte[]>, Publisher<Message<byte[]>>> function) {
|
||||
String definition = targetFunction.getFunctionDefinition();
|
||||
RSocket clientRSocket = new RSocket() { // imperative function or Function<?, Mono> = requestResponse
|
||||
@Override
|
||||
public Mono<Payload> requestResponse(Payload payload) {
|
||||
if (logger.isDebugEnabled()) {
|
||||
logger.debug("Invoking function '" + definition + "' as RSocket `requestResponse`.");
|
||||
}
|
||||
|
||||
if (isFunctionReactive(functionType)) {
|
||||
Flux<Payload> result = this.requestChannel(Flux.just(payload));
|
||||
return Mono.from(result);
|
||||
}
|
||||
else {
|
||||
Message<byte[]> inputMessage = deserealizePayload(payload);
|
||||
Mono<Message<byte[]>> result = Mono.from(function.apply(inputMessage));
|
||||
return result.map(message -> DefaultPayload.create(message.getPayload(), jsonMapper.toJson(message.getHeaders())));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<Payload> requestStream(Payload payload) {
|
||||
if (logger.isDebugEnabled()) {
|
||||
logger.debug("Invoking function '" + definition + "' as RSocket `requestStream`.");
|
||||
}
|
||||
if (isFunctionReactive(functionType)) {
|
||||
return this.requestChannel(Flux.just(payload));
|
||||
}
|
||||
else {
|
||||
Message<byte[]> inputMessage = deserealizePayload(payload);
|
||||
Flux<Message<byte[]>> result = Flux.from(function.apply(inputMessage));
|
||||
return result.map(message -> DefaultPayload.create(message.getPayload()));
|
||||
}
|
||||
}
|
||||
|
||||
@SuppressWarnings({ "unchecked", "rawtypes" })
|
||||
@Override
|
||||
public Flux<Payload> requestChannel(Publisher<Payload> payloads) {
|
||||
if (logger.isDebugEnabled()) {
|
||||
logger.debug("Invoking function '" + definition + "' as RSocket `requestChannel`.");
|
||||
}
|
||||
if (isFunctionReactive(functionType)) {
|
||||
return Flux.from(payloads)
|
||||
.transform(inputFlux -> inputFlux.map(payload -> deserealizePayload(payload)))
|
||||
.transform((Function) targetFunction)
|
||||
.transform(outputFlux -> ((Flux<Message<byte[]>>) outputFlux).map(message -> DefaultPayload.create(message.getPayload())));
|
||||
}
|
||||
else {
|
||||
return Flux.from(payloads)
|
||||
.transform(flux -> {
|
||||
return flux.flatMap(payload -> {
|
||||
Message<byte[]> inputMessage = deserealizePayload(payload);
|
||||
Flux<Message<byte[]>> result = Flux.from(function.apply(inputMessage));
|
||||
return result;
|
||||
});
|
||||
})
|
||||
.doOnNext(System.out::println)
|
||||
.transform(outputFlux -> outputFlux.map(message -> DefaultPayload.create(message.getPayload())));
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
return clientRSocket;
|
||||
@SuppressWarnings({ "unchecked", "rawtypes" })
|
||||
private Flux<?> handleAndReply(Message<Flux<byte[]>> messageToProcess) {
|
||||
Flux<?> dataFlux =
|
||||
messageToProcess.getPayload()
|
||||
.map((payload) -> MessageBuilder.createMessage(payload, messageToProcess.getHeaders()));
|
||||
if (isFunctionInputReactive(this.targetFunction.getFunctionType())) {
|
||||
dataFlux = dataFlux.transform((Function) this.targetFunction);
|
||||
}
|
||||
else {
|
||||
dataFlux = dataFlux.flatMap((data) -> {
|
||||
Object result = this.targetFunction.apply(data);
|
||||
return result instanceof Publisher<?>
|
||||
? (Publisher<Message<byte[]>>) result
|
||||
: Mono.just((Message<byte[]>) result);
|
||||
});
|
||||
}
|
||||
return dataFlux.cast(Message.class).map(Message::getPayload);
|
||||
}
|
||||
|
||||
private static boolean isFunctionReactive(Type functionType) {
|
||||
private static boolean isFunctionInputReactive(Type functionType) {
|
||||
Type inputType = FunctionTypeUtils.getInputType(functionType, 0);
|
||||
Type outputType = FunctionTypeUtils.getOutputType(functionType, 0);
|
||||
return FunctionTypeUtils.isPublisher(inputType) && FunctionTypeUtils.isFlux(outputType);
|
||||
}
|
||||
|
||||
@SuppressWarnings({ "rawtypes", "unchecked" })
|
||||
private Message<byte[]> deserealizePayload(Payload payload) {
|
||||
ByteBuffer buffer = payload.getData();
|
||||
byte[] rawData = new byte[buffer.remaining()];
|
||||
buffer.get(rawData);
|
||||
Map<String, Object> headers = null;
|
||||
if (payload.hasMetadata()) {
|
||||
try {
|
||||
ByteBuffer metadata = payload.getMetadata();
|
||||
byte[] metadataBytes = new byte[metadata.remaining()];
|
||||
metadata.get(metadataBytes);
|
||||
headers = this.jsonMapper.fromJson(metadataBytes, Map.class);
|
||||
}
|
||||
catch (Exception e) {
|
||||
//throw new IllegalStateException(e);
|
||||
logger.warn("Failed to extract headers from metadata", e);
|
||||
}
|
||||
}
|
||||
MessageBuilder builder = MessageBuilder.withPayload(rawData);
|
||||
if (!CollectionUtils.isEmpty(headers)) {
|
||||
builder.copyHeaders(headers);
|
||||
}
|
||||
Message<byte[]> inputMessage = builder.build();
|
||||
return inputMessage;
|
||||
|
||||
}
|
||||
|
||||
private void printSplashScreen(String definition, Type type) {
|
||||
System.out.println(splash);
|
||||
System.out.println("Function Definition: " + definition + "; T[" + type + "]");
|
||||
System.out.println("======================================================\n");
|
||||
return FunctionTypeUtils.isPublisher(inputType);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
|
||||
package org.springframework.cloud.function.rsocket;
|
||||
|
||||
import io.rsocket.SocketAcceptor;
|
||||
import io.rsocket.routing.client.spring.RoutingClientAutoConfiguration;
|
||||
|
||||
import org.springframework.boot.autoconfigure.AutoConfigureAfter;
|
||||
@@ -24,7 +23,6 @@ import org.springframework.boot.autoconfigure.AutoConfigureBefore;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
||||
import org.springframework.cloud.function.context.FunctionProperties;
|
||||
import org.springframework.cloud.function.rsocket.RSocketAutoConfiguration.FunctionToRSocketBinder;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.messaging.rsocket.RSocketConnectorConfigurer;
|
||||
@@ -45,8 +43,8 @@ class RSocketRoutingAutoConfiguration {
|
||||
|
||||
@Bean
|
||||
public RSocketConnectorConfigurer functionRSocketConnectorConfigurer(
|
||||
FunctionToRSocketBinder binder) {
|
||||
return connector -> connector.acceptor(SocketAcceptor.with(binder.getRSocket()));
|
||||
FunctionRSocketMessageHandler handler) {
|
||||
return connector -> connector.acceptor(handler.responder());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user