diff --git a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/FunctionController.java b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/FunctionController.java index 63c597ead..1eb556c26 100644 --- a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/FunctionController.java +++ b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/FunctionController.java @@ -60,8 +60,9 @@ public class FunctionController { @PostMapping(path = "/**") @ResponseBody public ResponseEntity> post( - @RequestAttribute(required = false, name = "org.springframework.cloud.function.web.flux.FunctionHandlerMapping.function") Function, Flux> function, - @RequestAttribute(required = false, name = "org.springframework.cloud.function.web.flux.FunctionHandlerMapping.consumer") Consumer> consumer, + @RequestAttribute(required = false, name = "org.springframework.cloud.function.web.flux.constants.WebRequestConstants.function") Function, Flux> function, + @RequestAttribute(required = false, name = "org.springframework.cloud.function.web.flux.constants.WebRequestConstants.consumer") Consumer> consumer, + @RequestAttribute(required = false, name = "org.springframework.cloud.function.web.flux.constants.WebRequestConstants.input_single") Boolean single, @RequestBody FluxRequest body) { if (function != null) { Flux result = (Flux) function.apply(body.flux()); @@ -84,9 +85,9 @@ public class FunctionController { @GetMapping(path = "/**") @ResponseBody public Object get( - @RequestAttribute(required = false, name = "org.springframework.cloud.function.web.flux.FunctionHandlerMapping.function") Function, Flux> function, - @RequestAttribute(required = false, name = "org.springframework.cloud.function.web.flux.FunctionHandlerMapping.supplier") Supplier> supplier, - @RequestAttribute(required = false, name = "org.springframework.cloud.function.web.flux.FunctionHandlerMapping.argument") String argument) { + @RequestAttribute(required = false, name = "org.springframework.cloud.function.web.flux.constants.WebRequestConstants.function") Function, Flux> function, + @RequestAttribute(required = false, name = "org.springframework.cloud.function.web.flux.constants.WebRequestConstants.supplier") Supplier> supplier, + @RequestAttribute(required = false, name = "org.springframework.cloud.function.web.flux.constants.WebRequestConstants.argument") String argument) { if (function != null) { return value(function, argument); } diff --git a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/FunctionHandlerMapping.java b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/FunctionHandlerMapping.java index feeba932a..6e8a72335 100644 --- a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/FunctionHandlerMapping.java +++ b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/FunctionHandlerMapping.java @@ -28,7 +28,7 @@ import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.cloud.function.context.FunctionInspector; import org.springframework.cloud.function.registry.FunctionCatalog; -import org.springframework.cloud.function.web.flux.request.FluxHandlerMethodArgumentResolver; +import org.springframework.cloud.function.web.flux.constants.WebRequestConstants; import org.springframework.context.annotation.Configuration; import org.springframework.web.method.HandlerMethod; import org.springframework.web.servlet.HandlerMapping; @@ -43,14 +43,6 @@ import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandl public class FunctionHandlerMapping extends RequestMappingHandlerMapping implements InitializingBean { - public static final String FUNCTION = FunctionHandlerMapping.class.getName() - + ".function"; - public static final String CONSUMER = FunctionHandlerMapping.class.getName() - + ".consumer"; - public static final String SUPPLIER = FunctionHandlerMapping.class.getName() - + ".supplier"; - public static final String ARGUMENT = FunctionHandlerMapping.class.getName() - + ".argument"; private final FunctionCatalog functions; private final FunctionController controller; @@ -95,7 +87,7 @@ public class FunctionHandlerMapping extends RequestMappingHandlerMapping if (logger.isDebugEnabled()) { logger.debug("Found function for GET: " + path); } - request.setAttribute(FluxHandlerMethodArgumentResolver.HANDLER, function); + request.setAttribute(WebRequestConstants.HANDLER, function); return handler; } function = findFunctionForPost(request, path); @@ -103,7 +95,7 @@ public class FunctionHandlerMapping extends RequestMappingHandlerMapping if (logger.isDebugEnabled()) { logger.debug("Found function for POST: " + path); } - request.setAttribute(FluxHandlerMethodArgumentResolver.HANDLER, function); + request.setAttribute(WebRequestConstants.HANDLER, function); return handler; } return null; @@ -116,12 +108,12 @@ public class FunctionHandlerMapping extends RequestMappingHandlerMapping path = path.startsWith("/") ? path.substring(1) : path; Consumer consumer = functions.lookupConsumer(path); if (consumer != null) { - request.setAttribute(CONSUMER, consumer); + request.setAttribute(WebRequestConstants.CONSUMER, consumer); return consumer; } Function function = functions.lookupFunction(path); if (function != null) { - request.setAttribute(FUNCTION, function); + request.setAttribute(WebRequestConstants.FUNCTION, function); return function; } return null; @@ -134,7 +126,7 @@ public class FunctionHandlerMapping extends RequestMappingHandlerMapping path = path.startsWith("/") ? path.substring(1) : path; Supplier supplier = functions.lookupSupplier(path); if (supplier != null) { - request.setAttribute(SUPPLIER, supplier); + request.setAttribute(WebRequestConstants.SUPPLIER, supplier); return supplier; } StringBuilder builder = new StringBuilder(); @@ -150,8 +142,8 @@ public class FunctionHandlerMapping extends RequestMappingHandlerMapping : null; Function function = functions.lookupFunction(name); if (function != null) { - request.setAttribute(FUNCTION, function); - request.setAttribute(ARGUMENT, value); + request.setAttribute(WebRequestConstants.FUNCTION, function); + request.setAttribute(WebRequestConstants.ARGUMENT, value); return function; } } diff --git a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/constants/WebRequestConstants.java b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/constants/WebRequestConstants.java new file mode 100644 index 000000000..720313dbb --- /dev/null +++ b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/constants/WebRequestConstants.java @@ -0,0 +1,39 @@ +/* + * Copyright 2016-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.cloud.function.web.flux.constants; + +/** + * Common storage for web request attribute names (in a separate package to avoid cycles). + * + * @author Dave Syer + * + */ +public abstract class WebRequestConstants { + + public static final String FUNCTION = WebRequestConstants.class.getName() + + ".function"; + public static final String CONSUMER = WebRequestConstants.class.getName() + + ".consumer"; + public static final String SUPPLIER = WebRequestConstants.class.getName() + + ".supplier"; + public static final String ARGUMENT = WebRequestConstants.class.getName() + + ".argument"; + public static final String HANDLER = WebRequestConstants.class.getName() + ".handler"; + public static final String INPUT_SINGLE = WebRequestConstants.class.getName() + + ".input_single"; + +} diff --git a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/request/FluxHandlerMethodArgumentResolver.java b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/request/FluxHandlerMethodArgumentResolver.java index bf1f6f249..05a9d8ffa 100644 --- a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/request/FluxHandlerMethodArgumentResolver.java +++ b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/request/FluxHandlerMethodArgumentResolver.java @@ -30,6 +30,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.cloud.function.context.FunctionInspector; +import org.springframework.cloud.function.web.flux.constants.WebRequestConstants; import org.springframework.core.MethodParameter; import org.springframework.core.Ordered; import org.springframework.http.MediaType; @@ -52,9 +53,6 @@ public class FluxHandlerMethodArgumentResolver private static Log logger = LogFactory .getLog(FluxHandlerMethodArgumentResolver.class); - public static final String HANDLER = FluxHandlerMethodArgumentResolver.class.getName() - + ".HANDLER"; - private final ObjectMapper mapper; private FunctionInspector inspector; @@ -74,7 +72,8 @@ public class FluxHandlerMethodArgumentResolver public Object resolveArgument(MethodParameter parameter, ModelAndViewContainer mavContainer, NativeWebRequest webRequest, WebDataBinderFactory binderFactory) throws Exception { - Object handler = webRequest.getAttribute(HANDLER, NativeWebRequest.SCOPE_REQUEST); + Object handler = webRequest.getAttribute(WebRequestConstants.HANDLER, + NativeWebRequest.SCOPE_REQUEST); Class type = inspector.getInputType(inspector.getName(handler)); if (type == null) { type = Object.class; @@ -96,6 +95,7 @@ public class FluxHandlerMethodArgumentResolver .constructCollectionLikeType(ArrayList.class, type)); } catch (JsonMappingException e) { + nativeRequest.setAttribute(WebRequestConstants.INPUT_SINGLE, true); body = Arrays.asList( mapper.readValue(nativeRequest.getContentAsByteArray(), type)); } diff --git a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxReturnValueHandler.java b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxReturnValueHandler.java index 6665b3bd1..3ce1f76c4 100644 --- a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxReturnValueHandler.java +++ b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxReturnValueHandler.java @@ -16,6 +16,7 @@ package org.springframework.cloud.function.web.flux.response; +import java.lang.reflect.Method; import java.time.Duration; import java.util.Arrays; import java.util.List; @@ -24,18 +25,21 @@ import javax.servlet.http.HttpServletResponse; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.el.stream.Optional; import org.reactivestreams.Publisher; import org.springframework.cloud.function.context.FunctionInspector; -import org.springframework.cloud.function.web.flux.request.FluxHandlerMethodArgumentResolver; +import org.springframework.cloud.function.web.flux.constants.WebRequestConstants; import org.springframework.core.MethodParameter; import org.springframework.core.ResolvableType; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.util.ReflectionUtils; import org.springframework.web.context.request.NativeWebRequest; import org.springframework.web.method.support.AsyncHandlerMethodReturnValueHandler; import org.springframework.web.method.support.ModelAndViewContainer; +import org.springframework.web.servlet.mvc.method.annotation.RequestResponseBodyMethodProcessor; import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitter; import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitterReturnValueHandler; @@ -50,19 +54,28 @@ import reactor.core.publisher.Mono; */ public class FluxReturnValueHandler implements AsyncHandlerMethodReturnValueHandler { - private static Log logger = LogFactory - .getLog(FluxReturnValueHandler.class); + private static Log logger = LogFactory.getLog(FluxReturnValueHandler.class); private ResponseBodyEmitterReturnValueHandler delegate; + private RequestResponseBodyMethodProcessor single; private long timeout = 1000L; private static final MediaType EVENT_STREAM = MediaType.valueOf("text/event-stream"); private FunctionInspector inspector; + private MethodParameter singleReturnType; + public FluxReturnValueHandler(FunctionInspector inspector, List> messageConverters) { this.inspector = inspector; this.delegate = new ResponseBodyEmitterReturnValueHandler(messageConverters); + this.single = new RequestResponseBodyMethodProcessor(messageConverters); + Method method = ReflectionUtils.findMethod(getClass(), "singleValue"); + singleReturnType = new MethodParameter(method, -1); + } + + ResponseEntity singleValue() { + return null; } /** @@ -120,24 +133,46 @@ public class FluxReturnValueHandler implements AsyncHandlerMethodReturnValueHand } Publisher flux = (Publisher) adaptFrom; - Object handler = webRequest.getAttribute( - FluxHandlerMethodArgumentResolver.HANDLER, + Object handler = webRequest.getAttribute(WebRequestConstants.HANDLER, NativeWebRequest.SCOPE_REQUEST); Class type = inspector.getOutputType(inspector.getName(handler)); + Boolean inputSingle = (Boolean) webRequest.getAttribute( + WebRequestConstants.INPUT_SINGLE, NativeWebRequest.SCOPE_REQUEST); + if (inputSingle!=null && inputSingle && isOutputSingle(handler)) { + single.handleReturnValue(Flux.from(flux).blockFirst(), singleReturnType, + mavContainer, webRequest); + return; + } + MediaType mediaType = null; if (isPlainText(webRequest) && CharSequence.class.isAssignableFrom(type)) { mediaType = MediaType.TEXT_PLAIN; - } else { + } + else { mediaType = findMediaType(webRequest); } if (logger.isDebugEnabled()) { - logger.debug("Handling return value " + type + " with media type: " + mediaType); + logger.debug( + "Handling return value " + type + " with media type: " + mediaType); } delegate.handleReturnValue(getEmitter(timeout, flux, mediaType), returnType, mavContainer, webRequest); } + private boolean isOutputSingle(Object handler) { + String name = inspector.getName(handler); + Class type = inspector.getOutputType(name); + Class wrapper = inspector.getOutputWrapper(name); + if (wrapper==type) { + return true; + } + if (Mono.class.equals(wrapper) || Optional.class.equals(wrapper)) { + return true; + } + return false; + } + private MediaType findMediaType(NativeWebRequest webRequest) { List accepts = Arrays.asList(MediaType.ALL); MediaType mediaType = null; diff --git a/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/RestApplicationTests.java b/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/RestApplicationTests.java index 0a28c0b71..240cfcfdb 100644 --- a/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/RestApplicationTests.java +++ b/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/RestApplicationTests.java @@ -283,6 +283,25 @@ public class RestApplicationTests { .isEqualTo("[{\"value\":\"FOO\"}]"); } + @Test + public void bareUppercaseFoos() throws Exception { + ResponseEntity result = rest.exchange(RequestEntity + .post(new URI("/bareUpFoos")).contentType(MediaType.APPLICATION_JSON) + .body("[{\"value\":\"foo\"},{\"value\":\"bar\"}]"), String.class); + assertThat(result.getBody()) + .isEqualTo("[{\"value\":\"FOO\"},{\"value\":\"BAR\"}]"); + } + + @Test + public void bareUppercaseFoo() throws Exception { + // Single Foo can be parsed and returns a single value if the function is defined that way + ResponseEntity result = rest.exchange(RequestEntity + .post(new URI("/bareUpFoos")).contentType(MediaType.APPLICATION_JSON) + .body("{\"value\":\"foo\"}"), String.class); + assertThat(result.getBody()) + .isEqualTo("{\"value\":\"FOO\"}"); + } + @Test public void bareUppercase() throws Exception { ResponseEntity result = rest.exchange(RequestEntity @@ -382,6 +401,11 @@ public class RestApplicationTests { .map(value -> new Foo(value.getValue().trim().toUpperCase())); } + @Bean + public Function bareUpFoos() { + return value -> new Foo(value.getValue().trim().toUpperCase()); + } + @Bean public Function, Flux> wrap() { return flux -> flux.log().map(value -> ".." + value + "..");