diff --git a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/FunctionRSocketMessageHandler.java b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/FunctionRSocketMessageHandler.java index 867a093a7..cfebb2b19 100644 --- a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/FunctionRSocketMessageHandler.java +++ b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/FunctionRSocketMessageHandler.java @@ -150,38 +150,39 @@ class FunctionRSocketMessageHandler extends RSocketMessageHandler { getReactiveAdapterRegistry())); } - @SuppressWarnings("unchecked") private String discoverAndInjectDestinationHeader(Message message) { String destination; if (StringUtils.hasText(this.functionProperties.getRoutingExpression())) { destination = RoutingFunction.FUNCTION_NAME; - Map headersMap = (Map) ReflectionUtils - .getField(this.headersField, message.getHeaders()); - PathPatternRouteMatcher matcher = new PathPatternRouteMatcher(); - headersMap.put(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, matcher.parseRoute(destination)); + this.updateMessageHeaders(message, destination); } else { Route route = (Route) message.getHeaders().get(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER); destination = route.value(); if (!StringUtils.hasText(destination)) { destination = this.functionProperties.getDefinition(); - Map headersMap = (Map) ReflectionUtils - .getField(this.headersField, message.getHeaders()); - PathPatternRouteMatcher matcher = new PathPatternRouteMatcher(); - headersMap.put(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, matcher.parseRoute(destination)); + this.updateMessageHeaders(message, destination); } } if (!StringUtils.hasText(destination) && logger.isDebugEnabled()) { logger.debug("Failed to discover function definition. Neither " + "`spring.cloud.function.definition`, nor `.route()`, nor " - + "`spring.cloud.function.routing-expression` were provided. Wil use empty string " + + "`spring.cloud.function.routing-expression` were provided. Will use empty string " + "for lookup, which will work only if there is one function in Function Catalog"); } return destination; } + @SuppressWarnings("unchecked") + private void updateMessageHeaders(Message message, String destination) { + Map headersMap = (Map) ReflectionUtils + .getField(this.headersField, message.getHeaders()); + PathPatternRouteMatcher matcher = new PathPatternRouteMatcher(); + headersMap.put(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, matcher.parseRoute(destination)); + } + protected static final class MessageHandlerMethodArgumentResolver implements SyncHandlerMethodArgumentResolver { private final Decoder decoder = new ByteArrayDecoder(); diff --git a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/FunctionRSocketUtils.java b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/FunctionRSocketUtils.java index 88a89ab5c..71359cc4e 100644 --- a/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/FunctionRSocketUtils.java +++ b/spring-cloud-function-rsocket/src/main/java/org/springframework/cloud/function/rsocket/FunctionRSocketUtils.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2020 the original author or authors. + * Copyright 2020-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,6 +33,7 @@ import org.springframework.context.ApplicationContext; import org.springframework.messaging.rsocket.RSocketRequester; import org.springframework.messaging.rsocket.RSocketRequester.Builder; import org.springframework.util.Assert; +import org.springframework.util.MimeTypeUtils; import org.springframework.util.StringUtils; /** @@ -60,14 +61,9 @@ final class FunctionRSocketUtils { String acceptContentType = functionProperties.getExpectedContentType(); if (!StringUtils.hasText(acceptContentType)) { FunctionInvocationWrapper function = functionCatalog.lookup(functionDefinition); - //Type functionType = function.getFunctionType(); Type outputType = function.getOutputType(); - if (outputType instanceof Class && String.class.isAssignableFrom((Class) outputType)) { - acceptContentType = "text/plain"; - } - else { - acceptContentType = "application/json"; - } + acceptContentType = (outputType instanceof Class && String.class.isAssignableFrom((Class) outputType)) + ? MimeTypeUtils.TEXT_PLAIN_VALUE : MimeTypeUtils.APPLICATION_JSON_VALUE; } FunctionInvocationWrapper function = functionCatalog.lookup(functionDefinition, acceptContentType); @@ -78,13 +74,18 @@ final class FunctionRSocketUtils { ApplicationContext applicationContext) { String[] names = StringUtils.delimitedListToStringArray(definition.replaceAll(",", "|").trim(), "|"); for (String name : names) { - if (!applicationContext.containsBean(name)) { // this means RSocket + + if (functionCatalog.lookup(name) == null) { // this means RSocket + String[] functionToRSocketDefinition = StringUtils.delimitedListToStringArray(name, ">"); + if (functionToRSocketDefinition.length == 1) { + throw new IllegalArgumentException("Function definition '" + name + "' does not exist in Function Catalog"); + } if (LOGGER.isDebugEnabled()) { LOGGER.debug("Registering RSocket forwarder for '" + name + "' function."); } - String[] functionToRSocketDefinition = StringUtils.delimitedListToStringArray(name, ">"); - Assert.isTrue(functionToRSocketDefinition.length == 2, "Must only contain one output redirect"); - FunctionInvocationWrapper function = functionCatalog.lookup(functionToRSocketDefinition[0], "application/json"); + + Assert.isTrue(functionToRSocketDefinition.length == 2, "Must only contain one output redirect. Was '" + name + "'."); + FunctionInvocationWrapper function = functionCatalog.lookup(functionToRSocketDefinition[0], MimeTypeUtils.APPLICATION_JSON_VALUE); String[] hostPort = StringUtils.delimitedListToStringArray(functionToRSocketDefinition[1], ":"); diff --git a/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationRoutingTests.java b/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationRoutingTests.java index bfc5f5b6e..8ec7fd459 100644 --- a/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationRoutingTests.java +++ b/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationRoutingTests.java @@ -139,9 +139,6 @@ public class RSocketAutoConfigurationRoutingTests { } } - - - @EnableAutoConfiguration @Configuration public static class SampleFunctionConfiguration { diff --git a/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationTests.java b/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationTests.java index 0a1531050..e79654aee 100644 --- a/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationTests.java +++ b/spring-cloud-function-rsocket/src/test/java/org/springframework/cloud/function/rsocket/RSocketAutoConfigurationTests.java @@ -47,6 +47,56 @@ import org.springframework.util.SocketUtils; * @since 3.1 */ public class RSocketAutoConfigurationTests { + + @Test + public void testNonExistingFunctionInRoute() { + int port = SocketUtils.findAvailableTcpPort(); + try ( + ConfigurableApplicationContext applicationContext = + new SpringApplicationBuilder(SampleFunctionConfiguration.class) + .web(WebApplicationType.NONE) + .run("--logging.level.org.springframework.cloud.function=DEBUG", + "--spring.rsocket.server.port=" + port); + ) { + RSocketRequester.Builder rsocketRequesterBuilder = + applicationContext.getBean(RSocketRequester.Builder.class); + + rsocketRequesterBuilder.tcp("localhost", port) + .route("foo") + .data("\"hello\"") + .retrieveMono(String.class) + .as(StepVerifier::create) + .expectError() + .verify(); + } + } + + @Test + public void testNonExistingFunctionInRouteSingleFunctionInCatalog() { + int port = SocketUtils.findAvailableTcpPort(); + try ( + ConfigurableApplicationContext applicationContext = + new SpringApplicationBuilder(SingleFunctionConfiguration.class) + .web(WebApplicationType.NONE) + .run("--logging.level.org.springframework.cloud.function=DEBUG", + "--spring.rsocket.server.port=" + port); + ) { + RSocketRequester.Builder rsocketRequesterBuilder = + applicationContext.getBean(RSocketRequester.Builder.class); + + rsocketRequesterBuilder.tcp("localhost", port) + .route("blah") + .data("\"hello\"") + .retrieveMono(String.class) + .as(StepVerifier::create) + .expectNext("hello") + .expectComplete() + .verify(); + } + } + + + @Test public void testImperativeFunctionAsRequestReplyWithDefinition() { int port = SocketUtils.findAvailableTcpPort(); @@ -535,4 +585,13 @@ public class RSocketAutoConfigurationTests { } + @EnableAutoConfiguration + @Configuration + public static class SingleFunctionConfiguration { + @Bean + public Function echo() { + return v -> v; + } + } + }