Add SubscriptionExceptionResolver

See gh-398
This commit is contained in:
Nikita Ivchenko
2022-07-06 21:42:20 +03:00
committed by rstoyanchev
parent 96f158b682
commit a7e68d35a2
16 changed files with 583 additions and 96 deletions

View File

@@ -16,12 +16,6 @@
package org.springframework.graphql.execution;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import graphql.GraphQL;
import graphql.execution.instrumentation.ChainedInstrumentation;
import graphql.execution.instrumentation.Instrumentation;
@@ -30,6 +24,9 @@ import graphql.schema.GraphQLSchema;
import graphql.schema.GraphQLTypeVisitor;
import graphql.schema.SchemaTraverser;
import java.util.*;
import java.util.function.Consumer;
/**
* Implementation of {@link GraphQlSource.Builder} that leaves it to subclasses
@@ -43,6 +40,8 @@ abstract class AbstractGraphQlSourceBuilder<B extends GraphQlSource.Builder<B>>
private final List<DataFetcherExceptionResolver> exceptionResolvers = new ArrayList<>();
private final List<SubscriptionExceptionResolver> subscriptionExceptionResolvers = new ArrayList<>();
private final List<GraphQLTypeVisitor> typeVisitors = new ArrayList<>();
private final List<Instrumentation> instrumentations = new ArrayList<>();
@@ -57,6 +56,12 @@ abstract class AbstractGraphQlSourceBuilder<B extends GraphQlSource.Builder<B>>
return self();
}
@Override
public B subscriptionExceptionResolvers(List<SubscriptionExceptionResolver> subscriptionExceptionResolvers) {
this.subscriptionExceptionResolvers.addAll(subscriptionExceptionResolvers);
return self();
}
@Override
public B typeVisitors(List<GraphQLTypeVisitor> typeVisitors) {
this.typeVisitors.addAll(typeVisitors);
@@ -105,8 +110,12 @@ abstract class AbstractGraphQlSourceBuilder<B extends GraphQlSource.Builder<B>>
protected abstract GraphQLSchema initGraphQlSchema();
private GraphQLSchema applyTypeVisitors(GraphQLSchema schema) {
SubscriptionExceptionResolver subscriptionExceptionResolver = new DelegatingSubscriptionExceptionResolver(
subscriptionExceptionResolvers);
GraphQLTypeVisitor visitor = ContextDataFetcherDecorator.createVisitor(subscriptionExceptionResolver);
List<GraphQLTypeVisitor> visitors = new ArrayList<>(this.typeVisitors);
visitors.add(ContextDataFetcherDecorator.TYPE_VISITOR);
visitors.add(visitor);
GraphQLCodeRegistry.Builder codeRegistry = GraphQLCodeRegistry.newCodeRegistry(schema.getCodeRegistry());
Map<Class<?>, Object> vars = Collections.singletonMap(GraphQLCodeRegistry.Builder.class, codeRegistry);

View File

@@ -17,22 +17,16 @@
package org.springframework.graphql.execution;
import graphql.ExecutionInput;
import graphql.schema.DataFetcher;
import graphql.schema.DataFetchingEnvironment;
import graphql.schema.GraphQLCodeRegistry;
import graphql.schema.GraphQLFieldDefinition;
import graphql.schema.GraphQLFieldsContainer;
import graphql.schema.GraphQLSchemaElement;
import graphql.schema.GraphQLTypeVisitor;
import graphql.schema.GraphQLTypeVisitorStub;
import graphql.schema.*;
import graphql.util.TraversalControl;
import graphql.util.TraverserContext;
import org.reactivestreams.Publisher;
import org.springframework.util.Assert;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.util.context.ContextView;
import org.springframework.util.Assert;
import java.util.function.Function;
/**
* Wrap a {@link DataFetcher} to enable the following:
@@ -51,10 +45,16 @@ final class ContextDataFetcherDecorator implements DataFetcher<Object> {
private final boolean subscription;
private ContextDataFetcherDecorator(DataFetcher<?> delegate, boolean subscription) {
private final SubscriptionExceptionResolver subscriptionExceptionResolver;
private ContextDataFetcherDecorator(
DataFetcher<?> delegate, boolean subscription,
SubscriptionExceptionResolver subscriptionExceptionResolver) {
Assert.notNull(delegate, "'delegate' DataFetcher is required");
Assert.notNull(subscriptionExceptionResolver, "'subscriptionExceptionResolver' is required");
this.delegate = delegate;
this.subscription = subscription;
this.subscriptionExceptionResolver = subscriptionExceptionResolver;
}
@Override
@@ -66,7 +66,8 @@ final class ContextDataFetcherDecorator implements DataFetcher<Object> {
ContextView contextView = ReactorContextManager.getReactorContext(environment.getGraphQlContext());
if (this.subscription) {
return (!contextView.isEmpty() ? Flux.from((Publisher<?>) value).contextWrite(contextView) : value);
Publisher<?> publisher = interceptSubscriptionPublisherWithExceptionHandler((Publisher<?>) value);
return (!contextView.isEmpty() ? Flux.from(publisher).contextWrite(contextView) : publisher);
}
if (value instanceof Flux) {
@@ -84,29 +85,48 @@ final class ContextDataFetcherDecorator implements DataFetcher<Object> {
return value;
}
@SuppressWarnings("unchecked")
private Publisher<?> interceptSubscriptionPublisherWithExceptionHandler(Publisher<?> publisher) {
Function<? super Throwable, Mono<?>> onErrorResumeFunction = e ->
subscriptionExceptionResolver.resolveException(e)
.flatMap(errors -> Mono.error(new SubscriptionStreamException(errors)));
if (publisher instanceof Flux) {
return ((Flux<Object>) publisher).onErrorResume(onErrorResumeFunction);
}
if (publisher instanceof Mono) {
return ((Mono<Object>) publisher).onErrorResume(onErrorResumeFunction);
}
throw new IllegalArgumentException("Unknown publisher type: '" + publisher.getClass().getName() +"'. " +
"Expected reactor.core.publisher.Mono or reactor.core.publisher.Flux");
}
/**
* {@link GraphQLTypeVisitor} that wraps non-GraphQL data fetchers and adapts them if
* they return {@link Flux} or {@link Mono}.
*/
static GraphQLTypeVisitor TYPE_VISITOR = new GraphQLTypeVisitorStub() {
static GraphQLTypeVisitor createVisitor(SubscriptionExceptionResolver subscriptionExceptionResolver) {
return new GraphQLTypeVisitorStub() {
@Override
public TraversalControl visitGraphQLFieldDefinition(GraphQLFieldDefinition fieldDefinition,
TraverserContext<GraphQLSchemaElement> context) {
@Override
public TraversalControl visitGraphQLFieldDefinition(GraphQLFieldDefinition fieldDefinition,
TraverserContext<GraphQLSchemaElement> context) {
GraphQLCodeRegistry.Builder codeRegistry = context.getVarFromParents(GraphQLCodeRegistry.Builder.class);
GraphQLFieldsContainer parent = (GraphQLFieldsContainer) context.getParentNode();
DataFetcher<?> dataFetcher = codeRegistry.getDataFetcher(parent, fieldDefinition);
GraphQLCodeRegistry.Builder codeRegistry = context.getVarFromParents(GraphQLCodeRegistry.Builder.class);
GraphQLFieldsContainer parent = (GraphQLFieldsContainer) context.getParentNode();
DataFetcher<?> dataFetcher = codeRegistry.getDataFetcher(parent, fieldDefinition);
if (dataFetcher.getClass().getPackage().getName().startsWith("graphql.")) {
return TraversalControl.CONTINUE;
}
if (dataFetcher.getClass().getPackage().getName().startsWith("graphql.")) {
boolean handlesSubscription = parent.getName().equals("Subscription");
dataFetcher = new ContextDataFetcherDecorator(dataFetcher, handlesSubscription, subscriptionExceptionResolver);
codeRegistry.dataFetcher(parent, fieldDefinition, dataFetcher);
return TraversalControl.CONTINUE;
}
boolean handlesSubscription = parent.getName().equals("Subscription");
dataFetcher = new ContextDataFetcherDecorator(dataFetcher, handlesSubscription);
codeRegistry.dataFetcher(parent, fieldDefinition, dataFetcher);
return TraversalControl.CONTINUE;
}
};
};
}
}

