SSE handlers support keep-alive

Closes gh-1048
This commit is contained in:
rstoyanchev
2025-03-14 16:57:08 +00:00
parent a5ec819748
commit 44ca9b0da0
4 changed files with 235 additions and 39 deletions

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2020-2024 the original author or authors.
* Copyright 2020-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -51,31 +51,54 @@ import org.springframework.web.reactive.function.server.ServerResponse;
*/
public class GraphQlSseHandler extends AbstractGraphQlHttpHandler {
private static final Mono<ServerSentEvent<Map<String, Object>>> COMPLETE_EVENT = Mono.just(
ServerSentEvent.<Map<String, Object>>builder(Collections.emptyMap()).event("complete").build());
private static final ServerSentEvent<Map<String, Object>> HEARTBEAT_EVENT =
ServerSentEvent.<Map<String, Object>>builder().comment("").build();
private static final Mono<ServerSentEvent<Map<String, Object>>> COMPLETE_EVENT_MONO =
Mono.just(ServerSentEvent.<Map<String, Object>>builder(Collections.emptyMap()).event("complete").build());
@Nullable
private final Duration timeout;
@Nullable
private final Duration keepAliveDuration;
/**
* Constructor with the handler to delegate to, and no timeout by default,
* Basic constructor with the handler to delegate to, and no timeout by default,
* which results in never timing out.
* @param graphQlHandler the handler to delegate to
*/
public GraphQlSseHandler(WebGraphQlHandler graphQlHandler) {
this(graphQlHandler, null);
this(graphQlHandler, null, null);
}
/**
* Variant constructor with a timeout to use for SSE subscriptions.
* Constructor with a timeout on how long to wait for the application to return
* the {@link ServerResponse} that will start the stream.
* @param graphQlHandler the handler to delegate to
* @param timeout the timeout value to use or {@code null} to never time out
* @param timeout the timeout value, or {@code null} to never time out
* @since 1.3.3
*/
public GraphQlSseHandler(WebGraphQlHandler graphQlHandler, @Nullable Duration timeout) {
this(graphQlHandler, null, null);
}
/**
* Constructor with a keep-alive duration that determines how frequently to
* heartbeats during periods of inactivity.
* @param graphQlHandler the handler to delegate to
* @param timeout the timeout value to use or {@code null} to never time out
* @param keepAliveDuration how frequently to send empty comment messages
* when no other messages are sent
* @since 1.4.0
*/
public GraphQlSseHandler(
WebGraphQlHandler graphQlHandler, @Nullable Duration timeout, @Nullable Duration keepAliveDuration) {
super(graphQlHandler, null);
this.timeout = timeout;
this.keepAliveDuration = keepAliveDuration;
}
@@ -104,7 +127,12 @@ public class GraphQlSseHandler extends AbstractGraphQlHttpHandler {
Flux<ServerSentEvent<Map<String, Object>>> sseFlux =
resultFlux.map((event) -> ServerSentEvent.builder(event).event("next").build())
.concatWith(COMPLETE_EVENT);
.concatWith(COMPLETE_EVENT_MONO);
if (this.keepAliveDuration != null) {
KeepAliveHandler handler = new KeepAliveHandler(this.keepAliveDuration);
sseFlux = handler.compose(sseFlux);
}
Mono<ServerResponse> responseMono = ServerResponse.ok()
.contentType(MediaType.TEXT_EVENT_STREAM)
@@ -124,4 +152,34 @@ public class GraphQlSseHandler extends AbstractGraphQlHttpHandler {
.toSpecification());
}
private static final class KeepAliveHandler {
private final Duration keepAliveDuration;
private boolean eventSent;
KeepAliveHandler(Duration keepAliveDuration) {
this.keepAliveDuration = keepAliveDuration;
}
public Flux<ServerSentEvent<Map<String, Object>>> compose(Flux<ServerSentEvent<Map<String, Object>>> flux) {
return flux.doOnNext((event) -> this.eventSent = true)
.mergeWith(getKeepAliveFlux())
.takeUntil((sse) -> "complete".equals(sse.event()));
}
private Flux<ServerSentEvent<Map<String, Object>>> getKeepAliveFlux() {
return Flux.interval(this.keepAliveDuration, this.keepAliveDuration)
.filter((aLong) -> !checkEventSentAndClear())
.map((aLong) -> HEARTBEAT_EVENT);
}
private boolean checkEventSentAndClear() {
boolean result = this.eventSent;
this.eventSent = false;
return result;
}
}
}

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2020-2024 the original author or authors.
* Copyright 2020-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -18,6 +18,7 @@ package org.springframework.graphql.server.webmvc;
import java.io.IOException;
import java.time.Duration;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.function.Consumer;
@@ -29,6 +30,7 @@ import org.reactivestreams.Publisher;
import reactor.core.publisher.BaseSubscriber;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;
import org.springframework.graphql.execution.SubscriptionPublisherException;
import org.springframework.graphql.server.WebGraphQlHandler;
@@ -50,9 +52,15 @@ import org.springframework.web.servlet.function.ServerResponse;
*/
public class GraphQlSseHandler extends AbstractGraphQlHttpHandler {
private static final Map<String, Object> HEARTBEAT_MAP = new LinkedHashMap<>(0);
@Nullable
private final Duration timeout;
@Nullable
private final Duration keepAliveDuration;
/**
* Constructor with the handler to delegate to, and no timeout,
@@ -60,7 +68,7 @@ public class GraphQlSseHandler extends AbstractGraphQlHttpHandler {
* @param graphQlHandler the handler to delegate to
*/
public GraphQlSseHandler(WebGraphQlHandler graphQlHandler) {
this(graphQlHandler, null);
this(graphQlHandler, null, null);
}
/**
@@ -71,8 +79,24 @@ public class GraphQlSseHandler extends AbstractGraphQlHttpHandler {
* @since 1.3.3
*/
public GraphQlSseHandler(WebGraphQlHandler graphQlHandler, @Nullable Duration timeout) {
this(graphQlHandler, timeout, null);
}
/**
* Variant constructor with a timeout to use for SSE subscriptions.
* @param graphQlHandler the handler to delegate to
* @param timeout the timeout value to set on
* @param keepAliveDuration how frequently to send empty comment messages
* when no other messages are sent
* {@link org.springframework.web.context.request.async.AsyncWebRequest#setTimeout(Long)}
* @since 1.4.0
*/
public GraphQlSseHandler(
WebGraphQlHandler graphQlHandler, @Nullable Duration timeout, @Nullable Duration keepAliveDuration) {
super(graphQlHandler, null);
this.timeout = timeout;
this.keepAliveDuration = keepAliveDuration;
}
@@ -101,8 +125,8 @@ public class GraphQlSseHandler extends AbstractGraphQlHttpHandler {
});
return ((this.timeout != null) ?
ServerResponse.sse(SseSubscriber.connect(resultFlux), this.timeout) :
ServerResponse.sse(SseSubscriber.connect(resultFlux)));
ServerResponse.sse(SseSubscriber.connect(resultFlux, this.keepAliveDuration), this.timeout) :
ServerResponse.sse(SseSubscriber.connect(resultFlux, this.keepAliveDuration)));
}
@@ -120,6 +144,10 @@ public class GraphQlSseHandler extends AbstractGraphQlHttpHandler {
@Override
protected void hookOnNext(Map<String, Object> value) {
if (value == HEARTBEAT_MAP) {
sendHeartbeat();
return;
}
sendNext(value);
}
@@ -133,6 +161,18 @@ public class GraphQlSseHandler extends AbstractGraphQlHttpHandler {
}
}
private void sendHeartbeat() {
try {
// Currently, comment cannot be empty:
// https://github.com/spring-projects/spring-framework/issues/34608
this.sseBuilder.comment(" ");
this.sseBuilder.send();
}
catch (IOException exception) {
cancelWithError(exception);
}
}
private void cancelWithError(Throwable ex) {
this.cancel();
this.sseBuilder.error(ex);
@@ -169,12 +209,53 @@ public class GraphQlSseHandler extends AbstractGraphQlHttpHandler {
sendComplete();
}
static Consumer<ServerResponse.SseBuilder> connect(Flux<Map<String, Object>> resultFlux) {
static Consumer<ServerResponse.SseBuilder> connect(
Flux<Map<String, Object>> resultFlux, @Nullable Duration keepAliveDuration) {
return (sseBuilder) -> {
SseSubscriber subscriber = new SseSubscriber(sseBuilder);
resultFlux.subscribe(subscriber);
if (keepAliveDuration != null) {
KeepAliveHandler handler = new KeepAliveHandler(keepAliveDuration);
handler.compose(resultFlux).subscribe(subscriber);
}
else {
resultFlux.subscribe(subscriber);
}
};
}
}
private static final class KeepAliveHandler {
private final Duration keepAliveDuration;
private boolean eventSent;
private final Sinks.Empty<Void> completionSink = Sinks.empty();
KeepAliveHandler(Duration keepAliveDuration) {
this.keepAliveDuration = keepAliveDuration;
}
public Flux<Map<String, Object>> compose(Flux<Map<String, Object>> flux) {
return flux.doOnNext((event) -> this.eventSent = true)
.doOnComplete(this.completionSink::tryEmitEmpty)
.mergeWith(getKeepAliveFlux())
.takeUntilOther(this.completionSink.asMono());
}
private Flux<Map<String, Object>> getKeepAliveFlux() {
return Flux.interval(this.keepAliveDuration, this.keepAliveDuration)
.filter((aLong) -> !checkEventSentAndClear())
.map((aLong) -> HEARTBEAT_MAP);
}
private boolean checkEventSentAndClear() {
boolean result = this.eventSent;
this.eventSent = false;
return result;
}
}
}

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2020-2024 the original author or authors.
* Copyright 2020-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -17,6 +17,7 @@
package org.springframework.graphql.server.webflux;
import java.time.Duration;
import java.util.Collections;
import java.util.List;
@@ -28,6 +29,7 @@ import reactor.core.publisher.Mono;
import org.springframework.graphql.BookSource;
import org.springframework.graphql.GraphQlRequest;
import org.springframework.graphql.GraphQlSetup;
import org.springframework.graphql.server.WebGraphQlHandler;
import org.springframework.graphql.server.support.SerializableGraphQlRequest;
import org.springframework.http.MediaType;
import org.springframework.http.codec.HttpMessageWriter;
@@ -67,7 +69,7 @@ class GraphQlSseHandlerTests {
@Test
void shouldRejectQueryOperations() {
SerializableGraphQlRequest request = initRequest("{ bookById(id: 42) {name} }");
GraphQlSseHandler handler = createHandler(SEARCH_DATA_FETCHER);
GraphQlSseHandler handler = createSseHandler(SEARCH_DATA_FETCHER);
MockServerHttpResponse response = handleRequest(this.httpRequest, handler, request);
assertThat(response.getHeaders().getContentType().isCompatibleWith(MediaType.TEXT_EVENT_STREAM)).isTrue();
@@ -87,7 +89,7 @@ class GraphQlSseHandlerTests {
SerializableGraphQlRequest request = initRequest(
"subscription TestSubscription { bookSearch(author:\"Orwell\") { id name } }");
GraphQlSseHandler handler = createHandler(SEARCH_DATA_FETCHER);
GraphQlSseHandler handler = createSseHandler(SEARCH_DATA_FETCHER);
MockServerHttpResponse response = handleRequest(this.httpRequest, handler, request);
assertThat(response.getHeaders().getContentType().isCompatibleWith(MediaType.TEXT_EVENT_STREAM)).isTrue();
@@ -113,7 +115,7 @@ class GraphQlSseHandlerTests {
DataFetcher<?> errorDataFetcher = env ->
Flux.just(BookSource.getBook(1L)).concatWith(Flux.error(new IllegalStateException("test error")));
GraphQlSseHandler handler = createHandler(errorDataFetcher);
GraphQlSseHandler handler = createSseHandler(errorDataFetcher);
MockServerHttpResponse response = handleRequest(this.httpRequest, handler, request);
assertThat(response.getHeaders().getContentType().isCompatibleWith(MediaType.TEXT_EVENT_STREAM)).isTrue();
@@ -130,12 +132,40 @@ class GraphQlSseHandlerTests {
""");
}
private GraphQlSseHandler createHandler(DataFetcher<?> subscriptionDataFetcher) {
return new GraphQlSseHandler(
GraphQlSetup.schemaResource(BookSource.schema)
.queryFetcher("bookById", (env) -> BookSource.getBookWithoutAuthor(1L))
.subscriptionFetcher("bookSearch", subscriptionDataFetcher)
.toWebGraphQlHandler());
@Test
void shouldSendKeepAlivePings() {
SerializableGraphQlRequest request = initRequest(
"subscription TestSubscription { bookSearch(author:\"Orwell\") { id name } }");
WebGraphQlHandler webGraphQlHandler = createWebGraphQlHandler(env -> Mono.delay(Duration.ofMillis(50)).then());
GraphQlSseHandler handler = new GraphQlSseHandler(webGraphQlHandler, null, Duration.ofMillis(10));
assertThat(handleRequest(this.httpRequest, handler, request).getBodyAsString().block())
.startsWith("""
:
:
""")
.endsWith("""
:
event:complete
data:{}
""");
}
private GraphQlSseHandler createSseHandler(DataFetcher<?> subscriptionDataFetcher) {
WebGraphQlHandler webGraphQlHandler = createWebGraphQlHandler(subscriptionDataFetcher);
return new GraphQlSseHandler(webGraphQlHandler);
}
private static WebGraphQlHandler createWebGraphQlHandler(DataFetcher<?> subscriptionDataFetcher) {
return GraphQlSetup.schemaResource(BookSource.schema)
.queryFetcher("bookById", (env) -> BookSource.getBookWithoutAuthor(1L))
.subscriptionFetcher("bookSearch", subscriptionDataFetcher)
.toWebGraphQlHandler();
}
private static SerializableGraphQlRequest initRequest(String document) {

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2020-2024 the original author or authors.
* Copyright 2020-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -31,9 +31,11 @@ import jakarta.servlet.ServletOutputStream;
import jakarta.servlet.http.HttpServletResponse;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.graphql.BookSource;
import org.springframework.graphql.GraphQlSetup;
import org.springframework.graphql.server.WebGraphQlHandler;
import org.springframework.http.MediaType;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter;
@@ -71,6 +73,10 @@ class GraphQlSseHandlerTests {
.doOnCancel(() -> DATA_FETCHER_CANCELLED.set(true));
};
private static final String BOOK_SEARCH_REQUEST = """
{ "query": "subscription TestSubscription { bookSearch(author:\\"Orwell\\") { id name } }" }
""";
@Test
void shouldRejectQueryOperations() throws Exception {
@@ -92,9 +98,7 @@ class GraphQlSseHandlerTests {
@Test
void shouldWriteMultipleEventsForSubscription() throws Exception {
GraphQlSseHandler handler = createSseHandler(SEARCH_DATA_FETCHER);
MockHttpServletRequest request = createServletRequest("""
{ "query": "subscription TestSubscription { bookSearch(author:\\"Orwell\\") { id name } }" }
""");
MockHttpServletRequest request = createServletRequest(BOOK_SEARCH_REQUEST);
MockHttpServletResponse response = handleAndAwait(request, handler);
assertThat(response.getContentType()).isEqualTo(MediaType.TEXT_EVENT_STREAM_VALUE);
@@ -111,6 +115,31 @@ class GraphQlSseHandlerTests {
""");
}
@Test
void shouldSendKeepAlivePings() throws Exception {
WebGraphQlHandler webGraphQlHandler = createWebGraphQlHandler(env -> Mono.delay(Duration.ofMillis(50)).then());
GraphQlSseHandler handler = new GraphQlSseHandler(webGraphQlHandler, null, Duration.ofMillis(10));
MockHttpServletRequest request = createServletRequest(BOOK_SEARCH_REQUEST);
MockHttpServletResponse response = handleRequest(request, handler);
await().atMost(Duration.ofSeconds(1)).until(() -> response.getContentAsString().contains("complete"));
assertThat(response.getContentAsString())
.startsWith("""
:\s
:\s
""")
.endsWith("""
:\s
event:complete
data:
""");
}
@Test
void shouldWriteEventsAndTerminalError() throws Exception {
@@ -118,9 +147,7 @@ class GraphQlSseHandlerTests {
.concatWith(Flux.error(new IllegalStateException("test error")));
GraphQlSseHandler handler = createSseHandler(errorDataFetcher);
MockHttpServletRequest request = createServletRequest("""
{ "query": "subscription TestSubscription { bookSearch(author:\\"Orwell\\") { id name } }" }
""");
MockHttpServletRequest request = createServletRequest(BOOK_SEARCH_REQUEST);
MockHttpServletResponse response = handleAndAwait(request, handler);
assertThat(response.getContentType()).isEqualTo(MediaType.TEXT_EVENT_STREAM_VALUE);
@@ -140,9 +167,7 @@ class GraphQlSseHandlerTests {
@Test
void shouldCancelDataFetcherPublisherWhenWritingFails() throws Exception {
GraphQlSseHandler handler = createSseHandler(SEARCH_DATA_FETCHER);
MockHttpServletRequest servletRequest = createServletRequest("""
{ "query": "subscription TestSubscription { bookSearch(author:\\"Orwell\\") { id name } }" }
""");
MockHttpServletRequest servletRequest = createServletRequest(BOOK_SEARCH_REQUEST);
HttpServletResponse servletResponse = mock(HttpServletResponse.class);
ServletOutputStream outputStream = mock(ServletOutputStream.class);
@@ -165,9 +190,7 @@ class GraphQlSseHandlerTests {
.delayElements(Duration.ofMillis(500)).doOnCancel(() -> DATA_FETCHER_CANCELLED.set(true));
GraphQlSseHandler handler = createSseHandler(errorDataFetcher);
MockHttpServletRequest servletRequest = createServletRequest("""
{ "query": "subscription TestSubscription { bookSearch(author:\\"Orwell\\") { id name } }" }
""");
MockHttpServletRequest servletRequest = createServletRequest(BOOK_SEARCH_REQUEST);
MockHttpServletResponse servletResponse = handleRequest(servletRequest, handler);
for (AsyncListener listener : ((MockAsyncContext) servletRequest.getAsyncContext()).getListeners()) {
@@ -180,10 +203,14 @@ class GraphQlSseHandlerTests {
}
private GraphQlSseHandler createSseHandler(DataFetcher<?> dataFetcher) {
return new GraphQlSseHandler(GraphQlSetup.schemaResource(BookSource.schema)
return new GraphQlSseHandler(createWebGraphQlHandler(dataFetcher));
}
private static WebGraphQlHandler createWebGraphQlHandler(DataFetcher<?> dataFetcher) {
return GraphQlSetup.schemaResource(BookSource.schema)
.queryFetcher("bookById", (env) -> BookSource.getBookWithoutAuthor(1L))
.subscriptionFetcher("bookSearch", dataFetcher)
.toWebGraphQlHandler());
.toWebGraphQlHandler();
}
private MockHttpServletRequest createServletRequest(String query) {