diff --git a/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/GraphQlSseHandler.java b/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/GraphQlSseHandler.java index 1069d353..cb3bd1e3 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/GraphQlSseHandler.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/GraphQlSseHandler.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2024 the original author or authors. + * Copyright 2020-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -51,31 +51,54 @@ import org.springframework.web.reactive.function.server.ServerResponse; */ public class GraphQlSseHandler extends AbstractGraphQlHttpHandler { - private static final Mono>> COMPLETE_EVENT = Mono.just( - ServerSentEvent.>builder(Collections.emptyMap()).event("complete").build()); + private static final ServerSentEvent> HEARTBEAT_EVENT = + ServerSentEvent.>builder().comment("").build(); + + private static final Mono>> COMPLETE_EVENT_MONO = + Mono.just(ServerSentEvent.>builder(Collections.emptyMap()).event("complete").build()); @Nullable private final Duration timeout; + @Nullable + private final Duration keepAliveDuration; + /** - * Constructor with the handler to delegate to, and no timeout by default, + * Basic constructor with the handler to delegate to, and no timeout by default, * which results in never timing out. * @param graphQlHandler the handler to delegate to */ public GraphQlSseHandler(WebGraphQlHandler graphQlHandler) { - this(graphQlHandler, null); + this(graphQlHandler, null, null); } /** - * Variant constructor with a timeout to use for SSE subscriptions. + * Constructor with a timeout on how long to wait for the application to return + * the {@link ServerResponse} that will start the stream. * @param graphQlHandler the handler to delegate to - * @param timeout the timeout value to use or {@code null} to never time out + * @param timeout the timeout value, or {@code null} to never time out * @since 1.3.3 */ public GraphQlSseHandler(WebGraphQlHandler graphQlHandler, @Nullable Duration timeout) { + this(graphQlHandler, null, null); + } + + /** + * Constructor with a keep-alive duration that determines how frequently to + * heartbeats during periods of inactivity. + * @param graphQlHandler the handler to delegate to + * @param timeout the timeout value to use or {@code null} to never time out + * @param keepAliveDuration how frequently to send empty comment messages + * when no other messages are sent + * @since 1.4.0 + */ + public GraphQlSseHandler( + WebGraphQlHandler graphQlHandler, @Nullable Duration timeout, @Nullable Duration keepAliveDuration) { + super(graphQlHandler, null); this.timeout = timeout; + this.keepAliveDuration = keepAliveDuration; } @@ -104,7 +127,12 @@ public class GraphQlSseHandler extends AbstractGraphQlHttpHandler { Flux>> sseFlux = resultFlux.map((event) -> ServerSentEvent.builder(event).event("next").build()) - .concatWith(COMPLETE_EVENT); + .concatWith(COMPLETE_EVENT_MONO); + + if (this.keepAliveDuration != null) { + KeepAliveHandler handler = new KeepAliveHandler(this.keepAliveDuration); + sseFlux = handler.compose(sseFlux); + } Mono responseMono = ServerResponse.ok() .contentType(MediaType.TEXT_EVENT_STREAM) @@ -124,4 +152,34 @@ public class GraphQlSseHandler extends AbstractGraphQlHttpHandler { .toSpecification()); } + + private static final class KeepAliveHandler { + + private final Duration keepAliveDuration; + + private boolean eventSent; + + KeepAliveHandler(Duration keepAliveDuration) { + this.keepAliveDuration = keepAliveDuration; + } + + public Flux>> compose(Flux>> flux) { + return flux.doOnNext((event) -> this.eventSent = true) + .mergeWith(getKeepAliveFlux()) + .takeUntil((sse) -> "complete".equals(sse.event())); + } + + private Flux>> getKeepAliveFlux() { + return Flux.interval(this.keepAliveDuration, this.keepAliveDuration) + .filter((aLong) -> !checkEventSentAndClear()) + .map((aLong) -> HEARTBEAT_EVENT); + } + + private boolean checkEventSentAndClear() { + boolean result = this.eventSent; + this.eventSent = false; + return result; + } + } + } diff --git a/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlSseHandler.java b/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlSseHandler.java index 50e340e0..fde91864 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlSseHandler.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlSseHandler.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2024 the original author or authors. + * Copyright 2020-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ package org.springframework.graphql.server.webmvc; import java.io.IOException; import java.time.Duration; +import java.util.LinkedHashMap; import java.util.Map; import java.util.function.Consumer; @@ -29,6 +30,7 @@ import org.reactivestreams.Publisher; import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; import org.springframework.graphql.execution.SubscriptionPublisherException; import org.springframework.graphql.server.WebGraphQlHandler; @@ -50,9 +52,15 @@ import org.springframework.web.servlet.function.ServerResponse; */ public class GraphQlSseHandler extends AbstractGraphQlHttpHandler { + private static final Map HEARTBEAT_MAP = new LinkedHashMap<>(0); + + @Nullable private final Duration timeout; + @Nullable + private final Duration keepAliveDuration; + /** * Constructor with the handler to delegate to, and no timeout, @@ -60,7 +68,7 @@ public class GraphQlSseHandler extends AbstractGraphQlHttpHandler { * @param graphQlHandler the handler to delegate to */ public GraphQlSseHandler(WebGraphQlHandler graphQlHandler) { - this(graphQlHandler, null); + this(graphQlHandler, null, null); } /** @@ -71,8 +79,24 @@ public class GraphQlSseHandler extends AbstractGraphQlHttpHandler { * @since 1.3.3 */ public GraphQlSseHandler(WebGraphQlHandler graphQlHandler, @Nullable Duration timeout) { + this(graphQlHandler, timeout, null); + } + + /** + * Variant constructor with a timeout to use for SSE subscriptions. + * @param graphQlHandler the handler to delegate to + * @param timeout the timeout value to set on + * @param keepAliveDuration how frequently to send empty comment messages + * when no other messages are sent + * {@link org.springframework.web.context.request.async.AsyncWebRequest#setTimeout(Long)} + * @since 1.4.0 + */ + public GraphQlSseHandler( + WebGraphQlHandler graphQlHandler, @Nullable Duration timeout, @Nullable Duration keepAliveDuration) { + super(graphQlHandler, null); this.timeout = timeout; + this.keepAliveDuration = keepAliveDuration; } @@ -101,8 +125,8 @@ public class GraphQlSseHandler extends AbstractGraphQlHttpHandler { }); return ((this.timeout != null) ? - ServerResponse.sse(SseSubscriber.connect(resultFlux), this.timeout) : - ServerResponse.sse(SseSubscriber.connect(resultFlux))); + ServerResponse.sse(SseSubscriber.connect(resultFlux, this.keepAliveDuration), this.timeout) : + ServerResponse.sse(SseSubscriber.connect(resultFlux, this.keepAliveDuration))); } @@ -120,6 +144,10 @@ public class GraphQlSseHandler extends AbstractGraphQlHttpHandler { @Override protected void hookOnNext(Map value) { + if (value == HEARTBEAT_MAP) { + sendHeartbeat(); + return; + } sendNext(value); } @@ -133,6 +161,18 @@ public class GraphQlSseHandler extends AbstractGraphQlHttpHandler { } } + private void sendHeartbeat() { + try { + // Currently, comment cannot be empty: + // https://github.com/spring-projects/spring-framework/issues/34608 + this.sseBuilder.comment(" "); + this.sseBuilder.send(); + } + catch (IOException exception) { + cancelWithError(exception); + } + } + private void cancelWithError(Throwable ex) { this.cancel(); this.sseBuilder.error(ex); @@ -169,12 +209,53 @@ public class GraphQlSseHandler extends AbstractGraphQlHttpHandler { sendComplete(); } - static Consumer connect(Flux> resultFlux) { + static Consumer connect( + Flux> resultFlux, @Nullable Duration keepAliveDuration) { + return (sseBuilder) -> { SseSubscriber subscriber = new SseSubscriber(sseBuilder); - resultFlux.subscribe(subscriber); + if (keepAliveDuration != null) { + KeepAliveHandler handler = new KeepAliveHandler(keepAliveDuration); + handler.compose(resultFlux).subscribe(subscriber); + } + else { + resultFlux.subscribe(subscriber); + } }; } } + + private static final class KeepAliveHandler { + + private final Duration keepAliveDuration; + + private boolean eventSent; + + private final Sinks.Empty completionSink = Sinks.empty(); + + KeepAliveHandler(Duration keepAliveDuration) { + this.keepAliveDuration = keepAliveDuration; + } + + public Flux> compose(Flux> flux) { + return flux.doOnNext((event) -> this.eventSent = true) + .doOnComplete(this.completionSink::tryEmitEmpty) + .mergeWith(getKeepAliveFlux()) + .takeUntilOther(this.completionSink.asMono()); + } + + private Flux> getKeepAliveFlux() { + return Flux.interval(this.keepAliveDuration, this.keepAliveDuration) + .filter((aLong) -> !checkEventSentAndClear()) + .map((aLong) -> HEARTBEAT_MAP); + } + + private boolean checkEventSentAndClear() { + boolean result = this.eventSent; + this.eventSent = false; + return result; + } + } + } diff --git a/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlSseHandlerTests.java b/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlSseHandlerTests.java index 60cb2394..d7707e13 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlSseHandlerTests.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlSseHandlerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2024 the original author or authors. + * Copyright 2020-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.graphql.server.webflux; +import java.time.Duration; import java.util.Collections; import java.util.List; @@ -28,6 +29,7 @@ import reactor.core.publisher.Mono; import org.springframework.graphql.BookSource; import org.springframework.graphql.GraphQlRequest; import org.springframework.graphql.GraphQlSetup; +import org.springframework.graphql.server.WebGraphQlHandler; import org.springframework.graphql.server.support.SerializableGraphQlRequest; import org.springframework.http.MediaType; import org.springframework.http.codec.HttpMessageWriter; @@ -67,7 +69,7 @@ class GraphQlSseHandlerTests { @Test void shouldRejectQueryOperations() { SerializableGraphQlRequest request = initRequest("{ bookById(id: 42) {name} }"); - GraphQlSseHandler handler = createHandler(SEARCH_DATA_FETCHER); + GraphQlSseHandler handler = createSseHandler(SEARCH_DATA_FETCHER); MockServerHttpResponse response = handleRequest(this.httpRequest, handler, request); assertThat(response.getHeaders().getContentType().isCompatibleWith(MediaType.TEXT_EVENT_STREAM)).isTrue(); @@ -87,7 +89,7 @@ class GraphQlSseHandlerTests { SerializableGraphQlRequest request = initRequest( "subscription TestSubscription { bookSearch(author:\"Orwell\") { id name } }"); - GraphQlSseHandler handler = createHandler(SEARCH_DATA_FETCHER); + GraphQlSseHandler handler = createSseHandler(SEARCH_DATA_FETCHER); MockServerHttpResponse response = handleRequest(this.httpRequest, handler, request); assertThat(response.getHeaders().getContentType().isCompatibleWith(MediaType.TEXT_EVENT_STREAM)).isTrue(); @@ -113,7 +115,7 @@ class GraphQlSseHandlerTests { DataFetcher errorDataFetcher = env -> Flux.just(BookSource.getBook(1L)).concatWith(Flux.error(new IllegalStateException("test error"))); - GraphQlSseHandler handler = createHandler(errorDataFetcher); + GraphQlSseHandler handler = createSseHandler(errorDataFetcher); MockServerHttpResponse response = handleRequest(this.httpRequest, handler, request); assertThat(response.getHeaders().getContentType().isCompatibleWith(MediaType.TEXT_EVENT_STREAM)).isTrue(); @@ -130,12 +132,40 @@ class GraphQlSseHandlerTests { """); } - private GraphQlSseHandler createHandler(DataFetcher subscriptionDataFetcher) { - return new GraphQlSseHandler( - GraphQlSetup.schemaResource(BookSource.schema) - .queryFetcher("bookById", (env) -> BookSource.getBookWithoutAuthor(1L)) - .subscriptionFetcher("bookSearch", subscriptionDataFetcher) - .toWebGraphQlHandler()); + @Test + void shouldSendKeepAlivePings() { + SerializableGraphQlRequest request = initRequest( + "subscription TestSubscription { bookSearch(author:\"Orwell\") { id name } }"); + + WebGraphQlHandler webGraphQlHandler = createWebGraphQlHandler(env -> Mono.delay(Duration.ofMillis(50)).then()); + GraphQlSseHandler handler = new GraphQlSseHandler(webGraphQlHandler, null, Duration.ofMillis(10)); + + assertThat(handleRequest(this.httpRequest, handler, request).getBodyAsString().block()) + .startsWith(""" + : + + : + + """) + .endsWith(""" + : + + event:complete + data:{} + + """); + } + + private GraphQlSseHandler createSseHandler(DataFetcher subscriptionDataFetcher) { + WebGraphQlHandler webGraphQlHandler = createWebGraphQlHandler(subscriptionDataFetcher); + return new GraphQlSseHandler(webGraphQlHandler); + } + + private static WebGraphQlHandler createWebGraphQlHandler(DataFetcher subscriptionDataFetcher) { + return GraphQlSetup.schemaResource(BookSource.schema) + .queryFetcher("bookById", (env) -> BookSource.getBookWithoutAuthor(1L)) + .subscriptionFetcher("bookSearch", subscriptionDataFetcher) + .toWebGraphQlHandler(); } private static SerializableGraphQlRequest initRequest(String document) { diff --git a/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlSseHandlerTests.java b/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlSseHandlerTests.java index ecaf18dd..6225bb80 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlSseHandlerTests.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlSseHandlerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2024 the original author or authors. + * Copyright 2020-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,9 +31,11 @@ import jakarta.servlet.ServletOutputStream; import jakarta.servlet.http.HttpServletResponse; import org.junit.jupiter.api.Test; import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; import org.springframework.graphql.BookSource; import org.springframework.graphql.GraphQlSetup; +import org.springframework.graphql.server.WebGraphQlHandler; import org.springframework.http.MediaType; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter; @@ -71,6 +73,10 @@ class GraphQlSseHandlerTests { .doOnCancel(() -> DATA_FETCHER_CANCELLED.set(true)); }; + private static final String BOOK_SEARCH_REQUEST = """ + { "query": "subscription TestSubscription { bookSearch(author:\\"Orwell\\") { id name } }" } + """; + @Test void shouldRejectQueryOperations() throws Exception { @@ -92,9 +98,7 @@ class GraphQlSseHandlerTests { @Test void shouldWriteMultipleEventsForSubscription() throws Exception { GraphQlSseHandler handler = createSseHandler(SEARCH_DATA_FETCHER); - MockHttpServletRequest request = createServletRequest(""" - { "query": "subscription TestSubscription { bookSearch(author:\\"Orwell\\") { id name } }" } - """); + MockHttpServletRequest request = createServletRequest(BOOK_SEARCH_REQUEST); MockHttpServletResponse response = handleAndAwait(request, handler); assertThat(response.getContentType()).isEqualTo(MediaType.TEXT_EVENT_STREAM_VALUE); @@ -111,6 +115,31 @@ class GraphQlSseHandlerTests { """); } + @Test + void shouldSendKeepAlivePings() throws Exception { + WebGraphQlHandler webGraphQlHandler = createWebGraphQlHandler(env -> Mono.delay(Duration.ofMillis(50)).then()); + GraphQlSseHandler handler = new GraphQlSseHandler(webGraphQlHandler, null, Duration.ofMillis(10)); + + MockHttpServletRequest request = createServletRequest(BOOK_SEARCH_REQUEST); + MockHttpServletResponse response = handleRequest(request, handler); + await().atMost(Duration.ofSeconds(1)).until(() -> response.getContentAsString().contains("complete")); + + assertThat(response.getContentAsString()) + .startsWith(""" + :\s + + :\s + + """) + .endsWith(""" + :\s + + event:complete + data: + + """); + } + @Test void shouldWriteEventsAndTerminalError() throws Exception { @@ -118,9 +147,7 @@ class GraphQlSseHandlerTests { .concatWith(Flux.error(new IllegalStateException("test error"))); GraphQlSseHandler handler = createSseHandler(errorDataFetcher); - MockHttpServletRequest request = createServletRequest(""" - { "query": "subscription TestSubscription { bookSearch(author:\\"Orwell\\") { id name } }" } - """); + MockHttpServletRequest request = createServletRequest(BOOK_SEARCH_REQUEST); MockHttpServletResponse response = handleAndAwait(request, handler); assertThat(response.getContentType()).isEqualTo(MediaType.TEXT_EVENT_STREAM_VALUE); @@ -140,9 +167,7 @@ class GraphQlSseHandlerTests { @Test void shouldCancelDataFetcherPublisherWhenWritingFails() throws Exception { GraphQlSseHandler handler = createSseHandler(SEARCH_DATA_FETCHER); - MockHttpServletRequest servletRequest = createServletRequest(""" - { "query": "subscription TestSubscription { bookSearch(author:\\"Orwell\\") { id name } }" } - """); + MockHttpServletRequest servletRequest = createServletRequest(BOOK_SEARCH_REQUEST); HttpServletResponse servletResponse = mock(HttpServletResponse.class); ServletOutputStream outputStream = mock(ServletOutputStream.class); @@ -165,9 +190,7 @@ class GraphQlSseHandlerTests { .delayElements(Duration.ofMillis(500)).doOnCancel(() -> DATA_FETCHER_CANCELLED.set(true)); GraphQlSseHandler handler = createSseHandler(errorDataFetcher); - MockHttpServletRequest servletRequest = createServletRequest(""" - { "query": "subscription TestSubscription { bookSearch(author:\\"Orwell\\") { id name } }" } - """); + MockHttpServletRequest servletRequest = createServletRequest(BOOK_SEARCH_REQUEST); MockHttpServletResponse servletResponse = handleRequest(servletRequest, handler); for (AsyncListener listener : ((MockAsyncContext) servletRequest.getAsyncContext()).getListeners()) { @@ -180,10 +203,14 @@ class GraphQlSseHandlerTests { } private GraphQlSseHandler createSseHandler(DataFetcher dataFetcher) { - return new GraphQlSseHandler(GraphQlSetup.schemaResource(BookSource.schema) + return new GraphQlSseHandler(createWebGraphQlHandler(dataFetcher)); + } + + private static WebGraphQlHandler createWebGraphQlHandler(DataFetcher dataFetcher) { + return GraphQlSetup.schemaResource(BookSource.schema) .queryFetcher("bookById", (env) -> BookSource.getBookWithoutAuthor(1L)) .subscriptionFetcher("bookSearch", dataFetcher) - .toWebGraphQlHandler()); + .toWebGraphQlHandler(); } private MockHttpServletRequest createServletRequest(String query) {