View File

@@ -0,0 +1,74 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.graphql.execution;
import graphql.ErrorType;
import graphql.GraphQLError;
import graphql.GraphqlErrorBuilder;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.util.Assert;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import java.util.Collections;
import java.util.List;
/**
* An implementation of {@link SubscriptionExceptionResolver} that is trying to map exception to GraphQL error
* using provided implementation of {@link SubscriptionExceptionResolver}.
* <br/>
* If none of provided implementations resolve exception to error or if any of implementation throw an exception,
* this {@link SubscriptionExceptionResolver} will return a default error.
*
* @author Mykyta Ivchenko
* @see SubscriptionExceptionResolver
*/
public class DelegatingSubscriptionExceptionResolver implements SubscriptionExceptionResolver {
private static final Log logger = LogFactory.getLog(DelegatingSubscriptionExceptionResolver.class);
private final List<SubscriptionExceptionResolver> resolvers;
public DelegatingSubscriptionExceptionResolver(List<SubscriptionExceptionResolver> resolvers) {
Assert.notNull(resolvers, "'resolvers' list must be not null.");
this.resolvers = resolvers;
}
@Override
public Mono<List<GraphQLError>> resolveException(Throwable exception) {
return Flux.fromIterable(resolvers)
.flatMap(resolver -> resolver.resolveException(exception))
.next()
.onErrorResume(error -> Mono.just(handleMappingException(error, exception)))
.defaultIfEmpty(createDefaultErrors());
}
private List<GraphQLError> handleMappingException(Throwable resolverException, Throwable originalException) {
if (logger.isWarnEnabled()) {
logger.warn("Failure while resolving " + originalException.getClass().getName(), resolverException);
}
return createDefaultErrors();
}
private List<GraphQLError> createDefaultErrors() {
GraphQLError error = GraphqlErrorBuilder.newError()
.message("Unknown error")
.errorType(ErrorType.DataFetchingException)
.build();
return Collections.singletonList(error);
}
}

