diff --git a/spring-graphql/src/main/java/org/springframework/graphql/execution/SubscriptionExceptionResolverAdapter.java b/spring-graphql/src/main/java/org/springframework/graphql/execution/SubscriptionExceptionResolverAdapter.java
index 1d1ee44a..7c3c8cb3 100644
--- a/spring-graphql/src/main/java/org/springframework/graphql/execution/SubscriptionExceptionResolverAdapter.java
+++ b/spring-graphql/src/main/java/org/springframework/graphql/execution/SubscriptionExceptionResolverAdapter.java
@@ -44,9 +44,50 @@ import org.springframework.lang.Nullable;
*/
public abstract class SubscriptionExceptionResolverAdapter implements SubscriptionExceptionResolver {
+ private boolean threadLocalContextAware;
+
+
+ /**
+ * Subclasses can set this to indicate that ThreadLocal context from the
+ * transport handler (e.g. HTTP handler) should be restored when resolving
+ * exceptions.
+ *
Note: This property is applicable only if transports
+ * use ThreadLocal's' (e.g. Spring MVC) and if a {@link ThreadLocalAccessor}
+ * is registered to extract ThreadLocal values of interest. There is no
+ * impact from setting this property otherwise.
+ *
By default this is set to "false" in which case there is no attempt
+ * to propagate ThreadLocal context.
+ * @param threadLocalContextAware whether this resolver needs access to
+ * ThreadLocal context or not.
+ */
+ public void setThreadLocalContextAware(boolean threadLocalContextAware) {
+ this.threadLocalContextAware = threadLocalContextAware;
+ }
+
+ /**
+ * Whether ThreadLocal context needs to be restored for this resolver.
+ */
+ public boolean isThreadLocalContextAware() {
+ return this.threadLocalContextAware;
+ }
+
+
@Override
public final Mono> resolveException(Throwable exception) {
- return Mono.justOrEmpty(resolveToMultipleErrors(exception));
+ if (!this.threadLocalContextAware) {
+ return Mono.justOrEmpty(resolveToMultipleErrors(exception));
+ }
+ return Mono.deferContextual(contextView -> {
+ List errors;
+ try {
+ ReactorContextManager.restoreThreadLocalValues(contextView);
+ errors = resolveToMultipleErrors(exception);
+ }
+ finally {
+ ReactorContextManager.resetThreadLocalValues(contextView);
+ }
+ return Mono.justOrEmpty(errors);
+ });
}
/**
diff --git a/spring-graphql/src/test/java/org/springframework/graphql/ResponseHelper.java b/spring-graphql/src/test/java/org/springframework/graphql/ResponseHelper.java
index e2e2ab65..6aaf877a 100644
--- a/spring-graphql/src/test/java/org/springframework/graphql/ResponseHelper.java
+++ b/spring-graphql/src/test/java/org/springframework/graphql/ResponseHelper.java
@@ -150,6 +150,9 @@ public class ResponseHelper {
public static Flux forSubscription(ExecutionResult result) {
assertThat(result.getErrors()).as("Errors present in GraphQL response").isEmpty();
+ Object data = result.getData();
+ assertThat(data).as("Expected Publisher from subscription").isNotNull();
+ assertThat(data).as("Expected Publisher from subscription").isInstanceOf(Publisher.class);
Publisher publisher = result.getData();
return Flux.from(publisher).map(ResponseHelper::forResult);
}
diff --git a/spring-graphql/src/test/java/org/springframework/graphql/execution/CompositeSubscriptionExceptionResolverTests.java b/spring-graphql/src/test/java/org/springframework/graphql/execution/CompositeSubscriptionExceptionResolverTests.java
new file mode 100644
index 00000000..95f36f2d
--- /dev/null
+++ b/spring-graphql/src/test/java/org/springframework/graphql/execution/CompositeSubscriptionExceptionResolverTests.java
@@ -0,0 +1,133 @@
+/*
+ * Copyright 2002-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.graphql.execution;
+
+import java.time.Duration;
+import java.util.List;
+
+import graphql.ExecutionInput;
+import graphql.GraphQL;
+import graphql.GraphQLError;
+import graphql.GraphqlErrorBuilder;
+import org.junit.jupiter.api.Test;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
+import reactor.test.StepVerifier;
+import reactor.util.context.Context;
+import reactor.util.context.ContextView;
+
+import org.springframework.graphql.GraphQlSetup;
+import org.springframework.graphql.ResponseHelper;
+import org.springframework.graphql.TestThreadLocalAccessor;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * Tests for resolving exceptions via {@link SubscriptionExceptionResolver}.
+ * @author Rossen Stoyanchev
+ */
+public class CompositeSubscriptionExceptionResolverTests {
+
+ private static final Duration TIMEOUT = Duration.ofSeconds(5);
+
+
+ @Test
+ void subscriptionPublisherExceptionResolved() {
+ String query = "subscription { greetings }";
+ String schema = "type Subscription { greetings: String! } type Query { greeting: String! }";
+
+ GraphQL graphQL = GraphQlSetup.schemaContent(schema)
+ .subscriptionFetcher("greetings", env ->
+ Flux.create(emitter -> {
+ emitter.next("a");
+ emitter.error(new RuntimeException("Test Exception"));
+ emitter.next("b");
+ }))
+ .subscriptionExceptionResolvers(SubscriptionExceptionResolver.forSingleError(exception ->
+ GraphqlErrorBuilder.newError()
+ .message("Error: " + exception.getMessage())
+ .errorType(ErrorType.BAD_REQUEST)
+ .build()))
+ .toGraphQl();
+
+ ExecutionInput input = ExecutionInput.newExecutionInput(query).build();
+ Flux flux = Mono.fromFuture(graphQL.executeAsync(input))
+ .map(ResponseHelper::forSubscription)
+ .block(TIMEOUT);
+
+ StepVerifier.create(flux)
+ .consumeNextWith((helper) -> assertThat(helper.toEntity("greetings", String.class)).isEqualTo("a"))
+ .consumeErrorWith((ex) -> {
+ SubscriptionPublisherException theEx = (SubscriptionPublisherException) ex;
+ List errors = theEx.getErrors();
+ assertThat(errors).hasSize(1);
+ assertThat(errors.get(0).getMessage()).isEqualTo("Error: Test Exception");
+ assertThat(errors.get(0).getErrorType()).isEqualTo(ErrorType.BAD_REQUEST);
+ })
+ .verify(TIMEOUT);
+ }
+
+ @Test
+ void resolveExceptionWithThreadLocal() {
+ String query = "subscription { greetings }";
+ String schema = "type Subscription { greetings: String! } type Query { greeting: String! }";
+
+ ThreadLocal nameThreadLocal = new ThreadLocal<>();
+ nameThreadLocal.set("007");
+ TestThreadLocalAccessor accessor = new TestThreadLocalAccessor<>(nameThreadLocal);
+
+ try {
+ SubscriptionExceptionResolverAdapter resolver = SubscriptionExceptionResolver.forSingleError(exception ->
+ GraphqlErrorBuilder.newError()
+ .message("Error: " + exception.getMessage() + ", name=" + nameThreadLocal.get())
+ .errorType(ErrorType.BAD_REQUEST)
+ .build());
+ resolver.setThreadLocalContextAware(true);
+
+ GraphQL graphQL = GraphQlSetup.schemaContent(schema)
+ .subscriptionFetcher("greetings", env ->
+ Flux.create(emitter -> {
+ emitter.next("a");
+ emitter.error(new RuntimeException("Test Exception"));
+ }))
+ .subscriptionExceptionResolvers(resolver)
+ .toGraphQl();
+
+ ContextView view = ReactorContextManager.extractThreadLocalValues(accessor, Context.empty());
+ ExecutionInput input = ExecutionInput.newExecutionInput(query).build();
+ ReactorContextManager.setReactorContext(view, input.getGraphQLContext());
+
+ Flux flux = Mono.delay(Duration.ofMillis(10))
+ .flatMap((aLong) -> Mono.fromFuture(graphQL.executeAsync(input)).map(ResponseHelper::forSubscription))
+ .block(TIMEOUT);
+
+ StepVerifier.create(flux)
+ .consumeNextWith((helper) -> assertThat(helper.toEntity("greetings", String.class)).isEqualTo("a"))
+ .consumeErrorWith((ex) -> {
+ SubscriptionPublisherException theEx = (SubscriptionPublisherException) ex;
+ List errors = theEx.getErrors();
+ assertThat(errors).hasSize(1);
+ assertThat(errors.get(0).getMessage()).isEqualTo("Error: Test Exception, name=007");
+ assertThat(errors.get(0).getErrorType()).isEqualTo(ErrorType.BAD_REQUEST);
+ })
+ .verify(TIMEOUT);
+ }
+ finally {
+ nameThreadLocal.remove();
+ }
+ }
+
+}
diff --git a/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlWebSocketHandlerTests.java b/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlWebSocketHandlerTests.java
index 6bba8688..051fa27a 100644
--- a/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlWebSocketHandlerTests.java
+++ b/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlWebSocketHandlerTests.java
@@ -25,7 +25,6 @@ import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
-import graphql.GraphqlErrorBuilder;
import org.assertj.core.api.InstanceOfAssertFactories;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Flux;
@@ -39,7 +38,6 @@ import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.graphql.GraphQlSetup;
import org.springframework.graphql.execution.ErrorType;
-import org.springframework.graphql.execution.SubscriptionExceptionResolver;
import org.springframework.graphql.server.ConsumeOneAndNeverCompleteInterceptor;
import org.springframework.graphql.server.WebGraphQlHandler;
import org.springframework.graphql.server.WebGraphQlInterceptor;
@@ -313,23 +311,20 @@ public class GraphQlWebSocketHandlerTests extends WebSocketHandlerTestSupport {
@Test
void subscriptionErrorPayloadIsArray() {
- final String GREETING_QUERY = "{" +
+ String query = "{" +
"\"id\":\"" + SUBSCRIPTION_ID + "\"," +
"\"type\":\"subscribe\"," +
- "\"payload\":{\"query\": \"" +
- " subscription TestTypenameSubscription {" +
- " greeting" +
- " }\"}" +
+ "\"payload\":{\"query\": \"subscription { greetings }\"}" +
"}";
- String schema = "type Subscription { greeting: String! } type Query { greetingUnused: String! }";
+ String schema = "type Subscription { greetings: String! } type Query { greeting: String! }";
TestWebSocketSession session = new TestWebSocketSession(Flux.just(
toWebSocketMessage("{\"type\":\"connection_init\"}"),
- toWebSocketMessage(GREETING_QUERY)));
+ toWebSocketMessage(query)));
WebGraphQlHandler webHandler = GraphQlSetup.schemaContent(schema)
- .subscriptionFetcher("greeting", env -> Flux.just("a", null, "b"))
+ .subscriptionFetcher("greetings", env -> Flux.just("a", null, "b"))
.toWebGraphQlHandler();
new GraphQlWebSocketHandler(webHandler, ServerCodecConfigurer.create(), TIMEOUT)
@@ -342,7 +337,7 @@ public class GraphQlWebSocketHandlerTests extends WebSocketHandlerTestSupport {
assertThat(actual.getId()).isEqualTo(SUBSCRIPTION_ID);
assertThat(actual.resolvedType()).isEqualTo(GraphQlWebSocketMessageType.NEXT);
assertThat(actual.