Fix distributed composition test

Added override of handleMessage(..) to FunctionRSocketMessageHandler to be able to register functions on demand instead of pre-registering all of them during the init
This commit is contained in:
Oleg Zhurakousky
2020-08-28 09:14:57 +02:00
parent 9ac98fd236
commit 123ced3fb6
5 changed files with 126 additions and 77 deletions

View File

@@ -19,6 +19,7 @@ package org.springframework.cloud.function.rsocket;
import java.lang.reflect.Method;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.function.Function;
import io.rsocket.frame.FrameType;
@@ -26,6 +27,8 @@ import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.cloud.function.context.FunctionCatalog;
import org.springframework.cloud.function.context.catalog.SimpleFunctionRegistry.FunctionInvocationWrapper;
import org.springframework.core.MethodParameter;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.ResolvableType;
@@ -36,6 +39,7 @@ import org.springframework.core.codec.Encoder;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessagingException;
import org.springframework.messaging.handler.CompositeMessageCondition;
import org.springframework.messaging.handler.DestinationPatternsMessageCondition;
import org.springframework.messaging.handler.invocation.reactive.HandlerMethodArgumentResolver;
@@ -51,11 +55,14 @@ import org.springframework.util.ReflectionUtils;
* An {@link RSocketMessageHandler} extension for Spring Cloud Function specifics.
*
* @author Artem Bilan
* @author Oleg Zhurakousky
*
* @since 3.1
*/
public class FunctionRSocketMessageHandler extends RSocketMessageHandler {
private final FunctionCatalog functionCatalog;
private static final Method FUNCTION_APPLY_METHOD =
ReflectionUtils.findMethod(Function.class, "apply", (Class<?>[]) null);
@@ -66,8 +73,9 @@ public class FunctionRSocketMessageHandler extends RSocketMessageHandler {
FrameType.REQUEST_STREAM,
FrameType.REQUEST_CHANNEL);
public FunctionRSocketMessageHandler() {
public FunctionRSocketMessageHandler(FunctionCatalog functionCatalog) {
setHandlerPredicate((clazz) -> false);
this.functionCatalog = functionCatalog;
}
@@ -77,6 +85,21 @@ public class FunctionRSocketMessageHandler extends RSocketMessageHandler {
super.afterPropertiesSet();
}
@Override
public Mono<Void> handleMessage(Message<?> message) throws MessagingException {
if (!FrameType.SETUP.equals(message.getHeaders().get("rsocketFrameType"))) {
String destination = this.getDestination(message).value();
Set<String> mappings = this.getDestinationLookup().keySet();
if (!mappings.contains(destination)) {
FunctionRSocketUtils.registerRSocketForwardingFunctionIfNecessary(destination, functionCatalog, this.getApplicationContext());
FunctionInvocationWrapper function = functionCatalog.lookup(destination, "application/json");
this.registerFunctionHandler(new RSocketListenerFunction(function), destination);
}
}
return super.handleMessage(message);
}
public void registerFunctionHandler(Function<?, ?> function, String route) {
CompositeMessageCondition condition =
new CompositeMessageCondition(REQUEST_CONDITION,

View File

@@ -0,0 +1,93 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.function.rsocket;
import java.net.URI;
import java.util.regex.Pattern;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.cloud.function.context.FunctionCatalog;
import org.springframework.cloud.function.context.FunctionRegistration;
import org.springframework.cloud.function.context.FunctionRegistry;
import org.springframework.cloud.function.context.catalog.FunctionTypeUtils;
import org.springframework.cloud.function.context.catalog.SimpleFunctionRegistry.FunctionInvocationWrapper;
import org.springframework.context.ApplicationContext;
import org.springframework.messaging.rsocket.RSocketRequester;
import org.springframework.messaging.rsocket.RSocketRequester.Builder;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
/**
*
* @author Oleg Zhurakousky
*
* @since 3.1
*
*/
final class FunctionRSocketUtils {
private static final Log LOGGER = LogFactory.getLog(FunctionRSocketUtils.class);
private static final Pattern WS_URI_PATTERN = Pattern.compile("^(https?|wss?)://.+");
private FunctionRSocketUtils() {
}
static String registerRSocketForwardingFunctionIfNecessary(String definition, FunctionCatalog functionCatalog,
ApplicationContext applicationContext) {
String[] names = StringUtils.delimitedListToStringArray(definition.replaceAll(",", "|").trim(), "|");
String rootFunctionName = names[0];
for (String name : names) {
if (!applicationContext.containsBean(name)) { // this means RSocket
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Registering RSocket forwarder for '" + name + "' function.");
}
String[] functionToRSocketDefinition = StringUtils.delimitedListToStringArray(name, ">");
Assert.isTrue(functionToRSocketDefinition.length == 2, "Must only contain one output redirect");
FunctionInvocationWrapper function = functionCatalog.lookup(functionToRSocketDefinition[0], "application/json");
String[] hostPort = StringUtils.delimitedListToStringArray(functionToRSocketDefinition[1], ":");
rootFunctionName = function.getFunctionDefinition();
String forwardingUrl = functionToRSocketDefinition[1];
RSocketRequester rsocketRequester;
Builder rsocketRequesterBuilder = applicationContext.getBean(Builder.class);
if (WS_URI_PATTERN.matcher(forwardingUrl).matches()) {
rsocketRequester = rsocketRequesterBuilder.websocket(URI.create(forwardingUrl));
}
else {
rsocketRequester = rsocketRequesterBuilder.tcp(hostPort[0], Integer.parseInt(hostPort[1]));
}
RSocketForwardingFunction rsocketFunction =
new RSocketForwardingFunction(function, rsocketRequester, null);
FunctionRegistration<RSocketForwardingFunction> functionRegistration =
new FunctionRegistration<>(rsocketFunction, name);
functionRegistration.type(
FunctionTypeUtils.discoverFunctionTypeFromClass(RSocketForwardingFunction.class));
((FunctionRegistry) functionCatalog).register(functionRegistration);
}
}
return rootFunctionName;
}
}

View File

@@ -16,12 +16,6 @@
package org.springframework.cloud.function.rsocket;
import java.net.URI;
import java.util.regex.Pattern;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
@@ -30,19 +24,12 @@ import org.springframework.boot.autoconfigure.rsocket.RSocketMessageHandlerCusto
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.cloud.function.context.FunctionCatalog;
import org.springframework.cloud.function.context.FunctionProperties;
import org.springframework.cloud.function.context.FunctionRegistration;
import org.springframework.cloud.function.context.FunctionRegistry;
import org.springframework.cloud.function.context.catalog.FunctionTypeUtils;
import org.springframework.cloud.function.context.catalog.SimpleFunctionRegistry.FunctionInvocationWrapper;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;
import org.springframework.messaging.rsocket.RSocketRequester;
import org.springframework.messaging.rsocket.RSocketRequester.Builder;
import org.springframework.messaging.rsocket.RSocketStrategies;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
/**
@@ -50,6 +37,8 @@ import org.springframework.util.StringUtils;
* spring-cloud-function.
*
* @author Oleg Zhurakousky
* @author Artem Bilan
*
* @since 3.1
*/
@Configuration(proxyBeanMethods = false)
@@ -57,10 +46,6 @@ import org.springframework.util.StringUtils;
@ConditionalOnProperty(name = FunctionProperties.PREFIX + ".rsocket.enabled", matchIfMissing = true)
class RSocketAutoConfiguration implements ApplicationContextAware {
private static final Log LOGGER = LogFactory.getLog(RSocketAutoConfiguration.class);
private static final Pattern WS_URI_PATTERN = Pattern.compile("^(https?|wss?)://.+");
private ApplicationContext applicationContext;
@Override
@@ -70,12 +55,11 @@ class RSocketAutoConfiguration implements ApplicationContextAware {
@Bean
@ConditionalOnMissingBean
@Primary
public FunctionRSocketMessageHandler functionRSocketMessageHandler(RSocketStrategies rSocketStrategies,
ObjectProvider<RSocketMessageHandlerCustomizer> customizers, FunctionCatalog functionCatalog,
FunctionProperties functionProperties) {
FunctionRSocketMessageHandler rsocketMessageHandler = new FunctionRSocketMessageHandler();
FunctionRSocketMessageHandler rsocketMessageHandler = new FunctionRSocketMessageHandler(functionCatalog);
rsocketMessageHandler.setRSocketStrategies(rSocketStrategies);
customizers.orderedStream().forEach((customizer) -> customizer.customize(rsocketMessageHandler));
registerFunctionsWithRSocketHandler(rsocketMessageHandler, functionCatalog, functionProperties);
@@ -86,58 +70,12 @@ class RSocketAutoConfiguration implements ApplicationContextAware {
FunctionCatalog functionCatalog, FunctionProperties functionProperties) {
String definition = functionProperties.getDefinition();
if (StringUtils.hasText(definition)) {
String rootFunctionName = registerRSocketForwardingFunctionIfNecessary(definition, functionCatalog);
FunctionRSocketUtils.registerRSocketForwardingFunctionIfNecessary(definition, functionCatalog, this.applicationContext);
//TODO externalize content-type
FunctionInvocationWrapper function = functionCatalog.lookup(definition, "application/json");
rsocketMessageHandler.registerFunctionHandler(new RSocketListenerFunction(function), rootFunctionName);
rsocketMessageHandler.registerFunctionHandler(new RSocketListenerFunction(function), definition);
rsocketMessageHandler.registerFunctionHandler(new RSocketListenerFunction(function), "");
}
else {
functionCatalog.getNames(null)
.forEach((name) -> {
FunctionInvocationWrapper function = functionCatalog.lookup(name, "application/json");
rsocketMessageHandler.registerFunctionHandler(new RSocketListenerFunction(function), name);
});
}
}
private String registerRSocketForwardingFunctionIfNecessary(String definition, FunctionCatalog functionCatalog) {
String[] names = StringUtils.delimitedListToStringArray(definition.replaceAll(",", "|").trim(), "|");
String rootFunctionName = names[0];
for (String name : names) {
if (!this.applicationContext.containsBean(name)) { // this means RSocket
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Registering RSocket forwarder for '" + name + "' function.");
}
String[] functionToRSocketDefinition = StringUtils.delimitedListToStringArray(name, ">");
Assert.isTrue(functionToRSocketDefinition.length == 2, "Must only contain one output redirect");
FunctionInvocationWrapper function = functionCatalog.lookup(functionToRSocketDefinition[0], "application/json");
String[] hostPort = StringUtils.delimitedListToStringArray(functionToRSocketDefinition[1], ":");
rootFunctionName = function.getFunctionDefinition();
String forwardingUrl = functionToRSocketDefinition[1];
RSocketRequester rsocketRequester;
Builder rsocketRequesterBuilder = RSocketRequester.builder();
if (WS_URI_PATTERN.matcher(forwardingUrl).matches()) {
rsocketRequester = rsocketRequesterBuilder.websocket(URI.create(forwardingUrl));
}
else {
rsocketRequester = rsocketRequesterBuilder.tcp(hostPort[0], Integer.parseInt(hostPort[1]));
}
RSocketForwardingFunction rsocketFunction =
new RSocketForwardingFunction(function, rsocketRequester, null);
FunctionRegistration<RSocketForwardingFunction> functionRegistration =
new FunctionRegistration<>(rsocketFunction, name);
functionRegistration.type(
FunctionTypeUtils.discoverFunctionTypeFromClass(RSocketForwardingFunction.class));
((FunctionRegistry) functionCatalog).register(functionRegistration);
}
}
return rootFunctionName;
}
}

View File

@@ -50,14 +50,11 @@ class RSocketForwardingFunction implements Function<Message<byte[]>, Publisher<M
private final RSocketRequester rSocketRequester;
// private final String remoteFunctionName;
RSocketForwardingFunction(FunctionInvocationWrapper targetFunction, RSocketRequester rsocketRequester,
String remoteFunctionName) {
this.targetFunction = targetFunction;
this.rSocketRequester = rsocketRequester;
// this.remoteFunctionName = remoteFunctionName;
}
@Override
@@ -72,11 +69,9 @@ class RSocketForwardingFunction implements Function<Message<byte[]>, Publisher<M
.map(Message::getPayload);
return this.rSocketRequester
// .route(this.remoteFunctionName)
.route("uppercase")
.route("")
.data(targetFunctionCall, byte[].class)
.retrieveFlux(byte[].class)
.map(GenericMessage::new);
}
}