diff --git a/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/GraphQlSseHandler.java b/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/GraphQlSseHandler.java index f2054bcc..1069d353 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/GraphQlSseHandler.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/GraphQlSseHandler.java @@ -24,6 +24,7 @@ import java.util.Map; import graphql.ErrorType; import graphql.ExecutionResult; import graphql.GraphQLError; +import graphql.GraphqlErrorBuilder; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -86,7 +87,7 @@ public class GraphQlSseHandler extends AbstractGraphQlHttpHandler { if (response.getData() instanceof Publisher) { resultFlux = Flux.from((Publisher) response.getData()) .map(ExecutionResult::toSpecification) - .onErrorResume(SubscriptionPublisherException.class, (ex) -> Mono.just(ex.toMap())); + .onErrorResume(this::exceptionToResultMap); } else { if (this.logger.isDebugEnabled()) { @@ -102,14 +103,25 @@ public class GraphQlSseHandler extends AbstractGraphQlHttpHandler { } Flux>> sseFlux = - resultFlux.map((event) -> ServerSentEvent.builder(event).event("next").build()); + resultFlux.map((event) -> ServerSentEvent.builder(event).event("next").build()) + .concatWith(COMPLETE_EVENT); Mono responseMono = ServerResponse.ok() .contentType(MediaType.TEXT_EVENT_STREAM) - .body(BodyInserters.fromServerSentEvents(sseFlux.concatWith(COMPLETE_EVENT))) + .body(BodyInserters.fromServerSentEvents(sseFlux)) .onErrorResume(Throwable.class, (ex) -> ServerResponse.badRequest().build()); return ((this.timeout != null) ? responseMono.timeout(this.timeout) : responseMono); } + private Mono> exceptionToResultMap(Throwable ex) { + return Mono.just((ex instanceof SubscriptionPublisherException spe) ? + spe.toMap() : + GraphqlErrorBuilder.newError() + .message("Subscription error") + .errorType(org.springframework.graphql.execution.ErrorType.INTERNAL_ERROR) + .build() + .toSpecification()); + } + } diff --git a/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlSseHandler.java b/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlSseHandler.java index 30534c71..50e340e0 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlSseHandler.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlSseHandler.java @@ -24,6 +24,7 @@ import java.util.function.Consumer; import graphql.ErrorType; import graphql.ExecutionResult; import graphql.GraphQLError; +import graphql.GraphqlErrorBuilder; import org.reactivestreams.Publisher; import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.Flux; @@ -119,10 +120,10 @@ public class GraphQlSseHandler extends AbstractGraphQlHttpHandler { @Override protected void hookOnNext(Map value) { - writeResult(value); + sendNext(value); } - private void writeResult(Map value) { + private void sendNext(Map value) { try { this.sseBuilder.event("next"); this.sseBuilder.data(value); @@ -139,18 +140,21 @@ public class GraphQlSseHandler extends AbstractGraphQlHttpHandler { @Override protected void hookOnError(Throwable ex) { - if (ex instanceof SubscriptionPublisherException spe) { - ExecutionResult result = ExecutionResult.newExecutionResult().errors(spe.getErrors()).build(); - writeResult(result.toSpecification()); - hookOnComplete(); - } - else { - this.sseBuilder.error(ex); - } + sendNext(exceptionToResultMap(ex)); + sendComplete(); } - @Override - protected void hookOnComplete() { + private static Map exceptionToResultMap(Throwable ex) { + return ((ex instanceof SubscriptionPublisherException spe) ? + spe.toMap() : + GraphqlErrorBuilder.newError() + .message("Subscription error") + .errorType(org.springframework.graphql.execution.ErrorType.INTERNAL_ERROR) + .build() + .toSpecification()); + } + + private void sendComplete() { try { this.sseBuilder.event("complete").data(""); } @@ -160,6 +164,11 @@ public class GraphQlSseHandler extends AbstractGraphQlHttpHandler { this.sseBuilder.complete(); } + @Override + protected void hookOnComplete() { + sendComplete(); + } + static Consumer connect(Flux> resultFlux) { return (sseBuilder) -> { SseSubscriber subscriber = new SseSubscriber(sseBuilder); diff --git a/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlSseHandlerTests.java b/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlSseHandlerTests.java index ff0fcce6..60cb2394 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlSseHandlerTests.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlSseHandlerTests.java @@ -55,7 +55,8 @@ class GraphQlSseHandlerTests { private static final DataFetcher SEARCH_DATA_FETCHER = env -> { String author = env.getArgument("author"); - return Flux.fromIterable(BookSource.books()).filter((book) -> book.getAuthor().getFullName().contains(author)); + return Flux.fromIterable(BookSource.books()) + .filter((book) -> book.getAuthor().getFullName().contains(author)); }; private final MockServerHttpRequest httpRequest = MockServerHttpRequest.post("/graphql") diff --git a/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlSseHandlerTests.java b/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlSseHandlerTests.java index f870df10..ecaf18dd 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlSseHandlerTests.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlSseHandlerTests.java @@ -56,6 +56,7 @@ import static org.mockito.Mockito.mock; * * @author Brian Clozel */ +@SuppressWarnings("ReactiveStreamsUnusedPublisher") class GraphQlSseHandlerTests { private static final List> MESSAGE_READERS = @@ -92,7 +93,7 @@ class GraphQlSseHandlerTests { void shouldWriteMultipleEventsForSubscription() throws Exception { GraphQlSseHandler handler = createSseHandler(SEARCH_DATA_FETCHER); MockHttpServletRequest request = createServletRequest(""" - { "query": "subscription TestSubscription { bookSearch(author:\\\"Orwell\\\") { id name } }" } + { "query": "subscription TestSubscription { bookSearch(author:\\"Orwell\\") { id name } }" } """); MockHttpServletResponse response = handleAndAwait(request, handler); @@ -118,7 +119,7 @@ class GraphQlSseHandlerTests { GraphQlSseHandler handler = createSseHandler(errorDataFetcher); MockHttpServletRequest request = createServletRequest(""" - { "query": "subscription TestSubscription { bookSearch(author:\\\"Orwell\\\") { id name } }" } + { "query": "subscription TestSubscription { bookSearch(author:\\"Orwell\\") { id name } }" } """); MockHttpServletResponse response = handleAndAwait(request, handler); @@ -140,7 +141,7 @@ class GraphQlSseHandlerTests { void shouldCancelDataFetcherPublisherWhenWritingFails() throws Exception { GraphQlSseHandler handler = createSseHandler(SEARCH_DATA_FETCHER); MockHttpServletRequest servletRequest = createServletRequest(""" - { "query": "subscription TestSubscription { bookSearch(author:\\\"Orwell\\\") { id name } }" } + { "query": "subscription TestSubscription { bookSearch(author:\\"Orwell\\") { id name } }" } """); HttpServletResponse servletResponse = mock(HttpServletResponse.class); ServletOutputStream outputStream = mock(ServletOutputStream.class); @@ -165,7 +166,7 @@ class GraphQlSseHandlerTests { GraphQlSseHandler handler = createSseHandler(errorDataFetcher); MockHttpServletRequest servletRequest = createServletRequest(""" - { "query": "subscription TestSubscription { bookSearch(author:\\\"Orwell\\\") { id name } }" } + { "query": "subscription TestSubscription { bookSearch(author:\\"Orwell\\") { id name } }" } """); MockHttpServletResponse servletResponse = handleRequest(servletRequest, handler);