diff --git a/spring-graphql/src/main/java/org/springframework/graphql/ExecutionGraphQlRequest.java b/spring-graphql/src/main/java/org/springframework/graphql/ExecutionGraphQlRequest.java index 25b20ed2..36f4e89a 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/ExecutionGraphQlRequest.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/ExecutionGraphQlRequest.java @@ -36,6 +36,12 @@ import org.springframework.lang.Nullable; */ public interface ExecutionGraphQlRequest extends GraphQlRequest { + /** + * Key of the GraphQL context entry that holds a {@code Mono} that completes + * when the inbound GraphQL request is cancelled at the transport level. + */ + String CANCEL_PUBLISHER_CONTEXT_KEY = ExecutionGraphQlRequest.class.getName() + ".cancelled"; + /** * Return the transport assigned id for the request that in turn sets * {@link ExecutionInput.Builder#executionId(ExecutionId) executionId}. diff --git a/spring-graphql/src/main/java/org/springframework/graphql/execution/ContextDataFetcherDecorator.java b/spring-graphql/src/main/java/org/springframework/graphql/execution/ContextDataFetcherDecorator.java index 2e5daa1b..4b888a87 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/execution/ContextDataFetcherDecorator.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/execution/ContextDataFetcherDecorator.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2024 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,6 +38,7 @@ import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import org.springframework.graphql.ExecutionGraphQlRequest; import org.springframework.util.Assert; /** @@ -79,15 +80,15 @@ final class ContextDataFetcherDecorator implements DataFetcher { GraphQLContext graphQlContext = env.getGraphQlContext(); ContextSnapshotFactory snapshotFactory = ContextSnapshotFactoryHelper.getInstance(graphQlContext); - ContextSnapshot snapshot = (env.getLocalContext() instanceof GraphQLContext localContext) ? snapshotFactory.captureFrom(graphQlContext, localContext) : snapshotFactory.captureFrom(graphQlContext); + Mono cancelledRequest = graphQlContext.get(ExecutionGraphQlRequest.CANCEL_PUBLISHER_CONTEXT_KEY); Object value = snapshot.wrap(() -> this.delegate.get(env)).call(); if (this.subscription) { - return ReactiveAdapterRegistryHelper.toSubscriptionFlux(value) + Flux subscriptionResult = ReactiveAdapterRegistryHelper.toSubscriptionFlux(value) .onErrorResume((exception) -> { // Already handled, e.g. controller methods? if (exception instanceof SubscriptionPublisherException) { @@ -95,13 +96,19 @@ final class ContextDataFetcherDecorator implements DataFetcher { } return this.subscriptionExceptionResolver.resolveException(exception) .flatMap((errors) -> Mono.error(new SubscriptionPublisherException(errors, exception))); - }) - .contextWrite(snapshot::updateContext); + }); + if (cancelledRequest != null) { + subscriptionResult = subscriptionResult.takeUntilOther(cancelledRequest); + } + return subscriptionResult.contextWrite(snapshot::updateContext); } value = ReactiveAdapterRegistryHelper.toMonoIfReactive(value); if (value instanceof Mono mono) { + if (cancelledRequest != null) { + mono = mono.takeUntilOther(cancelledRequest); + } value = mono.contextWrite(snapshot::updateContext).toFuture(); } diff --git a/spring-graphql/src/main/java/org/springframework/graphql/execution/DefaultExecutionGraphQlService.java b/spring-graphql/src/main/java/org/springframework/graphql/execution/DefaultExecutionGraphQlService.java index 278af14f..94f9f2d0 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/execution/DefaultExecutionGraphQlService.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/execution/DefaultExecutionGraphQlService.java @@ -31,6 +31,7 @@ import graphql.execution.instrumentation.dataloader.EmptyDataLoaderRegistryInsta import io.micrometer.context.ContextSnapshotFactory; import org.dataloader.DataLoaderRegistry; import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; import org.springframework.graphql.ExecutionGraphQlRequest; import org.springframework.graphql.ExecutionGraphQlResponse; @@ -101,12 +102,15 @@ public class DefaultExecutionGraphQlService implements ExecutionGraphQlService { ContextSnapshotFactoryHelper.saveInstance(factory, graphQLContext); factory.captureFrom(contextView).updateContext(graphQLContext); + Sinks.Empty requestCancelled = Sinks.empty(); + graphQLContext.put(ExecutionGraphQlRequest.CANCEL_PUBLISHER_CONTEXT_KEY, requestCancelled.asMono()); ExecutionInput executionInputToUse = registerDataLoaders(executionInput); return Mono.fromFuture(this.graphQlSource.graphQl().executeAsync(executionInputToUse)) .onErrorResume((ex) -> ex instanceof GraphQLError, (ex) -> Mono.just(ExecutionResult.newExecutionResult().addError((GraphQLError) ex).build())) - .map((result) -> new DefaultExecutionGraphQlResponse(executionInputToUse, result)); + .map((result) -> new DefaultExecutionGraphQlResponse(executionInputToUse, result)) + .doOnCancel(requestCancelled::tryEmitEmpty); }); } diff --git a/spring-graphql/src/test/java/org/springframework/graphql/execution/ContextDataFetcherDecoratorTests.java b/spring-graphql/src/test/java/org/springframework/graphql/execution/ContextDataFetcherDecoratorTests.java index 698108ae..7a637b70 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/execution/ContextDataFetcherDecoratorTests.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/execution/ContextDataFetcherDecoratorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2024 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,9 @@ package org.springframework.graphql.execution; import java.time.Duration; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiConsumer; import graphql.ExecutionInput; @@ -41,13 +43,16 @@ import io.micrometer.context.ContextSnapshotFactory; import org.junit.jupiter.api.Test; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; +import org.springframework.graphql.ExecutionGraphQlRequest; import org.springframework.graphql.GraphQlSetup; import org.springframework.graphql.ResponseHelper; import org.springframework.graphql.TestThreadLocalAccessor; import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; /** * Tests for {@link ContextDataFetcherDecorator}. @@ -257,4 +262,66 @@ public class ContextDataFetcherDecoratorTests { assertThat(dataFetcher).isInstanceOf(TrivialDataFetcher.class); } + @Test + void cancelMonoDataFetcherWhenRequestCancelled() throws Exception { + AtomicBoolean dataFetcherCancelled = new AtomicBoolean(); + GraphQL graphQl = GraphQlSetup.schemaContent(SCHEMA_CONTENT) + .queryFetcher("greeting", (env) -> + Mono.just("Hello") + .delayElement(Duration.ofSeconds(1)) + .doOnCancel(() -> dataFetcherCancelled.set(true)) + ) + .toGraphQl(); + + Sinks.Empty requestCancelled = Sinks.empty(); + ExecutionInput input = ExecutionInput.newExecutionInput().query("{ greeting }") + .graphQLContext(Map.of(ExecutionGraphQlRequest.CANCEL_PUBLISHER_CONTEXT_KEY, requestCancelled.asMono())).build(); + + CompletableFuture asyncResult = graphQl.executeAsync(input); + requestCancelled.tryEmitEmpty(); + await().atMost(Duration.ofSeconds(2)).until(dataFetcherCancelled::get); + } + + @Test + void cancelFluxDataFetcherWhenRequestCancelled() throws Exception { + AtomicBoolean dataFetcherCancelled = new AtomicBoolean(); + GraphQL graphQl = GraphQlSetup.schemaContent(SCHEMA_CONTENT) + .queryFetcher("greeting", (env) -> + Flux.just("Hello") + .delayElements(Duration.ofSeconds(1)) + .doOnCancel(() -> dataFetcherCancelled.set(true)) + ) + .toGraphQl(); + + Sinks.Empty requestCancelled = Sinks.empty(); + ExecutionInput input = ExecutionInput.newExecutionInput().query("{ greeting }") + .graphQLContext(Map.of(ExecutionGraphQlRequest.CANCEL_PUBLISHER_CONTEXT_KEY, requestCancelled.asMono())).build(); + + CompletableFuture asyncResult = graphQl.executeAsync(input); + requestCancelled.tryEmitEmpty(); + await().atMost(Duration.ofSeconds(2)).until(dataFetcherCancelled::get); + } + + @Test + void cancelFluxDataFetcherSubscriptionWhenRequestCancelled() throws Exception { + AtomicBoolean dataFetcherCancelled = new AtomicBoolean(); + GraphQL graphQl = GraphQlSetup.schemaContent(SCHEMA_CONTENT) + .subscriptionFetcher("greetings", (env) -> + Flux.just("Hi", "Bonjour", "Hola") + .delayElements(Duration.ofSeconds(1)) + .doOnCancel(() -> dataFetcherCancelled.set(true)) + ) + .toGraphQl(); + Sinks.Empty requestCancelled = Sinks.empty(); + ExecutionInput input = ExecutionInput.newExecutionInput().query("subscription { greetings }") + .graphQLContext(Map.of(ExecutionGraphQlRequest.CANCEL_PUBLISHER_CONTEXT_KEY, requestCancelled.asMono())).build(); + + ExecutionResult executionResult = graphQl.executeAsync(input).get(); + ResponseHelper.forSubscription(executionResult).subscribe(); + + requestCancelled.tryEmitEmpty(); + await().atMost(Duration.ofSeconds(2)).until(dataFetcherCancelled::get); + assertThat(dataFetcherCancelled).isTrue(); + } + } diff --git a/spring-graphql/src/test/java/org/springframework/graphql/execution/DefaultExecutionGraphQlServiceTests.java b/spring-graphql/src/test/java/org/springframework/graphql/execution/DefaultExecutionGraphQlServiceTests.java index 6431451f..7143c695 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/execution/DefaultExecutionGraphQlServiceTests.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/execution/DefaultExecutionGraphQlServiceTests.java @@ -16,12 +16,16 @@ package org.springframework.graphql.execution; +import java.time.Duration; import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; import graphql.ErrorType; import org.dataloader.DataLoaderRegistry; import org.junit.jupiter.api.Test; import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; import org.springframework.graphql.Author; import org.springframework.graphql.Book; @@ -77,4 +81,20 @@ public class DefaultExecutionGraphQlServiceTests { .hasFieldOrPropertyWithValue("errorType", ErrorType.ValidationError); } + @Test + void cancellationSupport() { + AtomicBoolean cancelled = new AtomicBoolean(); + Mono greetingMono = Mono.just("hi") + .delayElement(Duration.ofSeconds(3)) + .doOnCancel(() -> cancelled.set(true)); + + Mono execution = GraphQlSetup.schemaContent("type Query { greeting: String }") + .queryFetcher("greeting", (env) -> greetingMono) + .toGraphQlService() + .execute(TestExecutionRequest.forDocument("{ greeting }")); + + StepVerifier.create(execution).thenCancel().verify(); + assertThat(cancelled).isTrue(); + } + }