diff --git a/graphql-spring-boot-starter/src/main/java/org/springframework/graphql/boot/WebMvcGraphQlAutoConfiguration.java b/graphql-spring-boot-starter/src/main/java/org/springframework/graphql/boot/WebMvcGraphQlAutoConfiguration.java index e3844a66..0d37d8be 100644 --- a/graphql-spring-boot-starter/src/main/java/org/springframework/graphql/boot/WebMvcGraphQlAutoConfiguration.java +++ b/graphql-spring-boot-starter/src/main/java/org/springframework/graphql/boot/WebMvcGraphQlAutoConfiguration.java @@ -40,6 +40,7 @@ import org.springframework.core.io.Resource; import org.springframework.core.io.ResourceLoader; import org.springframework.graphql.GraphQlService; import org.springframework.graphql.execution.GraphQlSource; +import org.springframework.graphql.execution.ThreadLocalAccessor; import org.springframework.graphql.web.WebGraphQlHandler; import org.springframework.graphql.web.WebInterceptor; import org.springframework.graphql.web.webmvc.GraphQlHttpHandler; @@ -70,9 +71,13 @@ public class WebMvcGraphQlAutoConfiguration { @Bean @ConditionalOnMissingBean - public WebGraphQlHandler webGraphQlHandler(ObjectProvider interceptors, GraphQlService service) { + public WebGraphQlHandler webGraphQlHandler( + ObjectProvider interceptorsProvider, GraphQlService service, + ObjectProvider accessorsProvider) { + return WebGraphQlHandler.builder(service) - .interceptors(interceptors.orderedStream().collect(Collectors.toList())) + .interceptors(interceptorsProvider.orderedStream().collect(Collectors.toList())) + .threadLocalAccessors(accessorsProvider.orderedStream().collect(Collectors.toList())) .build(); } diff --git a/samples/webflux-websocket/src/main/java/io/spring/sample/graphql/ReactorContextWebFilter.java b/samples/webflux-websocket/src/main/java/io/spring/sample/graphql/ContextWebFilter.java similarity index 89% rename from samples/webflux-websocket/src/main/java/io/spring/sample/graphql/ReactorContextWebFilter.java rename to samples/webflux-websocket/src/main/java/io/spring/sample/graphql/ContextWebFilter.java index 37ef9bfd..13b4887e 100644 --- a/samples/webflux-websocket/src/main/java/io/spring/sample/graphql/ReactorContextWebFilter.java +++ b/samples/webflux-websocket/src/main/java/io/spring/sample/graphql/ContextWebFilter.java @@ -23,9 +23,9 @@ import org.springframework.web.server.WebFilterChain; /** * WebFilter that inserts a key-value pair into the Reactor context which is - * transferred to and accessible to Reactor-based data fetchers. + * transferred to and accessible in Reactor data fetchers. */ -public class ReactorContextWebFilter implements WebFilter { +public class ContextWebFilter implements WebFilter { @Override public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { diff --git a/samples/webflux-websocket/src/main/java/io/spring/sample/graphql/SampleApplication.java b/samples/webflux-websocket/src/main/java/io/spring/sample/graphql/SampleApplication.java index 68e7bbfe..cfbd0c70 100644 --- a/samples/webflux-websocket/src/main/java/io/spring/sample/graphql/SampleApplication.java +++ b/samples/webflux-websocket/src/main/java/io/spring/sample/graphql/SampleApplication.java @@ -28,8 +28,8 @@ public class SampleApplication { } @Bean - ReactorContextWebFilter reactorContextWebFilter() { - return new ReactorContextWebFilter(); + ContextWebFilter reactorContextWebFilter() { + return new ContextWebFilter(); } } diff --git a/samples/webmvc-http/src/main/java/io/spring/sample/graphql/SampleApplication.java b/samples/webmvc-http/src/main/java/io/spring/sample/graphql/SampleApplication.java index 56ce0e69..400baa49 100644 --- a/samples/webmvc-http/src/main/java/io/spring/sample/graphql/SampleApplication.java +++ b/samples/webmvc-http/src/main/java/io/spring/sample/graphql/SampleApplication.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,9 @@ import org.springframework.boot.autoconfigure.SpringBootApplication; @SpringBootApplication public class SampleApplication { + public static void main(String[] args) { SpringApplication.run(SampleApplication.class, args); } + } diff --git a/samples/webmvc-http/src/main/java/io/spring/sample/graphql/greeting/GreetingDataWiring.java b/samples/webmvc-http/src/main/java/io/spring/sample/graphql/greeting/GreetingDataWiring.java new file mode 100644 index 00000000..99cf3410 --- /dev/null +++ b/samples/webmvc-http/src/main/java/io/spring/sample/graphql/greeting/GreetingDataWiring.java @@ -0,0 +1,24 @@ +package io.spring.sample.graphql.greeting; + +import graphql.schema.idl.RuntimeWiring; + +import org.springframework.graphql.boot.RuntimeWiringCustomizer; +import org.springframework.stereotype.Component; +import org.springframework.web.context.request.RequestAttributes; +import org.springframework.web.context.request.RequestContextHolder; + +import static org.springframework.web.context.request.RequestAttributes.SCOPE_REQUEST; + +@Component +public class GreetingDataWiring implements RuntimeWiringCustomizer { + + + @Override + public void customize(RuntimeWiring.Builder builder) { + builder.type("Query", typeWiring -> + typeWiring.dataFetcher("greeting", env -> { + RequestAttributes attributes = RequestContextHolder.getRequestAttributes(); + return "Hello " + attributes.getAttribute(RequestAttributeFilter.NAME_ATTRIBUTE, SCOPE_REQUEST); + })); + } +} diff --git a/samples/webmvc-http/src/main/java/io/spring/sample/graphql/greeting/RequestAttributeFilter.java b/samples/webmvc-http/src/main/java/io/spring/sample/graphql/greeting/RequestAttributeFilter.java new file mode 100644 index 00000000..ffca7461 --- /dev/null +++ b/samples/webmvc-http/src/main/java/io/spring/sample/graphql/greeting/RequestAttributeFilter.java @@ -0,0 +1,43 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.spring.sample.graphql.greeting; + +import java.io.IOException; + +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; + +import org.springframework.stereotype.Component; + +/** + * Servlet Filter that adds a Servlet request attribute. + */ +@Component +public class RequestAttributeFilter implements Filter { + + public static final String NAME_ATTRIBUTE = RequestAttributeFilter.class.getName() + ".name"; + + + @Override + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { + request.setAttribute(NAME_ATTRIBUTE, "007"); + chain.doFilter(request, response); + } + +} diff --git a/samples/webmvc-http/src/main/java/io/spring/sample/graphql/greeting/RequestAttributesAccessor.java b/samples/webmvc-http/src/main/java/io/spring/sample/graphql/greeting/RequestAttributesAccessor.java new file mode 100644 index 00000000..81aff2e6 --- /dev/null +++ b/samples/webmvc-http/src/main/java/io/spring/sample/graphql/greeting/RequestAttributesAccessor.java @@ -0,0 +1,60 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.spring.sample.graphql.greeting; + +import java.util.Collections; +import java.util.Map; + +import org.springframework.graphql.execution.ThreadLocalAccessor; +import org.springframework.stereotype.Component; +import org.springframework.web.context.request.RequestAttributes; +import org.springframework.web.context.request.RequestContextHolder; + +/** + * {@link ThreadLocalAccessor} to expose a thread-bound RequestAttributes object + * to data fetchers in Spring GraphQL. + */ +@Component +public class RequestAttributesAccessor implements ThreadLocalAccessor { + + private static final String ATTRIBUTES_KEY = + RequestAttributesAccessor.class.getName() + ".requestAttributes"; + + + @Override + public void extractValues(Map container) { + RequestAttributes attributes = RequestContextHolder.getRequestAttributes(); + if (attributes != null) { + container.put(ATTRIBUTES_KEY, attributes); + } + } + + @Override + public void restoreValues(Map values) { + RequestAttributes attributes = (RequestAttributes) values.get(ATTRIBUTES_KEY); + if (attributes != null) { + RequestContextHolder.setRequestAttributes(attributes); + } + } + + @Override + public void resetValues(Map values) { + if (values.get(ATTRIBUTES_KEY) != null) { + RequestContextHolder.resetRequestAttributes(); + } + } + +} diff --git a/samples/webmvc-http/src/main/java/io/spring/sample/graphql/greeting/package-info.java b/samples/webmvc-http/src/main/java/io/spring/sample/graphql/greeting/package-info.java new file mode 100644 index 00000000..3220fa71 --- /dev/null +++ b/samples/webmvc-http/src/main/java/io/spring/sample/graphql/greeting/package-info.java @@ -0,0 +1,6 @@ +@NonNullApi +@NonNullFields +package io.spring.sample.graphql.greeting; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/samples/webmvc-http/src/main/java/io/spring/sample/graphql/package-info.java b/samples/webmvc-http/src/main/java/io/spring/sample/graphql/package-info.java new file mode 100644 index 00000000..78414e7f --- /dev/null +++ b/samples/webmvc-http/src/main/java/io/spring/sample/graphql/package-info.java @@ -0,0 +1,6 @@ +@NonNullApi +@NonNullFields +package io.spring.sample.graphql; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/samples/webmvc-http/src/main/java/io/spring/sample/graphql/project/package-info.java b/samples/webmvc-http/src/main/java/io/spring/sample/graphql/project/package-info.java new file mode 100644 index 00000000..67a9733d --- /dev/null +++ b/samples/webmvc-http/src/main/java/io/spring/sample/graphql/project/package-info.java @@ -0,0 +1,6 @@ +@NonNullApi +@NonNullFields +package io.spring.sample.graphql.project; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/samples/webmvc-http/src/main/java/io/spring/sample/graphql/repository/package-info.java b/samples/webmvc-http/src/main/java/io/spring/sample/graphql/repository/package-info.java new file mode 100644 index 00000000..08f88984 --- /dev/null +++ b/samples/webmvc-http/src/main/java/io/spring/sample/graphql/repository/package-info.java @@ -0,0 +1,6 @@ +@NonNullApi +@NonNullFields +package io.spring.sample.graphql.repository; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/samples/webmvc-http/src/main/resources/graphql/schema.graphqls b/samples/webmvc-http/src/main/resources/graphql/schema.graphqls index 871ea03f..9d4d995e 100644 --- a/samples/webmvc-http/src/main/resources/graphql/schema.graphqls +++ b/samples/webmvc-http/src/main/resources/graphql/schema.graphqls @@ -1,4 +1,5 @@ type Query { + greeting: String artifactRepositories : [ArtifactRepository] artifactRepository(id : ID!) : ArtifactRepository project(slug: ID!): Project diff --git a/spring-graphql/src/main/java/org/springframework/graphql/execution/CompositeThreadLocalAccessor.java b/spring-graphql/src/main/java/org/springframework/graphql/execution/CompositeThreadLocalAccessor.java new file mode 100644 index 00000000..e053a5dc --- /dev/null +++ b/spring-graphql/src/main/java/org/springframework/graphql/execution/CompositeThreadLocalAccessor.java @@ -0,0 +1,49 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.graphql.execution; + +import java.util.List; +import java.util.Map; + +/** + * Default implementation of a composite accessor that is returned from + * {@link ThreadLocalAccessor#composite(List)}. + */ +class CompositeThreadLocalAccessor implements ThreadLocalAccessor { + + private final List accessors; + + + public CompositeThreadLocalAccessor(List accessors) { + this.accessors = accessors; + } + + + @Override + public void extractValues(Map container) { + this.accessors.forEach(accessor -> accessor.extractValues(container)); + } + + @Override + public void restoreValues(Map values) { + this.accessors.forEach(accessor -> accessor.restoreValues(values)); + } + + @Override + public void resetValues(Map values) { + this.accessors.forEach(accessor -> accessor.resetValues(values)); + } +} diff --git a/spring-graphql/src/main/java/org/springframework/graphql/execution/ReactorDataFetcherAdapter.java b/spring-graphql/src/main/java/org/springframework/graphql/execution/ContextDataFetcherDecorator.java similarity index 73% rename from spring-graphql/src/main/java/org/springframework/graphql/execution/ReactorDataFetcherAdapter.java rename to spring-graphql/src/main/java/org/springframework/graphql/execution/ContextDataFetcherDecorator.java index d197a04e..e3cea765 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/execution/ReactorDataFetcherAdapter.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/execution/ContextDataFetcherDecorator.java @@ -34,19 +34,22 @@ import reactor.util.context.ContextView; import org.springframework.util.Assert; /** - * Adapter that can wrap a registered {@link DataFetcher} and enable it to return - * {@link Flux} or {@link Mono}, also adding Reactor Context passed through - * the {@link ExecutionInput}. Also exposes a {@link #TYPE_VISITOR} to apply - * the adapter. + * Wrap a {@link DataFetcher} to enable the following: + *
    + *
  • Support {@link Mono} return value. + *
  • Support {@link Flux} return value as a shortcut to {@link Flux#collectList()}. + *
  • Re-establish Reactor Context passed via {@link ExecutionInput}. + *
  • Re-establish ThreadLocal context passed via {@link ExecutionInput}. + *
*/ -class ReactorDataFetcherAdapter implements DataFetcher { +class ContextDataFetcherDecorator implements DataFetcher { private final DataFetcher delegate; private final boolean subscription; - private ReactorDataFetcherAdapter(DataFetcher delegate, boolean subscription) { + private ContextDataFetcherDecorator(DataFetcher delegate, boolean subscription) { Assert.notNull(delegate, "'delegate' DataFetcher is required"); this.delegate = delegate; this.subscription = subscription; @@ -55,11 +58,19 @@ class ReactorDataFetcherAdapter implements DataFetcher { @Override public Object get(DataFetchingEnvironment environment) throws Exception { - Object value = this.delegate.get(environment); + ContextView contextView = ContextManager.getReactorContext(environment); + + Object value; + try { + ContextManager.restoreThreadLocalValues(contextView); + value = this.delegate.get(environment); + } + finally { + ContextManager.resetThreadLocalValues(contextView); + } if (this.subscription) { - ContextView context = ContextManager.getReactorContext(environment); - return (context != null ? Flux.from((Publisher) value).contextWrite(context) : value); + return (!contextView.isEmpty() ? Flux.from((Publisher) value).contextWrite(contextView) : value); } if (value instanceof Flux) { @@ -68,9 +79,8 @@ class ReactorDataFetcherAdapter implements DataFetcher { if (value instanceof Mono) { Mono valueMono = (Mono) value; - ContextView reactorContext = ContextManager.getReactorContext(environment); - if (reactorContext != null) { - valueMono = valueMono.contextWrite(reactorContext); + if (!contextView.isEmpty()) { + valueMono = valueMono.contextWrite(contextView); } value = valueMono.toFuture(); } @@ -98,7 +108,7 @@ class ReactorDataFetcherAdapter implements DataFetcher { } boolean handlesSubscription = parent.getName().equals("Subscription"); - dataFetcher = new ReactorDataFetcherAdapter(dataFetcher, handlesSubscription); + dataFetcher = new ContextDataFetcherDecorator(dataFetcher, handlesSubscription); codeRegistry.dataFetcher(parent, fieldDefinition, dataFetcher); return TraversalControl.CONTINUE; } diff --git a/spring-graphql/src/main/java/org/springframework/graphql/execution/ContextManager.java b/spring-graphql/src/main/java/org/springframework/graphql/execution/ContextManager.java index 9408bbff..9beff012 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/execution/ContextManager.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/execution/ContextManager.java @@ -15,9 +15,13 @@ */ package org.springframework.graphql.execution; +import java.util.LinkedHashMap; +import java.util.Map; + import graphql.ExecutionInput; import graphql.GraphQLContext; import graphql.schema.DataFetchingEnvironment; +import reactor.util.context.Context; import reactor.util.context.ContextView; import org.springframework.lang.Nullable; @@ -27,10 +31,16 @@ import org.springframework.lang.Nullable; * through the {@link ExecutionInput} and the {@link DataFetchingEnvironment} * of a request. */ -abstract class ContextManager { +public abstract class ContextManager { - private static final String REACTOR_CONTEXT_KEY = - ReactorDataFetcherAdapter.class.getName() + ".REACTOR_CONTEXT"; + private static final String CONTEXT_VIEW_KEY = + ContextManager.class.getName() + ".CONTEXT_VIEW"; + + private static final String THREAD_LOCAL_VALUES_KEY = + ContextManager.class.getName() + ".THREAD_VALUES_ACCESSOR"; + + private static final String THREAD_LOCAL_ACCESSOR_KEY = + ContextManager.class.getName() + ".THREAD_LOCAL_ACCESSOR"; /** @@ -38,17 +48,55 @@ abstract class ContextManager { * later access through the {@link DataFetchingEnvironment}. */ static void setReactorContext(ContextView contextView, ExecutionInput input) { - ((GraphQLContext) input.getContext()).put(REACTOR_CONTEXT_KEY, contextView); + ((GraphQLContext) input.getContext()).put(CONTEXT_VIEW_KEY, contextView); } /** - * Return the Reactor ContextView saved in the given DataFetchingEnvironment, - * or null if not present. + * Return the Reactor ContextView saved in the given DataFetchingEnvironment. */ - @Nullable static ContextView getReactorContext(DataFetchingEnvironment environment) { GraphQLContext graphQlContext = environment.getContext(); - return graphQlContext.get(REACTOR_CONTEXT_KEY); + return graphQlContext.getOrDefault(CONTEXT_VIEW_KEY, Context.empty()); + } + + /** + * Use the given accessor to extract ThreadLocal value, and return a Reactor + * context that contains both the extracted values and the accessor. + * @param accessor the accessor to use + */ + public static ContextView extractThreadLocalValues(ThreadLocalAccessor accessor) { + Map valuesMap = new LinkedHashMap<>(); + accessor.extractValues(valuesMap); + return Context.of(THREAD_LOCAL_VALUES_KEY, valuesMap, THREAD_LOCAL_ACCESSOR_KEY, accessor); + } + + /** + * Look up saved ThreadLocal values and use them to re-establish ThreadLocal context. + */ + static void restoreThreadLocalValues(ContextView contextView) { + ThreadLocalAccessor accessor = getThreadLocalAccessor(contextView); + if (accessor != null) { + accessor.restoreValues(getThreadLocalValues(contextView)); + } + } + + /** + * Look up saved ThreadLocal values and remove associated ThreadLocal context. + */ + static void resetThreadLocalValues(ContextView contextView) { + ThreadLocalAccessor accessor = getThreadLocalAccessor(contextView); + if (accessor != null) { + accessor.resetValues(getThreadLocalValues(contextView)); + } + } + + @Nullable + private static ThreadLocalAccessor getThreadLocalAccessor(ContextView contextView) { + return (contextView.hasKey(THREAD_LOCAL_ACCESSOR_KEY) ? contextView.get(THREAD_LOCAL_ACCESSOR_KEY) : null); + } + + private static Map getThreadLocalValues(ContextView contextView) { + return contextView.get(THREAD_LOCAL_VALUES_KEY); } } diff --git a/spring-graphql/src/main/java/org/springframework/graphql/execution/DefaultGraphQlSourceBuilder.java b/spring-graphql/src/main/java/org/springframework/graphql/execution/DefaultGraphQlSourceBuilder.java index e0dd4046..54497f6b 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/execution/DefaultGraphQlSourceBuilder.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/execution/DefaultGraphQlSourceBuilder.java @@ -58,7 +58,7 @@ class DefaultGraphQlSourceBuilder implements GraphQlSource.Builder { DefaultGraphQlSourceBuilder() { - this.typeVisitors.add(ReactorDataFetcherAdapter.TYPE_VISITOR); + this.typeVisitors.add(ContextDataFetcherDecorator.TYPE_VISITOR); } diff --git a/spring-graphql/src/main/java/org/springframework/graphql/execution/ExceptionResolversExceptionHandler.java b/spring-graphql/src/main/java/org/springframework/graphql/execution/ExceptionResolversExceptionHandler.java index b113eafa..d7c485d2 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/execution/ExceptionResolversExceptionHandler.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/execution/ExceptionResolversExceptionHandler.java @@ -16,7 +16,6 @@ package org.springframework.graphql.execution; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.concurrent.CompletionException; @@ -26,11 +25,14 @@ import graphql.execution.DataFetcherExceptionHandler; import graphql.execution.DataFetcherExceptionHandlerParameters; import graphql.execution.DataFetcherExceptionHandlerResult; import graphql.schema.DataFetchingEnvironment; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import reactor.core.publisher.Flux; -import reactor.core.scheduler.Schedulers; +import reactor.core.publisher.Mono; import reactor.util.context.ContextView; import org.springframework.util.Assert; +import org.springframework.web.client.ExtractingResponseErrorHandler; /** * {@link DataFetcherExceptionHandler} that invokes {@link DataFetcherExceptionResolver}'s @@ -38,6 +40,9 @@ import org.springframework.util.Assert; */ class ExceptionResolversExceptionHandler implements DataFetcherExceptionHandler { + private static Log logger = LogFactory.getLog(ExtractingResponseErrorHandler.class); + + private final List resolvers; @@ -60,24 +65,35 @@ class ExceptionResolversExceptionHandler implements DataFetcherExceptionHandler @SuppressWarnings("ConstantConditions") public DataFetcherExceptionHandlerResult invokeChain(Throwable ex, DataFetchingEnvironment env) { - return Flux.fromIterable(this.resolvers) - .publishOn(Schedulers.boundedElastic()) // until GraphQL Java supports async exception handling - .flatMap(resolver -> resolver.resolveException(ex, env)) - .next() - .defaultIfEmpty(Collections.singletonList(applyDefaultHandling(ex, env))) - .map(errors -> DataFetcherExceptionHandlerResult.newResult().errors(errors).build()) - .contextWrite(context -> { - ContextView contextToAdd = ContextManager.getReactorContext(env); - return (contextToAdd != null ? context.putAll(contextToAdd) : context); - }) - .block(); + // For now we have to block: + // https://github.com/graphql-java/graphql-java/issues/2356 + try { + return Flux.fromIterable(this.resolvers) + .flatMap(resolver -> resolver.resolveException(ex, env)) + .next() + .map(errors -> DataFetcherExceptionHandlerResult.newResult().errors(errors).build()) + .switchIfEmpty(Mono.fromCallable(() -> applyDefaultHandling(ex, env))) + .contextWrite(context -> { + ContextView contextView = ContextManager.getReactorContext(env); + return (contextView.isEmpty() ? context : context.putAll(contextView)); + }) + .toFuture() + .get(); + } + catch (Exception ex2) { + if (logger.isWarnEnabled()) { + logger.warn("Failed to handle " + ex.getMessage(), ex2); + } + return applyDefaultHandling(ex, env); + } } - private GraphQLError applyDefaultHandling(Throwable ex, DataFetchingEnvironment env) { - return GraphqlErrorBuilder.newError(env) + private DataFetcherExceptionHandlerResult applyDefaultHandling(Throwable ex, DataFetchingEnvironment env) { + GraphQLError error = GraphqlErrorBuilder.newError(env) .message(ex.getMessage()) .errorType(ErrorType.INTERNAL_ERROR) .build(); + return DataFetcherExceptionHandlerResult.newResult(error).build(); } } diff --git a/spring-graphql/src/main/java/org/springframework/graphql/execution/SyncDataFetcherExceptionResolver.java b/spring-graphql/src/main/java/org/springframework/graphql/execution/SyncDataFetcherExceptionResolver.java new file mode 100644 index 00000000..4f30fbe5 --- /dev/null +++ b/spring-graphql/src/main/java/org/springframework/graphql/execution/SyncDataFetcherExceptionResolver.java @@ -0,0 +1,49 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.graphql.execution; + +import java.util.List; + +import graphql.GraphQLError; +import graphql.schema.DataFetchingEnvironment; +import reactor.core.publisher.Mono; +import reactor.util.context.ContextView; + +/** + * {@link DataFetcherExceptionResolver} that resolves exceptions synchronously. + */ +public interface SyncDataFetcherExceptionResolver extends DataFetcherExceptionResolver { + + @Override + default Mono> resolveException(Throwable exception, DataFetchingEnvironment environment) { + ContextView contextView = ContextManager.getReactorContext(environment); + try { + ContextManager.restoreThreadLocalValues(contextView); + return Mono.just(doResolveException(exception, environment)); + } + finally { + ContextManager.resetThreadLocalValues(contextView); + } + } + + /** + * Implement this method to resolve exceptions. + * @param exception the exception to resolve + * @param environment the environment for the invoked {@code DataFetcher} + */ + List doResolveException(Throwable exception, DataFetchingEnvironment environment); + +} diff --git a/spring-graphql/src/main/java/org/springframework/graphql/execution/ThreadLocalAccessor.java b/spring-graphql/src/main/java/org/springframework/graphql/execution/ThreadLocalAccessor.java new file mode 100644 index 00000000..59297e94 --- /dev/null +++ b/spring-graphql/src/main/java/org/springframework/graphql/execution/ThreadLocalAccessor.java @@ -0,0 +1,66 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.graphql.execution; + +import java.util.List; +import java.util.Map; + +import org.springframework.beans.factory.ObjectProvider; + +/** + * Interface to be implemented by a framework or an application in order to + * assist with extracting ThreadLocal values at the web layer, which can then be + * re-established for DataFetcher's that are potentially executing on a + * different thread. + * + *

Implementations may be declared as beans in Spring configuration and + * ordered as defined in {@link ObjectProvider#orderedStream()}. + */ +public interface ThreadLocalAccessor { + + /** + * Extract ThreadLocal values and add them to the given Map which is then + * passed to {@link #restoreValues(Map)} and {@link #resetValues(Map)} + * before and after the execution of a {@link graphql.schema.DataFetcher}. + * @param container container for ThreadLocal values + */ + void extractValues(Map container); + + /** + * Re-establish ThreadLocal context by looking up values, previously + * extracted via {@link #extractValues(Map)}. + * @param values the saved ThreadLocal values + */ + void restoreValues(Map values); + + /** + * Reset ThreadLocal context for the given values, previously extracted + * via {@link #extractValues(Map)}. + * @param values the saved ThreadLocal values + */ + void resetValues(Map values); + + + /** + * Create a composite accessor that delegates to all of the given accessors. + * @param accessors the accessors to aggregate + * @return the composite accessor + */ + static ThreadLocalAccessor composite(List accessors) { + return new CompositeThreadLocalAccessor(accessors); + } + +} diff --git a/spring-graphql/src/main/java/org/springframework/graphql/web/DefaultWebGraphQlHandlerBuilder.java b/spring-graphql/src/main/java/org/springframework/graphql/web/DefaultWebGraphQlHandlerBuilder.java index cae299b2..d86bde2e 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/web/DefaultWebGraphQlHandlerBuilder.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/web/DefaultWebGraphQlHandlerBuilder.java @@ -20,10 +20,15 @@ import java.util.Collections; import java.util.List; import graphql.ExecutionInput; +import reactor.core.publisher.Mono; +import reactor.util.context.ContextView; import org.springframework.graphql.GraphQlService; +import org.springframework.graphql.execution.ContextManager; +import org.springframework.graphql.execution.ThreadLocalAccessor; import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; /** * Default implementation of {@link WebGraphQlHandler.Builder}. @@ -35,17 +40,31 @@ class DefaultWebGraphQlHandlerBuilder implements WebGraphQlHandler.Builder { @Nullable private List interceptors; + @Nullable + private List accessors; + DefaultWebGraphQlHandlerBuilder(GraphQlService service) { - Assert.notNull(service, "GraphQlService must not be null"); + Assert.notNull(service, "GraphQlService is required"); this.service = service; } @Override public WebGraphQlHandler.Builder interceptors(List interceptors) { - this.interceptors = (this.interceptors != null ? this.interceptors : new ArrayList<>()); - this.interceptors.addAll(interceptors); + if (!CollectionUtils.isEmpty(interceptors)) { + this.interceptors = (this.interceptors != null ? this.interceptors : new ArrayList<>()); + this.interceptors.addAll(interceptors); + } + return this; + } + + @Override + public WebGraphQlHandler.Builder threadLocalAccessors(List accessors) { + if (!CollectionUtils.isEmpty(accessors)) { + this.accessors = (this.accessors != null ? this.accessors : new ArrayList<>()); + this.accessors.addAll(accessors); + } return this; } @@ -54,10 +73,13 @@ class DefaultWebGraphQlHandlerBuilder implements WebGraphQlHandler.Builder { List interceptorsToUse = (this.interceptors != null ? this.interceptors : Collections.emptyList()); - return interceptorsToUse.stream() + WebGraphQlHandler handler = interceptorsToUse.stream() .reduce(WebInterceptor::andThen) .map(interceptor -> (WebGraphQlHandler) input -> interceptor.intercept(input, createHandler())) .orElse(createHandler()); + + return (CollectionUtils.isEmpty(this.accessors) ? handler : + new ThreadLocalExtractingHandler(handler, ThreadLocalAccessor.composite(this.accessors))); } private WebGraphQlHandler createHandler() { @@ -67,4 +89,30 @@ class DefaultWebGraphQlHandlerBuilder implements WebGraphQlHandler.Builder { }; } + + /** + * {@link WebGraphQlHandler} that extracts ThreadLocal values and saves them + * in the Reactor context for subsequent use for DataFetcher's. + */ + private static class ThreadLocalExtractingHandler implements WebGraphQlHandler { + + private final WebGraphQlHandler delegate; + + private final ThreadLocalAccessor accessor; + + ThreadLocalExtractingHandler(WebGraphQlHandler delegate, ThreadLocalAccessor accessor) { + this.delegate = delegate; + this.accessor = accessor; + } + + @Override + public Mono handle(WebInput input) { + return this.delegate.handle(input) + .contextWrite(context -> { + ContextView view = ContextManager.extractThreadLocalValues(this.accessor); + return (!view.isEmpty() ? context.putAll(view) : context); + }); + } + } + } diff --git a/spring-graphql/src/main/java/org/springframework/graphql/web/WebGraphQlHandler.java b/spring-graphql/src/main/java/org/springframework/graphql/web/WebGraphQlHandler.java index 808b6bf3..fe732f7f 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/web/WebGraphQlHandler.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/web/WebGraphQlHandler.java @@ -20,12 +20,11 @@ import java.util.List; import reactor.core.publisher.Mono; import org.springframework.graphql.GraphQlService; +import org.springframework.graphql.execution.ThreadLocalAccessor; /** * Contract to handle a GraphQL over HTTP or WebSocket request that forms the * basis of a {@link WebInterceptor} delegation chain. - * - * @see WebInterceptor#createHandler(List, GraphQlService) */ public interface WebGraphQlHandler { @@ -37,4 +36,41 @@ public interface WebGraphQlHandler { */ Mono handle(WebInput input); + + /** + * Provides access to a builder to create a {@link WebGraphQlHandler} instance. + * @param graphQlService the {@link GraphQlService} to use for actual + * execution of the request. + */ + static Builder builder(GraphQlService graphQlService) { + return new DefaultWebGraphQlHandlerBuilder(graphQlService); + } + + + /** + * Builder for {@link WebGraphQlHandler} that represents a {@link WebInterceptor} + * chain followed by a {@link GraphQlService}. + */ + interface Builder { + + /** + * Configure interceptors to be invoked before the target {@code GraphQlService}. + * @param interceptors the interceptors to add + */ + Builder interceptors(List interceptors); + + /** + * Configure accessors for ThreadLocal variables to use to extract + * ThreadLocal values at the Web framework level, have those propagated + * and re-established at the DataFetcher level. + * @param accessors the accessors to add + */ + Builder threadLocalAccessors(List accessors); + + /** + * Build the {@link WebGraphQlHandler} instance. + */ + WebGraphQlHandler build(); + } + } diff --git a/spring-graphql/src/main/java/org/springframework/graphql/web/WebInterceptor.java b/spring-graphql/src/main/java/org/springframework/graphql/web/WebInterceptor.java index de855a81..203d1114 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/web/WebInterceptor.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/web/WebInterceptor.java @@ -15,14 +15,11 @@ */ package org.springframework.graphql.web; -import java.util.List; - import graphql.ExecutionInput; import graphql.ExecutionResult; import reactor.core.publisher.Mono; import org.springframework.beans.factory.ObjectProvider; -import org.springframework.graphql.GraphQlService; import org.springframework.util.Assert; /** @@ -58,38 +55,4 @@ public interface WebInterceptor { return (currentInput, next) -> intercept(currentInput, nextInput -> interceptor.intercept(nextInput, next)); } - /** - * Return {@link WebGraphQlHandler} that invokes the current interceptor - * first and then the given {@link GraphQlService} for actual execution of - * the GraphQL operation. - */ - default WebGraphQlHandler apply(GraphQlService service) { - Assert.notNull(service, "GraphQlService must not be null"); - return currentInput -> intercept(currentInput, createHandler(service)); - } - - - /** - * Factory method for a {@link WebGraphQlHandler} with a chain of - * interceptors followed by a {@link GraphQlService} at the end. - */ - static WebGraphQlHandler createHandler(List interceptors, GraphQlService service) { - return interceptors.stream() - .reduce(WebInterceptor::andThen) - .map(interceptor -> interceptor.apply(service)) - .orElse(createHandler(service)); - } - - /** - * Factory method for a {@link WebGraphQlHandler} that simple invokes the - * given {@link GraphQlService} adapting to its input and output. - */ - static WebGraphQlHandler createHandler(GraphQlService graphQlService) { - Assert.notNull(graphQlService, "GraphQlService must not be null"); - return webInput -> { - ExecutionInput executionInput = webInput.toExecutionInput(); - return graphQlService.execute(executionInput).map(result -> new WebOutput(webInput, result)); - }; - } - } \ No newline at end of file diff --git a/spring-graphql/src/test/java/org/springframework/graphql/GraphQlTestUtils.java b/spring-graphql/src/test/java/org/springframework/graphql/GraphQlTestUtils.java new file mode 100644 index 00000000..d56becaf --- /dev/null +++ b/spring-graphql/src/test/java/org/springframework/graphql/GraphQlTestUtils.java @@ -0,0 +1,64 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.graphql; + +import java.nio.charset.StandardCharsets; +import java.util.Arrays; + +import graphql.GraphQL; +import graphql.schema.DataFetcher; +import graphql.schema.idl.RuntimeWiring; + +import org.springframework.core.io.ByteArrayResource; +import org.springframework.graphql.execution.DataFetcherExceptionResolver; +import org.springframework.graphql.execution.GraphQlSource; + +/** + * Utility methods for GraphQL tests. + */ +public abstract class GraphQlTestUtils { + + public static GraphQL initGraphQl( + String schemaContent, String typeName, String fieldName, DataFetcher fetcher) { + + return initGraphQlSource(schemaContent, typeName, fieldName, fetcher) + .build() + .graphQl(); + } + + public static GraphQL initGraphQl( + String schemaContent, String typeName, String fieldName, DataFetcher fetcher, + DataFetcherExceptionResolver... resolvers) { + + return initGraphQlSource(schemaContent, typeName, fieldName, fetcher) + .exceptionResolvers(Arrays.asList(resolvers)) + .build() + .graphQl(); + } + + public static GraphQlSource.Builder initGraphQlSource( + String schemaContent, String typeName, String fieldName, DataFetcher fetcher) { + + RuntimeWiring wiring = RuntimeWiring.newRuntimeWiring() + .type(typeName, builder -> builder.dataFetcher(fieldName, fetcher)) + .build(); + + return GraphQlSource.builder() + .schemaResource(new ByteArrayResource(schemaContent.getBytes(StandardCharsets.UTF_8))) + .runtimeWiring(wiring); + } + +} diff --git a/spring-graphql/src/test/java/org/springframework/graphql/execution/ReactorDataFetcherAdapterTests.java b/spring-graphql/src/test/java/org/springframework/graphql/execution/ContextDataFetcherDecoratorTests.java similarity index 60% rename from spring-graphql/src/test/java/org/springframework/graphql/execution/ReactorDataFetcherAdapterTests.java rename to spring-graphql/src/test/java/org/springframework/graphql/execution/ContextDataFetcherDecoratorTests.java index a45a2b6c..8b52ad17 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/execution/ReactorDataFetcherAdapterTests.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/execution/ContextDataFetcherDecoratorTests.java @@ -15,7 +15,6 @@ */ package org.springframework.graphql.execution; -import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.List; import java.util.Map; @@ -23,33 +22,34 @@ import java.util.Map; import graphql.ExecutionInput; import graphql.ExecutionResult; import graphql.GraphQL; -import graphql.schema.DataFetcher; -import graphql.schema.idl.RuntimeWiring; import org.junit.jupiter.api.Test; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.util.context.Context; +import reactor.util.context.ContextView; -import org.springframework.core.io.ByteArrayResource; +import org.springframework.graphql.GraphQlTestUtils; import static org.assertj.core.api.Assertions.assertThat; /** - * Tests for {@link ReactorDataFetcherAdapter}. + * Tests for {@link ContextDataFetcherDecorator}. */ -public class ReactorDataFetcherAdapterTests { +public class ContextDataFetcherDecoratorTests { @Test void monoDataFetcher() throws Exception { - GraphQL graphQl = graphQl("type Query { greeting: String }", + GraphQL graphQl = GraphQlTestUtils.initGraphQl("type Query { greeting: String }", "Query", "greeting", env -> Mono.deferContextual(context -> { Object name = context.get("name"); return Mono.delay(Duration.ofMillis(50)).map(aLong -> "Hello " + name); })); - ExecutionInput input = executionInput("{ greeting }", Context.of("name", "007")); + ExecutionInput input = ExecutionInput.newExecutionInput().query("{ greeting }").build(); + ContextManager.setReactorContext(Context.of("name", "007"), input); + Map data = graphQl.executeAsync(input).get().getData(); assertThat(data).hasSize(1).containsEntry("greeting", "Hello 007"); @@ -57,7 +57,7 @@ public class ReactorDataFetcherAdapterTests { @Test void fluxDataFetcher() throws Exception { - GraphQL graphQl = graphQl("type Query { greetings: [String] }", + GraphQL graphQl = GraphQlTestUtils.initGraphQl("type Query { greetings: [String] }", "Query", "greetings", env -> Mono.delay(Duration.ofMillis(50)).flatMapMany(aLong -> Flux.deferContextual(context -> { @@ -65,7 +65,9 @@ public class ReactorDataFetcherAdapterTests { return Flux.just("Hi", "Bonjour", "Hola").map(s -> s + " " + name); }))); - ExecutionInput input = executionInput("{ greetings }", Context.of("name", "007")); + ExecutionInput input = ExecutionInput.newExecutionInput().query("{ greetings }").build(); + ContextManager.setReactorContext(Context.of("name", "007"), input); + Map data = graphQl.executeAsync(input).get().getData(); assertThat((List) data.get("greetings")).containsExactly("Hi 007", "Bonjour 007", "Hola 007"); @@ -73,7 +75,7 @@ public class ReactorDataFetcherAdapterTests { @Test void fluxDataFetcherSubscription() throws Exception { - GraphQL graphQl = graphQl( + GraphQL graphQl = GraphQlTestUtils.initGraphQl( "type Query { greeting: String } type Subscription { greetings: String }", "Subscription", "greetings", env -> Mono.delay(Duration.ofMillis(50)).flatMapMany(aLong -> @@ -82,7 +84,9 @@ public class ReactorDataFetcherAdapterTests { return Flux.just("Hi", "Bonjour", "Hola").map(s -> s + " " + name); }))); - ExecutionInput input = executionInput("subscription { greetings }", Context.of("name", "007")); + ExecutionInput input = ExecutionInput.newExecutionInput().query("subscription { greetings }").build(); + ContextManager.setReactorContext(Context.of("name", "007"), input); + Publisher publisher = graphQl.executeAsync(input).get().getData(); List actual = Flux.from(publisher) @@ -95,21 +99,32 @@ public class ReactorDataFetcherAdapterTests { assertThat(actual).containsExactly("Hi 007", "Bonjour 007", "Hola 007"); } - private GraphQL graphQl(String schemaValue, String typeName, String fieldName, DataFetcher dataFetcher) { - RuntimeWiring wiring = RuntimeWiring.newRuntimeWiring() - .type(typeName, builder -> builder.dataFetcher(fieldName, dataFetcher)) - .build(); - return GraphQlSource.builder() - .schemaResource(new ByteArrayResource(schemaValue.getBytes(StandardCharsets.UTF_8))) - .runtimeWiring(wiring) - .build() - .graphQl(); - } + @Test + void dataFetcherWithThreadLocalContext() { + long threadId = Thread.currentThread().getId(); + ThreadLocal nameThreadLocal = new ThreadLocal<>(); + nameThreadLocal.set("007"); + try { + GraphQL graphQl = GraphQlTestUtils.initGraphQl("type Query { greeting: String }", + "Query", "greeting", env -> { + assertThat(Thread.currentThread().getId() != threadId).as("Not on async thread").isTrue(); + return "Hello " + nameThreadLocal.get(); + }); - private ExecutionInput executionInput(String query, Context reactorContext) { - ExecutionInput input = ExecutionInput.newExecutionInput().query(query).build(); - ContextManager.setReactorContext(reactorContext, input); - return input; + ExecutionInput input = ExecutionInput.newExecutionInput().query("{ greeting }").build(); + ContextView view = ContextManager.extractThreadLocalValues(new TestThreadLocalAccessor(nameThreadLocal)); + ContextManager.setReactorContext(view, input); + + ExecutionResult result = Mono.delay(Duration.ofMillis(10)) + .flatMap(aLong -> Mono.fromFuture(graphQl.executeAsync(input))) + .block(); + + Map data = result.getData(); + assertThat(data).hasSize(1).containsEntry("greeting", "Hello 007"); + } + finally { + nameThreadLocal.remove(); + } } } diff --git a/spring-graphql/src/test/java/org/springframework/graphql/execution/ExceptionResolversExceptionHandlerTests.java b/spring-graphql/src/test/java/org/springframework/graphql/execution/ExceptionResolversExceptionHandlerTests.java index ae46c117..2e0dff09 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/execution/ExceptionResolversExceptionHandlerTests.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/execution/ExceptionResolversExceptionHandlerTests.java @@ -15,8 +15,7 @@ */ package org.springframework.graphql.execution; -import java.nio.charset.StandardCharsets; -import java.util.Arrays; +import java.time.Duration; import java.util.Collections; import java.util.List; import java.util.Map; @@ -26,14 +25,12 @@ import graphql.ExecutionResult; import graphql.GraphQL; import graphql.GraphQLError; import graphql.GraphqlErrorBuilder; -import graphql.schema.DataFetcher; -import graphql.schema.DataFetchingEnvironment; -import graphql.schema.idl.RuntimeWiring; import org.junit.jupiter.api.Test; import reactor.core.publisher.Mono; import reactor.util.context.Context; +import reactor.util.context.ContextView; -import org.springframework.core.io.ByteArrayResource; +import org.springframework.graphql.GraphQlTestUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -44,23 +41,17 @@ public class ExceptionResolversExceptionHandlerTests { @Test void resolveException() throws Exception { - GraphQL graphQl = graphQl("type Query { greeting: String }", + GraphQL graphQl = GraphQlTestUtils.initGraphQl("type Query { greeting: String }", "Query", "greeting", env -> { throw new IllegalArgumentException("Invalid greeting"); }, - new SingleErrorExceptionResolver() { - @Override - protected Mono doResolve(Throwable ex, DataFetchingEnvironment env) { - return Mono.deferContextual(view -> - Mono.just(GraphqlErrorBuilder.newError(env) - .message("Resolved error: " + ex.getMessage() + ", name=" + view.get("name")) - .errorType(ErrorType.BAD_REQUEST) - .build())); - }}); + (ex, env) -> Mono.just(Collections.singletonList( + GraphqlErrorBuilder.newError(env) + .message("Resolved error: " + ex.getMessage()) + .errorType(ErrorType.BAD_REQUEST) + .build()))); ExecutionInput input = ExecutionInput.newExecutionInput().query("{ greeting }").build(); - ContextManager.setReactorContext(Context.of("name", "007"), input); - ExecutionResult result = graphQl.executeAsync(input).get(); Map data = result.getData(); @@ -69,13 +60,69 @@ public class ExceptionResolversExceptionHandlerTests { List errors = result.getErrors(); assertThat(errors).hasSize(1); GraphQLError error = errors.get(0); - assertThat(error.getMessage()).isEqualTo("Resolved error: Invalid greeting, name=007"); + assertThat(error.getMessage()).isEqualTo("Resolved error: Invalid greeting"); assertThat(error.getErrorType().toString()).isEqualTo("BAD_REQUEST"); } + @Test + void resolveExceptionWithReactorContext() throws Exception { + GraphQL graphQl = GraphQlTestUtils.initGraphQl("type Query { greeting: String }", + "Query", "greeting", env -> { + throw new IllegalArgumentException("Invalid greeting"); + }, + (ex, env) -> Mono.deferContextual(view -> Mono.just(Collections.singletonList( + GraphqlErrorBuilder.newError(env) + .message("Resolved error: " + ex.getMessage() + ", name=" + view.get("name")) + .errorType(ErrorType.BAD_REQUEST) + .build())))); + + ExecutionInput input = ExecutionInput.newExecutionInput().query("{ greeting }").build(); + ContextManager.setReactorContext(Context.of("name", "007"), input); + + ExecutionResult result = graphQl.executeAsync(input).get(); + + GraphQLError error = result.getErrors().get(0); + assertThat(error.getMessage()).isEqualTo("Resolved error: Invalid greeting, name=007"); + } + + @Test + void resolveExceptionWithThreadLocal() { + long threadId = Thread.currentThread().getId(); + ThreadLocal nameThreadLocal = new ThreadLocal<>(); + nameThreadLocal.set("007"); + try { + GraphQL graphQl = GraphQlTestUtils.initGraphQl("type Query { greeting: String }", + "Query", "greeting", env -> { + throw new IllegalArgumentException("Invalid greeting"); + }, + (SyncDataFetcherExceptionResolver) (ex, env) -> { + assertThat(Thread.currentThread().getId() != threadId).as("Not on async thread").isTrue(); + return Collections.singletonList( + GraphqlErrorBuilder.newError(env) + .message("Resolved error: " + ex.getMessage() + ", name=" + nameThreadLocal.get()) + .errorType(ErrorType.BAD_REQUEST) + .build()); + }); + + ExecutionInput input = ExecutionInput.newExecutionInput().query("{ greeting }").build(); + ContextView view = ContextManager.extractThreadLocalValues(new TestThreadLocalAccessor(nameThreadLocal)); + ContextManager.setReactorContext(view, input); + + ExecutionResult result = Mono.delay(Duration.ofMillis(10)) + .flatMap(aLong -> Mono.fromFuture(graphQl.executeAsync(input))) + .block(); + + GraphQLError error = result.getErrors().get(0); + assertThat(error.getMessage()).isEqualTo("Resolved error: Invalid greeting, name=007"); + } + finally { + nameThreadLocal.remove(); + } + } + @Test void unresolvedException() throws Exception { - GraphQL graphQl = graphQl("type Query { greeting: String }", + GraphQL graphQl = GraphQlTestUtils.initGraphQl("type Query { greeting: String }", "Query", "greeting", env -> { throw new IllegalArgumentException("Invalid greeting"); }, @@ -96,7 +143,7 @@ public class ExceptionResolversExceptionHandlerTests { @Test void suppressedException() throws Exception { - GraphQL graphQl = graphQl("type Query { greeting: String }", + GraphQL graphQl = GraphQlTestUtils.initGraphQl("type Query { greeting: String }", "Query", "greeting", env -> { throw new IllegalArgumentException("Invalid greeting"); }, @@ -110,20 +157,4 @@ public class ExceptionResolversExceptionHandlerTests { assertThat(result.getErrors()).hasSize(0); } - private GraphQL graphQl(String schemaContent, - String typeName, String fieldName, DataFetcher dataFetcher, - DataFetcherExceptionResolver... resolvers) { - - RuntimeWiring wiring = RuntimeWiring.newRuntimeWiring() - .type(typeName, builder -> builder.dataFetcher(fieldName, dataFetcher)) - .build(); - - return GraphQlSource.builder() - .schemaResource(new ByteArrayResource(schemaContent.getBytes(StandardCharsets.UTF_8))) - .runtimeWiring(wiring) - .exceptionResolvers(Arrays.asList(resolvers)) - .build() - .graphQl(); - } - } diff --git a/spring-graphql/src/test/java/org/springframework/graphql/execution/TestThreadLocalAccessor.java b/spring-graphql/src/test/java/org/springframework/graphql/execution/TestThreadLocalAccessor.java new file mode 100644 index 00000000..4937003d --- /dev/null +++ b/spring-graphql/src/test/java/org/springframework/graphql/execution/TestThreadLocalAccessor.java @@ -0,0 +1,53 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.graphql.execution; + +import java.util.Map; + +import org.springframework.util.Assert; + +/** + * {@link ThreadLocalAccessor} that operates on the ThreadLocal it is given. + */ +class TestThreadLocalAccessor implements ThreadLocalAccessor { + + private final ThreadLocal threadLocal; + + + TestThreadLocalAccessor(ThreadLocal threadLocal) { + this.threadLocal = threadLocal; + } + + + @Override + public void extractValues(Map container) { + String name = this.threadLocal.get(); + Assert.notNull(name, "No ThreadLocal value"); + container.put("name", name); + } + + @Override + public void restoreValues(Map values) { + String name = (String) values.get("name"); + Assert.notNull(name, "No value to set"); + this.threadLocal.set(name); + } + + @Override + public void resetValues(Map values) { + this.threadLocal.remove(); + } +}