diff --git a/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/FunctionType.java b/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/FunctionType.java index fab2bb672..f79bf1206 100644 --- a/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/FunctionType.java +++ b/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/FunctionType.java @@ -149,12 +149,25 @@ public class FunctionType { .getType()); } - public FunctionType wrap(Class wrapper) { - if (wrapper.isAssignableFrom(getInputWrapper()) || !isWrapper(wrapper)) { + public FunctionType wrap(Class input, Class output) { + if (!isWrapper(input) && !isWrapper(output)) { + return this; + } + if (!isWrapper(input) || !isWrapper(output)) { + throw new IllegalArgumentException("Both wrapper types must be wrappers in (" + + input + ", " + output + ")"); + } + if (input.isAssignableFrom(getInputWrapper()) + && output.isAssignableFrom(getOutputWrapper())) { return this; } return new FunctionType(ResolvableType.forClassWithGenerics(Function.class, - wrap(wrapper, getInputType()), wrap(wrapper, getOutputType())).getType()); + wrapper(input, getInputType()), wrapper(output, getOutputType())) + .getType()); + } + + public FunctionType wrap(Class wrapper) { + return wrap(wrapper, wrapper); } public static FunctionType compose(FunctionType input, FunctionType output) { @@ -173,7 +186,7 @@ public class FunctionType { .getType()); } - private ResolvableType wrap(Class wrapper, Class type) { + private ResolvableType wrapper(Class wrapper, Class type) { return wrap(this, wrapper, type); } diff --git a/spring-cloud-function-context/src/test/java/org/springframework/cloud/function/context/config/BeanFactoryFunctionCatalogTests.java b/spring-cloud-function-context/src/test/java/org/springframework/cloud/function/context/config/BeanFactoryFunctionCatalogTests.java index 6bce9dbae..7d3cd3eaa 100644 --- a/spring-cloud-function-context/src/test/java/org/springframework/cloud/function/context/config/BeanFactoryFunctionCatalogTests.java +++ b/spring-cloud-function-context/src/test/java/org/springframework/cloud/function/context/config/BeanFactoryFunctionCatalogTests.java @@ -17,7 +17,9 @@ package org.springframework.cloud.function.context.config; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; @@ -78,6 +80,22 @@ public class BeanFactoryFunctionCatalogTests { assertThat(foos.apply(Flux.just(2)).blockFirst()).isEqualTo("i=2"); } + @Test + public void registerFunctionWithMonoType() { + processor.register( + new FunctionRegistration, Mono>>>( + flux -> flux.collect(HashMap::new, + (map, word) -> map.merge(word, 1, Integer::sum))) + .names("foos") + .type(FunctionType.from(String.class) + .to(Map.class) + .wrap(Flux.class, Mono.class).getType())); + Function, Mono>> foos = processor + .lookup(Function.class, ""); + assertThat(foos.apply(Flux.just("one", "one", "two")).block()) + .containsEntry("one", 2); + } + @Test public void lookupNonExistentConsumerWithEmptyName() { processor.register(new FunctionRegistration<>(new Foos()).names("foos")); diff --git a/spring-cloud-function-context/src/test/java/org/springframework/cloud/function/context/config/ContextFunctionCatalogAutoConfigurationTests.java b/spring-cloud-function-context/src/test/java/org/springframework/cloud/function/context/config/ContextFunctionCatalogAutoConfigurationTests.java index cab3c6adc..809514a69 100644 --- a/spring-cloud-function-context/src/test/java/org/springframework/cloud/function/context/config/ContextFunctionCatalogAutoConfigurationTests.java +++ b/spring-cloud-function-context/src/test/java/org/springframework/cloud/function/context/config/ContextFunctionCatalogAutoConfigurationTests.java @@ -19,6 +19,7 @@ package org.springframework.cloud.function.context.config; import java.net.URL; import java.net.URLClassLoader; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.function.Consumer; @@ -236,6 +237,22 @@ public class ContextFunctionCatalogAutoConfigurationTests { .isAssignableFrom(Publisher.class); } + @Test + public void monoFunction() { + create(MonoConfiguration.class); + assertThat(context.getBean("function")).isInstanceOf(Function.class); + assertThat(catalog.>lookup(Function.class, "function")) + .isInstanceOf(Function.class); + assertThat(inspector.isMessage(catalog.lookup(Function.class, "function"))) + .isFalse(); + assertThat(inspector.getInputType(catalog.lookup(Function.class, "function"))) + .isAssignableFrom(String.class); + assertThat(inspector.getInputWrapper(catalog.lookup(Function.class, "function"))) + .isAssignableFrom(Flux.class); + assertThat(inspector.getOutputWrapper(catalog.lookup(Function.class, "function"))) + .isAssignableFrom(Mono.class); + } + @Test public void messageFunction() { create(MessageConfiguration.class); @@ -756,6 +773,16 @@ public class ContextFunctionCatalogAutoConfigurationTests { } } + @EnableAutoConfiguration + @Configuration + protected static class MonoConfiguration { + @Bean + public Function, Mono>> function() { + return flux -> flux.collect(HashMap::new, + (map, word) -> map.merge(word, 1, Integer::sum)); + } + } + @EnableAutoConfiguration @Configuration protected static class MessageConfiguration { diff --git a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/FunctionController.java b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/FunctionController.java index e760245f9..5b4be3332 100644 --- a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/FunctionController.java +++ b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/flux/FunctionController.java @@ -16,6 +16,7 @@ package org.springframework.cloud.function.web.flux; +import java.util.Collection; import java.util.Optional; import java.util.function.Consumer; import java.util.function.Function; @@ -124,12 +125,24 @@ public class FunctionController { return Mono.from(result); } + if (isInputMultiple(handler) && isOutputSingle(handler)) { + request.setAttribute(WebRequestConstants.OUTPUT_SINGLE, true, + WebRequest.SCOPE_REQUEST); + return Mono.from(result); + } + request.setAttribute(WebRequestConstants.OUTPUT_SINGLE, false, WebRequest.SCOPE_REQUEST); return result; } + private boolean isInputMultiple(Object handler) { + Class type = inspector.getInputType(handler); + Class wrapper = inspector.getInputWrapper(handler); + return Collection.class.isAssignableFrom(type) || Flux.class.equals(wrapper); + } + private boolean isOutputSingle(Object handler) { Class type = inspector.getOutputType(handler); Class wrapper = inspector.getOutputWrapper(handler); diff --git a/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/RestApplicationTests.java b/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/RestApplicationTests.java index 0e772f30f..65eb60636 100644 --- a/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/RestApplicationTests.java +++ b/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/RestApplicationTests.java @@ -15,10 +15,24 @@ */ package org.springframework.cloud.function.web; +import java.net.URI; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Collectors; + import org.junit.Before; import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; @@ -38,18 +52,12 @@ import org.springframework.test.context.junit4.SpringRunner; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; -import reactor.core.publisher.Flux; - -import java.net.URI; -import java.time.Duration; -import java.util.*; -import java.util.function.Consumer; -import java.util.function.Function; -import java.util.function.Supplier; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + /** * @author Dave Syer */ @@ -219,7 +227,7 @@ public class RestApplicationTests { assertThat(rest.exchange( RequestEntity.get(new URI("/sentences")).accept(MediaType.ALL).build(), String.class).getBody()) - .isEqualTo("[[\"go\",\"home\"],[\"come\",\"back\"]]"); + .isEqualTo("[[\"go\",\"home\"],[\"come\",\"back\"]]"); } @Test @@ -418,7 +426,7 @@ public class RestApplicationTests { // The new line in the middle is optional .body("[{\"value\":\"foo\"},\n{\"value\":\"bar\"}]"), String.class).getBody()) - .isEqualTo("[{\"value\":\"FOO\"},{\"value\":\"BAR\"}]"); + .isEqualTo("[{\"value\":\"FOO\"},{\"value\":\"BAR\"}]"); } @Test @@ -426,7 +434,7 @@ public class RestApplicationTests { assertThat(rest.exchange(RequestEntity.post(new URI("/uppercase")) .accept(EVENT_STREAM).contentType(MediaType.APPLICATION_JSON) .body("[\"foo\",\"bar\"]"), String.class).getBody()) - .isEqualTo(sse("(FOO)", "(BAR)")); + .isEqualTo(sse("(FOO)", "(BAR)")); } @Test @@ -437,10 +445,19 @@ public class RestApplicationTests { map.put("A", Arrays.asList("1", "2", "3")); map.put("B", Arrays.asList("5", "6")); - assertThat(rest.exchange(RequestEntity.post(new URI("/sum")) - .accept(MediaType.APPLICATION_JSON).contentType(MediaType.MULTIPART_FORM_DATA) - .body(map), String.class).getBody()) - .isEqualTo("[{\"A\":6,\"B\":11}]"); + assertThat(rest.exchange( + RequestEntity.post(new URI("/sum")).accept(MediaType.APPLICATION_JSON) + .contentType(MediaType.MULTIPART_FORM_DATA).body(map), + String.class).getBody()).isEqualTo("[{\"A\":6,\"B\":11}]"); + } + + @Test + public void count() throws Exception { + List list = Arrays.asList("A", "B", "A"); + assertThat(rest.exchange( + RequestEntity.post(new URI("/count")).accept(MediaType.APPLICATION_JSON) + .contentType(MediaType.APPLICATION_JSON).body(list), + String.class).getBody()).isEqualTo("{\"A\":2,\"B\":1}"); } private String sse(String... values) { @@ -594,18 +611,16 @@ public class RestApplicationTests { @Bean public Function, Map> sum() { - return valueMap -> valueMap - .entrySet() - .stream() - .collect( - Collectors - .toMap( - Map.Entry::getKey, - values -> values.getValue().stream().mapToInt(Integer::parseInt).sum() - ) - ); + return valueMap -> valueMap.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, values -> values + .getValue().stream().mapToInt(Integer::parseInt).sum())); } + @Bean + public Function, Mono>> count() { + return flux -> flux.collect(HashMap::new, + (map, word) -> map.merge(word, 1, Integer::sum)); + } } public static class Foo {