From 5c2164c5f2407f11ae1e63a3ce76a66661d6ea86 Mon Sep 17 00:00:00 2001 From: Brian Clozel Date: Tue, 11 Mar 2025 10:31:53 +0100 Subject: [PATCH] Propagate cancel signal from transport to data fetchers Prior to this commit, a WebSocket/SSE client disconnecting from the data stream would cause a CANCEL signal to be sent to upstream publishers. This signal would flow from the transport layer up to the `ExecutionGraphQlService`. Because the `GraphQL` engine itself relies on `CompletableFuture`, the CANCEL signal would not flow through and reactive data fetchers would not receive it. This means that costly reactive operations would not be cancelled and this could cause write failures as publishers would still produce values. This commit adds at the service level a Reactor `Sink` to the `GraphQLContext` that can be picked up by the `ContextDataFetcherDecorator` when decorating reactive data fetchers. This allows us to manually cancel publishers when the CANCEL signal is received at the transport level. Fixes gh-1149 --- .../graphql/ExecutionGraphQlRequest.java | 6 ++ .../ContextDataFetcherDecorator.java | 17 +++-- .../DefaultExecutionGraphQlService.java | 6 +- .../ContextDataFetcherDecoratorTests.java | 69 ++++++++++++++++++- .../DefaultExecutionGraphQlServiceTests.java | 20 ++++++ 5 files changed, 111 insertions(+), 7 deletions(-) diff --git a/spring-graphql/src/main/java/org/springframework/graphql/ExecutionGraphQlRequest.java b/spring-graphql/src/main/java/org/springframework/graphql/ExecutionGraphQlRequest.java index 25b20ed2..36f4e89a 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/ExecutionGraphQlRequest.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/ExecutionGraphQlRequest.java @@ -36,6 +36,12 @@ import org.springframework.lang.Nullable; */ public interface ExecutionGraphQlRequest extends GraphQlRequest { + /** + * Key of the GraphQL context entry that holds a {@code Mono} that completes + * when the inbound GraphQL request is cancelled at the transport level. + */ + String CANCEL_PUBLISHER_CONTEXT_KEY = ExecutionGraphQlRequest.class.getName() + ".cancelled"; + /** * Return the transport assigned id for the request that in turn sets * {@link ExecutionInput.Builder#executionId(ExecutionId) executionId}. diff --git a/spring-graphql/src/main/java/org/springframework/graphql/execution/ContextDataFetcherDecorator.java b/spring-graphql/src/main/java/org/springframework/graphql/execution/ContextDataFetcherDecorator.java index 2e5daa1b..4b888a87 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/execution/ContextDataFetcherDecorator.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/execution/ContextDataFetcherDecorator.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2024 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,6 +38,7 @@ import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import org.springframework.graphql.ExecutionGraphQlRequest; import org.springframework.util.Assert; /** @@ -79,15 +80,15 @@ final class ContextDataFetcherDecorator implements DataFetcher { GraphQLContext graphQlContext = env.getGraphQlContext(); ContextSnapshotFactory snapshotFactory = ContextSnapshotFactoryHelper.getInstance(graphQlContext); - ContextSnapshot snapshot = (env.getLocalContext() instanceof GraphQLContext localContext) ? snapshotFactory.captureFrom(graphQlContext, localContext) : snapshotFactory.captureFrom(graphQlContext); + Mono cancelledRequest = graphQlContext.get(ExecutionGraphQlRequest.CANCEL_PUBLISHER_CONTEXT_KEY); Object value = snapshot.wrap(() -> this.delegate.get(env)).call(); if (this.subscription) { - return ReactiveAdapterRegistryHelper.toSubscriptionFlux(value) + Flux subscriptionResult = ReactiveAdapterRegistryHelper.toSubscriptionFlux(value) .onErrorResume((exception) -> { // Already handled, e.g. controller methods? if (exception instanceof SubscriptionPublisherException) { @@ -95,13 +96,19 @@ final class ContextDataFetcherDecorator implements DataFetcher { } return this.subscriptionExceptionResolver.resolveException(exception) .flatMap((errors) -> Mono.error(new SubscriptionPublisherException(errors, exception))); - }) - .contextWrite(snapshot::updateContext); + }); + if (cancelledRequest != null) { + subscriptionResult = subscriptionResult.takeUntilOther(cancelledRequest); + } + return subscriptionResult.contextWrite(snapshot::updateContext); } value = ReactiveAdapterRegistryHelper.toMonoIfReactive(value); if (value instanceof Mono mono) { + if (cancelledRequest != null) { + mono = mono.takeUntilOther(cancelledRequest); + } value = mono.contextWrite(snapshot::updateContext).toFuture(); } diff --git a/spring-graphql/src/main/java/org/springframework/graphql/execution/DefaultExecutionGraphQlService.java b/spring-graphql/src/main/java/org/springframework/graphql/execution/DefaultExecutionGraphQlService.java index 278af14f..94f9f2d0 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/execution/DefaultExecutionGraphQlService.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/execution/DefaultExecutionGraphQlService.java @@ -31,6 +31,7 @@ import graphql.execution.instrumentation.dataloader.EmptyDataLoaderRegistryInsta import io.micrometer.context.ContextSnapshotFactory; import org.dataloader.DataLoaderRegistry; import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; import org.springframework.graphql.ExecutionGraphQlRequest; import org.springframework.graphql.ExecutionGraphQlResponse; @@ -101,12 +102,15 @@ public class DefaultExecutionGraphQlService implements ExecutionGraphQlService { ContextSnapshotFactoryHelper.saveInstance(factory, graphQLContext); factory.captureFrom(contextView).updateContext(graphQLContext); + Sinks.Empty requestCancelled = Sinks.empty(); + graphQLContext.put(ExecutionGraphQlRequest.CANCEL_PUBLISHER_CONTEXT_KEY, requestCancelled.asMono()); ExecutionInput executionInputToUse = registerDataLoaders(executionInput); return Mono.fromFuture(this.graphQlSource.graphQl().executeAsync(executionInputToUse)) .onErrorResume((ex) -> ex instanceof GraphQLError, (ex) -> Mono.just(ExecutionResult.newExecutionResult().addError((GraphQLError) ex).build())) - .map((result) -> new DefaultExecutionGraphQlResponse(executionInputToUse, result)); + .map((result) -> new DefaultExecutionGraphQlResponse(executionInputToUse, result)) + .doOnCancel(requestCancelled::tryEmitEmpty); }); } diff --git a/spring-graphql/src/test/java/org/springframework/graphql/execution/ContextDataFetcherDecoratorTests.java b/spring-graphql/src/test/java/org/springframework/graphql/execution/ContextDataFetcherDecoratorTests.java index 698108ae..7a637b70 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/execution/ContextDataFetcherDecoratorTests.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/execution/ContextDataFetcherDecoratorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2024 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,9 @@ package org.springframework.graphql.execution; import java.time.Duration; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiConsumer; import graphql.ExecutionInput; @@ -41,13 +43,16 @@ import io.micrometer.context.ContextSnapshotFactory; import org.junit.jupiter.api.Test; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; +import org.springframework.graphql.ExecutionGraphQlRequest; import org.springframework.graphql.GraphQlSetup; import org.springframework.graphql.ResponseHelper; import org.springframework.graphql.TestThreadLocalAccessor; import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; /** * Tests for {@link ContextDataFetcherDecorator}. @@ -257,4 +262,66 @@ public class ContextDataFetcherDecoratorTests { assertThat(dataFetcher).isInstanceOf(TrivialDataFetcher.class); } + @Test + void cancelMonoDataFetcherWhenRequestCancelled() throws Exception { + AtomicBoolean dataFetcherCancelled = new AtomicBoolean(); + GraphQL graphQl = GraphQlSetup.schemaContent(SCHEMA_CONTENT) + .queryFetcher("greeting", (env) -> + Mono.just("Hello") + .delayElement(Duration.ofSeconds(1)) + .doOnCancel(() -> dataFetcherCancelled.set(true)) + ) + .toGraphQl(); + + Sinks.Empty requestCancelled = Sinks.empty(); + ExecutionInput input = ExecutionInput.newExecutionInput().query("{ greeting }") + .graphQLContext(Map.of(ExecutionGraphQlRequest.CANCEL_PUBLISHER_CONTEXT_KEY, requestCancelled.asMono())).build(); + + CompletableFuture asyncResult = graphQl.executeAsync(input); + requestCancelled.tryEmitEmpty(); + await().atMost(Duration.ofSeconds(2)).until(dataFetcherCancelled::get); + } + + @Test + void cancelFluxDataFetcherWhenRequestCancelled() throws Exception { + AtomicBoolean dataFetcherCancelled = new AtomicBoolean(); + GraphQL graphQl = GraphQlSetup.schemaContent(SCHEMA_CONTENT) + .queryFetcher("greeting", (env) -> + Flux.just("Hello") + .delayElements(Duration.ofSeconds(1)) + .doOnCancel(() -> dataFetcherCancelled.set(true)) + ) + .toGraphQl(); + + Sinks.Empty requestCancelled = Sinks.empty(); + ExecutionInput input = ExecutionInput.newExecutionInput().query("{ greeting }") + .graphQLContext(Map.of(ExecutionGraphQlRequest.CANCEL_PUBLISHER_CONTEXT_KEY, requestCancelled.asMono())).build(); + + CompletableFuture asyncResult = graphQl.executeAsync(input); + requestCancelled.tryEmitEmpty(); + await().atMost(Duration.ofSeconds(2)).until(dataFetcherCancelled::get); + } + + @Test + void cancelFluxDataFetcherSubscriptionWhenRequestCancelled() throws Exception { + AtomicBoolean dataFetcherCancelled = new AtomicBoolean(); + GraphQL graphQl = GraphQlSetup.schemaContent(SCHEMA_CONTENT) + .subscriptionFetcher("greetings", (env) -> + Flux.just("Hi", "Bonjour", "Hola") + .delayElements(Duration.ofSeconds(1)) + .doOnCancel(() -> dataFetcherCancelled.set(true)) + ) + .toGraphQl(); + Sinks.Empty requestCancelled = Sinks.empty(); + ExecutionInput input = ExecutionInput.newExecutionInput().query("subscription { greetings }") + .graphQLContext(Map.of(ExecutionGraphQlRequest.CANCEL_PUBLISHER_CONTEXT_KEY, requestCancelled.asMono())).build(); + + ExecutionResult executionResult = graphQl.executeAsync(input).get(); + ResponseHelper.forSubscription(executionResult).subscribe(); + + requestCancelled.tryEmitEmpty(); + await().atMost(Duration.ofSeconds(2)).until(dataFetcherCancelled::get); + assertThat(dataFetcherCancelled).isTrue(); + } + } diff --git a/spring-graphql/src/test/java/org/springframework/graphql/execution/DefaultExecutionGraphQlServiceTests.java b/spring-graphql/src/test/java/org/springframework/graphql/execution/DefaultExecutionGraphQlServiceTests.java index 6431451f..7143c695 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/execution/DefaultExecutionGraphQlServiceTests.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/execution/DefaultExecutionGraphQlServiceTests.java @@ -16,12 +16,16 @@ package org.springframework.graphql.execution; +import java.time.Duration; import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; import graphql.ErrorType; import org.dataloader.DataLoaderRegistry; import org.junit.jupiter.api.Test; import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; import org.springframework.graphql.Author; import org.springframework.graphql.Book; @@ -77,4 +81,20 @@ public class DefaultExecutionGraphQlServiceTests { .hasFieldOrPropertyWithValue("errorType", ErrorType.ValidationError); } + @Test + void cancellationSupport() { + AtomicBoolean cancelled = new AtomicBoolean(); + Mono greetingMono = Mono.just("hi") + .delayElement(Duration.ofSeconds(3)) + .doOnCancel(() -> cancelled.set(true)); + + Mono execution = GraphQlSetup.schemaContent("type Query { greeting: String }") + .queryFetcher("greeting", (env) -> greetingMono) + .toGraphQlService() + .execute(TestExecutionRequest.forDocument("{ greeting }")); + + StepVerifier.create(execution).thenCancel().verify(); + assertThat(cancelled).isTrue(); + } + }