diff --git a/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/catalog/FunctionAroundWrapper.java b/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/catalog/FunctionAroundWrapper.java new file mode 100644 index 000000000..da8bc508d --- /dev/null +++ b/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/catalog/FunctionAroundWrapper.java @@ -0,0 +1,41 @@ +/* + * Copyright 2012-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.cloud.function.context.catalog; + +import java.util.function.BiFunction; + +import org.springframework.cloud.function.context.catalog.SimpleFunctionRegistry.FunctionInvocationWrapper; +import org.springframework.messaging.Message; + +/** + * + * @author Oleg Zhurakousky + * @since 3.1 + */ +public abstract class FunctionAroundWrapper implements BiFunction { + + @SuppressWarnings("unchecked") + @Override + public final Object apply(Object input, FunctionInvocationWrapper targetFunction) { + if (input instanceof Message) { + return this.doApply((Message) input, targetFunction); + } + return targetFunction.apply(input); + } + + protected abstract Object doApply(Message input, FunctionInvocationWrapper targetFunction); +} diff --git a/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/catalog/SimpleFunctionRegistry.java b/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/catalog/SimpleFunctionRegistry.java index 09045830a..9654405cb 100644 --- a/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/catalog/SimpleFunctionRegistry.java +++ b/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/catalog/SimpleFunctionRegistry.java @@ -49,6 +49,7 @@ import reactor.util.function.Tuples; import org.springframework.aop.framework.ProxyFactory; import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.cloud.function.context.FunctionProperties; import org.springframework.cloud.function.context.FunctionRegistration; import org.springframework.cloud.function.context.FunctionRegistry; @@ -107,6 +108,9 @@ public class SimpleFunctionRegistry implements FunctionRegistry, FunctionInspect private List declaredFunctionDefinitions; + @Autowired(required = false) + private FunctionAroundWrapper functionAroundWrapper; + public SimpleFunctionRegistry(ConversionService conversionService, @Nullable CompositeMessageConverter messageConverter) { this.conversionService = conversionService; this.messageConverter = messageConverter; @@ -164,6 +168,15 @@ public class SimpleFunctionRegistry implements FunctionRegistry, FunctionInspect } FunctionInvocationWrapper function = (FunctionInvocationWrapper) this.compose(null, definition, acceptedOutputTypes); + + if (this.functionAroundWrapper != null && function != null) { + return (T) new FunctionInvocationWrapper(function) { + @Override + Object doApply(Object input, boolean consumer, Function enricher) { + return functionAroundWrapper.apply(input, function); + } + }; + } return (T) function; } @@ -406,6 +419,18 @@ public class SimpleFunctionRegistry implements FunctionRegistry, FunctionInspect private final Field headersField; + private FunctionInvocationWrapper delegate; + + FunctionInvocationWrapper(FunctionInvocationWrapper delegate) { + this.delegate = delegate; + this.target = delegate.target; + this.composed = delegate.composed; + this.functionType = delegate.functionType; + this.acceptedOutputMimeTypes = delegate.acceptedOutputMimeTypes; + this.functionDefinition = delegate.functionDefinition; + this.headersField = delegate.headersField; + } + FunctionInvocationWrapper(Object target, Type functionType, String functionDefinition, String... acceptedOutputMimeTypes) { this.target = target; this.composed = functionDefinition.contains("|") || target instanceof RoutingFunction; @@ -416,6 +441,22 @@ public class SimpleFunctionRegistry implements FunctionRegistry, FunctionInspect this.headersField.setAccessible(true); } + @Override + public int hashCode() { + if (this.delegate != null) { + return this.delegate.hashCode(); + } + return super.hashCode(); + } + + @Override + public boolean equals(Object o) { + if (this.delegate != null) { + return this.delegate.equals(o); + } + return super.equals(o); + } + public String getFunctionDefinition() { return this.functionDefinition; } @@ -536,7 +577,7 @@ public class SimpleFunctionRegistry implements FunctionRegistry, FunctionInspect } @SuppressWarnings({ "unchecked", "rawtypes" }) - private Object doApply(Object input, boolean consumer, Function enricher) { + Object doApply(Object input, boolean consumer, Function enricher) { if (logger.isDebugEnabled()) { logger.debug("Applying function: " + this.functionDefinition); }