GH-587 Add support for inferring 'accept' content type for simple types

This fix also introduces new Function property 'accept' with no default value which implicitely would default to application/json unless the output type of the function is String at which point it would default to text/plain. However, if it was explicitely set in FunctionProperties it will be used regardless of the function output type.
Resolves #587
This commit is contained in:
Oleg Zhurakousky
2020-09-16 18:14:40 +02:00
parent e1adb011ab
commit d3afd1fea4
7 changed files with 86 additions and 50 deletions

View File

@@ -41,7 +41,6 @@ import org.springframework.core.codec.ByteArrayEncoder;
import org.springframework.core.codec.Decoder;
import org.springframework.core.codec.Encoder;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.server.PathContainer;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
@@ -57,11 +56,8 @@ import org.springframework.messaging.rsocket.annotation.support.RSocketFrameType
import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler;
import org.springframework.messaging.rsocket.annotation.support.RSocketPayloadReturnValueHandler;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.AntPathMatcher;
import org.springframework.util.MimeTypeUtils;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.RouteMatcher;
import org.springframework.util.SimpleRouteMatcher;
import org.springframework.util.StringUtils;
import org.springframework.web.util.pattern.PathPatternRouteMatcher;
@@ -116,25 +112,18 @@ class FunctionRSocketMessageHandler extends RSocketMessageHandler {
* Will check if there is a function handler registered for destination before proceeding.
* This typically happens when user avoids using 'spring.cloud.function.definition' property.
*/
@SuppressWarnings("unchecked")
@Override
public Mono<Void> handleMessage(Message<?> message) throws MessagingException {
if (!FrameType.SETUP.equals(message.getHeaders().get("rsocketFrameType"))) {
String destination = this.getDestination(message).value();
if (!StringUtils.hasText(destination)) {
destination = this.functionProperties.getDefinition();
Map<String, Object> headersMap = (Map<String, Object>) ReflectionUtils
.getField(this.headersField, message.getHeaders());
PathPatternRouteMatcher matcher = new PathPatternRouteMatcher();
headersMap.put(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, matcher.parseRoute(destination));
destination = this.discoverAndInjectDestinationHeader(message);
}
Set<String> mappings = this.getDestinationLookup().keySet();
if (!mappings.contains(destination)) {
FunctionInvocationWrapper function = FunctionRSocketUtils
.registerFunctionForDestination(destination, functionCatalog, this.getApplicationContext());
.registerFunctionForDestination(destination, this.functionCatalog, this.getApplicationContext());
this.registerFunctionHandler(new RSocketListenerFunction(function), destination);
}
}
@@ -162,6 +151,18 @@ class FunctionRSocketMessageHandler extends RSocketMessageHandler {
getReactiveAdapterRegistry()));
}
@SuppressWarnings("unchecked")
private String discoverAndInjectDestinationHeader(Message<?> message) {
String destination = this.functionProperties.getDefinition();
Map<String, Object> headersMap = (Map<String, Object>) ReflectionUtils
.getField(this.headersField, message.getHeaders());
PathPatternRouteMatcher matcher = new PathPatternRouteMatcher();
headersMap.put(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, matcher.parseRoute(destination));
return destination;
}
protected static final class MessageHandlerMethodArgumentResolver implements SyncHandlerMethodArgumentResolver {
private final Decoder<byte[]> decoder = new ByteArrayDecoder();

View File

@@ -16,6 +16,7 @@
package org.springframework.cloud.function.rsocket;
import java.lang.reflect.Type;
import java.net.URI;
import java.util.regex.Pattern;
@@ -23,6 +24,7 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.cloud.function.context.FunctionCatalog;
import org.springframework.cloud.function.context.FunctionProperties;
import org.springframework.cloud.function.context.FunctionRegistration;
import org.springframework.cloud.function.context.FunctionRegistry;
import org.springframework.cloud.function.context.catalog.FunctionTypeUtils;
@@ -50,10 +52,25 @@ final class FunctionRSocketUtils {
}
static FunctionInvocationWrapper registerFunctionForDestination(String destination, FunctionCatalog functionCatalog,
static FunctionInvocationWrapper registerFunctionForDestination(String functionDefinition, FunctionCatalog functionCatalog,
ApplicationContext applicationContext) {
registerRSocketForwardingFunctionIfNecessary(destination, functionCatalog, applicationContext);
FunctionInvocationWrapper function = functionCatalog.lookup(destination, "application/json");
registerRSocketForwardingFunctionIfNecessary(functionDefinition, functionCatalog, applicationContext);
FunctionProperties functionProperties = applicationContext.getBean(FunctionProperties.class);
String acceptContentType = functionProperties.getAccept();
if (!StringUtils.hasText(acceptContentType)) {
FunctionInvocationWrapper function = functionCatalog.lookup(functionDefinition);
Type functionType = function.getFunctionType();
Type outputType = FunctionTypeUtils.getOutputType(functionType, 0);
if (outputType instanceof Class && String.class.isAssignableFrom((Class<?>) outputType)) {
acceptContentType = "text/plain";
}
else {
acceptContentType = "application/json";
}
}
FunctionInvocationWrapper function = functionCatalog.lookup(functionDefinition, acceptContentType);
return function;
}
@@ -73,6 +90,7 @@ final class FunctionRSocketUtils {
String forwardingUrl = functionToRSocketDefinition[1];
Builder rsocketRequesterBuilder = applicationContext.getBean(Builder.class);
RSocketRequester rsocketRequester = (WS_URI_PATTERN.matcher(forwardingUrl).matches())
? rsocketRequesterBuilder.websocket(URI.create(forwardingUrl))
: rsocketRequesterBuilder.tcp(hostPort[0], Integer.parseInt(hostPort[1]));

View File

@@ -61,6 +61,32 @@ public class RSocketAutoConfigurationTests {
RSocketRequester.Builder rsocketRequesterBuilder =
applicationContext.getBean(RSocketRequester.Builder.class);
rsocketRequesterBuilder.tcp("localhost", port)
.route("")
.data("\"hello\"")
.retrieveMono(String.class)
.as(StepVerifier::create)
.expectNext("HELLO")
.expectComplete()
.verify();
}
}
@Test
public void testImperativeFunctionAsRequestReplyWithDefinitionExplicitAccept() {
int port = SocketUtils.findAvailableTcpPort();
try (
ConfigurableApplicationContext applicationContext =
new SpringApplicationBuilder(SampleFunctionConfiguration.class)
.web(WebApplicationType.NONE)
.run("--logging.level.org.springframework.cloud.function=DEBUG",
"--spring.cloud.function.definition=uppercase",
"--spring.cloud.function.accept=application/json",
"--spring.rsocket.server.port=" + port);
) {
RSocketRequester.Builder rsocketRequesterBuilder =
applicationContext.getBean(RSocketRequester.Builder.class);
rsocketRequesterBuilder.tcp("localhost", port)
.route("")
.data("\"hello\"")
@@ -87,10 +113,10 @@ public class RSocketAutoConfigurationTests {
rsocketRequesterBuilder.tcp("localhost", port)
.route("uppercase")
.data("\"hello\"")
.data("hello")
.retrieveMono(String.class)
.as(StepVerifier::create)
.expectNext("\"HELLO\"")
.expectNext("HELLO")
.expectComplete()
.verify();
}
@@ -114,7 +140,7 @@ public class RSocketAutoConfigurationTests {
.data("\"hello\"")
.retrieveMono(String.class)
.as(StepVerifier::create)
.expectNext("\"HELLOHELLO\"")
.expectNext("HELLOHELLO")
.expectComplete()
.verify();
}
@@ -138,7 +164,7 @@ public class RSocketAutoConfigurationTests {
.data("\"hello\"")
.retrieveMono(String.class)
.as(StepVerifier::create)
.expectNext("\"test data\"")
.expectNext("test data")
.expectComplete()
.verify();
}
@@ -162,7 +188,7 @@ public class RSocketAutoConfigurationTests {
.data("\"hello\"")
.retrieveFlux(String.class)
.as(StepVerifier::create)
.expectNext("\"HELLO\"")
.expectNext("HELLO")
.expectComplete()
.verify();
}
@@ -186,7 +212,7 @@ public class RSocketAutoConfigurationTests {
.data(Flux.just("\"Ricky\"", "\"Julien\"", "\"Bubbles\""))
.retrieveFlux(String.class)
.as(StepVerifier::create)
.expectNext("\"RICKY\"", "\"JULIEN\"", "\"BUBBLES\"")
.expectNext("RICKY", "JULIEN", "BUBBLES")
.expectComplete()
.verify();
}
@@ -294,7 +320,7 @@ public class RSocketAutoConfigurationTests {
.data("\"hello\"")
.retrieveMono(String.class)
.as(StepVerifier::create)
.expectNext("\"(OLLEHOLLEH)\"")
.expectNext("(OLLEHOLLEH)")
.expectComplete()
.verify();
}
@@ -394,7 +420,7 @@ public class RSocketAutoConfigurationTests {
.data("\"hello\"")
.retrieveMono(String.class)
.as(StepVerifier::create)
.expectNext("\"olleh\"")
.expectNext("olleh")
.expectComplete()
.verify();
@@ -402,7 +428,7 @@ public class RSocketAutoConfigurationTests {
.data("\"hello\"")
.retrieveMono(String.class)
.as(StepVerifier::create)
.expectNext("\"(hello)\"")
.expectNext("(hello)")
.expectComplete()
.verify();
}
@@ -443,7 +469,9 @@ public class RSocketAutoConfigurationTests {
@Bean
public Function<String, String> uppercase() {
return String::toUpperCase;
return v -> {
return v.toUpperCase();
};
}
@Bean

View File

@@ -20,6 +20,7 @@ import java.util.function.Function;
import io.rsocket.routing.client.spring.RoutingMetadata;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
@@ -39,6 +40,7 @@ import org.springframework.util.SocketUtils;
* @author Oleg Zhurakousky
* @since 3.1
*/
@Disabled
public class RoutingBrokerTests {
ConfigurableApplicationContext functionContext;
@@ -70,7 +72,7 @@ public class RoutingBrokerTests {
StepVerifier
.create(result)
.expectNext("\"HELLO\"")
.expectNext("HELLO")
.expectComplete()
.verify();
}
@@ -87,7 +89,7 @@ public class RoutingBrokerTests {
StepVerifier
.create(result)
.expectNext("\"HELLO\"")
.expectNext("HELLO")
.expectComplete()
.verify();
}