ThreadLocal context propagation for DataFetcher's

The WebMvc starter now supports propagation of ThreadLocal values
extracted at the level of the HTTP handler.

Closes gh-53
This commit is contained in:
Rossen Stoyanchev
2021-05-28 16:26:03 +01:00
parent ca18d03244
commit 56917d8d83
26 changed files with 757 additions and 150 deletions

View File

@@ -40,6 +40,7 @@ import org.springframework.core.io.Resource;
import org.springframework.core.io.ResourceLoader;
import org.springframework.graphql.GraphQlService;
import org.springframework.graphql.execution.GraphQlSource;
import org.springframework.graphql.execution.ThreadLocalAccessor;
import org.springframework.graphql.web.WebGraphQlHandler;
import org.springframework.graphql.web.WebInterceptor;
import org.springframework.graphql.web.webmvc.GraphQlHttpHandler;
@@ -70,9 +71,13 @@ public class WebMvcGraphQlAutoConfiguration {
@Bean
@ConditionalOnMissingBean
public WebGraphQlHandler webGraphQlHandler(ObjectProvider<WebInterceptor> interceptors, GraphQlService service) {
public WebGraphQlHandler webGraphQlHandler(
ObjectProvider<WebInterceptor> interceptorsProvider, GraphQlService service,
ObjectProvider<ThreadLocalAccessor> accessorsProvider) {
return WebGraphQlHandler.builder(service)
.interceptors(interceptors.orderedStream().collect(Collectors.toList()))
.interceptors(interceptorsProvider.orderedStream().collect(Collectors.toList()))
.threadLocalAccessors(accessorsProvider.orderedStream().collect(Collectors.toList()))
.build();
}

View File

@@ -23,9 +23,9 @@ import org.springframework.web.server.WebFilterChain;
/**
* WebFilter that inserts a key-value pair into the Reactor context which is
* transferred to and accessible to Reactor-based data fetchers.
* transferred to and accessible in Reactor data fetchers.
*/
public class ReactorContextWebFilter implements WebFilter {
public class ContextWebFilter implements WebFilter {
@Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {

View File

@@ -28,8 +28,8 @@ public class SampleApplication {
}
@Bean
ReactorContextWebFilter reactorContextWebFilter() {
return new ReactorContextWebFilter();
ContextWebFilter reactorContextWebFilter() {
return new ContextWebFilter();
}
}

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -21,7 +21,9 @@ import org.springframework.boot.autoconfigure.SpringBootApplication;
@SpringBootApplication
public class SampleApplication {
public static void main(String[] args) {
SpringApplication.run(SampleApplication.class, args);
}
}

View File

@@ -0,0 +1,24 @@
package io.spring.sample.graphql.greeting;
import graphql.schema.idl.RuntimeWiring;
import org.springframework.graphql.boot.RuntimeWiringCustomizer;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import static org.springframework.web.context.request.RequestAttributes.SCOPE_REQUEST;
@Component
public class GreetingDataWiring implements RuntimeWiringCustomizer {
@Override
public void customize(RuntimeWiring.Builder builder) {
builder.type("Query", typeWiring ->
typeWiring.dataFetcher("greeting", env -> {
RequestAttributes attributes = RequestContextHolder.getRequestAttributes();
return "Hello " + attributes.getAttribute(RequestAttributeFilter.NAME_ATTRIBUTE, SCOPE_REQUEST);
}));
}
}

View File

@@ -0,0 +1,43 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.spring.sample.graphql.greeting;
import java.io.IOException;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import org.springframework.stereotype.Component;
/**
* Servlet Filter that adds a Servlet request attribute.
*/
@Component
public class RequestAttributeFilter implements Filter {
public static final String NAME_ATTRIBUTE = RequestAttributeFilter.class.getName() + ".name";
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
request.setAttribute(NAME_ATTRIBUTE, "007");
chain.doFilter(request, response);
}
}

View File

@@ -0,0 +1,60 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.spring.sample.graphql.greeting;
import java.util.Collections;
import java.util.Map;
import org.springframework.graphql.execution.ThreadLocalAccessor;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
/**
* {@link ThreadLocalAccessor} to expose a thread-bound RequestAttributes object
* to data fetchers in Spring GraphQL.
*/
@Component
public class RequestAttributesAccessor implements ThreadLocalAccessor {
private static final String ATTRIBUTES_KEY =
RequestAttributesAccessor.class.getName() + ".requestAttributes";
@Override
public void extractValues(Map<String, Object> container) {
RequestAttributes attributes = RequestContextHolder.getRequestAttributes();
if (attributes != null) {
container.put(ATTRIBUTES_KEY, attributes);
}
}
@Override
public void restoreValues(Map<String, Object> values) {
RequestAttributes attributes = (RequestAttributes) values.get(ATTRIBUTES_KEY);
if (attributes != null) {
RequestContextHolder.setRequestAttributes(attributes);
}
}
@Override
public void resetValues(Map<String, Object> values) {
if (values.get(ATTRIBUTES_KEY) != null) {
RequestContextHolder.resetRequestAttributes();
}
}
}

View File

@@ -0,0 +1,6 @@
@NonNullApi
@NonNullFields
package io.spring.sample.graphql.greeting;
import org.springframework.lang.NonNullApi;
import org.springframework.lang.NonNullFields;

View File

@@ -0,0 +1,6 @@
@NonNullApi
@NonNullFields
package io.spring.sample.graphql;
import org.springframework.lang.NonNullApi;
import org.springframework.lang.NonNullFields;

View File

@@ -0,0 +1,6 @@
@NonNullApi
@NonNullFields
package io.spring.sample.graphql.project;
import org.springframework.lang.NonNullApi;
import org.springframework.lang.NonNullFields;

View File

@@ -0,0 +1,6 @@
@NonNullApi
@NonNullFields
package io.spring.sample.graphql.repository;
import org.springframework.lang.NonNullApi;
import org.springframework.lang.NonNullFields;

View File

@@ -1,4 +1,5 @@
type Query {
greeting: String
artifactRepositories : [ArtifactRepository]
artifactRepository(id : ID!) : ArtifactRepository
project(slug: ID!): Project

View File

@@ -0,0 +1,49 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.graphql.execution;
import java.util.List;
import java.util.Map;
/**
* Default implementation of a composite accessor that is returned from
* {@link ThreadLocalAccessor#composite(List)}.
*/
class CompositeThreadLocalAccessor implements ThreadLocalAccessor {
private final List<ThreadLocalAccessor> accessors;
public CompositeThreadLocalAccessor(List<ThreadLocalAccessor> accessors) {
this.accessors = accessors;
}
@Override
public void extractValues(Map<String, Object> container) {
this.accessors.forEach(accessor -> accessor.extractValues(container));
}
@Override
public void restoreValues(Map<String, Object> values) {
this.accessors.forEach(accessor -> accessor.restoreValues(values));
}
@Override
public void resetValues(Map<String, Object> values) {
this.accessors.forEach(accessor -> accessor.resetValues(values));
}
}

View File

@@ -34,19 +34,22 @@ import reactor.util.context.ContextView;
import org.springframework.util.Assert;
/**
* Adapter that can wrap a registered {@link DataFetcher} and enable it to return
* {@link Flux} or {@link Mono}, also adding Reactor Context passed through
* the {@link ExecutionInput}. Also exposes a {@link #TYPE_VISITOR} to apply
* the adapter.
* Wrap a {@link DataFetcher} to enable the following:
* <ul>
* <li>Support {@link Mono} return value.
* <li>Support {@link Flux} return value as a shortcut to {@link Flux#collectList()}.
* <li>Re-establish Reactor Context passed via {@link ExecutionInput}.
* <li>Re-establish ThreadLocal context passed via {@link ExecutionInput}.
* </ul>
*/
class ReactorDataFetcherAdapter implements DataFetcher<Object> {
class ContextDataFetcherDecorator implements DataFetcher<Object> {
private final DataFetcher<?> delegate;
private final boolean subscription;
private ReactorDataFetcherAdapter(DataFetcher<?> delegate, boolean subscription) {
private ContextDataFetcherDecorator(DataFetcher<?> delegate, boolean subscription) {
Assert.notNull(delegate, "'delegate' DataFetcher is required");
this.delegate = delegate;
this.subscription = subscription;
@@ -55,11 +58,19 @@ class ReactorDataFetcherAdapter implements DataFetcher<Object> {
@Override
public Object get(DataFetchingEnvironment environment) throws Exception {
Object value = this.delegate.get(environment);
ContextView contextView = ContextManager.getReactorContext(environment);
Object value;
try {
ContextManager.restoreThreadLocalValues(contextView);
value = this.delegate.get(environment);
}
finally {
ContextManager.resetThreadLocalValues(contextView);
}
if (this.subscription) {
ContextView context = ContextManager.getReactorContext(environment);
return (context != null ? Flux.from((Publisher<?>) value).contextWrite(context) : value);
return (!contextView.isEmpty() ? Flux.from((Publisher<?>) value).contextWrite(contextView) : value);
}
if (value instanceof Flux) {
@@ -68,9 +79,8 @@ class ReactorDataFetcherAdapter implements DataFetcher<Object> {
if (value instanceof Mono) {
Mono<?> valueMono = (Mono<?>) value;
ContextView reactorContext = ContextManager.getReactorContext(environment);
if (reactorContext != null) {
valueMono = valueMono.contextWrite(reactorContext);
if (!contextView.isEmpty()) {
valueMono = valueMono.contextWrite(contextView);
}
value = valueMono.toFuture();
}
@@ -98,7 +108,7 @@ class ReactorDataFetcherAdapter implements DataFetcher<Object> {
}
boolean handlesSubscription = parent.getName().equals("Subscription");
dataFetcher = new ReactorDataFetcherAdapter(dataFetcher, handlesSubscription);
dataFetcher = new ContextDataFetcherDecorator(dataFetcher, handlesSubscription);
codeRegistry.dataFetcher(parent, fieldDefinition, dataFetcher);
return TraversalControl.CONTINUE;
}

View File

@@ -15,9 +15,13 @@
*/
package org.springframework.graphql.execution;
import java.util.LinkedHashMap;
import java.util.Map;
import graphql.ExecutionInput;
import graphql.GraphQLContext;
import graphql.schema.DataFetchingEnvironment;
import reactor.util.context.Context;
import reactor.util.context.ContextView;
import org.springframework.lang.Nullable;
@@ -27,10 +31,16 @@ import org.springframework.lang.Nullable;
* through the {@link ExecutionInput} and the {@link DataFetchingEnvironment}
* of a request.
*/
abstract class ContextManager {
public abstract class ContextManager {
private static final String REACTOR_CONTEXT_KEY =
ReactorDataFetcherAdapter.class.getName() + ".REACTOR_CONTEXT";
private static final String CONTEXT_VIEW_KEY =
ContextManager.class.getName() + ".CONTEXT_VIEW";
private static final String THREAD_LOCAL_VALUES_KEY =
ContextManager.class.getName() + ".THREAD_VALUES_ACCESSOR";
private static final String THREAD_LOCAL_ACCESSOR_KEY =
ContextManager.class.getName() + ".THREAD_LOCAL_ACCESSOR";
/**
@@ -38,17 +48,55 @@ abstract class ContextManager {
* later access through the {@link DataFetchingEnvironment}.
*/
static void setReactorContext(ContextView contextView, ExecutionInput input) {
((GraphQLContext) input.getContext()).put(REACTOR_CONTEXT_KEY, contextView);
((GraphQLContext) input.getContext()).put(CONTEXT_VIEW_KEY, contextView);
}
/**
* Return the Reactor ContextView saved in the given DataFetchingEnvironment,
* or null if not present.
* Return the Reactor ContextView saved in the given DataFetchingEnvironment.
*/
@Nullable
static ContextView getReactorContext(DataFetchingEnvironment environment) {
GraphQLContext graphQlContext = environment.getContext();
return graphQlContext.get(REACTOR_CONTEXT_KEY);
return graphQlContext.getOrDefault(CONTEXT_VIEW_KEY, Context.empty());
}
/**
* Use the given accessor to extract ThreadLocal value, and return a Reactor
* context that contains both the extracted values and the accessor.
* @param accessor the accessor to use
*/
public static ContextView extractThreadLocalValues(ThreadLocalAccessor accessor) {
Map<String, Object> valuesMap = new LinkedHashMap<>();
accessor.extractValues(valuesMap);
return Context.of(THREAD_LOCAL_VALUES_KEY, valuesMap, THREAD_LOCAL_ACCESSOR_KEY, accessor);
}
/**
* Look up saved ThreadLocal values and use them to re-establish ThreadLocal context.
*/
static void restoreThreadLocalValues(ContextView contextView) {
ThreadLocalAccessor accessor = getThreadLocalAccessor(contextView);
if (accessor != null) {
accessor.restoreValues(getThreadLocalValues(contextView));
}
}
/**
* Look up saved ThreadLocal values and remove associated ThreadLocal context.
*/
static void resetThreadLocalValues(ContextView contextView) {
ThreadLocalAccessor accessor = getThreadLocalAccessor(contextView);
if (accessor != null) {
accessor.resetValues(getThreadLocalValues(contextView));
}
}
@Nullable
private static ThreadLocalAccessor getThreadLocalAccessor(ContextView contextView) {
return (contextView.hasKey(THREAD_LOCAL_ACCESSOR_KEY) ? contextView.get(THREAD_LOCAL_ACCESSOR_KEY) : null);
}
private static Map<String, Object> getThreadLocalValues(ContextView contextView) {
return contextView.get(THREAD_LOCAL_VALUES_KEY);
}
}

View File

@@ -58,7 +58,7 @@ class DefaultGraphQlSourceBuilder implements GraphQlSource.Builder {
DefaultGraphQlSourceBuilder() {
this.typeVisitors.add(ReactorDataFetcherAdapter.TYPE_VISITOR);
this.typeVisitors.add(ContextDataFetcherDecorator.TYPE_VISITOR);
}

View File

@@ -16,7 +16,6 @@
package org.springframework.graphql.execution;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletionException;
@@ -26,11 +25,14 @@ import graphql.execution.DataFetcherExceptionHandler;
import graphql.execution.DataFetcherExceptionHandlerParameters;
import graphql.execution.DataFetcherExceptionHandlerResult;
import graphql.schema.DataFetchingEnvironment;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Schedulers;
import reactor.core.publisher.Mono;
import reactor.util.context.ContextView;
import org.springframework.util.Assert;
import org.springframework.web.client.ExtractingResponseErrorHandler;
/**
* {@link DataFetcherExceptionHandler} that invokes {@link DataFetcherExceptionResolver}'s
@@ -38,6 +40,9 @@ import org.springframework.util.Assert;
*/
class ExceptionResolversExceptionHandler implements DataFetcherExceptionHandler {
private static Log logger = LogFactory.getLog(ExtractingResponseErrorHandler.class);
private final List<DataFetcherExceptionResolver> resolvers;
@@ -60,24 +65,35 @@ class ExceptionResolversExceptionHandler implements DataFetcherExceptionHandler
@SuppressWarnings("ConstantConditions")
public DataFetcherExceptionHandlerResult invokeChain(Throwable ex, DataFetchingEnvironment env) {
return Flux.fromIterable(this.resolvers)
.publishOn(Schedulers.boundedElastic()) // until GraphQL Java supports async exception handling
.flatMap(resolver -> resolver.resolveException(ex, env))
.next()
.defaultIfEmpty(Collections.singletonList(applyDefaultHandling(ex, env)))
.map(errors -> DataFetcherExceptionHandlerResult.newResult().errors(errors).build())
.contextWrite(context -> {
ContextView contextToAdd = ContextManager.getReactorContext(env);
return (contextToAdd != null ? context.putAll(contextToAdd) : context);
})
.block();
// For now we have to block:
// https://github.com/graphql-java/graphql-java/issues/2356
try {
return Flux.fromIterable(this.resolvers)
.flatMap(resolver -> resolver.resolveException(ex, env))
.next()
.map(errors -> DataFetcherExceptionHandlerResult.newResult().errors(errors).build())
.switchIfEmpty(Mono.fromCallable(() -> applyDefaultHandling(ex, env)))
.contextWrite(context -> {
ContextView contextView = ContextManager.getReactorContext(env);
return (contextView.isEmpty() ? context : context.putAll(contextView));
})
.toFuture()
.get();
}
catch (Exception ex2) {
if (logger.isWarnEnabled()) {
logger.warn("Failed to handle " + ex.getMessage(), ex2);
}
return applyDefaultHandling(ex, env);
}
}
private GraphQLError applyDefaultHandling(Throwable ex, DataFetchingEnvironment env) {
return GraphqlErrorBuilder.newError(env)
private DataFetcherExceptionHandlerResult applyDefaultHandling(Throwable ex, DataFetchingEnvironment env) {
GraphQLError error = GraphqlErrorBuilder.newError(env)
.message(ex.getMessage())
.errorType(ErrorType.INTERNAL_ERROR)
.build();
return DataFetcherExceptionHandlerResult.newResult(error).build();
}
}

View File

@@ -0,0 +1,49 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.graphql.execution;
import java.util.List;
import graphql.GraphQLError;
import graphql.schema.DataFetchingEnvironment;
import reactor.core.publisher.Mono;
import reactor.util.context.ContextView;
/**
* {@link DataFetcherExceptionResolver} that resolves exceptions synchronously.
*/
public interface SyncDataFetcherExceptionResolver extends DataFetcherExceptionResolver {
@Override
default Mono<List<GraphQLError>> resolveException(Throwable exception, DataFetchingEnvironment environment) {
ContextView contextView = ContextManager.getReactorContext(environment);
try {
ContextManager.restoreThreadLocalValues(contextView);
return Mono.just(doResolveException(exception, environment));
}
finally {
ContextManager.resetThreadLocalValues(contextView);
}
}
/**
* Implement this method to resolve exceptions.
* @param exception the exception to resolve
* @param environment the environment for the invoked {@code DataFetcher}
*/
List<GraphQLError> doResolveException(Throwable exception, DataFetchingEnvironment environment);
}

View File

@@ -0,0 +1,66 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.graphql.execution;
import java.util.List;
import java.util.Map;
import org.springframework.beans.factory.ObjectProvider;
/**
* Interface to be implemented by a framework or an application in order to
* assist with extracting ThreadLocal values at the web layer, which can then be
* re-established for DataFetcher's that are potentially executing on a
* different thread.
*
* <p>Implementations may be declared as beans in Spring configuration and
* ordered as defined in {@link ObjectProvider#orderedStream()}.
*/
public interface ThreadLocalAccessor {
/**
* Extract ThreadLocal values and add them to the given Map which is then
* passed to {@link #restoreValues(Map)} and {@link #resetValues(Map)}
* before and after the execution of a {@link graphql.schema.DataFetcher}.
* @param container container for ThreadLocal values
*/
void extractValues(Map<String, Object> container);
/**
* Re-establish ThreadLocal context by looking up values, previously
* extracted via {@link #extractValues(Map)}.
* @param values the saved ThreadLocal values
*/
void restoreValues(Map<String, Object> values);
/**
* Reset ThreadLocal context for the given values, previously extracted
* via {@link #extractValues(Map)}.
* @param values the saved ThreadLocal values
*/
void resetValues(Map<String, Object> values);
/**
* Create a composite accessor that delegates to all of the given accessors.
* @param accessors the accessors to aggregate
* @return the composite accessor
*/
static ThreadLocalAccessor composite(List<ThreadLocalAccessor> accessors) {
return new CompositeThreadLocalAccessor(accessors);
}
}

View File

@@ -20,10 +20,15 @@ import java.util.Collections;
import java.util.List;
import graphql.ExecutionInput;
import reactor.core.publisher.Mono;
import reactor.util.context.ContextView;
import org.springframework.graphql.GraphQlService;
import org.springframework.graphql.execution.ContextManager;
import org.springframework.graphql.execution.ThreadLocalAccessor;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
/**
* Default implementation of {@link WebGraphQlHandler.Builder}.
@@ -35,17 +40,31 @@ class DefaultWebGraphQlHandlerBuilder implements WebGraphQlHandler.Builder {
@Nullable
private List<WebInterceptor> interceptors;
@Nullable
private List<ThreadLocalAccessor> accessors;
DefaultWebGraphQlHandlerBuilder(GraphQlService service) {
Assert.notNull(service, "GraphQlService must not be null");
Assert.notNull(service, "GraphQlService is required");
this.service = service;
}
@Override
public WebGraphQlHandler.Builder interceptors(List<WebInterceptor> interceptors) {
this.interceptors = (this.interceptors != null ? this.interceptors : new ArrayList<>());
this.interceptors.addAll(interceptors);
if (!CollectionUtils.isEmpty(interceptors)) {
this.interceptors = (this.interceptors != null ? this.interceptors : new ArrayList<>());
this.interceptors.addAll(interceptors);
}
return this;
}
@Override
public WebGraphQlHandler.Builder threadLocalAccessors(List<ThreadLocalAccessor> accessors) {
if (!CollectionUtils.isEmpty(accessors)) {
this.accessors = (this.accessors != null ? this.accessors : new ArrayList<>());
this.accessors.addAll(accessors);
}
return this;
}
@@ -54,10 +73,13 @@ class DefaultWebGraphQlHandlerBuilder implements WebGraphQlHandler.Builder {
List<WebInterceptor> interceptorsToUse =
(this.interceptors != null ? this.interceptors : Collections.emptyList());
return interceptorsToUse.stream()
WebGraphQlHandler handler = interceptorsToUse.stream()
.reduce(WebInterceptor::andThen)
.map(interceptor -> (WebGraphQlHandler) input -> interceptor.intercept(input, createHandler()))
.orElse(createHandler());
return (CollectionUtils.isEmpty(this.accessors) ? handler :
new ThreadLocalExtractingHandler(handler, ThreadLocalAccessor.composite(this.accessors)));
}
private WebGraphQlHandler createHandler() {
@@ -67,4 +89,30 @@ class DefaultWebGraphQlHandlerBuilder implements WebGraphQlHandler.Builder {
};
}
/**
* {@link WebGraphQlHandler} that extracts ThreadLocal values and saves them
* in the Reactor context for subsequent use for DataFetcher's.
*/
private static class ThreadLocalExtractingHandler implements WebGraphQlHandler {
private final WebGraphQlHandler delegate;
private final ThreadLocalAccessor accessor;
ThreadLocalExtractingHandler(WebGraphQlHandler delegate, ThreadLocalAccessor accessor) {
this.delegate = delegate;
this.accessor = accessor;
}
@Override
public Mono<WebOutput> handle(WebInput input) {
return this.delegate.handle(input)
.contextWrite(context -> {
ContextView view = ContextManager.extractThreadLocalValues(this.accessor);
return (!view.isEmpty() ? context.putAll(view) : context);
});
}
}
}

View File

@@ -20,12 +20,11 @@ import java.util.List;
import reactor.core.publisher.Mono;
import org.springframework.graphql.GraphQlService;
import org.springframework.graphql.execution.ThreadLocalAccessor;
/**
* Contract to handle a GraphQL over HTTP or WebSocket request that forms the
* basis of a {@link WebInterceptor} delegation chain.
*
* @see WebInterceptor#createHandler(List, GraphQlService)
*/
public interface WebGraphQlHandler {
@@ -37,4 +36,41 @@ public interface WebGraphQlHandler {
*/
Mono<WebOutput> handle(WebInput input);
/**
* Provides access to a builder to create a {@link WebGraphQlHandler} instance.
* @param graphQlService the {@link GraphQlService} to use for actual
* execution of the request.
*/
static Builder builder(GraphQlService graphQlService) {
return new DefaultWebGraphQlHandlerBuilder(graphQlService);
}
/**
* Builder for {@link WebGraphQlHandler} that represents a {@link WebInterceptor}
* chain followed by a {@link GraphQlService}.
*/
interface Builder {
/**
* Configure interceptors to be invoked before the target {@code GraphQlService}.
* @param interceptors the interceptors to add
*/
Builder interceptors(List<WebInterceptor> interceptors);
/**
* Configure accessors for ThreadLocal variables to use to extract
* ThreadLocal values at the Web framework level, have those propagated
* and re-established at the DataFetcher level.
* @param accessors the accessors to add
*/
Builder threadLocalAccessors(List<ThreadLocalAccessor> accessors);
/**
* Build the {@link WebGraphQlHandler} instance.
*/
WebGraphQlHandler build();
}
}

View File

@@ -15,14 +15,11 @@
*/
package org.springframework.graphql.web;
import java.util.List;
import graphql.ExecutionInput;
import graphql.ExecutionResult;
import reactor.core.publisher.Mono;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.graphql.GraphQlService;
import org.springframework.util.Assert;
/**
@@ -58,38 +55,4 @@ public interface WebInterceptor {
return (currentInput, next) -> intercept(currentInput, nextInput -> interceptor.intercept(nextInput, next));
}
/**
* Return {@link WebGraphQlHandler} that invokes the current interceptor
* first and then the given {@link GraphQlService} for actual execution of
* the GraphQL operation.
*/
default WebGraphQlHandler apply(GraphQlService service) {
Assert.notNull(service, "GraphQlService must not be null");
return currentInput -> intercept(currentInput, createHandler(service));
}
/**
* Factory method for a {@link WebGraphQlHandler} with a chain of
* interceptors followed by a {@link GraphQlService} at the end.
*/
static WebGraphQlHandler createHandler(List<WebInterceptor> interceptors, GraphQlService service) {
return interceptors.stream()
.reduce(WebInterceptor::andThen)
.map(interceptor -> interceptor.apply(service))
.orElse(createHandler(service));
}
/**
* Factory method for a {@link WebGraphQlHandler} that simple invokes the
* given {@link GraphQlService} adapting to its input and output.
*/
static WebGraphQlHandler createHandler(GraphQlService graphQlService) {
Assert.notNull(graphQlService, "GraphQlService must not be null");
return webInput -> {
ExecutionInput executionInput = webInput.toExecutionInput();
return graphQlService.execute(executionInput).map(result -> new WebOutput(webInput, result));
};
}
}

View File

@@ -0,0 +1,64 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.graphql;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import graphql.GraphQL;
import graphql.schema.DataFetcher;
import graphql.schema.idl.RuntimeWiring;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.graphql.execution.DataFetcherExceptionResolver;
import org.springframework.graphql.execution.GraphQlSource;
/**
* Utility methods for GraphQL tests.
*/
public abstract class GraphQlTestUtils {
public static GraphQL initGraphQl(
String schemaContent, String typeName, String fieldName, DataFetcher<?> fetcher) {
return initGraphQlSource(schemaContent, typeName, fieldName, fetcher)
.build()
.graphQl();
}
public static GraphQL initGraphQl(
String schemaContent, String typeName, String fieldName, DataFetcher<?> fetcher,
DataFetcherExceptionResolver... resolvers) {
return initGraphQlSource(schemaContent, typeName, fieldName, fetcher)
.exceptionResolvers(Arrays.asList(resolvers))
.build()
.graphQl();
}
public static GraphQlSource.Builder initGraphQlSource(
String schemaContent, String typeName, String fieldName, DataFetcher<?> fetcher) {
RuntimeWiring wiring = RuntimeWiring.newRuntimeWiring()
.type(typeName, builder -> builder.dataFetcher(fieldName, fetcher))
.build();
return GraphQlSource.builder()
.schemaResource(new ByteArrayResource(schemaContent.getBytes(StandardCharsets.UTF_8)))
.runtimeWiring(wiring);
}
}

View File

@@ -15,7 +15,6 @@
*/
package org.springframework.graphql.execution;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.List;
import java.util.Map;
@@ -23,33 +22,34 @@ import java.util.Map;
import graphql.ExecutionInput;
import graphql.ExecutionResult;
import graphql.GraphQL;
import graphql.schema.DataFetcher;
import graphql.schema.idl.RuntimeWiring;
import org.junit.jupiter.api.Test;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.util.context.Context;
import reactor.util.context.ContextView;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.graphql.GraphQlTestUtils;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Tests for {@link ReactorDataFetcherAdapter}.
* Tests for {@link ContextDataFetcherDecorator}.
*/
public class ReactorDataFetcherAdapterTests {
public class ContextDataFetcherDecoratorTests {
@Test
void monoDataFetcher() throws Exception {
GraphQL graphQl = graphQl("type Query { greeting: String }",
GraphQL graphQl = GraphQlTestUtils.initGraphQl("type Query { greeting: String }",
"Query", "greeting", env ->
Mono.deferContextual(context -> {
Object name = context.get("name");
return Mono.delay(Duration.ofMillis(50)).map(aLong -> "Hello " + name);
}));
ExecutionInput input = executionInput("{ greeting }", Context.of("name", "007"));
ExecutionInput input = ExecutionInput.newExecutionInput().query("{ greeting }").build();
ContextManager.setReactorContext(Context.of("name", "007"), input);
Map<String, Object> data = graphQl.executeAsync(input).get().getData();
assertThat(data).hasSize(1).containsEntry("greeting", "Hello 007");
@@ -57,7 +57,7 @@ public class ReactorDataFetcherAdapterTests {
@Test
void fluxDataFetcher() throws Exception {
GraphQL graphQl = graphQl("type Query { greetings: [String] }",
GraphQL graphQl = GraphQlTestUtils.initGraphQl("type Query { greetings: [String] }",
"Query", "greetings", env ->
Mono.delay(Duration.ofMillis(50)).flatMapMany(aLong ->
Flux.deferContextual(context -> {
@@ -65,7 +65,9 @@ public class ReactorDataFetcherAdapterTests {
return Flux.just("Hi", "Bonjour", "Hola").map(s -> s + " " + name);
})));
ExecutionInput input = executionInput("{ greetings }", Context.of("name", "007"));
ExecutionInput input = ExecutionInput.newExecutionInput().query("{ greetings }").build();
ContextManager.setReactorContext(Context.of("name", "007"), input);
Map<String, Object> data = graphQl.executeAsync(input).get().getData();
assertThat((List<String>) data.get("greetings")).containsExactly("Hi 007", "Bonjour 007", "Hola 007");
@@ -73,7 +75,7 @@ public class ReactorDataFetcherAdapterTests {
@Test
void fluxDataFetcherSubscription() throws Exception {
GraphQL graphQl = graphQl(
GraphQL graphQl = GraphQlTestUtils.initGraphQl(
"type Query { greeting: String } type Subscription { greetings: String }",
"Subscription", "greetings", env ->
Mono.delay(Duration.ofMillis(50)).flatMapMany(aLong ->
@@ -82,7 +84,9 @@ public class ReactorDataFetcherAdapterTests {
return Flux.just("Hi", "Bonjour", "Hola").map(s -> s + " " + name);
})));
ExecutionInput input = executionInput("subscription { greetings }", Context.of("name", "007"));
ExecutionInput input = ExecutionInput.newExecutionInput().query("subscription { greetings }").build();
ContextManager.setReactorContext(Context.of("name", "007"), input);
Publisher<String> publisher = graphQl.executeAsync(input).get().getData();
List<String> actual = Flux.from(publisher)
@@ -95,21 +99,32 @@ public class ReactorDataFetcherAdapterTests {
assertThat(actual).containsExactly("Hi 007", "Bonjour 007", "Hola 007");
}
private GraphQL graphQl(String schemaValue, String typeName, String fieldName, DataFetcher<?> dataFetcher) {
RuntimeWiring wiring = RuntimeWiring.newRuntimeWiring()
.type(typeName, builder -> builder.dataFetcher(fieldName, dataFetcher))
.build();
return GraphQlSource.builder()
.schemaResource(new ByteArrayResource(schemaValue.getBytes(StandardCharsets.UTF_8)))
.runtimeWiring(wiring)
.build()
.graphQl();
}
@Test
void dataFetcherWithThreadLocalContext() {
long threadId = Thread.currentThread().getId();
ThreadLocal<String> nameThreadLocal = new ThreadLocal<>();
nameThreadLocal.set("007");
try {
GraphQL graphQl = GraphQlTestUtils.initGraphQl("type Query { greeting: String }",
"Query", "greeting", env -> {
assertThat(Thread.currentThread().getId() != threadId).as("Not on async thread").isTrue();
return "Hello " + nameThreadLocal.get();
});
private ExecutionInput executionInput(String query, Context reactorContext) {
ExecutionInput input = ExecutionInput.newExecutionInput().query(query).build();
ContextManager.setReactorContext(reactorContext, input);
return input;
ExecutionInput input = ExecutionInput.newExecutionInput().query("{ greeting }").build();
ContextView view = ContextManager.extractThreadLocalValues(new TestThreadLocalAccessor(nameThreadLocal));
ContextManager.setReactorContext(view, input);
ExecutionResult result = Mono.delay(Duration.ofMillis(10))
.flatMap(aLong -> Mono.fromFuture(graphQl.executeAsync(input)))
.block();
Map<String, Object> data = result.getData();
assertThat(data).hasSize(1).containsEntry("greeting", "Hello 007");
}
finally {
nameThreadLocal.remove();
}
}
}

View File

@@ -15,8 +15,7 @@
*/
package org.springframework.graphql.execution;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@@ -26,14 +25,12 @@ import graphql.ExecutionResult;
import graphql.GraphQL;
import graphql.GraphQLError;
import graphql.GraphqlErrorBuilder;
import graphql.schema.DataFetcher;
import graphql.schema.DataFetchingEnvironment;
import graphql.schema.idl.RuntimeWiring;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
import reactor.util.context.Context;
import reactor.util.context.ContextView;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.graphql.GraphQlTestUtils;
import static org.assertj.core.api.Assertions.assertThat;
@@ -44,23 +41,17 @@ public class ExceptionResolversExceptionHandlerTests {
@Test
void resolveException() throws Exception {
GraphQL graphQl = graphQl("type Query { greeting: String }",
GraphQL graphQl = GraphQlTestUtils.initGraphQl("type Query { greeting: String }",
"Query", "greeting", env -> {
throw new IllegalArgumentException("Invalid greeting");
},
new SingleErrorExceptionResolver() {
@Override
protected Mono<GraphQLError> doResolve(Throwable ex, DataFetchingEnvironment env) {
return Mono.deferContextual(view ->
Mono.just(GraphqlErrorBuilder.newError(env)
.message("Resolved error: " + ex.getMessage() + ", name=" + view.get("name"))
.errorType(ErrorType.BAD_REQUEST)
.build()));
}});
(ex, env) -> Mono.just(Collections.singletonList(
GraphqlErrorBuilder.newError(env)
.message("Resolved error: " + ex.getMessage())
.errorType(ErrorType.BAD_REQUEST)
.build())));
ExecutionInput input = ExecutionInput.newExecutionInput().query("{ greeting }").build();
ContextManager.setReactorContext(Context.of("name", "007"), input);
ExecutionResult result = graphQl.executeAsync(input).get();
Map<String, Object> data = result.getData();
@@ -69,13 +60,69 @@ public class ExceptionResolversExceptionHandlerTests {
List<GraphQLError> errors = result.getErrors();
assertThat(errors).hasSize(1);
GraphQLError error = errors.get(0);
assertThat(error.getMessage()).isEqualTo("Resolved error: Invalid greeting, name=007");
assertThat(error.getMessage()).isEqualTo("Resolved error: Invalid greeting");
assertThat(error.getErrorType().toString()).isEqualTo("BAD_REQUEST");
}
@Test
void resolveExceptionWithReactorContext() throws Exception {
GraphQL graphQl = GraphQlTestUtils.initGraphQl("type Query { greeting: String }",
"Query", "greeting", env -> {
throw new IllegalArgumentException("Invalid greeting");
},
(ex, env) -> Mono.deferContextual(view -> Mono.just(Collections.singletonList(
GraphqlErrorBuilder.newError(env)
.message("Resolved error: " + ex.getMessage() + ", name=" + view.get("name"))
.errorType(ErrorType.BAD_REQUEST)
.build()))));
ExecutionInput input = ExecutionInput.newExecutionInput().query("{ greeting }").build();
ContextManager.setReactorContext(Context.of("name", "007"), input);
ExecutionResult result = graphQl.executeAsync(input).get();
GraphQLError error = result.getErrors().get(0);
assertThat(error.getMessage()).isEqualTo("Resolved error: Invalid greeting, name=007");
}
@Test
void resolveExceptionWithThreadLocal() {
long threadId = Thread.currentThread().getId();
ThreadLocal<String> nameThreadLocal = new ThreadLocal<>();
nameThreadLocal.set("007");
try {
GraphQL graphQl = GraphQlTestUtils.initGraphQl("type Query { greeting: String }",
"Query", "greeting", env -> {
throw new IllegalArgumentException("Invalid greeting");
},
(SyncDataFetcherExceptionResolver) (ex, env) -> {
assertThat(Thread.currentThread().getId() != threadId).as("Not on async thread").isTrue();
return Collections.singletonList(
GraphqlErrorBuilder.newError(env)
.message("Resolved error: " + ex.getMessage() + ", name=" + nameThreadLocal.get())
.errorType(ErrorType.BAD_REQUEST)
.build());
});
ExecutionInput input = ExecutionInput.newExecutionInput().query("{ greeting }").build();
ContextView view = ContextManager.extractThreadLocalValues(new TestThreadLocalAccessor(nameThreadLocal));
ContextManager.setReactorContext(view, input);
ExecutionResult result = Mono.delay(Duration.ofMillis(10))
.flatMap(aLong -> Mono.fromFuture(graphQl.executeAsync(input)))
.block();
GraphQLError error = result.getErrors().get(0);
assertThat(error.getMessage()).isEqualTo("Resolved error: Invalid greeting, name=007");
}
finally {
nameThreadLocal.remove();
}
}
@Test
void unresolvedException() throws Exception {
GraphQL graphQl = graphQl("type Query { greeting: String }",
GraphQL graphQl = GraphQlTestUtils.initGraphQl("type Query { greeting: String }",
"Query", "greeting", env -> {
throw new IllegalArgumentException("Invalid greeting");
},
@@ -96,7 +143,7 @@ public class ExceptionResolversExceptionHandlerTests {
@Test
void suppressedException() throws Exception {
GraphQL graphQl = graphQl("type Query { greeting: String }",
GraphQL graphQl = GraphQlTestUtils.initGraphQl("type Query { greeting: String }",
"Query", "greeting", env -> {
throw new IllegalArgumentException("Invalid greeting");
},
@@ -110,20 +157,4 @@ public class ExceptionResolversExceptionHandlerTests {
assertThat(result.getErrors()).hasSize(0);
}
private GraphQL graphQl(String schemaContent,
String typeName, String fieldName, DataFetcher<?> dataFetcher,
DataFetcherExceptionResolver... resolvers) {
RuntimeWiring wiring = RuntimeWiring.newRuntimeWiring()
.type(typeName, builder -> builder.dataFetcher(fieldName, dataFetcher))
.build();
return GraphQlSource.builder()
.schemaResource(new ByteArrayResource(schemaContent.getBytes(StandardCharsets.UTF_8)))
.runtimeWiring(wiring)
.exceptionResolvers(Arrays.asList(resolvers))
.build()
.graphQl();
}
}

View File

@@ -0,0 +1,53 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.graphql.execution;
import java.util.Map;
import org.springframework.util.Assert;
/**
* {@link ThreadLocalAccessor} that operates on the ThreadLocal it is given.
*/
class TestThreadLocalAccessor implements ThreadLocalAccessor {
private final ThreadLocal<String> threadLocal;
TestThreadLocalAccessor(ThreadLocal<String> threadLocal) {
this.threadLocal = threadLocal;
}
@Override
public void extractValues(Map<String, Object> container) {
String name = this.threadLocal.get();
Assert.notNull(name, "No ThreadLocal value");
container.put("name", name);
}
@Override
public void restoreValues(Map<String, Object> values) {
String name = (String) values.get("name");
Assert.notNull(name, "No value to set");
this.threadLocal.set(name);
}
@Override
public void resetValues(Map<String, Object> values) {
this.threadLocal.remove();
}
}