diff --git a/spring-cloud-function-grpc/src/main/java/org/springframework/cloud/function/grpc/GrpcServerMessageHandler.java b/spring-cloud-function-grpc/src/main/java/org/springframework/cloud/function/grpc/GrpcServerMessageHandler.java index 0b366666b..eb3bd22ab 100644 --- a/spring-cloud-function-grpc/src/main/java/org/springframework/cloud/function/grpc/GrpcServerMessageHandler.java +++ b/spring-cloud-function-grpc/src/main/java/org/springframework/cloud/function/grpc/GrpcServerMessageHandler.java @@ -85,12 +85,18 @@ class GrpcServerMessageHandler extends MessagingServiceImplBase { responseObserver.onNext(reply); responseObserver.onCompleted(); } -// -// @Override -// public void serverStream(GrpcMessage request, -// StreamObserver responseObserver) { -// -// } + + @Override + public void serverStream(GrpcMessage request, StreamObserver responseObserver) { + Message message = GrpcUtils.fromGrpcMessage(request); + Publisher> replyStream = (Publisher>) this.function.apply(message); + Flux.from(replyStream).doOnNext(replyMessage -> { + responseObserver.onNext(GrpcUtils.toGrpcMessage(replyMessage)); + }) + .doOnComplete(() -> responseObserver.onCompleted()) + .subscribe(); + } + @SuppressWarnings("unchecked") @Override diff --git a/spring-cloud-function-grpc/src/main/java/org/springframework/cloud/function/grpc/GrpcUtils.java b/spring-cloud-function-grpc/src/main/java/org/springframework/cloud/function/grpc/GrpcUtils.java index 3f2315588..8720cfc0c 100644 --- a/spring-cloud-function-grpc/src/main/java/org/springframework/cloud/function/grpc/GrpcUtils.java +++ b/spring-cloud-function-grpc/src/main/java/org/springframework/cloud/function/grpc/GrpcUtils.java @@ -17,7 +17,10 @@ package org.springframework.cloud.function.grpc; import java.util.HashMap; +import java.util.Iterator; import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; @@ -128,6 +131,33 @@ final class GrpcUtils { }); } + public static Flux> serverStream(String host, int port, Message inputMessage) { + ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port) + .usePlaintext().build(); + MessagingServiceGrpc.MessagingServiceBlockingStub stub = MessagingServiceGrpc + .newBlockingStub(channel); + + Iterator serverStream = stub.serverStream(toGrpcMessage(inputMessage)); + + Many> sink = Sinks.many().unicast().onBackpressureBuffer(); + ExecutorService executor = Executors.newSingleThreadExecutor(); + executor.execute(() -> { + while (serverStream.hasNext()) { + GrpcMessage grpcMessage = serverStream.next(); + sink.tryEmitNext(GrpcUtils.fromGrpcMessage(grpcMessage)); + } + sink.tryEmitComplete(); + }); + + + return sink.asFlux() + .doOnComplete(() -> { + channel.shutdown(); + executor.shutdownNow(); + }); + } + + /** * Utility method to support client-side streaming interaction. Will connect to gRPC server using default host/port, * otherwise use {@link #clientStream(String, int, Flux)} method. diff --git a/spring-cloud-function-grpc/src/test/java/org/springframework/cloud/function/grpc/GrpcInteractionTests.java b/spring-cloud-function-grpc/src/test/java/org/springframework/cloud/function/grpc/GrpcInteractionTests.java index 38bb37438..b49ae1a6c 100644 --- a/spring-cloud-function-grpc/src/test/java/org/springframework/cloud/function/grpc/GrpcInteractionTests.java +++ b/spring-cloud-function-grpc/src/test/java/org/springframework/cloud/function/grpc/GrpcInteractionTests.java @@ -118,10 +118,10 @@ public class GrpcInteractionTests { .setHeader(MessageHeaders.CONTENT_TYPE, MimeTypeUtils.TEXT_PLAIN) .build()); - Flux> clientResponseObserver = + Flux> resultStream = GrpcUtils.biStreaming("localhost", FunctionGrpcProperties.GRPC_PORT, Flux.fromIterable(messages)); - List> results = clientResponseObserver.collectList().block(Duration.ofSeconds(5)); + List> results = resultStream.collectList().block(Duration.ofSeconds(5)); assertThat(results.size()).isEqualTo(3); assertThat(results.get(0).getPayload()).isEqualTo("\"RICKY\"".getBytes()); assertThat(results.get(1).getPayload()).isEqualTo("\"JULIEN\"".getBytes()); @@ -154,6 +154,28 @@ public class GrpcInteractionTests { } } + @Test + public void testServerStreaming() { + try (ConfigurableApplicationContext context = new SpringApplicationBuilder( + SampleConfiguration.class).web(WebApplicationType.NONE).run( + "--spring.jmx.enabled=false", + "--spring.cloud.function.definition=stringInStreamOut", + "--spring.cloud.function.grpc.port=" + + FunctionGrpcProperties.GRPC_PORT, + "--spring.cloud.function.grpc.mode=server")) { + + Message message = MessageBuilder.withPayload("\"Ricky\"".getBytes()).setHeader("foo", "bar").build(); + + Flux> reply = + GrpcUtils.serverStream("localhost", FunctionGrpcProperties.GRPC_PORT, message); + + List> results = reply.collectList().block(Duration.ofSeconds(5)); + assertThat(results.size()).isEqualTo(2); + assertThat(results.get(0).getPayload()).isEqualTo("\"Ricky\"".getBytes()); + assertThat(results.get(1).getPayload()).isEqualTo("\"RICKY\"".getBytes()); + } + } + @Test public void testBiStreamStreamInStringOutFailure() { try (ConfigurableApplicationContext context = new SpringApplicationBuilder( @@ -166,13 +188,10 @@ public class GrpcInteractionTests { List> messages = new ArrayList<>(); messages.add(MessageBuilder.withPayload("\"Ricky\"".getBytes()).setHeader("foo", "bar") - .setHeader(MessageHeaders.CONTENT_TYPE, MimeTypeUtils.TEXT_PLAIN) .build()); messages.add(MessageBuilder.withPayload("\"Julien\"".getBytes()).setHeader("foo", "bar") - .setHeader(MessageHeaders.CONTENT_TYPE, MimeTypeUtils.TEXT_PLAIN) .build()); messages.add(MessageBuilder.withPayload("\"Bubbles\"".getBytes()).setHeader("foo", "bar") - .setHeader(MessageHeaders.CONTENT_TYPE, MimeTypeUtils.TEXT_PLAIN) .build()); Flux> clientResponseObserver = @@ -200,13 +219,10 @@ public class GrpcInteractionTests { List> messages = new ArrayList<>(); messages.add(MessageBuilder.withPayload("\"Ricky\"".getBytes()).setHeader("foo", "bar") - .setHeader(MessageHeaders.CONTENT_TYPE, MimeTypeUtils.TEXT_PLAIN) .build()); messages.add(MessageBuilder.withPayload("\"Julien\"".getBytes()).setHeader("foo", "bar") - .setHeader(MessageHeaders.CONTENT_TYPE, MimeTypeUtils.TEXT_PLAIN) .build()); messages.add(MessageBuilder.withPayload("\"Bubbles\"".getBytes()).setHeader("foo", "bar") - .setHeader(MessageHeaders.CONTENT_TYPE, MimeTypeUtils.TEXT_PLAIN) .build()); Flux> clientResponseObserver = @@ -249,7 +265,7 @@ public class GrpcInteractionTests { @Bean public Function> stringInStreamOut() { - return value -> Flux.just(value); + return value -> Flux.just(value, value.toUpperCase()); } } }