View File

@@ -16,11 +16,6 @@
package org.springframework.graphql.execution;
import java.io.InputStream;
import java.util.List;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import graphql.GraphQL;
import graphql.execution.instrumentation.Instrumentation;
import graphql.schema.GraphQLSchema;
@@ -28,9 +23,13 @@ import graphql.schema.GraphQLTypeVisitor;
import graphql.schema.TypeResolver;
import graphql.schema.idl.RuntimeWiring;
import graphql.schema.idl.TypeDefinitionRegistry;
import org.springframework.core.io.Resource;
import java.io.InputStream;
import java.util.List;
import java.util.function.BiFunction;
import java.util.function.Consumer;
/**
* Strategy to resolve a {@link GraphQL} and a {@link GraphQLSchema}.
@@ -91,6 +90,14 @@ public interface GraphQlSource {
*/
B exceptionResolvers(List<DataFetcherExceptionResolver> resolvers);
/**
* Add {@link SubscriptionExceptionResolver}s to map exceptions, thrown by
* GraphQL Subscription publisher.
* @param subscriptionExceptionResolver the subscription exception resolver
* @return the current builder
*/
B subscriptionExceptionResolvers(List<SubscriptionExceptionResolver> subscriptionExceptionResolvers);
/**
* Add {@link GraphQLTypeVisitor}s to visit all element of the created
* {@link graphql.schema.GraphQLSchema}.

View File

@@ -0,0 +1,50 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.graphql.execution;
import graphql.GraphQLError;
import reactor.core.publisher.Mono;
import java.util.List;
/**
* Contract to resolve exceptions, that are thrown by subscription publisher.
* Implementations are typically declared as beans in Spring configuration and
* are invoked sequentially until one emits a List of {@link GraphQLError}s.
* <br/>
* Usually, it is enough to implement this interface by extending {@link SubscriptionExceptionResolverAdapter}
* and overriding one of its {@link SubscriptionExceptionResolverAdapter#resolveToSingleError(Throwable)}
* or {@link SubscriptionExceptionResolverAdapter#resolveToMultipleErrors(Throwable)}
*
* @author Mykyta Ivchenko
* @see SubscriptionExceptionResolverAdapter
* @see DelegatingSubscriptionExceptionResolver
* @see org.springframework.graphql.server.webflux.GraphQlWebSocketHandler
*/
@FunctionalInterface
public interface SubscriptionExceptionResolver {
/**
* Resolve given exception as list of {@link GraphQLError}s and send them as WebSocket message.
* @param exception the exception to resolve
* @return a {@code Mono} with errors to send in a WebSocket message;
* if the {@code Mono} completes with an empty List, the exception is resolved
* without any errors added to the response; if the {@code Mono} completes
* empty, without emitting a List, the exception remains unresolved and gives
* other resolvers a chance.
*/
Mono<List<GraphQLError>> resolveException(Throwable exception);
}

View File

