diff --git a/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/FunctionType.java b/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/FunctionType.java index 46bc87bd6..3ca0f6166 100644 --- a/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/FunctionType.java +++ b/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/FunctionType.java @@ -98,6 +98,9 @@ public class FunctionType { } public static boolean isWrapper(Type type) { + if (type instanceof ParameterizedType) { + type = ((ParameterizedType)type).getRawType(); + } return Publisher.class.equals(type) || Flux.class.equals(type) || Mono.class.equals(type) || Optional.class.equals(type); } @@ -147,21 +150,34 @@ public class FunctionType { public static FunctionType compose(FunctionType input, FunctionType output) { ResolvableType inputGeneric = input(input); ResolvableType outputGeneric = output(output); + if (!isWrapper(outputGeneric.getType())) { + ResolvableType inputOutput = output(input); + if (isWrapper(inputOutput.getType())) { + outputGeneric = wrap(input, + extractClass(inputOutput.getType(), ParamType.OUTPUT_WRAPPER), + extractClass(outputGeneric.getType(), ParamType.OUTPUT)); + } + } return new FunctionType(ResolvableType .forClassWithGenerics(Function.class, inputGeneric, outputGeneric) .getType()); } private ResolvableType wrap(Class wrapper, Class type) { - return isMessage() ? wrap(wrapper, message(type)) + return wrap(this, wrapper, type); + } + + private static ResolvableType wrap(FunctionType input, Class wrapper, + Class type) { + return input.isMessage() ? wrap(wrapper, message(type)) : ResolvableType.forClassWithGenerics(wrapper, type); } - private ResolvableType wrap(Class wrapper, ResolvableType type) { + private static ResolvableType wrap(Class wrapper, ResolvableType type) { return ResolvableType.forClassWithGenerics(wrapper, type); } - private ResolvableType message(Class type) { + private static ResolvableType message(Class type) { return ResolvableType.forClassWithGenerics(Message.class, type); } @@ -224,7 +240,7 @@ public class FunctionType { return Object.class; } - private Class extractClass(Type param, ParamType paramType) { + private static Class extractClass(Type param, ParamType paramType) { if (param instanceof ParameterizedType) { ParameterizedType concrete = (ParameterizedType) param; param = concrete.getRawType(); diff --git a/spring-cloud-function-context/src/test/java/org/springframework/cloud/function/context/FunctionTypeTests.java b/spring-cloud-function-context/src/test/java/org/springframework/cloud/function/context/FunctionTypeTests.java index 8516e8f8e..61d138b98 100644 --- a/spring-cloud-function-context/src/test/java/org/springframework/cloud/function/context/FunctionTypeTests.java +++ b/spring-cloud-function-context/src/test/java/org/springframework/cloud/function/context/FunctionTypeTests.java @@ -165,6 +165,18 @@ public class FunctionTypeTests { assertThat(function.isMessage()).isEqualTo(true); } + @Test + public void compose() { + FunctionType input = FunctionType.from(Foo.class).to(Bar.class).wrap(Flux.class); + FunctionType output = FunctionType.from(Bar.class).to(String.class); + FunctionType function = FunctionType.compose(input, output); + assertThat(function.getInputType()).isEqualTo(Foo.class); + assertThat(function.getOutputType()).isEqualTo(String.class); + assertThat(function.getInputWrapper()).isEqualTo(Flux.class); + assertThat(function.getOutputWrapper()).isEqualTo(Flux.class); + assertThat(function.isMessage()).isEqualTo(false); + } + @Test public void idempotentMessage() { FunctionType function = FunctionType.from(Foo.class).to(Bar.class).message() diff --git a/spring-cloud-function-context/src/test/java/org/springframework/cloud/function/context/config/ContextFunctionPostProcessorTests.java b/spring-cloud-function-context/src/test/java/org/springframework/cloud/function/context/config/ContextFunctionPostProcessorTests.java index b6fa952bd..e76902b16 100644 --- a/spring-cloud-function-context/src/test/java/org/springframework/cloud/function/context/config/ContextFunctionPostProcessorTests.java +++ b/spring-cloud-function-context/src/test/java/org/springframework/cloud/function/context/config/ContextFunctionPostProcessorTests.java @@ -98,7 +98,8 @@ public class ContextFunctionPostProcessorTests { Function, Flux> foos = (Function, Flux>) processor .lookupFunction("foos,bars"); assertThat(foos.apply(Flux.just(2)).blockFirst()).isEqualTo("Hello 4"); - assertThat(processor.getRegistration(foos).getNames()).containsExactly("foos|bars"); + assertThat(processor.getRegistration(foos).getNames()) + .containsExactly("foos|bars"); } @Test @@ -109,7 +110,21 @@ public class ContextFunctionPostProcessorTests { Function, Flux> foos = (Function, Flux>) processor .lookupFunction("foos|bars"); assertThat(foos.apply(Flux.just(2)).blockFirst()).isEqualTo("Hello 4"); - assertThat(processor.getRegistration(foos).getNames()).containsExactly("foos|bars"); + assertThat(processor.getRegistration(foos).getNames()) + .containsExactly("foos|bars"); + } + + @Test + public void composeWrapper() { + processor.register(new FunctionRegistration<>(new WrappedSource()).names("ints")); + processor.register(new FunctionRegistration<>(new Foos()).names("foos")); + @SuppressWarnings("unchecked") + Supplier> foos = (Supplier>) processor + .lookupSupplier("ints|foos"); + assertThat(foos.get().blockFirst()).isEqualTo("8"); + assertThat(processor.getRegistration(foos).getNames()) + .containsExactly("ints|foos"); + assertThat(processor.getRegistration(foos).getType().getOutputWrapper()).isEqualTo(Flux.class); } @Test @@ -127,7 +142,8 @@ public class ContextFunctionPostProcessorTests { public void isolatedSupplier() { contextClassLoader = ClassUtils .overrideThreadContextClassLoader(getClass().getClassLoader()); - processor.register(new FunctionRegistration<>(create(Source.class)).names("source")); + processor.register( + new FunctionRegistration<>(create(Source.class)).names("source")); @SuppressWarnings("unchecked") Supplier> source = (Supplier>) processor .lookupSupplier("source"); @@ -145,7 +161,8 @@ public class ContextFunctionPostProcessorTests { .lookupConsumer("sink"); sink.accept(Flux.just("Hello")); @SuppressWarnings("unchecked") - List values = (List) ReflectionTestUtils.getField(target, "values"); + List values = (List) ReflectionTestUtils.getField(target, + "values"); assertThat(values).contains("Hello"); } @@ -201,6 +218,15 @@ public class ContextFunctionPostProcessorTests { } + public static class WrappedSource implements Supplier> { + + @Override + public Flux get() { + return Flux.just(4); + } + + } + public static class Foo { private String value; @@ -239,5 +265,5 @@ public class ContextFunctionPostProcessorTests { } } - + } diff --git a/spring-cloud-function-samples/function-sample-pojo/src/main/java/com/example/SampleApplication.java b/spring-cloud-function-samples/function-sample-pojo/src/main/java/com/example/SampleApplication.java index 68a6437c5..73fc0e004 100644 --- a/spring-cloud-function-samples/function-sample-pojo/src/main/java/com/example/SampleApplication.java +++ b/spring-cloud-function-samples/function-sample-pojo/src/main/java/com/example/SampleApplication.java @@ -33,8 +33,8 @@ public class SampleApplication { } @Bean - public Supplier> words() { - return () -> Flux.fromArray(new Bar[] { new Bar("foo"), new Bar("bar") }).log(); + public Supplier> words() { + return () -> Flux.fromArray(new Foo[] { new Foo("foo"), new Foo("bar") }).log(); } public static void main(String[] args) throws Exception { diff --git a/spring-cloud-function-samples/function-sample-pojo/src/test/java/com/example/SampleApplicationTests.java b/spring-cloud-function-samples/function-sample-pojo/src/test/java/com/example/SampleApplicationTests.java index 767127d05..86153f540 100644 --- a/spring-cloud-function-samples/function-sample-pojo/src/test/java/com/example/SampleApplicationTests.java +++ b/spring-cloud-function-samples/function-sample-pojo/src/test/java/com/example/SampleApplicationTests.java @@ -51,6 +51,13 @@ public class SampleApplicationTests { String.class)).isEqualTo("[{\"value\":\"FOO\"}]"); } + @Test + public void composite() { + assertThat(new TestRestTemplate() + .getForObject("http://localhost:" + port + "/words,uppercase", String.class)) + .isEqualTo("[{\"value\":\"FOO\"},{\"value\":\"BAR\"}]"); + } + @Test public void single() { assertThat(new TestRestTemplate().postForObject(