diff --git a/spring-graphql/src/main/java/org/springframework/graphql/execution/SubscriptionExceptionResolverAdapter.java b/spring-graphql/src/main/java/org/springframework/graphql/execution/SubscriptionExceptionResolverAdapter.java index 1d1ee44a..7c3c8cb3 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/execution/SubscriptionExceptionResolverAdapter.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/execution/SubscriptionExceptionResolverAdapter.java @@ -44,9 +44,50 @@ import org.springframework.lang.Nullable; */ public abstract class SubscriptionExceptionResolverAdapter implements SubscriptionExceptionResolver { + private boolean threadLocalContextAware; + + + /** + * Subclasses can set this to indicate that ThreadLocal context from the + * transport handler (e.g. HTTP handler) should be restored when resolving + * exceptions. + *

Note: This property is applicable only if transports + * use ThreadLocal's' (e.g. Spring MVC) and if a {@link ThreadLocalAccessor} + * is registered to extract ThreadLocal values of interest. There is no + * impact from setting this property otherwise. + *

By default this is set to "false" in which case there is no attempt + * to propagate ThreadLocal context. + * @param threadLocalContextAware whether this resolver needs access to + * ThreadLocal context or not. + */ + public void setThreadLocalContextAware(boolean threadLocalContextAware) { + this.threadLocalContextAware = threadLocalContextAware; + } + + /** + * Whether ThreadLocal context needs to be restored for this resolver. + */ + public boolean isThreadLocalContextAware() { + return this.threadLocalContextAware; + } + + @Override public final Mono> resolveException(Throwable exception) { - return Mono.justOrEmpty(resolveToMultipleErrors(exception)); + if (!this.threadLocalContextAware) { + return Mono.justOrEmpty(resolveToMultipleErrors(exception)); + } + return Mono.deferContextual(contextView -> { + List errors; + try { + ReactorContextManager.restoreThreadLocalValues(contextView); + errors = resolveToMultipleErrors(exception); + } + finally { + ReactorContextManager.resetThreadLocalValues(contextView); + } + return Mono.justOrEmpty(errors); + }); } /** diff --git a/spring-graphql/src/test/java/org/springframework/graphql/ResponseHelper.java b/spring-graphql/src/test/java/org/springframework/graphql/ResponseHelper.java index e2e2ab65..6aaf877a 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/ResponseHelper.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/ResponseHelper.java @@ -150,6 +150,9 @@ public class ResponseHelper { public static Flux forSubscription(ExecutionResult result) { assertThat(result.getErrors()).as("Errors present in GraphQL response").isEmpty(); + Object data = result.getData(); + assertThat(data).as("Expected Publisher from subscription").isNotNull(); + assertThat(data).as("Expected Publisher from subscription").isInstanceOf(Publisher.class); Publisher publisher = result.getData(); return Flux.from(publisher).map(ResponseHelper::forResult); } diff --git a/spring-graphql/src/test/java/org/springframework/graphql/execution/CompositeSubscriptionExceptionResolverTests.java b/spring-graphql/src/test/java/org/springframework/graphql/execution/CompositeSubscriptionExceptionResolverTests.java new file mode 100644 index 00000000..95f36f2d --- /dev/null +++ b/spring-graphql/src/test/java/org/springframework/graphql/execution/CompositeSubscriptionExceptionResolverTests.java @@ -0,0 +1,133 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.graphql.execution; + +import java.time.Duration; +import java.util.List; + +import graphql.ExecutionInput; +import graphql.GraphQL; +import graphql.GraphQLError; +import graphql.GraphqlErrorBuilder; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; +import reactor.util.context.Context; +import reactor.util.context.ContextView; + +import org.springframework.graphql.GraphQlSetup; +import org.springframework.graphql.ResponseHelper; +import org.springframework.graphql.TestThreadLocalAccessor; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for resolving exceptions via {@link SubscriptionExceptionResolver}. + * @author Rossen Stoyanchev + */ +public class CompositeSubscriptionExceptionResolverTests { + + private static final Duration TIMEOUT = Duration.ofSeconds(5); + + + @Test + void subscriptionPublisherExceptionResolved() { + String query = "subscription { greetings }"; + String schema = "type Subscription { greetings: String! } type Query { greeting: String! }"; + + GraphQL graphQL = GraphQlSetup.schemaContent(schema) + .subscriptionFetcher("greetings", env -> + Flux.create(emitter -> { + emitter.next("a"); + emitter.error(new RuntimeException("Test Exception")); + emitter.next("b"); + })) + .subscriptionExceptionResolvers(SubscriptionExceptionResolver.forSingleError(exception -> + GraphqlErrorBuilder.newError() + .message("Error: " + exception.getMessage()) + .errorType(ErrorType.BAD_REQUEST) + .build())) + .toGraphQl(); + + ExecutionInput input = ExecutionInput.newExecutionInput(query).build(); + Flux flux = Mono.fromFuture(graphQL.executeAsync(input)) + .map(ResponseHelper::forSubscription) + .block(TIMEOUT); + + StepVerifier.create(flux) + .consumeNextWith((helper) -> assertThat(helper.toEntity("greetings", String.class)).isEqualTo("a")) + .consumeErrorWith((ex) -> { + SubscriptionPublisherException theEx = (SubscriptionPublisherException) ex; + List errors = theEx.getErrors(); + assertThat(errors).hasSize(1); + assertThat(errors.get(0).getMessage()).isEqualTo("Error: Test Exception"); + assertThat(errors.get(0).getErrorType()).isEqualTo(ErrorType.BAD_REQUEST); + }) + .verify(TIMEOUT); + } + + @Test + void resolveExceptionWithThreadLocal() { + String query = "subscription { greetings }"; + String schema = "type Subscription { greetings: String! } type Query { greeting: String! }"; + + ThreadLocal nameThreadLocal = new ThreadLocal<>(); + nameThreadLocal.set("007"); + TestThreadLocalAccessor accessor = new TestThreadLocalAccessor<>(nameThreadLocal); + + try { + SubscriptionExceptionResolverAdapter resolver = SubscriptionExceptionResolver.forSingleError(exception -> + GraphqlErrorBuilder.newError() + .message("Error: " + exception.getMessage() + ", name=" + nameThreadLocal.get()) + .errorType(ErrorType.BAD_REQUEST) + .build()); + resolver.setThreadLocalContextAware(true); + + GraphQL graphQL = GraphQlSetup.schemaContent(schema) + .subscriptionFetcher("greetings", env -> + Flux.create(emitter -> { + emitter.next("a"); + emitter.error(new RuntimeException("Test Exception")); + })) + .subscriptionExceptionResolvers(resolver) + .toGraphQl(); + + ContextView view = ReactorContextManager.extractThreadLocalValues(accessor, Context.empty()); + ExecutionInput input = ExecutionInput.newExecutionInput(query).build(); + ReactorContextManager.setReactorContext(view, input.getGraphQLContext()); + + Flux flux = Mono.delay(Duration.ofMillis(10)) + .flatMap((aLong) -> Mono.fromFuture(graphQL.executeAsync(input)).map(ResponseHelper::forSubscription)) + .block(TIMEOUT); + + StepVerifier.create(flux) + .consumeNextWith((helper) -> assertThat(helper.toEntity("greetings", String.class)).isEqualTo("a")) + .consumeErrorWith((ex) -> { + SubscriptionPublisherException theEx = (SubscriptionPublisherException) ex; + List errors = theEx.getErrors(); + assertThat(errors).hasSize(1); + assertThat(errors.get(0).getMessage()).isEqualTo("Error: Test Exception, name=007"); + assertThat(errors.get(0).getErrorType()).isEqualTo(ErrorType.BAD_REQUEST); + }) + .verify(TIMEOUT); + } + finally { + nameThreadLocal.remove(); + } + } + +} diff --git a/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlWebSocketHandlerTests.java b/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlWebSocketHandlerTests.java index 6bba8688..051fa27a 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlWebSocketHandlerTests.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlWebSocketHandlerTests.java @@ -25,7 +25,6 @@ import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiConsumer; -import graphql.GraphqlErrorBuilder; import org.assertj.core.api.InstanceOfAssertFactories; import org.junit.jupiter.api.Test; import reactor.core.publisher.Flux; @@ -39,7 +38,6 @@ import org.springframework.core.io.buffer.DataBufferUtils; import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.graphql.GraphQlSetup; import org.springframework.graphql.execution.ErrorType; -import org.springframework.graphql.execution.SubscriptionExceptionResolver; import org.springframework.graphql.server.ConsumeOneAndNeverCompleteInterceptor; import org.springframework.graphql.server.WebGraphQlHandler; import org.springframework.graphql.server.WebGraphQlInterceptor; @@ -313,23 +311,20 @@ public class GraphQlWebSocketHandlerTests extends WebSocketHandlerTestSupport { @Test void subscriptionErrorPayloadIsArray() { - final String GREETING_QUERY = "{" + + String query = "{" + "\"id\":\"" + SUBSCRIPTION_ID + "\"," + "\"type\":\"subscribe\"," + - "\"payload\":{\"query\": \"" + - " subscription TestTypenameSubscription {" + - " greeting" + - " }\"}" + + "\"payload\":{\"query\": \"subscription { greetings }\"}" + "}"; - String schema = "type Subscription { greeting: String! } type Query { greetingUnused: String! }"; + String schema = "type Subscription { greetings: String! } type Query { greeting: String! }"; TestWebSocketSession session = new TestWebSocketSession(Flux.just( toWebSocketMessage("{\"type\":\"connection_init\"}"), - toWebSocketMessage(GREETING_QUERY))); + toWebSocketMessage(query))); WebGraphQlHandler webHandler = GraphQlSetup.schemaContent(schema) - .subscriptionFetcher("greeting", env -> Flux.just("a", null, "b")) + .subscriptionFetcher("greetings", env -> Flux.just("a", null, "b")) .toWebGraphQlHandler(); new GraphQlWebSocketHandler(webHandler, ServerCodecConfigurer.create(), TIMEOUT) @@ -342,7 +337,7 @@ public class GraphQlWebSocketHandlerTests extends WebSocketHandlerTestSupport { assertThat(actual.getId()).isEqualTo(SUBSCRIPTION_ID); assertThat(actual.resolvedType()).isEqualTo(GraphQlWebSocketMessageType.NEXT); assertThat(actual.>getPayload()) - .containsEntry("data", Collections.singletonMap("greeting", "a")); + .containsEntry("data", Collections.singletonMap("greetings", "a")); }) .consumeNextWith((message) -> { GraphQlWebSocketMessage actual = decode(message); @@ -358,63 +353,6 @@ public class GraphQlWebSocketHandlerTests extends WebSocketHandlerTestSupport { .verify(TIMEOUT); } - @Test - void subscriptionPublisherExceptionResolved() { - final String GREETING_QUERY = "{" + - "\"id\":\"" + SUBSCRIPTION_ID + "\"," + - "\"type\":\"subscribe\"," + - "\"payload\":{\"query\": \"" + - " subscription TestTypenameSubscription {" + - " greeting" + - " }\"}" + - "}"; - - String schema = "type Subscription { greeting: String! } type Query { greetingUnused: String! }"; - - TestWebSocketSession session = new TestWebSocketSession(Flux.just( - toWebSocketMessage("{\"type\":\"connection_init\"}"), - toWebSocketMessage(GREETING_QUERY))); - - WebGraphQlHandler webHandler = GraphQlSetup.schemaContent(schema) - .subscriptionFetcher("greeting", env -> - Flux.create(emitter -> { - emitter.next("a"); - emitter.error(new RuntimeException("Test Exception")); - emitter.next("b"); - })) - .subscriptionExceptionResolvers(SubscriptionExceptionResolver.forSingleError(exception -> - GraphqlErrorBuilder.newError() - .message("Error: " + exception.getMessage()) - .errorType(ErrorType.BAD_REQUEST) - .build())) - .toWebGraphQlHandler(); - - new GraphQlWebSocketHandler(webHandler, ServerCodecConfigurer.create(), TIMEOUT) - .handle(session).block(TIMEOUT); - - StepVerifier.create(session.getOutput()) - .consumeNextWith((message) -> assertMessageType(message, GraphQlWebSocketMessageType.CONNECTION_ACK)) - .consumeNextWith((message) -> { - GraphQlWebSocketMessage actual = decode(message); - assertThat(actual.getId()).isEqualTo(SUBSCRIPTION_ID); - assertThat(actual.resolvedType()).isEqualTo(GraphQlWebSocketMessageType.NEXT); - assertThat(actual.>getPayload()) - .containsEntry("data", Collections.singletonMap("greeting", "a")); - }) - .consumeNextWith((message) -> { - GraphQlWebSocketMessage actual = decode(message); - assertThat(actual.getId()).isEqualTo(SUBSCRIPTION_ID); - assertThat(actual.resolvedType()).isEqualTo(GraphQlWebSocketMessageType.ERROR); - List> errors = actual.getPayload(); - assertThat(errors).hasSize(1); - assertThat(errors.get(0)).containsEntry("message", "Error: Test Exception"); - assertThat(errors.get(0)).containsEntry("extensions", - Collections.singletonMap("classification", ErrorType.BAD_REQUEST.name())); - }) - .expectComplete() - .verify(TIMEOUT); - } - private TestWebSocketSession handle(Flux input, WebGraphQlInterceptor... interceptors) { GraphQlWebSocketHandler handler = new GraphQlWebSocketHandler( initHandler(interceptors), diff --git a/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlWebSocketHandlerTests.java b/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlWebSocketHandlerTests.java index 390f71d1..03aeaa9a 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlWebSocketHandlerTests.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlWebSocketHandlerTests.java @@ -28,7 +28,6 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiConsumer; import java.util.function.Consumer; -import graphql.GraphqlErrorBuilder; import org.assertj.core.api.InstanceOfAssertFactories; import org.junit.jupiter.api.Test; import reactor.core.publisher.Flux; @@ -38,7 +37,6 @@ import reactor.test.StepVerifier; import org.springframework.graphql.GraphQlSetup; import org.springframework.graphql.TestThreadLocalAccessor; import org.springframework.graphql.execution.ErrorType; -import org.springframework.graphql.execution.SubscriptionExceptionResolver; import org.springframework.graphql.execution.ThreadLocalAccessor; import org.springframework.graphql.server.ConsumeOneAndNeverCompleteInterceptor; import org.springframework.graphql.server.WebGraphQlHandler; @@ -330,16 +328,13 @@ public class GraphQlWebSocketHandlerTests extends WebSocketHandlerTestSupport { final String GREETING_QUERY = "{" + "\"id\":\"" + SUBSCRIPTION_ID + "\"," + "\"type\":\"subscribe\"," + - "\"payload\":{\"query\": \"" + - " subscription TestTypenameSubscription {" + - " greeting" + - " }\"}" + + "\"payload\":{\"query\": \"subscription { greetings }\"}" + "}"; - String schema = "type Subscription { greeting: String! }type Query { greetingUnused: String! }"; + String schema = "type Subscription { greetings: String! }type Query { greeting: String! }"; WebGraphQlHandler webHandler = GraphQlSetup.schemaContent(schema) - .subscriptionFetcher("greeting", env -> Flux.just("a", null, "b")) + .subscriptionFetcher("greetings", env -> Flux.just("a", null, "b")) .toWebGraphQlHandler(); handle(new GraphQlWebSocketHandler(webHandler, converter, TIMEOUT), @@ -353,7 +348,7 @@ public class GraphQlWebSocketHandlerTests extends WebSocketHandlerTestSupport { assertThat(actual.getId()).isEqualTo(SUBSCRIPTION_ID); assertThat(actual.resolvedType()).isEqualTo(GraphQlWebSocketMessageType.NEXT); assertThat(actual.>getPayload()) - .containsEntry("data", Collections.singletonMap("greeting", "a")); + .containsEntry("data", Collections.singletonMap("greetings", "a")); }) .consumeNextWith((message) -> { GraphQlWebSocketMessage actual = decode(message); @@ -370,62 +365,6 @@ public class GraphQlWebSocketHandlerTests extends WebSocketHandlerTestSupport { .verify(TIMEOUT); } - @Test - void subscriptionPublisherExceptionResolved() throws Exception { - final String GREETING_QUERY = "{" + - "\"id\":\"" + SUBSCRIPTION_ID + "\"," + - "\"type\":\"subscribe\"," + - "\"payload\":{\"query\": \"" + - " subscription TestTypenameSubscription {" + - " greeting" + - " }\"}" + - "}"; - - String schema = "type Subscription { greeting: String! }type Query { greetingUnused: String! }"; - - WebGraphQlHandler initHandler = GraphQlSetup.schemaContent(schema) - .subscriptionFetcher("greeting", env -> Flux.create(emitter -> { - emitter.next("a"); - emitter.error(new RuntimeException("Test Exception")); - emitter.next("b"); - })) - .subscriptionExceptionResolvers(SubscriptionExceptionResolver.forSingleError(exception -> - GraphqlErrorBuilder.newError() - .message("Error: " + exception.getMessage()) - .errorType(ErrorType.BAD_REQUEST) - .build())) - .toWebGraphQlHandler(); - - GraphQlWebSocketHandler handler = new GraphQlWebSocketHandler(initHandler, converter, Duration.ofSeconds(60)); - - handle(handler, - new TextMessage("{\"type\":\"connection_init\"}"), - new TextMessage(GREETING_QUERY)); - - StepVerifier.create(this.session.getOutput()) - .consumeNextWith((message) -> assertMessageType(message, GraphQlWebSocketMessageType.CONNECTION_ACK)) - .consumeNextWith((message) -> { - GraphQlWebSocketMessage actual = decode(message); - assertThat(actual.getId()).isEqualTo(SUBSCRIPTION_ID); - assertThat(actual.resolvedType()).isEqualTo(GraphQlWebSocketMessageType.NEXT); - assertThat(actual.>getPayload()) - .containsEntry("data", Collections.singletonMap("greeting", "a")); - }) - .consumeNextWith((message) -> { - GraphQlWebSocketMessage actual = decode(message); - assertThat(actual.getId()).isEqualTo(SUBSCRIPTION_ID); - assertThat(actual.resolvedType()).isEqualTo(GraphQlWebSocketMessageType.ERROR); - List> errors = actual.getPayload(); - assertThat(errors).hasSize(1); - assertThat(errors.get(0)).containsEntry("message", "Error: Test Exception"); - assertThat(errors.get(0)).containsEntry("extensions", - Collections.singletonMap("classification", ErrorType.BAD_REQUEST.name())); - }) - .then(this.session::close) - .expectComplete() - .verify(TIMEOUT); - } - @Test void contextPropagation() throws Exception { ThreadLocal threadLocal = new ThreadLocal<>();