@@ -0,0 +1,48 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.graphql.execution;
import graphql.GraphQLError;
import reactor.core.publisher.Mono;
import java.util.Collections;
import java.util.List;
/**
* Abstract class for {@link SubscriptionExceptionResolver} implementations.
* This class provide an easy way to map an exception as GraphQL error synchronously.
* <br/>
* To use this class, you need to override either {@link SubscriptionExceptionResolverAdapter#resolveToSingleError(Throwable)}
* or {@link SubscriptionExceptionResolverAdapter#resolveToMultipleErrors(Throwable)}.
*
* @author Mykyta Ivchenko
* @see SubscriptionExceptionResolver
*/
public abstract class SubscriptionExceptionResolverAdapter implements SubscriptionExceptionResolver {
@Override
public Mono<List<GraphQLError>> resolveException(Throwable exception) {
return Mono.just(resolveToMultipleErrors(exception));
}
protected List<GraphQLError> resolveToMultipleErrors(Throwable exception) {
return Collections.singletonList(resolveToSingleError(exception));
}
protected GraphQLError resolveToSingleError(Throwable exception) {
return null;
}
}

View File

@@ -0,0 +1,20 @@
package org.springframework.graphql.execution;
import graphql.GraphQLError;
import org.springframework.core.NestedRuntimeException;
import java.util.List;
@SuppressWarnings("serial")
public class SubscriptionStreamException extends NestedRuntimeException {
private final List<GraphQLError> errors;
public SubscriptionStreamException(List<GraphQLError> errors) {
super("An exception happened in GraphQL subscription stream.");
this.errors = errors;
}
public List<GraphQLError> getErrors() {
return errors;
}
}

View File

@@ -19,6 +19,7 @@ package org.springframework.graphql.server.support;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import graphql.GraphQLError;
@@ -188,10 +189,13 @@ public class GraphQlWebSocketMessage {
* @param id unique request id
* @param error the error to add as the message payload
*/
public static GraphQlWebSocketMessage error(String id, GraphQLError error) {
Assert.notNull(error, "GraphQlError is required");
List<Map<String, Object>> errors = Collections.singletonList(error.toSpecification());
return new GraphQlWebSocketMessage(id, GraphQlWebSocketMessageType.ERROR, errors);
public static GraphQlWebSocketMessage error(String id, List<GraphQLError> errors) {
Assert.notNull(errors, "GraphQlErrors list is required");
List<Map<String, Object>> errorsMap = errors.stream()
.map(GraphQLError::toSpecification)
.collect(Collectors.toList());
return new GraphQlWebSocketMessage(id, GraphQlWebSocketMessageType.ERROR, errorsMap);
}
/**

View File

@@ -15,6 +15,8 @@
*/
package org.springframework.graphql.server.webflux;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import graphql.GraphQLError;
@@ -97,9 +99,13 @@ final class CodecDelegate {
return encode(session, GraphQlWebSocketMessage.next(id, responseMap));
}
public WebSocketMessage encodeError(WebSocketSession session, String id, Throwable ex) {
public WebSocketMessage encodeUnknownError(WebSocketSession session, String id, Throwable ex) {
GraphQLError error = GraphqlErrorBuilder.newError().message(ex.getMessage()).build();
return encode(session, GraphQlWebSocketMessage.error(id, error));
return encodeError(session, id, Collections.singletonList(error));
}
public WebSocketMessage encodeError(WebSocketSession session, String id, List<GraphQLError> errors) {
return encode(session, GraphQlWebSocketMessage.error(id, errors));
}
public WebSocketMessage encodeComplete(WebSocketSession session, String id) {

View File

@@ -28,10 +28,12 @@ import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
import graphql.ExecutionResult;
import graphql.GraphQLError;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscription;
import org.springframework.graphql.execution.SubscriptionStreamException;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
@@ -216,7 +218,11 @@ public class GraphQlWebSocketHandler implements WebSocketHandler {
CloseStatus status = new CloseStatus(4409, "Subscriber for " + id + " already exists");
return GraphQlStatus.close(session, status);
}
return Mono.fromCallable(() -> this.codecDelegate.encodeError(session, id, ex));
if (ex instanceof SubscriptionStreamException) {
List<GraphQLError> errors = ((SubscriptionStreamException) ex).getErrors();
return Mono.fromCallable(() -> this.codecDelegate.encodeError(session, id, errors));
}
return Mono.fromCallable(() -> this.codecDelegate.encodeUnknownError(session, id, ex));
});
}

View File

@@ -41,6 +41,7 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscription;
import org.springframework.graphql.execution.SubscriptionStreamException;
import reactor.core.publisher.BaseSubscriber;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
@@ -286,9 +287,13 @@ public class GraphQlWebSocketHandler extends TextWebSocketHandler implements Sub
GraphQlStatus.closeSession(session, status);
return Flux.empty();
}
if (ex instanceof SubscriptionStreamException) {
List<GraphQLError> errors = ((SubscriptionStreamException) ex).getErrors();
return Mono.just(encode(GraphQlWebSocketMessage.error(id, errors)));
}
String message = ex.getMessage();
GraphQLError error = GraphqlErrorBuilder.newError().message(message).build();
return Mono.just(encode(GraphQlWebSocketMessage.error(id, error)));
return Mono.just(encode(GraphQlWebSocketMessage.error(id, Collections.singletonList(error))));
});
}

