diff --git a/spring-cloud-function-adapters/spring-cloud-function-grpc/src/main/java/org/springframework/cloud/function/grpc/GrpcUtils.java b/spring-cloud-function-adapters/spring-cloud-function-grpc/src/main/java/org/springframework/cloud/function/grpc/GrpcUtils.java index ce38c883e..7ed7d7117 100644 --- a/spring-cloud-function-adapters/spring-cloud-function-grpc/src/main/java/org/springframework/cloud/function/grpc/GrpcUtils.java +++ b/spring-cloud-function-adapters/spring-cloud-function-grpc/src/main/java/org/springframework/cloud/function/grpc/GrpcUtils.java @@ -87,8 +87,13 @@ public final class GrpcUtils { MessagingServiceGrpc.MessagingServiceBlockingStub stub = MessagingServiceGrpc .newBlockingStub(channel); - GrpcSpringMessage response = stub.requestReply(toGrpcSpringMessage(inputMessage)); - return fromGrpcSpringMessage(response); + try { + GrpcSpringMessage response = stub.requestReply(toGrpcSpringMessage(inputMessage)); + return fromGrpcSpringMessage(response); + } + catch (Exception e) { + throw new IllegalStateException(e); + } } finally { channel.shutdownNow(); diff --git a/spring-cloud-function-adapters/spring-cloud-function-grpc/src/main/java/org/springframework/cloud/function/grpc/MessageHandlingHelper.java b/spring-cloud-function-adapters/spring-cloud-function-grpc/src/main/java/org/springframework/cloud/function/grpc/MessageHandlingHelper.java index caadb6e92..c24e3cd4f 100644 --- a/spring-cloud-function-adapters/spring-cloud-function-grpc/src/main/java/org/springframework/cloud/function/grpc/MessageHandlingHelper.java +++ b/spring-cloud-function-adapters/spring-cloud-function-grpc/src/main/java/org/springframework/cloud/function/grpc/MessageHandlingHelper.java @@ -24,12 +24,14 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import com.google.protobuf.ByteString; import com.google.protobuf.GeneratedMessageV3; import io.grpc.Status; import io.grpc.stub.ServerCallStreamObserver; import io.grpc.stub.StreamObserver; import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; import reactor.core.publisher.Sinks; import reactor.core.publisher.Sinks.Many; @@ -38,6 +40,7 @@ import org.apache.commons.logging.LogFactory; import org.reactivestreams.Publisher; import org.springframework.cloud.function.context.FunctionCatalog; import org.springframework.cloud.function.context.FunctionProperties; +import org.springframework.cloud.function.context.catalog.FunctionTypeUtils; import org.springframework.cloud.function.context.catalog.SimpleFunctionRegistry.FunctionInvocationWrapper; import org.springframework.context.SmartLifecycle; import org.springframework.messaging.Message; @@ -75,12 +78,29 @@ public class MessageHandlingHelper implements Smar public void requestReply(T request, StreamObserver responseObserver) { Message message = this.toSpringMessage(request); FunctionInvocationWrapper function = this.resolveFunction(message.getHeaders()); + if (FunctionTypeUtils.isFlux(function.getOutputType())) { + String errorMessage = "Flux reply is not supported for `requestReply` mode"; + responseObserver.onError(Status.UNKNOWN.withDescription(errorMessage) + .withCause(new UnsupportedOperationException(errorMessage)).asRuntimeException()); + return; + } - Message replyMessage = (Message) function.apply(message); - GeneratedMessageV3 reply = this.toGrpcMessage(replyMessage, (Class) request.getClass()); - - responseObserver.onNext((T) reply); - responseObserver.onCompleted(); + Object replyMessage = function.apply(message); + if (replyMessage instanceof Message) { + GeneratedMessageV3 reply = this.toGrpcMessage((Message) replyMessage, (Class) request.getClass()); + responseObserver.onNext((T) reply); + responseObserver.onCompleted(); + } + else if (replyMessage instanceof Publisher) { + if (replyMessage instanceof Mono) { + Mono.from((Publisher) replyMessage).doOnNext(reply -> { + GeneratedMessageV3 replyGrps = this.toGrpcMessage((Message) reply, (Class) request.getClass()); + responseObserver.onNext((T) replyGrps); + responseObserver.onCompleted(); + }) + .subscribe(); + } + } } @SuppressWarnings("unchecked") diff --git a/spring-cloud-function-adapters/spring-cloud-function-grpc/src/test/java/org/springframework/cloud/function/grpc/GrpcInteractionTests.java b/spring-cloud-function-adapters/spring-cloud-function-grpc/src/test/java/org/springframework/cloud/function/grpc/GrpcInteractionTests.java index 6d39125a1..83015c8e8 100644 --- a/spring-cloud-function-adapters/spring-cloud-function-grpc/src/test/java/org/springframework/cloud/function/grpc/GrpcInteractionTests.java +++ b/spring-cloud-function-adapters/spring-cloud-function-grpc/src/test/java/org/springframework/cloud/function/grpc/GrpcInteractionTests.java @@ -27,6 +27,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; import org.springframework.boot.WebApplicationType; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; @@ -40,6 +41,7 @@ import org.springframework.util.MimeTypeUtils; import org.springframework.util.SocketUtils; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.fail; /** * @@ -79,6 +81,49 @@ public class GrpcInteractionTests { } } + @Test + public void testRequestReplyWithMonoReturn() { + int port = SocketUtils.findAvailableTcpPort(); + try (ConfigurableApplicationContext context = new SpringApplicationBuilder( + SampleConfiguration.class).web(WebApplicationType.NONE).run( + "--spring.jmx.enabled=false", + "--spring.cloud.function.definition=uppercaseMonoReturn", + "--spring.cloud.function.grpc.port=" + port)) { + + Message message = MessageBuilder.withPayload("\"hello gRPC\"".getBytes()) + .setHeader("foo", "bar") + .setHeader(MessageHeaders.CONTENT_TYPE, MimeTypeUtils.TEXT_PLAIN) + .build(); + + Message reply = GrpcUtils.requestReply("localhost", port, message); + + assertThat(reply.getPayload()).isEqualTo("\"HELLO GRPC\"".getBytes()); + } + } + + @Test + public void testRequestReplyWithFluxReturn() { + int port = SocketUtils.findAvailableTcpPort(); + try (ConfigurableApplicationContext context = new SpringApplicationBuilder( + SampleConfiguration.class).web(WebApplicationType.NONE).run( + "--spring.jmx.enabled=false", + "--spring.cloud.function.definition=uppercaseFluxReturn", + "--spring.cloud.function.grpc.port=" + port)) { + + Message message = MessageBuilder.withPayload("\"hello gRPC\"".getBytes()) + .setHeader("foo", "bar") + .setHeader(MessageHeaders.CONTENT_TYPE, MimeTypeUtils.TEXT_PLAIN) + .build(); + try { + GrpcUtils.requestReply("localhost", port, message); + fail(); + } + catch (Exception e) { + assertThat(e.getMessage()).contains("Flux reply is not supported for `requestReply` mode"); + } + } + } + @Test public void testRequstReplyFunctionDefinitionInMessage() { int port = SocketUtils.findAvailableTcpPort(); @@ -263,6 +308,16 @@ public class GrpcInteractionTests { return v -> v.toUpperCase(); } + @Bean + public Function> uppercaseMonoReturn() { + return v -> Mono.just(v.toUpperCase()); + } + + @Bean + public Function> uppercaseFluxReturn() { + return v -> Flux.just(v.toUpperCase(), v.toUpperCase() + "-1", v.toUpperCase() + "-2"); + } + @Bean public Function reverse() { return v -> new StringBuilder(v).reverse().toString();