From 65be8ed35ff4524316c05395dd69916bff874128 Mon Sep 17 00:00:00 2001 From: rstoyanchev Date: Fri, 9 Sep 2022 07:16:47 +0100 Subject: [PATCH] Do not skip DataFetcherFactories Closes gh-440 --- .../ContextDataFetcherDecorator.java | 14 +++- .../ContextDataFetcherDecoratorTests.java | 68 ++++++++++++++++++- 2 files changed, 76 insertions(+), 6 deletions(-) diff --git a/spring-graphql/src/main/java/org/springframework/graphql/execution/ContextDataFetcherDecorator.java b/spring-graphql/src/main/java/org/springframework/graphql/execution/ContextDataFetcherDecorator.java index 1f2d3d9b..06696d99 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/execution/ContextDataFetcherDecorator.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/execution/ContextDataFetcherDecorator.java @@ -109,14 +109,14 @@ final class ContextDataFetcherDecorator implements DataFetcher { return new GraphQLTypeVisitorStub() { @Override - public TraversalControl visitGraphQLFieldDefinition(GraphQLFieldDefinition fieldDefinition, - TraverserContext context) { + public TraversalControl visitGraphQLFieldDefinition( + GraphQLFieldDefinition fieldDefinition, TraverserContext context) { GraphQLCodeRegistry.Builder codeRegistry = context.getVarFromParents(GraphQLCodeRegistry.Builder.class); GraphQLFieldsContainer parent = (GraphQLFieldsContainer) context.getParentNode(); DataFetcher dataFetcher = codeRegistry.getDataFetcher(parent, fieldDefinition); - if (dataFetcher.getClass().getPackage().getName().startsWith("graphql.")) { + if (skipDataFetcher(dataFetcher)) { return TraversalControl.CONTINUE; } @@ -125,6 +125,14 @@ final class ContextDataFetcherDecorator implements DataFetcher { codeRegistry.dataFetcher(parent, fieldDefinition, dataFetcher); return TraversalControl.CONTINUE; } + + private boolean skipDataFetcher(DataFetcher dataFetcher) { + Class type = dataFetcher.getClass(); + if (type.getPackage().getName().startsWith("graphql.")) { + return !type.getSimpleName().startsWith("DataFetcherFactories"); + } + return false; + } }; } diff --git a/spring-graphql/src/test/java/org/springframework/graphql/execution/ContextDataFetcherDecoratorTests.java b/spring-graphql/src/test/java/org/springframework/graphql/execution/ContextDataFetcherDecoratorTests.java index 56befead..43fb90ae 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/execution/ContextDataFetcherDecoratorTests.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/execution/ContextDataFetcherDecoratorTests.java @@ -19,12 +19,19 @@ package org.springframework.graphql.execution; import java.time.Duration; import java.util.Collections; import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.function.BiConsumer; import graphql.ExecutionInput; import graphql.ExecutionResult; import graphql.GraphQL; import graphql.GraphQLError; import graphql.GraphqlErrorBuilder; +import graphql.schema.DataFetcher; +import graphql.schema.DataFetcherFactories; +import graphql.schema.GraphQLFieldDefinition; +import graphql.schema.idl.SchemaDirectiveWiring; +import graphql.schema.idl.SchemaDirectiveWiringEnvironment; import org.junit.jupiter.api.Test; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -42,10 +49,18 @@ import static org.assertj.core.api.Assertions.assertThat; * Tests for {@link ContextDataFetcherDecorator}. * @author Rossen Stoyanchev */ +@SuppressWarnings("ReactiveStreamsUnusedPublisher") public class ContextDataFetcherDecoratorTests { - private static final String SCHEMA_CONTENT = - "type Query { greeting: String, greetings: [String] } type Subscription { greetings: String }"; + private static final String SCHEMA_CONTENT = "" + + "directive @UpperCase on FIELD_DEFINITION " + + "type Query { " + + " greeting: String @UpperCase, " + + " greetings: [String] " + + "} " + + "type Subscription { " + + " greetings: String " + + "}"; @Test @@ -112,7 +127,7 @@ public class ContextDataFetcherDecoratorTests { } @Test - void fluxDataFetcherSubscriptionThrowException() throws Exception { + void fluxDataFetcherSubscriptionThrowingException() throws Exception { SubscriptionExceptionResolver resolver = SubscriptionExceptionResolver.forSingleError(exception -> @@ -177,4 +192,51 @@ public class ContextDataFetcherDecoratorTests { } } + @Test // gh-440 + void dataFetcherDecoratedWithDataFetcherFactories() { + + SchemaDirectiveWiring directiveWiring = new SchemaDirectiveWiring() { + + @SuppressWarnings("unchecked") + @Override + public GraphQLFieldDefinition onField(SchemaDirectiveWiringEnvironment env) { + if (env.getDirective("UpperCase") != null) { + return env.setFieldDataFetcher(DataFetcherFactories.wrapDataFetcher( + env.getFieldDataFetcher(), + ((dataFetchingEnv, value) -> { + if (value instanceof String) { + return ((String) value).toUpperCase(); + } + else if (value instanceof Mono) { + return ((Mono) value).map(String::toUpperCase); + } + else { + throw new IllegalArgumentException(); + } + }))); + } + else { + return env.getElement(); + } + } + }; + + BiConsumer> tester = (schemaDirectiveWiring, dataFetcher) -> { + + GraphQL graphQl = GraphQlSetup.schemaContent(SCHEMA_CONTENT) + .queryFetcher("greeting", dataFetcher) + .runtimeWiring(builder -> builder.directiveWiring(directiveWiring)) + .toGraphQl(); + + ExecutionInput input = ExecutionInput.newExecutionInput().query("{ greeting }").build(); + Mono resultMono = Mono.fromFuture(graphQl.executeAsync(input)); + + String greeting = ResponseHelper.forResult(resultMono).toEntity("greeting", String.class); + assertThat(greeting).isEqualTo("HELLO"); + }; + + tester.accept(directiveWiring, env -> CompletableFuture.completedFuture("hello")); + tester.accept(directiveWiring, env -> Mono.just("hello")); + } + }