View File

@@ -16,6 +16,7 @@
package org.springframework.graphql.client;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.function.Function;
@@ -105,7 +106,7 @@ public final class MockGraphQlWebSocketServer implements WebSocketHandler {
.map(response -> GraphQlWebSocketMessage.next(id, response.toMap()))
.concatWithValues(
request.getError() != null ?
GraphQlWebSocketMessage.error(id, request.getError()) :
GraphQlWebSocketMessage.error(id, Collections.singletonList(request.getError())) :
GraphQlWebSocketMessage.complete(id));
case COMPLETE:
return Flux.empty();

View File

@@ -16,24 +16,25 @@
package org.springframework.graphql.execution;
import java.time.Duration;
import java.util.List;
import graphql.ExecutionInput;
import graphql.ExecutionResult;
import graphql.GraphQL;
import graphql.*;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.springframework.graphql.GraphQlSetup;
import org.springframework.graphql.ResponseHelper;
import org.springframework.graphql.TestThreadLocalAccessor;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import reactor.util.context.Context;
import reactor.util.context.ContextView;
import org.springframework.graphql.ResponseHelper;
import org.springframework.graphql.GraphQlSetup;
import org.springframework.graphql.TestThreadLocalAccessor;
import java.time.Duration;
import java.util.Collections;
import java.util.List;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.verify;
/**
* Tests for {@link ContextDataFetcherDecorator}.
@@ -104,6 +105,97 @@ public class ContextDataFetcherDecoratorTests {
.verifyComplete();
}
@Test
void fluxDataFetcherSubscriptionThrowException() throws Exception {
GraphQLError expectedError = GraphqlErrorBuilder.newError()
.message("Error: Example Error")
.errorType(ErrorType.INTERNAL_ERROR)
.extensions(Collections.singletonMap("a", "b"))
.build();
SubscriptionExceptionResolver subscriptionSingleExceptionResolverAdapter = Mockito.spy(
new SubscriptionExceptionResolverAdapter() {
@Override
protected GraphQLError resolveToSingleError(Throwable exception) {
return GraphqlErrorBuilder.newError()
.message("Error: " + exception.getMessage())
.errorType(ErrorType.INTERNAL_ERROR)
.extensions(Collections.singletonMap("a", "b"))
.build();
}
}
);
GraphQL graphQl = GraphQlSetup.schemaContent("type Query { greeting: String } type Subscription { greetings: String }")
.subscriptionExceptionResolvers(subscriptionSingleExceptionResolverAdapter)
.subscriptionFetcher("greetings", (env) ->
Mono.delay(Duration.ofMillis(50))
.flatMapMany((aLong) -> Flux.create(sink -> {
sink.next("Hi!");
sink.error(new RuntimeException("Example Error"));
})))
.toGraphQl();
ExecutionInput input = ExecutionInput.newExecutionInput().query("subscription { greetings }").build();
ExecutionResult executionResult = graphQl.executeAsync(input).get();
Flux<String> greetingsFlux = ResponseHelper.forSubscription(executionResult)
.map(message -> message.toEntity("greetings", String.class));
StepVerifier.create(greetingsFlux)
.expectNext("Hi!")
.expectErrorSatisfies(error -> assertThat(error)
.usingRecursiveComparison()
.isEqualTo(new SubscriptionStreamException(Collections.singletonList(expectedError))))
.verify();
verify(subscriptionSingleExceptionResolverAdapter).resolveException(any(RuntimeException.class));
}
@Test
void monoDataFetcherSubscriptionThrowException() throws Exception {
GraphQLError expectedError = GraphqlErrorBuilder.newError()
.message("Error: Example Error")
.errorType(ErrorType.INTERNAL_ERROR)
.extensions(Collections.singletonMap("a", "b"))
.build();
SubscriptionExceptionResolver subscriptionSingleExceptionResolverAdapter = Mockito.spy(
new SubscriptionExceptionResolverAdapter() {
@Override
protected GraphQLError resolveToSingleError(Throwable exception) {
return GraphqlErrorBuilder.newError()
.message("Error: " + exception.getMessage())
.errorType(ErrorType.INTERNAL_ERROR)
.extensions(Collections.singletonMap("a", "b"))
.build();
}
}
);
GraphQL graphQl = GraphQlSetup.schemaContent("type Query { greeting: String } type Subscription { greetings: String }")
.subscriptionExceptionResolvers(subscriptionSingleExceptionResolverAdapter)
.subscriptionFetcher("greetings", (env) ->
Mono.delay(Duration.ofMillis(50))
.then(Mono.error(new RuntimeException("Example Error"))))
.toGraphQl();
ExecutionInput input = ExecutionInput.newExecutionInput().query("subscription { greetings }").build();
ExecutionResult executionResult = graphQl.executeAsync(input).get();
Flux<ResponseHelper> greetingsFlux = ResponseHelper.forSubscription(executionResult);
StepVerifier.create(greetingsFlux)
.expectErrorSatisfies(error -> assertThat(error)
.usingRecursiveComparison()
.isEqualTo(new SubscriptionStreamException(Collections.singletonList(expectedError))))
.verify();
verify(subscriptionSingleExceptionResolverAdapter).resolveException(any(RuntimeException.class));
}
@Test
void dataFetcherWithThreadLocalContext() {
ThreadLocal<String> nameThreadLocal = new ThreadLocal<>();

View File

@@ -16,6 +16,31 @@
package org.springframework.graphql.server.webflux;
import graphql.GraphQLError;
import graphql.GraphqlErrorBuilder;
import org.assertj.core.api.InstanceOfAssertFactories;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.springframework.core.ResolvableType;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.graphql.GraphQlSetup;
import org.springframework.graphql.execution.ErrorType;
import org.springframework.graphql.execution.SubscriptionExceptionResolver;
import org.springframework.graphql.execution.SubscriptionExceptionResolverAdapter;
import org.springframework.graphql.server.*;
import org.springframework.graphql.server.support.GraphQlWebSocketMessage;
import org.springframework.graphql.server.support.GraphQlWebSocketMessageType;
import org.springframework.http.codec.ServerCodecConfigurer;
import org.springframework.http.codec.json.Jackson2JsonDecoder;
import org.springframework.web.reactive.socket.CloseStatus;
import org.springframework.web.reactive.socket.WebSocketMessage;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;
import reactor.test.StepVerifier;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.ArrayList;
@@ -25,33 +50,11 @@ import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
import org.assertj.core.api.InstanceOfAssertFactories;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;
import reactor.test.StepVerifier;
import org.springframework.core.ResolvableType;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.graphql.GraphQlSetup;
import org.springframework.graphql.server.ConsumeOneAndNeverCompleteInterceptor;
import org.springframework.graphql.server.WebGraphQlHandler;
import org.springframework.graphql.server.WebGraphQlInterceptor;
import org.springframework.graphql.server.WebSocketHandlerTestSupport;
import org.springframework.graphql.server.WebSocketGraphQlInterceptor;
import org.springframework.graphql.server.WebSocketSessionInfo;
import org.springframework.graphql.server.support.GraphQlWebSocketMessage;
import org.springframework.graphql.server.support.GraphQlWebSocketMessageType;
import org.springframework.http.codec.ServerCodecConfigurer;
import org.springframework.http.codec.json.Jackson2JsonDecoder;
import org.springframework.web.reactive.socket.CloseStatus;
import org.springframework.web.reactive.socket.WebSocketMessage;
import static org.assertj.core.api.Assertions.as;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
/**
* Unit tests for {@link GraphQlWebSocketHandler}.
@@ -356,7 +359,7 @@ public class GraphQlWebSocketHandlerTests extends WebSocketHandlerTestSupport {
.asInstanceOf(InstanceOfAssertFactories.map(String.class, Object.class))
.hasSize(3)
.hasEntrySatisfying("locations", loc -> assertThat(loc).asList().isEmpty())
.hasEntrySatisfying("message", msg -> assertThat(msg).asString().contains("null"))
.hasEntrySatisfying("message", msg -> assertThat(msg).asString().contains("Unknown error"))
.extractingByKey("extensions", as(InstanceOfAssertFactories.map(String.class, Object.class)))
.containsEntry("classification", "DataFetchingException"));
})
@@ -364,6 +367,77 @@ public class GraphQlWebSocketHandlerTests extends WebSocketHandlerTestSupport {
.verify(TIMEOUT);
}
@Test
void subscriptionStreamException() {
final String GREETING_QUERY = "{" +
"\"id\":\"" + SUBSCRIPTION_ID + "\"," +
"\"type\":\"subscribe\"," +
"\"payload\":{\"query\": \"" +
" subscription TestTypenameSubscription {" +
" greeting" +
" }\"}" +
"}";
String schema = "type Subscription { greeting: String! } type Query { greetingUnused: String! }";
WebGraphQlHandler initHandler = GraphQlSetup.schemaContent(schema)
.subscriptionFetcher("greeting", env -> Flux.create(emitter -> {
emitter.next("a");
emitter.error(new RuntimeException("Test Exception"));
emitter.next("b");
}))
.subscriptionExceptionResolvers(new SubscriptionExceptionResolverAdapter() {
@Override
protected GraphQLError resolveToSingleError(Throwable exception) {
return GraphqlErrorBuilder.newError()
.errorType(ErrorType.INTERNAL_ERROR)
.message("Error: " + exception.getMessage())
.extensions(Collections.singletonMap("key", "value"))
.build();
}
})
.interceptor()
.toWebGraphQlHandler();
GraphQlWebSocketHandler handler = new GraphQlWebSocketHandler(
initHandler,
ServerCodecConfigurer.create(),
Duration.ofSeconds(60));
TestWebSocketSession session = new TestWebSocketSession(Flux.just(
toWebSocketMessage("{\"type\":\"connection_init\"}"),
toWebSocketMessage(GREETING_QUERY)));
handler.handle(session).block(TIMEOUT);
StepVerifier.create(session.getOutput())
.consumeNextWith((message) -> assertMessageType(message, GraphQlWebSocketMessageType.CONNECTION_ACK))
.consumeNextWith((message) -> {
GraphQlWebSocketMessage actual = decode(message);
assertThat(actual.getId()).isEqualTo(SUBSCRIPTION_ID);
assertThat(actual.resolvedType()).isEqualTo(GraphQlWebSocketMessageType.NEXT);
assertThat(actual.<Map<String, Object>>getPayload())
.extractingByKey("data", as(InstanceOfAssertFactories.map(String.class, Object.class)))
.containsEntry("greeting", "a");
})
.consumeNextWith((message) -> {
GraphQlWebSocketMessage actual = decode(message);
assertThat(actual.getId()).isEqualTo(SUBSCRIPTION_ID);
assertThat(actual.resolvedType()).isEqualTo(GraphQlWebSocketMessageType.ERROR);
assertThat(actual.<List<Map<String, Object>>>getPayload())
.asList().hasSize(1)
.allSatisfy(theError -> assertThat(theError)
.asInstanceOf(InstanceOfAssertFactories.map(String.class, Object.class))
.hasSize(3)
.hasEntrySatisfying("locations", loc -> assertThat(loc).asList().isEmpty())
.hasEntrySatisfying("message", msg -> assertThat(msg).asString().isEqualTo("Error: Test Exception"))
.extractingByKey("extensions", as(InstanceOfAssertFactories.map(String.class, Object.class)))
.containsEntry("classification", "INTERNAL_ERROR")
.containsEntry("key", "value"));
})
.expectComplete()
.verify(TIMEOUT);
}
private TestWebSocketSession handle(Flux<WebSocketMessage> input, WebGraphQlInterceptor... interceptors) {
GraphQlWebSocketHandler handler = new GraphQlWebSocketHandler(
initHandler(interceptors),

View File

@@ -28,8 +28,12 @@ import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import graphql.GraphQLError;
import graphql.GraphqlErrorBuilder;
import org.assertj.core.api.InstanceOfAssertFactories;
import org.junit.jupiter.api.Test;
import org.springframework.graphql.execution.ErrorType;
import org.springframework.graphql.execution.SubscriptionExceptionResolverAdapter;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
@@ -366,7 +370,7 @@ public class GraphQlWebSocketHandlerTests extends WebSocketHandlerTestSupport {
.asInstanceOf(InstanceOfAssertFactories.map(String.class, Object.class))
.hasSize(3)
.hasEntrySatisfying("locations", loc -> assertThat(loc).asList().isEmpty())
.hasEntrySatisfying("message", msg -> assertThat(msg).asString().contains("null"))
.hasEntrySatisfying("message", msg -> assertThat(msg).asString().contains("Unknown error"))
.extractingByKey("extensions", as(InstanceOfAssertFactories.map(String.class, Object.class)))
.containsEntry("classification", "DataFetchingException"));
})
@@ -375,6 +379,74 @@ public class GraphQlWebSocketHandlerTests extends WebSocketHandlerTestSupport {
.verify(TIMEOUT);
}
@Test
void subscriptionStreamException() throws Exception {
final String GREETING_QUERY = "{" +
"\"id\":\"" + SUBSCRIPTION_ID + "\"," +
"\"type\":\"subscribe\"," +
"\"payload\":{\"query\": \"" +
" subscription TestTypenameSubscription {" +
" greeting" +
" }\"}" +
"}";
String schema = "type Subscription { greeting: String! }type Query { greetingUnused: String! }";
WebGraphQlHandler initHandler = GraphQlSetup.schemaContent(schema)
.subscriptionFetcher("greeting", env -> Flux.create(emitter -> {
emitter.next("a");
emitter.error(new RuntimeException("Test Exception"));
emitter.next("b");
}))
.subscriptionExceptionResolvers(new SubscriptionExceptionResolverAdapter() {
@Override
protected GraphQLError resolveToSingleError(Throwable exception) {
return GraphqlErrorBuilder.newError()
.message("Error: " + exception.getMessage())
.errorType(ErrorType.INTERNAL_ERROR)
.extensions(Collections.singletonMap("key", "value"))
.build();
}
})
.interceptor()
.toWebGraphQlHandler();
GraphQlWebSocketHandler handler = new GraphQlWebSocketHandler(initHandler, converter, Duration.ofSeconds(60));
handle(handler,
new TextMessage("{\"type\":\"connection_init\"}"),
new TextMessage(GREETING_QUERY));
StepVerifier.create(this.session.getOutput())
.consumeNextWith((message) -> assertMessageType(message, GraphQlWebSocketMessageType.CONNECTION_ACK))
.consumeNextWith((message) -> {
GraphQlWebSocketMessage actual = decode(message);
assertThat(actual.getId()).isEqualTo(SUBSCRIPTION_ID);
assertThat(actual.resolvedType()).isEqualTo(GraphQlWebSocketMessageType.NEXT);
assertThat(actual.<Map<String, Object>>getPayload())
.extractingByKey("data", as(InstanceOfAssertFactories.map(String.class, Object.class)))
.containsEntry("greeting", "a");
})
.consumeNextWith((message) -> {
GraphQlWebSocketMessage actual = decode(message);
assertThat(actual.getId()).isEqualTo(SUBSCRIPTION_ID);
assertThat(actual.resolvedType()).isEqualTo(GraphQlWebSocketMessageType.ERROR);
assertThat(actual.<List<Map<String, Object>>>getPayload())
.asList().hasSize(1)
.allSatisfy(theError -> assertThat(theError)
.asInstanceOf(InstanceOfAssertFactories.map(String.class, Object.class))
.hasSize(3)
.hasEntrySatisfying("locations", loc -> assertThat(loc).asList().isEmpty())
.hasEntrySatisfying("message", msg -> assertThat(msg).asString().contains("Error: Test Exception"))
.extractingByKey("extensions", as(InstanceOfAssertFactories.map(String.class, Object.class)))
.containsEntry("classification", "INTERNAL_ERROR")
.containsEntry("key", "value"));
})
.then(this.session::close)
.expectComplete()
.verify(TIMEOUT);
}
@Test
void contextPropagation() throws Exception {
ThreadLocal<String> threadLocal = new ThreadLocal<>();

View File

@@ -15,32 +15,26 @@
*/
package org.springframework.graphql;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import graphql.GraphQL;
import graphql.schema.DataFetcher;
import graphql.schema.GraphQLTypeVisitor;
import graphql.schema.TypeResolver;
import org.springframework.context.ApplicationContext;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.core.io.Resource;
import org.springframework.graphql.data.method.annotation.support.AnnotatedControllerConfigurer;
import org.springframework.graphql.execution.DataFetcherExceptionResolver;
import org.springframework.graphql.execution.DataLoaderRegistrar;
import org.springframework.graphql.execution.DefaultExecutionGraphQlService;
import org.springframework.graphql.execution.GraphQlSource;
import org.springframework.graphql.execution.RuntimeWiringConfigurer;
import org.springframework.graphql.execution.ThreadLocalAccessor;
import org.springframework.graphql.execution.*;
import org.springframework.graphql.server.WebGraphQlHandler;
import org.springframework.graphql.server.WebGraphQlInterceptor;
import org.springframework.graphql.server.WebGraphQlSetup;
import org.springframework.graphql.server.webflux.GraphQlHttpHandler;
import org.springframework.lang.Nullable;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* Workflow for GraphQL tests setup that starts with {@link GraphQlSource.Builder}
* related input, and then optionally moving on to the creation of a
@@ -99,6 +93,11 @@ public class GraphQlSetup implements GraphQlServiceSetup {
return this;
}
public GraphQlSetup subscriptionExceptionResolvers(SubscriptionExceptionResolver... resolvers) {
this.graphQlSourceBuilder.subscriptionExceptionResolvers(Arrays.asList(resolvers));
return this;
}
public GraphQlSetup typeResolver(TypeResolver typeResolver) {
this.graphQlSourceBuilder.defaultTypeResolver(typeResolver);
return this;