Fix output type conversion for composed function

Follow up to te provious commit
This ensures that only the last function in composition does the conversion
This commit is contained in:
Oleg Zhurakousky
2021-04-14 16:09:52 +02:00
parent de22a1e61b
commit 2bb6ea2a9f
2 changed files with 26 additions and 11 deletions

View File

@@ -285,8 +285,7 @@ public class SimpleFunctionRegistry implements FunctionRegistry, FunctionInspect
((FunctionInvocationWrapper) targetFunction).acceptedOutputMimeTypes = acceptedOutputTypes;
}
resultFunction = (Function<?, ?>) targetFunction;
//resultFunction = new FunctionInvocationWrapper(((FunctionInvocationWrapper) targetFunction).getTarget(), functionType, definition, acceptedOutputTypes);
resultFunction = new FunctionInvocationWrapper(((FunctionInvocationWrapper) targetFunction).getTarget(), functionType, definition, acceptedOutputTypes);
}
else {
resultFunction = new FunctionInvocationWrapper(targetFunction, functionType, definition, acceptedOutputTypes);
@@ -298,7 +297,8 @@ public class SimpleFunctionRegistry implements FunctionRegistry, FunctionInspect
String prefix = "";
Type originFunctionType = null;
for (String name : names) {
for (int i = 0; i < names.length; i++) {
String name = names[i];
Object function = this.locateFunction(name);
if (function == null) {
if (logger.isDebugEnabled()) {
@@ -349,13 +349,11 @@ public class SimpleFunctionRegistry implements FunctionRegistry, FunctionInspect
}
function = new FunctionInvocationWrapper(function, currentFunctionType, name, acceptedOutputTypes);
// function = new FunctionInvocationWrapper(function, currentFunctionType, name, names.length > 1 ? new String[] {} : acceptedOutputTypes);
//function = new FunctionInvocationWrapper(function, currentFunctionType, name, !names[0].equals("origin") && name.equals(names[names.length - 1]) ? acceptedOutputTypes : new String[] {});
if (originFunctionType == null) {
originFunctionType = currentFunctionType;
}
if (name.equals(names[names.length - 1]) && !names[0].equals("origin")) {
if (name.equals(names[names.length - 1]) /*&& !names[0].equals("origin")*/) {
((FunctionInvocationWrapper) function).setSkipOutputConversion(false);
}
else {
@@ -370,6 +368,9 @@ public class SimpleFunctionRegistry implements FunctionRegistry, FunctionInspect
originFunctionType = FunctionTypeUtils.compose(originFunctionType, currentFunctionType);
resultFunction = new FunctionInvocationWrapper(resultFunction.andThen((Function) function),
originFunctionType, composedNameBuilder.toString(), acceptedOutputTypes);
if (((FunctionInvocationWrapper) resultFunction).composed) { //if (i < names.length - 1) {
((FunctionInvocationWrapper) resultFunction).setSkipOutputConversion(true);
}
}
prefix = "|";
}
@@ -377,7 +378,6 @@ public class SimpleFunctionRegistry implements FunctionRegistry, FunctionInspect
FunctionRegistration<Object> registration = new FunctionRegistration<Object>(resultFunction, definition)
.type(originFunctionType);
registrationsByFunction.putIfAbsent(resultFunction, registration);
registrationsByName.putIfAbsent(definition, registration);
}
return resultFunction;
}
@@ -1002,6 +1002,4 @@ public class SimpleFunctionRegistry implements FunctionRegistry, FunctionInspect
return "org.springframework.kafka.support.KafkaNull".equals(payload.getClass().getName());
}
}
}

View File

@@ -247,6 +247,12 @@ public class BeanFactoryAwareFunctionRegistryTests {
@Test
public void testCompositionWithOutputConversion() {
FunctionCatalog catalog = this.configureCatalog();
Function<Message<byte[]>, Message<byte[]>> composedFunction = catalog.lookup("mapfrompojo|uppercase|reverse", "application/json");
Message<byte[]> resultMessage = composedFunction.apply(MessageBuilder.withPayload("{\"name\":\"Ricky\"}".getBytes()).build());
assertThat(new String(resultMessage.getPayload())).isEqualTo("\"YKCIR\"");
Function<Flux<String>, Flux<Message<byte[]>>> fluxFunction = catalog.lookup("uppercase|reverseFlux", "application/json");
List<Message<byte[]>> result = fluxFunction.apply(Flux.just("hello", "bye")).collectList().block();
assertThat(result.get(0).getPayload()).isEqualTo("\"OLLEH\"".getBytes());
@@ -877,6 +883,13 @@ public class BeanFactoryAwareFunctionRegistryTests {
return () -> "one";
}
@Bean
public Function<Person, String> mapfrompojo() {
return person -> {
return person.getName();
};
}
@Bean
public Function<Map<String, Object>, Person> maptopojo() {
return map -> {
@@ -887,7 +900,9 @@ public class BeanFactoryAwareFunctionRegistryTests {
@Bean
public Function<String, String> uppercase() {
return v -> v.toUpperCase();
return v -> {
return v.toUpperCase();
};
}
@Bean
@@ -920,7 +935,9 @@ public class BeanFactoryAwareFunctionRegistryTests {
@Bean
public Function<String, String> reverse() {
return value -> new StringBuilder(value).reverse().toString();
return value -> {
return new StringBuilder(value).reverse().toString();
};
}
@Bean