diff --git a/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/catalog/SimpleFunctionRegistry.java b/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/catalog/SimpleFunctionRegistry.java index a474d81c9..e49d3d9b0 100644 --- a/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/catalog/SimpleFunctionRegistry.java +++ b/spring-cloud-function-context/src/main/java/org/springframework/cloud/function/context/catalog/SimpleFunctionRegistry.java @@ -775,7 +775,7 @@ public class SimpleFunctionRegistry implements FunctionRegistry, FunctionInspect input = FunctionTypeUtils.isMono(this.inputType) ? Mono.just(input) : Flux.just(input); } } - else if (input instanceof Iterable && !FunctionTypeUtils.isTypeCollection(this.inputType)) { + else if (!(input instanceof Publisher) && input instanceof Iterable && !FunctionTypeUtils.isTypeCollection(this.inputType)) { input = Flux.fromIterable((Iterable) input); } return input; diff --git a/spring-cloud-function-grpc/src/main/java/org/springframework/cloud/function/grpc/GrpcAutoConfiguration.java b/spring-cloud-function-grpc/src/main/java/org/springframework/cloud/function/grpc/GrpcAutoConfiguration.java index 0e25b8a60..677cc61ba 100644 --- a/spring-cloud-function-grpc/src/main/java/org/springframework/cloud/function/grpc/GrpcAutoConfiguration.java +++ b/spring-cloud-function-grpc/src/main/java/org/springframework/cloud/function/grpc/GrpcAutoConfiguration.java @@ -20,6 +20,7 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.cloud.function.context.FunctionCatalog; import org.springframework.cloud.function.context.FunctionProperties; +import org.springframework.cloud.function.grpc.MessagingServiceGrpc.MessagingServiceImplBase; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -31,16 +32,16 @@ import org.springframework.context.annotation.Configuration; @Configuration(proxyBeanMethods = false) @EnableConfigurationProperties(FunctionGrpcProperties.class) @ConditionalOnProperty(name = "spring.cloud.function.grpc.mode", havingValue = "server", matchIfMissing = false) -public class GrpcAutoConfiguration { +class GrpcAutoConfiguration { @Bean - public GrpcServer grpcServer(FunctionGrpcProperties grpcProperties, GrpcMessagingServiceImpl grpcMessagingService) { + public GrpcServer grpcServer(FunctionGrpcProperties grpcProperties, MessagingServiceImplBase grpcMessagingService) { return new GrpcServer(grpcProperties, grpcMessagingService); } @Bean - public GrpcMessagingServiceImpl grpcMessageService(FunctionProperties funcProperties, FunctionCatalog functionCatalog) { - return new GrpcMessagingServiceImpl(funcProperties, functionCatalog); + public GrpcServerMessageHandler grpcMessageService(FunctionProperties funcProperties, FunctionCatalog functionCatalog) { + return new GrpcServerMessageHandler(funcProperties, functionCatalog); } } diff --git a/spring-cloud-function-grpc/src/main/java/org/springframework/cloud/function/grpc/GrpcMessagingServiceImpl.java b/spring-cloud-function-grpc/src/main/java/org/springframework/cloud/function/grpc/GrpcServerMessageHandler.java similarity index 70% rename from spring-cloud-function-grpc/src/main/java/org/springframework/cloud/function/grpc/GrpcMessagingServiceImpl.java rename to spring-cloud-function-grpc/src/main/java/org/springframework/cloud/function/grpc/GrpcServerMessageHandler.java index 7be957145..9c5a39e76 100644 --- a/spring-cloud-function-grpc/src/main/java/org/springframework/cloud/function/grpc/GrpcMessagingServiceImpl.java +++ b/spring-cloud-function-grpc/src/main/java/org/springframework/cloud/function/grpc/GrpcServerMessageHandler.java @@ -38,6 +38,10 @@ import java.util.concurrent.atomic.AtomicBoolean; import io.grpc.Status; import io.grpc.stub.ServerCallStreamObserver; import io.grpc.stub.StreamObserver; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Sinks; +import reactor.core.publisher.Sinks.Many; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -46,6 +50,7 @@ import org.springframework.cloud.function.context.FunctionProperties; import org.springframework.cloud.function.context.catalog.SimpleFunctionRegistry.FunctionInvocationWrapper; import org.springframework.cloud.function.grpc.MessagingServiceGrpc.MessagingServiceImplBase; import org.springframework.messaging.Message; +import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.Assert; /** @@ -54,13 +59,13 @@ import org.springframework.util.Assert; * @since 3.2 * */ -class GrpcMessagingServiceImpl extends MessagingServiceImplBase { +class GrpcServerMessageHandler extends MessagingServiceImplBase { - private Log logger = LogFactory.getLog(GrpcMessagingServiceImpl.class); + private Log logger = LogFactory.getLog(GrpcServerMessageHandler.class); private final FunctionInvocationWrapper function; - GrpcMessagingServiceImpl(FunctionProperties funcProperties, FunctionCatalog functionCatalog) { + GrpcServerMessageHandler(FunctionProperties funcProperties, FunctionCatalog functionCatalog) { this.function = functionCatalog.lookup(funcProperties.getDefinition(), "application/json"); Assert.notNull(this.function, "Failed to lookup function " + funcProperties.getDefinition()); } @@ -91,7 +96,6 @@ class GrpcMessagingServiceImpl extends MessagingServiceImplBase { // } // @Override - @SuppressWarnings("unchecked") public StreamObserver biStream(StreamObserver responseObserver) { ServerCallStreamObserver serverCallStreamObserver = (ServerCallStreamObserver) responseObserver; serverCallStreamObserver.disableAutoInboundFlowControl(); @@ -104,15 +108,26 @@ class GrpcMessagingServiceImpl extends MessagingServiceImplBase { serverCallStreamObserver.request(1); } }); + + if (function.isInputTypePublisher()) { + return this.biStreamReactive(responseObserver, serverCallStreamObserver); + } + else { + return this.biStreamImperative(responseObserver, serverCallStreamObserver, wasReady); + } + } + + private StreamObserver biStreamImperative(StreamObserver responseObserver, + ServerCallStreamObserver serverCallStreamObserver, AtomicBoolean wasReady) { return new StreamObserver() { + @SuppressWarnings("unchecked") @Override public void onNext(GrpcMessage request) { try { Message message = GrpcUtils.fromGrpcMessage(request); - Message replyMessage = (Message) function - .apply(message); + Message replyMessage = (Message) function.apply(message); GrpcMessage reply = GrpcUtils.toGrpcMessage(replyMessage); @@ -147,4 +162,40 @@ class GrpcMessagingServiceImpl extends MessagingServiceImplBase { } }; } + + @SuppressWarnings("unchecked") + private StreamObserver biStreamReactive(StreamObserver responseObserver, + ServerCallStreamObserver serverCallStreamObserver) { + Many> sink = Sinks.many().unicast().onBackpressureBuffer(); + Flux> flux = sink.asFlux(); + + Flux> connectedFlux = (Flux>) function.apply(flux); + + connectedFlux.subscribe(functionResult -> { + GrpcMessage reply = GrpcUtils.toGrpcMessage(functionResult); + responseObserver.onNext(reply); + }); + + return new StreamObserver() { + + @Override + public void onNext(GrpcMessage value) { + sink.tryEmitNext(GrpcUtils.fromGrpcMessage(value)); + serverCallStreamObserver.request(1); + } + + @Override + public void onError(Throwable t) { + t.printStackTrace(); + responseObserver.onCompleted(); + } + + @Override + public void onCompleted() { + logger.info("Server stream is complete"); + sink.tryEmitComplete(); + responseObserver.onCompleted(); + } + }; + } } diff --git a/spring-cloud-function-grpc/src/main/java/org/springframework/cloud/function/grpc/GrpcUtils.java b/spring-cloud-function-grpc/src/main/java/org/springframework/cloud/function/grpc/GrpcUtils.java index 4f8b8e009..ad5195e97 100644 --- a/spring-cloud-function-grpc/src/main/java/org/springframework/cloud/function/grpc/GrpcUtils.java +++ b/spring-cloud-function-grpc/src/main/java/org/springframework/cloud/function/grpc/GrpcUtils.java @@ -39,7 +39,7 @@ import org.springframework.messaging.support.MessageBuilder; * @since 3.2 * */ -public final class GrpcUtils { +final class GrpcUtils { private static Log logger = LogFactory.getLog(GrpcUtils.class); @@ -155,6 +155,7 @@ public final class GrpcUtils { @Override public void onNext(GrpcMessage message) { + System.out.println("RECEIVED: " + message); if (logger.isDebugEnabled()) { logger.debug("Receiving message: " + message); } diff --git a/spring-cloud-function-grpc/src/test/java/org/springframework/cloud/function/grpc/GrpcInteractionTests.java b/spring-cloud-function-grpc/src/test/java/org/springframework/cloud/function/grpc/GrpcInteractionTests.java index 022da1a02..4c4ac990b 100644 --- a/spring-cloud-function-grpc/src/test/java/org/springframework/cloud/function/grpc/GrpcInteractionTests.java +++ b/spring-cloud-function-grpc/src/test/java/org/springframework/cloud/function/grpc/GrpcInteractionTests.java @@ -64,7 +64,7 @@ public class GrpcInteractionTests { } @Test - public void testBidirectionalStream() { + public void testBidirectionalStreamWithImperativeFunction() { try (ConfigurableApplicationContext context = new SpringApplicationBuilder( SampleConfiguration.class).web(WebApplicationType.NONE).run( "--spring.jmx.enabled=false", @@ -95,6 +95,38 @@ public class GrpcInteractionTests { } } + @Test + public void testBidirectionalStreamWithReactiveFunction() { + try (ConfigurableApplicationContext context = new SpringApplicationBuilder( + SampleConfiguration.class).web(WebApplicationType.NONE).run( + "--spring.jmx.enabled=false", + "--spring.cloud.function.definition=uppercaseReactive", + "--spring.cloud.function.grpc.port=" + + FunctionGrpcProperties.GRPC_PORT, + "--spring.cloud.function.grpc.mode=server")) { + + List> messages = new ArrayList<>(); + messages.add(MessageBuilder.withPayload("\"Ricky\"".getBytes()).setHeader("foo", "bar") + .setHeader(MessageHeaders.CONTENT_TYPE, MimeTypeUtils.TEXT_PLAIN) + .build()); + messages.add(MessageBuilder.withPayload("\"Julien\"".getBytes()).setHeader("foo", "bar") + .setHeader(MessageHeaders.CONTENT_TYPE, MimeTypeUtils.TEXT_PLAIN) + .build()); + messages.add(MessageBuilder.withPayload("\"Bubbles\"".getBytes()).setHeader("foo", "bar") + .setHeader(MessageHeaders.CONTENT_TYPE, MimeTypeUtils.TEXT_PLAIN) + .build()); + + Flux> clientResponseObserver = + GrpcUtils.biStreaming("localhost", FunctionGrpcProperties.GRPC_PORT, Flux.fromIterable(messages)); + + List> results = clientResponseObserver.collectList().block(Duration.ofSeconds(5)); + assertThat(results.size()).isEqualTo(3); + assertThat(results.get(0).getPayload()).isEqualTo("\"RICKY\"".getBytes()); + assertThat(results.get(1).getPayload()).isEqualTo("\"JULIEN\"".getBytes()); + assertThat(results.get(2).getPayload()).isEqualTo("\"BUBBLES\"".getBytes()); + } + } + @EnableAutoConfiguration public static class SampleConfiguration { @@ -102,5 +134,10 @@ public class GrpcInteractionTests { public Function uppercase() { return v -> v.toUpperCase(); } + + @Bean + public Function, Flux> uppercaseReactive() { + return flux -> flux.map(v -> v.toUpperCase()); + } } }