diff --git a/spring-graphql/src/main/java/org/springframework/graphql/execution/ContextDataFetcherDecorator.java b/spring-graphql/src/main/java/org/springframework/graphql/execution/ContextDataFetcherDecorator.java index 4b888a87..fe0be98a 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/execution/ContextDataFetcherDecorator.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/execution/ContextDataFetcherDecorator.java @@ -21,6 +21,7 @@ import java.util.List; import graphql.ExecutionInput; import graphql.GraphQLContext; import graphql.TrivialDataFetcher; +import graphql.execution.DataFetcherResult; import graphql.schema.DataFetcher; import graphql.schema.DataFetchingEnvironment; import graphql.schema.FieldCoordinates; @@ -39,6 +40,7 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import org.springframework.graphql.ExecutionGraphQlRequest; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; /** @@ -74,7 +76,6 @@ final class ContextDataFetcherDecorator implements DataFetcher { } - @SuppressWarnings("ReactiveStreamsUnusedPublisher") @Override public Object get(DataFetchingEnvironment env) throws Exception { @@ -83,10 +84,33 @@ final class ContextDataFetcherDecorator implements DataFetcher { ContextSnapshot snapshot = (env.getLocalContext() instanceof GraphQLContext localContext) ? snapshotFactory.captureFrom(graphQlContext, localContext) : snapshotFactory.captureFrom(graphQlContext); + Mono cancelledRequest = graphQlContext.get(ExecutionGraphQlRequest.CANCEL_PUBLISHER_CONTEXT_KEY); Object value = snapshot.wrap(() -> this.delegate.get(env)).call(); + if (value instanceof DataFetcherResult dataFetcherResult) { + Object adapted = updateValue(dataFetcherResult.getData(), snapshot, cancelledRequest); + value = DataFetcherResult.newResult() + .data(adapted) + .errors(dataFetcherResult.getErrors()) + .localContext(dataFetcherResult.getLocalContext()).build(); + } + else { + value = updateValue(value, snapshot, cancelledRequest); + } + + return value; + } + + @SuppressWarnings("ReactiveStreamsUnusedPublisher") + private @Nullable Object updateValue( + @Nullable Object value, ContextSnapshot snapshot, @Nullable Mono cancelledRequest) { + + if (value == null) { + return null; + } + if (this.subscription) { Flux subscriptionResult = ReactiveAdapterRegistryHelper.toSubscriptionFlux(value) .onErrorResume((exception) -> { diff --git a/spring-graphql/src/test/java/org/springframework/graphql/execution/ContextDataFetcherDecoratorTests.java b/spring-graphql/src/test/java/org/springframework/graphql/execution/ContextDataFetcherDecoratorTests.java index 7a637b70..b8eed439 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/execution/ContextDataFetcherDecoratorTests.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/execution/ContextDataFetcherDecoratorTests.java @@ -30,6 +30,7 @@ import graphql.GraphQL; import graphql.GraphQLError; import graphql.GraphqlErrorBuilder; import graphql.TrivialDataFetcher; +import graphql.execution.DataFetcherResult; import graphql.schema.DataFetcher; import graphql.schema.DataFetcherFactories; import graphql.schema.FieldCoordinates; @@ -135,6 +136,32 @@ public class ContextDataFetcherDecoratorTests { .verifyComplete(); } + @Test + void fluxDataFetcherSubscriptionWithDataFetcherResult() throws Exception { + GraphQL graphQl = GraphQlSetup.schemaContent(SCHEMA_CONTENT) + .subscriptionFetcher("greetings", (env) -> { + Flux flux = Mono.delay(Duration.ofMillis(50)) + .flatMapMany((aLong) -> Flux.deferContextual((context) -> { + String name = context.get("name"); + return Flux.just("Hi", "Bonjour", "Hola").map((s) -> s + " " + name); + })); + return DataFetcherResult.newResult().data(flux).build(); + }) + .toGraphQl(); + + ExecutionInput input = ExecutionInput.newExecutionInput().query("subscription { greetings }").build(); + input.getGraphQLContext().put("name", "007"); + + ExecutionResult executionResult = graphQl.executeAsync(input).get(); + + Flux greetingsFlux = ResponseHelper.forSubscription(executionResult) + .map(response -> response.toEntity("greetings", String.class)); + + StepVerifier.create(greetingsFlux) + .expectNext("Hi 007", "Bonjour 007", "Hola 007") + .verifyComplete(); + } + @Test void fluxDataFetcherSubscriptionThrowingException() throws Exception {