diff --git a/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/catalog/SimpleFunctionRegistry.java b/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/catalog/SimpleFunctionRegistry.java index 88fcd34ea..bcd7a5b19 100644 --- a/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/catalog/SimpleFunctionRegistry.java +++ b/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/catalog/SimpleFunctionRegistry.java @@ -720,7 +720,25 @@ public class SimpleFunctionRegistry implements FunctionRegistry, FunctionInspect */ @SuppressWarnings("unchecked") private Object fluxifyInputIfNecessary(Object input) { - if (!(input instanceof Publisher) && this.isTypePublisher(this.inputType) && !FunctionTypeUtils.isMultipleArgumentType(this.inputType)) { + if (FunctionTypeUtils.isMultipleArgumentType(this.inputType)) { + return input; + } + + if (!this.isRoutingFunction() && !(input instanceof Publisher)) { + Object payload = input; + if (input instanceof Message) { + payload = ((Message) input).getPayload(); + } + if (JsonMapper.isJsonStringRepresentsCollection(payload) && !FunctionTypeUtils.isTypeCollection(this.inputType)) { + payload = jsonMapper.fromJson(payload, List.class); + MessageHeaders headers = ((Message) input).getHeaders(); + input = ((List) payload).stream() + .map(p -> MessageBuilder.withPayload(p).copyHeaders(headers).build()) + .collect(Collectors.toList()); + } + } + + if (this.isTypePublisher(this.inputType) && !(input instanceof Publisher)) { if (input == null) { input = FunctionTypeUtils.isMono(this.inputType) ? Mono.empty() : Flux.empty(); } @@ -740,6 +758,9 @@ public class SimpleFunctionRegistry implements FunctionRegistry, FunctionInspect input = FunctionTypeUtils.isMono(this.inputType) ? Mono.just(input) : Flux.just(input); } } + else if (input instanceof Iterable && !FunctionTypeUtils.isTypeCollection(this.inputType)) { + input = Flux.fromIterable((Iterable) input); + } return input; } @@ -946,7 +967,7 @@ public class SimpleFunctionRegistry implements FunctionRegistry, FunctionInspect } } else { - convertedInput = this.convertNonMessageInputIfNecessary(type, input, JsonMapper.isJsonString(input)); + convertedInput = this.convertNonMessageInputIfNecessary(type, input, JsonMapper.isJsonString(input) || input instanceof Map); if (convertedInput != null && logger.isDebugEnabled()) { logger.debug("Converted input: " + input + " to: " + convertedInput); } diff --git a/spring-cloud-function-context/src/test/java/org/springframework/cloud/function/context/catalog/BeanFactoryAwareFunctionRegistryTests.java b/spring-cloud-function-context/src/test/java/org/springframework/cloud/function/context/catalog/BeanFactoryAwareFunctionRegistryTests.java index ac222d1fd..996822c7f 100644 --- a/spring-cloud-function-context/src/test/java/org/springframework/cloud/function/context/catalog/BeanFactoryAwareFunctionRegistryTests.java +++ b/spring-cloud-function-context/src/test/java/org/springframework/cloud/function/context/catalog/BeanFactoryAwareFunctionRegistryTests.java @@ -547,6 +547,29 @@ public class BeanFactoryAwareFunctionRegistryTests { assertThat(result.getHeaders().get("after")).isEqualTo("bar"); } + @SuppressWarnings({ "rawtypes", "unchecked" }) + @Test + public void testEachElementInFluxIsProcessed() { + FunctionCatalog catalog = this.configureCatalog(SampleFunctionConfiguration.class); + Function f = catalog.lookup("uppercasePerson"); + + Flux flux = Flux.just("{\"id\":1, \"name\":\"oleg\"}", "{\"id\":2, \"name\":\"seva\"}"); + Flux result = (Flux) f.apply(flux); + + List list = (List) result.collectList().block(); + assertThat(list.size()).isEqualTo(2); + assertThat(list.get(0).name).isEqualTo("OLEG"); + assertThat(list.get(1).name).isEqualTo("SEVA"); + + + + result = (Flux) f.apply(new GenericMessage("[{\"id\":1, \"name\":\"oleg\"}, {\"id\":2, \"name\":\"seva\"}]")); + list = (List) result.collectList().block(); + assertThat(list.size()).isEqualTo(2); + assertThat(list.get(0).name).isEqualTo("OLEG"); + assertThat(list.get(1).name).isEqualTo("SEVA"); + } + @Test public void testGH_608() { ApplicationContext context = new SpringApplicationBuilder(SampleFunctionConfiguration.class) diff --git a/spring-cloud-function-web/refactoring_notes b/spring-cloud-function-web/refactoring_notes new file mode 100644 index 000000000..d4b44cad3 --- /dev/null +++ b/spring-cloud-function-web/refactoring_notes @@ -0,0 +1,6 @@ +If a function returns Flux, we must represent output as JSON Array/Collection since we never know how many elements such flux will contain per each invocation. +For that same reason we can't use TEXT/PLAIN as CT + +NON-WEB +When sendng collection of objects to function who's input is not collection, the inpt will be converted to flux and the result is alos going to be flux. +That is to ensure that the function is invoked with idividual. . . \ No newline at end of file diff --git a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/mvc/FunctionController.java b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/mvc/FunctionController.java index 822f97530..009f62b6e 100644 --- a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/mvc/FunctionController.java +++ b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/mvc/FunctionController.java @@ -17,6 +17,7 @@ package org.springframework.cloud.function.web.mvc; import java.util.Arrays; +import java.util.Collection; import java.util.Iterator; import java.util.List; import java.util.stream.Collectors; @@ -29,6 +30,8 @@ import org.springframework.cloud.function.context.catalog.SimpleFunctionRegistry import org.springframework.cloud.function.web.RequestProcessor; import org.springframework.cloud.function.web.RequestProcessor.FunctionWrapper; import org.springframework.cloud.function.web.constants.WebRequestConstants; +import org.springframework.cloud.function.web.util.HeaderUtils; +import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.http.ResponseEntity.BodyBuilder; @@ -37,6 +40,7 @@ import org.springframework.messaging.support.MessageBuilder; import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; @@ -89,15 +93,6 @@ public class FunctionController { return this.processor.post(wrapper, null, false); } - @PostMapping(path = "/**") - @ResponseBody - public Mono> post(WebRequest request, - @RequestBody(required = false) String body) { - FunctionWrapper wrapper = wrapper(request); - Mono> result = this.processor.post(wrapper, body, false); - return result; - } - @PostMapping(path = "/**", produces = MediaType.TEXT_EVENT_STREAM_VALUE) @ResponseBody public Mono>> postStream(WebRequest request, @@ -108,13 +103,6 @@ public class FunctionController { .body((Publisher) response.getBody())); } - @GetMapping(path = "/**") - @ResponseBody - public Mono> get(WebRequest request) { - FunctionWrapper wrapper = wrapper(request); - return this.processor.get(wrapper); - } - @GetMapping(path = "/**", produces = MediaType.TEXT_EVENT_STREAM_VALUE) @ResponseBody public Mono>> getStream(WebRequest request) { @@ -123,6 +111,79 @@ public class FunctionController { .headers(response.getHeaders()).body((Publisher) response.getBody())); } + @PostMapping(path = "/**") + @ResponseBody + public Object post(WebRequest request, @RequestBody(required = false) String body) { + String argument = StringUtils.hasText(body) ? body : ""; + return this.doProcess(request, argument); + } + + @GetMapping(path = "/**") + @ResponseBody + public Object get(WebRequest request) { + String argument = (String) request.getAttribute(WebRequestConstants.ARGUMENT, WebRequest.SCOPE_REQUEST); + return this.doProcess(request, argument); + } + + @SuppressWarnings({ "rawtypes", "unchecked" }) + private Object doProcess(WebRequest request, String argument) { + FunctionWrapper wrapper = wrapper(request); + + FunctionInvocationWrapper function = wrapper.function(); + + HttpHeaders headers = wrapper.headers(); + + Message inputMessage = argument == null ? null : MessageBuilder.withPayload(argument).copyHeaders(headers.toSingleValueMap()).build(); + + if (function.isRoutingFunction()) { + function.setSkipOutputConversion(true); + } + + Object result = function.apply(inputMessage); + + BodyBuilder responseOkBuilder = ResponseEntity.ok().headers(HeaderUtils.sanitize(headers)); + if (result instanceof Publisher) { + if (result instanceof Flux) { + result = ((Flux) result).collectList(); + } + + if (function.isConsumer()) { + ((Mono) result).subscribe(); + return ResponseEntity.accepted().headers(HeaderUtils.sanitize(headers)).build(); + } + else { + result = Mono.from((Publisher) result).map(v -> { + if (v instanceof Iterable) { + List aggregatedResult = (List) ((Collection) v).stream().map(m -> { + return m instanceof Message ? this.doProcessMessage(responseOkBuilder, (Message) m) : m; + }).collect(Collectors.toList()); + return Mono.just(responseOkBuilder.body(aggregatedResult)); + } + else if (v instanceof Message) { + return this.doProcessMessage(responseOkBuilder, (Message) v); + } + else { + return Mono.just(v); + } + }); + return result; + } + } + else if (function.isConsumer()) { + return ResponseEntity.accepted().headers(HeaderUtils.sanitize(headers)).build(); + } + else { + return result instanceof Message ? + responseOkBuilder.headers(HeaderUtils.fromMessage(((Message) result).getHeaders())).body(((Message) result).getPayload()) : + responseOkBuilder.body(result); + } + } + + private Object doProcessMessage(BodyBuilder responseOkBuilder, Message message) { + responseOkBuilder.headers(HeaderUtils.fromMessage(message.getHeaders())); + return message.getPayload(); + } + private FunctionWrapper wrapper(WebRequest request) { FunctionInvocationWrapper function = (FunctionInvocationWrapper) request .getAttribute(WebRequestConstants.HANDLER, WebRequest.SCOPE_REQUEST); diff --git a/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/mvc/HttpGetIntegrationTests.java b/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/mvc/HttpGetIntegrationTests.java index 01af8dc00..d301bd134 100644 --- a/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/mvc/HttpGetIntegrationTests.java +++ b/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/mvc/HttpGetIntegrationTests.java @@ -209,19 +209,19 @@ public class HttpGetIntegrationTests { @Test public void postMoreFoo() { assertThat(this.rest.getForObject("/post/more/foo", String.class)) - .isEqualTo("(FOO)"); + .isEqualTo("[\"(FOO)\"]"); } @Test public void uppercaseGet() { assertThat(this.rest.getForObject("/uppercase/foo", String.class)) - .isEqualTo("(FOO)"); + .isEqualTo("[\"(FOO)\"]"); } @Test public void convertGet() { assertThat(this.rest.getForObject("/wrap/123", String.class)) - .isEqualTo("..123.."); + .isEqualTo("[\"..123..\"]"); } @Test @@ -235,10 +235,12 @@ public class HttpGetIntegrationTests { assertThat(this.rest .exchange(RequestEntity.get(new URI("/entity/321")) .accept(MediaType.APPLICATION_JSON).build(), String.class) - .getBody()).isEqualTo("{\"value\":321}"); + .getBody()).isEqualTo("[{\"value\":321}]"); } @Test + @Disabled + // this test is wrong since it is returning Flux while setting CT to TEXT_PLAIN. We can't convert it public void compose() throws Exception { ResponseEntity result = this.rest.exchange(RequestEntity .get(new URI("/concat,reverse/foo")).accept(MediaType.TEXT_PLAIN).build(), @@ -338,7 +340,7 @@ public class HttpGetIntegrationTests { public Supplier> timeout() { return () -> Flux.defer(() -> Flux.create(emitter -> { emitter.next("foo"); - }).timeout(Duration.ofMillis(100L), Flux.empty())); + }).timeout(Duration.ofMillis(1000L), Flux.empty())); } @Bean diff --git a/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/mvc/HttpPostIntegrationTests.java b/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/mvc/HttpPostIntegrationTests.java index 61e6dda65..b1004f709 100644 --- a/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/mvc/HttpPostIntegrationTests.java +++ b/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/mvc/HttpPostIntegrationTests.java @@ -420,7 +420,9 @@ public class HttpPostIntegrationTests { @Bean public Consumer bareUpdates() { - return value -> this.list.add(value); + return value -> { + this.list.add(value); + }; } @Bean("not/a")