diff --git a/spring-cloud-function-samples/function-sample-pojo/src/main/java/com/example/SampleApplication.java b/spring-cloud-function-samples/function-sample-pojo/src/main/java/com/example/SampleApplication.java index 4614753e8..68a6437c5 100644 --- a/spring-cloud-function-samples/function-sample-pojo/src/main/java/com/example/SampleApplication.java +++ b/spring-cloud-function-samples/function-sample-pojo/src/main/java/com/example/SampleApplication.java @@ -28,8 +28,8 @@ import reactor.core.publisher.Flux; public class SampleApplication { @Bean - public Function, Flux> uppercase() { - return flux -> flux.log().map(value -> new Bar(value.uppercase())); + public Function uppercase() { + return value -> new Bar(value.uppercase()); } @Bean diff --git a/spring-cloud-function-samples/function-sample-pojo/src/test/java/com/example/SampleApplicationTests.java b/spring-cloud-function-samples/function-sample-pojo/src/test/java/com/example/SampleApplicationTests.java index 4b662742c..767127d05 100644 --- a/spring-cloud-function-samples/function-sample-pojo/src/test/java/com/example/SampleApplicationTests.java +++ b/spring-cloud-function-samples/function-sample-pojo/src/test/java/com/example/SampleApplicationTests.java @@ -53,10 +53,9 @@ public class SampleApplicationTests { @Test public void single() { - // TODO: make this return a single value assertThat(new TestRestTemplate().postForObject( "http://localhost:" + port + "/uppercase", "{\"value\":\"foo\"}", - String.class)).isEqualTo("[{\"value\":\"FOO\"}]"); + String.class)).isEqualTo("{\"value\":\"FOO\"}"); } @Test diff --git a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/FunctionController.java b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/FunctionController.java index 51c1797b1..108f557eb 100644 --- a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/FunctionController.java +++ b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/FunctionController.java @@ -16,9 +16,11 @@ package org.springframework.cloud.function.web.flux; +import java.util.Optional; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; +import java.util.stream.Stream; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -26,15 +28,16 @@ import org.reactivestreams.Publisher; import org.springframework.cloud.function.context.catalog.FunctionInspector; import org.springframework.cloud.function.context.message.MessageUtils; +import org.springframework.cloud.function.web.flux.constants.WebRequestConstants; import org.springframework.cloud.function.web.flux.request.FluxRequest; import org.springframework.http.HttpStatus; import org.springframework.http.ResponseEntity; import org.springframework.stereotype.Component; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PostMapping; -import org.springframework.web.bind.annotation.RequestAttribute; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.ResponseBody; +import org.springframework.web.context.request.WebRequest; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -65,11 +68,16 @@ public class FunctionController { @PostMapping(path = "/**") @ResponseBody - public ResponseEntity> post( - @RequestAttribute(required = false, name = "org.springframework.cloud.function.web.flux.constants.WebRequestConstants.function") Function, Flux> function, - @RequestAttribute(required = false, name = "org.springframework.cloud.function.web.flux.constants.WebRequestConstants.consumer") Consumer> consumer, - @RequestAttribute(required = false, name = "org.springframework.cloud.function.web.flux.constants.WebRequestConstants.input_single") Boolean single, + public ResponseEntity> post(WebRequest request, @RequestBody FluxRequest body) { + @SuppressWarnings("unchecked") + Function, Flux> function = (Function, Flux>) request + .getAttribute(WebRequestConstants.FUNCTION, WebRequest.SCOPE_REQUEST); + @SuppressWarnings("unchecked") + Consumer> consumer = (Consumer>) request + .getAttribute(WebRequestConstants.CONSUMER, WebRequest.SCOPE_REQUEST); + Boolean single = (Boolean) request.getAttribute(WebRequestConstants.INPUT_SINGLE, + WebRequest.SCOPE_REQUEST); if (function != null) { Flux flux = body.flux(); if (debug) { @@ -82,7 +90,8 @@ public class FunctionController { if (logger.isDebugEnabled()) { logger.debug("Handled POST with function"); } - return ResponseEntity.ok().body(debug ? result.log() : result); + return ResponseEntity.ok().body( + debug ? result.log() : response(request, function, single, result)); } if (consumer != null) { Flux flux = body.flux().cache(); // send a copy back to the caller @@ -98,16 +107,48 @@ public class FunctionController { throw new IllegalArgumentException("no such function"); } + private Publisher response(WebRequest request, Object handler, Boolean single, + Flux result) { + if (single != null && single && isOutputSingle(handler)) { + request.setAttribute(WebRequestConstants.OUTPUT_SINGLE, true, + WebRequest.SCOPE_REQUEST); + return Mono.from(result); + } + request.setAttribute(WebRequestConstants.OUTPUT_SINGLE, false, + WebRequest.SCOPE_REQUEST); + return result; + } + + private boolean isOutputSingle(Object handler) { + Class type = inspector.getOutputType(handler); + Class wrapper = inspector.getOutputWrapper(handler); + if (Stream.class.isAssignableFrom(type)) { + return false; + } + if (wrapper == type) { + return true; + } + if (Mono.class.equals(wrapper) || Optional.class.equals(wrapper)) { + return true; + } + return false; + } + @GetMapping(path = "/**") @ResponseBody - public Publisher get( - @RequestAttribute(required = false, name = "org.springframework.cloud.function.web.flux.constants.WebRequestConstants.function") Function, Flux> function, - @RequestAttribute(required = false, name = "org.springframework.cloud.function.web.flux.constants.WebRequestConstants.supplier") Supplier> supplier, - @RequestAttribute(required = false, name = "org.springframework.cloud.function.web.flux.constants.WebRequestConstants.argument") String argument) { + public Publisher get(WebRequest request) { + @SuppressWarnings("unchecked") + Function, Flux> function = (Function, Flux>) request + .getAttribute(WebRequestConstants.FUNCTION, WebRequest.SCOPE_REQUEST); + @SuppressWarnings("unchecked") + Supplier> supplier = (Supplier>) request + .getAttribute(WebRequestConstants.SUPPLIER, WebRequest.SCOPE_REQUEST); + String argument = (String) request.getAttribute(WebRequestConstants.ARGUMENT, + WebRequest.SCOPE_REQUEST); if (function != null) { return value(function, argument); } - return supplier(supplier); + return response(request, supplier, true, supplier(supplier)); } private Flux supplier(Supplier> supplier) { diff --git a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/constants/WebRequestConstants.java b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/constants/WebRequestConstants.java index 720313dbb..a8fcdd1f1 100644 --- a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/constants/WebRequestConstants.java +++ b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/constants/WebRequestConstants.java @@ -35,5 +35,7 @@ public abstract class WebRequestConstants { public static final String HANDLER = WebRequestConstants.class.getName() + ".handler"; public static final String INPUT_SINGLE = WebRequestConstants.class.getName() + ".input_single"; + public static final String OUTPUT_SINGLE = WebRequestConstants.class.getName() + + ".output_single"; } diff --git a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxReturnValueHandler.java b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxReturnValueHandler.java index 6489efe23..afb7e5ae9 100644 --- a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxReturnValueHandler.java +++ b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/response/FluxReturnValueHandler.java @@ -19,10 +19,8 @@ package org.springframework.cloud.function.web.flux.response; import java.lang.reflect.Method; import java.time.Duration; import java.util.Arrays; +import java.util.Collection; import java.util.List; -import java.util.Optional; -import java.util.function.Supplier; -import java.util.stream.Stream; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -116,7 +114,7 @@ public class FluxReturnValueHandler implements AsyncHandlerMethodReturnValueHand if (ResponseEntity.class.isAssignableFrom(returnType.getParameterType())) { Class bodyType = ResolvableType.forMethodParameter(returnType) .getGeneric(0).resolve(); - return bodyType != null && Flux.class.isAssignableFrom(bodyType); + return bodyType != null && Publisher.class.isAssignableFrom(bodyType); } return false; } @@ -152,8 +150,7 @@ public class FluxReturnValueHandler implements AsyncHandlerMethodReturnValueHand NativeWebRequest.SCOPE_REQUEST); Class type = inspector.getOutputType(handler); - boolean inputSingle = isInputSingle(webRequest, handler); - if (inputSingle && isOutputSingle(handler)) { + if (isOutputSingle(webRequest, handler, type)) { Object result = Flux.from(flux).blockFirst(); if (result instanceof Message) { Message message = (Message) result; @@ -197,30 +194,18 @@ public class FluxReturnValueHandler implements AsyncHandlerMethodReturnValueHand } } - private boolean isInputSingle(NativeWebRequest webRequest, Object handler) { + private boolean isOutputSingle(NativeWebRequest webRequest, Object handler, + Class type) { Boolean single = (Boolean) webRequest.getAttribute( - WebRequestConstants.INPUT_SINGLE, NativeWebRequest.SCOPE_REQUEST); + WebRequestConstants.OUTPUT_SINGLE, NativeWebRequest.SCOPE_REQUEST); if (single == null) { - return handler instanceof Supplier; + // If the declared return type is a collection then we can render it as a + // "single" value + return Collection.class.isAssignableFrom(type); } return single; } - private boolean isOutputSingle(Object handler) { - Class type = inspector.getOutputType(handler); - Class wrapper = inspector.getOutputWrapper(handler); - if (Stream.class.isAssignableFrom(type)) { - return false; - } - if (wrapper == type) { - return true; - } - if (Mono.class.equals(wrapper) || Optional.class.equals(wrapper)) { - return true; - } - return false; - } - private MediaType findMediaType(NativeWebRequest webRequest) { List accepts = Arrays.asList(MediaType.ALL); MediaType mediaType = null;